1use pulp::Arch;
7
8pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
14 assert_eq!(a.len(), b.len(), "vectors must have equal length");
15
16 let arch = Arch::new();
17 arch.dispatch(DotProduct(a, b))
18}
19
20pub fn norm_squared_simd(a: &[f32]) -> f32 {
22 dot_product_simd(a, a)
23}
24
25pub fn l2_squared_simd(a: &[f32], b: &[f32]) -> f32 {
31 assert_eq!(a.len(), b.len(), "vectors must have equal length");
32
33 let arch = Arch::new();
34 arch.dispatch(L2Squared(a, b))
35}
36
37struct DotProduct<'a>(&'a [f32], &'a [f32]);
38
39impl pulp::WithSimd for DotProduct<'_> {
40 type Output = f32;
41
42 #[inline(always)]
43 fn with_simd<S: pulp::Simd>(self, simd: S) -> f32 {
44 let (a_head, a_tail) = S::as_simd_f32s(self.0);
45 let (b_head, b_tail) = S::as_simd_f32s(self.1);
46
47 let mut acc = simd.splat_f32s(0.0);
48 for (&a_chunk, &b_chunk) in a_head.iter().zip(b_head.iter()) {
49 acc = simd.mul_add_e_f32s(a_chunk, b_chunk, acc);
50 }
51
52 let mut sum = simd.reduce_sum_f32s(acc);
53 for (&a_val, &b_val) in a_tail.iter().zip(b_tail.iter()) {
54 sum = f32::mul_add(a_val, b_val, sum);
55 }
56 sum
57 }
58}
59
60struct L2Squared<'a>(&'a [f32], &'a [f32]);
61
62impl pulp::WithSimd for L2Squared<'_> {
63 type Output = f32;
64
65 #[inline(always)]
66 fn with_simd<S: pulp::Simd>(self, simd: S) -> f32 {
67 let (a_head, a_tail) = S::as_simd_f32s(self.0);
68 let (b_head, b_tail) = S::as_simd_f32s(self.1);
69
70 let mut acc_aa = simd.splat_f32s(0.0);
73 let mut acc_bb = simd.splat_f32s(0.0);
74 let mut acc_ab = simd.splat_f32s(0.0);
75
76 for (&a_chunk, &b_chunk) in a_head.iter().zip(b_head.iter()) {
77 acc_aa = simd.mul_add_e_f32s(a_chunk, a_chunk, acc_aa);
78 acc_bb = simd.mul_add_e_f32s(b_chunk, b_chunk, acc_bb);
79 acc_ab = simd.mul_add_e_f32s(a_chunk, b_chunk, acc_ab);
80 }
81
82 let mut sum_aa = simd.reduce_sum_f32s(acc_aa);
83 let mut sum_bb = simd.reduce_sum_f32s(acc_bb);
84 let mut sum_ab = simd.reduce_sum_f32s(acc_ab);
85
86 for (&a_val, &b_val) in a_tail.iter().zip(b_tail.iter()) {
87 sum_aa = f32::mul_add(a_val, a_val, sum_aa);
88 sum_bb = f32::mul_add(b_val, b_val, sum_bb);
89 sum_ab = f32::mul_add(a_val, b_val, sum_ab);
90 }
91
92 (sum_aa + sum_bb - 2.0 * sum_ab).max(0.0)
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[test]
103 fn dot_product_basic() {
104 let a = [1.0, 2.0, 3.0];
105 let b = [4.0, 5.0, 6.0];
106 let result = dot_product_simd(&a, &b);
107 assert!((result - 32.0).abs() < 1e-5); }
109
110 #[test]
111 fn dot_product_orthogonal() {
112 let a = [1.0, 0.0, 0.0];
113 let b = [0.0, 1.0, 0.0];
114 let result = dot_product_simd(&a, &b);
115 assert!(result.abs() < 1e-6);
116 }
117
118 #[test]
119 fn l2_squared_identical() {
120 let a = [1.0, 2.0, 3.0];
121 assert!(l2_squared_simd(&a, &a).abs() < 1e-6);
122 }
123
124 #[test]
125 fn l2_squared_known() {
126 let a = [1.0, 0.0];
127 let b = [0.0, 1.0];
128 let result = l2_squared_simd(&a, &b);
129 assert!((result - 2.0).abs() < 1e-5); }
131
132 #[test]
133 fn norm_squared_unit_vector() {
134 let a = [1.0, 0.0, 0.0];
135 assert!((norm_squared_simd(&a) - 1.0).abs() < 1e-6);
136 }
137
138 #[test]
139 fn handles_d768() {
140 let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.001).collect();
141 let b: Vec<f32> = (0..768).map(|i| ((768 - i) as f32) * 0.001).collect();
142 let _result = dot_product_simd(&a, &b);
143 let _result = l2_squared_simd(&a, &b);
144 }
145
146 #[test]
147 #[should_panic(expected = "vectors must have equal length")]
148 fn dot_product_panics_on_mismatch() {
149 dot_product_simd(&[1.0, 2.0], &[1.0]);
150 }
151}
152
153#[cfg(test)]
154mod proptests {
155 use super::*;
156 use proptest::prelude::*;
157
158 proptest! {
159 #[test]
161 fn dot_product_commutative(
162 a in prop::collection::vec(-100.0f32..100.0, 64..=64),
163 b in prop::collection::vec(-100.0f32..100.0, 64..=64),
164 ) {
165 let ab = dot_product_simd(&a, &b);
166 let ba = dot_product_simd(&b, &a);
167 prop_assert!((ab - ba).abs() < 1e-3, "a·b={ab}, b·a={ba}");
168 }
169
170 #[test]
172 fn l2_squared_non_negative(
173 a in prop::collection::vec(-100.0f32..100.0, 64..=64),
174 b in prop::collection::vec(-100.0f32..100.0, 64..=64),
175 ) {
176 prop_assert!(l2_squared_simd(&a, &b) >= -1e-6);
177 }
178
179 #[test]
181 fn l2_squared_symmetric(
182 a in prop::collection::vec(-100.0f32..100.0, 64..=64),
183 b in prop::collection::vec(-100.0f32..100.0, 64..=64),
184 ) {
185 let ab = l2_squared_simd(&a, &b);
186 let ba = l2_squared_simd(&b, &a);
187 prop_assert!((ab - ba).abs() < 1e-3, "d(a,b)={ab}, d(b,a)={ba}");
188 }
189
190 #[test]
192 fn l2_squared_identity(
193 a in prop::collection::vec(-100.0f32..100.0, 64..=64),
194 ) {
195 prop_assert!(l2_squared_simd(&a, &a).abs() < 1e-3);
196 }
197
198 #[test]
200 fn norm_squared_non_negative(
201 a in prop::collection::vec(-100.0f32..100.0, 64..=64),
202 ) {
203 prop_assert!(norm_squared_simd(&a) >= -1e-6);
204 }
205 }
206}