1use cvx_core::error::QueryError;
6use cvx_core::types::TemporalFilter;
7use cvx_core::{StorageBackend, TemporalIndexAccess};
8
9use cvx_analytics::calculus;
10use cvx_analytics::cohort;
11use cvx_analytics::counterfactual;
12use cvx_analytics::granger;
13use cvx_analytics::motifs;
14use cvx_analytics::ode;
15use cvx_analytics::pelt::{self, PeltConfig};
16use cvx_analytics::temporal_join;
17
18use crate::types::*;
19
20pub struct QueryEngine<I: TemporalIndexAccess, S: StorageBackend> {
26 index: I,
27 store: S,
28}
29
30impl<I: TemporalIndexAccess, S: StorageBackend> QueryEngine<I, S> {
31 pub fn new(index: I, store: S) -> Self {
33 Self { index, store }
34 }
35
36 pub fn execute(&self, query: TemporalQuery) -> Result<QueryResult, QueryError> {
38 execute_query(&self.index, query)
39 }
40
41 pub fn index(&self) -> &I {
43 &self.index
44 }
45
46 pub fn store(&self) -> &S {
48 &self.store
49 }
50}
51
52pub fn execute_query(
57 index: &dyn TemporalIndexAccess,
58 query: TemporalQuery,
59) -> Result<QueryResult, QueryError> {
60 match query {
61 TemporalQuery::SnapshotKnn {
62 vector,
63 timestamp,
64 k,
65 } => {
66 let results = index.search_raw(
67 &vector,
68 k,
69 TemporalFilter::Snapshot(timestamp),
70 1.0,
71 timestamp,
72 );
73 Ok(QueryResult::Knn(
74 results
75 .into_iter()
76 .map(|(node_id, score)| KnnResult {
77 entity_id: index.entity_id(node_id),
78 timestamp: index.timestamp(node_id),
79 score,
80 })
81 .collect(),
82 ))
83 }
84
85 TemporalQuery::RangeKnn {
86 vector,
87 start,
88 end,
89 k,
90 alpha,
91 } => {
92 let mid = start + (end - start) / 2;
93 let results =
94 index.search_raw(&vector, k, TemporalFilter::Range(start, end), alpha, mid);
95 Ok(QueryResult::Knn(
96 results
97 .into_iter()
98 .map(|(node_id, score)| KnnResult {
99 entity_id: index.entity_id(node_id),
100 timestamp: index.timestamp(node_id),
101 score,
102 })
103 .collect(),
104 ))
105 }
106
107 TemporalQuery::Trajectory { entity_id, filter } => {
108 let traj = index.trajectory(entity_id, filter);
109 let points = traj
110 .iter()
111 .map(|&(ts, node_id)| {
112 cvx_core::TemporalPoint::new(entity_id, ts, index.vector(node_id))
113 })
114 .collect();
115 Ok(QueryResult::Trajectory(points))
116 }
117
118 TemporalQuery::Velocity {
119 entity_id,
120 timestamp,
121 } => do_velocity(index, entity_id, timestamp),
122
123 TemporalQuery::Prediction {
124 entity_id,
125 target_timestamp,
126 } => do_prediction(index, entity_id, target_timestamp),
127
128 TemporalQuery::ChangePointDetect {
129 entity_id,
130 start,
131 end,
132 } => do_change_points(index, entity_id, start, end),
133
134 TemporalQuery::DriftQuant {
135 entity_id,
136 t1,
137 t2,
138 top_n,
139 } => do_drift_quant(index, entity_id, t1, t2, top_n),
140
141 TemporalQuery::Analogy {
142 entity_a,
143 t1,
144 t2,
145 entity_b,
146 t3,
147 } => do_analogy(index, entity_a, t1, t2, entity_b, t3),
148
149 TemporalQuery::Counterfactual {
150 entity_id,
151 change_point,
152 } => do_counterfactual(index, entity_id, change_point),
153
154 TemporalQuery::GrangerCausality {
155 entity_a,
156 entity_b,
157 max_lag,
158 significance,
159 } => do_granger(index, entity_a, entity_b, max_lag, significance),
160
161 TemporalQuery::DiscoverMotifs {
162 entity_id,
163 window,
164 max_motifs,
165 } => do_discover_motifs(index, entity_id, window, max_motifs),
166
167 TemporalQuery::DiscoverDiscords {
168 entity_id,
169 window,
170 max_discords,
171 } => do_discover_discords(index, entity_id, window, max_discords),
172
173 TemporalQuery::TemporalJoin {
174 entity_a,
175 entity_b,
176 epsilon,
177 window_us,
178 } => do_temporal_join(index, entity_a, entity_b, epsilon, window_us),
179
180 TemporalQuery::CohortDrift {
181 entity_ids,
182 t1,
183 t2,
184 top_n,
185 } => do_cohort_drift(index, &entity_ids, t1, t2, top_n),
186
187 TemporalQuery::CausalSearch {
188 vector,
189 k,
190 filter,
191 alpha,
192 query_timestamp,
193 temporal_context,
194 } => do_causal_search(
195 index,
196 &vector,
197 k,
198 filter,
199 alpha,
200 query_timestamp,
201 temporal_context,
202 ),
203 }
204}
205
206fn build_traj(
209 index: &dyn TemporalIndexAccess,
210 entity_id: u64,
211 filter: TemporalFilter,
212) -> (Vec<(i64, u32)>, Vec<Vec<f32>>) {
213 let traj_data = index.trajectory(entity_id, filter);
214 let vectors: Vec<Vec<f32>> = traj_data
215 .iter()
216 .map(|&(_, node_id)| index.vector(node_id))
217 .collect();
218 (traj_data, vectors)
219}
220
221fn to_slices<'a>(traj_data: &'a [(i64, u32)], vectors: &'a [Vec<f32>]) -> Vec<(i64, &'a [f32])> {
222 traj_data
223 .iter()
224 .zip(vectors.iter())
225 .map(|(&(ts, _), v)| (ts, v.as_slice()))
226 .collect()
227}
228
229fn find_nearest(
230 index: &dyn TemporalIndexAccess,
231 entity_id: u64,
232 timestamp: i64,
233) -> Result<u32, QueryError> {
234 let traj = index.trajectory(entity_id, TemporalFilter::All);
235 if traj.is_empty() {
236 return Err(QueryError::EntityNotFound(entity_id));
237 }
238 let (_, node_id) = traj
239 .iter()
240 .min_by_key(|&&(ts, _)| (ts - timestamp).unsigned_abs())
241 .unwrap();
242 Ok(*node_id)
243}
244
245fn do_velocity(
246 index: &dyn TemporalIndexAccess,
247 entity_id: u64,
248 timestamp: i64,
249) -> Result<QueryResult, QueryError> {
250 let (traj_data, vectors) = build_traj(index, entity_id, TemporalFilter::All);
251 if traj_data.len() < 2 {
252 return Err(QueryError::InsufficientData {
253 needed: 2,
254 have: traj_data.len(),
255 });
256 }
257 let traj = to_slices(&traj_data, &vectors);
258 let vel = calculus::velocity(&traj, timestamp).map_err(|_| QueryError::InsufficientData {
259 needed: 2,
260 have: traj.len(),
261 })?;
262 Ok(QueryResult::Velocity(vel))
263}
264
265fn do_prediction(
266 index: &dyn TemporalIndexAccess,
267 entity_id: u64,
268 target_timestamp: i64,
269) -> Result<QueryResult, QueryError> {
270 let (traj_data, vectors) = build_traj(index, entity_id, TemporalFilter::All);
271 if traj_data.len() < 2 {
272 return Err(QueryError::InsufficientData {
273 needed: 2,
274 have: traj_data.len(),
275 });
276 }
277 let traj = to_slices(&traj_data, &vectors);
278 let predicted = ode::linear_extrapolate(&traj, target_timestamp).map_err(|_| {
279 QueryError::InsufficientData {
280 needed: 2,
281 have: traj.len(),
282 }
283 })?;
284 Ok(QueryResult::Prediction(PredictionResult {
285 vector: predicted,
286 timestamp: target_timestamp,
287 method: PredictionMethod::Linear,
288 }))
289}
290
291fn do_change_points(
292 index: &dyn TemporalIndexAccess,
293 entity_id: u64,
294 start: i64,
295 end: i64,
296) -> Result<QueryResult, QueryError> {
297 let (traj_data, vectors) = build_traj(index, entity_id, TemporalFilter::Range(start, end));
298 if traj_data.len() < 4 {
299 return Ok(QueryResult::ChangePoints(Vec::new()));
300 }
301 let traj = to_slices(&traj_data, &vectors);
302 let cps = pelt::detect(entity_id, &traj, &PeltConfig::default());
303 Ok(QueryResult::ChangePoints(cps))
304}
305
306fn do_drift_quant(
307 index: &dyn TemporalIndexAccess,
308 entity_id: u64,
309 t1: i64,
310 t2: i64,
311 top_n: usize,
312) -> Result<QueryResult, QueryError> {
313 let p1 = find_nearest(index, entity_id, t1)?;
314 let p2 = find_nearest(index, entity_id, t2)?;
315 let v1 = index.vector(p1);
316 let v2 = index.vector(p2);
317 let report = calculus::drift_report(&v1, &v2, top_n);
318 Ok(QueryResult::Drift(DriftResult {
319 l2_magnitude: report.l2_magnitude,
320 cosine_drift: report.cosine_drift,
321 top_dimensions: report.top_dimensions,
322 }))
323}
324
325fn do_counterfactual(
326 index: &dyn TemporalIndexAccess,
327 entity_id: u64,
328 change_point: i64,
329) -> Result<QueryResult, QueryError> {
330 let (td_pre, vecs_pre) = build_traj(index, entity_id, TemporalFilter::Before(change_point));
331 let (td_post, vecs_post) = build_traj(index, entity_id, TemporalFilter::After(change_point));
332
333 if td_pre.len() < 2 {
334 return Err(QueryError::InsufficientData {
335 needed: 2,
336 have: td_pre.len(),
337 });
338 }
339 if td_post.is_empty() {
340 return Err(QueryError::InsufficientData { needed: 1, have: 0 });
341 }
342
343 let pre = to_slices(&td_pre, &vecs_pre);
344 let post = to_slices(&td_post, &vecs_post);
345
346 let result =
347 counterfactual::counterfactual_trajectory(&pre, &post, change_point).map_err(|_| {
348 QueryError::InsufficientData {
349 needed: 2,
350 have: td_pre.len(),
351 }
352 })?;
353
354 Ok(QueryResult::Counterfactual(CounterfactualQueryResult {
355 change_point,
356 total_divergence: result.total_divergence,
357 max_divergence_time: result.max_divergence_time,
358 max_divergence_value: result.max_divergence_value,
359 divergence_curve: result.divergence_curve,
360 method: format!("{:?}", result.method),
361 }))
362}
363
364fn do_granger(
365 index: &dyn TemporalIndexAccess,
366 entity_a: u64,
367 entity_b: u64,
368 max_lag: usize,
369 significance: f64,
370) -> Result<QueryResult, QueryError> {
371 let (td_a, vecs_a) = build_traj(index, entity_a, TemporalFilter::All);
372 let (td_b, vecs_b) = build_traj(index, entity_b, TemporalFilter::All);
373
374 if td_a.is_empty() {
375 return Err(QueryError::EntityNotFound(entity_a));
376 }
377 if td_b.is_empty() {
378 return Err(QueryError::EntityNotFound(entity_b));
379 }
380
381 let traj_a = to_slices(&td_a, &vecs_a);
382 let traj_b = to_slices(&td_b, &vecs_b);
383
384 let result =
385 granger::granger_causality(&traj_a, &traj_b, max_lag, significance).map_err(|_| {
386 QueryError::InsufficientData {
387 needed: max_lag + 3,
388 have: traj_a.len().min(traj_b.len()),
389 }
390 })?;
391
392 let direction = match result.direction {
393 granger::GrangerDirection::AToB => "a_to_b",
394 granger::GrangerDirection::BToA => "b_to_a",
395 granger::GrangerDirection::Bidirectional => "bidirectional",
396 granger::GrangerDirection::None => "none",
397 };
398
399 Ok(QueryResult::Granger(GrangerCausalityResult {
400 direction: direction.to_string(),
401 optimal_lag: result.optimal_lag,
402 f_statistic: result.f_statistic,
403 p_value: result.p_value,
404 effect_size: result.effect_size,
405 per_dimension_a_to_b: result.per_dimension_a_to_b,
406 per_dimension_b_to_a: result.per_dimension_b_to_a,
407 }))
408}
409
410fn do_discover_motifs(
411 index: &dyn TemporalIndexAccess,
412 entity_id: u64,
413 window: usize,
414 max_motifs: usize,
415) -> Result<QueryResult, QueryError> {
416 let (td, vecs) = build_traj(index, entity_id, TemporalFilter::All);
417 if td.is_empty() {
418 return Err(QueryError::EntityNotFound(entity_id));
419 }
420 let traj = to_slices(&td, &vecs);
421 let found = motifs::discover_motifs(&traj, window, max_motifs, 0.5).map_err(|_| {
422 QueryError::InsufficientData {
423 needed: 2 * window,
424 have: traj.len(),
425 }
426 })?;
427
428 Ok(QueryResult::Motifs(
429 found
430 .into_iter()
431 .map(|m| MotifResult {
432 canonical_index: m.canonical_index,
433 occurrences: m
434 .occurrences
435 .into_iter()
436 .map(|o| MotifOccurrenceResult {
437 start_index: o.start_index,
438 timestamp: o.timestamp,
439 distance: o.distance,
440 })
441 .collect(),
442 period: m.period,
443 mean_match_distance: m.mean_match_distance,
444 })
445 .collect(),
446 ))
447}
448
449fn do_discover_discords(
450 index: &dyn TemporalIndexAccess,
451 entity_id: u64,
452 window: usize,
453 max_discords: usize,
454) -> Result<QueryResult, QueryError> {
455 let (td, vecs) = build_traj(index, entity_id, TemporalFilter::All);
456 if td.is_empty() {
457 return Err(QueryError::EntityNotFound(entity_id));
458 }
459 let traj = to_slices(&td, &vecs);
460 let found = motifs::discover_discords(&traj, window, max_discords).map_err(|_| {
461 QueryError::InsufficientData {
462 needed: 2 * window,
463 have: traj.len(),
464 }
465 })?;
466
467 Ok(QueryResult::Discords(
468 found
469 .into_iter()
470 .map(|d| DiscordResult {
471 start_index: d.start_index,
472 timestamp: d.timestamp,
473 nn_distance: d.nn_distance,
474 })
475 .collect(),
476 ))
477}
478
479fn do_temporal_join(
480 index: &dyn TemporalIndexAccess,
481 entity_a: u64,
482 entity_b: u64,
483 epsilon: f32,
484 window_us: i64,
485) -> Result<QueryResult, QueryError> {
486 let (td_a, vecs_a) = build_traj(index, entity_a, TemporalFilter::All);
487 let (td_b, vecs_b) = build_traj(index, entity_b, TemporalFilter::All);
488
489 if td_a.is_empty() {
490 return Err(QueryError::EntityNotFound(entity_a));
491 }
492 if td_b.is_empty() {
493 return Err(QueryError::EntityNotFound(entity_b));
494 }
495
496 let traj_a = to_slices(&td_a, &vecs_a);
497 let traj_b = to_slices(&td_b, &vecs_b);
498
499 let joins = temporal_join::temporal_join(&traj_a, &traj_b, epsilon, window_us)
500 .map_err(|_| QueryError::InsufficientData { needed: 1, have: 0 })?;
501
502 Ok(QueryResult::TemporalJoin(
503 joins
504 .into_iter()
505 .map(|j| TemporalJoinResultEntry {
506 start: j.start,
507 end: j.end,
508 mean_distance: j.mean_distance,
509 min_distance: j.min_distance,
510 points_a: j.points_a,
511 points_b: j.points_b,
512 })
513 .collect(),
514 ))
515}
516
517fn do_cohort_drift(
518 index: &dyn TemporalIndexAccess,
519 entity_ids: &[u64],
520 t1: i64,
521 t2: i64,
522 top_n: usize,
523) -> Result<QueryResult, QueryError> {
524 #[allow(clippy::type_complexity)]
526 let mut traj_data: Vec<(Vec<(i64, u32)>, Vec<Vec<f32>>)> = Vec::new();
527 let mut valid_ids: Vec<u64> = Vec::new();
528
529 for &eid in entity_ids {
530 let (td, vecs) = build_traj(index, eid, TemporalFilter::All);
531 if !td.is_empty() {
532 traj_data.push((td, vecs));
533 valid_ids.push(eid);
534 }
535 }
536
537 if valid_ids.len() < 2 {
538 return Err(QueryError::InsufficientData {
539 needed: 2,
540 have: valid_ids.len(),
541 });
542 }
543
544 let slice_trajs: Vec<Vec<(i64, &[f32])>> = traj_data
546 .iter()
547 .map(|(td, vecs)| to_slices(td, vecs))
548 .collect();
549
550 #[allow(clippy::type_complexity)]
551 let input: Vec<(u64, &[(i64, &[f32])])> = valid_ids
552 .iter()
553 .zip(slice_trajs.iter())
554 .map(|(&eid, st)| (eid, st.as_slice()))
555 .collect();
556
557 let report =
558 cohort::cohort_drift(&input, t1, t2, top_n).map_err(|_| QueryError::InsufficientData {
559 needed: 2,
560 have: valid_ids.len(),
561 })?;
562
563 Ok(QueryResult::CohortDrift(CohortDriftResult {
564 n_entities: report.n_entities,
565 mean_drift_l2: report.mean_drift_l2,
566 median_drift_l2: report.median_drift_l2,
567 std_drift_l2: report.std_drift_l2,
568 centroid_l2_magnitude: report.centroid_drift.l2_magnitude,
569 centroid_cosine_drift: report.centroid_drift.cosine_drift,
570 dispersion_t1: report.dispersion_t1,
571 dispersion_t2: report.dispersion_t2,
572 dispersion_change: report.dispersion_change,
573 convergence_score: report.convergence_score,
574 top_dimensions: report.top_dimensions,
575 outliers: report
576 .outliers
577 .into_iter()
578 .map(|o| CohortOutlierResult {
579 entity_id: o.entity_id,
580 drift_magnitude: o.drift_magnitude,
581 z_score: o.z_score,
582 drift_direction_alignment: o.drift_direction_alignment,
583 })
584 .collect(),
585 }))
586}
587
588fn do_analogy(
589 index: &dyn TemporalIndexAccess,
590 entity_a: u64,
591 t1: i64,
592 t2: i64,
593 entity_b: u64,
594 t3: i64,
595) -> Result<QueryResult, QueryError> {
596 let a1 = find_nearest(index, entity_a, t1)?;
597 let a2 = find_nearest(index, entity_a, t2)?;
598 let b3 = find_nearest(index, entity_b, t3)?;
599 let va1 = index.vector(a1);
600 let va2 = index.vector(a2);
601 let vb3 = index.vector(b3);
602 let result: Vec<f32> = vb3
603 .iter()
604 .zip(va2.iter().zip(va1.iter()))
605 .map(|(&b, (&a2, &a1))| b + (a2 - a1))
606 .collect();
607 Ok(QueryResult::Analogy(result))
608}
609
610fn do_causal_search(
615 index: &dyn TemporalIndexAccess,
616 vector: &[f32],
617 k: usize,
618 filter: TemporalFilter,
619 alpha: f32,
620 query_timestamp: i64,
621 temporal_context: usize,
622) -> Result<QueryResult, QueryError> {
623 let results = index.search_raw(vector, k, filter, alpha, query_timestamp);
624
625 let causal: Vec<CausalSearchResultEntry> = results
626 .into_iter()
627 .map(|(node_id, score)| {
628 let entity_id = index.entity_id(node_id);
629 let _ts = index.timestamp(node_id);
630
631 let traj = index.trajectory(entity_id, TemporalFilter::All);
633
634 let pos = traj.iter().position(|&(_, nid)| nid == node_id);
636
637 let (successors, predecessors) = match pos {
638 Some(p) => {
639 let succ: Vec<(u32, i64)> = traj[p + 1..]
640 .iter()
641 .take(temporal_context)
642 .map(|&(t, nid)| (nid, t))
643 .collect();
644 let pred: Vec<(u32, i64)> = traj[..p]
645 .iter()
646 .rev()
647 .take(temporal_context)
648 .map(|&(t, nid)| (nid, t))
649 .rev()
650 .collect();
651 (succ, pred)
652 }
653 None => (Vec::new(), Vec::new()),
654 };
655
656 CausalSearchResultEntry {
657 node_id,
658 score,
659 entity_id,
660 successors,
661 predecessors,
662 }
663 })
664 .collect();
665
666 Ok(QueryResult::CausalSearch(causal))
667}
668
669#[cfg(test)]
670mod tests {
671 use super::*;
672 use cvx_core::TemporalPoint;
673 use cvx_index::hnsw::HnswConfig;
674 use cvx_index::hnsw::temporal::TemporalHnsw;
675 use cvx_index::metrics::L2Distance;
676 use cvx_storage::memory::InMemoryStore;
677
678 fn setup_engine(
679 n_entities: u64,
680 points_per_entity: usize,
681 dim: usize,
682 ) -> QueryEngine<TemporalHnsw<L2Distance>, InMemoryStore> {
683 let config = HnswConfig {
684 m: 16,
685 ef_construction: 200,
686 ef_search: 100,
687 ..Default::default()
688 };
689 let mut index = TemporalHnsw::new(config, L2Distance);
690 let store = InMemoryStore::new();
691
692 for e in 0..n_entities {
693 for i in 0..points_per_entity {
694 let ts = (i as i64) * 1_000_000;
695 let v: Vec<f32> = (0..dim)
696 .map(|d| (e as f32 * 10.0) + (i as f32 * 0.1) + (d as f32 * 0.01))
697 .collect();
698 index.insert(e, ts, &v);
699 store.put(0, &TemporalPoint::new(e, ts, v)).unwrap();
700 }
701 }
702
703 QueryEngine::new(index, store)
704 }
705
706 #[test]
707 fn snapshot_knn_returns_at_timestamp() {
708 let engine = setup_engine(5, 20, 4);
709 let result = engine
710 .execute(TemporalQuery::SnapshotKnn {
711 vector: vec![0.0; 4],
712 timestamp: 5_000_000,
713 k: 3,
714 })
715 .unwrap();
716
717 if let QueryResult::Knn(results) = result {
718 for r in &results {
719 assert_eq!(r.timestamp, 5_000_000);
720 }
721 } else {
722 panic!("expected Knn result");
723 }
724 }
725
726 #[test]
727 fn range_knn_returns_in_range() {
728 let engine = setup_engine(5, 20, 4);
729 let result = engine
730 .execute(TemporalQuery::RangeKnn {
731 vector: vec![0.0; 4],
732 start: 3_000_000,
733 end: 7_000_000,
734 k: 10,
735 alpha: 1.0,
736 })
737 .unwrap();
738
739 if let QueryResult::Knn(results) = result {
740 assert!(!results.is_empty());
741 for r in &results {
742 assert!(
743 r.timestamp >= 3_000_000 && r.timestamp <= 7_000_000,
744 "ts {} out of range",
745 r.timestamp
746 );
747 }
748 } else {
749 panic!("expected Knn result");
750 }
751 }
752
753 #[test]
754 fn trajectory_returns_all_points_ordered() {
755 let engine = setup_engine(3, 20, 4);
756 let result = engine
757 .execute(TemporalQuery::Trajectory {
758 entity_id: 1,
759 filter: TemporalFilter::All,
760 })
761 .unwrap();
762
763 if let QueryResult::Trajectory(points) = result {
764 assert_eq!(points.len(), 20);
765 for w in points.windows(2) {
766 assert!(w[0].timestamp() <= w[1].timestamp());
767 }
768 for p in &points {
769 assert_eq!(p.entity_id(), 1);
770 }
771 } else {
772 panic!("expected Trajectory result");
773 }
774 }
775
776 #[test]
777 fn trajectory_with_range_filter() {
778 let engine = setup_engine(1, 20, 4);
779 let result = engine
780 .execute(TemporalQuery::Trajectory {
781 entity_id: 0,
782 filter: TemporalFilter::Range(5_000_000, 10_000_000),
783 })
784 .unwrap();
785
786 if let QueryResult::Trajectory(points) = result {
787 assert_eq!(points.len(), 6);
788 } else {
789 panic!("expected Trajectory result");
790 }
791 }
792
793 #[test]
794 fn velocity_returns_vector() {
795 let engine = setup_engine(1, 20, 4);
796 let result = engine
797 .execute(TemporalQuery::Velocity {
798 entity_id: 0,
799 timestamp: 10_000_000,
800 })
801 .unwrap();
802
803 if let QueryResult::Velocity(vel) = result {
804 assert_eq!(vel.len(), 4);
805 for &v in &vel {
806 assert!(v.is_finite());
807 }
808 } else {
809 panic!("expected Velocity result");
810 }
811 }
812
813 #[test]
814 fn velocity_insufficient_data() {
815 let config = HnswConfig::default();
816 let index = TemporalHnsw::new(config, L2Distance);
817 let store = InMemoryStore::new();
818 let engine = QueryEngine::new(index, store);
819
820 let result = engine.execute(TemporalQuery::Velocity {
821 entity_id: 999,
822 timestamp: 0,
823 });
824 assert!(result.is_err());
825 }
826
827 #[test]
828 fn prediction_linear_extrapolation() {
829 let engine = setup_engine(1, 20, 4);
830 let result = engine
831 .execute(TemporalQuery::Prediction {
832 entity_id: 0,
833 target_timestamp: 25_000_000,
834 })
835 .unwrap();
836
837 if let QueryResult::Prediction(pred) = result {
838 assert_eq!(pred.vector.len(), 4);
839 assert_eq!(pred.timestamp, 25_000_000);
840 assert!(matches!(pred.method, PredictionMethod::Linear));
841 } else {
842 panic!("expected Prediction result");
843 }
844 }
845
846 #[test]
847 fn changepoint_on_stationary() {
848 let engine = setup_engine(1, 50, 2);
849 let result = engine
850 .execute(TemporalQuery::ChangePointDetect {
851 entity_id: 0,
852 start: 0,
853 end: 50_000_000,
854 })
855 .unwrap();
856
857 if let QueryResult::ChangePoints(cps) = result {
858 assert!(
859 cps.len() <= 5,
860 "too many CPs on near-linear data: {}",
861 cps.len()
862 );
863 } else {
864 panic!("expected ChangePoints result");
865 }
866 }
867
868 #[test]
869 fn drift_quant_returns_report() {
870 let engine = setup_engine(1, 20, 4);
871 let result = engine
872 .execute(TemporalQuery::DriftQuant {
873 entity_id: 0,
874 t1: 0,
875 t2: 19_000_000,
876 top_n: 3,
877 })
878 .unwrap();
879
880 if let QueryResult::Drift(drift) = result {
881 assert!(drift.l2_magnitude > 0.0);
882 assert!(drift.top_dimensions.len() <= 3);
883 } else {
884 panic!("expected Drift result");
885 }
886 }
887
888 #[test]
889 fn analogy_computes_displacement() {
890 let engine = setup_engine(3, 20, 4);
891 let result = engine
892 .execute(TemporalQuery::Analogy {
893 entity_a: 0,
894 t1: 0,
895 t2: 10_000_000,
896 entity_b: 1,
897 t3: 5_000_000,
898 })
899 .unwrap();
900
901 if let QueryResult::Analogy(vec) = result {
902 assert_eq!(vec.len(), 4);
903 for &v in &vec {
904 assert!(v.is_finite());
905 }
906 } else {
907 panic!("expected Analogy result");
908 }
909 }
910
911 #[test]
912 fn cohort_drift_via_engine() {
913 let engine = setup_engine(5, 20, 4);
914 let result = engine
915 .execute(TemporalQuery::CohortDrift {
916 entity_ids: vec![0, 1, 2, 3, 4],
917 t1: 0,
918 t2: 19_000_000,
919 top_n: 3,
920 })
921 .unwrap();
922
923 if let QueryResult::CohortDrift(report) = result {
924 assert_eq!(report.n_entities, 5);
925 assert!(report.mean_drift_l2 > 0.0);
926 assert!(report.top_dimensions.len() <= 3);
927 assert!(
928 report.convergence_score > 0.0,
929 "entities with similar drift patterns should show some convergence"
930 );
931 } else {
932 panic!("expected CohortDrift result");
933 }
934 }
935
936 #[test]
937 fn cohort_drift_insufficient_entities() {
938 let engine = setup_engine(1, 10, 4);
939 let result = engine.execute(TemporalQuery::CohortDrift {
940 entity_ids: vec![0],
941 t1: 0,
942 t2: 9_000_000,
943 top_n: 3,
944 });
945 assert!(result.is_err());
946 }
947
948 #[test]
949 fn analogy_unknown_entity() {
950 let engine = setup_engine(1, 10, 4);
951 let result = engine.execute(TemporalQuery::Analogy {
952 entity_a: 999,
953 t1: 0,
954 t2: 1000,
955 entity_b: 0,
956 t3: 0,
957 });
958 assert!(result.is_err());
959 }
960}