1#[cfg(feature = "torch-backend")]
6use std::sync::Arc;
7
8use cvx_core::error::AnalyticsError;
9use cvx_core::traits::AnalyticsBackend;
10use cvx_core::types::{ChangePoint, CpdMethod, TemporalPoint};
11
12use crate::calculus;
13use crate::ode;
14use crate::pelt::{self, PeltConfig};
15
16pub struct DefaultAnalytics {
21 pelt_config: PeltConfig,
22 #[cfg(feature = "torch-backend")]
23 torch_model: Option<Arc<crate::torch_ode::TorchOdeModel>>,
24}
25
26impl DefaultAnalytics {
27 pub fn new() -> Self {
29 Self {
30 pelt_config: PeltConfig::default(),
31 #[cfg(feature = "torch-backend")]
32 torch_model: None,
33 }
34 }
35
36 pub fn with_pelt_config(pelt_config: PeltConfig) -> Self {
38 Self {
39 pelt_config,
40 #[cfg(feature = "torch-backend")]
41 torch_model: None,
42 }
43 }
44
45 #[cfg(feature = "torch-backend")]
47 pub fn with_torch_model(
48 pelt_config: PeltConfig,
49 model: Arc<crate::torch_ode::TorchOdeModel>,
50 ) -> Self {
51 Self {
52 pelt_config,
53 torch_model: Some(model),
54 }
55 }
56
57 pub fn has_neural_ode(&self) -> bool {
59 #[cfg(feature = "torch-backend")]
60 {
61 self.torch_model.is_some()
62 }
63 #[cfg(not(feature = "torch-backend"))]
64 {
65 false
66 }
67 }
68}
69
70impl Default for DefaultAnalytics {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76fn to_trajectory(points: &[TemporalPoint]) -> Vec<(i64, &[f32])> {
78 points.iter().map(|p| (p.timestamp(), p.vector())).collect()
79}
80
81impl AnalyticsBackend for DefaultAnalytics {
82 fn predict(
83 &self,
84 trajectory: &[TemporalPoint],
85 target_timestamp: i64,
86 ) -> Result<TemporalPoint, AnalyticsError> {
87 if trajectory.len() < 2 {
88 return Err(AnalyticsError::InsufficientData {
89 needed: 2,
90 have: trajectory.len(),
91 });
92 }
93
94 let entity_id = trajectory.last().unwrap().entity_id();
95 let traj = to_trajectory(trajectory);
96
97 #[cfg(feature = "torch-backend")]
99 if let Some(ref model) = self.torch_model {
100 match model.predict(&traj, target_timestamp) {
101 Ok(predicted) => {
102 return Ok(TemporalPoint::new(entity_id, target_timestamp, predicted));
103 }
104 Err(e) => {
105 tracing::warn!("Neural ODE prediction failed, falling back to linear: {e}");
106 }
107 }
108 }
109
110 let predicted = ode::linear_extrapolate(&traj, target_timestamp)?;
112 Ok(TemporalPoint::new(entity_id, target_timestamp, predicted))
113 }
114
115 fn detect_changepoints(
116 &self,
117 trajectory: &[TemporalPoint],
118 method: CpdMethod,
119 ) -> Result<Vec<ChangePoint>, AnalyticsError> {
120 if trajectory.is_empty() {
121 return Ok(Vec::new());
122 }
123
124 let entity_id = trajectory[0].entity_id();
125 let traj = to_trajectory(trajectory);
126
127 match method {
128 CpdMethod::Pelt => Ok(pelt::detect(entity_id, &traj, &self.pelt_config)),
129 CpdMethod::Bocpd => {
130 let mut detector = crate::bocpd::BocpdDetector::new(
132 entity_id,
133 crate::bocpd::BocpdConfig::default(),
134 );
135 let mut cps = Vec::new();
136 for p in trajectory {
137 if let Some(cp) = detector.observe(p.timestamp(), p.vector()) {
138 cps.push(cp);
139 }
140 }
141 Ok(cps)
142 }
143 }
144 }
145
146 fn velocity(
147 &self,
148 trajectory: &[TemporalPoint],
149 timestamp: i64,
150 ) -> Result<Vec<f32>, AnalyticsError> {
151 let traj = to_trajectory(trajectory);
152 calculus::velocity(&traj, timestamp)
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 fn make_trajectory(n: usize, entity_id: u64) -> Vec<TemporalPoint> {
161 (0..n)
162 .map(|i| TemporalPoint::new(entity_id, (i as i64) * 1000, vec![i as f32 * 0.1; 3]))
163 .collect()
164 }
165
166 #[test]
167 fn predict_returns_point_at_target() {
168 let backend = DefaultAnalytics::new();
169 let traj = make_trajectory(20, 1);
170 let result = backend.predict(&traj, 25_000).unwrap();
171 assert_eq!(result.entity_id(), 1);
172 assert_eq!(result.timestamp(), 25_000);
173 assert_eq!(result.dim(), 3);
174 }
175
176 #[test]
177 fn predict_insufficient_data() {
178 let backend = DefaultAnalytics::new();
179 let traj = make_trajectory(1, 1);
180 assert!(backend.predict(&traj, 5000).is_err());
181 }
182
183 #[test]
184 fn detect_changepoints_pelt_near_linear() {
185 let backend = DefaultAnalytics::new();
186 let traj = make_trajectory(100, 1);
187 let cps = backend.detect_changepoints(&traj, CpdMethod::Pelt).unwrap();
188 assert!(
190 cps.len() <= 10,
191 "too many CPs on near-linear data: {}",
192 cps.len()
193 );
194 }
195
196 #[test]
197 fn detect_changepoints_pelt_with_change() {
198 let backend = DefaultAnalytics::new();
199 let mut traj = Vec::new();
200 for i in 0..50 {
201 traj.push(TemporalPoint::new(1, i * 1000, vec![0.0, 0.0]));
202 }
203 for i in 50..100 {
204 traj.push(TemporalPoint::new(1, i * 1000, vec![10.0, 10.0]));
205 }
206 let cps = backend.detect_changepoints(&traj, CpdMethod::Pelt).unwrap();
207 assert!(!cps.is_empty(), "should detect the planted change");
208 }
209
210 #[test]
211 fn detect_changepoints_bocpd() {
212 let backend = DefaultAnalytics::new();
213 let mut traj = Vec::new();
214 for i in 0..50 {
215 traj.push(TemporalPoint::new(1, i * 1000, vec![0.0]));
216 }
217 for i in 50..100 {
218 traj.push(TemporalPoint::new(1, i * 1000, vec![10.0]));
219 }
220 let cps = backend
221 .detect_changepoints(&traj, CpdMethod::Bocpd)
222 .unwrap();
223 assert!(!cps.is_empty());
225 }
226
227 #[test]
228 fn velocity_linear_trajectory() {
229 let backend = DefaultAnalytics::new();
230 let traj = make_trajectory(20, 1);
231 let vel = backend.velocity(&traj, 10_000).unwrap();
232 assert_eq!(vel.len(), 3);
233 for &v in &vel {
235 assert!(v.is_finite());
236 }
237 }
238
239 #[test]
240 fn velocity_insufficient_data() {
241 let backend = DefaultAnalytics::new();
242 let traj = make_trajectory(1, 1);
243 assert!(backend.velocity(&traj, 0).is_err());
244 }
245
246 #[test]
247 fn is_send_sync() {
248 fn assert_send_sync<T: Send + Sync>() {}
249 assert_send_sync::<DefaultAnalytics>();
250 }
251}