cvx_index/metrics/
dot_product.rs1use cvx_core::DistanceMetric;
7
8use super::simd_ops::dot_product_simd;
9
10#[derive(Clone, Copy)]
27pub struct DotProductDistance;
28
29impl DistanceMetric for DotProductDistance {
30 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
31 -dot_product_simd(a, b)
32 }
33
34 fn name(&self) -> &str {
35 "dot"
36 }
37}
38
39#[cfg(test)]
40mod tests {
41 use super::*;
42
43 #[test]
44 fn same_direction_is_negative() {
45 let d = DotProductDistance;
46 let a = vec![1.0, 0.0];
47 let b = vec![2.0, 0.0];
48 assert!(d.distance(&a, &b) < 0.0);
49 }
50
51 #[test]
52 fn orthogonal_is_zero() {
53 let d = DotProductDistance;
54 let a = vec![1.0, 0.0];
55 let b = vec![0.0, 1.0];
56 assert!(d.distance(&a, &b).abs() < 1e-6);
57 }
58
59 #[test]
60 fn opposite_is_positive() {
61 let d = DotProductDistance;
62 let a = vec![1.0, 0.0];
63 let b = vec![-1.0, 0.0];
64 assert!(d.distance(&a, &b) > 0.0);
65 }
66
67 #[test]
68 fn more_similar_has_lower_distance() {
69 let d = DotProductDistance;
70 let query = vec![1.0, 0.0];
71 let close = vec![0.9, 0.1];
72 let far = vec![0.1, 0.9];
73 assert!(d.distance(&query, &close) < d.distance(&query, &far));
74 }
75}
76
77#[cfg(test)]
78mod proptests {
79 use super::*;
80 use proptest::prelude::*;
81
82 proptest! {
83 #[test]
86 fn symmetric(
87 a in prop::collection::vec(-100.0f32..100.0, 32..=32),
88 b in prop::collection::vec(-100.0f32..100.0, 32..=32),
89 ) {
90 let ab = DotProductDistance.distance(&a, &b);
91 let ba = DotProductDistance.distance(&b, &a);
92 prop_assert!((ab - ba).abs() < 1e-3);
93 }
94 }
95}