cvx_analytics/
backend.rs

1//! Implementation of the `AnalyticsBackend` trait.
2//!
3//! Wires together calculus, PELT, and ODE modules into the core trait contract.
4
5#[cfg(feature = "torch-backend")]
6use std::sync::Arc;
7
8use cvx_core::error::AnalyticsError;
9use cvx_core::traits::AnalyticsBackend;
10use cvx_core::types::{ChangePoint, CpdMethod, TemporalPoint};
11
12use crate::calculus;
13use crate::ode;
14use crate::pelt::{self, PeltConfig};
15
16/// Default analytics backend using pure-Rust implementations.
17///
18/// When the `torch-backend` feature is enabled and a model is loaded,
19/// prediction uses the Neural ODE. Otherwise falls back to linear extrapolation.
20pub struct DefaultAnalytics {
21    pelt_config: PeltConfig,
22    #[cfg(feature = "torch-backend")]
23    torch_model: Option<Arc<crate::torch_ode::TorchOdeModel>>,
24}
25
26impl DefaultAnalytics {
27    /// Create with default configuration.
28    pub fn new() -> Self {
29        Self {
30            pelt_config: PeltConfig::default(),
31            #[cfg(feature = "torch-backend")]
32            torch_model: None,
33        }
34    }
35
36    /// Create with custom PELT configuration.
37    pub fn with_pelt_config(pelt_config: PeltConfig) -> Self {
38        Self {
39            pelt_config,
40            #[cfg(feature = "torch-backend")]
41            torch_model: None,
42        }
43    }
44
45    /// Create with a loaded TorchScript Neural ODE model.
46    #[cfg(feature = "torch-backend")]
47    pub fn with_torch_model(
48        pelt_config: PeltConfig,
49        model: Arc<crate::torch_ode::TorchOdeModel>,
50    ) -> Self {
51        Self {
52            pelt_config,
53            torch_model: Some(model),
54        }
55    }
56
57    /// Whether a Neural ODE model is loaded.
58    pub fn has_neural_ode(&self) -> bool {
59        #[cfg(feature = "torch-backend")]
60        {
61            self.torch_model.is_some()
62        }
63        #[cfg(not(feature = "torch-backend"))]
64        {
65            false
66        }
67    }
68}
69
70impl Default for DefaultAnalytics {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76/// Convert a slice of `TemporalPoint` to the `(i64, &[f32])` format used internally.
77fn to_trajectory(points: &[TemporalPoint]) -> Vec<(i64, &[f32])> {
78    points.iter().map(|p| (p.timestamp(), p.vector())).collect()
79}
80
81impl AnalyticsBackend for DefaultAnalytics {
82    fn predict(
83        &self,
84        trajectory: &[TemporalPoint],
85        target_timestamp: i64,
86    ) -> Result<TemporalPoint, AnalyticsError> {
87        if trajectory.len() < 2 {
88            return Err(AnalyticsError::InsufficientData {
89                needed: 2,
90                have: trajectory.len(),
91            });
92        }
93
94        let entity_id = trajectory.last().unwrap().entity_id();
95        let traj = to_trajectory(trajectory);
96
97        // Try Neural ODE if available (RFC-003)
98        #[cfg(feature = "torch-backend")]
99        if let Some(ref model) = self.torch_model {
100            match model.predict(&traj, target_timestamp) {
101                Ok(predicted) => {
102                    return Ok(TemporalPoint::new(entity_id, target_timestamp, predicted));
103                }
104                Err(e) => {
105                    tracing::warn!("Neural ODE prediction failed, falling back to linear: {e}");
106                }
107            }
108        }
109
110        // Fallback: linear extrapolation
111        let predicted = ode::linear_extrapolate(&traj, target_timestamp)?;
112        Ok(TemporalPoint::new(entity_id, target_timestamp, predicted))
113    }
114
115    fn detect_changepoints(
116        &self,
117        trajectory: &[TemporalPoint],
118        method: CpdMethod,
119    ) -> Result<Vec<ChangePoint>, AnalyticsError> {
120        if trajectory.is_empty() {
121            return Ok(Vec::new());
122        }
123
124        let entity_id = trajectory[0].entity_id();
125        let traj = to_trajectory(trajectory);
126
127        match method {
128            CpdMethod::Pelt => Ok(pelt::detect(entity_id, &traj, &self.pelt_config)),
129            CpdMethod::Bocpd => {
130                // Use online detector in batch mode
131                let mut detector = crate::bocpd::BocpdDetector::new(
132                    entity_id,
133                    crate::bocpd::BocpdConfig::default(),
134                );
135                let mut cps = Vec::new();
136                for p in trajectory {
137                    if let Some(cp) = detector.observe(p.timestamp(), p.vector()) {
138                        cps.push(cp);
139                    }
140                }
141                Ok(cps)
142            }
143        }
144    }
145
146    fn velocity(
147        &self,
148        trajectory: &[TemporalPoint],
149        timestamp: i64,
150    ) -> Result<Vec<f32>, AnalyticsError> {
151        let traj = to_trajectory(trajectory);
152        calculus::velocity(&traj, timestamp)
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    fn make_trajectory(n: usize, entity_id: u64) -> Vec<TemporalPoint> {
161        (0..n)
162            .map(|i| TemporalPoint::new(entity_id, (i as i64) * 1000, vec![i as f32 * 0.1; 3]))
163            .collect()
164    }
165
166    #[test]
167    fn predict_returns_point_at_target() {
168        let backend = DefaultAnalytics::new();
169        let traj = make_trajectory(20, 1);
170        let result = backend.predict(&traj, 25_000).unwrap();
171        assert_eq!(result.entity_id(), 1);
172        assert_eq!(result.timestamp(), 25_000);
173        assert_eq!(result.dim(), 3);
174    }
175
176    #[test]
177    fn predict_insufficient_data() {
178        let backend = DefaultAnalytics::new();
179        let traj = make_trajectory(1, 1);
180        assert!(backend.predict(&traj, 5000).is_err());
181    }
182
183    #[test]
184    fn detect_changepoints_pelt_near_linear() {
185        let backend = DefaultAnalytics::new();
186        let traj = make_trajectory(100, 1);
187        let cps = backend.detect_changepoints(&traj, CpdMethod::Pelt).unwrap();
188        // Near-linear trajectory may have some CPs due to slope, but not many
189        assert!(
190            cps.len() <= 10,
191            "too many CPs on near-linear data: {}",
192            cps.len()
193        );
194    }
195
196    #[test]
197    fn detect_changepoints_pelt_with_change() {
198        let backend = DefaultAnalytics::new();
199        let mut traj = Vec::new();
200        for i in 0..50 {
201            traj.push(TemporalPoint::new(1, i * 1000, vec![0.0, 0.0]));
202        }
203        for i in 50..100 {
204            traj.push(TemporalPoint::new(1, i * 1000, vec![10.0, 10.0]));
205        }
206        let cps = backend.detect_changepoints(&traj, CpdMethod::Pelt).unwrap();
207        assert!(!cps.is_empty(), "should detect the planted change");
208    }
209
210    #[test]
211    fn detect_changepoints_bocpd() {
212        let backend = DefaultAnalytics::new();
213        let mut traj = Vec::new();
214        for i in 0..50 {
215            traj.push(TemporalPoint::new(1, i * 1000, vec![0.0]));
216        }
217        for i in 50..100 {
218            traj.push(TemporalPoint::new(1, i * 1000, vec![10.0]));
219        }
220        let cps = backend
221            .detect_changepoints(&traj, CpdMethod::Bocpd)
222            .unwrap();
223        // BOCPD should detect at least 1 change
224        assert!(!cps.is_empty());
225    }
226
227    #[test]
228    fn velocity_linear_trajectory() {
229        let backend = DefaultAnalytics::new();
230        let traj = make_trajectory(20, 1);
231        let vel = backend.velocity(&traj, 10_000).unwrap();
232        assert_eq!(vel.len(), 3);
233        // Linear trajectory → constant velocity
234        for &v in &vel {
235            assert!(v.is_finite());
236        }
237    }
238
239    #[test]
240    fn velocity_insufficient_data() {
241        let backend = DefaultAnalytics::new();
242        let traj = make_trajectory(1, 1);
243        assert!(backend.velocity(&traj, 0).is_err());
244    }
245
246    #[test]
247    fn is_send_sync() {
248        fn assert_send_sync<T: Send + Sync>() {}
249        assert_send_sync::<DefaultAnalytics>();
250    }
251}