1use crate::calculus::{DriftReport, drift_magnitude_l2, drift_report};
16use cvx_core::error::AnalyticsError;
17
18#[derive(Debug, Clone)]
22pub struct CohortDriftReport {
23 pub n_entities: usize,
25 pub mean_drift_l2: f32,
27 pub median_drift_l2: f32,
29 pub std_drift_l2: f32,
31 pub centroid_drift: DriftReport,
33 pub dispersion_t1: f32,
35 pub dispersion_t2: f32,
37 pub dispersion_change: f32,
39 pub convergence_score: f32,
41 pub top_dimensions: Vec<(usize, f32)>,
43 pub outliers: Vec<CohortOutlier>,
45}
46
47#[derive(Debug, Clone)]
49pub struct CohortOutlier {
50 pub entity_id: u64,
52 pub drift_magnitude: f32,
54 pub z_score: f32,
56 pub drift_direction_alignment: f32,
58}
59
60pub fn nearest_vector_at<'a>(trajectory: &'a [(i64, &'a [f32])], target: i64) -> Option<&'a [f32]> {
66 if trajectory.is_empty() {
67 return None;
68 }
69 let idx = trajectory
70 .iter()
71 .enumerate()
72 .min_by_key(|(_, (ts, _))| (ts - target).unsigned_abs())
73 .map(|(i, _)| i)?;
74 Some(trajectory[idx].1)
75}
76
77fn centroid(vectors: &[&[f32]]) -> Vec<f32> {
79 if vectors.is_empty() {
80 return Vec::new();
81 }
82 let dim = vectors[0].len();
83 let n = vectors.len() as f32;
84 let mut result = vec![0.0f32; dim];
85 for v in vectors {
86 for (i, &val) in v.iter().enumerate() {
87 result[i] += val;
88 }
89 }
90 for val in &mut result {
91 *val /= n;
92 }
93 result
94}
95
96fn compute_convergence_score(drift_vectors: &[Vec<f32>]) -> f32 {
100 if drift_vectors.len() < 2 {
101 return 0.0;
102 }
103 let dim = drift_vectors[0].len();
104 let n = drift_vectors.len() as f32;
105
106 let mut mean_dir = vec![0.0f32; dim];
108 for dv in drift_vectors {
109 for (i, &val) in dv.iter().enumerate() {
110 mean_dir[i] += val;
111 }
112 }
113 for val in &mut mean_dir {
114 *val /= n;
115 }
116
117 let mean_norm: f32 = mean_dir.iter().map(|x| x * x).sum::<f32>().sqrt();
118 if mean_norm < 1e-12 {
119 return 0.0;
120 }
121
122 let mut total_sim = 0.0f32;
124 let mut valid = 0usize;
125 for dv in drift_vectors {
126 let dv_norm: f32 = dv.iter().map(|x| x * x).sum::<f32>().sqrt();
127 if dv_norm < 1e-12 {
128 continue;
129 }
130 let dot: f32 = dv.iter().zip(mean_dir.iter()).map(|(a, b)| a * b).sum();
131 total_sim += (dot / (dv_norm * mean_norm)).clamp(-1.0, 1.0);
132 valid += 1;
133 }
134
135 if valid == 0 {
136 0.0
137 } else {
138 total_sim / valid as f32
139 }
140}
141
142#[allow(clippy::type_complexity)]
157pub fn cohort_drift(
158 trajectories: &[(u64, &[(i64, &[f32])])],
159 t1: i64,
160 t2: i64,
161 top_n: usize,
162) -> Result<CohortDriftReport, AnalyticsError> {
163 #[allow(clippy::type_complexity)]
165 let mut entity_data: Vec<(u64, Vec<f32>, Vec<f32>, Vec<f32>)> = Vec::new();
166
167 for &(entity_id, traj) in trajectories {
168 let Some(v1) = nearest_vector_at(traj, t1) else {
169 continue;
170 };
171 let Some(v2) = nearest_vector_at(traj, t2) else {
172 continue;
173 };
174 if v1.len() != v2.len() {
175 continue;
176 }
177 let drift_vec: Vec<f32> = v2.iter().zip(v1.iter()).map(|(a, b)| a - b).collect();
178 entity_data.push((entity_id, v1.to_vec(), v2.to_vec(), drift_vec));
179 }
180
181 let n = entity_data.len();
182 if n < 2 {
183 return Err(AnalyticsError::InsufficientData { needed: 2, have: n });
184 }
185
186 let drift_magnitudes: Vec<f32> = entity_data
189 .iter()
190 .map(|(_, v1, v2, _)| drift_magnitude_l2(v1, v2))
191 .collect();
192
193 let mean_drift_l2 = drift_magnitudes.iter().sum::<f32>() / n as f32;
194
195 let mut sorted_mags = drift_magnitudes.clone();
196 sorted_mags.sort_by(|a, b| a.partial_cmp(b).unwrap());
197 let median_drift_l2 = if n % 2 == 0 {
198 (sorted_mags[n / 2 - 1] + sorted_mags[n / 2]) / 2.0
199 } else {
200 sorted_mags[n / 2]
201 };
202
203 let variance: f32 = drift_magnitudes
204 .iter()
205 .map(|m| (m - mean_drift_l2) * (m - mean_drift_l2))
206 .sum::<f32>()
207 / (n - 1) as f32;
208 let std_drift_l2 = variance.sqrt();
209
210 let vectors_t1: Vec<&[f32]> = entity_data
213 .iter()
214 .map(|(_, v1, _, _)| v1.as_slice())
215 .collect();
216 let vectors_t2: Vec<&[f32]> = entity_data
217 .iter()
218 .map(|(_, _, v2, _)| v2.as_slice())
219 .collect();
220
221 let centroid_t1 = centroid(&vectors_t1);
222 let centroid_t2 = centroid(&vectors_t2);
223 let centroid_drift = drift_report(¢roid_t1, ¢roid_t2, top_n);
224
225 let dispersion_t1 = vectors_t1
228 .iter()
229 .map(|v| drift_magnitude_l2(v, ¢roid_t1))
230 .sum::<f32>()
231 / n as f32;
232
233 let dispersion_t2 = vectors_t2
234 .iter()
235 .map(|v| drift_magnitude_l2(v, ¢roid_t2))
236 .sum::<f32>()
237 / n as f32;
238
239 let dispersion_change = dispersion_t2 - dispersion_t1;
240
241 let drift_vectors: Vec<Vec<f32>> = entity_data.iter().map(|(_, _, _, dv)| dv.clone()).collect();
244 let convergence_score = compute_convergence_score(&drift_vectors);
245
246 let dim = entity_data[0].3.len();
249 let mut mean_delta = vec![0.0f32; dim];
250 for (_, _, _, dv) in &entity_data {
251 for (i, &val) in dv.iter().enumerate() {
252 mean_delta[i] += val;
253 }
254 }
255 for val in &mut mean_delta {
256 *val /= n as f32;
257 }
258
259 let mut dim_changes: Vec<(usize, f32)> = mean_delta
260 .iter()
261 .enumerate()
262 .map(|(i, &v)| (i, v.abs()))
263 .collect();
264 dim_changes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
265 dim_changes.truncate(top_n);
266
267 let mean_drift_dir: Vec<f32> = mean_delta.clone();
271 let mean_dir_norm: f32 = mean_drift_dir.iter().map(|x| x * x).sum::<f32>().sqrt();
272
273 let outliers: Vec<CohortOutlier> = entity_data
274 .iter()
275 .zip(drift_magnitudes.iter())
276 .filter_map(|((entity_id, _, _, dv), &mag)| {
277 let z = if std_drift_l2 > 1e-12 {
278 (mag - mean_drift_l2) / std_drift_l2
279 } else {
280 0.0
281 };
282
283 if z.abs() <= 2.0 {
284 return None;
285 }
286
287 let alignment = if mean_dir_norm > 1e-12 {
288 let dv_norm: f32 = dv.iter().map(|x| x * x).sum::<f32>().sqrt();
289 if dv_norm > 1e-12 {
290 let dot: f32 = dv
291 .iter()
292 .zip(mean_drift_dir.iter())
293 .map(|(a, b)| a * b)
294 .sum();
295 (dot / (dv_norm * mean_dir_norm)).clamp(-1.0, 1.0)
296 } else {
297 0.0
298 }
299 } else {
300 0.0
301 };
302
303 Some(CohortOutlier {
304 entity_id: *entity_id,
305 drift_magnitude: mag,
306 z_score: z,
307 drift_direction_alignment: alignment,
308 })
309 })
310 .collect();
311
312 Ok(CohortDriftReport {
313 n_entities: n,
314 mean_drift_l2,
315 median_drift_l2,
316 std_drift_l2,
317 centroid_drift,
318 dispersion_t1,
319 dispersion_t2,
320 dispersion_change,
321 convergence_score,
322 top_dimensions: dim_changes,
323 outliers,
324 })
325}
326
327#[cfg(test)]
330#[allow(
331 clippy::type_complexity,
332 clippy::needless_range_loop,
333 clippy::useless_vec
334)]
335mod tests {
336 use super::*;
337
338 fn as_refs(points: &[(i64, Vec<f32>)]) -> Vec<(i64, &[f32])> {
340 points.iter().map(|(t, v)| (*t, v.as_slice())).collect()
341 }
342
343 #[test]
346 fn nearest_vector_empty_trajectory() {
347 let traj: Vec<(i64, &[f32])> = vec![];
348 assert!(nearest_vector_at(&traj, 100).is_none());
349 }
350
351 #[test]
352 fn nearest_vector_exact_match() {
353 let owned = vec![
354 (100i64, vec![1.0f32, 0.0]),
355 (200, vec![0.0, 1.0]),
356 (300, vec![1.0, 1.0]),
357 ];
358 let traj = as_refs(&owned);
359 let v = nearest_vector_at(&traj, 200).unwrap();
360 assert_eq!(v, &[0.0, 1.0]);
361 }
362
363 #[test]
364 fn nearest_vector_between_timestamps() {
365 let owned = vec![
366 (100i64, vec![1.0f32, 0.0]),
367 (200, vec![0.0, 1.0]),
368 (300, vec![1.0, 1.0]),
369 ];
370 let traj = as_refs(&owned);
371 let v = nearest_vector_at(&traj, 190).unwrap();
373 assert_eq!(v, &[0.0, 1.0]);
374 }
375
376 #[test]
377 fn nearest_vector_before_first() {
378 let owned = vec![(100i64, vec![1.0f32, 2.0])];
379 let traj = as_refs(&owned);
380 let v = nearest_vector_at(&traj, 0).unwrap();
381 assert_eq!(v, &[1.0, 2.0]);
382 }
383
384 #[test]
387 fn centroid_single_vector() {
388 let v = vec![2.0f32, 4.0, 6.0];
389 let c = centroid(&[v.as_slice()]);
390 assert_eq!(c, vec![2.0, 4.0, 6.0]);
391 }
392
393 #[test]
394 fn centroid_two_vectors() {
395 let v1 = vec![0.0f32, 0.0];
396 let v2 = vec![2.0, 4.0];
397 let c = centroid(&[v1.as_slice(), v2.as_slice()]);
398 assert!((c[0] - 1.0).abs() < 1e-6);
399 assert!((c[1] - 2.0).abs() < 1e-6);
400 }
401
402 #[test]
403 fn centroid_empty() {
404 let c = centroid(&[]);
405 assert!(c.is_empty());
406 }
407
408 #[test]
411 fn convergence_all_same_direction() {
412 let drifts = vec![
413 vec![1.0f32, 0.0, 0.0],
414 vec![2.0, 0.0, 0.0],
415 vec![0.5, 0.0, 0.0],
416 ];
417 let score = compute_convergence_score(&drifts);
418 assert!((score - 1.0).abs() < 1e-6, "expected ~1.0, got {score}");
419 }
420
421 #[test]
422 fn convergence_opposite_directions() {
423 let drifts = vec![vec![1.0f32, 0.0], vec![-1.0, 0.0]];
424 let score = compute_convergence_score(&drifts);
425 assert!(
427 score.abs() < 1e-6,
428 "expected ~0.0 for zero mean, got {score}"
429 );
430 }
431
432 #[test]
433 fn convergence_orthogonal_directions() {
434 let drifts = vec![
436 vec![1.0f32, 0.0],
437 vec![0.0, 1.0],
438 vec![-1.0, 0.0],
439 vec![0.0, -1.0],
440 ];
441 let score = compute_convergence_score(&drifts);
442 assert!(score.abs() < 1e-6, "expected ~0.0, got {score}");
444 }
445
446 #[test]
447 fn convergence_too_few_vectors() {
448 let drifts = vec![vec![1.0f32]];
449 assert_eq!(compute_convergence_score(&drifts), 0.0);
450 }
451
452 #[test]
455 fn cohort_drift_insufficient_data() {
456 let traj1 = vec![(100i64, vec![1.0f32, 0.0])];
457 let refs1 = as_refs(&traj1);
458
459 let trajectories: Vec<(u64, &[(i64, &[f32])])> = vec![(1, &refs1)];
460 let result = cohort_drift(&trajectories, 100, 200, 5);
461 assert!(result.is_err());
462 match result.unwrap_err() {
463 AnalyticsError::InsufficientData { needed, have } => {
464 assert_eq!(needed, 2);
465 assert_eq!(have, 1);
466 }
467 other => panic!("expected InsufficientData, got {other:?}"),
468 }
469 }
470
471 #[test]
472 fn cohort_drift_uniform_shift() {
473 let dim = 3;
475 let n_entities = 10;
476 let shift = 0.1f32;
477
478 let mut owned_trajs: Vec<Vec<(i64, Vec<f32>)>> = Vec::new();
479 for i in 0..n_entities {
480 let base: Vec<f32> = (0..dim).map(|d| (i * dim + d) as f32 * 0.1).collect();
481 let shifted: Vec<f32> = base
482 .iter()
483 .enumerate()
484 .map(|(d, &v)| if d == 0 { v + shift } else { v })
485 .collect();
486 owned_trajs.push(vec![(1000, base), (2000, shifted)]);
487 }
488
489 let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
490 let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
491 .iter()
492 .enumerate()
493 .map(|(i, t)| (i as u64, t.as_slice()))
494 .collect();
495
496 let report = cohort_drift(&trajectories, 1000, 2000, 5).unwrap();
497
498 assert_eq!(report.n_entities, n_entities);
499
500 assert!(
502 (report.mean_drift_l2 - shift).abs() < 1e-5,
503 "expected mean drift ~{shift}, got {}",
504 report.mean_drift_l2
505 );
506 assert!(
507 (report.median_drift_l2 - shift).abs() < 1e-5,
508 "expected median ~{shift}, got {}",
509 report.median_drift_l2
510 );
511 assert!(
512 report.std_drift_l2 < 1e-5,
513 "expected std ~0 for uniform shift, got {}",
514 report.std_drift_l2
515 );
516
517 assert!(
519 report.convergence_score > 0.99,
520 "expected convergence ~1.0, got {}",
521 report.convergence_score
522 );
523
524 assert_eq!(report.top_dimensions[0].0, 0);
526
527 assert!(
529 report.outliers.is_empty(),
530 "expected no outliers, got {}",
531 report.outliers.len()
532 );
533 }
534
535 #[test]
536 fn cohort_drift_convergence_detected() {
537 let owned_trajs = [
539 vec![(1000i64, vec![0.0f32, 0.0]), (2000, vec![0.5, 0.5])],
540 vec![(1000, vec![2.0, 0.0]), (2000, vec![0.5, 0.5])],
541 vec![(1000, vec![0.0, 2.0]), (2000, vec![0.5, 0.5])],
542 ];
543
544 let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
545 let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
546 .iter()
547 .enumerate()
548 .map(|(i, t)| (i as u64, t.as_slice()))
549 .collect();
550
551 let report = cohort_drift(&trajectories, 1000, 2000, 2).unwrap();
552
553 assert!(
554 report.dispersion_change < 0.0,
555 "expected negative dispersion change (convergence), got {}",
556 report.dispersion_change
557 );
558 assert!(
559 report.dispersion_t2 < report.dispersion_t1,
560 "t2 dispersion ({}) should be less than t1 ({})",
561 report.dispersion_t2,
562 report.dispersion_t1
563 );
564 }
565
566 #[test]
567 fn cohort_drift_divergence_detected() {
568 let owned_trajs = [
570 vec![(1000i64, vec![0.5f32, 0.5]), (2000, vec![0.0, 0.0])],
571 vec![(1000, vec![0.5, 0.5]), (2000, vec![2.0, 0.0])],
572 vec![(1000, vec![0.5, 0.5]), (2000, vec![0.0, 2.0])],
573 ];
574
575 let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
576 let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
577 .iter()
578 .enumerate()
579 .map(|(i, t)| (i as u64, t.as_slice()))
580 .collect();
581
582 let report = cohort_drift(&trajectories, 1000, 2000, 2).unwrap();
583
584 assert!(
585 report.dispersion_change > 0.0,
586 "expected positive dispersion change (divergence), got {}",
587 report.dispersion_change
588 );
589 }
590
591 #[test]
592 fn cohort_drift_outlier_detection() {
593 let dim = 4;
595 let mut owned_trajs: Vec<Vec<(i64, Vec<f32>)>> = Vec::new();
596
597 for i in 0..9u64 {
599 let base: Vec<f32> = vec![i as f32 * 0.1; dim];
600 let shifted: Vec<f32> = base
601 .iter()
602 .enumerate()
603 .map(|(d, &v)| if d == 0 { v + 0.01 } else { v })
604 .collect();
605 owned_trajs.push(vec![(1000, base), (2000, shifted)]);
606 }
607
608 let base = vec![0.5f32; dim];
610 let shifted: Vec<f32> = base
611 .iter()
612 .enumerate()
613 .map(|(d, &v)| if d == 0 { v + 10.0 } else { v })
614 .collect();
615 owned_trajs.push(vec![(1000, base), (2000, shifted)]);
616
617 let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
618 let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
619 .iter()
620 .enumerate()
621 .map(|(i, t)| (i as u64, t.as_slice()))
622 .collect();
623
624 let report = cohort_drift(&trajectories, 1000, 2000, 3).unwrap();
625
626 assert_eq!(report.n_entities, 10);
627 assert!(!report.outliers.is_empty(), "expected at least one outlier");
628
629 let outlier = report.outliers.iter().find(|o| o.entity_id == 9);
631 assert!(outlier.is_some(), "entity 9 should be flagged as outlier");
632 let outlier = outlier.unwrap();
633 assert!(
634 outlier.z_score > 2.0,
635 "outlier z-score should be > 2.0, got {}",
636 outlier.z_score
637 );
638 assert!(
639 outlier.drift_magnitude > 9.0,
640 "outlier drift should be large, got {}",
641 outlier.drift_magnitude
642 );
643 }
644
645 #[test]
646 fn cohort_drift_centroid_drift_matches_manual() {
647 let owned_trajs = [
649 vec![(1000i64, vec![0.0f32, 0.0]), (2000, vec![1.0, 0.0])],
650 vec![(1000, vec![2.0, 0.0]), (2000, vec![3.0, 0.0])],
651 ];
652
653 let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
654 let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
655 .iter()
656 .enumerate()
657 .map(|(i, t)| (i as u64, t.as_slice()))
658 .collect();
659
660 let report = cohort_drift(&trajectories, 1000, 2000, 2).unwrap();
661
662 assert!(
665 (report.centroid_drift.l2_magnitude - 1.0).abs() < 1e-5,
666 "expected centroid drift 1.0, got {}",
667 report.centroid_drift.l2_magnitude
668 );
669 }
670
671 #[test]
672 fn cohort_drift_no_data_at_one_timepoint() {
673 let owned_trajs = [
677 vec![(1000i64, vec![1.0f32, 0.0])],
678 vec![(1000, vec![2.0, 0.0]), (2000, vec![3.0, 0.0])],
679 vec![(2000i64, vec![4.0, 0.0])],
680 ];
681
682 let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
683 let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
684 .iter()
685 .enumerate()
686 .map(|(i, t)| (i as u64, t.as_slice()))
687 .collect();
688
689 let result = cohort_drift(&trajectories, 1000, 2000, 2);
691 assert!(result.is_ok());
692 assert_eq!(result.unwrap().n_entities, 3);
693 }
694
695 #[test]
696 fn cohort_drift_stationary_cohort() {
697 let owned_trajs = [
699 vec![
700 (1000i64, vec![1.0f32, 2.0, 3.0]),
701 (2000, vec![1.0, 2.0, 3.0]),
702 ],
703 vec![(1000, vec![4.0, 5.0, 6.0]), (2000, vec![4.0, 5.0, 6.0])],
704 vec![(1000, vec![7.0, 8.0, 9.0]), (2000, vec![7.0, 8.0, 9.0])],
705 ];
706
707 let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
708 let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
709 .iter()
710 .enumerate()
711 .map(|(i, t)| (i as u64, t.as_slice()))
712 .collect();
713
714 let report = cohort_drift(&trajectories, 1000, 2000, 3).unwrap();
715
716 assert!(
717 report.mean_drift_l2 < 1e-6,
718 "stationary cohort should have ~0 drift"
719 );
720 assert!(report.median_drift_l2 < 1e-6);
721 assert!(report.centroid_drift.l2_magnitude < 1e-6);
722 assert!(
723 (report.dispersion_change).abs() < 1e-6,
724 "dispersion should not change"
725 );
726 assert!(report.outliers.is_empty());
727 }
728
729 #[test]
730 fn cohort_drift_high_dimensional() {
731 let dim = 128;
733 let n_entities = 20;
734 let mut owned_trajs = Vec::new();
735
736 for i in 0..n_entities {
737 let base: Vec<f32> = (0..dim)
738 .map(|d| ((i * dim + d) as f32 * 0.01).sin())
739 .collect();
740 let shifted: Vec<f32> = base.iter().map(|v| v + 0.05).collect();
741 owned_trajs.push(vec![(1000i64, base), (2000, shifted)]);
742 }
743
744 let ref_trajs: Vec<Vec<(i64, &[f32])>> = owned_trajs.iter().map(|t| as_refs(t)).collect();
745 let trajectories: Vec<(u64, &[(i64, &[f32])])> = ref_trajs
746 .iter()
747 .enumerate()
748 .map(|(i, t)| (i as u64, t.as_slice()))
749 .collect();
750
751 let report = cohort_drift(&trajectories, 1000, 2000, 10).unwrap();
752
753 assert_eq!(report.n_entities, n_entities);
754 assert!(report.mean_drift_l2 > 0.0);
755 assert_eq!(report.top_dimensions.len(), 10);
756 assert!(
758 report.convergence_score > 0.95,
759 "uniform shift should give high convergence, got {}",
760 report.convergence_score
761 );
762 }
763}