1use std::cmp::Reverse;
8use std::collections::{BinaryHeap, HashSet};
9use std::path::Path;
10
11use cvx_core::traits::DistanceMetric;
12use cvx_core::types::TemporalFilter;
13
14use super::HnswConfig;
15use super::temporal::TemporalHnsw;
16use super::temporal_edges::TemporalEdgeLayer;
17use super::typed_edges::{EdgeType, TypedEdgeStore};
18
19#[derive(Debug, Clone)]
23pub struct CausalSearchResult {
24 pub node_id: u32,
26 pub score: f32,
28 pub entity_id: u64,
30 pub successors: Vec<(u32, i64)>,
32 pub predecessors: Vec<(u32, i64)>,
34}
35
36pub struct TemporalGraphIndex<D: DistanceMetric> {
43 inner: TemporalHnsw<D>,
45 edges: TemporalEdgeLayer,
47 typed_edges: TypedEdgeStore,
49}
50
51impl<D: DistanceMetric + Clone> TemporalGraphIndex<D> {
52 pub fn new(config: HnswConfig, metric: D) -> Self {
54 Self {
55 inner: TemporalHnsw::new(config, metric),
56 edges: TemporalEdgeLayer::new(),
57 typed_edges: TypedEdgeStore::new(),
58 }
59 }
60
61 pub fn from_temporal_hnsw(inner: TemporalHnsw<D>) -> Self {
65 let mut edges = TemporalEdgeLayer::with_capacity(inner.len());
66
67 let mut pred_map: Vec<Option<u32>> = vec![None; inner.len()];
73
74 for nid in 0..inner.len() as u32 {
75 let eid = inner.entity_id(nid);
76 let traj = inner.trajectory(eid, TemporalFilter::All);
78 let my_ts = inner.timestamp(nid);
79
80 let prev = traj
81 .iter()
82 .filter(|&&(ts, id)| ts < my_ts || (ts == my_ts && id < nid))
83 .max_by_key(|&&(ts, _)| ts)
84 .map(|&(_, id)| id);
85
86 pred_map[nid as usize] = prev;
87 }
88
89 for nid in 0..inner.len() as u32 {
90 edges.register(nid, pred_map[nid as usize]);
91 }
92
93 Self {
94 inner,
95 edges,
96 typed_edges: TypedEdgeStore::new(),
97 }
98 }
99
100 pub fn insert(&mut self, entity_id: u64, timestamp: i64, vector: &[f32]) -> u32 {
102 let last_node = self.inner.entity_last_node(entity_id);
103 let node_id = self.inner.insert(entity_id, timestamp, vector);
104 self.edges.register(node_id, last_node);
105 node_id
106 }
107
108 pub fn insert_with_reward(
110 &mut self,
111 entity_id: u64,
112 timestamp: i64,
113 vector: &[f32],
114 reward: f32,
115 ) -> u32 {
116 let last_node = self.inner.entity_last_node(entity_id);
117 let node_id = self
118 .inner
119 .insert_with_reward(entity_id, timestamp, vector, reward);
120 self.edges.register(node_id, last_node);
121 node_id
122 }
123
124 pub fn search(
126 &self,
127 query: &[f32],
128 k: usize,
129 filter: TemporalFilter,
130 alpha: f32,
131 query_timestamp: i64,
132 ) -> Vec<(u32, f32)> {
133 self.inner.search(query, k, filter, alpha, query_timestamp)
134 }
135
136 pub fn causal_search(
144 &self,
145 query: &[f32],
146 k: usize,
147 filter: TemporalFilter,
148 alpha: f32,
149 query_timestamp: i64,
150 temporal_context: usize,
151 ) -> Vec<CausalSearchResult> {
152 let results = self.inner.search(query, k, filter, alpha, query_timestamp);
153
154 results
155 .into_iter()
156 .map(|(node_id, score)| {
157 let entity_id = self.inner.entity_id(node_id);
158
159 let succ_ids = self.edges.walk_forward(node_id, temporal_context);
160 let successors: Vec<(u32, i64)> = succ_ids
161 .into_iter()
162 .map(|nid| (nid, self.inner.timestamp(nid)))
163 .collect();
164
165 let pred_ids = self.edges.walk_backward(node_id, temporal_context);
166 let predecessors: Vec<(u32, i64)> = pred_ids
167 .into_iter()
168 .map(|nid| (nid, self.inner.timestamp(nid)))
169 .collect();
170
171 CausalSearchResult {
172 node_id,
173 score,
174 entity_id,
175 successors,
176 predecessors,
177 }
178 })
179 .collect()
180 }
181
182 pub fn hybrid_search(
191 &self,
192 query: &[f32],
193 k: usize,
194 filter: TemporalFilter,
195 alpha: f32,
196 beta: f32,
197 query_timestamp: i64,
198 ) -> Vec<(u32, f32)> {
199 let graph = self.inner.graph();
200
201 if graph.is_empty() {
202 return Vec::new();
203 }
204
205 let entry = match graph.entry_point() {
206 Some(ep) => ep,
207 None => return Vec::new(),
208 };
209
210 let bitmap = self.inner.build_filter_bitmap(&filter);
211 let ef = graph.config().ef_search.max(k);
212
213 let max_level = graph.max_level();
215 let mut current = entry;
216 let mut current_dist = graph.distance_to(current, query);
217
218 for level in (1..=max_level).rev() {
219 let mut improved = true;
220 while improved {
221 improved = false;
222 for &neighbor in graph.neighbors_at_level(current, level) {
223 let d = graph.distance_to(neighbor, query);
224 if d < current_dist {
225 current = neighbor;
226 current_dist = d;
227 improved = true;
228 }
229 }
230 }
231 }
232
233 let mut candidates: BinaryHeap<Reverse<(OrderedF32, u32)>> = BinaryHeap::new();
237 let mut results: BinaryHeap<(OrderedF32, u32)> = BinaryHeap::new();
238 let mut visited: HashSet<u32> = HashSet::new();
239
240 let entry_dist = graph.distance_to(current, query);
241 candidates.push(Reverse((OrderedF32(entry_dist), current)));
242 if bitmap.contains(current) {
243 results.push((OrderedF32(entry_dist), current));
244 }
245 visited.insert(current);
246
247 while let Some(Reverse((OrderedF32(c_dist), c_id))) = candidates.pop() {
248 let farthest_dist = results
249 .peek()
250 .map(|(OrderedF32(d), _)| *d)
251 .unwrap_or(f32::MAX);
252 if c_dist > farthest_dist && results.len() >= ef {
253 break;
254 }
255
256 let semantic_neighbors = graph.neighbors_at_level(c_id, 0);
258
259 let temporal_neighbors: Vec<u32> = if beta > 0.0 {
261 self.edges.temporal_neighbors(c_id).collect()
262 } else {
263 Vec::new()
264 };
265
266 for &neighbor in semantic_neighbors.iter().chain(temporal_neighbors.iter()) {
268 if !visited.insert(neighbor) {
269 continue;
270 }
271
272 if !bitmap.contains(neighbor) {
274 continue;
275 }
276
277 let mut dist = graph.distance_to(neighbor, query);
278
279 if alpha < 1.0 {
281 let t_dist = self.inner.temporal_distance_normalized(
282 self.inner.timestamp(neighbor),
283 query_timestamp,
284 );
285 dist = alpha * dist + (1.0 - alpha) * t_dist;
286 }
287
288 let farthest = results
289 .peek()
290 .map(|(OrderedF32(d), _)| *d)
291 .unwrap_or(f32::MAX);
292 if dist < farthest || results.len() < ef {
293 candidates.push(Reverse((OrderedF32(dist), neighbor)));
294 results.push((OrderedF32(dist), neighbor));
295 if results.len() > ef {
296 results.pop();
297 }
298 }
299 }
300 }
301
302 let mut final_results: Vec<(u32, f32)> = results
304 .into_iter()
305 .map(|(OrderedF32(d), nid)| (nid, d))
306 .collect();
307 final_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
308 final_results.truncate(k);
309 final_results
310 }
311
312 pub fn inner(&self) -> &TemporalHnsw<D> {
316 &self.inner
317 }
318
319 pub fn edges(&self) -> &TemporalEdgeLayer {
321 &self.edges
322 }
323
324 pub fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
326 self.inner.trajectory(entity_id, filter)
327 }
328
329 pub fn vector(&self, node_id: u32) -> &[f32] {
331 self.inner.vector(node_id)
332 }
333
334 pub fn entity_id(&self, node_id: u32) -> u64 {
336 self.inner.entity_id(node_id)
337 }
338
339 pub fn timestamp(&self, node_id: u32) -> i64 {
341 self.inner.timestamp(node_id)
342 }
343
344 pub fn len(&self) -> usize {
346 self.inner.len()
347 }
348
349 pub fn is_empty(&self) -> bool {
351 self.inner.is_empty()
352 }
353
354 pub fn inner_mut(&mut self) -> &mut TemporalHnsw<D> {
358 &mut self.inner
359 }
360
361 pub fn config(&self) -> &super::HnswConfig {
363 self.inner.config()
364 }
365
366 pub fn set_ef_construction(&mut self, ef: usize) {
368 self.inner.set_ef_construction(ef);
369 }
370
371 pub fn set_ef_search(&mut self, ef: usize) {
373 self.inner.set_ef_search(ef);
374 }
375
376 pub fn enable_scalar_quantization(&mut self, min_val: f32, max_val: f32) {
378 self.inner.enable_scalar_quantization(min_val, max_val);
379 }
380
381 pub fn disable_scalar_quantization(&mut self) {
383 self.inner.disable_scalar_quantization();
384 }
385
386 #[allow(clippy::too_many_arguments)]
390 pub fn search_with_recency(
391 &self,
392 query: &[f32],
393 k: usize,
394 filter: TemporalFilter,
395 alpha: f32,
396 query_timestamp: i64,
397 recency_lambda: f32,
398 recency_weight: f32,
399 ) -> Vec<(u32, f32)> {
400 self.inner.search_with_recency(
401 query,
402 k,
403 filter,
404 alpha,
405 query_timestamp,
406 recency_lambda,
407 recency_weight,
408 )
409 }
410
411 pub fn reward(&self, node_id: u32) -> f32 {
415 self.inner.reward(node_id)
416 }
417
418 pub fn set_reward(&mut self, node_id: u32, reward: f32) {
420 self.inner.set_reward(node_id, reward);
421 }
422
423 pub fn search_with_reward(
425 &self,
426 query: &[f32],
427 k: usize,
428 filter: TemporalFilter,
429 alpha: f32,
430 query_timestamp: i64,
431 min_reward: f32,
432 ) -> Vec<(u32, f32)> {
433 self.inner
434 .search_with_reward(query, k, filter, alpha, query_timestamp, min_reward)
435 }
436
437 pub fn compute_centroid(&self) -> Option<Vec<f32>> {
441 self.inner.compute_centroid()
442 }
443
444 pub fn set_centroid(&mut self, centroid: Vec<f32>) {
446 self.inner.set_centroid(centroid);
447 }
448
449 pub fn clear_centroid(&mut self) {
451 self.inner.clear_centroid();
452 }
453
454 pub fn centroid(&self) -> Option<&[f32]> {
456 self.inner.centroid()
457 }
458
459 pub fn centered_vector(&self, vec: &[f32]) -> Vec<f32> {
461 self.inner.centered_vector(vec)
462 }
463
464 pub fn regions(&self, level: usize) -> Vec<(u32, Vec<f32>, usize)> {
468 self.inner.regions(level)
469 }
470
471 pub fn region_assignments(
473 &self,
474 level: usize,
475 filter: TemporalFilter,
476 ) -> std::collections::HashMap<u32, Vec<(u64, i64)>> {
477 self.inner.region_assignments(level, filter)
478 }
479
480 pub fn region_trajectory(
482 &self,
483 entity_id: u64,
484 level: usize,
485 window_days: i64,
486 alpha: f32,
487 ) -> Vec<(i64, Vec<f32>)> {
488 self.inner
489 .region_trajectory(entity_id, level, window_days, alpha)
490 }
491
492 pub fn typed_edges(&self) -> &TypedEdgeStore {
496 &self.typed_edges
497 }
498
499 pub fn typed_edges_mut(&mut self) -> &mut TypedEdgeStore {
501 &mut self.typed_edges
502 }
503
504 pub fn add_typed_edge(&mut self, source: u32, target: u32, edge_type: EdgeType, weight: f32) {
506 self.typed_edges.add_edge(source, target, edge_type, weight);
507 }
508
509 pub fn success_score(&self, node_id: u32) -> f32 {
513 self.typed_edges.success_score(node_id)
514 }
515
516 pub fn save(&self, dir: &Path) -> std::io::Result<()> {
518 std::fs::create_dir_all(dir)?;
519 self.inner.save(&dir.join("index.bin"))?;
520 let edge_bytes = postcard::to_allocvec(&self.edges).map_err(std::io::Error::other)?;
521 std::fs::write(dir.join("temporal_edges.bin"), edge_bytes)?;
522 let typed_bytes =
523 postcard::to_allocvec(&self.typed_edges).map_err(std::io::Error::other)?;
524 std::fs::write(dir.join("typed_edges.bin"), typed_bytes)?;
525 Ok(())
526 }
527
528 pub fn load(dir: &Path, metric: D) -> std::io::Result<Self> {
530 let inner = TemporalHnsw::load(&dir.join("index.bin"), metric)?;
531 let edge_bytes = std::fs::read(dir.join("temporal_edges.bin"))?;
532 let edges: TemporalEdgeLayer = postcard::from_bytes(&edge_bytes)
533 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
534 let typed_edges = if dir.join("typed_edges.bin").exists() {
536 let typed_bytes = std::fs::read(dir.join("typed_edges.bin"))?;
537 postcard::from_bytes(&typed_bytes)
538 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?
539 } else {
540 TypedEdgeStore::new()
541 };
542 Ok(Self {
543 inner,
544 edges,
545 typed_edges,
546 })
547 }
548
549 #[allow(clippy::too_many_arguments)]
562 pub fn scored_search(
563 &self,
564 query: &[f32],
565 k: usize,
566 filter: TemporalFilter,
567 query_timestamp: i64,
568 weights: &super::bayesian_scorer::ScoringWeights,
569 query_region: Option<u32>,
570 ) -> Vec<(u32, f32)> {
571 use super::bayesian_scorer::{CandidateFeatures, rerank};
572
573 if self.inner.is_empty() {
574 return Vec::new();
575 }
576
577 let over_fetch = k * 4;
579 let candidates = self
580 .inner
581 .search(query, over_fetch, filter, 1.0, query_timestamp);
582
583 let features: Vec<CandidateFeatures> = candidates
585 .iter()
586 .map(|&(node_id, raw_distance)| {
587 let ts = self.inner.timestamp(node_id);
588 let sem_norm = self.inner.normalize_semantic_distance(raw_distance);
589 let recency = self.inner.recency_penalty(ts, 1.0);
590 let reward = self.inner.reward(node_id);
591 let success = self.typed_edges.success_score(node_id);
592
593 let region_match = query_region
594 .map(|qr| {
595 let candidate_vec = self.inner.vector(node_id);
597 self.inner
598 .graph()
599 .assign_region(candidate_vec, 1)
600 .map(|cr| cr == qr)
601 .unwrap_or(false)
602 })
603 .unwrap_or(false);
604
605 CandidateFeatures {
606 node_id,
607 raw_distance,
608 similarity: sem_norm,
609 recency,
610 reward,
611 success_score: success,
612 region_match,
613 }
614 })
615 .collect();
616
617 rerank(&features, weights, k)
619 }
620
621 pub fn assign_region(&self, vector: &[f32], level: usize) -> Option<u32> {
623 self.inner.graph().assign_region(vector, level)
624 }
625
626 pub fn trajectory_search(
648 &self,
649 recent_trajectory: &[(i64, &[f32])],
650 k: usize,
651 signature_depth: usize,
652 ) -> Vec<(u64, f64, usize)> {
653 use cvx_analytics::signatures::{SignatureConfig, compute_signature, signature_distance};
654
655 if recent_trajectory.len() < 2 {
656 return Vec::new();
657 }
658
659 let config = SignatureConfig {
660 depth: signature_depth,
661 time_augmentation: false,
662 };
663
664 let query_sig = match compute_signature(recent_trajectory, &config) {
666 Ok(sig) => sig,
667 Err(_) => return Vec::new(),
668 };
669
670 let mut entity_ids: Vec<u64> = Vec::new();
672 let mut seen = std::collections::HashSet::new();
673 for i in 0..self.inner.len() {
674 let eid = self.inner.entity_id(i as u32);
675 if seen.insert(eid) {
676 entity_ids.push(eid);
677 }
678 }
679
680 let mut scored: Vec<(u64, f64, usize)> = entity_ids
682 .iter()
683 .filter_map(|&eid| {
684 let traj = self.inner.trajectory(eid, TemporalFilter::All);
685 if traj.len() < 2 {
686 return None;
687 }
688 let ep_traj: Vec<(i64, Vec<f32>)> = traj
689 .iter()
690 .map(|&(ts, nid)| {
691 let v = self.inner.vector(nid);
692 (ts, v.to_vec())
693 })
694 .collect();
695 let ep_refs: Vec<(i64, &[f32])> =
697 ep_traj.iter().map(|(ts, v)| (*ts, v.as_slice())).collect();
698 let ep_sig = compute_signature(&ep_refs, &config).ok()?;
699 let dist = signature_distance(&query_sig, &ep_sig);
700 Some((eid, dist, ep_traj.len()))
701 })
702 .collect();
703
704 scored.sort_by(|a, b| a.1.total_cmp(&b.1));
705 scored.truncate(k);
706 scored
707 }
708}
709
710impl<D: DistanceMetric + Clone> cvx_core::TemporalIndexAccess for TemporalGraphIndex<D> {
713 fn search_raw(
714 &self,
715 query: &[f32],
716 k: usize,
717 filter: TemporalFilter,
718 alpha: f32,
719 query_timestamp: i64,
720 ) -> Vec<(u32, f32)> {
721 self.hybrid_search(query, k, filter, alpha, 0.3, query_timestamp)
723 }
724
725 fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
726 self.inner.trajectory(entity_id, filter)
727 }
728
729 fn vector(&self, node_id: u32) -> Vec<f32> {
730 self.inner.vector(node_id).to_vec()
731 }
732
733 fn entity_id(&self, node_id: u32) -> u64 {
734 self.inner.entity_id(node_id)
735 }
736
737 fn timestamp(&self, node_id: u32) -> i64 {
738 self.inner.timestamp(node_id)
739 }
740
741 fn len(&self) -> usize {
742 self.inner.len()
743 }
744}
745
746#[derive(Debug, Clone, Copy, PartialEq)]
749struct OrderedF32(f32);
750
751impl Eq for OrderedF32 {}
752
753impl PartialOrd for OrderedF32 {
754 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
755 Some(self.cmp(other))
756 }
757}
758
759impl Ord for OrderedF32 {
760 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
761 self.0
762 .partial_cmp(&other.0)
763 .unwrap_or(std::cmp::Ordering::Equal)
764 }
765}
766
767#[cfg(test)]
770mod tests {
771 use super::*;
772 use crate::metrics::L2Distance;
773
774 fn setup_index(
775 n_entities: u64,
776 points_per_entity: usize,
777 dim: usize,
778 ) -> TemporalGraphIndex<L2Distance> {
779 let config = HnswConfig {
780 m: 16,
781 ef_construction: 100,
782 ef_search: 50,
783 ..Default::default()
784 };
785 let mut index = TemporalGraphIndex::new(config, L2Distance);
786
787 for e in 0..n_entities {
788 for i in 0..points_per_entity {
789 let ts = (i as i64) * 1_000_000;
790 let v: Vec<f32> = (0..dim)
791 .map(|d| (e as f32 * 10.0) + (i as f32 * 0.1) + (d as f32 * 0.01))
792 .collect();
793 index.insert(e, ts, &v);
794 }
795 }
796
797 index
798 }
799
800 #[test]
803 fn insert_creates_temporal_edges() {
804 let index = setup_index(1, 5, 3);
805
806 assert_eq!(index.len(), 5);
807 assert_eq!(index.edges().len(), 5);
808
809 assert_eq!(index.edges().successor(0), Some(1));
811 assert_eq!(index.edges().successor(3), Some(4));
812 assert_eq!(index.edges().predecessor(4), Some(3));
813 assert_eq!(index.edges().successor(4), None);
814 }
815
816 #[test]
817 fn multi_entity_edges_isolated() {
818 let index = setup_index(3, 5, 3);
819
820 for i in 0..4u32 {
823 let succ = index.edges().successor(i);
824 assert!(succ.is_some());
825 let succ_entity = index.entity_id(succ.unwrap());
827 let my_entity = index.entity_id(i);
828 assert_eq!(
829 succ_entity, my_entity,
830 "edge from node {i} crosses entities"
831 );
832 }
833 }
834
835 #[test]
838 fn causal_search_returns_context() {
839 let index = setup_index(3, 10, 4);
840
841 let results = index.causal_search(
842 &[0.5, 0.05, 0.005, 0.001],
843 3,
844 TemporalFilter::All,
845 1.0,
846 5_000_000,
847 3, );
849
850 assert_eq!(results.len(), 3);
851
852 for r in &results {
853 assert!(
856 !r.successors.is_empty() || !r.predecessors.is_empty(),
857 "node {} should have some temporal context",
858 r.node_id
859 );
860
861 for w in r.successors.windows(2) {
863 assert!(w[0].1 <= w[1].1, "successors should be time-ordered");
864 }
865 }
866 }
867
868 #[test]
869 fn causal_search_successors_same_entity() {
870 let index = setup_index(5, 10, 3);
871
872 let results = index.causal_search(&[0.0, 0.0, 0.0], 5, TemporalFilter::All, 1.0, 0, 5);
873
874 for r in &results {
875 for &(succ_id, _) in &r.successors {
876 assert_eq!(
877 index.entity_id(succ_id),
878 r.entity_id,
879 "successor should be same entity"
880 );
881 }
882 for &(pred_id, _) in &r.predecessors {
883 assert_eq!(
884 index.entity_id(pred_id),
885 r.entity_id,
886 "predecessor should be same entity"
887 );
888 }
889 }
890 }
891
892 #[test]
895 fn hybrid_search_beta_zero_matches_standard() {
896 let index = setup_index(5, 20, 4);
897 let query = [5.0f32, 0.05, 0.005, 0.001];
898
899 let standard = index.search(&query, 10, TemporalFilter::All, 1.0, 0);
900 let hybrid = index.hybrid_search(&query, 10, TemporalFilter::All, 1.0, 0.0, 0);
901
902 assert_eq!(standard.len(), hybrid.len());
905
906 assert_eq!(
908 standard[0].0, hybrid[0].0,
909 "top result should match between standard and hybrid (beta=0)"
910 );
911 }
912
913 #[test]
914 fn hybrid_search_with_temporal_edges() {
915 let index = setup_index(3, 20, 4);
916 let query = [0.5f32, 0.05, 0.005, 0.001];
917
918 let results = index.hybrid_search(
919 &query,
920 10,
921 TemporalFilter::All,
922 1.0,
923 0.5, 5_000_000,
925 );
926
927 assert!(!results.is_empty());
928 assert!(results.len() <= 10);
929
930 for &(nid, score) in &results {
932 assert!((nid as usize) < index.len());
933 assert!(score >= 0.0);
934 assert!(score.is_finite());
935 }
936 }
937
938 #[test]
939 fn hybrid_search_respects_temporal_filter() {
940 let index = setup_index(3, 20, 4);
941 let query = [1.0f32, 0.1, 0.01, 0.001];
942
943 let results = index.hybrid_search(
944 &query,
945 10,
946 TemporalFilter::Range(5_000_000, 15_000_000),
947 1.0,
948 0.5,
949 10_000_000,
950 );
951
952 for &(nid, _) in &results {
953 let ts = index.timestamp(nid);
954 assert!(
955 (5_000_000..=15_000_000).contains(&ts),
956 "ts {ts} outside filter range"
957 );
958 }
959 }
960
961 #[test]
964 fn trait_search_works() {
965 let index = setup_index(3, 10, 4);
966 let trait_ref: &dyn cvx_core::TemporalIndexAccess = &index;
967
968 let results = trait_ref.search_raw(&[0.0; 4], 5, TemporalFilter::All, 1.0, 0);
969 assert_eq!(results.len(), 5);
970 }
971
972 #[test]
973 fn trait_trajectory_works() {
974 let index = setup_index(3, 10, 4);
975 let trait_ref: &dyn cvx_core::TemporalIndexAccess = &index;
976
977 let traj = trait_ref.trajectory(0, TemporalFilter::All);
978 assert_eq!(traj.len(), 10);
979 }
980
981 #[test]
984 fn from_temporal_hnsw_preserves_edges() {
985 let config = HnswConfig::default();
986 let mut hnsw = TemporalHnsw::new(config, L2Distance);
987
988 for i in 0..10u64 {
989 hnsw.insert(i % 3, i as i64 * 1000, &[i as f32, 0.0]);
990 }
991
992 let graph_index = TemporalGraphIndex::from_temporal_hnsw(hnsw);
993
994 assert_eq!(graph_index.len(), 10);
995 assert_eq!(graph_index.edges().len(), 10);
996
997 for nid in 0..10u32 {
999 if let Some(succ) = graph_index.edges().successor(nid) {
1000 assert_eq!(
1001 graph_index.entity_id(succ),
1002 graph_index.entity_id(nid),
1003 "edge from {nid} crosses entities after migration"
1004 );
1005 }
1006 }
1007 }
1008
1009 #[test]
1012 fn save_load_roundtrip() {
1013 let index = setup_index(3, 10, 3);
1014
1015 let dir = tempfile::tempdir().unwrap();
1016 index.save(dir.path()).unwrap();
1017
1018 let loaded = TemporalGraphIndex::load(dir.path(), L2Distance).unwrap();
1019
1020 assert_eq!(loaded.len(), 30);
1021 assert_eq!(loaded.edges().len(), 30);
1022
1023 for nid in 0..30u32 {
1025 assert_eq!(
1026 loaded.edges().successor(nid),
1027 index.edges().successor(nid),
1028 "successor mismatch at node {nid}"
1029 );
1030 }
1031 }
1032
1033 #[test]
1036 fn empty_index() {
1037 let config = HnswConfig::default();
1038 let index = TemporalGraphIndex::new(config, L2Distance);
1039
1040 assert!(index.is_empty());
1041 let results = index.hybrid_search(&[0.0; 3], 5, TemporalFilter::All, 1.0, 0.5, 0);
1042 assert!(results.is_empty());
1043
1044 let causal = index.causal_search(&[0.0; 3], 5, TemporalFilter::All, 1.0, 0, 3);
1045 assert!(causal.is_empty());
1046 }
1047
1048 #[test]
1049 fn single_point() {
1050 let config = HnswConfig::default();
1051 let mut index = TemporalGraphIndex::new(config, L2Distance);
1052 index.insert(1, 1000, &[1.0, 2.0, 3.0]);
1053
1054 let causal = index.causal_search(&[1.0, 2.0, 3.0], 1, TemporalFilter::All, 1.0, 0, 5);
1055 assert_eq!(causal.len(), 1);
1056 assert!(causal[0].successors.is_empty());
1057 assert!(causal[0].predecessors.is_empty());
1058 }
1059
1060 #[test]
1063 fn config_and_ef_delegation() {
1064 let config = HnswConfig {
1065 m: 8,
1066 ef_construction: 100,
1067 ef_search: 50,
1068 ..Default::default()
1069 };
1070 let mut index = TemporalGraphIndex::new(config, L2Distance);
1071 assert_eq!(index.config().m, 8);
1072 assert_eq!(index.config().ef_construction, 100);
1073
1074 index.set_ef_construction(150);
1075 assert_eq!(index.config().ef_construction, 150);
1076
1077 index.set_ef_search(200);
1078 assert_eq!(index.config().ef_search, 200);
1079 }
1080
1081 #[test]
1082 fn centering_delegation() {
1083 let config = HnswConfig::default();
1084 let mut index = TemporalGraphIndex::new(config, L2Distance);
1085 index.insert(1, 1000, &[2.0, 4.0]);
1086 index.insert(2, 2000, &[4.0, 6.0]);
1087
1088 let centroid = index.compute_centroid().unwrap();
1089 assert!((centroid[0] - 3.0).abs() < 1e-6);
1090
1091 index.set_centroid(vec![3.0, 5.0]);
1092 assert_eq!(index.centroid().unwrap(), &[3.0, 5.0]);
1093
1094 let centered = index.centered_vector(&[5.0, 8.0]);
1095 assert!((centered[0] - 2.0).abs() < 1e-6);
1096 assert!((centered[1] - 3.0).abs() < 1e-6);
1097
1098 index.clear_centroid();
1099 assert!(index.centroid().is_none());
1100 }
1101
1102 #[test]
1103 fn reward_delegation() {
1104 let config = HnswConfig::default();
1105 let mut index = TemporalGraphIndex::new(config, L2Distance);
1106 let n0 = index.insert(1, 1000, &[1.0, 0.0]);
1107 let n1 = index.insert_with_reward(2, 2000, &[0.0, 1.0], 0.8);
1108
1109 assert!(index.reward(n0).is_nan());
1110 assert!((index.reward(n1) - 0.8).abs() < 1e-6);
1111
1112 index.set_reward(n0, 0.95);
1113 assert!((index.reward(n0) - 0.95).abs() < 1e-6);
1114 }
1115
1116 #[test]
1117 fn search_with_reward_delegation() {
1118 let config = HnswConfig::default();
1119 let mut index = TemporalGraphIndex::new(config, L2Distance);
1120 for i in 0..10u64 {
1121 index.insert_with_reward(i, i as i64 * 1000, &[i as f32, 0.0, 0.0], i as f32 * 0.1);
1122 }
1123
1124 let results =
1125 index.search_with_reward(&[7.0, 0.0, 0.0], 5, TemporalFilter::All, 1.0, 0, 0.5);
1126 assert!(!results.is_empty());
1127 for &(node_id, _) in &results {
1128 assert!(
1129 index.reward(node_id) >= 0.5,
1130 "node {node_id} reward {} < 0.5",
1131 index.reward(node_id)
1132 );
1133 }
1134 }
1135
1136 #[test]
1137 fn region_delegation() {
1138 let config = HnswConfig {
1139 m: 4,
1140 ef_construction: 50,
1141 ef_search: 50,
1142 ..Default::default()
1143 };
1144 let mut index = TemporalGraphIndex::new(config, L2Distance);
1145 let mut rng = rand::rng();
1146 for i in 0..200u64 {
1147 let v: Vec<f32> = (0..8).map(|_| rand::Rng::random::<f32>(&mut rng)).collect();
1148 index.insert(i % 4, i as i64 * 1000, &v);
1149 }
1150
1151 let regions = index.regions(1);
1152 assert!(!regions.is_empty());
1153
1154 let assignments = index.region_assignments(1, TemporalFilter::All);
1155 let total: usize = assignments.values().map(|v| v.len()).sum();
1156 assert_eq!(total, 200);
1157 }
1158
1159 #[test]
1160 fn scalar_quantization_delegation() {
1161 let config = HnswConfig::default();
1162 let mut index = TemporalGraphIndex::new(config, L2Distance);
1163 index.insert(1, 1000, &[1.0, 0.0]);
1164 index.insert(2, 2000, &[0.0, 1.0]);
1165
1166 index.enable_scalar_quantization(-1.0, 1.0);
1167 let results = index.search(&[1.0, 0.0], 2, TemporalFilter::All, 1.0, 0);
1168 assert_eq!(results.len(), 2);
1169
1170 index.disable_scalar_quantization();
1171 let results = index.search(&[1.0, 0.0], 2, TemporalFilter::All, 1.0, 0);
1172 assert_eq!(results.len(), 2);
1173 }
1174
1175 #[test]
1176 fn insert_with_reward_creates_temporal_edges() {
1177 let config = HnswConfig::default();
1178 let mut index = TemporalGraphIndex::new(config, L2Distance);
1179
1180 for i in 0..5u32 {
1182 index.insert_with_reward(1, i as i64 * 100, &[i as f32, 0.0], i as f32 * 0.2);
1183 }
1184
1185 let edges = index.edges();
1187 assert!(edges.successor(0).is_some());
1188 assert!(edges.predecessor(4).is_some());
1189
1190 let results = index.causal_search(&[0.0, 0.0], 1, TemporalFilter::All, 1.0, 0, 3);
1192 assert_eq!(results.len(), 1);
1193 assert!(!results[0].successors.is_empty());
1194 }
1195
1196 #[test]
1199 fn scored_search_basic() {
1200 use crate::hnsw::ScoringWeights;
1201
1202 let config = HnswConfig::default();
1203 let mut index = TemporalGraphIndex::new(config, L2Distance);
1204
1205 for i in 0..20u64 {
1206 index.insert(i % 3, (i * 100) as i64, &[i as f32, 0.0, 0.0]);
1207 }
1208
1209 let weights = ScoringWeights::default();
1210 let results =
1211 index.scored_search(&[10.0, 0.0, 0.0], 5, TemporalFilter::All, 0, &weights, None);
1212 assert_eq!(results.len(), 5);
1213 }
1214
1215 #[test]
1216 fn scored_search_reward_boosts() {
1217 use crate::hnsw::ScoringWeights;
1218
1219 let config = HnswConfig::default();
1220 let mut index = TemporalGraphIndex::new(config, L2Distance);
1221
1222 index.insert_with_reward(1, 1000, &[1.0, 0.0, 0.0], 0.9);
1224 index.insert_with_reward(2, 2000, &[1.01, 0.0, 0.0], 0.1);
1225
1226 let weights = ScoringWeights {
1227 similarity: 1.0,
1228 reward: 0.5, ..ScoringWeights::default()
1230 };
1231
1232 let results =
1233 index.scored_search(&[1.0, 0.0, 0.0], 2, TemporalFilter::All, 0, &weights, None);
1234
1235 assert_eq!(results.len(), 2);
1236 assert_eq!(results[0].0, 0, "high-reward node should rank first");
1238 }
1239
1240 #[test]
1241 fn scored_search_success_score_from_typed_edges() {
1242 use crate::hnsw::{EdgeType, ScoringWeights};
1243
1244 let config = HnswConfig::default();
1245 let mut index = TemporalGraphIndex::new(config, L2Distance);
1246
1247 let n0 = index.insert(1, 1000, &[1.0, 0.0]);
1248 let n1 = index.insert(2, 2000, &[1.01, 0.0]);
1249
1250 index.add_typed_edge(n0, 10, EdgeType::CausedSuccess, 1.0);
1252 index.add_typed_edge(n0, 11, EdgeType::CausedSuccess, 1.0);
1253 index.add_typed_edge(n1, 12, EdgeType::CausedFailure, 1.0);
1254 index.add_typed_edge(n1, 13, EdgeType::CausedFailure, 1.0);
1255
1256 let weights = ScoringWeights {
1257 similarity: 1.0,
1258 success: 0.5,
1259 ..ScoringWeights::default()
1260 };
1261
1262 let results = index.scored_search(&[1.0, 0.0], 2, TemporalFilter::All, 0, &weights, None);
1263
1264 assert_eq!(results[0].0, n0);
1266 }
1267
1268 #[test]
1269 fn assign_region_works() {
1270 let config = HnswConfig {
1271 m: 4,
1272 ef_construction: 50,
1273 ef_search: 50,
1274 ..Default::default()
1275 };
1276 let mut index = TemporalGraphIndex::new(config, L2Distance);
1277
1278 let mut rng = rand::rng();
1279 for i in 0..200u64 {
1280 let v: Vec<f32> = (0..8).map(|_| rand::Rng::random::<f32>(&mut rng)).collect();
1281 index.insert(i % 4, (i * 100) as i64, &v);
1282 }
1283
1284 let query: Vec<f32> = (0..8).map(|_| rand::Rng::random::<f32>(&mut rng)).collect();
1285 let region = index.assign_region(&query, 1);
1286 assert!(region.is_some());
1287 }
1288}