cvx_index/metrics/
simd_ops.rs

1//! Low-level SIMD operations using pulp.
2//!
3//! These are the building blocks for all distance metrics. Each function
4//! dispatches to the best SIMD ISA available at runtime.
5
6use pulp::Arch;
7
8/// Compute the dot product of two float slices using SIMD.
9///
10/// # Panics
11///
12/// Panics if `a.len() != b.len()`.
13pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
14    assert_eq!(a.len(), b.len(), "vectors must have equal length");
15
16    let arch = Arch::new();
17    arch.dispatch(DotProduct(a, b))
18}
19
20/// Compute the squared L2 norm of a float slice using SIMD.
21pub fn norm_squared_simd(a: &[f32]) -> f32 {
22    dot_product_simd(a, a)
23}
24
25/// Compute the sum of squared differences between two slices using SIMD.
26///
27/// # Panics
28///
29/// Panics if `a.len() != b.len()`.
30pub fn l2_squared_simd(a: &[f32], b: &[f32]) -> f32 {
31    assert_eq!(a.len(), b.len(), "vectors must have equal length");
32
33    let arch = Arch::new();
34    arch.dispatch(L2Squared(a, b))
35}
36
37struct DotProduct<'a>(&'a [f32], &'a [f32]);
38
39impl pulp::WithSimd for DotProduct<'_> {
40    type Output = f32;
41
42    #[inline(always)]
43    fn with_simd<S: pulp::Simd>(self, simd: S) -> f32 {
44        let (a_head, a_tail) = S::as_simd_f32s(self.0);
45        let (b_head, b_tail) = S::as_simd_f32s(self.1);
46
47        let mut acc = simd.splat_f32s(0.0);
48        for (&a_chunk, &b_chunk) in a_head.iter().zip(b_head.iter()) {
49            acc = simd.mul_add_e_f32s(a_chunk, b_chunk, acc);
50        }
51
52        let mut sum = simd.reduce_sum_f32s(acc);
53        for (&a_val, &b_val) in a_tail.iter().zip(b_tail.iter()) {
54            sum = f32::mul_add(a_val, b_val, sum);
55        }
56        sum
57    }
58}
59
60struct L2Squared<'a>(&'a [f32], &'a [f32]);
61
62impl pulp::WithSimd for L2Squared<'_> {
63    type Output = f32;
64
65    #[inline(always)]
66    fn with_simd<S: pulp::Simd>(self, simd: S) -> f32 {
67        let (a_head, a_tail) = S::as_simd_f32s(self.0);
68        let (b_head, b_tail) = S::as_simd_f32s(self.1);
69
70        // Compute ||a - b||² = ||a||² + ||b||² - 2(a·b)
71        // This avoids needing a SIMD subtract operation.
72        let mut acc_aa = simd.splat_f32s(0.0);
73        let mut acc_bb = simd.splat_f32s(0.0);
74        let mut acc_ab = simd.splat_f32s(0.0);
75
76        for (&a_chunk, &b_chunk) in a_head.iter().zip(b_head.iter()) {
77            acc_aa = simd.mul_add_e_f32s(a_chunk, a_chunk, acc_aa);
78            acc_bb = simd.mul_add_e_f32s(b_chunk, b_chunk, acc_bb);
79            acc_ab = simd.mul_add_e_f32s(a_chunk, b_chunk, acc_ab);
80        }
81
82        let mut sum_aa = simd.reduce_sum_f32s(acc_aa);
83        let mut sum_bb = simd.reduce_sum_f32s(acc_bb);
84        let mut sum_ab = simd.reduce_sum_f32s(acc_ab);
85
86        for (&a_val, &b_val) in a_tail.iter().zip(b_tail.iter()) {
87            sum_aa = f32::mul_add(a_val, a_val, sum_aa);
88            sum_bb = f32::mul_add(b_val, b_val, sum_bb);
89            sum_ab = f32::mul_add(a_val, b_val, sum_ab);
90        }
91
92        // ||a-b||² = ||a||² + ||b||² - 2(a·b)
93        // Clamp to 0.0 for numerical stability (can be slightly negative due to FP)
94        (sum_aa + sum_bb - 2.0 * sum_ab).max(0.0)
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn dot_product_basic() {
104        let a = [1.0, 2.0, 3.0];
105        let b = [4.0, 5.0, 6.0];
106        let result = dot_product_simd(&a, &b);
107        assert!((result - 32.0).abs() < 1e-5); // 1*4 + 2*5 + 3*6 = 32
108    }
109
110    #[test]
111    fn dot_product_orthogonal() {
112        let a = [1.0, 0.0, 0.0];
113        let b = [0.0, 1.0, 0.0];
114        let result = dot_product_simd(&a, &b);
115        assert!(result.abs() < 1e-6);
116    }
117
118    #[test]
119    fn l2_squared_identical() {
120        let a = [1.0, 2.0, 3.0];
121        assert!(l2_squared_simd(&a, &a).abs() < 1e-6);
122    }
123
124    #[test]
125    fn l2_squared_known() {
126        let a = [1.0, 0.0];
127        let b = [0.0, 1.0];
128        let result = l2_squared_simd(&a, &b);
129        assert!((result - 2.0).abs() < 1e-5); // (1-0)^2 + (0-1)^2 = 2
130    }
131
132    #[test]
133    fn norm_squared_unit_vector() {
134        let a = [1.0, 0.0, 0.0];
135        assert!((norm_squared_simd(&a) - 1.0).abs() < 1e-6);
136    }
137
138    #[test]
139    fn handles_d768() {
140        let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.001).collect();
141        let b: Vec<f32> = (0..768).map(|i| ((768 - i) as f32) * 0.001).collect();
142        let _result = dot_product_simd(&a, &b);
143        let _result = l2_squared_simd(&a, &b);
144    }
145
146    #[test]
147    #[should_panic(expected = "vectors must have equal length")]
148    fn dot_product_panics_on_mismatch() {
149        dot_product_simd(&[1.0, 2.0], &[1.0]);
150    }
151}
152
153#[cfg(test)]
154mod proptests {
155    use super::*;
156    use proptest::prelude::*;
157
158    proptest! {
159        /// Dot product is commutative: a·b == b·a
160        #[test]
161        fn dot_product_commutative(
162            a in prop::collection::vec(-100.0f32..100.0, 64..=64),
163            b in prop::collection::vec(-100.0f32..100.0, 64..=64),
164        ) {
165            let ab = dot_product_simd(&a, &b);
166            let ba = dot_product_simd(&b, &a);
167            prop_assert!((ab - ba).abs() < 1e-3, "a·b={ab}, b·a={ba}");
168        }
169
170        /// L2 squared is non-negative
171        #[test]
172        fn l2_squared_non_negative(
173            a in prop::collection::vec(-100.0f32..100.0, 64..=64),
174            b in prop::collection::vec(-100.0f32..100.0, 64..=64),
175        ) {
176            prop_assert!(l2_squared_simd(&a, &b) >= -1e-6);
177        }
178
179        /// L2 squared is symmetric: d(a,b) == d(b,a)
180        #[test]
181        fn l2_squared_symmetric(
182            a in prop::collection::vec(-100.0f32..100.0, 64..=64),
183            b in prop::collection::vec(-100.0f32..100.0, 64..=64),
184        ) {
185            let ab = l2_squared_simd(&a, &b);
186            let ba = l2_squared_simd(&b, &a);
187            prop_assert!((ab - ba).abs() < 1e-3, "d(a,b)={ab}, d(b,a)={ba}");
188        }
189
190        /// L2 squared of identical vectors is zero
191        #[test]
192        fn l2_squared_identity(
193            a in prop::collection::vec(-100.0f32..100.0, 64..=64),
194        ) {
195            prop_assert!(l2_squared_simd(&a, &a).abs() < 1e-3);
196        }
197
198        /// Norm squared is non-negative
199        #[test]
200        fn norm_squared_non_negative(
201            a in prop::collection::vec(-100.0f32..100.0, 64..=64),
202        ) {
203            prop_assert!(norm_squared_simd(&a) >= -1e-6);
204        }
205    }
206}