cvx_analytics/
cohort.rs

1//! Cohort-level temporal drift analytics.
2//!
3//! Measures how a **group** of entities evolves collectively in embedding space,
4//! complementing the single-entity analytics in [`crate::calculus`].
5//!
6//! # Key metrics
7//!
8//! | Metric | What it captures |
9//! |--------|-----------------|
10//! | `centroid_drift` | How the group center moved |
11//! | `dispersion_change` | Did the group spread or compress? |
12//! | `convergence_score` | Are entities moving in the same direction? |
13//! | `outliers` | Entities drifting abnormally vs the group |
14
15use crate::calculus::{DriftReport, drift_magnitude_l2, drift_report};
16use cvx_core::error::AnalyticsError;
17
18// ─── Types ──────────────────────────────────────────────────────────
19
20/// Full cohort drift analysis between two time points.
21#[derive(Debug, Clone)]
22pub struct CohortDriftReport {
23    /// Number of entities successfully analyzed (with data at both t1 and t2).
24    pub n_entities: usize,
25    /// Mean individual L2 drift magnitude across the cohort.
26    pub mean_drift_l2: f32,
27    /// Median individual L2 drift magnitude.
28    pub median_drift_l2: f32,
29    /// Standard deviation of individual L2 drift magnitudes.
30    pub std_drift_l2: f32,
31    /// Drift of the cohort centroid between t1 and t2.
32    pub centroid_drift: DriftReport,
33    /// Mean distance from entities to centroid at t1.
34    pub dispersion_t1: f32,
35    /// Mean distance from entities to centroid at t2.
36    pub dispersion_t2: f32,
37    /// Change in dispersion: positive = diverging, negative = converging.
38    pub dispersion_change: f32,
39    /// Cosine alignment of individual drift vectors (0 = random, 1 = all same direction).
40    pub convergence_score: f32,
41    /// Top-N most changed dimensions aggregated across the cohort.
42    pub top_dimensions: Vec<(usize, f32)>,
43    /// Entities flagged as outliers (|z-score| > 2.0).
44    pub outliers: Vec<CohortOutlier>,
45}
46
47/// An entity whose drift deviates significantly from the cohort.
48#[derive(Debug, Clone)]
49pub struct CohortOutlier {
50    /// Entity identifier.
51    pub entity_id: u64,
52    /// Individual L2 drift magnitude.
53    pub drift_magnitude: f32,
54    /// Z-score relative to cohort distribution.
55    pub z_score: f32,
56    /// Cosine similarity between this entity's drift direction and the cohort mean direction.
57    pub drift_direction_alignment: f32,
58}
59
60// ─── Helpers ────────────────────────────────────────────────────────
61
62/// Find the vector closest in time to `target` within a trajectory.
63///
64/// Returns `None` if the trajectory is empty.
65pub fn nearest_vector_at<'a>(trajectory: &'a [(i64, &'a [f32])], target: i64) -> Option<&'a [f32]> {
66    if trajectory.is_empty() {
67        return None;
68    }
69    let idx = trajectory
70        .iter()
71        .enumerate()
72        .min_by_key(|(_, (ts, _))| (ts - target).unsigned_abs())
73        .map(|(i, _)| i)?;
74    Some(trajectory[idx].1)
75}
76
77/// Compute the centroid (element-wise mean) of a set of vectors.
78fn centroid(vectors: &[&[f32]]) -> Vec<f32> {
79    if vectors.is_empty() {
80        return Vec::new();
81    }
82    let dim = vectors[0].len();
83    let n = vectors.len() as f32;
84    let mut result = vec![0.0f32; dim];
85    for v in vectors {
86        for (i, &val) in v.iter().enumerate() {
87            result[i] += val;
88        }
89    }
90    for val in &mut result {
91        *val /= n;
92    }
93    result
94}
95
96/// Mean cosine similarity of all drift vectors against their mean direction.
97///
98/// Returns 0.0 if fewer than 2 drift vectors, or if the mean drift is zero.
99fn compute_convergence_score(drift_vectors: &[Vec<f32>]) -> f32 {
100    if drift_vectors.len() < 2 {
101        return 0.0;
102    }
103    let dim = drift_vectors[0].len();
104    let n = drift_vectors.len() as f32;
105
106    // Mean drift direction
107    let mut mean_dir = vec![0.0f32; dim];
108    for dv in drift_vectors {
109        for (i, &val) in dv.iter().enumerate() {
110            mean_dir[i] += val;
111        }
112    }
113    for val in &mut mean_dir {
114        *val /= n;
115    }
116
117    let mean_norm: f32 = mean_dir.iter().map(|x| x * x).sum::<f32>().sqrt();
118    if mean_norm < 1e-12 {
119        return 0.0;
120    }
121
122    // Average cosine similarity of each drift vector against mean direction
123    let mut total_sim = 0.0f32;
124    let mut valid = 0usize;
125    for dv in drift_vectors {
126        let dv_norm: f32 = dv.iter().map(|x| x * x).sum::<f32>().sqrt();
127        if dv_norm < 1e-12 {
128            continue;
129        }
130        let dot: f32 = dv.iter().zip(mean_dir.iter()).map(|(a, b)| a * b).sum();
131        total_sim += (dot / (dv_norm * mean_norm)).clamp(-1.0, 1.0);
132        valid += 1;
133    }
134
135    if valid == 0 {
136        0.0
137    } else {
138        total_sim / valid as f32
139    }
140}
141
142// ─── Core function ──────────────────────────────────────────────────
143
144/// Compute cohort-level drift analysis.
145///
146/// Each trajectory in `trajectories` is `(entity_id, sorted_trajectory)` where
147/// the trajectory uses the standard CVX format `&[(i64, &[f32])]`.
148///
149/// The function finds the nearest vector to `t1` and `t2` for each entity,
150/// computes individual drift vectors, then aggregates cohort statistics.
151///
152/// # Errors
153///
154/// Returns [`AnalyticsError::InsufficientData`] if fewer than 2 entities have
155/// data at both t1 and t2.
156#[allow(clippy::type_complexity)]
157pub fn cohort_drift(
158    trajectories: &[(u64, &[(i64, &[f32])])],
159    t1: i64,
160    t2: i64,
161    top_n: usize,
162) -> Result<CohortDriftReport, AnalyticsError> {
163    // Collect per-entity data: (entity_id, vector_at_t1, vector_at_t2, drift_vector)
164    #[allow(clippy::type_complexity)]
165    let mut entity_data: Vec<(u64, Vec<f32>, Vec<f32>, Vec<f32>)> = Vec::new();
166
167    for &(entity_id, traj) in trajectories {
168        let Some(v1) = nearest_vector_at(traj, t1) else {
169            continue;
170        };
171        let Some(v2) = nearest_vector_at(traj, t2) else {
172            continue;
173        };
174        if v1.len() != v2.len() {
175            continue;
176        }
177        let drift_vec: Vec<f32> = v2.iter().zip(v1.iter()).map(|(a, b)| a - b).collect();
178        entity_data.push((entity_id, v1.to_vec(), v2.to_vec(), drift_vec));
179    }
180
181    let n = entity_data.len();
182    if n < 2 {
183        return Err(AnalyticsError::InsufficientData { needed: 2, have: n });
184    }
185
186    // ── Individual drift magnitudes ──
187
188    let drift_magnitudes: Vec<f32> = entity_data
189        .iter()
190        .map(|(_, v1, v2, _)| drift_magnitude_l2(v1, v2))
191        .collect();
192
193    let mean_drift_l2 = drift_magnitudes.iter().sum::<f32>() / n as f32;
194
195    let mut sorted_mags = drift_magnitudes.clone();
196    sorted_mags.sort_by(|a, b| a.partial_cmp(b).unwrap());
197    let median_drift_l2 = if n % 2 == 0 {
198        (sorted_mags[n / 2 - 1] + sorted_mags[n / 2]) / 2.0
199    } else {
200        sorted_mags[n / 2]
201    };
202
203    let variance: f32 = drift_magnitudes
204        .iter()
205        .map(|m| (m - mean_drift_l2) * (m - mean_drift_l2))
206        .sum::<f32>()
207        / (n - 1) as f32;
208    let std_drift_l2 = variance.sqrt();
209
210    // ── Centroid drift ──
211
212    let vectors_t1: Vec<&[f32]> = entity_data
213        .iter()
214        .map(|(_, v1, _, _)| v1.as_slice())
215        .collect();
216    let vectors_t2: Vec<&[f32]> = entity_data
217        .iter()
218        .map(|(_, _, v2, _)| v2.as_slice())
219        .collect();
220
221    let centroid_t1 = centroid(&vectors_t1);
222    let centroid_t2 = centroid(&vectors_t2);
223    let centroid_drift = drift_report(&centroid_t1, &centroid_t2, top_n);
224
225    // ── Dispersion ──
226
227    let dispersion_t1 = vectors_t1
228        .iter()
229        .map(|v| drift_magnitude_l2(v, &centroid_t1))
230        .sum::<f32>()
231        / n as f32;
232
233    let dispersion_t2 = vectors_t2
234        .iter()
235        .map(|v| drift_magnitude_l2(v, &centroid_t2))
236        .sum::<f32>()
237        / n as f32;
238
239    let dispersion_change = dispersion_t2 - dispersion_t1;
240
241    // ── Convergence score ──
242
243    let drift_vectors: Vec<Vec<f32>> = entity_data.iter().map(|(_, _, _, dv)| dv.clone()).collect();
244    let convergence_score = compute_convergence_score(&drift_vectors);
245
246    // ── Top dimensions (aggregated) ──
247
248    let dim = entity_data[0].3.len();
249    let mut mean_delta = vec![0.0f32; dim];
250    for (_, _, _, dv) in &entity_data {
251        for (i, &val) in dv.iter().enumerate() {
252            mean_delta[i] += val;
253        }
254    }
255    for val in &mut mean_delta {
256        *val /= n as f32;
257    }
258
259    let mut dim_changes: Vec<(usize, f32)> = mean_delta
260        .iter()
261        .enumerate()
262        .map(|(i, &v)| (i, v.abs()))
263        .collect();
264    dim_changes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
265    dim_changes.truncate(top_n);
266
267    // ── Outlier detection ──
268
269    // Mean drift direction for alignment computation
270    let mean_drift_dir: Vec<f32> = mean_delta.clone();
271    let mean_dir_norm: f32 = mean_drift_dir.iter().map(|x| x * x).sum::<f32>().sqrt();
272
273    let outliers: Vec<CohortOutlier> = entity_data
274        .iter()
275        .zip(drift_magnitudes.iter())
276        .filter_map(|((entity_id, _, _, dv), &mag)| {
277            let z = if std_drift_l2 > 1e-12 {
278                (mag - mean_drift_l2) / std_drift_l2
279            } else {
280                0.0
281            };
282
283            if z.abs() <= 2.0 {
284                return None;
285            }
286
287            let alignment = if mean_dir_norm > 1e-12 {
288                let dv_norm: f32 = dv.iter().map(|x| x * x).sum::<f32>().sqrt();
289                if dv_norm > 1e-12 {
290                    let dot: f32 = dv
291                        .iter()
292                        .zip(mean_drift_dir.iter())
293                        .map(|(a, b)| a * b)
294                        .sum();
295                    (dot / (dv_norm * mean_dir_norm)).clamp(-1.0, 1.0)
296                } else {
297                    0.0
298                }
299            } else {
300                0.0
301            };
302
303            Some(CohortOutlier {
304                entity_id: *entity_id,
305                drift_magnitude: mag,
306                z_score: z,
307                drift_direction_alignment: alignment,
308            })
309        })
310        .collect();
311
312    Ok(CohortDriftReport {
313        n_entities: n,
314        mean_drift_l2,
315        median_drift_l2,
316        std_drift_l2,
317        centroid_drift,
318        dispersion_t1,
319        dispersion_t2,
320        dispersion_change,
321        convergence_score,
322        top_dimensions: dim_changes,
323        outliers,
324    })
325}
326
327// ─── Tests ──────────────────────────────────────────────────────────
328
329#[cfg(test)]
330#[allow(
331    clippy::type_complexity,
332    clippy::needless_range_loop,
333    clippy::useless_vec
334)]
335mod tests {
336    use super::*;
337
338    /// Helper: convert owned trajectory to borrowed format.
339    fn as_refs(points: &[(i64, Vec<f32>)]) -> Vec<(i64, &[f32])> {
340        points.iter().map(|(t, v)| (*t, v.as_slice())).collect()
341    }
342
343    // ─── nearest_vector_at ──────────────────────────────────────
344
345    #[test]
346    fn nearest_vector_empty_trajectory() {
347        let traj: Vec<(i64, &[f32])> = vec![];
348        assert!(nearest_vector_at(&traj, 100).is_none());
349    }
350
351    #[test]
352    fn nearest_vector_exact_match() {
353        let owned = vec![
354            (100i64, vec![1.0f32, 0.0]),
355            (200, vec![0.0, 1.0]),
356            (300, vec![1.0, 1.0]),
357        ];
358        let traj = as_refs(&owned);
359        let v = nearest_vector_at(&traj, 200).unwrap();
360        assert_eq!(v, &[0.0, 1.0]);
361    }
362
363    #[test]
364    fn nearest_vector_between_timestamps() {
365        let owned = vec![
366            (100i64, vec![1.0f32, 0.0]),
367            (200, vec![0.0, 1.0]),
368            (300, vec![1.0, 1.0]),
369        ];
370        let traj = as_refs(&owned);
371        // 190 is closer to 200 than to 100
372        let v = nearest_vector_at(&traj, 190).unwrap();
373        assert_eq!(v, &[0.0, 1.0]);
374    }
375
376    #[test]
377    fn nearest_vector_before_first() {
378        let owned = vec![(100i64, vec![1.0f32, 2.0])];
379        let traj = as_refs(&owned);
380        let v = nearest_vector_at(&traj, 0).unwrap();
381        assert_eq!(v, &[1.0, 2.0]);
382    }
383
384    // ─── centroid ───────────────────────────────────────────────
385
386    #[test]
387    fn centroid_single_vector() {
388        let v = vec![2.0f32, 4.0, 6.0];
389        let c = centroid(&[v.as_slice()]);
390        assert_eq!(c, vec![2.0, 4.0, 6.0]);
391    }
392
393    #[test]
394    fn centroid_two_vectors() {
395        let v1 = vec![0.0f32, 0.0];
396        let v2 = vec![2.0, 4.0];
397        let c = centroid(&[v1.as_slice(), v2.as_slice()]);
398        assert!((c[0] - 1.0).abs() < 1e-6);
399        assert!((c[1] - 2.0).abs() < 1e-6);
400    }
401
402    #[test]
403    fn centroid_empty() {
404        let c = centroid(&[]);
405        assert!(c.is_empty());
406    }
407
408    // ─── convergence_score ──────────────────────────────────────
409
410    #[test]
411    fn convergence_all_same_direction() {
412        let drifts = vec![
413            vec![1.0f32, 0.0, 0.0],
414            vec![2.0, 0.0, 0.0],
415            vec![0.5, 0.0, 0.0],
416        ];
417        let score = compute_convergence_score(&drifts);
418        assert!((score - 1.0).abs() < 1e-6, "expected ~1.0, got {score}");
419    }
420
421    #[test]
422    fn convergence_opposite_directions() {
423        let drifts = vec![vec![1.0f32, 0.0], vec![-1.0, 0.0]];
424        let score = compute_convergence_score(&drifts);
425        // Mean direction is [0, 0], so score should be 0
426        assert!(
427            score.abs() < 1e-6,
428            "expected ~0.0 for zero mean, got {score}"
429        );
430    }
431
432    #[test]
433    fn convergence_orthogonal_directions() {
434        // 4 vectors pointing in 4 orthogonal-ish directions
435        let drifts = vec![
436            vec![1.0f32, 0.0],
437            vec![0.0, 1.0],
438            vec![-1.0, 0.0],
439            vec![0.0, -1.0],
440        ];
441        let score = compute_convergence_score(&drifts);
442        // Mean direction is ~[0, 0], score should be ~0
443        assert!(score.abs() < 1e-6, "expected ~0.0, got {score}");
444    }
445
446    #[test]
447    fn convergence_too_few_vectors() {
448        let drifts = vec![vec![1.0f32]];
449        assert_eq!(compute_convergence_score(&drifts), 0.0);
450    }
451
452    // ─── cohort_drift — basic functionality ─────────────────────
453
454    #[test]
455    fn cohort_drift_insufficient_data() {
456        let traj1 = vec![(100i64, vec![1.0f32, 0.0])];
457        let refs1 = as_refs(&traj1);
458
459        let trajectories: Vec<(u64, &[(i64, &[f32])])> = vec![(1, &refs1)];
460        let result = cohort_drift(&trajectories, 100, 200, 5);
461        assert!(result.is_err());
462        match result.unwrap_err() {
463            AnalyticsError::InsufficientData { needed, have } => {
464                assert_eq!(needed, 2);
465                assert_eq!(have, 1);
466            }
467            other => panic!("expected InsufficientData, got {other:?}"),
468        }
469    }
470
471    #[test]
472    fn cohort_drift_uniform_shift() {
473        // All entities shift by exactly [0.1, 0.0, 0.0]
474        let dim = 3;
475        let n_entities = 10;
476        let shift = 0.1f32;
477
478        let mut owned_trajs: Vec<Vec<(i64, Vec<f32>)>> = Vec::new();
479        for i in 0..n_entities {
480            let base: Vec<f32> = (0..dim).map(|d| (i * dim + d) as f32 * 0.1).collect();
481            let shifted: Vec<f32> = base
482                .iter()
483                .enumerate()
484                .map(|(d, &v)| if d == 0 { v + shift } else { v })
485                .collect();
486            owned_trajs.push(vec![(1000, base), (2000, shifted)]);
487        }
488
489        let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
490        let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
491            .iter()
492            .enumerate()
493            .map(|(i, t)| (i as u64, t.as_slice()))
494            .collect();
495
496        let report = cohort_drift(&trajectories, 1000, 2000, 5).unwrap();
497
498        assert_eq!(report.n_entities, n_entities);
499
500        // All drifts should be equal to shift magnitude
501        assert!(
502            (report.mean_drift_l2 - shift).abs() < 1e-5,
503            "expected mean drift ~{shift}, got {}",
504            report.mean_drift_l2
505        );
506        assert!(
507            (report.median_drift_l2 - shift).abs() < 1e-5,
508            "expected median ~{shift}, got {}",
509            report.median_drift_l2
510        );
511        assert!(
512            report.std_drift_l2 < 1e-5,
513            "expected std ~0 for uniform shift, got {}",
514            report.std_drift_l2
515        );
516
517        // All moving in same direction → convergence ~1.0
518        assert!(
519            report.convergence_score > 0.99,
520            "expected convergence ~1.0, got {}",
521            report.convergence_score
522        );
523
524        // Top dimension should be dim 0
525        assert_eq!(report.top_dimensions[0].0, 0);
526
527        // No outliers (all identical drift)
528        assert!(
529            report.outliers.is_empty(),
530            "expected no outliers, got {}",
531            report.outliers.len()
532        );
533    }
534
535    #[test]
536    fn cohort_drift_convergence_detected() {
537        // Entities start far apart, end close together → dispersion decreases
538        let owned_trajs = [
539            vec![(1000i64, vec![0.0f32, 0.0]), (2000, vec![0.5, 0.5])],
540            vec![(1000, vec![2.0, 0.0]), (2000, vec![0.5, 0.5])],
541            vec![(1000, vec![0.0, 2.0]), (2000, vec![0.5, 0.5])],
542        ];
543
544        let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
545        let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
546            .iter()
547            .enumerate()
548            .map(|(i, t)| (i as u64, t.as_slice()))
549            .collect();
550
551        let report = cohort_drift(&trajectories, 1000, 2000, 2).unwrap();
552
553        assert!(
554            report.dispersion_change < 0.0,
555            "expected negative dispersion change (convergence), got {}",
556            report.dispersion_change
557        );
558        assert!(
559            report.dispersion_t2 < report.dispersion_t1,
560            "t2 dispersion ({}) should be less than t1 ({})",
561            report.dispersion_t2,
562            report.dispersion_t1
563        );
564    }
565
566    #[test]
567    fn cohort_drift_divergence_detected() {
568        // Entities start close together, end far apart → dispersion increases
569        let owned_trajs = [
570            vec![(1000i64, vec![0.5f32, 0.5]), (2000, vec![0.0, 0.0])],
571            vec![(1000, vec![0.5, 0.5]), (2000, vec![2.0, 0.0])],
572            vec![(1000, vec![0.5, 0.5]), (2000, vec![0.0, 2.0])],
573        ];
574
575        let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
576        let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
577            .iter()
578            .enumerate()
579            .map(|(i, t)| (i as u64, t.as_slice()))
580            .collect();
581
582        let report = cohort_drift(&trajectories, 1000, 2000, 2).unwrap();
583
584        assert!(
585            report.dispersion_change > 0.0,
586            "expected positive dispersion change (divergence), got {}",
587            report.dispersion_change
588        );
589    }
590
591    #[test]
592    fn cohort_drift_outlier_detection() {
593        // 9 entities with tiny drift + 1 entity with massive drift
594        let dim = 4;
595        let mut owned_trajs: Vec<Vec<(i64, Vec<f32>)>> = Vec::new();
596
597        // 9 normal entities: drift of 0.01 in dim 0
598        for i in 0..9u64 {
599            let base: Vec<f32> = vec![i as f32 * 0.1; dim];
600            let shifted: Vec<f32> = base
601                .iter()
602                .enumerate()
603                .map(|(d, &v)| if d == 0 { v + 0.01 } else { v })
604                .collect();
605            owned_trajs.push(vec![(1000, base), (2000, shifted)]);
606        }
607
608        // 1 outlier: drift of 10.0 in dim 0
609        let base = vec![0.5f32; dim];
610        let shifted: Vec<f32> = base
611            .iter()
612            .enumerate()
613            .map(|(d, &v)| if d == 0 { v + 10.0 } else { v })
614            .collect();
615        owned_trajs.push(vec![(1000, base), (2000, shifted)]);
616
617        let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
618        let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
619            .iter()
620            .enumerate()
621            .map(|(i, t)| (i as u64, t.as_slice()))
622            .collect();
623
624        let report = cohort_drift(&trajectories, 1000, 2000, 3).unwrap();
625
626        assert_eq!(report.n_entities, 10);
627        assert!(!report.outliers.is_empty(), "expected at least one outlier");
628
629        // The outlier should be entity 9
630        let outlier = report.outliers.iter().find(|o| o.entity_id == 9);
631        assert!(outlier.is_some(), "entity 9 should be flagged as outlier");
632        let outlier = outlier.unwrap();
633        assert!(
634            outlier.z_score > 2.0,
635            "outlier z-score should be > 2.0, got {}",
636            outlier.z_score
637        );
638        assert!(
639            outlier.drift_magnitude > 9.0,
640            "outlier drift should be large, got {}",
641            outlier.drift_magnitude
642        );
643    }
644
645    #[test]
646    fn cohort_drift_centroid_drift_matches_manual() {
647        // 2 entities, manually compute expected centroid drift
648        let owned_trajs = [
649            vec![(1000i64, vec![0.0f32, 0.0]), (2000, vec![1.0, 0.0])],
650            vec![(1000, vec![2.0, 0.0]), (2000, vec![3.0, 0.0])],
651        ];
652
653        let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
654        let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
655            .iter()
656            .enumerate()
657            .map(|(i, t)| (i as u64, t.as_slice()))
658            .collect();
659
660        let report = cohort_drift(&trajectories, 1000, 2000, 2).unwrap();
661
662        // Centroid at t1 = [1.0, 0.0], centroid at t2 = [2.0, 0.0]
663        // Centroid drift L2 = 1.0
664        assert!(
665            (report.centroid_drift.l2_magnitude - 1.0).abs() < 1e-5,
666            "expected centroid drift 1.0, got {}",
667            report.centroid_drift.l2_magnitude
668        );
669    }
670
671    #[test]
672    fn cohort_drift_no_data_at_one_timepoint() {
673        // Entity 1 has data only at t1, entity 2 has data at both, entity 3 only at t2
674        // Entity 1's nearest to t2=2000 will be its only point at t1=1000
675        // All entities will actually be included since nearest_vector_at finds closest
676        let owned_trajs = [
677            vec![(1000i64, vec![1.0f32, 0.0])],
678            vec![(1000, vec![2.0, 0.0]), (2000, vec![3.0, 0.0])],
679            vec![(2000i64, vec![4.0, 0.0])],
680        ];
681
682        let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
683        let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
684            .iter()
685            .enumerate()
686            .map(|(i, t)| (i as u64, t.as_slice()))
687            .collect();
688
689        // Should succeed with all 3 entities (nearest_vector_at finds closest available)
690        let result = cohort_drift(&trajectories, 1000, 2000, 2);
691        assert!(result.is_ok());
692        assert_eq!(result.unwrap().n_entities, 3);
693    }
694
695    #[test]
696    fn cohort_drift_stationary_cohort() {
697        // All entities stay in the same place
698        let owned_trajs = [
699            vec![
700                (1000i64, vec![1.0f32, 2.0, 3.0]),
701                (2000, vec![1.0, 2.0, 3.0]),
702            ],
703            vec![(1000, vec![4.0, 5.0, 6.0]), (2000, vec![4.0, 5.0, 6.0])],
704            vec![(1000, vec![7.0, 8.0, 9.0]), (2000, vec![7.0, 8.0, 9.0])],
705        ];
706
707        let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
708        let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
709            .iter()
710            .enumerate()
711            .map(|(i, t)| (i as u64, t.as_slice()))
712            .collect();
713
714        let report = cohort_drift(&trajectories, 1000, 2000, 3).unwrap();
715
716        assert!(
717            report.mean_drift_l2 < 1e-6,
718            "stationary cohort should have ~0 drift"
719        );
720        assert!(report.median_drift_l2 < 1e-6);
721        assert!(report.centroid_drift.l2_magnitude < 1e-6);
722        assert!(
723            (report.dispersion_change).abs() < 1e-6,
724            "dispersion should not change"
725        );
726        assert!(report.outliers.is_empty());
727    }
728
729    #[test]
730    fn cohort_drift_high_dimensional() {
731        // Sanity check with 128-dim vectors
732        let dim = 128;
733        let n_entities = 20;
734        let mut owned_trajs = Vec::new();
735
736        for i in 0..n_entities {
737            let base: Vec<f32> = (0..dim)
738                .map(|d| ((i * dim + d) as f32 * 0.01).sin())
739                .collect();
740            let shifted: Vec<f32> = base.iter().map(|v| v + 0.05).collect();
741            owned_trajs.push(vec![(1000i64, base), (2000, shifted)]);
742        }
743
744        let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
745        let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
746            .iter()
747            .enumerate()
748            .map(|(i, t)| (i as u64, t.as_slice()))
749            .collect();
750
751        let report = cohort_drift(&trajectories, 1000, 2000, 10).unwrap();
752
753        assert_eq!(report.n_entities, n_entities);
754        assert!(report.mean_drift_l2 > 0.0);
755        assert_eq!(report.top_dimensions.len(), 10);
756        // Uniform shift → high convergence
757        assert!(
758            report.convergence_score > 0.95,
759            "uniform shift should give high convergence, got {}",
760            report.convergence_score
761        );
762    }
763}