cvx_analytics/
granger.rs

1//! Granger causality testing for embedding trajectories.
2//!
3//! Tests whether entity A's movements in embedding space **precede** entity B's.
4//! Uses a VAR(L) model on dimensionality-reduced trajectories.
5//!
6//! # Algorithm
7//!
8//! 1. Align trajectories to a common time grid (linear interpolation)
9//! 2. For each dimension d:
10//!    - Fit restricted model: `B_d(t) = Σ β_l · B_d(t-l) + ε`
11//!    - Fit unrestricted model: `B_d(t) = Σ β_l · B_d(t-l) + Σ γ_l · A_d(t-l) + ε`
12//!    - F-test: does the unrestricted model significantly improve?
13//! 3. Combine per-dimension p-values via Fisher's method
14//!
15//! # References
16//!
17//! - Granger, C.W.J. (1969). Investigating causal relations. *Econometrica*, 37(3).
18//! - Fisher, R.A. (1925). Statistical methods for research workers.
19
20use cvx_core::error::AnalyticsError;
21
22// ─── Types ──────────────────────────────────────────────────────────
23
24/// Direction of Granger causality.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum GrangerDirection {
27    /// A Granger-causes B.
28    AToB,
29    /// B Granger-causes A.
30    BToA,
31    /// Both directions significant.
32    Bidirectional,
33    /// No significant causality detected.
34    None,
35}
36
37/// Result of a Granger causality test.
38#[derive(Debug, Clone)]
39pub struct GrangerResult {
40    /// Detected causal direction.
41    pub direction: GrangerDirection,
42    /// Optimal lag (the one with lowest combined p-value for the winning direction).
43    pub optimal_lag: usize,
44    /// F-statistic for the winning direction at optimal lag.
45    pub f_statistic: f64,
46    /// Combined p-value (Fisher's method) for the winning direction.
47    pub p_value: f64,
48    /// Partial R² improvement (effect size).
49    pub effect_size: f64,
50    /// Per-dimension F-statistics for A→B at optimal lag.
51    pub per_dimension_a_to_b: Vec<f64>,
52    /// Per-dimension F-statistics for B→A at optimal lag.
53    pub per_dimension_b_to_a: Vec<f64>,
54}
55
56// ─── Alignment ──────────────────────────────────────────────────────
57
58/// Align two trajectories to a common time grid via linear interpolation.
59///
60/// Returns `(aligned_a, aligned_b)` with the same timestamps.
61fn align_trajectories(
62    traj_a: &[(i64, &[f32])],
63    traj_b: &[(i64, &[f32])],
64) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
65    // Use the union of timestamps, then interpolate both at those points
66    // For simplicity, use the timestamps of the trajectory with more points
67    // and interpolate the other
68    let (base, other, swap) = if traj_a.len() >= traj_b.len() {
69        (traj_a, traj_b, false)
70    } else {
71        (traj_b, traj_a, true)
72    };
73
74    let dim = base[0].1.len();
75    let timestamps: Vec<i64> = base.iter().map(|(t, _)| *t).collect();
76
77    let base_vecs: Vec<Vec<f32>> = base.iter().map(|(_, v)| v.to_vec()).collect();
78    let other_vecs: Vec<Vec<f32>> = timestamps
79        .iter()
80        .map(|&t| interpolate_at(other, t, dim))
81        .collect();
82
83    if swap {
84        (other_vecs, base_vecs)
85    } else {
86        (base_vecs, other_vecs)
87    }
88}
89
90/// Linear interpolation of a trajectory at a specific timestamp.
91fn interpolate_at(traj: &[(i64, &[f32])], t: i64, dim: usize) -> Vec<f32> {
92    if traj.is_empty() {
93        return vec![0.0; dim];
94    }
95    if traj.len() == 1 {
96        return traj[0].1.to_vec();
97    }
98
99    // Before first point
100    if t <= traj[0].0 {
101        return traj[0].1.to_vec();
102    }
103    // After last point
104    if t >= traj.last().unwrap().0 {
105        return traj.last().unwrap().1.to_vec();
106    }
107
108    // Find bracketing points
109    let idx = traj
110        .iter()
111        .position(|(ts, _)| *ts >= t)
112        .unwrap_or(traj.len() - 1);
113
114    if traj[idx].0 == t {
115        return traj[idx].1.to_vec();
116    }
117
118    let (t0, v0) = &traj[idx - 1];
119    let (t1, v1) = &traj[idx];
120    let alpha = (t - t0) as f64 / (t1 - t0) as f64;
121
122    v0.iter()
123        .zip(v1.iter())
124        .map(|(&a, &b)| (a as f64 * (1.0 - alpha) + b as f64 * alpha) as f32)
125        .collect()
126}
127
128// ─── OLS regression ─────────────────────────────────────────────────
129
130/// Fit OLS for a single dimension and compute residual sum of squares.
131///
132/// Restricted model: `y_d(t) = Σ_l β_l · y_d(t-l) + ε`
133/// Unrestricted model: `y_d(t) = Σ_l β_l · y_d(t-l) + Σ_l γ_l · x_d(t-l) + ε`
134///
135/// Returns `(rss_restricted, rss_unrestricted, n_obs)`.
136fn ols_granger_single_dim(
137    y: &[f64], // target series (one dimension)
138    x: &[f64], // predictor series (one dimension)
139    lag: usize,
140) -> (f64, f64, usize) {
141    let n = y.len();
142    if n <= lag {
143        return (1.0, 1.0, 0);
144    }
145
146    let n_obs = n - lag;
147
148    // Build design matrices
149    // Restricted: [y(t-1), y(t-2), ..., y(t-lag)]
150    // Unrestricted: [y(t-1), ..., y(t-lag), x(t-1), ..., x(t-lag)]
151
152    let rss_r = fit_and_rss(y, &[y], lag, n_obs);
153    let rss_u = fit_and_rss(y, &[y, x], lag, n_obs);
154
155    (rss_r, rss_u, n_obs)
156}
157
158/// Fit a simple autoregressive model and return residual sum of squares.
159///
160/// `y[lag..] = sum over each series in `predictors` of (sum_l beta_l * series[t-l]) + epsilon`
161///
162/// Uses the normal equation (X^T X)^{-1} X^T y via iterative least squares
163/// simplified to a simple approach.
164fn fit_and_rss(y: &[f64], predictors: &[&[f64]], lag: usize, n_obs: usize) -> f64 {
165    let n_features = predictors.len() * lag;
166    if n_features == 0 || n_obs == 0 {
167        return y.iter().map(|v| v * v).sum();
168    }
169
170    // Build X matrix (n_obs × n_features) and y vector
171    let mut x_mat: Vec<Vec<f64>> = Vec::with_capacity(n_obs);
172    let mut y_vec: Vec<f64> = Vec::with_capacity(n_obs);
173
174    for t in lag..(lag + n_obs) {
175        let mut row = Vec::with_capacity(n_features);
176        for pred in predictors {
177            for l in 1..=lag {
178                row.push(pred[t - l]);
179            }
180        }
181        x_mat.push(row);
182        y_vec.push(y[t]);
183    }
184
185    // Solve via normal equations: β = (X^T X)^{-1} X^T y
186    // Using simple gradient-free approach: compute pseudo-inverse iteratively
187    // For small lag (1-20) and moderate n, this is fine
188    let beta = solve_ols(&x_mat, &y_vec, n_features);
189
190    // Compute RSS
191    let mut rss = 0.0;
192    for (i, row) in x_mat.iter().enumerate() {
193        let pred: f64 = row.iter().zip(beta.iter()).map(|(x, b)| x * b).sum();
194        let resid = y_vec[i] - pred;
195        rss += resid * resid;
196    }
197
198    rss
199}
200
201/// Solve OLS via normal equations with regularization.
202fn solve_ols(x: &[Vec<f64>], y: &[f64], p: usize) -> Vec<f64> {
203    let n = x.len();
204    if n == 0 || p == 0 {
205        return vec![0.0; p];
206    }
207
208    // X^T X (p × p)
209    let mut xtx = vec![vec![0.0f64; p]; p];
210    for row in x {
211        for i in 0..p {
212            for j in 0..p {
213                xtx[i][j] += row[i] * row[j];
214            }
215        }
216    }
217
218    // Ridge regularization (small lambda for numerical stability)
219    let lambda = 1e-8;
220    for (i, row) in xtx.iter_mut().enumerate().take(p) {
221        row[i] += lambda;
222    }
223
224    // X^T y (p × 1)
225    let mut xty = vec![0.0f64; p];
226    for (row, &yi) in x.iter().zip(y.iter()) {
227        for i in 0..p {
228            xty[i] += row[i] * yi;
229        }
230    }
231
232    // Solve via Cholesky decomposition (X^T X is PD with ridge)
233    cholesky_solve(&xtx, &xty)
234}
235
236/// Solve Ax = b where A is symmetric positive definite via Cholesky.
237fn cholesky_solve(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
238    let n = a.len();
239
240    // Cholesky decomposition: A = L L^T
241    let mut l = vec![vec![0.0f64; n]; n];
242    for i in 0..n {
243        for j in 0..=i {
244            let mut s = 0.0;
245            for k in 0..j {
246                s += l[i][k] * l[j][k];
247            }
248            if i == j {
249                let val = a[i][i] - s;
250                l[i][j] = if val > 0.0 { val.sqrt() } else { 1e-12 };
251            } else {
252                l[i][j] = if l[j][j].abs() > 1e-15 {
253                    (a[i][j] - s) / l[j][j]
254                } else {
255                    0.0
256                };
257            }
258        }
259    }
260
261    // Forward substitution: L z = b
262    let mut z = vec![0.0f64; n];
263    for i in 0..n {
264        let mut s = 0.0;
265        for (lij, zj) in l[i].iter().zip(z.iter()).take(i) {
266            s += lij * zj;
267        }
268        z[i] = if l[i][i].abs() > 1e-15 {
269            (b[i] - s) / l[i][i]
270        } else {
271            0.0
272        };
273    }
274
275    // Backward substitution: L^T x = z
276    let mut x = vec![0.0f64; n];
277    for i in (0..n).rev() {
278        let mut s = 0.0;
279        for j in (i + 1)..n {
280            s += l[j][i] * x[j];
281        }
282        x[i] = if l[i][i].abs() > 1e-15 {
283            (z[i] - s) / l[i][i]
284        } else {
285            0.0
286        };
287    }
288
289    x
290}
291
292// ─── F-test & Fisher's method ───────────────────────────────────────
293
294/// Compute F-statistic from restricted and unrestricted RSS.
295///
296/// F = ((RSS_r - RSS_u) / q) / (RSS_u / (n - p_u))
297/// where q = extra parameters in unrestricted model, p_u = total params.
298fn f_statistic(rss_r: f64, rss_u: f64, q: usize, n: usize, p_u: usize) -> f64 {
299    if rss_u <= 0.0 || n <= p_u || q == 0 {
300        return 0.0;
301    }
302    let numerator = (rss_r - rss_u) / q as f64;
303    let denominator = rss_u / (n - p_u) as f64;
304    if denominator <= 0.0 {
305        0.0
306    } else {
307        (numerator / denominator).max(0.0)
308    }
309}
310
311/// Approximate p-value from F-statistic using the F-distribution.
312///
313/// Uses the regularized incomplete beta function approximation.
314/// For F(q, n-p) distribution.
315fn f_to_p(f: f64, df1: usize, df2: usize) -> f64 {
316    if f <= 0.0 || df1 == 0 || df2 == 0 {
317        return 1.0;
318    }
319
320    // Use the relationship: P(F > f) = I_x(df2/2, df1/2)
321    // where x = df2 / (df2 + df1 * f)
322    let x = df2 as f64 / (df2 as f64 + df1 as f64 * f);
323    regularized_incomplete_beta(x, df2 as f64 / 2.0, df1 as f64 / 2.0)
324}
325
326/// Regularized incomplete beta function via continued fraction (Lentz's method).
327///
328/// I_x(a, b) = B_x(a, b) / B(a, b)
329fn regularized_incomplete_beta(x: f64, a: f64, b: f64) -> f64 {
330    if x <= 0.0 {
331        return 0.0;
332    }
333    if x >= 1.0 {
334        return 1.0;
335    }
336
337    // Use symmetry relation if needed for convergence
338    if x > (a + 1.0) / (a + b + 2.0) {
339        return 1.0 - regularized_incomplete_beta(1.0 - x, b, a);
340    }
341
342    let ln_prefix = a * x.ln() + b * (1.0 - x).ln() - (a.ln() + ln_beta(a, b));
343    let prefix = ln_prefix.exp();
344
345    // Lentz's continued fraction
346    let mut c = 1.0;
347    let mut d = 1.0 - (a + b) * x / (a + 1.0);
348    if d.abs() < 1e-30 {
349        d = 1e-30;
350    }
351    d = 1.0 / d;
352    let mut f = d;
353
354    for m in 1..200 {
355        let m_f = m as f64;
356
357        // Even step
358        let num_even = m_f * (b - m_f) * x / ((a + 2.0 * m_f - 1.0) * (a + 2.0 * m_f));
359        d = 1.0 + num_even * d;
360        if d.abs() < 1e-30 {
361            d = 1e-30;
362        }
363        c = 1.0 + num_even / c;
364        if c.abs() < 1e-30 {
365            c = 1e-30;
366        }
367        d = 1.0 / d;
368        f *= d * c;
369
370        // Odd step
371        let num_odd = -(a + m_f) * (a + b + m_f) * x / ((a + 2.0 * m_f) * (a + 2.0 * m_f + 1.0));
372        d = 1.0 + num_odd * d;
373        if d.abs() < 1e-30 {
374            d = 1e-30;
375        }
376        c = 1.0 + num_odd / c;
377        if c.abs() < 1e-30 {
378            c = 1e-30;
379        }
380        d = 1.0 / d;
381        let delta = d * c;
382        f *= delta;
383
384        if (delta - 1.0).abs() < 1e-10 {
385            break;
386        }
387    }
388
389    prefix * f / a
390}
391
392/// Log of the Beta function: ln(B(a,b)) = ln(Gamma(a)) + ln(Gamma(b)) - ln(Gamma(a+b))
393fn ln_beta(a: f64, b: f64) -> f64 {
394    ln_gamma(a) + ln_gamma(b) - ln_gamma(a + b)
395}
396
397/// Stirling's approximation for ln(Gamma(x)) for x > 0.
398fn ln_gamma(x: f64) -> f64 {
399    if x <= 0.0 {
400        return 0.0;
401    }
402    // Lanczos approximation
403    let g = 7.0;
404    let coefs = [
405        0.999_999_999_999_809_9,
406        676.520_368_121_885_1,
407        -1_259.139_216_722_402_9,
408        771.323_428_777_653_1,
409        -176.615_029_162_140_6,
410        12.507_343_278_686_905,
411        -0.138_571_095_265_720_12,
412        9.984_369_578_019_572e-6,
413        1.505_632_735_149_311_6e-7,
414    ];
415
416    let xx = x - 1.0;
417    let mut sum = coefs[0];
418    for (i, &c) in coefs[1..].iter().enumerate() {
419        sum += c / (xx + i as f64 + 1.0);
420    }
421
422    let t = xx + g + 0.5;
423    0.5 * (2.0 * std::f64::consts::PI).ln() + (xx + 0.5) * t.ln() - t + sum.ln()
424}
425
426/// Fisher's method: combine independent p-values.
427///
428/// χ² = -2 Σ ln(p_i), with 2k degrees of freedom.
429fn fisher_combine(p_values: &[f64]) -> f64 {
430    let valid: Vec<f64> = p_values
431        .iter()
432        .filter(|&&p| p > 0.0 && p <= 1.0)
433        .copied()
434        .collect();
435
436    if valid.is_empty() {
437        return 1.0;
438    }
439
440    let chi2: f64 = -2.0 * valid.iter().map(|p| p.ln()).sum::<f64>();
441    let df = 2 * valid.len();
442
443    // Approximate chi-squared p-value using Wilson-Hilferty transformation
444    let k = df as f64;
445    let z = ((chi2 / k).powf(1.0 / 3.0) - (1.0 - 2.0 / (9.0 * k))) / (2.0 / (9.0 * k)).sqrt();
446
447    // Standard normal survival function (1 - Φ(z))
448    0.5 * erfc(z / std::f64::consts::SQRT_2)
449}
450
451/// Complementary error function approximation.
452fn erfc(x: f64) -> f64 {
453    // Abramowitz & Stegun approximation 7.1.26
454    let t = 1.0 / (1.0 + 0.327_591_1 * x.abs());
455    let poly = t
456        * (0.254_829_592
457            + t * (-0.284_496_736
458                + t * (1.421_413_741 + t * (-1.453_152_027 + t * 1.061_405_429))));
459    let result = poly * (-x * x).exp();
460    if x >= 0.0 { result } else { 2.0 - result }
461}
462
463// ─── Core function ──────────────────────────────────────────────────
464
465/// Test Granger causality between two embedding trajectories.
466///
467/// # Arguments
468///
469/// * `traj_a` — Entity A's trajectory (sorted by timestamp)
470/// * `traj_b` — Entity B's trajectory (sorted by timestamp)
471/// * `max_lag` — Maximum lag to test (number of time steps)
472/// * `significance` — P-value threshold for significance (e.g., 0.05)
473///
474/// # Errors
475///
476/// Returns [`AnalyticsError::InsufficientData`] if trajectories are too short
477/// for the requested lag.
478pub fn granger_causality(
479    traj_a: &[(i64, &[f32])],
480    traj_b: &[(i64, &[f32])],
481    max_lag: usize,
482    significance: f64,
483) -> Result<GrangerResult, AnalyticsError> {
484    if traj_a.len() < 3 || traj_b.len() < 3 {
485        return Err(AnalyticsError::InsufficientData {
486            needed: 3,
487            have: traj_a.len().min(traj_b.len()),
488        });
489    }
490
491    let (aligned_a, aligned_b) = align_trajectories(traj_a, traj_b);
492    let n = aligned_a.len();
493    let dim = aligned_a[0].len();
494
495    if n < max_lag + 3 {
496        return Err(AnalyticsError::InsufficientData {
497            needed: max_lag + 3,
498            have: n,
499        });
500    }
501
502    // Convert to f64 per-dimension series
503    let a_dims: Vec<Vec<f64>> = (0..dim)
504        .map(|d| aligned_a.iter().map(|v| v[d] as f64).collect())
505        .collect();
506    let b_dims: Vec<Vec<f64>> = (0..dim)
507        .map(|d| aligned_b.iter().map(|v| v[d] as f64).collect())
508        .collect();
509
510    // Test each lag, keep the best
511    let mut best_a2b_p = 1.0f64;
512    let mut best_a2b_lag = 1;
513    let mut best_a2b_f = 0.0;
514    let mut best_a2b_effect = 0.0;
515    let mut best_a2b_per_dim = vec![0.0; dim];
516
517    let mut best_b2a_p = 1.0f64;
518    let mut best_b2a_lag = 1;
519    let mut best_b2a_f = 0.0;
520    let mut best_b2a_effect = 0.0;
521    let mut best_b2a_per_dim = vec![0.0; dim];
522
523    for lag in 1..=max_lag {
524        // A → B: does A's past improve prediction of B?
525        let (p_a2b, f_a2b, effect_a2b, per_dim_a2b) = test_direction(&a_dims, &b_dims, lag, n);
526
527        if p_a2b < best_a2b_p {
528            best_a2b_p = p_a2b;
529            best_a2b_lag = lag;
530            best_a2b_f = f_a2b;
531            best_a2b_effect = effect_a2b;
532            best_a2b_per_dim = per_dim_a2b;
533        }
534
535        // B → A: does B's past improve prediction of A?
536        let (p_b2a, f_b2a, effect_b2a, per_dim_b2a) = test_direction(&b_dims, &a_dims, lag, n);
537
538        if p_b2a < best_b2a_p {
539            best_b2a_p = p_b2a;
540            best_b2a_lag = lag;
541            best_b2a_f = f_b2a;
542            best_b2a_effect = effect_b2a;
543            best_b2a_per_dim = per_dim_b2a;
544        }
545    }
546
547    let a2b_sig = best_a2b_p < significance;
548    let b2a_sig = best_b2a_p < significance;
549
550    let (direction, optimal_lag, f_stat, p_val, effect) = match (a2b_sig, b2a_sig) {
551        (true, true) => (
552            GrangerDirection::Bidirectional,
553            if best_a2b_p < best_b2a_p {
554                best_a2b_lag
555            } else {
556                best_b2a_lag
557            },
558            best_a2b_f.max(best_b2a_f),
559            best_a2b_p.min(best_b2a_p),
560            best_a2b_effect.max(best_b2a_effect),
561        ),
562        (true, false) => (
563            GrangerDirection::AToB,
564            best_a2b_lag,
565            best_a2b_f,
566            best_a2b_p,
567            best_a2b_effect,
568        ),
569        (false, true) => (
570            GrangerDirection::BToA,
571            best_b2a_lag,
572            best_b2a_f,
573            best_b2a_p,
574            best_b2a_effect,
575        ),
576        (false, false) => (GrangerDirection::None, 1, 0.0, 1.0, 0.0),
577    };
578
579    Ok(GrangerResult {
580        direction,
581        optimal_lag,
582        f_statistic: f_stat,
583        p_value: p_val,
584        effect_size: effect,
585        per_dimension_a_to_b: best_a2b_per_dim,
586        per_dimension_b_to_a: best_b2a_per_dim,
587    })
588}
589
590/// Test one direction: does `cause` Granger-cause `effect`?
591///
592/// Returns `(combined_p, mean_f, effect_size, per_dim_f)`.
593fn test_direction(
594    cause_dims: &[Vec<f64>],
595    effect_dims: &[Vec<f64>],
596    lag: usize,
597    _n: usize,
598) -> (f64, f64, f64, Vec<f64>) {
599    let dim = effect_dims.len();
600    let mut p_values = Vec::with_capacity(dim);
601    let mut f_values = Vec::with_capacity(dim);
602    let mut total_rss_r = 0.0;
603    let mut total_rss_u = 0.0;
604
605    for d in 0..dim {
606        let (rss_r, rss_u, n_obs) = ols_granger_single_dim(&effect_dims[d], &cause_dims[d], lag);
607
608        let q = lag; // extra parameters
609        let p_u = 2 * lag; // total params in unrestricted model
610        let f = f_statistic(rss_r, rss_u, q, n_obs, p_u);
611        let df2 = if n_obs > p_u { n_obs - p_u } else { 1 };
612        let p = f_to_p(f, q, df2);
613
614        f_values.push(f);
615        p_values.push(p);
616        total_rss_r += rss_r;
617        total_rss_u += rss_u;
618    }
619
620    let combined_p = fisher_combine(&p_values);
621    let mean_f = if f_values.is_empty() {
622        0.0
623    } else {
624        f_values.iter().sum::<f64>() / f_values.len() as f64
625    };
626
627    // Effect size: proportional reduction in RSS
628    let effect = if total_rss_r > 0.0 {
629        ((total_rss_r - total_rss_u) / total_rss_r).max(0.0)
630    } else {
631        0.0
632    };
633
634    (combined_p, mean_f, effect, f_values)
635}
636
637// ─── Tests ──────────────────────────────────────────────────────────
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642
643    fn as_refs(points: &[(i64, Vec<f32>)]) -> Vec<(i64, &[f32])> {
644        points.iter().map(|(t, v)| (*t, v.as_slice())).collect()
645    }
646
647    // ─── interpolation ──────────────────────────────────────────
648
649    #[test]
650    fn interpolate_at_exact_point() {
651        let owned = vec![(100i64, vec![1.0f32]), (200, vec![2.0]), (300, vec![3.0])];
652        let traj = as_refs(&owned);
653        let v = interpolate_at(&traj, 200, 1);
654        assert!((v[0] - 2.0).abs() < 1e-6);
655    }
656
657    #[test]
658    fn interpolate_at_midpoint() {
659        let owned = vec![(100i64, vec![0.0f32]), (200, vec![10.0])];
660        let traj = as_refs(&owned);
661        let v = interpolate_at(&traj, 150, 1);
662        assert!((v[0] - 5.0).abs() < 1e-4);
663    }
664
665    #[test]
666    fn interpolate_at_boundary() {
667        let owned = vec![(100i64, vec![5.0f32]), (200, vec![10.0])];
668        let traj = as_refs(&owned);
669        assert!((interpolate_at(&traj, 50, 1)[0] - 5.0).abs() < 1e-6);
670        assert!((interpolate_at(&traj, 300, 1)[0] - 10.0).abs() < 1e-6);
671    }
672
673    // ─── cholesky_solve ─────────────────────────────────────────
674
675    #[test]
676    fn cholesky_simple() {
677        // Solve [[4,2],[2,3]] x = [1, 2]
678        let a = vec![vec![4.0, 2.0], vec![2.0, 3.0]];
679        let b = vec![1.0, 2.0];
680        let x = cholesky_solve(&a, &b);
681        // Expected: x = [-0.125, 0.75]
682        assert!((x[0] - (-0.125)).abs() < 1e-6, "got {}", x[0]);
683        assert!((x[1] - 0.75).abs() < 1e-6, "got {}", x[1]);
684    }
685
686    // ─── f_to_p ─────────────────────────────────────────────────
687
688    #[test]
689    fn f_to_p_zero_f() {
690        assert!((f_to_p(0.0, 5, 20) - 1.0).abs() < 0.01);
691    }
692
693    #[test]
694    fn f_to_p_large_f() {
695        let p = f_to_p(100.0, 5, 50);
696        assert!(p < 0.001, "very large F should give very small p, got {p}");
697    }
698
699    // ─── fisher_combine ─────────────────────────────────────────
700
701    #[test]
702    fn fisher_all_significant() {
703        let p = fisher_combine(&[0.01, 0.02, 0.01]);
704        assert!(
705            p < 0.05,
706            "combined very significant p-values should be significant, got {p}"
707        );
708    }
709
710    #[test]
711    fn fisher_all_nonsignificant() {
712        let p = fisher_combine(&[0.8, 0.9, 0.7]);
713        assert!(
714            p > 0.3,
715            "combined non-significant should remain non-significant, got {p}"
716        );
717    }
718
719    // ─── granger_causality ──────────────────────────────────────
720
721    #[test]
722    fn granger_insufficient_data() {
723        let a_owned = vec![(0i64, vec![1.0f32]), (1, vec![2.0])];
724        let b_owned = vec![(0i64, vec![3.0f32]), (1, vec![4.0])];
725        let a = as_refs(&a_owned);
726        let b = as_refs(&b_owned);
727        let result = granger_causality(&a, &b, 3, 0.05);
728        assert!(result.is_err());
729    }
730
731    #[test]
732    fn granger_synthetic_a_causes_b() {
733        // A is a sine wave, B is A shifted by 2 steps (A leads B)
734        let n = 100;
735        let lag = 2;
736        let a_owned: Vec<(i64, Vec<f32>)> = (0..n)
737            .map(|i| {
738                let t = i as f64 * 0.2;
739                (i as i64 * 1000, vec![t.sin() as f32])
740            })
741            .collect();
742
743        let b_owned: Vec<(i64, Vec<f32>)> = (0..n)
744            .map(|i| {
745                let t = (i as i64 - lag as i64).max(0) as f64 * 0.2;
746                (i as i64 * 1000, vec![t.sin() as f32 + 0.01 * (i as f32)])
747            })
748            .collect();
749
750        let a = as_refs(&a_owned);
751        let b = as_refs(&b_owned);
752
753        let result = granger_causality(&a, &b, 5, 0.05).unwrap();
754
755        // We expect A→B direction or at least that A→B has some signal
756        assert!(
757            result.per_dimension_a_to_b[0] > 0.0,
758            "A should have some predictive power for B"
759        );
760    }
761
762    #[test]
763    fn granger_independent_series() {
764        // Two completely independent random-ish series
765        let n = 80;
766        let a_owned: Vec<(i64, Vec<f32>)> = (0..n)
767            .map(|i| {
768                let v = ((i as f64 * 1.7).sin() * 100.0) as f32;
769                (i as i64 * 1000, vec![v])
770            })
771            .collect();
772        let b_owned: Vec<(i64, Vec<f32>)> = (0..n)
773            .map(|i| {
774                let v = ((i as f64 * 3.1 + 42.0).cos() * 100.0) as f32;
775                (i as i64 * 1000, vec![v])
776            })
777            .collect();
778
779        let a = as_refs(&a_owned);
780        let b = as_refs(&b_owned);
781
782        let result = granger_causality(&a, &b, 3, 0.05).unwrap();
783
784        // Independent series should ideally show no causality
785        // But with deterministic pseudo-random, there could be spurious correlation
786        // Just verify the function runs without error and returns valid values
787        assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
788        assert!(result.f_statistic >= 0.0);
789    }
790
791    #[test]
792    fn granger_multidimensional() {
793        // 3D trajectory where A leads B in all dims
794        let n = 60;
795        let a_owned: Vec<(i64, Vec<f32>)> = (0..n)
796            .map(|i| {
797                let t = i as f64 * 0.15;
798                (
799                    i as i64 * 1000,
800                    vec![t.sin() as f32, t.cos() as f32, (t * 0.5).sin() as f32],
801                )
802            })
803            .collect();
804        let b_owned: Vec<(i64, Vec<f32>)> = (0..n)
805            .map(|i| {
806                let t = (i as i64 - 3).max(0) as f64 * 0.15;
807                (
808                    i as i64 * 1000,
809                    vec![t.sin() as f32, t.cos() as f32, (t * 0.5).sin() as f32],
810                )
811            })
812            .collect();
813
814        let a = as_refs(&a_owned);
815        let b = as_refs(&b_owned);
816
817        let result = granger_causality(&a, &b, 5, 0.1).unwrap();
818
819        assert_eq!(result.per_dimension_a_to_b.len(), 3);
820        assert_eq!(result.per_dimension_b_to_a.len(), 3);
821    }
822
823    #[test]
824    fn granger_result_has_valid_fields() {
825        let n = 50;
826        let a_owned: Vec<(i64, Vec<f32>)> = (0..n)
827            .map(|i| (i as i64 * 1000, vec![i as f32 * 0.1]))
828            .collect();
829        let b_owned: Vec<(i64, Vec<f32>)> = (0..n)
830            .map(|i| (i as i64 * 1000, vec![i as f32 * 0.2 + 1.0]))
831            .collect();
832
833        let a = as_refs(&a_owned);
834        let b = as_refs(&b_owned);
835
836        let result = granger_causality(&a, &b, 3, 0.05).unwrap();
837
838        assert!(result.optimal_lag >= 1 && result.optimal_lag <= 3);
839        assert!(result.f_statistic >= 0.0);
840        assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
841        assert!(result.effect_size >= 0.0 && result.effect_size <= 1.0);
842    }
843}