1use cvx_core::DistanceMetric;
7
8use super::simd_ops::l2_squared_simd;
9
10#[derive(Clone, Copy)]
27pub struct L2Distance;
28
29impl DistanceMetric for L2Distance {
30 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
31 l2_squared_simd(a, b)
32 }
33
34 fn name(&self) -> &str {
35 "l2"
36 }
37}
38
39#[cfg(test)]
40mod tests {
41 use super::*;
42
43 #[test]
44 fn identical_vectors_zero() {
45 let d = L2Distance;
46 let v = vec![1.0, 2.0, 3.0];
47 assert!(d.distance(&v, &v).abs() < 1e-6);
48 }
49
50 #[test]
51 fn unit_vectors_known_distance() {
52 let d = L2Distance;
53 let a = vec![3.0, 0.0];
54 let b = vec![0.0, 4.0];
55 assert!((d.distance(&a, &b) - 25.0).abs() < 1e-5); }
57
58 #[test]
59 fn d768_works() {
60 let d = L2Distance;
61 let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.001).collect();
62 let b: Vec<f32> = (0..768).map(|i| ((768 - i) as f32) * 0.001).collect();
63 let result = d.distance(&a, &b);
64 assert!(result >= 0.0);
65 }
66}
67
68#[cfg(test)]
69mod proptests {
70 use super::*;
71 use proptest::prelude::*;
72
73 proptest! {
74 #[test]
75 fn non_negative(
76 a in prop::collection::vec(-100.0f32..100.0, 32..=32),
77 b in prop::collection::vec(-100.0f32..100.0, 32..=32),
78 ) {
79 prop_assert!(L2Distance.distance(&a, &b) >= -1e-5);
80 }
81
82 #[test]
83 fn symmetric(
84 a in prop::collection::vec(-100.0f32..100.0, 32..=32),
85 b in prop::collection::vec(-100.0f32..100.0, 32..=32),
86 ) {
87 let ab = L2Distance.distance(&a, &b);
88 let ba = L2Distance.distance(&b, &a);
89 prop_assert!((ab - ba).abs() < 1e-3);
90 }
91
92 #[test]
93 fn identity(
94 a in prop::collection::vec(-100.0f32..100.0, 32..=32),
95 ) {
96 prop_assert!(L2Distance.distance(&a, &a).abs() < 1e-3);
97 }
98
99 #[test]
101 fn triangle_inequality(
102 a in prop::collection::vec(-10.0f32..10.0, 16..=16),
103 b in prop::collection::vec(-10.0f32..10.0, 16..=16),
104 c in prop::collection::vec(-10.0f32..10.0, 16..=16),
105 ) {
106 let d = L2Distance;
107 let ab = d.distance(&a, &b).sqrt();
108 let bc = d.distance(&b, &c).sqrt();
109 let ac = d.distance(&a, &c).sqrt();
110 prop_assert!(ac <= ab + bc + 1e-3, "triangle inequality: {ac} <= {ab} + {bc}");
111 }
112 }
113}