cvx_analytics/
counterfactual.rs

1//! Counterfactual trajectory analysis.
2//!
3//! Given a detected change point, extrapolates the **pre-change trajectory**
4//! beyond the change point and compares with the actual post-change trajectory.
5//! Answers: "What *would* have happened if the change hadn't occurred?"
6//!
7//! # References
8//!
9//! - Brodersen, K. H. et al. (2015). Causal impact. *Annals of Applied Statistics*.
10//! - Abadie, A. (2021). Synthetic controls. *JEL*, 59(2).
11
12use crate::calculus::drift_magnitude_l2;
13use cvx_core::error::AnalyticsError;
14
15// ─── Types ──────────────────────────────────────────────────────────
16
17/// Counterfactual analysis result.
18#[derive(Debug, Clone)]
19pub struct CounterfactualResult {
20    /// The change point timestamp.
21    pub change_point: i64,
22    /// Actual post-change trajectory.
23    pub actual: Vec<(i64, Vec<f32>)>,
24    /// Counterfactual (extrapolated pre-change) trajectory.
25    pub counterfactual: Vec<(i64, Vec<f32>)>,
26    /// Divergence between actual and counterfactual over time.
27    pub divergence_curve: Vec<(i64, f32)>,
28    /// Total divergence (area under curve, via trapezoidal rule).
29    pub total_divergence: f64,
30    /// Timestamp of maximum divergence.
31    pub max_divergence_time: i64,
32    /// Maximum divergence value.
33    pub max_divergence_value: f32,
34    /// Method used for extrapolation.
35    pub method: CounterfactualMethod,
36}
37
38/// Method used for counterfactual extrapolation.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum CounterfactualMethod {
41    /// OLS linear extrapolation per dimension.
42    LinearExtrapolation,
43}
44
45// ─── Core function ──────────────────────────────────────────────────
46
47/// Compute a counterfactual trajectory analysis.
48///
49/// Splits the trajectory at `change_point`, fits a linear trend to the
50/// pre-change segment, extrapolates beyond the change point, and compares
51/// with the actual post-change data.
52///
53/// # Arguments
54///
55/// * `pre_change` — Trajectory before the change point (sorted by timestamp)
56/// * `post_change` — Trajectory after the change point (sorted by timestamp)
57/// * `change_point` — Timestamp of the detected change
58///
59/// # Errors
60///
61/// Returns [`AnalyticsError::InsufficientData`] if pre_change has < 2 points
62/// or post_change is empty.
63pub fn counterfactual_trajectory(
64    pre_change: &[(i64, &[f32])],
65    post_change: &[(i64, &[f32])],
66    change_point: i64,
67) -> Result<CounterfactualResult, AnalyticsError> {
68    if pre_change.len() < 2 {
69        return Err(AnalyticsError::InsufficientData {
70            needed: 2,
71            have: pre_change.len(),
72        });
73    }
74    if post_change.is_empty() {
75        return Err(AnalyticsError::InsufficientData { needed: 1, have: 0 });
76    }
77
78    let dim = pre_change[0].1.len();
79
80    // ── Fit linear trend per dimension on pre-change data ──
81
82    let (slopes, intercepts) = fit_linear_per_dim(pre_change);
83
84    // ── Extrapolate at post-change timestamps ──
85
86    let counterfactual: Vec<(i64, Vec<f32>)> = post_change
87        .iter()
88        .map(|&(t, _)| {
89            let t_f = t as f64;
90            let vec: Vec<f32> = (0..dim)
91                .map(|d| (slopes[d] * t_f + intercepts[d]) as f32)
92                .collect();
93            (t, vec)
94        })
95        .collect();
96
97    // ── Actual post-change trajectory ──
98
99    let actual: Vec<(i64, Vec<f32>)> = post_change.iter().map(|&(t, v)| (t, v.to_vec())).collect();
100
101    // ── Divergence curve ──
102
103    let divergence_curve: Vec<(i64, f32)> = actual
104        .iter()
105        .zip(counterfactual.iter())
106        .map(|((t, act), (_, cf))| {
107            let dist = drift_magnitude_l2(act, cf);
108            (*t, dist)
109        })
110        .collect();
111
112    // ── Aggregate metrics ──
113
114    let (max_time, max_val) = divergence_curve
115        .iter()
116        .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
117        .map(|&(t, v)| (t, v))
118        .unwrap_or((change_point, 0.0));
119
120    let total_divergence = trapezoidal_integral(&divergence_curve);
121
122    Ok(CounterfactualResult {
123        change_point,
124        actual,
125        counterfactual,
126        divergence_curve,
127        total_divergence,
128        max_divergence_time: max_time,
129        max_divergence_value: max_val,
130        method: CounterfactualMethod::LinearExtrapolation,
131    })
132}
133
134// ─── Helpers ────────────────────────────────────────────────────────
135
136/// Fit a linear model y_d = slope_d * t + intercept_d for each dimension.
137fn fit_linear_per_dim(trajectory: &[(i64, &[f32])]) -> (Vec<f64>, Vec<f64>) {
138    let n = trajectory.len() as f64;
139    let dim = trajectory[0].1.len();
140
141    let t_vals: Vec<f64> = trajectory.iter().map(|(t, _)| *t as f64).collect();
142    let t_mean: f64 = t_vals.iter().sum::<f64>() / n;
143
144    let mut slopes = vec![0.0f64; dim];
145    let mut intercepts = vec![0.0f64; dim];
146
147    let t_var: f64 = t_vals.iter().map(|t| (t - t_mean) * (t - t_mean)).sum();
148
149    if t_var < 1e-15 {
150        // All same timestamp — use last values as constant
151        let last = trajectory.last().unwrap().1;
152        for d in 0..dim {
153            intercepts[d] = last[d] as f64;
154        }
155        return (slopes, intercepts);
156    }
157
158    for d in 0..dim {
159        let y_vals: Vec<f64> = trajectory.iter().map(|(_, v)| v[d] as f64).collect();
160        let y_mean: f64 = y_vals.iter().sum::<f64>() / n;
161
162        let covar: f64 = t_vals
163            .iter()
164            .zip(y_vals.iter())
165            .map(|(t, y)| (t - t_mean) * (y - y_mean))
166            .sum();
167
168        slopes[d] = covar / t_var;
169        intercepts[d] = y_mean - slopes[d] * t_mean;
170    }
171
172    (slopes, intercepts)
173}
174
175/// Trapezoidal rule integration of a (timestamp, value) curve.
176fn trapezoidal_integral(curve: &[(i64, f32)]) -> f64 {
177    if curve.len() < 2 {
178        return 0.0;
179    }
180
181    let mut total = 0.0f64;
182    for w in curve.windows(2) {
183        let dt = (w[1].0 - w[0].0) as f64;
184        let avg_val = (w[0].1 as f64 + w[1].1 as f64) / 2.0;
185        total += dt * avg_val;
186    }
187
188    total
189}
190
191// ─── Tests ──────────────────────────────────────────────────────────
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    fn as_refs(points: &[(i64, Vec<f32>)]) -> Vec<(i64, &[f32])> {
198        points.iter().map(|(t, v)| (*t, v.as_slice())).collect()
199    }
200
201    // ─── fit_linear_per_dim ─────────────────────────────────────
202
203    #[test]
204    fn linear_fit_perfect_line() {
205        // y = 2*t + 1 (1-dim)
206        let owned = vec![
207            (0i64, vec![1.0f32]),
208            (1, vec![3.0]),
209            (2, vec![5.0]),
210            (3, vec![7.0]),
211        ];
212        let traj = as_refs(&owned);
213        let (slopes, intercepts) = fit_linear_per_dim(&traj);
214        assert!(
215            (slopes[0] - 2.0).abs() < 1e-6,
216            "slope should be 2.0, got {}",
217            slopes[0]
218        );
219        assert!(
220            (intercepts[0] - 1.0).abs() < 1e-6,
221            "intercept should be 1.0, got {}",
222            intercepts[0]
223        );
224    }
225
226    #[test]
227    fn linear_fit_multidim() {
228        // dim 0: y = t, dim 1: y = -t + 10
229        let owned = vec![
230            (0i64, vec![0.0f32, 10.0]),
231            (5, vec![5.0, 5.0]),
232            (10, vec![10.0, 0.0]),
233        ];
234        let traj = as_refs(&owned);
235        let (slopes, intercepts) = fit_linear_per_dim(&traj);
236        assert!((slopes[0] - 1.0).abs() < 1e-6);
237        assert!((slopes[1] - (-1.0)).abs() < 1e-6);
238        assert!(intercepts[0].abs() < 1e-6);
239        assert!((intercepts[1] - 10.0).abs() < 1e-6);
240    }
241
242    // ─── trapezoidal_integral ───────────────────────────────────
243
244    #[test]
245    fn integral_constant() {
246        let curve = vec![(0i64, 5.0f32), (100, 5.0)];
247        let area = trapezoidal_integral(&curve);
248        assert!((area - 500.0).abs() < 1e-6);
249    }
250
251    #[test]
252    fn integral_triangle() {
253        // Triangle from (0,0) to (100, 10): area = 0.5 * 100 * 10 = 500
254        let curve = vec![(0i64, 0.0f32), (100, 10.0)];
255        let area = trapezoidal_integral(&curve);
256        assert!((area - 500.0).abs() < 1e-6);
257    }
258
259    #[test]
260    fn integral_single_point() {
261        assert_eq!(trapezoidal_integral(&[(0, 5.0)]), 0.0);
262    }
263
264    // ─── counterfactual_trajectory ──────────────────────────────
265
266    #[test]
267    fn counterfactual_insufficient_pre() {
268        let pre_owned = vec![(100i64, vec![1.0f32])];
269        let post_owned = vec![(200i64, vec![2.0f32])];
270        let pre = as_refs(&pre_owned);
271        let post = as_refs(&post_owned);
272        assert!(counterfactual_trajectory(&pre, &post, 150).is_err());
273    }
274
275    #[test]
276    fn counterfactual_empty_post() {
277        let pre_owned = vec![(100i64, vec![1.0f32]), (200, vec![2.0])];
278        let post: Vec<(i64, &[f32])> = vec![];
279        let pre = as_refs(&pre_owned);
280        assert!(counterfactual_trajectory(&pre, &post, 250).is_err());
281    }
282
283    #[test]
284    fn counterfactual_linear_continuation() {
285        // Pre: y = t * 0.1 (linear, slope = 0.1)
286        // Post (actual): y jumps to 100 (major change)
287        // Counterfactual should continue the linear trend
288        let pre_owned: Vec<(i64, Vec<f32>)> = (0..10)
289            .map(|i| (i as i64 * 1000, vec![i as f32 * 0.1]))
290            .collect();
291        let post_owned: Vec<(i64, Vec<f32>)> = (10..15)
292            .map(|i| (i as i64 * 1000, vec![100.0])) // sudden jump
293            .collect();
294
295        let pre = as_refs(&pre_owned);
296        let post = as_refs(&post_owned);
297
298        let result = counterfactual_trajectory(&pre, &post, 10000).unwrap();
299
300        assert_eq!(result.change_point, 10000);
301        assert_eq!(result.actual.len(), 5);
302        assert_eq!(result.counterfactual.len(), 5);
303        assert_eq!(result.divergence_curve.len(), 5);
304
305        // Counterfactual at t=10000 should be ~1.0 (continuing slope 0.0001 * 10000)
306        let cf_at_cp = &result.counterfactual[0].1[0];
307        assert!(
308            (*cf_at_cp - 1.0).abs() < 0.1,
309            "counterfactual at change point should be ~1.0, got {cf_at_cp}"
310        );
311
312        // Actual is 100.0, so divergence should be large
313        assert!(
314            result.max_divergence_value > 90.0,
315            "divergence should be large, got {}",
316            result.max_divergence_value
317        );
318
319        assert!(result.total_divergence > 0.0);
320    }
321
322    #[test]
323    fn counterfactual_no_change() {
324        // Pre and post follow the same linear trend — divergence should be ~0
325        let pre_owned: Vec<(i64, Vec<f32>)> = (0..10)
326            .map(|i| (i as i64 * 1000, vec![i as f32 * 0.1]))
327            .collect();
328        let post_owned: Vec<(i64, Vec<f32>)> = (10..15)
329            .map(|i| (i as i64 * 1000, vec![i as f32 * 0.1]))
330            .collect();
331
332        let pre = as_refs(&pre_owned);
333        let post = as_refs(&post_owned);
334
335        let result = counterfactual_trajectory(&pre, &post, 10000).unwrap();
336
337        // Counterfactual should match actual (same linear trend)
338        assert!(
339            result.max_divergence_value < 0.1,
340            "no change should have ~0 divergence, got {}",
341            result.max_divergence_value
342        );
343    }
344
345    #[test]
346    fn counterfactual_multidim() {
347        let pre_owned: Vec<(i64, Vec<f32>)> = (0..10)
348            .map(|i| (i as i64 * 1000, vec![i as f32, -(i as f32)]))
349            .collect();
350        let post_owned: Vec<(i64, Vec<f32>)> = (10..15)
351            .map(|i| (i as i64 * 1000, vec![50.0, 50.0])) // abrupt change in both dims
352            .collect();
353
354        let pre = as_refs(&pre_owned);
355        let post = as_refs(&post_owned);
356
357        let result = counterfactual_trajectory(&pre, &post, 10000).unwrap();
358
359        assert_eq!(result.counterfactual[0].1.len(), 2);
360        assert!(result.max_divergence_value > 10.0);
361    }
362
363    #[test]
364    fn counterfactual_method_is_linear() {
365        let pre_owned = vec![(0i64, vec![0.0f32]), (1000, vec![1.0])];
366        let post_owned = vec![(2000i64, vec![10.0f32])];
367        let pre = as_refs(&pre_owned);
368        let post = as_refs(&post_owned);
369
370        let result = counterfactual_trajectory(&pre, &post, 1500).unwrap();
371        assert_eq!(result.method, CounterfactualMethod::LinearExtrapolation);
372    }
373}