1pub fn wasserstein_1d(a: &[f64], b: &[f64]) -> f64 {
34 assert_eq!(a.len(), b.len(), "distributions must have equal length");
35 let k = a.len();
36 if k == 0 {
37 return 0.0;
38 }
39
40 let mut cum_a = 0.0;
43 let mut cum_b = 0.0;
44 let mut distance = 0.0;
45
46 for i in 0..k {
47 cum_a += a[i];
48 cum_b += b[i];
49 distance += (cum_a - cum_b).abs();
50 }
51
52 distance
53}
54
55pub fn sliced_wasserstein(
73 a: &[f64],
74 b: &[f64],
75 centroids: &[&[f32]],
76 n_projections: usize,
77 seed: u64,
78) -> f64 {
79 let k = a.len();
80 assert_eq!(k, b.len(), "distributions must have equal length");
81 assert_eq!(k, centroids.len(), "must have one centroid per bin");
82 if k == 0 {
83 return 0.0;
84 }
85
86 let dim = centroids[0].len();
87 let mut total = 0.0;
88
89 let mut rng_state = seed;
91 let mut next_rand = || -> f64 {
92 rng_state ^= rng_state << 13;
93 rng_state ^= rng_state >> 7;
94 rng_state ^= rng_state << 17;
95 (rng_state as f64 / u64::MAX as f64) * 2.0 - 1.0
97 };
98
99 for _ in 0..n_projections {
100 let mut direction: Vec<f64> = (0..dim).map(|_| next_rand()).collect();
102 let norm: f64 = direction.iter().map(|x| x * x).sum::<f64>().sqrt();
103 if norm < 1e-10 {
104 continue;
105 }
106 for d in &mut direction {
107 *d /= norm;
108 }
109
110 let mut projections: Vec<(f64, f64, f64)> = centroids
112 .iter()
113 .enumerate()
114 .map(|(i, c)| {
115 let proj: f64 = c
116 .iter()
117 .zip(direction.iter())
118 .map(|(&cv, &dv)| cv as f64 * dv)
119 .sum();
120 (proj, a[i], b[i])
121 })
122 .collect();
123
124 projections.sort_by(|x, y| x.0.partial_cmp(&y.0).unwrap());
126
127 let sorted_a: Vec<f64> = projections.iter().map(|p| p.1).collect();
129 let sorted_b: Vec<f64> = projections.iter().map(|p| p.2).collect();
130 total += wasserstein_1d(&sorted_a, &sorted_b);
131 }
132
133 total / n_projections as f64
134}
135
136pub fn wasserstein_drift(
148 dist_t1: &[f32],
149 dist_t2: &[f32],
150 centroids: &[&[f32]],
151 n_projections: usize,
152) -> f64 {
153 let a: Vec<f64> = dist_t1.iter().map(|&v| v as f64).collect();
154 let b: Vec<f64> = dist_t2.iter().map(|&v| v as f64).collect();
155 sliced_wasserstein(&a, &b, centroids, n_projections, 42)
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn w1d_identical_distributions() {
164 let a = vec![0.25, 0.25, 0.25, 0.25];
165 let b = vec![0.25, 0.25, 0.25, 0.25];
166 assert!((wasserstein_1d(&a, &b)).abs() < 1e-10);
167 }
168
169 #[test]
170 fn w1d_shifted_mass() {
171 let a = vec![1.0, 0.0, 0.0, 0.0];
173 let b = vec![0.0, 0.0, 0.0, 1.0];
174 let d = wasserstein_1d(&a, &b);
175 assert!((d - 3.0).abs() < 1e-10, "expected 3.0, got {d}");
177 }
178
179 #[test]
180 fn w1d_adjacent_shift() {
181 let a = vec![1.0, 0.0, 0.0];
183 let b = vec![0.0, 1.0, 0.0];
184 let d = wasserstein_1d(&a, &b);
185 assert!((d - 1.0).abs() < 1e-10, "expected 1.0, got {d}");
186 }
187
188 #[test]
189 fn w1d_symmetric() {
190 let a = vec![0.5, 0.3, 0.2];
191 let b = vec![0.1, 0.4, 0.5];
192 assert!((wasserstein_1d(&a, &b) - wasserstein_1d(&b, &a)).abs() < 1e-10);
193 }
194
195 #[test]
196 fn w1d_non_negative() {
197 let a = vec![0.7, 0.2, 0.1];
198 let b = vec![0.1, 0.3, 0.6];
199 assert!(wasserstein_1d(&a, &b) >= 0.0);
200 }
201
202 #[test]
203 fn w1d_triangle_inequality() {
204 let a = vec![0.5, 0.3, 0.2];
205 let b = vec![0.1, 0.4, 0.5];
206 let c = vec![0.3, 0.3, 0.4];
207 let d_ab = wasserstein_1d(&a, &b);
208 let d_bc = wasserstein_1d(&b, &c);
209 let d_ac = wasserstein_1d(&a, &c);
210 assert!(
211 d_ac <= d_ab + d_bc + 1e-10,
212 "triangle inequality: d(a,c)={d_ac} > d(a,b)+d(b,c)={}",
213 d_ab + d_bc
214 );
215 }
216
217 #[test]
218 fn sliced_identical_zero() {
219 let a = vec![0.5, 0.3, 0.2];
220 let b = vec![0.5, 0.3, 0.2];
221 let centroids: Vec<&[f32]> =
222 vec![&[1.0f32, 0.0] as &[f32], &[0.0f32, 1.0], &[-1.0f32, 0.0]];
223 let d = sliced_wasserstein(&a, &b, ¢roids, 100, 42);
224 assert!(
225 d < 1e-10,
226 "identical distributions should have distance ~0, got {d}"
227 );
228 }
229
230 #[test]
231 fn sliced_different_positive() {
232 let a = vec![1.0, 0.0, 0.0];
233 let b = vec![0.0, 0.0, 1.0];
234 let centroids: Vec<&[f32]> =
235 vec![&[1.0f32, 0.0] as &[f32], &[0.0f32, 0.0], &[-1.0f32, 0.0]];
236 let d = sliced_wasserstein(&a, &b, ¢roids, 100, 42);
237 assert!(
238 d > 0.1,
239 "different distributions should have positive distance, got {d}"
240 );
241 }
242
243 #[test]
244 fn sliced_symmetric() {
245 let a = vec![0.6, 0.3, 0.1];
246 let b = vec![0.1, 0.2, 0.7];
247 let centroids: Vec<&[f32]> =
248 vec![&[1.0f32, 0.0] as &[f32], &[0.0f32, 1.0], &[-1.0f32, -1.0]];
249 let d_ab = sliced_wasserstein(&a, &b, ¢roids, 200, 42);
250 let d_ba = sliced_wasserstein(&b, &a, ¢roids, 200, 42);
251 assert!(
252 (d_ab - d_ba).abs() < 0.05,
253 "should be approximately symmetric: {d_ab} vs {d_ba}"
254 );
255 }
256
257 #[test]
258 fn drift_convenience() {
259 let dist_t1: Vec<f32> = vec![0.5, 0.3, 0.2];
260 let dist_t2: Vec<f32> = vec![0.2, 0.3, 0.5];
261 let centroids: Vec<&[f32]> =
262 vec![&[1.0f32, 0.0] as &[f32], &[0.0f32, 1.0], &[-1.0f32, 0.0]];
263 let d = wasserstein_drift(&dist_t1, &dist_t2, ¢roids, 100);
264 assert!(d > 0.0);
265 }
266}