1use cvx_core::error::AnalyticsError;
14
15pub trait OdeSystem {
19 fn derivative(&self, t: f64, y: &[f64]) -> Vec<f64>;
21}
22
23#[derive(Debug, Clone)]
25pub struct Rk45Config {
26 pub rtol: f64,
28 pub atol: f64,
30 pub h_init: f64,
32 pub h_min: f64,
34 pub h_max: f64,
36 pub max_steps: usize,
38}
39
40impl Default for Rk45Config {
41 fn default() -> Self {
42 Self {
43 rtol: 1e-6,
44 atol: 1e-9,
45 h_init: 0.01,
46 h_min: 1e-12,
47 h_max: 1.0,
48 max_steps: 100_000,
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct OdeResult {
56 pub t: f64,
58 pub y: Vec<f64>,
60 pub steps: usize,
62}
63
64pub fn rk45_integrate(
69 system: &dyn OdeSystem,
70 t0: f64,
71 y0: &[f64],
72 t_end: f64,
73 config: &Rk45Config,
74) -> Result<OdeResult, AnalyticsError> {
75 let dim = y0.len();
76 let mut t = t0;
77 let mut y = y0.to_vec();
78 let mut h = config.h_init.min(config.h_max);
79 let direction = if t_end >= t0 { 1.0 } else { -1.0 };
80 h *= direction;
81
82 let mut steps = 0;
83
84 let a2 = 1.0 / 5.0;
86 let a3 = 3.0 / 10.0;
87 let a4 = 4.0 / 5.0;
88 let a5 = 8.0 / 9.0;
89
90 let b21 = 1.0 / 5.0;
91 let b31 = 3.0 / 40.0;
92 let b32 = 9.0 / 40.0;
93 let b41 = 44.0 / 45.0;
94 let b42 = -56.0 / 15.0;
95 let b43 = 32.0 / 9.0;
96 let b51 = 19372.0 / 6561.0;
97 let b52 = -25360.0 / 2187.0;
98 let b53 = 64448.0 / 6561.0;
99 let b54 = -212.0 / 729.0;
100 let b61 = 9017.0 / 3168.0;
101 let b62 = -355.0 / 33.0;
102 let b63 = 46732.0 / 5247.0;
103 let b64 = 49.0 / 176.0;
104 let b65 = -5103.0 / 18656.0;
105
106 let c1 = 35.0 / 384.0;
108 let c3 = 500.0 / 1113.0;
109 let c4 = 125.0 / 192.0;
110 let c5 = -2187.0 / 6784.0;
111 let c6 = 11.0 / 84.0;
112
113 let e1 = 71.0 / 57600.0;
115 let e3 = -71.0 / 16695.0;
116 let e4 = 71.0 / 1920.0;
117 let e5 = -17253.0 / 339200.0;
118 let e6 = 22.0 / 525.0;
119 let e7 = -1.0 / 40.0;
120
121 while (t_end - t) * direction > 1e-15 {
122 if steps >= config.max_steps {
123 return Err(AnalyticsError::SolverDiverged { step: steps });
124 }
125
126 if (t + h - t_end) * direction > 0.0 {
128 h = t_end - t;
129 }
130
131 let k1 = system.derivative(t, &y);
133
134 let y2: Vec<f64> = (0..dim).map(|i| y[i] + h * b21 * k1[i]).collect();
136 let k2 = system.derivative(t + a2 * h, &y2);
137
138 let y3: Vec<f64> = (0..dim)
140 .map(|i| y[i] + h * (b31 * k1[i] + b32 * k2[i]))
141 .collect();
142 let k3 = system.derivative(t + a3 * h, &y3);
143
144 let y4: Vec<f64> = (0..dim)
146 .map(|i| y[i] + h * (b41 * k1[i] + b42 * k2[i] + b43 * k3[i]))
147 .collect();
148 let k4 = system.derivative(t + a4 * h, &y4);
149
150 let y5: Vec<f64> = (0..dim)
152 .map(|i| y[i] + h * (b51 * k1[i] + b52 * k2[i] + b53 * k3[i] + b54 * k4[i]))
153 .collect();
154 let k5 = system.derivative(t + a5 * h, &y5);
155
156 let y6: Vec<f64> = (0..dim)
158 .map(|i| {
159 y[i] + h * (b61 * k1[i] + b62 * k2[i] + b63 * k3[i] + b64 * k4[i] + b65 * k5[i])
160 })
161 .collect();
162 let k6 = system.derivative(t + h, &y6);
163
164 let y_new: Vec<f64> = (0..dim)
166 .map(|i| y[i] + h * (c1 * k1[i] + c3 * k3[i] + c4 * k4[i] + c5 * k5[i] + c6 * k6[i]))
167 .collect();
168
169 let k7 = system.derivative(t + h, &y_new);
171
172 let mut err = 0.0;
174 for i in 0..dim {
175 let ei =
176 h * (e1 * k1[i] + e3 * k3[i] + e4 * k4[i] + e5 * k5[i] + e6 * k6[i] + e7 * k7[i]);
177 let scale = config.atol + config.rtol * y_new[i].abs().max(y[i].abs());
178 err += (ei / scale) * (ei / scale);
179 }
180 err = (err / dim as f64).sqrt();
181
182 if err <= 1.0 {
183 t += h;
185 y = y_new;
186 steps += 1;
187 }
188
189 let safety = 0.9;
191 let factor = if err > 0.0 {
192 safety * err.powf(-0.2)
193 } else {
194 5.0
195 };
196 h *= factor.clamp(0.2, 5.0);
197 h = h.abs().clamp(config.h_min, config.h_max) * direction;
198 }
199
200 Ok(OdeResult { t, y, steps })
201}
202
203pub fn linear_extrapolate(
209 trajectory: &[(i64, &[f32])],
210 target_timestamp: i64,
211) -> Result<Vec<f32>, AnalyticsError> {
212 if trajectory.len() < 2 {
213 return Err(AnalyticsError::InsufficientData {
214 needed: 2,
215 have: trajectory.len(),
216 });
217 }
218
219 let n = trajectory.len();
220 let (t1, v1) = &trajectory[n - 2];
221 let (t2, v2) = &trajectory[n - 1];
222
223 let dt = (*t2 - *t1) as f64;
224 if dt == 0.0 {
225 return Ok(v2.to_vec());
226 }
227
228 let dt_target = (target_timestamp - *t2) as f64;
229 let ratio = dt_target / dt;
230
231 let predicted: Vec<f32> = v1
232 .iter()
233 .zip(v2.iter())
234 .map(|(&a, &b)| {
235 let vel = b as f64 - a as f64;
236 (b as f64 + vel * ratio) as f32
237 })
238 .collect();
239
240 Ok(predicted)
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 struct ExponentialDecay;
251 impl OdeSystem for ExponentialDecay {
252 fn derivative(&self, _t: f64, y: &[f64]) -> Vec<f64> {
253 vec![-y[0]]
254 }
255 }
256
257 struct HarmonicOscillator;
259 impl OdeSystem for HarmonicOscillator {
260 fn derivative(&self, _t: f64, y: &[f64]) -> Vec<f64> {
261 vec![y[1], -y[0]] }
263 }
264
265 struct ExponentialGrowth;
267 impl OdeSystem for ExponentialGrowth {
268 fn derivative(&self, _t: f64, y: &[f64]) -> Vec<f64> {
269 vec![y[0]]
270 }
271 }
272
273 #[test]
276 fn exponential_decay_matches_analytical() {
277 let system = ExponentialDecay;
278 let y0 = [1.0];
279 let t_end = 5.0;
280 let config = Rk45Config {
281 rtol: 1e-8,
282 ..Default::default()
283 };
284
285 let result = rk45_integrate(&system, 0.0, &y0, t_end, &config).unwrap();
286 let analytical = (-t_end).exp();
287
288 assert!(
289 (result.y[0] - analytical).abs() < 1e-5,
290 "RK45: {}, analytical: {analytical}",
291 result.y[0]
292 );
293 }
294
295 #[test]
296 fn harmonic_oscillator_matches_analytical() {
297 let system = HarmonicOscillator;
298 let y0 = [1.0, 0.0]; let t_end = 2.0 * std::f64::consts::PI;
300 let config = Rk45Config {
301 rtol: 1e-8,
302 ..Default::default()
303 };
304
305 let result = rk45_integrate(&system, 0.0, &y0, t_end, &config).unwrap();
306 assert!(
308 (result.y[0] - 1.0).abs() < 1e-5,
309 "x after full period: {}, expected 1.0",
310 result.y[0]
311 );
312 assert!(
313 result.y[1].abs() < 1e-5,
314 "v after full period: {}, expected 0.0",
315 result.y[1]
316 );
317 }
318
319 #[test]
320 fn rk45_adaptive_step_count() {
321 let system = ExponentialDecay;
323 let config = Rk45Config {
324 rtol: 1e-6,
325 ..Default::default()
326 };
327
328 let result = rk45_integrate(&system, 0.0, &[1.0], 1.0, &config).unwrap();
329 assert!(
330 result.steps < 100,
331 "should need few steps, got {}",
332 result.steps
333 );
334 }
335
336 #[test]
337 fn rk45_backward_integration() {
338 let system = ExponentialDecay;
339 let config = Rk45Config::default();
340
341 let y_at_1 = (-1.0f64).exp();
343 let result = rk45_integrate(&system, 1.0, &[y_at_1], 0.0, &config).unwrap();
344 assert!(
345 (result.y[0] - 1.0).abs() < 1e-4,
346 "backward: {}, expected 1.0",
347 result.y[0]
348 );
349 }
350
351 #[test]
352 fn rk45_exponential_growth() {
353 let system = ExponentialGrowth;
354 let config = Rk45Config {
355 rtol: 1e-6,
356 h_max: 0.5,
357 ..Default::default()
358 };
359
360 let result = rk45_integrate(&system, 0.0, &[1.0], 3.0, &config).unwrap();
361 let analytical = 3.0f64.exp();
362 assert!(
363 (result.y[0] - analytical).abs() / analytical < 1e-5,
364 "growth: {}, analytical: {analytical}",
365 result.y[0]
366 );
367 }
368
369 #[test]
372 fn linear_extrapolate_constant() {
373 let points = [(0i64, vec![1.0f32, 2.0, 3.0]), (1000, vec![1.0, 2.0, 3.0])];
374 let traj: Vec<(i64, &[f32])> = points.iter().map(|(t, v)| (*t, v.as_slice())).collect();
375
376 let pred = linear_extrapolate(&traj, 5000).unwrap();
377 assert_eq!(pred, vec![1.0, 2.0, 3.0]);
378 }
379
380 #[test]
381 fn linear_extrapolate_trend() {
382 let points = [(0i64, vec![0.0f32, 0.0]), (1000, vec![1.0, 2.0])];
383 let traj: Vec<(i64, &[f32])> = points.iter().map(|(t, v)| (*t, v.as_slice())).collect();
384
385 let pred = linear_extrapolate(&traj, 2000).unwrap();
386 assert!((pred[0] - 2.0).abs() < 1e-5);
388 assert!((pred[1] - 4.0).abs() < 1e-5);
389 }
390
391 #[test]
392 fn linear_extrapolate_backward() {
393 let points = [(0i64, vec![0.0f32]), (1000, vec![10.0])];
394 let traj: Vec<(i64, &[f32])> = points.iter().map(|(t, v)| (*t, v.as_slice())).collect();
395
396 let pred = linear_extrapolate(&traj, -1000).unwrap();
397 assert!((pred[0] - (-10.0)).abs() < 1e-5);
399 }
400
401 #[test]
402 fn linear_extrapolate_insufficient_data() {
403 let points = [(0i64, vec![1.0f32])];
404 let traj: Vec<(i64, &[f32])> = points.iter().map(|(t, v)| (*t, v.as_slice())).collect();
405 assert!(linear_extrapolate(&traj, 1000).is_err());
406 }
407}