cvx_query/
engine.rs

1//! Query execution engine.
2//!
3//! Orchestrates index, storage, and analytics to execute temporal queries.
4
5use cvx_core::error::QueryError;
6use cvx_core::types::TemporalFilter;
7use cvx_core::{StorageBackend, TemporalIndexAccess};
8
9use cvx_analytics::calculus;
10use cvx_analytics::cohort;
11use cvx_analytics::counterfactual;
12use cvx_analytics::granger;
13use cvx_analytics::motifs;
14use cvx_analytics::ode;
15use cvx_analytics::pelt::{self, PeltConfig};
16use cvx_analytics::temporal_join;
17
18use crate::types::*;
19
20/// Query engine that coordinates index + storage + analytics.
21///
22/// Generic over the index `I` (any `TemporalIndexAccess` impl) and storage `S`.
23/// This allows the same engine to work with both `TemporalHnsw` (single-threaded)
24/// and `ConcurrentTemporalHnsw` (thread-safe).
25pub struct QueryEngine<I: TemporalIndexAccess, S: StorageBackend> {
26    index: I,
27    store: S,
28}
29
30impl<I: TemporalIndexAccess, S: StorageBackend> QueryEngine<I, S> {
31    /// Create a new query engine.
32    pub fn new(index: I, store: S) -> Self {
33        Self { index, store }
34    }
35
36    /// Execute a temporal query.
37    pub fn execute(&self, query: TemporalQuery) -> Result<QueryResult, QueryError> {
38        execute_query(&self.index, query)
39    }
40
41    /// Access the underlying index.
42    pub fn index(&self) -> &I {
43        &self.index
44    }
45
46    /// Access the underlying store.
47    pub fn store(&self) -> &S {
48        &self.store
49    }
50}
51
52/// Execute a temporal query against any index without owning it.
53///
54/// This is the primary entry point for API handlers that share index/store
55/// via `Arc<AppState>`.
56pub fn execute_query(
57    index: &dyn TemporalIndexAccess,
58    query: TemporalQuery,
59) -> Result<QueryResult, QueryError> {
60    match query {
61        TemporalQuery::SnapshotKnn {
62            vector,
63            timestamp,
64            k,
65        } => {
66            let results = index.search_raw(
67                &vector,
68                k,
69                TemporalFilter::Snapshot(timestamp),
70                1.0,
71                timestamp,
72            );
73            Ok(QueryResult::Knn(
74                results
75                    .into_iter()
76                    .map(|(node_id, score)| KnnResult {
77                        entity_id: index.entity_id(node_id),
78                        timestamp: index.timestamp(node_id),
79                        score,
80                    })
81                    .collect(),
82            ))
83        }
84
85        TemporalQuery::RangeKnn {
86            vector,
87            start,
88            end,
89            k,
90            alpha,
91        } => {
92            let mid = start + (end - start) / 2;
93            let results =
94                index.search_raw(&vector, k, TemporalFilter::Range(start, end), alpha, mid);
95            Ok(QueryResult::Knn(
96                results
97                    .into_iter()
98                    .map(|(node_id, score)| KnnResult {
99                        entity_id: index.entity_id(node_id),
100                        timestamp: index.timestamp(node_id),
101                        score,
102                    })
103                    .collect(),
104            ))
105        }
106
107        TemporalQuery::Trajectory { entity_id, filter } => {
108            let traj = index.trajectory(entity_id, filter);
109            let points = traj
110                .iter()
111                .map(|&(ts, node_id)| {
112                    cvx_core::TemporalPoint::new(entity_id, ts, index.vector(node_id))
113                })
114                .collect();
115            Ok(QueryResult::Trajectory(points))
116        }
117
118        TemporalQuery::Velocity {
119            entity_id,
120            timestamp,
121        } => do_velocity(index, entity_id, timestamp),
122
123        TemporalQuery::Prediction {
124            entity_id,
125            target_timestamp,
126        } => do_prediction(index, entity_id, target_timestamp),
127
128        TemporalQuery::ChangePointDetect {
129            entity_id,
130            start,
131            end,
132        } => do_change_points(index, entity_id, start, end),
133
134        TemporalQuery::DriftQuant {
135            entity_id,
136            t1,
137            t2,
138            top_n,
139        } => do_drift_quant(index, entity_id, t1, t2, top_n),
140
141        TemporalQuery::Analogy {
142            entity_a,
143            t1,
144            t2,
145            entity_b,
146            t3,
147        } => do_analogy(index, entity_a, t1, t2, entity_b, t3),
148
149        TemporalQuery::Counterfactual {
150            entity_id,
151            change_point,
152        } => do_counterfactual(index, entity_id, change_point),
153
154        TemporalQuery::GrangerCausality {
155            entity_a,
156            entity_b,
157            max_lag,
158            significance,
159        } => do_granger(index, entity_a, entity_b, max_lag, significance),
160
161        TemporalQuery::DiscoverMotifs {
162            entity_id,
163            window,
164            max_motifs,
165        } => do_discover_motifs(index, entity_id, window, max_motifs),
166
167        TemporalQuery::DiscoverDiscords {
168            entity_id,
169            window,
170            max_discords,
171        } => do_discover_discords(index, entity_id, window, max_discords),
172
173        TemporalQuery::TemporalJoin {
174            entity_a,
175            entity_b,
176            epsilon,
177            window_us,
178        } => do_temporal_join(index, entity_a, entity_b, epsilon, window_us),
179
180        TemporalQuery::CohortDrift {
181            entity_ids,
182            t1,
183            t2,
184            top_n,
185        } => do_cohort_drift(index, &entity_ids, t1, t2, top_n),
186
187        TemporalQuery::CausalSearch {
188            vector,
189            k,
190            filter,
191            alpha,
192            query_timestamp,
193            temporal_context,
194        } => do_causal_search(
195            index,
196            &vector,
197            k,
198            filter,
199            alpha,
200            query_timestamp,
201            temporal_context,
202        ),
203    }
204}
205
206// ─── Internal helpers ──────────────────────────────────────────────
207
208fn build_traj(
209    index: &dyn TemporalIndexAccess,
210    entity_id: u64,
211    filter: TemporalFilter,
212) -> (Vec<(i64, u32)>, Vec<Vec<f32>>) {
213    let traj_data = index.trajectory(entity_id, filter);
214    let vectors: Vec<Vec<f32>> = traj_data
215        .iter()
216        .map(|&(_, node_id)| index.vector(node_id))
217        .collect();
218    (traj_data, vectors)
219}
220
221fn to_slices<'a>(traj_data: &'a [(i64, u32)], vectors: &'a [Vec<f32>]) -> Vec<(i64, &'a [f32])> {
222    traj_data
223        .iter()
224        .zip(vectors.iter())
225        .map(|(&(ts, _), v)| (ts, v.as_slice()))
226        .collect()
227}
228
229fn find_nearest(
230    index: &dyn TemporalIndexAccess,
231    entity_id: u64,
232    timestamp: i64,
233) -> Result<u32, QueryError> {
234    let traj = index.trajectory(entity_id, TemporalFilter::All);
235    if traj.is_empty() {
236        return Err(QueryError::EntityNotFound(entity_id));
237    }
238    let (_, node_id) = traj
239        .iter()
240        .min_by_key(|&&(ts, _)| (ts - timestamp).unsigned_abs())
241        .unwrap();
242    Ok(*node_id)
243}
244
245fn do_velocity(
246    index: &dyn TemporalIndexAccess,
247    entity_id: u64,
248    timestamp: i64,
249) -> Result<QueryResult, QueryError> {
250    let (traj_data, vectors) = build_traj(index, entity_id, TemporalFilter::All);
251    if traj_data.len() < 2 {
252        return Err(QueryError::InsufficientData {
253            needed: 2,
254            have: traj_data.len(),
255        });
256    }
257    let traj = to_slices(&traj_data, &vectors);
258    let vel = calculus::velocity(&traj, timestamp).map_err(|_| QueryError::InsufficientData {
259        needed: 2,
260        have: traj.len(),
261    })?;
262    Ok(QueryResult::Velocity(vel))
263}
264
265fn do_prediction(
266    index: &dyn TemporalIndexAccess,
267    entity_id: u64,
268    target_timestamp: i64,
269) -> Result<QueryResult, QueryError> {
270    let (traj_data, vectors) = build_traj(index, entity_id, TemporalFilter::All);
271    if traj_data.len() < 2 {
272        return Err(QueryError::InsufficientData {
273            needed: 2,
274            have: traj_data.len(),
275        });
276    }
277    let traj = to_slices(&traj_data, &vectors);
278    let predicted = ode::linear_extrapolate(&traj, target_timestamp).map_err(|_| {
279        QueryError::InsufficientData {
280            needed: 2,
281            have: traj.len(),
282        }
283    })?;
284    Ok(QueryResult::Prediction(PredictionResult {
285        vector: predicted,
286        timestamp: target_timestamp,
287        method: PredictionMethod::Linear,
288    }))
289}
290
291fn do_change_points(
292    index: &dyn TemporalIndexAccess,
293    entity_id: u64,
294    start: i64,
295    end: i64,
296) -> Result<QueryResult, QueryError> {
297    let (traj_data, vectors) = build_traj(index, entity_id, TemporalFilter::Range(start, end));
298    if traj_data.len() < 4 {
299        return Ok(QueryResult::ChangePoints(Vec::new()));
300    }
301    let traj = to_slices(&traj_data, &vectors);
302    let cps = pelt::detect(entity_id, &traj, &PeltConfig::default());
303    Ok(QueryResult::ChangePoints(cps))
304}
305
306fn do_drift_quant(
307    index: &dyn TemporalIndexAccess,
308    entity_id: u64,
309    t1: i64,
310    t2: i64,
311    top_n: usize,
312) -> Result<QueryResult, QueryError> {
313    let p1 = find_nearest(index, entity_id, t1)?;
314    let p2 = find_nearest(index, entity_id, t2)?;
315    let v1 = index.vector(p1);
316    let v2 = index.vector(p2);
317    let report = calculus::drift_report(&v1, &v2, top_n);
318    Ok(QueryResult::Drift(DriftResult {
319        l2_magnitude: report.l2_magnitude,
320        cosine_drift: report.cosine_drift,
321        top_dimensions: report.top_dimensions,
322    }))
323}
324
325fn do_counterfactual(
326    index: &dyn TemporalIndexAccess,
327    entity_id: u64,
328    change_point: i64,
329) -> Result<QueryResult, QueryError> {
330    let (td_pre, vecs_pre) = build_traj(index, entity_id, TemporalFilter::Before(change_point));
331    let (td_post, vecs_post) = build_traj(index, entity_id, TemporalFilter::After(change_point));
332
333    if td_pre.len() < 2 {
334        return Err(QueryError::InsufficientData {
335            needed: 2,
336            have: td_pre.len(),
337        });
338    }
339    if td_post.is_empty() {
340        return Err(QueryError::InsufficientData { needed: 1, have: 0 });
341    }
342
343    let pre = to_slices(&td_pre, &vecs_pre);
344    let post = to_slices(&td_post, &vecs_post);
345
346    let result =
347        counterfactual::counterfactual_trajectory(&pre, &post, change_point).map_err(|_| {
348            QueryError::InsufficientData {
349                needed: 2,
350                have: td_pre.len(),
351            }
352        })?;
353
354    Ok(QueryResult::Counterfactual(CounterfactualQueryResult {
355        change_point,
356        total_divergence: result.total_divergence,
357        max_divergence_time: result.max_divergence_time,
358        max_divergence_value: result.max_divergence_value,
359        divergence_curve: result.divergence_curve,
360        method: format!("{:?}", result.method),
361    }))
362}
363
364fn do_granger(
365    index: &dyn TemporalIndexAccess,
366    entity_a: u64,
367    entity_b: u64,
368    max_lag: usize,
369    significance: f64,
370) -> Result<QueryResult, QueryError> {
371    let (td_a, vecs_a) = build_traj(index, entity_a, TemporalFilter::All);
372    let (td_b, vecs_b) = build_traj(index, entity_b, TemporalFilter::All);
373
374    if td_a.is_empty() {
375        return Err(QueryError::EntityNotFound(entity_a));
376    }
377    if td_b.is_empty() {
378        return Err(QueryError::EntityNotFound(entity_b));
379    }
380
381    let traj_a = to_slices(&td_a, &vecs_a);
382    let traj_b = to_slices(&td_b, &vecs_b);
383
384    let result =
385        granger::granger_causality(&traj_a, &traj_b, max_lag, significance).map_err(|_| {
386            QueryError::InsufficientData {
387                needed: max_lag + 3,
388                have: traj_a.len().min(traj_b.len()),
389            }
390        })?;
391
392    let direction = match result.direction {
393        granger::GrangerDirection::AToB => "a_to_b",
394        granger::GrangerDirection::BToA => "b_to_a",
395        granger::GrangerDirection::Bidirectional => "bidirectional",
396        granger::GrangerDirection::None => "none",
397    };
398
399    Ok(QueryResult::Granger(GrangerCausalityResult {
400        direction: direction.to_string(),
401        optimal_lag: result.optimal_lag,
402        f_statistic: result.f_statistic,
403        p_value: result.p_value,
404        effect_size: result.effect_size,
405        per_dimension_a_to_b: result.per_dimension_a_to_b,
406        per_dimension_b_to_a: result.per_dimension_b_to_a,
407    }))
408}
409
410fn do_discover_motifs(
411    index: &dyn TemporalIndexAccess,
412    entity_id: u64,
413    window: usize,
414    max_motifs: usize,
415) -> Result<QueryResult, QueryError> {
416    let (td, vecs) = build_traj(index, entity_id, TemporalFilter::All);
417    if td.is_empty() {
418        return Err(QueryError::EntityNotFound(entity_id));
419    }
420    let traj = to_slices(&td, &vecs);
421    let found = motifs::discover_motifs(&traj, window, max_motifs, 0.5).map_err(|_| {
422        QueryError::InsufficientData {
423            needed: 2 * window,
424            have: traj.len(),
425        }
426    })?;
427
428    Ok(QueryResult::Motifs(
429        found
430            .into_iter()
431            .map(|m| MotifResult {
432                canonical_index: m.canonical_index,
433                occurrences: m
434                    .occurrences
435                    .into_iter()
436                    .map(|o| MotifOccurrenceResult {
437                        start_index: o.start_index,
438                        timestamp: o.timestamp,
439                        distance: o.distance,
440                    })
441                    .collect(),
442                period: m.period,
443                mean_match_distance: m.mean_match_distance,
444            })
445            .collect(),
446    ))
447}
448
449fn do_discover_discords(
450    index: &dyn TemporalIndexAccess,
451    entity_id: u64,
452    window: usize,
453    max_discords: usize,
454) -> Result<QueryResult, QueryError> {
455    let (td, vecs) = build_traj(index, entity_id, TemporalFilter::All);
456    if td.is_empty() {
457        return Err(QueryError::EntityNotFound(entity_id));
458    }
459    let traj = to_slices(&td, &vecs);
460    let found = motifs::discover_discords(&traj, window, max_discords).map_err(|_| {
461        QueryError::InsufficientData {
462            needed: 2 * window,
463            have: traj.len(),
464        }
465    })?;
466
467    Ok(QueryResult::Discords(
468        found
469            .into_iter()
470            .map(|d| DiscordResult {
471                start_index: d.start_index,
472                timestamp: d.timestamp,
473                nn_distance: d.nn_distance,
474            })
475            .collect(),
476    ))
477}
478
479fn do_temporal_join(
480    index: &dyn TemporalIndexAccess,
481    entity_a: u64,
482    entity_b: u64,
483    epsilon: f32,
484    window_us: i64,
485) -> Result<QueryResult, QueryError> {
486    let (td_a, vecs_a) = build_traj(index, entity_a, TemporalFilter::All);
487    let (td_b, vecs_b) = build_traj(index, entity_b, TemporalFilter::All);
488
489    if td_a.is_empty() {
490        return Err(QueryError::EntityNotFound(entity_a));
491    }
492    if td_b.is_empty() {
493        return Err(QueryError::EntityNotFound(entity_b));
494    }
495
496    let traj_a = to_slices(&td_a, &vecs_a);
497    let traj_b = to_slices(&td_b, &vecs_b);
498
499    let joins = temporal_join::temporal_join(&traj_a, &traj_b, epsilon, window_us)
500        .map_err(|_| QueryError::InsufficientData { needed: 1, have: 0 })?;
501
502    Ok(QueryResult::TemporalJoin(
503        joins
504            .into_iter()
505            .map(|j| TemporalJoinResultEntry {
506                start: j.start,
507                end: j.end,
508                mean_distance: j.mean_distance,
509                min_distance: j.min_distance,
510                points_a: j.points_a,
511                points_b: j.points_b,
512            })
513            .collect(),
514    ))
515}
516
517fn do_cohort_drift(
518    index: &dyn TemporalIndexAccess,
519    entity_ids: &[u64],
520    t1: i64,
521    t2: i64,
522    top_n: usize,
523) -> Result<QueryResult, QueryError> {
524    // Build trajectories for all entities
525    #[allow(clippy::type_complexity)]
526    let mut traj_data: Vec<(Vec<(i64, u32)>, Vec<Vec<f32>>)> = Vec::new();
527    let mut valid_ids: Vec<u64> = Vec::new();
528
529    for &eid in entity_ids {
530        let (td, vecs) = build_traj(index, eid, TemporalFilter::All);
531        if !td.is_empty() {
532            traj_data.push((td, vecs));
533            valid_ids.push(eid);
534        }
535    }
536
537    if valid_ids.len() < 2 {
538        return Err(QueryError::InsufficientData {
539            needed: 2,
540            have: valid_ids.len(),
541        });
542    }
543
544    // Build slice-based trajectories for the cohort function
545    let slice_trajs: Vec<Vec<(i64, &[f32])>> = traj_data
546        .iter()
547        .map(|(td, vecs)| to_slices(td, vecs))
548        .collect();
549
550    #[allow(clippy::type_complexity)]
551    let input: Vec<(u64, &[(i64, &[f32])])> = valid_ids
552        .iter()
553        .zip(slice_trajs.iter())
554        .map(|(&eid, st)| (eid, st.as_slice()))
555        .collect();
556
557    let report =
558        cohort::cohort_drift(&input, t1, t2, top_n).map_err(|_| QueryError::InsufficientData {
559            needed: 2,
560            have: valid_ids.len(),
561        })?;
562
563    Ok(QueryResult::CohortDrift(CohortDriftResult {
564        n_entities: report.n_entities,
565        mean_drift_l2: report.mean_drift_l2,
566        median_drift_l2: report.median_drift_l2,
567        std_drift_l2: report.std_drift_l2,
568        centroid_l2_magnitude: report.centroid_drift.l2_magnitude,
569        centroid_cosine_drift: report.centroid_drift.cosine_drift,
570        dispersion_t1: report.dispersion_t1,
571        dispersion_t2: report.dispersion_t2,
572        dispersion_change: report.dispersion_change,
573        convergence_score: report.convergence_score,
574        top_dimensions: report.top_dimensions,
575        outliers: report
576            .outliers
577            .into_iter()
578            .map(|o| CohortOutlierResult {
579                entity_id: o.entity_id,
580                drift_magnitude: o.drift_magnitude,
581                z_score: o.z_score,
582                drift_direction_alignment: o.drift_direction_alignment,
583            })
584            .collect(),
585    }))
586}
587
588fn do_analogy(
589    index: &dyn TemporalIndexAccess,
590    entity_a: u64,
591    t1: i64,
592    t2: i64,
593    entity_b: u64,
594    t3: i64,
595) -> Result<QueryResult, QueryError> {
596    let a1 = find_nearest(index, entity_a, t1)?;
597    let a2 = find_nearest(index, entity_a, t2)?;
598    let b3 = find_nearest(index, entity_b, t3)?;
599    let va1 = index.vector(a1);
600    let va2 = index.vector(a2);
601    let vb3 = index.vector(b3);
602    let result: Vec<f32> = vb3
603        .iter()
604        .zip(va2.iter().zip(va1.iter()))
605        .map(|(&b, (&a2, &a1))| b + (a2 - a1))
606        .collect();
607    Ok(QueryResult::Analogy(result))
608}
609
610/// Causal search: kNN + temporal context via trajectory walk.
611///
612/// Works with ANY TemporalIndexAccess (not just TemporalGraphIndex).
613/// For full hybrid beam search, use TemporalGraphIndex directly.
614fn do_causal_search(
615    index: &dyn TemporalIndexAccess,
616    vector: &[f32],
617    k: usize,
618    filter: TemporalFilter,
619    alpha: f32,
620    query_timestamp: i64,
621    temporal_context: usize,
622) -> Result<QueryResult, QueryError> {
623    let results = index.search_raw(vector, k, filter, alpha, query_timestamp);
624
625    let causal: Vec<CausalSearchResultEntry> = results
626        .into_iter()
627        .map(|(node_id, score)| {
628            let entity_id = index.entity_id(node_id);
629            let _ts = index.timestamp(node_id);
630
631            // Get entity's full trajectory to find temporal neighbors
632            let traj = index.trajectory(entity_id, TemporalFilter::All);
633
634            // Find this node's position in the trajectory
635            let pos = traj.iter().position(|&(_, nid)| nid == node_id);
636
637            let (successors, predecessors) = match pos {
638                Some(p) => {
639                    let succ: Vec<(u32, i64)> = traj[p + 1..]
640                        .iter()
641                        .take(temporal_context)
642                        .map(|&(t, nid)| (nid, t))
643                        .collect();
644                    let pred: Vec<(u32, i64)> = traj[..p]
645                        .iter()
646                        .rev()
647                        .take(temporal_context)
648                        .map(|&(t, nid)| (nid, t))
649                        .rev()
650                        .collect();
651                    (succ, pred)
652                }
653                None => (Vec::new(), Vec::new()),
654            };
655
656            CausalSearchResultEntry {
657                node_id,
658                score,
659                entity_id,
660                successors,
661                predecessors,
662            }
663        })
664        .collect();
665
666    Ok(QueryResult::CausalSearch(causal))
667}
668
669#[cfg(test)]
670mod tests {
671    use super::*;
672    use cvx_core::TemporalPoint;
673    use cvx_index::hnsw::HnswConfig;
674    use cvx_index::hnsw::temporal::TemporalHnsw;
675    use cvx_index::metrics::L2Distance;
676    use cvx_storage::memory::InMemoryStore;
677
678    fn setup_engine(
679        n_entities: u64,
680        points_per_entity: usize,
681        dim: usize,
682    ) -> QueryEngine<TemporalHnsw<L2Distance>, InMemoryStore> {
683        let config = HnswConfig {
684            m: 16,
685            ef_construction: 200,
686            ef_search: 100,
687            ..Default::default()
688        };
689        let mut index = TemporalHnsw::new(config, L2Distance);
690        let store = InMemoryStore::new();
691
692        for e in 0..n_entities {
693            for i in 0..points_per_entity {
694                let ts = (i as i64) * 1_000_000;
695                let v: Vec<f32> = (0..dim)
696                    .map(|d| (e as f32 * 10.0) + (i as f32 * 0.1) + (d as f32 * 0.01))
697                    .collect();
698                index.insert(e, ts, &v);
699                store.put(0, &TemporalPoint::new(e, ts, v)).unwrap();
700            }
701        }
702
703        QueryEngine::new(index, store)
704    }
705
706    #[test]
707    fn snapshot_knn_returns_at_timestamp() {
708        let engine = setup_engine(5, 20, 4);
709        let result = engine
710            .execute(TemporalQuery::SnapshotKnn {
711                vector: vec![0.0; 4],
712                timestamp: 5_000_000,
713                k: 3,
714            })
715            .unwrap();
716
717        if let QueryResult::Knn(results) = result {
718            for r in &results {
719                assert_eq!(r.timestamp, 5_000_000);
720            }
721        } else {
722            panic!("expected Knn result");
723        }
724    }
725
726    #[test]
727    fn range_knn_returns_in_range() {
728        let engine = setup_engine(5, 20, 4);
729        let result = engine
730            .execute(TemporalQuery::RangeKnn {
731                vector: vec![0.0; 4],
732                start: 3_000_000,
733                end: 7_000_000,
734                k: 10,
735                alpha: 1.0,
736            })
737            .unwrap();
738
739        if let QueryResult::Knn(results) = result {
740            assert!(!results.is_empty());
741            for r in &results {
742                assert!(
743                    r.timestamp >= 3_000_000 && r.timestamp <= 7_000_000,
744                    "ts {} out of range",
745                    r.timestamp
746                );
747            }
748        } else {
749            panic!("expected Knn result");
750        }
751    }
752
753    #[test]
754    fn trajectory_returns_all_points_ordered() {
755        let engine = setup_engine(3, 20, 4);
756        let result = engine
757            .execute(TemporalQuery::Trajectory {
758                entity_id: 1,
759                filter: TemporalFilter::All,
760            })
761            .unwrap();
762
763        if let QueryResult::Trajectory(points) = result {
764            assert_eq!(points.len(), 20);
765            for w in points.windows(2) {
766                assert!(w[0].timestamp() <= w[1].timestamp());
767            }
768            for p in &points {
769                assert_eq!(p.entity_id(), 1);
770            }
771        } else {
772            panic!("expected Trajectory result");
773        }
774    }
775
776    #[test]
777    fn trajectory_with_range_filter() {
778        let engine = setup_engine(1, 20, 4);
779        let result = engine
780            .execute(TemporalQuery::Trajectory {
781                entity_id: 0,
782                filter: TemporalFilter::Range(5_000_000, 10_000_000),
783            })
784            .unwrap();
785
786        if let QueryResult::Trajectory(points) = result {
787            assert_eq!(points.len(), 6);
788        } else {
789            panic!("expected Trajectory result");
790        }
791    }
792
793    #[test]
794    fn velocity_returns_vector() {
795        let engine = setup_engine(1, 20, 4);
796        let result = engine
797            .execute(TemporalQuery::Velocity {
798                entity_id: 0,
799                timestamp: 10_000_000,
800            })
801            .unwrap();
802
803        if let QueryResult::Velocity(vel) = result {
804            assert_eq!(vel.len(), 4);
805            for &v in &vel {
806                assert!(v.is_finite());
807            }
808        } else {
809            panic!("expected Velocity result");
810        }
811    }
812
813    #[test]
814    fn velocity_insufficient_data() {
815        let config = HnswConfig::default();
816        let index = TemporalHnsw::new(config, L2Distance);
817        let store = InMemoryStore::new();
818        let engine = QueryEngine::new(index, store);
819
820        let result = engine.execute(TemporalQuery::Velocity {
821            entity_id: 999,
822            timestamp: 0,
823        });
824        assert!(result.is_err());
825    }
826
827    #[test]
828    fn prediction_linear_extrapolation() {
829        let engine = setup_engine(1, 20, 4);
830        let result = engine
831            .execute(TemporalQuery::Prediction {
832                entity_id: 0,
833                target_timestamp: 25_000_000,
834            })
835            .unwrap();
836
837        if let QueryResult::Prediction(pred) = result {
838            assert_eq!(pred.vector.len(), 4);
839            assert_eq!(pred.timestamp, 25_000_000);
840            assert!(matches!(pred.method, PredictionMethod::Linear));
841        } else {
842            panic!("expected Prediction result");
843        }
844    }
845
846    #[test]
847    fn changepoint_on_stationary() {
848        let engine = setup_engine(1, 50, 2);
849        let result = engine
850            .execute(TemporalQuery::ChangePointDetect {
851                entity_id: 0,
852                start: 0,
853                end: 50_000_000,
854            })
855            .unwrap();
856
857        if let QueryResult::ChangePoints(cps) = result {
858            assert!(
859                cps.len() <= 5,
860                "too many CPs on near-linear data: {}",
861                cps.len()
862            );
863        } else {
864            panic!("expected ChangePoints result");
865        }
866    }
867
868    #[test]
869    fn drift_quant_returns_report() {
870        let engine = setup_engine(1, 20, 4);
871        let result = engine
872            .execute(TemporalQuery::DriftQuant {
873                entity_id: 0,
874                t1: 0,
875                t2: 19_000_000,
876                top_n: 3,
877            })
878            .unwrap();
879
880        if let QueryResult::Drift(drift) = result {
881            assert!(drift.l2_magnitude > 0.0);
882            assert!(drift.top_dimensions.len() <= 3);
883        } else {
884            panic!("expected Drift result");
885        }
886    }
887
888    #[test]
889    fn analogy_computes_displacement() {
890        let engine = setup_engine(3, 20, 4);
891        let result = engine
892            .execute(TemporalQuery::Analogy {
893                entity_a: 0,
894                t1: 0,
895                t2: 10_000_000,
896                entity_b: 1,
897                t3: 5_000_000,
898            })
899            .unwrap();
900
901        if let QueryResult::Analogy(vec) = result {
902            assert_eq!(vec.len(), 4);
903            for &v in &vec {
904                assert!(v.is_finite());
905            }
906        } else {
907            panic!("expected Analogy result");
908        }
909    }
910
911    #[test]
912    fn cohort_drift_via_engine() {
913        let engine = setup_engine(5, 20, 4);
914        let result = engine
915            .execute(TemporalQuery::CohortDrift {
916                entity_ids: vec![0, 1, 2, 3, 4],
917                t1: 0,
918                t2: 19_000_000,
919                top_n: 3,
920            })
921            .unwrap();
922
923        if let QueryResult::CohortDrift(report) = result {
924            assert_eq!(report.n_entities, 5);
925            assert!(report.mean_drift_l2 > 0.0);
926            assert!(report.top_dimensions.len() <= 3);
927            assert!(
928                report.convergence_score > 0.0,
929                "entities with similar drift patterns should show some convergence"
930            );
931        } else {
932            panic!("expected CohortDrift result");
933        }
934    }
935
936    #[test]
937    fn cohort_drift_insufficient_entities() {
938        let engine = setup_engine(1, 10, 4);
939        let result = engine.execute(TemporalQuery::CohortDrift {
940            entity_ids: vec![0],
941            t1: 0,
942            t2: 9_000_000,
943            top_n: 3,
944        });
945        assert!(result.is_err());
946    }
947
948    #[test]
949    fn analogy_unknown_entity() {
950        let engine = setup_engine(1, 10, 4);
951        let result = engine.execute(TemporalQuery::Analogy {
952            entity_a: 999,
953            t1: 0,
954            t2: 1000,
955            entity_b: 0,
956            t3: 0,
957        });
958        assert!(result.is_err());
959    }
960}