cvx_index/hnsw/
temporal_graph.rs

1//! Temporal Graph Index — HNSW extended with temporal successor edges (RFC-010).
2//!
3//! Composites `TemporalHnsw` with `TemporalEdgeLayer` to enable:
4//! 1. **Causal search**: find semantic neighbors, then walk temporal edges for context
5//! 2. **Hybrid search**: beam search that explores both semantic AND temporal edges
6
7use std::cmp::Reverse;
8use std::collections::{BinaryHeap, HashSet};
9use std::path::Path;
10
11use cvx_core::traits::DistanceMetric;
12use cvx_core::types::TemporalFilter;
13
14use super::HnswConfig;
15use super::temporal::TemporalHnsw;
16use super::temporal_edges::TemporalEdgeLayer;
17use super::typed_edges::{EdgeType, TypedEdgeStore};
18
19// ─── Types ──────────────────────────────────────────────────────────
20
21/// A search result with causal temporal context.
22#[derive(Debug, Clone)]
23pub struct CausalSearchResult {
24    /// The semantically matched node.
25    pub node_id: u32,
26    /// Distance score.
27    pub score: f32,
28    /// Entity that owns this node.
29    pub entity_id: u64,
30    /// Temporal successors: what happened NEXT to this entity.
31    pub successors: Vec<(u32, i64)>,
32    /// Temporal predecessors: what happened BEFORE.
33    pub predecessors: Vec<(u32, i64)>,
34}
35
36// ─── TemporalGraphIndex ─────────────────────────────────────────────
37
38/// Temporal Graph Index: HNSW with temporal successor/predecessor edges.
39///
40/// Wraps `TemporalHnsw` (untouched) with a `TemporalEdgeLayer` for
41/// causal navigation and hybrid search.
42pub struct TemporalGraphIndex<D: DistanceMetric> {
43    /// The underlying spatiotemporal HNSW index.
44    inner: TemporalHnsw<D>,
45    /// Temporal edge layer (successor/predecessor per entity).
46    edges: TemporalEdgeLayer,
47    /// Typed relational edges (RFC-013 Part B).
48    typed_edges: TypedEdgeStore,
49}
50
51impl<D: DistanceMetric + Clone> TemporalGraphIndex<D> {
52    /// Create a new empty temporal graph index.
53    pub fn new(config: HnswConfig, metric: D) -> Self {
54        Self {
55            inner: TemporalHnsw::new(config, metric),
56            edges: TemporalEdgeLayer::new(),
57            typed_edges: TypedEdgeStore::new(),
58        }
59    }
60
61    /// Create from an existing TemporalHnsw (migration path).
62    ///
63    /// Rebuilds the temporal edge layer from the entity_index.
64    pub fn from_temporal_hnsw(inner: TemporalHnsw<D>) -> Self {
65        let mut edges = TemporalEdgeLayer::with_capacity(inner.len());
66
67        // We need to register all nodes in order.
68        // The entity_index has (timestamp, node_id) sorted by timestamp per entity.
69        // But we must register in node_id order (0, 1, 2, ...).
70
71        // Build a mapping: node_id → its predecessor in the entity chain
72        let mut pred_map: Vec<Option<u32>> = vec![None; inner.len()];
73
74        for nid in 0..inner.len() as u32 {
75            let eid = inner.entity_id(nid);
76            // Find the previous node for this entity (node with closest earlier timestamp)
77            let traj = inner.trajectory(eid, TemporalFilter::All);
78            let my_ts = inner.timestamp(nid);
79
80            let prev = traj
81                .iter()
82                .filter(|&&(ts, id)| ts < my_ts || (ts == my_ts && id < nid))
83                .max_by_key(|&&(ts, _)| ts)
84                .map(|&(_, id)| id);
85
86            pred_map[nid as usize] = prev;
87        }
88
89        for nid in 0..inner.len() as u32 {
90            edges.register(nid, pred_map[nid as usize]);
91        }
92
93        Self {
94            inner,
95            edges,
96            typed_edges: TypedEdgeStore::new(),
97        }
98    }
99
100    /// Insert a temporal point.
101    pub fn insert(&mut self, entity_id: u64, timestamp: i64, vector: &[f32]) -> u32 {
102        let last_node = self.inner.entity_last_node(entity_id);
103        let node_id = self.inner.insert(entity_id, timestamp, vector);
104        self.edges.register(node_id, last_node);
105        node_id
106    }
107
108    /// Insert a temporal point with an outcome reward.
109    pub fn insert_with_reward(
110        &mut self,
111        entity_id: u64,
112        timestamp: i64,
113        vector: &[f32],
114        reward: f32,
115    ) -> u32 {
116        let last_node = self.inner.entity_last_node(entity_id);
117        let node_id = self
118            .inner
119            .insert_with_reward(entity_id, timestamp, vector, reward);
120        self.edges.register(node_id, last_node);
121        node_id
122    }
123
124    /// Standard search (delegates to inner TemporalHnsw).
125    pub fn search(
126        &self,
127        query: &[f32],
128        k: usize,
129        filter: TemporalFilter,
130        alpha: f32,
131        query_timestamp: i64,
132    ) -> Vec<(u32, f32)> {
133        self.inner.search(query, k, filter, alpha, query_timestamp)
134    }
135
136    /// Causal search: semantic search + temporal edge context.
137    ///
138    /// Phase 1: Standard HNSW search.
139    /// Phase 2: For each result, walk temporal edges to get what happened
140    /// before and after.
141    ///
142    /// Answers: "Find similar entities, and show me what happened to them next."
143    pub fn causal_search(
144        &self,
145        query: &[f32],
146        k: usize,
147        filter: TemporalFilter,
148        alpha: f32,
149        query_timestamp: i64,
150        temporal_context: usize,
151    ) -> Vec<CausalSearchResult> {
152        let results = self.inner.search(query, k, filter, alpha, query_timestamp);
153
154        results
155            .into_iter()
156            .map(|(node_id, score)| {
157                let entity_id = self.inner.entity_id(node_id);
158
159                let succ_ids = self.edges.walk_forward(node_id, temporal_context);
160                let successors: Vec<(u32, i64)> = succ_ids
161                    .into_iter()
162                    .map(|nid| (nid, self.inner.timestamp(nid)))
163                    .collect();
164
165                let pred_ids = self.edges.walk_backward(node_id, temporal_context);
166                let predecessors: Vec<(u32, i64)> = pred_ids
167                    .into_iter()
168                    .map(|nid| (nid, self.inner.timestamp(nid)))
169                    .collect();
170
171                CausalSearchResult {
172                    node_id,
173                    score,
174                    entity_id,
175                    successors,
176                    predecessors,
177                }
178            })
179            .collect()
180    }
181
182    /// Hybrid search: beam search exploring both semantic AND temporal edges.
183    ///
184    /// At each step of the beam search on level 0, when visiting a node,
185    /// also adds its temporal neighbors to the candidate set with a
186    /// distance penalty controlled by `beta`.
187    ///
188    /// - `beta = 0.0`: pure semantic HNSW (ignores temporal edges)
189    /// - `beta = 1.0`: always follow temporal edges (aggressive temporal exploration)
190    pub fn hybrid_search(
191        &self,
192        query: &[f32],
193        k: usize,
194        filter: TemporalFilter,
195        alpha: f32,
196        beta: f32,
197        query_timestamp: i64,
198    ) -> Vec<(u32, f32)> {
199        let graph = self.inner.graph();
200
201        if graph.is_empty() {
202            return Vec::new();
203        }
204
205        let entry = match graph.entry_point() {
206            Some(ep) => ep,
207            None => return Vec::new(),
208        };
209
210        let bitmap = self.inner.build_filter_bitmap(&filter);
211        let ef = graph.config().ef_search.max(k);
212
213        // Phase 1: greedy descent from entry point to level 0
214        let max_level = graph.max_level();
215        let mut current = entry;
216        let mut current_dist = graph.distance_to(current, query);
217
218        for level in (1..=max_level).rev() {
219            let mut improved = true;
220            while improved {
221                improved = false;
222                for &neighbor in graph.neighbors_at_level(current, level) {
223                    let d = graph.distance_to(neighbor, query);
224                    if d < current_dist {
225                        current = neighbor;
226                        current_dist = d;
227                        improved = true;
228                    }
229                }
230            }
231        }
232
233        // Phase 2: hybrid beam search on level 0
234        // candidates: min-heap (closest first to explore)
235        // results: max-heap (farthest first to evict)
236        let mut candidates: BinaryHeap<Reverse<(OrderedF32, u32)>> = BinaryHeap::new();
237        let mut results: BinaryHeap<(OrderedF32, u32)> = BinaryHeap::new();
238        let mut visited: HashSet<u32> = HashSet::new();
239
240        let entry_dist = graph.distance_to(current, query);
241        candidates.push(Reverse((OrderedF32(entry_dist), current)));
242        if bitmap.contains(current) {
243            results.push((OrderedF32(entry_dist), current));
244        }
245        visited.insert(current);
246
247        while let Some(Reverse((OrderedF32(c_dist), c_id))) = candidates.pop() {
248            let farthest_dist = results
249                .peek()
250                .map(|(OrderedF32(d), _)| *d)
251                .unwrap_or(f32::MAX);
252            if c_dist > farthest_dist && results.len() >= ef {
253                break;
254            }
255
256            // Explore semantic neighbors
257            let semantic_neighbors = graph.neighbors_at_level(c_id, 0);
258
259            // Explore temporal neighbors (weighted by beta)
260            let temporal_neighbors: Vec<u32> = if beta > 0.0 {
261                self.edges.temporal_neighbors(c_id).collect()
262            } else {
263                Vec::new()
264            };
265
266            // Process all neighbors
267            for &neighbor in semantic_neighbors.iter().chain(temporal_neighbors.iter()) {
268                if !visited.insert(neighbor) {
269                    continue;
270                }
271
272                // Skip if not in temporal filter
273                if !bitmap.contains(neighbor) {
274                    continue;
275                }
276
277                let mut dist = graph.distance_to(neighbor, query);
278
279                // Apply temporal component if alpha < 1.0
280                if alpha < 1.0 {
281                    let t_dist = self.inner.temporal_distance_normalized(
282                        self.inner.timestamp(neighbor),
283                        query_timestamp,
284                    );
285                    dist = alpha * dist + (1.0 - alpha) * t_dist;
286                }
287
288                let farthest = results
289                    .peek()
290                    .map(|(OrderedF32(d), _)| *d)
291                    .unwrap_or(f32::MAX);
292                if dist < farthest || results.len() < ef {
293                    candidates.push(Reverse((OrderedF32(dist), neighbor)));
294                    results.push((OrderedF32(dist), neighbor));
295                    if results.len() > ef {
296                        results.pop();
297                    }
298                }
299            }
300        }
301
302        // Collect and sort by distance
303        let mut final_results: Vec<(u32, f32)> = results
304            .into_iter()
305            .map(|(OrderedF32(d), nid)| (nid, d))
306            .collect();
307        final_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
308        final_results.truncate(k);
309        final_results
310    }
311
312    // ─── Accessors ──────────────────────────────────────────────
313
314    /// Access the underlying TemporalHnsw.
315    pub fn inner(&self) -> &TemporalHnsw<D> {
316        &self.inner
317    }
318
319    /// Access the temporal edge layer.
320    pub fn edges(&self) -> &TemporalEdgeLayer {
321        &self.edges
322    }
323
324    /// Get trajectory for an entity.
325    pub fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
326        self.inner.trajectory(entity_id, filter)
327    }
328
329    /// Get vector by node ID.
330    pub fn vector(&self, node_id: u32) -> &[f32] {
331        self.inner.vector(node_id)
332    }
333
334    /// Get entity ID by node ID.
335    pub fn entity_id(&self, node_id: u32) -> u64 {
336        self.inner.entity_id(node_id)
337    }
338
339    /// Get timestamp by node ID.
340    pub fn timestamp(&self, node_id: u32) -> i64 {
341        self.inner.timestamp(node_id)
342    }
343
344    /// Total number of points.
345    pub fn len(&self) -> usize {
346        self.inner.len()
347    }
348
349    /// Whether empty.
350    pub fn is_empty(&self) -> bool {
351        self.inner.is_empty()
352    }
353
354    // ─── Delegated configuration ──────────────────────────────────
355
356    /// Mutable access to the underlying TemporalHnsw.
357    pub fn inner_mut(&mut self) -> &mut TemporalHnsw<D> {
358        &mut self.inner
359    }
360
361    /// Get HNSW config.
362    pub fn config(&self) -> &super::HnswConfig {
363        self.inner.config()
364    }
365
366    /// Set ef_construction at runtime.
367    pub fn set_ef_construction(&mut self, ef: usize) {
368        self.inner.set_ef_construction(ef);
369    }
370
371    /// Set ef_search at runtime.
372    pub fn set_ef_search(&mut self, ef: usize) {
373        self.inner.set_ef_search(ef);
374    }
375
376    /// Enable scalar quantization.
377    pub fn enable_scalar_quantization(&mut self, min_val: f32, max_val: f32) {
378        self.inner.enable_scalar_quantization(min_val, max_val);
379    }
380
381    /// Disable scalar quantization.
382    pub fn disable_scalar_quantization(&mut self) {
383        self.inner.disable_scalar_quantization();
384    }
385
386    // ─── Delegated recency search (RFC-012 P7+P8) ──────────────────
387
388    /// Search with recency bias and normalized distances.
389    #[allow(clippy::too_many_arguments)]
390    pub fn search_with_recency(
391        &self,
392        query: &[f32],
393        k: usize,
394        filter: TemporalFilter,
395        alpha: f32,
396        query_timestamp: i64,
397        recency_lambda: f32,
398        recency_weight: f32,
399    ) -> Vec<(u32, f32)> {
400        self.inner.search_with_recency(
401            query,
402            k,
403            filter,
404            alpha,
405            query_timestamp,
406            recency_lambda,
407            recency_weight,
408        )
409    }
410
411    // ─── Delegated outcome / reward (RFC-012 P4) ───────────────────
412
413    /// Get the reward for a node.
414    pub fn reward(&self, node_id: u32) -> f32 {
415        self.inner.reward(node_id)
416    }
417
418    /// Set the reward for a node retroactively.
419    pub fn set_reward(&mut self, node_id: u32, reward: f32) {
420        self.inner.set_reward(node_id, reward);
421    }
422
423    /// Search with reward pre-filtering.
424    pub fn search_with_reward(
425        &self,
426        query: &[f32],
427        k: usize,
428        filter: TemporalFilter,
429        alpha: f32,
430        query_timestamp: i64,
431        min_reward: f32,
432    ) -> Vec<(u32, f32)> {
433        self.inner
434            .search_with_reward(query, k, filter, alpha, query_timestamp, min_reward)
435    }
436
437    // ─── Delegated centering (RFC-012 Part B) ─────────────────────
438
439    /// Compute the centroid of all vectors.
440    pub fn compute_centroid(&self) -> Option<Vec<f32>> {
441        self.inner.compute_centroid()
442    }
443
444    /// Set centroid for anisotropy correction.
445    pub fn set_centroid(&mut self, centroid: Vec<f32>) {
446        self.inner.set_centroid(centroid);
447    }
448
449    /// Clear centroid.
450    pub fn clear_centroid(&mut self) {
451        self.inner.clear_centroid();
452    }
453
454    /// Get current centroid.
455    pub fn centroid(&self) -> Option<&[f32]> {
456        self.inner.centroid()
457    }
458
459    /// Center a vector by subtracting the centroid.
460    pub fn centered_vector(&self, vec: &[f32]) -> Vec<f32> {
461        self.inner.centered_vector(vec)
462    }
463
464    // ─── Delegated region operations ──────────────────────────────
465
466    /// Get semantic regions at a given HNSW level.
467    pub fn regions(&self, level: usize) -> Vec<(u32, Vec<f32>, usize)> {
468        self.inner.regions(level)
469    }
470
471    /// O(N) single-pass region assignments.
472    pub fn region_assignments(
473        &self,
474        level: usize,
475        filter: TemporalFilter,
476    ) -> std::collections::HashMap<u32, Vec<(u64, i64)>> {
477        self.inner.region_assignments(level, filter)
478    }
479
480    /// Smoothed region distribution trajectory for an entity.
481    pub fn region_trajectory(
482        &self,
483        entity_id: u64,
484        level: usize,
485        window_days: i64,
486        alpha: f32,
487    ) -> Vec<(i64, Vec<f32>)> {
488        self.inner
489            .region_trajectory(entity_id, level, window_days, alpha)
490    }
491
492    // ─── Typed edges (RFC-013 Part B) ───────────────────────────
493
494    /// Access the typed edge store.
495    pub fn typed_edges(&self) -> &TypedEdgeStore {
496        &self.typed_edges
497    }
498
499    /// Mutable access to the typed edge store.
500    pub fn typed_edges_mut(&mut self) -> &mut TypedEdgeStore {
501        &mut self.typed_edges
502    }
503
504    /// Add a typed edge between two nodes.
505    pub fn add_typed_edge(&mut self, source: u32, target: u32, edge_type: EdgeType, weight: f32) {
506        self.typed_edges.add_edge(source, target, edge_type, weight);
507    }
508
509    /// Get the success score of a node based on typed edges.
510    ///
511    /// Uses Beta prior: P(success) = (1 + n_success) / (2 + n_total).
512    pub fn success_score(&self, node_id: u32) -> f32 {
513        self.typed_edges.success_score(node_id)
514    }
515
516    /// Save to directory (index + temporal edges + typed edges).
517    pub fn save(&self, dir: &Path) -> std::io::Result<()> {
518        std::fs::create_dir_all(dir)?;
519        self.inner.save(&dir.join("index.bin"))?;
520        let edge_bytes = postcard::to_allocvec(&self.edges).map_err(std::io::Error::other)?;
521        std::fs::write(dir.join("temporal_edges.bin"), edge_bytes)?;
522        let typed_bytes =
523            postcard::to_allocvec(&self.typed_edges).map_err(std::io::Error::other)?;
524        std::fs::write(dir.join("typed_edges.bin"), typed_bytes)?;
525        Ok(())
526    }
527
528    /// Load from directory.
529    pub fn load(dir: &Path, metric: D) -> std::io::Result<Self> {
530        let inner = TemporalHnsw::load(&dir.join("index.bin"), metric)?;
531        let edge_bytes = std::fs::read(dir.join("temporal_edges.bin"))?;
532        let edges: TemporalEdgeLayer = postcard::from_bytes(&edge_bytes)
533            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
534        // Typed edges are optional (backward compat)
535        let typed_edges = if dir.join("typed_edges.bin").exists() {
536            let typed_bytes = std::fs::read(dir.join("typed_edges.bin"))?;
537            postcard::from_bytes(&typed_bytes)
538                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?
539        } else {
540            TypedEdgeStore::new()
541        };
542        Ok(Self {
543            inner,
544            edges,
545            typed_edges,
546        })
547    }
548
549    // ─── Scored search (RFC-013 Part D — wiring A+B+C) ────────────
550
551    /// Search with Bayesian multi-factor scoring.
552    ///
553    /// Pipeline:
554    /// 1. HNSW over-fetch 4k candidates
555    /// 2. Compute features per candidate (similarity, recency, reward, success_score, region)
556    /// 3. Bayesian rerank using `weights`
557    /// 4. Return top-k
558    ///
559    /// This integrates typed edges (success_score), reward annotations,
560    /// recency, and region membership into a single scored retrieval.
561    #[allow(clippy::too_many_arguments)]
562    pub fn scored_search(
563        &self,
564        query: &[f32],
565        k: usize,
566        filter: TemporalFilter,
567        query_timestamp: i64,
568        weights: &super::bayesian_scorer::ScoringWeights,
569        query_region: Option<u32>,
570    ) -> Vec<(u32, f32)> {
571        use super::bayesian_scorer::{CandidateFeatures, rerank};
572
573        if self.inner.is_empty() {
574            return Vec::new();
575        }
576
577        // Phase 1: HNSW over-fetch
578        let over_fetch = k * 4;
579        let candidates = self
580            .inner
581            .search(query, over_fetch, filter, 1.0, query_timestamp);
582
583        // Phase 2: Compute features
584        let features: Vec<CandidateFeatures> = candidates
585            .iter()
586            .map(|&(node_id, raw_distance)| {
587                let ts = self.inner.timestamp(node_id);
588                let sem_norm = self.inner.normalize_semantic_distance(raw_distance);
589                let recency = self.inner.recency_penalty(ts, 1.0);
590                let reward = self.inner.reward(node_id);
591                let success = self.typed_edges.success_score(node_id);
592
593                let region_match = query_region
594                    .map(|qr| {
595                        // Check if candidate is in same region
596                        let candidate_vec = self.inner.vector(node_id);
597                        self.inner
598                            .graph()
599                            .assign_region(candidate_vec, 1)
600                            .map(|cr| cr == qr)
601                            .unwrap_or(false)
602                    })
603                    .unwrap_or(false);
604
605                CandidateFeatures {
606                    node_id,
607                    raw_distance,
608                    similarity: sem_norm,
609                    recency,
610                    reward,
611                    success_score: success,
612                    region_match,
613                }
614            })
615            .collect();
616
617        // Phase 3: Bayesian rerank
618        rerank(&features, weights, k)
619    }
620
621    /// Assign a vector to a region at a given HNSW level.
622    pub fn assign_region(&self, vector: &[f32], level: usize) -> Option<u32> {
623        self.inner.graph().assign_region(vector, level)
624    }
625
626    // ─── Trajectory search (RFC-014 Opción 3) ─────────────────────
627
628    /// Search for episodes with similar trajectory SHAPE (not just similar state).
629    ///
630    /// Uses path signatures to compare the agent's recent trajectory against
631    /// stored episodes. Returns episodes whose trajectory shape is most similar,
632    /// regardless of where in embedding space they are.
633    ///
634    /// This enables "lateral thinking": finding episodes that MOVED similarly,
635    /// not just episodes that WERE in a similar position.
636    ///
637    /// Returns `(entity_id, signature_distance, trajectory_length)`.
638    /// Search for episodes with similar trajectory SHAPE (not just similar state).
639    ///
640    /// Uses path signatures to compare the agent's recent trajectory against
641    /// stored episodes. Returns episodes whose trajectory shape is most similar.
642    ///
643    /// This enables "lateral thinking": finding episodes that MOVED similarly,
644    /// not just episodes that WERE in a similar position.
645    ///
646    /// Returns `(entity_id, signature_distance, trajectory_length)`.
647    pub fn trajectory_search(
648        &self,
649        recent_trajectory: &[(i64, &[f32])],
650        k: usize,
651        signature_depth: usize,
652    ) -> Vec<(u64, f64, usize)> {
653        use cvx_analytics::signatures::{SignatureConfig, compute_signature, signature_distance};
654
655        if recent_trajectory.len() < 2 {
656            return Vec::new();
657        }
658
659        let config = SignatureConfig {
660            depth: signature_depth,
661            time_augmentation: false,
662        };
663
664        // Compute signature of query trajectory
665        let query_sig = match compute_signature(recent_trajectory, &config) {
666            Ok(sig) => sig,
667            Err(_) => return Vec::new(),
668        };
669
670        // Get all unique entity IDs
671        let mut entity_ids: Vec<u64> = Vec::new();
672        let mut seen = std::collections::HashSet::new();
673        for i in 0..self.inner.len() {
674            let eid = self.inner.entity_id(i as u32);
675            if seen.insert(eid) {
676                entity_ids.push(eid);
677            }
678        }
679
680        // Compute signature distance for each entity's trajectory
681        let mut scored: Vec<(u64, f64, usize)> = entity_ids
682            .iter()
683            .filter_map(|&eid| {
684                let traj = self.inner.trajectory(eid, TemporalFilter::All);
685                if traj.len() < 2 {
686                    return None;
687                }
688                let ep_traj: Vec<(i64, Vec<f32>)> = traj
689                    .iter()
690                    .map(|&(ts, nid)| {
691                        let v = self.inner.vector(nid);
692                        (ts, v.to_vec())
693                    })
694                    .collect();
695                // Convert to borrowed slices for compute_signature
696                let ep_refs: Vec<(i64, &[f32])> =
697                    ep_traj.iter().map(|(ts, v)| (*ts, v.as_slice())).collect();
698                let ep_sig = compute_signature(&ep_refs, &config).ok()?;
699                let dist = signature_distance(&query_sig, &ep_sig);
700                Some((eid, dist, ep_traj.len()))
701            })
702            .collect();
703
704        scored.sort_by(|a, b| a.1.total_cmp(&b.1));
705        scored.truncate(k);
706        scored
707    }
708}
709
710// ─── TemporalIndexAccess ────────────────────────────────────────────
711
712impl<D: DistanceMetric + Clone> cvx_core::TemporalIndexAccess for TemporalGraphIndex<D> {
713    fn search_raw(
714        &self,
715        query: &[f32],
716        k: usize,
717        filter: TemporalFilter,
718        alpha: f32,
719        query_timestamp: i64,
720    ) -> Vec<(u32, f32)> {
721        // Use hybrid search with moderate beta
722        self.hybrid_search(query, k, filter, alpha, 0.3, query_timestamp)
723    }
724
725    fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
726        self.inner.trajectory(entity_id, filter)
727    }
728
729    fn vector(&self, node_id: u32) -> Vec<f32> {
730        self.inner.vector(node_id).to_vec()
731    }
732
733    fn entity_id(&self, node_id: u32) -> u64 {
734        self.inner.entity_id(node_id)
735    }
736
737    fn timestamp(&self, node_id: u32) -> i64 {
738        self.inner.timestamp(node_id)
739    }
740
741    fn len(&self) -> usize {
742        self.inner.len()
743    }
744}
745
746// ─── Ordered float helper ───────────────────────────────────────────
747
748#[derive(Debug, Clone, Copy, PartialEq)]
749struct OrderedF32(f32);
750
751impl Eq for OrderedF32 {}
752
753impl PartialOrd for OrderedF32 {
754    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
755        Some(self.cmp(other))
756    }
757}
758
759impl Ord for OrderedF32 {
760    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
761        self.0
762            .partial_cmp(&other.0)
763            .unwrap_or(std::cmp::Ordering::Equal)
764    }
765}
766
767// ─── Tests ──────────────────────────────────────────────────────────
768
769#[cfg(test)]
770mod tests {
771    use super::*;
772    use crate::metrics::L2Distance;
773
774    fn setup_index(
775        n_entities: u64,
776        points_per_entity: usize,
777        dim: usize,
778    ) -> TemporalGraphIndex<L2Distance> {
779        let config = HnswConfig {
780            m: 16,
781            ef_construction: 100,
782            ef_search: 50,
783            ..Default::default()
784        };
785        let mut index = TemporalGraphIndex::new(config, L2Distance);
786
787        for e in 0..n_entities {
788            for i in 0..points_per_entity {
789                let ts = (i as i64) * 1_000_000;
790                let v: Vec<f32> = (0..dim)
791                    .map(|d| (e as f32 * 10.0) + (i as f32 * 0.1) + (d as f32 * 0.01))
792                    .collect();
793                index.insert(e, ts, &v);
794            }
795        }
796
797        index
798    }
799
800    // ─── Basic insert + edges ───────────────────────────────────
801
802    #[test]
803    fn insert_creates_temporal_edges() {
804        let index = setup_index(1, 5, 3);
805
806        assert_eq!(index.len(), 5);
807        assert_eq!(index.edges().len(), 5);
808
809        // Chain: 0 → 1 → 2 → 3 → 4
810        assert_eq!(index.edges().successor(0), Some(1));
811        assert_eq!(index.edges().successor(3), Some(4));
812        assert_eq!(index.edges().predecessor(4), Some(3));
813        assert_eq!(index.edges().successor(4), None);
814    }
815
816    #[test]
817    fn multi_entity_edges_isolated() {
818        let index = setup_index(3, 5, 3);
819
820        // Entity 0: nodes 0-4, Entity 1: nodes 5-9, Entity 2: nodes 10-14
821        // (assuming sequential insert order)
822        for i in 0..4u32 {
823            let succ = index.edges().successor(i);
824            assert!(succ.is_some());
825            // Successor should be same entity
826            let succ_entity = index.entity_id(succ.unwrap());
827            let my_entity = index.entity_id(i);
828            assert_eq!(
829                succ_entity, my_entity,
830                "edge from node {i} crosses entities"
831            );
832        }
833    }
834
835    // ─── Causal search ──────────────────────────────────────────
836
837    #[test]
838    fn causal_search_returns_context() {
839        let index = setup_index(3, 10, 4);
840
841        let results = index.causal_search(
842            &[0.5, 0.05, 0.005, 0.001],
843            3,
844            TemporalFilter::All,
845            1.0,
846            5_000_000,
847            3, // 3 steps of temporal context
848        );
849
850        assert_eq!(results.len(), 3);
851
852        for r in &results {
853            // Each result should have temporal context
854            // (unless it's at the very end of its entity's timeline)
855            assert!(
856                !r.successors.is_empty() || !r.predecessors.is_empty(),
857                "node {} should have some temporal context",
858                r.node_id
859            );
860
861            // Verify successors are temporally ordered
862            for w in r.successors.windows(2) {
863                assert!(w[0].1 <= w[1].1, "successors should be time-ordered");
864            }
865        }
866    }
867
868    #[test]
869    fn causal_search_successors_same_entity() {
870        let index = setup_index(5, 10, 3);
871
872        let results = index.causal_search(&[0.0, 0.0, 0.0], 5, TemporalFilter::All, 1.0, 0, 5);
873
874        for r in &results {
875            for &(succ_id, _) in &r.successors {
876                assert_eq!(
877                    index.entity_id(succ_id),
878                    r.entity_id,
879                    "successor should be same entity"
880                );
881            }
882            for &(pred_id, _) in &r.predecessors {
883                assert_eq!(
884                    index.entity_id(pred_id),
885                    r.entity_id,
886                    "predecessor should be same entity"
887                );
888            }
889        }
890    }
891
892    // ─── Hybrid search ──────────────────────────────────────────
893
894    #[test]
895    fn hybrid_search_beta_zero_matches_standard() {
896        let index = setup_index(5, 20, 4);
897        let query = [5.0f32, 0.05, 0.005, 0.001];
898
899        let standard = index.search(&query, 10, TemporalFilter::All, 1.0, 0);
900        let hybrid = index.hybrid_search(&query, 10, TemporalFilter::All, 1.0, 0.0, 0);
901
902        // With beta=0, hybrid should produce the same top results as standard
903        // (may differ slightly due to beam search implementation differences)
904        assert_eq!(standard.len(), hybrid.len());
905
906        // At least the top result should match
907        assert_eq!(
908            standard[0].0, hybrid[0].0,
909            "top result should match between standard and hybrid (beta=0)"
910        );
911    }
912
913    #[test]
914    fn hybrid_search_with_temporal_edges() {
915        let index = setup_index(3, 20, 4);
916        let query = [0.5f32, 0.05, 0.005, 0.001];
917
918        let results = index.hybrid_search(
919            &query,
920            10,
921            TemporalFilter::All,
922            1.0,
923            0.5, // moderate temporal edge exploration
924            5_000_000,
925        );
926
927        assert!(!results.is_empty());
928        assert!(results.len() <= 10);
929
930        // Verify all results are valid nodes
931        for &(nid, score) in &results {
932            assert!((nid as usize) < index.len());
933            assert!(score >= 0.0);
934            assert!(score.is_finite());
935        }
936    }
937
938    #[test]
939    fn hybrid_search_respects_temporal_filter() {
940        let index = setup_index(3, 20, 4);
941        let query = [1.0f32, 0.1, 0.01, 0.001];
942
943        let results = index.hybrid_search(
944            &query,
945            10,
946            TemporalFilter::Range(5_000_000, 15_000_000),
947            1.0,
948            0.5,
949            10_000_000,
950        );
951
952        for &(nid, _) in &results {
953            let ts = index.timestamp(nid);
954            assert!(
955                (5_000_000..=15_000_000).contains(&ts),
956                "ts {ts} outside filter range"
957            );
958        }
959    }
960
961    // ─── TemporalIndexAccess trait ──────────────────────────────
962
963    #[test]
964    fn trait_search_works() {
965        let index = setup_index(3, 10, 4);
966        let trait_ref: &dyn cvx_core::TemporalIndexAccess = &index;
967
968        let results = trait_ref.search_raw(&[0.0; 4], 5, TemporalFilter::All, 1.0, 0);
969        assert_eq!(results.len(), 5);
970    }
971
972    #[test]
973    fn trait_trajectory_works() {
974        let index = setup_index(3, 10, 4);
975        let trait_ref: &dyn cvx_core::TemporalIndexAccess = &index;
976
977        let traj = trait_ref.trajectory(0, TemporalFilter::All);
978        assert_eq!(traj.len(), 10);
979    }
980
981    // ─── from_temporal_hnsw migration ───────────────────────────
982
983    #[test]
984    fn from_temporal_hnsw_preserves_edges() {
985        let config = HnswConfig::default();
986        let mut hnsw = TemporalHnsw::new(config, L2Distance);
987
988        for i in 0..10u64 {
989            hnsw.insert(i % 3, i as i64 * 1000, &[i as f32, 0.0]);
990        }
991
992        let graph_index = TemporalGraphIndex::from_temporal_hnsw(hnsw);
993
994        assert_eq!(graph_index.len(), 10);
995        assert_eq!(graph_index.edges().len(), 10);
996
997        // Verify temporal chains don't cross entities
998        for nid in 0..10u32 {
999            if let Some(succ) = graph_index.edges().successor(nid) {
1000                assert_eq!(
1001                    graph_index.entity_id(succ),
1002                    graph_index.entity_id(nid),
1003                    "edge from {nid} crosses entities after migration"
1004                );
1005            }
1006        }
1007    }
1008
1009    // ─── Save/Load ──────────────────────────────────────────────
1010
1011    #[test]
1012    fn save_load_roundtrip() {
1013        let index = setup_index(3, 10, 3);
1014
1015        let dir = tempfile::tempdir().unwrap();
1016        index.save(dir.path()).unwrap();
1017
1018        let loaded = TemporalGraphIndex::load(dir.path(), L2Distance).unwrap();
1019
1020        assert_eq!(loaded.len(), 30);
1021        assert_eq!(loaded.edges().len(), 30);
1022
1023        // Verify edges match
1024        for nid in 0..30u32 {
1025            assert_eq!(
1026                loaded.edges().successor(nid),
1027                index.edges().successor(nid),
1028                "successor mismatch at node {nid}"
1029            );
1030        }
1031    }
1032
1033    // ─── Edge cases ─────────────────────────────────────────────
1034
1035    #[test]
1036    fn empty_index() {
1037        let config = HnswConfig::default();
1038        let index = TemporalGraphIndex::new(config, L2Distance);
1039
1040        assert!(index.is_empty());
1041        let results = index.hybrid_search(&[0.0; 3], 5, TemporalFilter::All, 1.0, 0.5, 0);
1042        assert!(results.is_empty());
1043
1044        let causal = index.causal_search(&[0.0; 3], 5, TemporalFilter::All, 1.0, 0, 3);
1045        assert!(causal.is_empty());
1046    }
1047
1048    #[test]
1049    fn single_point() {
1050        let config = HnswConfig::default();
1051        let mut index = TemporalGraphIndex::new(config, L2Distance);
1052        index.insert(1, 1000, &[1.0, 2.0, 3.0]);
1053
1054        let causal = index.causal_search(&[1.0, 2.0, 3.0], 1, TemporalFilter::All, 1.0, 0, 5);
1055        assert_eq!(causal.len(), 1);
1056        assert!(causal[0].successors.is_empty());
1057        assert!(causal[0].predecessors.is_empty());
1058    }
1059
1060    // ─── Delegated method coverage ───────────────────────────────
1061
1062    #[test]
1063    fn config_and_ef_delegation() {
1064        let config = HnswConfig {
1065            m: 8,
1066            ef_construction: 100,
1067            ef_search: 50,
1068            ..Default::default()
1069        };
1070        let mut index = TemporalGraphIndex::new(config, L2Distance);
1071        assert_eq!(index.config().m, 8);
1072        assert_eq!(index.config().ef_construction, 100);
1073
1074        index.set_ef_construction(150);
1075        assert_eq!(index.config().ef_construction, 150);
1076
1077        index.set_ef_search(200);
1078        assert_eq!(index.config().ef_search, 200);
1079    }
1080
1081    #[test]
1082    fn centering_delegation() {
1083        let config = HnswConfig::default();
1084        let mut index = TemporalGraphIndex::new(config, L2Distance);
1085        index.insert(1, 1000, &[2.0, 4.0]);
1086        index.insert(2, 2000, &[4.0, 6.0]);
1087
1088        let centroid = index.compute_centroid().unwrap();
1089        assert!((centroid[0] - 3.0).abs() < 1e-6);
1090
1091        index.set_centroid(vec![3.0, 5.0]);
1092        assert_eq!(index.centroid().unwrap(), &[3.0, 5.0]);
1093
1094        let centered = index.centered_vector(&[5.0, 8.0]);
1095        assert!((centered[0] - 2.0).abs() < 1e-6);
1096        assert!((centered[1] - 3.0).abs() < 1e-6);
1097
1098        index.clear_centroid();
1099        assert!(index.centroid().is_none());
1100    }
1101
1102    #[test]
1103    fn reward_delegation() {
1104        let config = HnswConfig::default();
1105        let mut index = TemporalGraphIndex::new(config, L2Distance);
1106        let n0 = index.insert(1, 1000, &[1.0, 0.0]);
1107        let n1 = index.insert_with_reward(2, 2000, &[0.0, 1.0], 0.8);
1108
1109        assert!(index.reward(n0).is_nan());
1110        assert!((index.reward(n1) - 0.8).abs() < 1e-6);
1111
1112        index.set_reward(n0, 0.95);
1113        assert!((index.reward(n0) - 0.95).abs() < 1e-6);
1114    }
1115
1116    #[test]
1117    fn search_with_reward_delegation() {
1118        let config = HnswConfig::default();
1119        let mut index = TemporalGraphIndex::new(config, L2Distance);
1120        for i in 0..10u64 {
1121            index.insert_with_reward(i, i as i64 * 1000, &[i as f32, 0.0, 0.0], i as f32 * 0.1);
1122        }
1123
1124        let results =
1125            index.search_with_reward(&[7.0, 0.0, 0.0], 5, TemporalFilter::All, 1.0, 0, 0.5);
1126        assert!(!results.is_empty());
1127        for &(node_id, _) in &results {
1128            assert!(
1129                index.reward(node_id) >= 0.5,
1130                "node {node_id} reward {} < 0.5",
1131                index.reward(node_id)
1132            );
1133        }
1134    }
1135
1136    #[test]
1137    fn region_delegation() {
1138        let config = HnswConfig {
1139            m: 4,
1140            ef_construction: 50,
1141            ef_search: 50,
1142            ..Default::default()
1143        };
1144        let mut index = TemporalGraphIndex::new(config, L2Distance);
1145        let mut rng = rand::rng();
1146        for i in 0..200u64 {
1147            let v: Vec<f32> = (0..8).map(|_| rand::Rng::random::<f32>(&mut rng)).collect();
1148            index.insert(i % 4, i as i64 * 1000, &v);
1149        }
1150
1151        let regions = index.regions(1);
1152        assert!(!regions.is_empty());
1153
1154        let assignments = index.region_assignments(1, TemporalFilter::All);
1155        let total: usize = assignments.values().map(|v| v.len()).sum();
1156        assert_eq!(total, 200);
1157    }
1158
1159    #[test]
1160    fn scalar_quantization_delegation() {
1161        let config = HnswConfig::default();
1162        let mut index = TemporalGraphIndex::new(config, L2Distance);
1163        index.insert(1, 1000, &[1.0, 0.0]);
1164        index.insert(2, 2000, &[0.0, 1.0]);
1165
1166        index.enable_scalar_quantization(-1.0, 1.0);
1167        let results = index.search(&[1.0, 0.0], 2, TemporalFilter::All, 1.0, 0);
1168        assert_eq!(results.len(), 2);
1169
1170        index.disable_scalar_quantization();
1171        let results = index.search(&[1.0, 0.0], 2, TemporalFilter::All, 1.0, 0);
1172        assert_eq!(results.len(), 2);
1173    }
1174
1175    #[test]
1176    fn insert_with_reward_creates_temporal_edges() {
1177        let config = HnswConfig::default();
1178        let mut index = TemporalGraphIndex::new(config, L2Distance);
1179
1180        // Insert 5 steps for same entity with rewards
1181        for i in 0..5u32 {
1182            index.insert_with_reward(1, i as i64 * 100, &[i as f32, 0.0], i as f32 * 0.2);
1183        }
1184
1185        // Check temporal edges exist
1186        let edges = index.edges();
1187        assert!(edges.successor(0).is_some());
1188        assert!(edges.predecessor(4).is_some());
1189
1190        // Causal search should return continuations
1191        let results = index.causal_search(&[0.0, 0.0], 1, TemporalFilter::All, 1.0, 0, 3);
1192        assert_eq!(results.len(), 1);
1193        assert!(!results[0].successors.is_empty());
1194    }
1195
1196    // ─── Scored search (RFC-013 Part D) ──────────────────────────
1197
1198    #[test]
1199    fn scored_search_basic() {
1200        use crate::hnsw::ScoringWeights;
1201
1202        let config = HnswConfig::default();
1203        let mut index = TemporalGraphIndex::new(config, L2Distance);
1204
1205        for i in 0..20u64 {
1206            index.insert(i % 3, (i * 100) as i64, &[i as f32, 0.0, 0.0]);
1207        }
1208
1209        let weights = ScoringWeights::default();
1210        let results =
1211            index.scored_search(&[10.0, 0.0, 0.0], 5, TemporalFilter::All, 0, &weights, None);
1212        assert_eq!(results.len(), 5);
1213    }
1214
1215    #[test]
1216    fn scored_search_reward_boosts() {
1217        use crate::hnsw::ScoringWeights;
1218
1219        let config = HnswConfig::default();
1220        let mut index = TemporalGraphIndex::new(config, L2Distance);
1221
1222        // Two similar vectors, one with high reward, one with low
1223        index.insert_with_reward(1, 1000, &[1.0, 0.0, 0.0], 0.9);
1224        index.insert_with_reward(2, 2000, &[1.01, 0.0, 0.0], 0.1);
1225
1226        let weights = ScoringWeights {
1227            similarity: 1.0,
1228            reward: 0.5, // reward matters
1229            ..ScoringWeights::default()
1230        };
1231
1232        let results =
1233            index.scored_search(&[1.0, 0.0, 0.0], 2, TemporalFilter::All, 0, &weights, None);
1234
1235        assert_eq!(results.len(), 2);
1236        // High-reward node should rank first
1237        assert_eq!(results[0].0, 0, "high-reward node should rank first");
1238    }
1239
1240    #[test]
1241    fn scored_search_success_score_from_typed_edges() {
1242        use crate::hnsw::{EdgeType, ScoringWeights};
1243
1244        let config = HnswConfig::default();
1245        let mut index = TemporalGraphIndex::new(config, L2Distance);
1246
1247        let n0 = index.insert(1, 1000, &[1.0, 0.0]);
1248        let n1 = index.insert(2, 2000, &[1.01, 0.0]);
1249
1250        // Node 0 has success edges, node 1 has failure edges
1251        index.add_typed_edge(n0, 10, EdgeType::CausedSuccess, 1.0);
1252        index.add_typed_edge(n0, 11, EdgeType::CausedSuccess, 1.0);
1253        index.add_typed_edge(n1, 12, EdgeType::CausedFailure, 1.0);
1254        index.add_typed_edge(n1, 13, EdgeType::CausedFailure, 1.0);
1255
1256        let weights = ScoringWeights {
1257            similarity: 1.0,
1258            success: 0.5,
1259            ..ScoringWeights::default()
1260        };
1261
1262        let results = index.scored_search(&[1.0, 0.0], 2, TemporalFilter::All, 0, &weights, None);
1263
1264        // Node 0 (high success) should rank higher than node 1 (high failure)
1265        assert_eq!(results[0].0, n0);
1266    }
1267
1268    #[test]
1269    fn assign_region_works() {
1270        let config = HnswConfig {
1271            m: 4,
1272            ef_construction: 50,
1273            ef_search: 50,
1274            ..Default::default()
1275        };
1276        let mut index = TemporalGraphIndex::new(config, L2Distance);
1277
1278        let mut rng = rand::rng();
1279        for i in 0..200u64 {
1280            let v: Vec<f32> = (0..8).map(|_| rand::Rng::random::<f32>(&mut rng)).collect();
1281            index.insert(i % 4, (i * 100) as i64, &v);
1282        }
1283
1284        let query: Vec<f32> = (0..8).map(|_| rand::Rng::random::<f32>(&mut rng)).collect();
1285        let region = index.assign_region(&query, 1);
1286        assert!(region.is_some());
1287    }
1288}