cvx_index/metrics/
cosine.rs1use cvx_core::DistanceMetric;
8
9use super::simd_ops::{dot_product_simd, norm_squared_simd};
10
11#[derive(Clone, Copy)]
25pub struct CosineDistance;
26
27impl DistanceMetric for CosineDistance {
28 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
29 let dot = dot_product_simd(a, b);
30 let norm_a = norm_squared_simd(a).sqrt();
31 let norm_b = norm_squared_simd(b).sqrt();
32
33 let denom = norm_a * norm_b;
34 if denom < f32::EPSILON {
35 return 1.0;
37 }
38
39 let cosine_sim = dot / denom;
40 1.0 - cosine_sim.clamp(-1.0, 1.0)
42 }
43
44 fn name(&self) -> &str {
45 "cosine"
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52
53 #[test]
54 fn identical_vectors_distance_zero() {
55 let d = CosineDistance;
56 let v = vec![1.0, 2.0, 3.0];
57 assert!(d.distance(&v, &v) < 1e-5);
58 }
59
60 #[test]
61 fn opposite_vectors_distance_two() {
62 let d = CosineDistance;
63 let a = vec![1.0, 0.0, 0.0];
64 let b = vec![-1.0, 0.0, 0.0];
65 assert!((d.distance(&a, &b) - 2.0).abs() < 1e-5);
66 }
67
68 #[test]
69 fn orthogonal_vectors_distance_one() {
70 let d = CosineDistance;
71 let a = vec![1.0, 0.0];
72 let b = vec![0.0, 1.0];
73 assert!((d.distance(&a, &b) - 1.0).abs() < 1e-5);
74 }
75
76 #[test]
77 fn zero_vector_returns_one() {
78 let d = CosineDistance;
79 let a = vec![0.0, 0.0, 0.0];
80 let b = vec![1.0, 2.0, 3.0];
81 assert!((d.distance(&a, &b) - 1.0).abs() < 1e-5);
82 }
83
84 #[test]
85 fn d768_works() {
86 let d = CosineDistance;
87 let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.001).collect();
88 let b: Vec<f32> = (0..768).map(|i| ((768 - i) as f32) * 0.001).collect();
89 let result = d.distance(&a, &b);
90 assert!((0.0..=2.0).contains(&result));
91 }
92
93 #[test]
94 fn name_is_cosine() {
95 assert_eq!(CosineDistance.name(), "cosine");
96 }
97}
98
99#[cfg(test)]
100mod proptests {
101 use super::*;
102 use proptest::prelude::*;
103
104 proptest! {
105 #[test]
107 fn symmetric(
108 a in prop::collection::vec(0.01f32..10.0, 32..=32),
109 b in prop::collection::vec(0.01f32..10.0, 32..=32),
110 ) {
111 let d = CosineDistance;
112 let ab = d.distance(&a, &b);
113 let ba = d.distance(&b, &a);
114 prop_assert!((ab - ba).abs() < 1e-5, "d(a,b)={ab}, d(b,a)={ba}");
115 }
116
117 #[test]
119 fn in_range(
120 a in prop::collection::vec(0.01f32..10.0, 32..=32),
121 b in prop::collection::vec(0.01f32..10.0, 32..=32),
122 ) {
123 let d = CosineDistance;
124 let dist = d.distance(&a, &b);
125 prop_assert!((-1e-5..=2.0 + 1e-5).contains(&dist), "dist={dist}");
126 }
127
128 #[test]
130 fn identity(
131 a in prop::collection::vec(0.01f32..10.0, 32..=32),
132 ) {
133 let d = CosineDistance;
134 prop_assert!(d.distance(&a, &a) < 1e-4);
135 }
136 }
137}