cvx_index/metrics/
l2.rs

1//! L2 (Euclidean) squared distance metric.
2//!
3//! Returns the squared Euclidean distance to avoid the sqrt cost in comparisons.
4//! Range: `[0.0, ∞)`.
5
6use cvx_core::DistanceMetric;
7
8use super::simd_ops::l2_squared_simd;
9
10/// Squared Euclidean distance: $d(a, b) = \sum_i (a_i - b_i)^2$.
11///
12/// Returns the **squared** distance (no sqrt) because ranking is preserved
13/// and sqrt is unnecessary for nearest-neighbor comparisons.
14///
15/// # Example
16///
17/// ```
18/// use cvx_core::DistanceMetric;
19/// use cvx_index::metrics::L2Distance;
20///
21/// let d = L2Distance;
22/// let a = vec![1.0, 0.0];
23/// let b = vec![0.0, 1.0];
24/// assert!((d.distance(&a, &b) - 2.0).abs() < 1e-5); // (1-0)^2 + (0-1)^2 = 2
25/// ```
26#[derive(Clone, Copy)]
27pub struct L2Distance;
28
29impl DistanceMetric for L2Distance {
30    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
31        l2_squared_simd(a, b)
32    }
33
34    fn name(&self) -> &str {
35        "l2"
36    }
37}
38
39#[cfg(test)]
40mod tests {
41    use super::*;
42
43    #[test]
44    fn identical_vectors_zero() {
45        let d = L2Distance;
46        let v = vec![1.0, 2.0, 3.0];
47        assert!(d.distance(&v, &v).abs() < 1e-6);
48    }
49
50    #[test]
51    fn unit_vectors_known_distance() {
52        let d = L2Distance;
53        let a = vec![3.0, 0.0];
54        let b = vec![0.0, 4.0];
55        assert!((d.distance(&a, &b) - 25.0).abs() < 1e-5); // 9 + 16 = 25
56    }
57
58    #[test]
59    fn d768_works() {
60        let d = L2Distance;
61        let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.001).collect();
62        let b: Vec<f32> = (0..768).map(|i| ((768 - i) as f32) * 0.001).collect();
63        let result = d.distance(&a, &b);
64        assert!(result >= 0.0);
65    }
66}
67
68#[cfg(test)]
69mod proptests {
70    use super::*;
71    use proptest::prelude::*;
72
73    proptest! {
74        #[test]
75        fn non_negative(
76            a in prop::collection::vec(-100.0f32..100.0, 32..=32),
77            b in prop::collection::vec(-100.0f32..100.0, 32..=32),
78        ) {
79            prop_assert!(L2Distance.distance(&a, &b) >= -1e-5);
80        }
81
82        #[test]
83        fn symmetric(
84            a in prop::collection::vec(-100.0f32..100.0, 32..=32),
85            b in prop::collection::vec(-100.0f32..100.0, 32..=32),
86        ) {
87            let ab = L2Distance.distance(&a, &b);
88            let ba = L2Distance.distance(&b, &a);
89            prop_assert!((ab - ba).abs() < 1e-3);
90        }
91
92        #[test]
93        fn identity(
94            a in prop::collection::vec(-100.0f32..100.0, 32..=32),
95        ) {
96            prop_assert!(L2Distance.distance(&a, &a).abs() < 1e-3);
97        }
98
99        /// Triangle inequality: d(a,c) <= d(a,b) + d(b,c) (using sqrt for true metric)
100        #[test]
101        fn triangle_inequality(
102            a in prop::collection::vec(-10.0f32..10.0, 16..=16),
103            b in prop::collection::vec(-10.0f32..10.0, 16..=16),
104            c in prop::collection::vec(-10.0f32..10.0, 16..=16),
105        ) {
106            let d = L2Distance;
107            let ab = d.distance(&a, &b).sqrt();
108            let bc = d.distance(&b, &c).sqrt();
109            let ac = d.distance(&a, &c).sqrt();
110            prop_assert!(ac <= ab + bc + 1e-3, "triangle inequality: {ac} <= {ab} + {bc}");
111        }
112    }
113}