cvx_index/metrics/
dot_product.rs

1//! Dot product distance metric (for maximum inner product search).
2//!
3//! Returns **negative** dot product so that smaller values = more similar
4//! (consistent with other distance metrics where lower = better).
5
6use cvx_core::DistanceMetric;
7
8use super::simd_ops::dot_product_simd;
9
10/// Negative dot product distance: $d(a, b) = -a \cdot b$.
11///
12/// Used for Maximum Inner Product Search (MIPS). The negation ensures that
13/// the most similar vectors (highest dot product) have the smallest distance.
14///
15/// # Example
16///
17/// ```
18/// use cvx_core::DistanceMetric;
19/// use cvx_index::metrics::DotProductDistance;
20///
21/// let d = DotProductDistance;
22/// let a = vec![1.0, 0.0];
23/// let b = vec![1.0, 0.0];
24/// assert!(d.distance(&a, &b) < 0.0); // same direction → negative (= similar)
25/// ```
26#[derive(Clone, Copy)]
27pub struct DotProductDistance;
28
29impl DistanceMetric for DotProductDistance {
30    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
31        -dot_product_simd(a, b)
32    }
33
34    fn name(&self) -> &str {
35        "dot"
36    }
37}
38
39#[cfg(test)]
40mod tests {
41    use super::*;
42
43    #[test]
44    fn same_direction_is_negative() {
45        let d = DotProductDistance;
46        let a = vec![1.0, 0.0];
47        let b = vec![2.0, 0.0];
48        assert!(d.distance(&a, &b) < 0.0);
49    }
50
51    #[test]
52    fn orthogonal_is_zero() {
53        let d = DotProductDistance;
54        let a = vec![1.0, 0.0];
55        let b = vec![0.0, 1.0];
56        assert!(d.distance(&a, &b).abs() < 1e-6);
57    }
58
59    #[test]
60    fn opposite_is_positive() {
61        let d = DotProductDistance;
62        let a = vec![1.0, 0.0];
63        let b = vec![-1.0, 0.0];
64        assert!(d.distance(&a, &b) > 0.0);
65    }
66
67    #[test]
68    fn more_similar_has_lower_distance() {
69        let d = DotProductDistance;
70        let query = vec![1.0, 0.0];
71        let close = vec![0.9, 0.1];
72        let far = vec![0.1, 0.9];
73        assert!(d.distance(&query, &close) < d.distance(&query, &far));
74    }
75}
76
77#[cfg(test)]
78mod proptests {
79    use super::*;
80    use proptest::prelude::*;
81
82    proptest! {
83        /// Dot product distance is anti-symmetric in sign: d(a,b) == d(b,a)
84        /// (because dot product is commutative, negation preserves equality)
85        #[test]
86        fn symmetric(
87            a in prop::collection::vec(-100.0f32..100.0, 32..=32),
88            b in prop::collection::vec(-100.0f32..100.0, 32..=32),
89        ) {
90            let ab = DotProductDistance.distance(&a, &b);
91            let ba = DotProductDistance.distance(&b, &a);
92            prop_assert!((ab - ba).abs() < 1e-3);
93        }
94    }
95}