cvx_analytics/
ode.rs

1//! ODE solver and prediction engine.
2//!
3//! ## RK45 (Dormand-Prince) Adaptive Solver
4//!
5//! A 4th/5th order adaptive Runge-Kutta solver for systems of ODEs.
6//! Used to integrate learned dynamics $\frac{dy}{dt} = f(t, y)$.
7//!
8//! ## Linear Extrapolation Fallback
9//!
10//! When no Neural ODE model is available, uses linear extrapolation
11//! from the last two observations as a simple baseline.
12
13use cvx_core::error::AnalyticsError;
14
15/// ODE system: dy/dt = f(t, y).
16///
17/// Implementations define the right-hand side of the ODE system.
18pub trait OdeSystem {
19    /// Evaluate the derivative at `(t, y)`.
20    fn derivative(&self, t: f64, y: &[f64]) -> Vec<f64>;
21}
22
23/// RK45 solver configuration.
24#[derive(Debug, Clone)]
25pub struct Rk45Config {
26    /// Relative tolerance for adaptive step size.
27    pub rtol: f64,
28    /// Absolute tolerance.
29    pub atol: f64,
30    /// Initial step size.
31    pub h_init: f64,
32    /// Minimum step size.
33    pub h_min: f64,
34    /// Maximum step size.
35    pub h_max: f64,
36    /// Maximum number of steps.
37    pub max_steps: usize,
38}
39
40impl Default for Rk45Config {
41    fn default() -> Self {
42        Self {
43            rtol: 1e-6,
44            atol: 1e-9,
45            h_init: 0.01,
46            h_min: 1e-12,
47            h_max: 1.0,
48            max_steps: 100_000,
49        }
50    }
51}
52
53/// Result of an ODE integration.
54#[derive(Debug, Clone)]
55pub struct OdeResult {
56    /// Final time.
57    pub t: f64,
58    /// Final state.
59    pub y: Vec<f64>,
60    /// Number of steps taken.
61    pub steps: usize,
62}
63
64/// Integrate an ODE system from `(t0, y0)` to `t_end` using Dormand-Prince RK45.
65///
66/// The Dormand-Prince method uses 7 function evaluations per step with an
67/// embedded error estimate for adaptive step-size control.
68pub fn rk45_integrate(
69    system: &dyn OdeSystem,
70    t0: f64,
71    y0: &[f64],
72    t_end: f64,
73    config: &Rk45Config,
74) -> Result<OdeResult, AnalyticsError> {
75    let dim = y0.len();
76    let mut t = t0;
77    let mut y = y0.to_vec();
78    let mut h = config.h_init.min(config.h_max);
79    let direction = if t_end >= t0 { 1.0 } else { -1.0 };
80    h *= direction;
81
82    let mut steps = 0;
83
84    // Dormand-Prince coefficients
85    let a2 = 1.0 / 5.0;
86    let a3 = 3.0 / 10.0;
87    let a4 = 4.0 / 5.0;
88    let a5 = 8.0 / 9.0;
89
90    let b21 = 1.0 / 5.0;
91    let b31 = 3.0 / 40.0;
92    let b32 = 9.0 / 40.0;
93    let b41 = 44.0 / 45.0;
94    let b42 = -56.0 / 15.0;
95    let b43 = 32.0 / 9.0;
96    let b51 = 19372.0 / 6561.0;
97    let b52 = -25360.0 / 2187.0;
98    let b53 = 64448.0 / 6561.0;
99    let b54 = -212.0 / 729.0;
100    let b61 = 9017.0 / 3168.0;
101    let b62 = -355.0 / 33.0;
102    let b63 = 46732.0 / 5247.0;
103    let b64 = 49.0 / 176.0;
104    let b65 = -5103.0 / 18656.0;
105
106    // 5th order weights
107    let c1 = 35.0 / 384.0;
108    let c3 = 500.0 / 1113.0;
109    let c4 = 125.0 / 192.0;
110    let c5 = -2187.0 / 6784.0;
111    let c6 = 11.0 / 84.0;
112
113    // 4th order weights (for error estimation)
114    let e1 = 71.0 / 57600.0;
115    let e3 = -71.0 / 16695.0;
116    let e4 = 71.0 / 1920.0;
117    let e5 = -17253.0 / 339200.0;
118    let e6 = 22.0 / 525.0;
119    let e7 = -1.0 / 40.0;
120
121    while (t_end - t) * direction > 1e-15 {
122        if steps >= config.max_steps {
123            return Err(AnalyticsError::SolverDiverged { step: steps });
124        }
125
126        // Clamp step to not overshoot
127        if (t + h - t_end) * direction > 0.0 {
128            h = t_end - t;
129        }
130
131        // k1 = f(t, y)
132        let k1 = system.derivative(t, &y);
133
134        // k2
135        let y2: Vec<f64> = (0..dim).map(|i| y[i] + h * b21 * k1[i]).collect();
136        let k2 = system.derivative(t + a2 * h, &y2);
137
138        // k3
139        let y3: Vec<f64> = (0..dim)
140            .map(|i| y[i] + h * (b31 * k1[i] + b32 * k2[i]))
141            .collect();
142        let k3 = system.derivative(t + a3 * h, &y3);
143
144        // k4
145        let y4: Vec<f64> = (0..dim)
146            .map(|i| y[i] + h * (b41 * k1[i] + b42 * k2[i] + b43 * k3[i]))
147            .collect();
148        let k4 = system.derivative(t + a4 * h, &y4);
149
150        // k5
151        let y5: Vec<f64> = (0..dim)
152            .map(|i| y[i] + h * (b51 * k1[i] + b52 * k2[i] + b53 * k3[i] + b54 * k4[i]))
153            .collect();
154        let k5 = system.derivative(t + a5 * h, &y5);
155
156        // k6
157        let y6: Vec<f64> = (0..dim)
158            .map(|i| {
159                y[i] + h * (b61 * k1[i] + b62 * k2[i] + b63 * k3[i] + b64 * k4[i] + b65 * k5[i])
160            })
161            .collect();
162        let k6 = system.derivative(t + h, &y6);
163
164        // 5th order solution
165        let y_new: Vec<f64> = (0..dim)
166            .map(|i| y[i] + h * (c1 * k1[i] + c3 * k3[i] + c4 * k4[i] + c5 * k5[i] + c6 * k6[i]))
167            .collect();
168
169        // k7 (for error estimate)
170        let k7 = system.derivative(t + h, &y_new);
171
172        // Error estimate
173        let mut err = 0.0;
174        for i in 0..dim {
175            let ei =
176                h * (e1 * k1[i] + e3 * k3[i] + e4 * k4[i] + e5 * k5[i] + e6 * k6[i] + e7 * k7[i]);
177            let scale = config.atol + config.rtol * y_new[i].abs().max(y[i].abs());
178            err += (ei / scale) * (ei / scale);
179        }
180        err = (err / dim as f64).sqrt();
181
182        if err <= 1.0 {
183            // Accept step
184            t += h;
185            y = y_new;
186            steps += 1;
187        }
188
189        // Adjust step size
190        let safety = 0.9;
191        let factor = if err > 0.0 {
192            safety * err.powf(-0.2)
193        } else {
194            5.0
195        };
196        h *= factor.clamp(0.2, 5.0);
197        h = h.abs().clamp(config.h_min, config.h_max) * direction;
198    }
199
200    Ok(OdeResult { t, y, steps })
201}
202
203// ─── Linear Extrapolation Fallback ──────────────────────────────────
204
205/// Predict a future vector using linear extrapolation.
206///
207/// Uses the last two observations to estimate velocity and extrapolate.
208pub fn linear_extrapolate(
209    trajectory: &[(i64, &[f32])],
210    target_timestamp: i64,
211) -> Result<Vec<f32>, AnalyticsError> {
212    if trajectory.len() < 2 {
213        return Err(AnalyticsError::InsufficientData {
214            needed: 2,
215            have: trajectory.len(),
216        });
217    }
218
219    let n = trajectory.len();
220    let (t1, v1) = &trajectory[n - 2];
221    let (t2, v2) = &trajectory[n - 1];
222
223    let dt = (*t2 - *t1) as f64;
224    if dt == 0.0 {
225        return Ok(v2.to_vec());
226    }
227
228    let dt_target = (target_timestamp - *t2) as f64;
229    let ratio = dt_target / dt;
230
231    let predicted: Vec<f32> = v1
232        .iter()
233        .zip(v2.iter())
234        .map(|(&a, &b)| {
235            let vel = b as f64 - a as f64;
236            (b as f64 + vel * ratio) as f32
237        })
238        .collect();
239
240    Ok(predicted)
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    // ─── Simple ODE systems for testing ─────────────────────────────
248
249    /// dy/dt = -y (exponential decay: y(t) = y0 * e^{-t})
250    struct ExponentialDecay;
251    impl OdeSystem for ExponentialDecay {
252        fn derivative(&self, _t: f64, y: &[f64]) -> Vec<f64> {
253            vec![-y[0]]
254        }
255    }
256
257    /// dx/dt = v, dv/dt = -x (simple harmonic oscillator: x(t) = cos(t))
258    struct HarmonicOscillator;
259    impl OdeSystem for HarmonicOscillator {
260        fn derivative(&self, _t: f64, y: &[f64]) -> Vec<f64> {
261            vec![y[1], -y[0]] // [dx/dt, dv/dt]
262        }
263    }
264
265    /// dy/dt = y (exponential growth)
266    struct ExponentialGrowth;
267    impl OdeSystem for ExponentialGrowth {
268        fn derivative(&self, _t: f64, y: &[f64]) -> Vec<f64> {
269            vec![y[0]]
270        }
271    }
272
273    // ─── RK45 tests ─────────────────────────────────────────────────
274
275    #[test]
276    fn exponential_decay_matches_analytical() {
277        let system = ExponentialDecay;
278        let y0 = [1.0];
279        let t_end = 5.0;
280        let config = Rk45Config {
281            rtol: 1e-8,
282            ..Default::default()
283        };
284
285        let result = rk45_integrate(&system, 0.0, &y0, t_end, &config).unwrap();
286        let analytical = (-t_end).exp();
287
288        assert!(
289            (result.y[0] - analytical).abs() < 1e-5,
290            "RK45: {}, analytical: {analytical}",
291            result.y[0]
292        );
293    }
294
295    #[test]
296    fn harmonic_oscillator_matches_analytical() {
297        let system = HarmonicOscillator;
298        let y0 = [1.0, 0.0]; // x=1, v=0 → x(t) = cos(t)
299        let t_end = 2.0 * std::f64::consts::PI;
300        let config = Rk45Config {
301            rtol: 1e-8,
302            ..Default::default()
303        };
304
305        let result = rk45_integrate(&system, 0.0, &y0, t_end, &config).unwrap();
306        // After one full period, should return to (1, 0)
307        assert!(
308            (result.y[0] - 1.0).abs() < 1e-5,
309            "x after full period: {}, expected 1.0",
310            result.y[0]
311        );
312        assert!(
313            result.y[1].abs() < 1e-5,
314            "v after full period: {}, expected 0.0",
315            result.y[1]
316        );
317    }
318
319    #[test]
320    fn rk45_adaptive_step_count() {
321        // Simple decay should need few steps with good tolerance
322        let system = ExponentialDecay;
323        let config = Rk45Config {
324            rtol: 1e-6,
325            ..Default::default()
326        };
327
328        let result = rk45_integrate(&system, 0.0, &[1.0], 1.0, &config).unwrap();
329        assert!(
330            result.steps < 100,
331            "should need few steps, got {}",
332            result.steps
333        );
334    }
335
336    #[test]
337    fn rk45_backward_integration() {
338        let system = ExponentialDecay;
339        let config = Rk45Config::default();
340
341        // Integrate backward: from t=1 to t=0
342        let y_at_1 = (-1.0f64).exp();
343        let result = rk45_integrate(&system, 1.0, &[y_at_1], 0.0, &config).unwrap();
344        assert!(
345            (result.y[0] - 1.0).abs() < 1e-4,
346            "backward: {}, expected 1.0",
347            result.y[0]
348        );
349    }
350
351    #[test]
352    fn rk45_exponential_growth() {
353        let system = ExponentialGrowth;
354        let config = Rk45Config {
355            rtol: 1e-6,
356            h_max: 0.5,
357            ..Default::default()
358        };
359
360        let result = rk45_integrate(&system, 0.0, &[1.0], 3.0, &config).unwrap();
361        let analytical = 3.0f64.exp();
362        assert!(
363            (result.y[0] - analytical).abs() / analytical < 1e-5,
364            "growth: {}, analytical: {analytical}",
365            result.y[0]
366        );
367    }
368
369    // ─── Linear extrapolation ───────────────────────────────────────
370
371    #[test]
372    fn linear_extrapolate_constant() {
373        let points = [(0i64, vec![1.0f32, 2.0, 3.0]), (1000, vec![1.0, 2.0, 3.0])];
374        let traj: Vec<(i64, &[f32])> = points.iter().map(|(t, v)| (*t, v.as_slice())).collect();
375
376        let pred = linear_extrapolate(&traj, 5000).unwrap();
377        assert_eq!(pred, vec![1.0, 2.0, 3.0]);
378    }
379
380    #[test]
381    fn linear_extrapolate_trend() {
382        let points = [(0i64, vec![0.0f32, 0.0]), (1000, vec![1.0, 2.0])];
383        let traj: Vec<(i64, &[f32])> = points.iter().map(|(t, v)| (*t, v.as_slice())).collect();
384
385        let pred = linear_extrapolate(&traj, 2000).unwrap();
386        // After 1 more unit: [2.0, 4.0]
387        assert!((pred[0] - 2.0).abs() < 1e-5);
388        assert!((pred[1] - 4.0).abs() < 1e-5);
389    }
390
391    #[test]
392    fn linear_extrapolate_backward() {
393        let points = [(0i64, vec![0.0f32]), (1000, vec![10.0])];
394        let traj: Vec<(i64, &[f32])> = points.iter().map(|(t, v)| (*t, v.as_slice())).collect();
395
396        let pred = linear_extrapolate(&traj, -1000).unwrap();
397        // 2 units before last: 10 + 10 * (-2) = -10
398        assert!((pred[0] - (-10.0)).abs() < 1e-5);
399    }
400
401    #[test]
402    fn linear_extrapolate_insufficient_data() {
403        let points = [(0i64, vec![1.0f32])];
404        let traj: Vec<(i64, &[f32])> = points.iter().map(|(t, v)| (*t, v.as_slice())).collect();
405        assert!(linear_extrapolate(&traj, 1000).is_err());
406    }
407}