cvx_analytics/
wasserstein.rs

1//! Wasserstein (optimal transport) distance between distributions.
2//!
3//! Measures the "cost" of transforming one probability distribution into another,
4//! respecting the geometry of the underlying space. Unlike L2 distance between
5//! histograms, Wasserstein accounts for which bins are *close* vs *far*.
6//!
7//! # Why Wasserstein for CVX
8//!
9//! Region distributions at two timestamps are histograms over K regions.
10//! L2 treats all regions as equally distant. Wasserstein uses the actual
11//! distances between region centroids — shifting mass between neighboring
12//! regions costs less than between distant ones.
13//!
14//! # Implementations
15//!
16//! - [`sliced_wasserstein`]: Fast approximation via random 1D projections. O(K × n_proj × K log K).
17//! - [`wasserstein_1d`]: Exact W₁ on 1D distributions. O(K log K).
18//! - [`emd_1d`]: Earth Mover's Distance on sorted 1D values.
19//!
20//! # References
21//!
22//! - Villani, C. (2008). *Optimal Transport: Old and New*. Springer.
23//! - Bonneel, N. et al. (2015). Sliced and Radon Wasserstein barycenters. *JMIV*.
24
25/// Exact Wasserstein-1 (Earth Mover's Distance) between two 1D distributions.
26///
27/// Both distributions must be non-negative and sum to the same total
28/// (typically 1.0 for probability distributions).
29///
30/// # Complexity
31///
32/// O(K log K) for sorting + O(K) for the sweep.
33pub fn wasserstein_1d(a: &[f64], b: &[f64]) -> f64 {
34    assert_eq!(a.len(), b.len(), "distributions must have equal length");
35    let k = a.len();
36    if k == 0 {
37        return 0.0;
38    }
39
40    // W₁ on 1D = integral of |CDF_a - CDF_b|
41    // For discrete distributions on the same support: sum of |cumsum(a) - cumsum(b)|
42    let mut cum_a = 0.0;
43    let mut cum_b = 0.0;
44    let mut distance = 0.0;
45
46    for i in 0..k {
47        cum_a += a[i];
48        cum_b += b[i];
49        distance += (cum_a - cum_b).abs();
50    }
51
52    distance
53}
54
55/// Sliced Wasserstein distance between two distributions in K dimensions.
56///
57/// Approximates the true Wasserstein distance by projecting both distributions
58/// onto random 1D lines and computing the exact W₁ on each projection.
59/// The average over projections converges to the Sliced Wasserstein distance.
60///
61/// # Arguments
62///
63/// * `a` - First distribution: weights over K bins (must sum to ~1.0).
64/// * `b` - Second distribution: same K bins.
65/// * `centroids` - K centroid vectors (one per bin). Used for projection.
66/// * `n_projections` - Number of random 1D projections (more = more accurate).
67/// * `seed` - Random seed for reproducibility.
68///
69/// # Complexity
70///
71/// O(n_proj × K × (D + K log K)) where D = centroid dimensionality.
72pub fn sliced_wasserstein(
73    a: &[f64],
74    b: &[f64],
75    centroids: &[&[f32]],
76    n_projections: usize,
77    seed: u64,
78) -> f64 {
79    let k = a.len();
80    assert_eq!(k, b.len(), "distributions must have equal length");
81    assert_eq!(k, centroids.len(), "must have one centroid per bin");
82    if k == 0 {
83        return 0.0;
84    }
85
86    let dim = centroids[0].len();
87    let mut total = 0.0;
88
89    // Simple PRNG (xorshift64) for random projections
90    let mut rng_state = seed;
91    let mut next_rand = || -> f64 {
92        rng_state ^= rng_state << 13;
93        rng_state ^= rng_state >> 7;
94        rng_state ^= rng_state << 17;
95        // Map to [-1, 1]
96        (rng_state as f64 / u64::MAX as f64) * 2.0 - 1.0
97    };
98
99    for _ in 0..n_projections {
100        // Generate random unit direction
101        let mut direction: Vec<f64> = (0..dim).map(|_| next_rand()).collect();
102        let norm: f64 = direction.iter().map(|x| x * x).sum::<f64>().sqrt();
103        if norm < 1e-10 {
104            continue;
105        }
106        for d in &mut direction {
107            *d /= norm;
108        }
109
110        // Project centroids onto this direction
111        let mut projections: Vec<(f64, f64, f64)> = centroids
112            .iter()
113            .enumerate()
114            .map(|(i, c)| {
115                let proj: f64 = c
116                    .iter()
117                    .zip(direction.iter())
118                    .map(|(&cv, &dv)| cv as f64 * dv)
119                    .sum();
120                (proj, a[i], b[i])
121            })
122            .collect();
123
124        // Sort by projection value
125        projections.sort_by(|x, y| x.0.partial_cmp(&y.0).unwrap());
126
127        // Extract sorted distributions and compute 1D Wasserstein
128        let sorted_a: Vec<f64> = projections.iter().map(|p| p.1).collect();
129        let sorted_b: Vec<f64> = projections.iter().map(|p| p.2).collect();
130        total += wasserstein_1d(&sorted_a, &sorted_b);
131    }
132
133    total / n_projections as f64
134}
135
136/// Compute Wasserstein drift between two region distributions.
137///
138/// Convenience function that wraps [`sliced_wasserstein`] for the common
139/// use case of comparing region distributions at two time points.
140///
141/// # Arguments
142///
143/// * `dist_t1` - Region distribution at time T₁ (K floats summing to ~1.0).
144/// * `dist_t2` - Region distribution at time T₂.
145/// * `centroids` - Region centroid vectors from `index.regions(level)`.
146/// * `n_projections` - Number of projections (default: 50).
147pub fn wasserstein_drift(
148    dist_t1: &[f32],
149    dist_t2: &[f32],
150    centroids: &[&[f32]],
151    n_projections: usize,
152) -> f64 {
153    let a: Vec<f64> = dist_t1.iter().map(|&v| v as f64).collect();
154    let b: Vec<f64> = dist_t2.iter().map(|&v| v as f64).collect();
155    sliced_wasserstein(&a, &b, centroids, n_projections, 42)
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn w1d_identical_distributions() {
164        let a = vec![0.25, 0.25, 0.25, 0.25];
165        let b = vec![0.25, 0.25, 0.25, 0.25];
166        assert!((wasserstein_1d(&a, &b)).abs() < 1e-10);
167    }
168
169    #[test]
170    fn w1d_shifted_mass() {
171        // All mass in first bin vs all in last
172        let a = vec![1.0, 0.0, 0.0, 0.0];
173        let b = vec![0.0, 0.0, 0.0, 1.0];
174        let d = wasserstein_1d(&a, &b);
175        // Mass must travel through 3 bins: cost = 1+1+1 = 3
176        assert!((d - 3.0).abs() < 1e-10, "expected 3.0, got {d}");
177    }
178
179    #[test]
180    fn w1d_adjacent_shift() {
181        // Mass shifts one bin to the right
182        let a = vec![1.0, 0.0, 0.0];
183        let b = vec![0.0, 1.0, 0.0];
184        let d = wasserstein_1d(&a, &b);
185        assert!((d - 1.0).abs() < 1e-10, "expected 1.0, got {d}");
186    }
187
188    #[test]
189    fn w1d_symmetric() {
190        let a = vec![0.5, 0.3, 0.2];
191        let b = vec![0.1, 0.4, 0.5];
192        assert!((wasserstein_1d(&a, &b) - wasserstein_1d(&b, &a)).abs() < 1e-10);
193    }
194
195    #[test]
196    fn w1d_non_negative() {
197        let a = vec![0.7, 0.2, 0.1];
198        let b = vec![0.1, 0.3, 0.6];
199        assert!(wasserstein_1d(&a, &b) >= 0.0);
200    }
201
202    #[test]
203    fn w1d_triangle_inequality() {
204        let a = vec![0.5, 0.3, 0.2];
205        let b = vec![0.1, 0.4, 0.5];
206        let c = vec![0.3, 0.3, 0.4];
207        let d_ab = wasserstein_1d(&a, &b);
208        let d_bc = wasserstein_1d(&b, &c);
209        let d_ac = wasserstein_1d(&a, &c);
210        assert!(
211            d_ac <= d_ab + d_bc + 1e-10,
212            "triangle inequality: d(a,c)={d_ac} > d(a,b)+d(b,c)={}",
213            d_ab + d_bc
214        );
215    }
216
217    #[test]
218    fn sliced_identical_zero() {
219        let a = vec![0.5, 0.3, 0.2];
220        let b = vec![0.5, 0.3, 0.2];
221        let centroids: Vec<&[f32]> =
222            vec![&[1.0f32, 0.0] as &[f32], &[0.0f32, 1.0], &[-1.0f32, 0.0]];
223        let d = sliced_wasserstein(&a, &b, &centroids, 100, 42);
224        assert!(
225            d < 1e-10,
226            "identical distributions should have distance ~0, got {d}"
227        );
228    }
229
230    #[test]
231    fn sliced_different_positive() {
232        let a = vec![1.0, 0.0, 0.0];
233        let b = vec![0.0, 0.0, 1.0];
234        let centroids: Vec<&[f32]> =
235            vec![&[1.0f32, 0.0] as &[f32], &[0.0f32, 0.0], &[-1.0f32, 0.0]];
236        let d = sliced_wasserstein(&a, &b, &centroids, 100, 42);
237        assert!(
238            d > 0.1,
239            "different distributions should have positive distance, got {d}"
240        );
241    }
242
243    #[test]
244    fn sliced_symmetric() {
245        let a = vec![0.6, 0.3, 0.1];
246        let b = vec![0.1, 0.2, 0.7];
247        let centroids: Vec<&[f32]> =
248            vec![&[1.0f32, 0.0] as &[f32], &[0.0f32, 1.0], &[-1.0f32, -1.0]];
249        let d_ab = sliced_wasserstein(&a, &b, &centroids, 200, 42);
250        let d_ba = sliced_wasserstein(&b, &a, &centroids, 200, 42);
251        assert!(
252            (d_ab - d_ba).abs() < 0.05,
253            "should be approximately symmetric: {d_ab} vs {d_ba}"
254        );
255    }
256
257    #[test]
258    fn drift_convenience() {
259        let dist_t1: Vec<f32> = vec![0.5, 0.3, 0.2];
260        let dist_t2: Vec<f32> = vec![0.2, 0.3, 0.5];
261        let centroids: Vec<&[f32]> =
262            vec![&[1.0f32, 0.0] as &[f32], &[0.0f32, 1.0], &[-1.0f32, 0.0]];
263        let d = wasserstein_drift(&dist_t1, &dist_t2, &centroids, 100);
264        assert!(d > 0.0);
265    }
266}