cvx_analytics/
procrustes.rs

1#![allow(clippy::needless_range_loop)]
2//! Procrustes alignment for cross-model embedding comparison (RFC-012 P10).
3//!
4//! When switching embedding models (e.g., MentalRoBERTa → all-MiniLM),
5//! vectors from the old model are incompatible with the new model's space.
6//! Procrustes alignment finds the optimal orthogonal rotation R that
7//! minimizes ||A - BR||² where A is the target space and B is the source.
8//!
9//! # Algorithm
10//!
11//! Given N corresponding vector pairs (a_i, b_i):
12//! 1. Center both sets: A' = A - mean(A), B' = B - mean(B)
13//! 2. Compute cross-covariance: M = A'^T B'
14//! 3. SVD: M = U Σ V^T
15//! 4. Rotation: R = V U^T
16//! 5. Scale: s = trace(Σ) / trace(B'^T B')
17//! 6. Transform: b_aligned = s × (b - mean(B)) × R + mean(A)
18//!
19//! # Example
20//!
21//! ```
22//! use cvx_analytics::procrustes::{ProcrustesTransform, fit_procrustes};
23//!
24//! let source = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
25//! let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0], vec![-1.0, 1.0]];
26//!
27//! let transform = fit_procrustes(&source, &target).unwrap();
28//! let aligned = transform.apply(&[1.0, 0.0]);
29//! // aligned should be close to [0.0, 1.0] (90° rotation)
30//! ```
31
32/// A fitted Procrustes transformation.
33#[derive(Debug, Clone)]
34pub struct ProcrustesTransform {
35    /// Rotation matrix (D × D), stored row-major.
36    pub rotation: Vec<Vec<f64>>,
37    /// Scale factor.
38    pub scale: f64,
39    /// Source centroid (subtracted before rotation).
40    pub source_mean: Vec<f64>,
41    /// Target centroid (added after rotation).
42    pub target_mean: Vec<f64>,
43    /// Dimensionality.
44    pub dim: usize,
45    /// Alignment error (Frobenius norm of residual).
46    pub error: f64,
47}
48
49impl ProcrustesTransform {
50    /// Apply the transform to a source vector → target space.
51    pub fn apply(&self, source_vec: &[f32]) -> Vec<f32> {
52        let d = self.dim;
53        assert_eq!(source_vec.len(), d);
54
55        // Center
56        let centered: Vec<f64> = source_vec
57            .iter()
58            .zip(&self.source_mean)
59            .map(|(&v, &m)| v as f64 - m)
60            .collect();
61
62        // Rotate + scale
63        let mut rotated = vec![0.0f64; d];
64        for i in 0..d {
65            for j in 0..d {
66                rotated[i] += centered[j] * self.rotation[j][i];
67            }
68            rotated[i] *= self.scale;
69        }
70
71        // Translate to target space
72        rotated
73            .iter()
74            .zip(&self.target_mean)
75            .map(|(&r, &m)| (r + m) as f32)
76            .collect()
77    }
78
79    /// Apply to a batch of vectors.
80    pub fn apply_batch(&self, vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
81        vectors.iter().map(|v| self.apply(v)).collect()
82    }
83}
84
85/// Fit a Procrustes transformation from source to target vectors.
86///
87/// Both sets must have the same number of vectors and same dimensionality.
88/// Vectors should be corresponding pairs (same entity in different models).
89///
90/// Returns `None` if inputs are empty or have mismatched dimensions.
91pub fn fit_procrustes(source: &[Vec<f32>], target: &[Vec<f32>]) -> Option<ProcrustesTransform> {
92    let n = source.len();
93    if n == 0 || n != target.len() {
94        return None;
95    }
96    let d = source[0].len();
97    if d == 0 || target[0].len() != d {
98        return None;
99    }
100
101    // 1. Compute centroids
102    let mut src_mean = vec![0.0f64; d];
103    let mut tgt_mean = vec![0.0f64; d];
104    for i in 0..n {
105        for j in 0..d {
106            src_mean[j] += source[i][j] as f64;
107            tgt_mean[j] += target[i][j] as f64;
108        }
109    }
110    let inv_n = 1.0 / n as f64;
111    for j in 0..d {
112        src_mean[j] *= inv_n;
113        tgt_mean[j] *= inv_n;
114    }
115
116    // 2. Center
117    let mut a = vec![vec![0.0f64; d]; n]; // target centered
118    let mut b = vec![vec![0.0f64; d]; n]; // source centered
119    for i in 0..n {
120        for j in 0..d {
121            a[i][j] = target[i][j] as f64 - tgt_mean[j];
122            b[i][j] = source[i][j] as f64 - src_mean[j];
123        }
124    }
125
126    // 3. Cross-covariance M = A^T B (D × D)
127    let mut m = vec![vec![0.0f64; d]; d];
128    for i in 0..n {
129        for j in 0..d {
130            for k in 0..d {
131                m[j][k] += a[i][j] * b[i][k];
132            }
133        }
134    }
135
136    // 4. SVD via Jacobi iterations (simple, works for moderate D)
137    let (u, sigma, vt) = svd_jacobi(&m, d);
138
139    // 5. Rotation R = V U^T (V is rows of vt transposed)
140    let mut rotation = vec![vec![0.0f64; d]; d];
141    for i in 0..d {
142        for j in 0..d {
143            for k in 0..d {
144                rotation[i][j] += vt[k][i] * u[j][k]; // V^T^T * U^T = V * U^T
145            }
146        }
147    }
148
149    // 6. Scale
150    let trace_sigma: f64 = sigma.iter().sum();
151    let mut trace_btb = 0.0f64;
152    for i in 0..n {
153        for j in 0..d {
154            trace_btb += b[i][j] * b[i][j];
155        }
156    }
157    let scale = if trace_btb > 1e-12 {
158        trace_sigma / trace_btb
159    } else {
160        1.0
161    };
162
163    // 7. Compute error
164    let mut error = 0.0f64;
165    for i in 0..n {
166        let aligned = apply_rotation(&b[i], &rotation, scale);
167        for j in 0..d {
168            let diff = a[i][j] - aligned[j];
169            error += diff * diff;
170        }
171    }
172    error = error.sqrt() / n as f64;
173
174    Some(ProcrustesTransform {
175        rotation,
176        scale,
177        source_mean: src_mean,
178        target_mean: tgt_mean,
179        dim: d,
180        error,
181    })
182}
183
184fn apply_rotation(vec: &[f64], rotation: &[Vec<f64>], scale: f64) -> Vec<f64> {
185    let d = vec.len();
186    let mut result = vec![0.0f64; d];
187    for i in 0..d {
188        for j in 0..d {
189            result[i] += vec[j] * rotation[j][i];
190        }
191        result[i] *= scale;
192    }
193    result
194}
195
196/// Simple Jacobi SVD for small-to-moderate matrices.
197/// Returns (U, singular_values, V^T).
198fn svd_jacobi(m: &[Vec<f64>], d: usize) -> (Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>) {
199    // Compute M^T M
200    let mut mtm = vec![vec![0.0f64; d]; d];
201    for i in 0..d {
202        for j in 0..d {
203            for k in 0..d {
204                mtm[i][j] += m[k][i] * m[k][j];
205            }
206        }
207    }
208
209    // Eigendecomposition of M^T M via Jacobi rotations
210    let (eigenvalues, v) = jacobi_eigendecomposition(&mtm, d);
211
212    // Singular values = sqrt(eigenvalues)
213    let sigma: Vec<f64> = eigenvalues.iter().map(|&e| e.max(0.0).sqrt()).collect();
214
215    // U = M V Σ^{-1}
216    let mut u = vec![vec![0.0f64; d]; d];
217    for i in 0..d {
218        for j in 0..d {
219            let mut sum = 0.0f64;
220            for k in 0..d {
221                sum += m[i][k] * v[k][j];
222            }
223            u[i][j] = if sigma[j] > 1e-12 {
224                sum / sigma[j]
225            } else {
226                0.0
227            };
228        }
229    }
230
231    // V^T
232    let mut vt = vec![vec![0.0f64; d]; d];
233    for i in 0..d {
234        for j in 0..d {
235            vt[i][j] = v[j][i];
236        }
237    }
238
239    (u, sigma, vt)
240}
241
242/// Jacobi eigendecomposition for symmetric matrices.
243fn jacobi_eigendecomposition(a: &[Vec<f64>], d: usize) -> (Vec<f64>, Vec<Vec<f64>>) {
244    let mut mat = a.to_vec();
245    let mut v = vec![vec![0.0f64; d]; d];
246    for i in 0..d {
247        v[i][i] = 1.0;
248    }
249
250    let max_iter = 100 * d * d;
251    for _ in 0..max_iter {
252        // Find largest off-diagonal element
253        let mut max_val = 0.0f64;
254        let mut p = 0;
255        let mut q = 1;
256        for i in 0..d {
257            for j in (i + 1)..d {
258                if mat[i][j].abs() > max_val {
259                    max_val = mat[i][j].abs();
260                    p = i;
261                    q = j;
262                }
263            }
264        }
265        if max_val < 1e-12 {
266            break;
267        }
268
269        // Compute rotation angle
270        let theta = if (mat[p][p] - mat[q][q]).abs() < 1e-12 {
271            std::f64::consts::FRAC_PI_4
272        } else {
273            0.5 * (2.0 * mat[p][q] / (mat[p][p] - mat[q][q])).atan()
274        };
275
276        let (sin_t, cos_t) = theta.sin_cos();
277
278        // Apply rotation to mat
279        let mut new_mat = mat.clone();
280        for i in 0..d {
281            if i != p && i != q {
282                new_mat[i][p] = cos_t * mat[i][p] + sin_t * mat[i][q];
283                new_mat[p][i] = new_mat[i][p];
284                new_mat[i][q] = -sin_t * mat[i][p] + cos_t * mat[i][q];
285                new_mat[q][i] = new_mat[i][q];
286            }
287        }
288        new_mat[p][p] =
289            cos_t * cos_t * mat[p][p] + 2.0 * sin_t * cos_t * mat[p][q] + sin_t * sin_t * mat[q][q];
290        new_mat[q][q] =
291            sin_t * sin_t * mat[p][p] - 2.0 * sin_t * cos_t * mat[p][q] + cos_t * cos_t * mat[q][q];
292        new_mat[p][q] = 0.0;
293        new_mat[q][p] = 0.0;
294        mat = new_mat;
295
296        // Update eigenvectors
297        for i in 0..d {
298            let vip = v[i][p];
299            let viq = v[i][q];
300            v[i][p] = cos_t * vip + sin_t * viq;
301            v[i][q] = -sin_t * vip + cos_t * viq;
302        }
303    }
304
305    let eigenvalues: Vec<f64> = (0..d).map(|i| mat[i][i]).collect();
306    (eigenvalues, v)
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn identity_alignment() {
315        // Same vectors → identity transform, zero error
316        let vecs: Vec<Vec<f32>> = vec![
317            vec![1.0, 0.0],
318            vec![0.0, 1.0],
319            vec![1.0, 1.0],
320            vec![-1.0, 0.5],
321        ];
322        let t = fit_procrustes(&vecs, &vecs).unwrap();
323        assert!(t.error < 0.01, "error = {}", t.error);
324        assert!((t.scale - 1.0).abs() < 0.01, "scale = {}", t.scale);
325
326        let aligned = t.apply(&vecs[0]);
327        assert!((aligned[0] - 1.0).abs() < 0.1);
328        assert!((aligned[1] - 0.0).abs() < 0.1);
329    }
330
331    #[test]
332    fn rotation_90_degrees() {
333        // Source rotated 90° CCW from target
334        let source = vec![
335            vec![1.0, 0.0],
336            vec![0.0, 1.0],
337            vec![-1.0, 0.0],
338            vec![0.0, -1.0],
339        ];
340        let target = vec![
341            vec![0.0, 1.0],
342            vec![-1.0, 0.0],
343            vec![0.0, -1.0],
344            vec![1.0, 0.0],
345        ];
346
347        let t = fit_procrustes(&source, &target).unwrap();
348        assert!(t.error < 0.1, "error = {}", t.error);
349
350        let aligned = t.apply(&[1.0, 0.0]);
351        assert!(
352            (aligned[0] - 0.0).abs() < 0.2 && (aligned[1] - 1.0).abs() < 0.2,
353            "expected ~[0, 1], got {aligned:?}",
354        );
355    }
356
357    #[test]
358    fn higher_dimension() {
359        let d = 8;
360        let n = 20;
361        let mut rng = 42u64;
362        let next = |r: &mut u64| -> f32 {
363            *r = r.wrapping_mul(6364136223846793005).wrapping_add(1);
364            ((*r >> 33) as f32) / (u32::MAX as f32) - 0.5
365        };
366
367        let source: Vec<Vec<f32>> = (0..n)
368            .map(|_| (0..d).map(|_| next(&mut rng)).collect())
369            .collect();
370        // Target = source with small perturbation (should align well)
371        let target: Vec<Vec<f32>> = source
372            .iter()
373            .map(|v| v.iter().map(|&x| x + next(&mut rng) * 0.01).collect())
374            .collect();
375
376        let t = fit_procrustes(&source, &target).unwrap();
377        assert!(t.error < 0.1, "error = {} (d={d}, n={n})", t.error);
378    }
379
380    #[test]
381    fn empty_input() {
382        let empty: Vec<Vec<f32>> = vec![];
383        assert!(fit_procrustes(&empty, &empty).is_none());
384    }
385
386    #[test]
387    fn mismatched_sizes() {
388        let a = vec![vec![1.0, 0.0]];
389        let b = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
390        assert!(fit_procrustes(&a, &b).is_none());
391    }
392}