1use crate::calculus::drift_magnitude_l2;
13use cvx_core::error::AnalyticsError;
14
15#[derive(Debug, Clone)]
19pub struct CounterfactualResult {
20 pub change_point: i64,
22 pub actual: Vec<(i64, Vec<f32>)>,
24 pub counterfactual: Vec<(i64, Vec<f32>)>,
26 pub divergence_curve: Vec<(i64, f32)>,
28 pub total_divergence: f64,
30 pub max_divergence_time: i64,
32 pub max_divergence_value: f32,
34 pub method: CounterfactualMethod,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum CounterfactualMethod {
41 LinearExtrapolation,
43}
44
45pub fn counterfactual_trajectory(
64 pre_change: &[(i64, &[f32])],
65 post_change: &[(i64, &[f32])],
66 change_point: i64,
67) -> Result<CounterfactualResult, AnalyticsError> {
68 if pre_change.len() < 2 {
69 return Err(AnalyticsError::InsufficientData {
70 needed: 2,
71 have: pre_change.len(),
72 });
73 }
74 if post_change.is_empty() {
75 return Err(AnalyticsError::InsufficientData { needed: 1, have: 0 });
76 }
77
78 let dim = pre_change[0].1.len();
79
80 let (slopes, intercepts) = fit_linear_per_dim(pre_change);
83
84 let counterfactual: Vec<(i64, Vec<f32>)> = post_change
87 .iter()
88 .map(|&(t, _)| {
89 let t_f = t as f64;
90 let vec: Vec<f32> = (0..dim)
91 .map(|d| (slopes[d] * t_f + intercepts[d]) as f32)
92 .collect();
93 (t, vec)
94 })
95 .collect();
96
97 let actual: Vec<(i64, Vec<f32>)> = post_change.iter().map(|&(t, v)| (t, v.to_vec())).collect();
100
101 let divergence_curve: Vec<(i64, f32)> = actual
104 .iter()
105 .zip(counterfactual.iter())
106 .map(|((t, act), (_, cf))| {
107 let dist = drift_magnitude_l2(act, cf);
108 (*t, dist)
109 })
110 .collect();
111
112 let (max_time, max_val) = divergence_curve
115 .iter()
116 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
117 .map(|&(t, v)| (t, v))
118 .unwrap_or((change_point, 0.0));
119
120 let total_divergence = trapezoidal_integral(&divergence_curve);
121
122 Ok(CounterfactualResult {
123 change_point,
124 actual,
125 counterfactual,
126 divergence_curve,
127 total_divergence,
128 max_divergence_time: max_time,
129 max_divergence_value: max_val,
130 method: CounterfactualMethod::LinearExtrapolation,
131 })
132}
133
134fn fit_linear_per_dim(trajectory: &[(i64, &[f32])]) -> (Vec<f64>, Vec<f64>) {
138 let n = trajectory.len() as f64;
139 let dim = trajectory[0].1.len();
140
141 let t_vals: Vec<f64> = trajectory.iter().map(|(t, _)| *t as f64).collect();
142 let t_mean: f64 = t_vals.iter().sum::<f64>() / n;
143
144 let mut slopes = vec![0.0f64; dim];
145 let mut intercepts = vec![0.0f64; dim];
146
147 let t_var: f64 = t_vals.iter().map(|t| (t - t_mean) * (t - t_mean)).sum();
148
149 if t_var < 1e-15 {
150 let last = trajectory.last().unwrap().1;
152 for d in 0..dim {
153 intercepts[d] = last[d] as f64;
154 }
155 return (slopes, intercepts);
156 }
157
158 for d in 0..dim {
159 let y_vals: Vec<f64> = trajectory.iter().map(|(_, v)| v[d] as f64).collect();
160 let y_mean: f64 = y_vals.iter().sum::<f64>() / n;
161
162 let covar: f64 = t_vals
163 .iter()
164 .zip(y_vals.iter())
165 .map(|(t, y)| (t - t_mean) * (y - y_mean))
166 .sum();
167
168 slopes[d] = covar / t_var;
169 intercepts[d] = y_mean - slopes[d] * t_mean;
170 }
171
172 (slopes, intercepts)
173}
174
175fn trapezoidal_integral(curve: &[(i64, f32)]) -> f64 {
177 if curve.len() < 2 {
178 return 0.0;
179 }
180
181 let mut total = 0.0f64;
182 for w in curve.windows(2) {
183 let dt = (w[1].0 - w[0].0) as f64;
184 let avg_val = (w[0].1 as f64 + w[1].1 as f64) / 2.0;
185 total += dt * avg_val;
186 }
187
188 total
189}
190
191#[cfg(test)]
194mod tests {
195 use super::*;
196
197 fn as_refs(points: &[(i64, Vec<f32>)]) -> Vec<(i64, &[f32])> {
198 points.iter().map(|(t, v)| (*t, v.as_slice())).collect()
199 }
200
201 #[test]
204 fn linear_fit_perfect_line() {
205 let owned = vec![
207 (0i64, vec![1.0f32]),
208 (1, vec![3.0]),
209 (2, vec![5.0]),
210 (3, vec![7.0]),
211 ];
212 let traj = as_refs(&owned);
213 let (slopes, intercepts) = fit_linear_per_dim(&traj);
214 assert!(
215 (slopes[0] - 2.0).abs() < 1e-6,
216 "slope should be 2.0, got {}",
217 slopes[0]
218 );
219 assert!(
220 (intercepts[0] - 1.0).abs() < 1e-6,
221 "intercept should be 1.0, got {}",
222 intercepts[0]
223 );
224 }
225
226 #[test]
227 fn linear_fit_multidim() {
228 let owned = vec![
230 (0i64, vec![0.0f32, 10.0]),
231 (5, vec![5.0, 5.0]),
232 (10, vec![10.0, 0.0]),
233 ];
234 let traj = as_refs(&owned);
235 let (slopes, intercepts) = fit_linear_per_dim(&traj);
236 assert!((slopes[0] - 1.0).abs() < 1e-6);
237 assert!((slopes[1] - (-1.0)).abs() < 1e-6);
238 assert!(intercepts[0].abs() < 1e-6);
239 assert!((intercepts[1] - 10.0).abs() < 1e-6);
240 }
241
242 #[test]
245 fn integral_constant() {
246 let curve = vec![(0i64, 5.0f32), (100, 5.0)];
247 let area = trapezoidal_integral(&curve);
248 assert!((area - 500.0).abs() < 1e-6);
249 }
250
251 #[test]
252 fn integral_triangle() {
253 let curve = vec![(0i64, 0.0f32), (100, 10.0)];
255 let area = trapezoidal_integral(&curve);
256 assert!((area - 500.0).abs() < 1e-6);
257 }
258
259 #[test]
260 fn integral_single_point() {
261 assert_eq!(trapezoidal_integral(&[(0, 5.0)]), 0.0);
262 }
263
264 #[test]
267 fn counterfactual_insufficient_pre() {
268 let pre_owned = vec![(100i64, vec![1.0f32])];
269 let post_owned = vec![(200i64, vec![2.0f32])];
270 let pre = as_refs(&pre_owned);
271 let post = as_refs(&post_owned);
272 assert!(counterfactual_trajectory(&pre, &post, 150).is_err());
273 }
274
275 #[test]
276 fn counterfactual_empty_post() {
277 let pre_owned = vec![(100i64, vec![1.0f32]), (200, vec![2.0])];
278 let post: Vec<(i64, &[f32])> = vec![];
279 let pre = as_refs(&pre_owned);
280 assert!(counterfactual_trajectory(&pre, &post, 250).is_err());
281 }
282
283 #[test]
284 fn counterfactual_linear_continuation() {
285 let pre_owned: Vec<(i64, Vec<f32>)> = (0..10)
289 .map(|i| (i as i64 * 1000, vec![i as f32 * 0.1]))
290 .collect();
291 let post_owned: Vec<(i64, Vec<f32>)> = (10..15)
292 .map(|i| (i as i64 * 1000, vec![100.0])) .collect();
294
295 let pre = as_refs(&pre_owned);
296 let post = as_refs(&post_owned);
297
298 let result = counterfactual_trajectory(&pre, &post, 10000).unwrap();
299
300 assert_eq!(result.change_point, 10000);
301 assert_eq!(result.actual.len(), 5);
302 assert_eq!(result.counterfactual.len(), 5);
303 assert_eq!(result.divergence_curve.len(), 5);
304
305 let cf_at_cp = &result.counterfactual[0].1[0];
307 assert!(
308 (*cf_at_cp - 1.0).abs() < 0.1,
309 "counterfactual at change point should be ~1.0, got {cf_at_cp}"
310 );
311
312 assert!(
314 result.max_divergence_value > 90.0,
315 "divergence should be large, got {}",
316 result.max_divergence_value
317 );
318
319 assert!(result.total_divergence > 0.0);
320 }
321
322 #[test]
323 fn counterfactual_no_change() {
324 let pre_owned: Vec<(i64, Vec<f32>)> = (0..10)
326 .map(|i| (i as i64 * 1000, vec![i as f32 * 0.1]))
327 .collect();
328 let post_owned: Vec<(i64, Vec<f32>)> = (10..15)
329 .map(|i| (i as i64 * 1000, vec![i as f32 * 0.1]))
330 .collect();
331
332 let pre = as_refs(&pre_owned);
333 let post = as_refs(&post_owned);
334
335 let result = counterfactual_trajectory(&pre, &post, 10000).unwrap();
336
337 assert!(
339 result.max_divergence_value < 0.1,
340 "no change should have ~0 divergence, got {}",
341 result.max_divergence_value
342 );
343 }
344
345 #[test]
346 fn counterfactual_multidim() {
347 let pre_owned: Vec<(i64, Vec<f32>)> = (0..10)
348 .map(|i| (i as i64 * 1000, vec![i as f32, -(i as f32)]))
349 .collect();
350 let post_owned: Vec<(i64, Vec<f32>)> = (10..15)
351 .map(|i| (i as i64 * 1000, vec![50.0, 50.0])) .collect();
353
354 let pre = as_refs(&pre_owned);
355 let post = as_refs(&post_owned);
356
357 let result = counterfactual_trajectory(&pre, &post, 10000).unwrap();
358
359 assert_eq!(result.counterfactual[0].1.len(), 2);
360 assert!(result.max_divergence_value > 10.0);
361 }
362
363 #[test]
364 fn counterfactual_method_is_linear() {
365 let pre_owned = vec![(0i64, vec![0.0f32]), (1000, vec![1.0])];
366 let post_owned = vec![(2000i64, vec![10.0f32])];
367 let pre = as_refs(&pre_owned);
368 let post = as_refs(&post_owned);
369
370 let result = counterfactual_trajectory(&pre, &post, 1500).unwrap();
371 assert_eq!(result.method, CounterfactualMethod::LinearExtrapolation);
372 }
373}