cvx_index/metrics/
cosine.rs

1//! Cosine distance metric.
2//!
3//! Cosine distance = $1 - \text{cosine\_similarity}(a, b)$
4//!
5//! Range: `[0.0, 2.0]` where 0.0 = identical direction, 1.0 = orthogonal, 2.0 = opposite.
6
7use cvx_core::DistanceMetric;
8
9use super::simd_ops::{dot_product_simd, norm_squared_simd};
10
11/// Cosine distance: $d(a, b) = 1 - \frac{a \cdot b}{\|a\| \cdot \|b\|}$.
12///
13/// # Example
14///
15/// ```
16/// use cvx_core::DistanceMetric;
17/// use cvx_index::metrics::CosineDistance;
18///
19/// let d = CosineDistance;
20/// let a = vec![1.0, 0.0];
21/// let b = vec![1.0, 0.0];
22/// assert!(d.distance(&a, &b) < 1e-5); // same direction → 0.0
23/// ```
24#[derive(Clone, Copy)]
25pub struct CosineDistance;
26
27impl DistanceMetric for CosineDistance {
28    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
29        let dot = dot_product_simd(a, b);
30        let norm_a = norm_squared_simd(a).sqrt();
31        let norm_b = norm_squared_simd(b).sqrt();
32
33        let denom = norm_a * norm_b;
34        if denom < f32::EPSILON {
35            // At least one zero vector — define distance as 1.0 (orthogonal)
36            return 1.0;
37        }
38
39        let cosine_sim = dot / denom;
40        // Clamp to [-1, 1] for numerical stability (FP rounding can exceed)
41        1.0 - cosine_sim.clamp(-1.0, 1.0)
42    }
43
44    fn name(&self) -> &str {
45        "cosine"
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52
53    #[test]
54    fn identical_vectors_distance_zero() {
55        let d = CosineDistance;
56        let v = vec![1.0, 2.0, 3.0];
57        assert!(d.distance(&v, &v) < 1e-5);
58    }
59
60    #[test]
61    fn opposite_vectors_distance_two() {
62        let d = CosineDistance;
63        let a = vec![1.0, 0.0, 0.0];
64        let b = vec![-1.0, 0.0, 0.0];
65        assert!((d.distance(&a, &b) - 2.0).abs() < 1e-5);
66    }
67
68    #[test]
69    fn orthogonal_vectors_distance_one() {
70        let d = CosineDistance;
71        let a = vec![1.0, 0.0];
72        let b = vec![0.0, 1.0];
73        assert!((d.distance(&a, &b) - 1.0).abs() < 1e-5);
74    }
75
76    #[test]
77    fn zero_vector_returns_one() {
78        let d = CosineDistance;
79        let a = vec![0.0, 0.0, 0.0];
80        let b = vec![1.0, 2.0, 3.0];
81        assert!((d.distance(&a, &b) - 1.0).abs() < 1e-5);
82    }
83
84    #[test]
85    fn d768_works() {
86        let d = CosineDistance;
87        let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.001).collect();
88        let b: Vec<f32> = (0..768).map(|i| ((768 - i) as f32) * 0.001).collect();
89        let result = d.distance(&a, &b);
90        assert!((0.0..=2.0).contains(&result));
91    }
92
93    #[test]
94    fn name_is_cosine() {
95        assert_eq!(CosineDistance.name(), "cosine");
96    }
97}
98
99#[cfg(test)]
100mod proptests {
101    use super::*;
102    use proptest::prelude::*;
103
104    proptest! {
105        /// Cosine distance is symmetric
106        #[test]
107        fn symmetric(
108            a in prop::collection::vec(0.01f32..10.0, 32..=32),
109            b in prop::collection::vec(0.01f32..10.0, 32..=32),
110        ) {
111            let d = CosineDistance;
112            let ab = d.distance(&a, &b);
113            let ba = d.distance(&b, &a);
114            prop_assert!((ab - ba).abs() < 1e-5, "d(a,b)={ab}, d(b,a)={ba}");
115        }
116
117        /// Cosine distance is in [0.0, 2.0]
118        #[test]
119        fn in_range(
120            a in prop::collection::vec(0.01f32..10.0, 32..=32),
121            b in prop::collection::vec(0.01f32..10.0, 32..=32),
122        ) {
123            let d = CosineDistance;
124            let dist = d.distance(&a, &b);
125            prop_assert!((-1e-5..=2.0 + 1e-5).contains(&dist), "dist={dist}");
126        }
127
128        /// Distance to self is zero
129        #[test]
130        fn identity(
131            a in prop::collection::vec(0.01f32..10.0, 32..=32),
132        ) {
133            let d = CosineDistance;
134            prop_assert!(d.distance(&a, &a) < 1e-4);
135        }
136    }
137}