1use std::collections::BTreeMap;
36use std::io::{Read, Write};
37use std::path::Path;
38
39use cvx_core::{DistanceMetric, TemporalFilter};
40use roaring::RoaringBitmap;
41use serde::{Deserialize, Serialize};
42
43use super::{HnswConfig, HnswGraph, HnswSnapshot};
44
45pub struct TemporalHnsw<D: DistanceMetric> {
50 graph: HnswGraph<D>,
51 timestamps: Vec<i64>,
53 entity_ids: Vec<u64>,
55 entity_index: BTreeMap<u64, Vec<(i64, u32)>>,
57 min_timestamp: i64,
59 max_timestamp: i64,
60 metadata_store: Option<super::metadata_store::MetadataStore>,
62 centroid: Option<Vec<f32>>,
68 rewards: Vec<f32>,
72}
73
74impl<D: DistanceMetric> TemporalHnsw<D> {
75 pub fn new(config: HnswConfig, metric: D) -> Self {
77 Self {
78 graph: HnswGraph::new(config, metric),
79 timestamps: Vec::new(),
80 entity_ids: Vec::new(),
81 entity_index: BTreeMap::new(),
82 min_timestamp: i64::MAX,
83 max_timestamp: i64::MIN,
84 metadata_store: None,
85 centroid: None,
86 rewards: Vec::new(),
87 }
88 }
89
90 pub fn len(&self) -> usize {
92 self.graph.len()
93 }
94
95 pub fn is_empty(&self) -> bool {
97 self.graph.is_empty()
98 }
99
100 pub fn entity_last_node(&self, entity_id: u64) -> Option<u32> {
102 self.entity_index
103 .get(&entity_id)
104 .and_then(|pts| pts.last().map(|&(_, nid)| nid))
105 }
106
107 pub fn set_ef_construction(&mut self, ef: usize) {
109 self.graph.set_ef_construction(ef);
110 }
111
112 pub fn set_ef_search(&mut self, ef: usize) {
114 self.graph.set_ef_search(ef);
115 }
116
117 pub fn config(&self) -> &HnswConfig {
119 self.graph.config()
120 }
121
122 pub fn enable_scalar_quantization(&mut self, min_val: f32, max_val: f32) {
124 self.graph.enable_scalar_quantization(min_val, max_val);
125 }
126
127 pub fn disable_scalar_quantization(&mut self) {
129 self.graph.disable_scalar_quantization();
130 }
131
132 pub fn is_quantized(&self) -> bool {
134 self.graph.is_quantized()
135 }
136
137 pub fn insert(&mut self, entity_id: u64, timestamp: i64, vector: &[f32]) -> u32 {
141 self.insert_with_reward(entity_id, timestamp, vector, f32::NAN)
142 }
143
144 pub fn bulk_insert_parallel(
152 &mut self,
153 entity_ids: &[u64],
154 timestamps: &[i64],
155 vectors: &[&[f32]],
156 ) -> usize {
157 use rayon::prelude::*;
158
159 let n = entity_ids.len();
160 if n == 0 {
161 return 0;
162 }
163
164 let seed_count = n.min(100);
166 for i in 0..seed_count {
167 self.insert(entity_ids[i], timestamps[i], vectors[i]);
168 }
169
170 if seed_count >= n {
171 return n;
172 }
173
174 let batch_size = 256;
176 let remaining = &vectors[seed_count..];
177 let remaining_eids = &entity_ids[seed_count..];
178 let remaining_ts = ×tamps[seed_count..];
179
180 for batch_start in (0..remaining.len()).step_by(batch_size) {
181 let batch_end = (batch_start + batch_size).min(remaining.len());
182 let batch_vecs = &remaining[batch_start..batch_end];
183
184 let neighbor_lists: Vec<Vec<(u32, f32)>> = batch_vecs
186 .par_iter()
187 .map(|vec| self.graph.search(vec, self.graph.config().ef_construction))
188 .collect();
189
190 for (i, neighbors) in neighbor_lists.into_iter().enumerate() {
192 let idx = batch_start + i;
193 let eid = remaining_eids[idx];
194 let ts = remaining_ts[idx];
195 let vec = remaining[idx];
196
197 let node_id = self.graph.len() as u32;
198 let level = self.graph.random_level();
200 self.graph.push_node(vec, level);
201 self.timestamps.push(ts);
202 self.entity_ids.push(eid);
203 self.rewards.push(f32::NAN);
204
205 self.graph.connect_node(node_id, &neighbors, level);
207
208 self.entity_index
210 .entry(eid)
211 .or_default()
212 .push((ts, node_id));
213 self.min_timestamp = self.min_timestamp.min(ts);
214 self.max_timestamp = self.max_timestamp.max(ts);
215
216 if let Some(ref mut store) = self.metadata_store {
217 store.push_empty();
218 }
219 }
220 }
221
222 n
223 }
224
225 pub fn insert_with_reward(
230 &mut self,
231 entity_id: u64,
232 timestamp: i64,
233 vector: &[f32],
234 reward: f32,
235 ) -> u32 {
236 let node_id = self.graph.len() as u32;
237 self.graph.insert(node_id, vector);
238 self.timestamps.push(timestamp);
239 self.entity_ids.push(entity_id);
240 self.rewards.push(reward);
241
242 self.entity_index
244 .entry(entity_id)
245 .or_default()
246 .push((timestamp, node_id));
247
248 self.min_timestamp = self.min_timestamp.min(timestamp);
250 self.max_timestamp = self.max_timestamp.max(timestamp);
251
252 if let Some(ref mut store) = self.metadata_store {
254 store.push_empty();
255 }
256
257 node_id
258 }
259
260 pub fn insert_with_metadata(
262 &mut self,
263 entity_id: u64,
264 timestamp: i64,
265 vector: &[f32],
266 metadata: std::collections::HashMap<String, String>,
267 ) -> u32 {
268 if self.metadata_store.is_none() {
270 let mut store = super::metadata_store::MetadataStore::new();
271 for _ in 0..self.graph.len() {
273 store.push_empty();
274 }
275 self.metadata_store = Some(store);
276 }
277
278 let node_id = self.graph.len() as u32;
279 self.graph.insert(node_id, vector);
280 self.timestamps.push(timestamp);
281 self.entity_ids.push(entity_id);
282 self.rewards.push(f32::NAN);
283
284 self.entity_index
285 .entry(entity_id)
286 .or_default()
287 .push((timestamp, node_id));
288
289 self.min_timestamp = self.min_timestamp.min(timestamp);
290 self.max_timestamp = self.max_timestamp.max(timestamp);
291
292 if let Some(ref mut store) = self.metadata_store {
293 store.push(metadata);
294 }
295
296 node_id
297 }
298
299 pub fn node_metadata(&self, node_id: u32) -> std::collections::HashMap<String, String> {
301 self.metadata_store
302 .as_ref()
303 .map(|s| s.get(node_id).clone())
304 .unwrap_or_default()
305 }
306
307 pub fn build_filter_bitmap(&self, filter: &TemporalFilter) -> RoaringBitmap {
309 let mut bitmap = RoaringBitmap::new();
310 for (i, &ts) in self.timestamps.iter().enumerate() {
311 if filter.matches(ts) {
312 bitmap.insert(i as u32);
313 }
314 }
315 bitmap
316 }
317
318 pub fn temporal_distance_normalized(&self, t1: i64, t2: i64) -> f32 {
322 let range = (self.max_timestamp - self.min_timestamp).max(1) as f64;
323 let diff = (t1 as f64 - t2 as f64).abs();
324 (diff / range) as f32
325 }
326
327 pub(crate) fn normalize_semantic_distance(&self, d: f32) -> f32 {
332 (d / 2.0).min(1.0)
335 }
336
337 pub(crate) fn recency_penalty(&self, node_timestamp: i64, recency_lambda: f32) -> f32 {
347 if recency_lambda <= 0.0 {
348 return 0.0;
349 }
350 let age = self.temporal_distance_normalized(node_timestamp, self.max_timestamp);
351 1.0 - (-recency_lambda * age).exp()
352 }
353
354 #[allow(clippy::too_many_arguments)]
362 pub fn search_with_recency(
363 &self,
364 query: &[f32],
365 k: usize,
366 filter: TemporalFilter,
367 alpha: f32,
368 query_timestamp: i64,
369 recency_lambda: f32,
370 recency_weight: f32,
371 ) -> Vec<(u32, f32)> {
372 if self.is_empty() {
373 return Vec::new();
374 }
375
376 let bitmap = self.build_filter_bitmap(&filter);
377 if bitmap.is_empty() {
378 return Vec::new();
379 }
380
381 let over_fetch = k * 4;
382 let candidates = self
383 .graph
384 .search_filtered(query, over_fetch, |id| bitmap.contains(id));
385
386 let mut scored: Vec<(u32, f32)> = candidates
387 .into_iter()
388 .map(|(id, sem_dist)| {
389 let sem_norm = self.normalize_semantic_distance(sem_dist);
390 let t_dist = self
391 .temporal_distance_normalized(self.timestamps[id as usize], query_timestamp);
392 let recency = self.recency_penalty(self.timestamps[id as usize], recency_lambda);
393
394 let combined = alpha * sem_norm + (1.0 - alpha) * t_dist + recency_weight * recency;
395 (id, combined)
396 })
397 .collect();
398
399 scored.sort_by(|a, b| a.1.total_cmp(&b.1));
400 scored.truncate(k);
401 scored
402 }
403
404 pub fn reward(&self, node_id: u32) -> f32 {
408 self.rewards
409 .get(node_id as usize)
410 .copied()
411 .unwrap_or(f32::NAN)
412 }
413
414 pub fn set_reward(&mut self, node_id: u32, reward: f32) {
418 if let Some(r) = self.rewards.get_mut(node_id as usize) {
419 *r = reward;
420 }
421 }
422
423 pub fn build_reward_bitmap(&self, min_reward: f32) -> RoaringBitmap {
425 let mut bitmap = RoaringBitmap::new();
426 for (i, &r) in self.rewards.iter().enumerate() {
427 if !r.is_nan() && r >= min_reward {
428 bitmap.insert(i as u32);
429 }
430 }
431 bitmap
432 }
433
434 pub fn search_with_reward(
438 &self,
439 query: &[f32],
440 k: usize,
441 filter: TemporalFilter,
442 alpha: f32,
443 query_timestamp: i64,
444 min_reward: f32,
445 ) -> Vec<(u32, f32)> {
446 if self.is_empty() {
447 return Vec::new();
448 }
449
450 let temporal_bitmap = self.build_filter_bitmap(&filter);
451 let reward_bitmap = self.build_reward_bitmap(min_reward);
452 let combined = temporal_bitmap & reward_bitmap;
453
454 if combined.is_empty() {
455 return Vec::new();
456 }
457
458 let candidates = self
459 .graph
460 .search_filtered(query, k, |id| combined.contains(id));
461
462 if alpha >= 1.0 {
463 return candidates;
464 }
465
466 let mut scored: Vec<(u32, f32)> = candidates
467 .into_iter()
468 .map(|(id, sem_dist)| {
469 let t_dist = self
470 .temporal_distance_normalized(self.timestamps[id as usize], query_timestamp);
471 (id, alpha * sem_dist + (1.0 - alpha) * t_dist)
472 })
473 .collect();
474
475 scored.sort_by(|a, b| a.1.total_cmp(&b.1));
476 scored.truncate(k);
477 scored
478 }
479
480 pub fn compute_centroid(&self) -> Option<Vec<f32>> {
487 let n = self.graph.len();
488 if n == 0 {
489 return None;
490 }
491
492 let dim = self.graph.vector(0).len();
493 let mut sum = vec![0.0f64; dim];
494
495 for i in 0..n {
496 let v = self.graph.vector(i as u32);
497 for (s, &val) in sum.iter_mut().zip(v.iter()) {
498 *s += val as f64;
499 }
500 }
501
502 let inv_n = 1.0 / n as f64;
503 Some(sum.into_iter().map(|s| (s * inv_n) as f32).collect())
504 }
505
506 pub fn set_centroid(&mut self, centroid: Vec<f32>) {
515 self.centroid = Some(centroid);
516 }
517
518 pub fn clear_centroid(&mut self) {
520 self.centroid = None;
521 }
522
523 pub fn centroid(&self) -> Option<&[f32]> {
525 self.centroid.as_deref()
526 }
527
528 pub fn centered_vector(&self, vec: &[f32]) -> Vec<f32> {
532 match &self.centroid {
533 Some(c) => vec.iter().zip(c.iter()).map(|(v, m)| v - m).collect(),
534 None => vec.to_vec(),
535 }
536 }
537
538 pub fn search(
548 &self,
549 query: &[f32],
550 k: usize,
551 filter: TemporalFilter,
552 alpha: f32,
553 query_timestamp: i64,
554 ) -> Vec<(u32, f32)> {
555 if self.is_empty() {
556 return Vec::new();
557 }
558
559 let bitmap = self.build_filter_bitmap(&filter);
561 if bitmap.is_empty() {
562 return Vec::new();
563 }
564
565 if alpha >= 1.0 {
566 return self
568 .graph
569 .search_filtered(query, k, |id| bitmap.contains(id));
570 }
571
572 let over_fetch = k * 4;
574 let candidates = self
575 .graph
576 .search_filtered(query, over_fetch, |id| bitmap.contains(id));
577
578 let mut scored: Vec<(u32, f32)> = candidates
580 .into_iter()
581 .map(|(id, sem_dist)| {
582 let sem_norm = self.normalize_semantic_distance(sem_dist);
583 let t_dist = self
584 .temporal_distance_normalized(self.timestamps[id as usize], query_timestamp);
585 let combined = alpha * sem_norm + (1.0 - alpha) * t_dist;
586 (id, combined)
587 })
588 .collect();
589
590 scored.sort_by(|a, b| a.1.total_cmp(&b.1));
591 scored.truncate(k);
592 scored
593 }
594
595 pub fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
599 let Some(points) = self.entity_index.get(&entity_id) else {
600 return Vec::new();
601 };
602
603 let mut result: Vec<(i64, u32)> = points
604 .iter()
605 .filter(|&&(ts, _)| filter.matches(ts))
606 .copied()
607 .collect();
608
609 result.sort_by_key(|&(ts, _)| ts);
610 result
611 }
612
613 pub fn timestamp(&self, node_id: u32) -> i64 {
615 self.timestamps[node_id as usize]
616 }
617
618 pub fn entity_id(&self, node_id: u32) -> u64 {
620 self.entity_ids[node_id as usize]
621 }
622
623 pub fn vector(&self, node_id: u32) -> &[f32] {
625 self.graph.vector(node_id)
626 }
627
628 pub fn bitmap_memory_bytes(&self) -> usize {
632 let bitmap = self.build_filter_bitmap(&TemporalFilter::All);
633 bitmap.serialized_size()
634 }
635
636 pub fn graph(&self) -> &HnswGraph<D> {
638 &self.graph
639 }
640
641 pub fn regions(&self, level: usize) -> Vec<(u32, Vec<f32>, usize)> {
648 let hubs = self.graph.nodes_at_level(level);
649 let n = self.graph.len();
650
651 let mut counts = vec![0usize; hubs.len()];
653 let hub_set: std::collections::HashMap<u32, usize> =
654 hubs.iter().enumerate().map(|(i, &h)| (h, i)).collect();
655
656 for node_id in 0..n as u32 {
657 if let Some(hub) = self.graph.assign_region(self.graph.vector(node_id), level) {
658 if let Some(&idx) = hub_set.get(&hub) {
659 counts[idx] += 1;
660 }
661 }
662 }
663
664 hubs.iter()
665 .enumerate()
666 .map(|(i, &hub_id)| (hub_id, self.graph.vector(hub_id).to_vec(), counts[i]))
667 .collect()
668 }
669
670 pub fn region_trajectory(
679 &self,
680 entity_id: u64,
681 level: usize,
682 window_days: i64,
683 alpha: f32,
684 ) -> Vec<(i64, Vec<f32>)> {
685 let hubs = self.graph.nodes_at_level(level);
686 let k = hubs.len();
687 if k == 0 {
688 return Vec::new();
689 }
690
691 let hub_index: std::collections::HashMap<u32, usize> =
693 hubs.iter().enumerate().map(|(i, &h)| (h, i)).collect();
694
695 let posts = self.trajectory(entity_id, TemporalFilter::All);
697 if posts.is_empty() {
698 return Vec::new();
699 }
700
701 let assignments: Vec<(i64, usize)> = posts
703 .iter()
704 .filter_map(|&(ts, node_id)| {
705 let vec = self.graph.vector(node_id);
706 self.graph
707 .assign_region(vec, level)
708 .and_then(|hub| hub_index.get(&hub).map(|&idx| (ts, idx)))
709 })
710 .collect();
711
712 if assignments.is_empty() {
713 return Vec::new();
714 }
715
716 let t_start = assignments[0].0;
718 let t_end = assignments.last().unwrap().0;
719 let mut result = Vec::new();
720 let mut ema_state: Vec<f32> = vec![0.0; k];
721 let mut first = true;
722
723 let mut window_start = t_start;
724 while window_start <= t_end {
725 let window_end = window_start + window_days;
726
727 let mut counts = vec![0.0f32; k];
729 let mut n_in_window = 0.0f32;
730 for &(ts, region_idx) in &assignments {
731 if ts >= window_start && ts < window_end {
732 counts[region_idx] += 1.0;
733 n_in_window += 1.0;
734 }
735 }
736
737 if n_in_window > 0.0 {
738 for c in &mut counts {
740 *c /= n_in_window;
741 }
742
743 if first {
745 ema_state = counts;
746 first = false;
747 } else {
748 for i in 0..k {
749 ema_state[i] = alpha * counts[i] + (1.0 - alpha) * ema_state[i];
750 }
751 }
752
753 let mid_ts = window_start + window_days / 2;
754 result.push((mid_ts, ema_state.clone()));
755 }
756
757 window_start = window_end;
758 }
759
760 result
761 }
762
763 pub fn region_members(
771 &self,
772 region_hub: u32,
773 level: usize,
774 filter: TemporalFilter,
775 ) -> Vec<(u32, u64, i64)> {
776 let mut members = Vec::new();
777 for node_id in 0..self.graph.len() as u32 {
778 let ts = self.timestamps[node_id as usize];
779 if !filter.matches(ts) {
780 continue;
781 }
782 let vec = self.graph.vector(node_id);
783 if let Some(assigned_hub) = self.graph.assign_region(vec, level) {
784 if assigned_hub == region_hub {
785 let eid = self.entity_ids[node_id as usize];
786 members.push((node_id, eid, ts));
787 }
788 }
789 }
790 members
791 }
792
793 pub fn region_assignments(
798 &self,
799 level: usize,
800 filter: TemporalFilter,
801 ) -> std::collections::HashMap<u32, Vec<(u64, i64)>> {
802 let mut assignments: std::collections::HashMap<u32, Vec<(u64, i64)>> =
803 std::collections::HashMap::new();
804
805 for node_id in 0..self.graph.len() as u32 {
806 let ts = self.timestamps[node_id as usize];
807 if !filter.matches(ts) {
808 continue;
809 }
810 let vec = self.graph.vector(node_id);
811 if let Some(hub) = self.graph.assign_region(vec, level) {
812 let eid = self.entity_ids[node_id as usize];
813 assignments.entry(hub).or_default().push((eid, ts));
814 }
815 }
816
817 assignments
818 }
819}
820
821const SNAPSHOT_VERSION: u32 = 2;
823
824#[derive(Serialize, Deserialize)]
826struct TemporalSnapshot {
827 #[serde(default = "default_snapshot_version")]
831 version: u32,
832 graph: HnswSnapshot,
833 timestamps: Vec<i64>,
834 entity_ids: Vec<u64>,
835 entity_index: BTreeMap<u64, Vec<(i64, u32)>>,
836 min_timestamp: i64,
837 max_timestamp: i64,
838 #[serde(default)]
839 metadata_store: Option<super::metadata_store::MetadataStore>,
840 #[serde(default)]
842 centroid: Option<Vec<f32>>,
843 #[serde(default)]
845 rewards: Vec<f32>,
846}
847
848fn default_snapshot_version() -> u32 {
849 1 }
851
852impl<D: DistanceMetric> TemporalHnsw<D> {
853 pub fn save(&self, path: &Path) -> std::io::Result<()> {
858 let snapshot = TemporalSnapshot {
859 version: SNAPSHOT_VERSION,
860 graph: self.graph.to_snapshot(),
861 timestamps: self.timestamps.clone(),
862 entity_ids: self.entity_ids.clone(),
863 entity_index: self.entity_index.clone(),
864 min_timestamp: self.min_timestamp,
865 max_timestamp: self.max_timestamp,
866 metadata_store: self.metadata_store.clone(),
867 centroid: self.centroid.clone(),
868 rewards: self.rewards.clone(),
869 };
870
871 let bytes = postcard::to_allocvec(&snapshot).map_err(std::io::Error::other)?;
872
873 let mut file = std::fs::File::create(path)?;
874 file.write_all(&bytes)?;
875 Ok(())
876 }
877
878 pub fn load(path: &Path, metric: D) -> std::io::Result<Self> {
883 let mut file = std::fs::File::open(path)?;
884 let mut bytes = Vec::new();
885 file.read_to_end(&mut bytes)?;
886
887 let snapshot: TemporalSnapshot = postcard::from_bytes(&bytes)
888 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
889
890 if snapshot.version > SNAPSHOT_VERSION {
891 return Err(std::io::Error::new(
892 std::io::ErrorKind::InvalidData,
893 format!(
894 "Snapshot version {} is newer than supported version {}. \
895 Please upgrade chronos-vector.",
896 snapshot.version, SNAPSHOT_VERSION
897 ),
898 ));
899 }
900
901 let n_points = snapshot.timestamps.len();
902 let rewards = if snapshot.rewards.is_empty() {
903 vec![f32::NAN; n_points]
905 } else {
906 snapshot.rewards
907 };
908
909 Ok(Self {
910 graph: HnswGraph::from_snapshot(snapshot.graph, metric),
911 timestamps: snapshot.timestamps,
912 entity_ids: snapshot.entity_ids,
913 entity_index: snapshot.entity_index,
914 min_timestamp: snapshot.min_timestamp,
915 max_timestamp: snapshot.max_timestamp,
916 metadata_store: snapshot.metadata_store,
917 centroid: snapshot.centroid,
918 rewards,
919 })
920 }
921}
922
923impl<D: DistanceMetric> cvx_core::TemporalIndexAccess for TemporalHnsw<D> {
924 fn search_raw(
925 &self,
926 query: &[f32],
927 k: usize,
928 filter: TemporalFilter,
929 alpha: f32,
930 query_timestamp: i64,
931 ) -> Vec<(u32, f32)> {
932 self.search(query, k, filter, alpha, query_timestamp)
933 }
934
935 fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
936 self.trajectory(entity_id, filter)
937 }
938
939 fn vector(&self, node_id: u32) -> Vec<f32> {
940 self.graph.vector(node_id).to_vec()
941 }
942
943 fn entity_id(&self, node_id: u32) -> u64 {
944 self.entity_ids[node_id as usize]
945 }
946
947 fn timestamp(&self, node_id: u32) -> i64 {
948 self.timestamps[node_id as usize]
949 }
950
951 fn len(&self) -> usize {
952 self.graph.len()
953 }
954
955 fn regions(&self, level: usize) -> Vec<(u32, Vec<f32>, usize)> {
956 self.regions(level)
957 }
958
959 fn region_members(
960 &self,
961 region_hub: u32,
962 level: usize,
963 filter: TemporalFilter,
964 ) -> Vec<(u32, u64, i64)> {
965 self.region_members(region_hub, level, filter)
966 }
967
968 fn region_trajectory(
969 &self,
970 entity_id: u64,
971 level: usize,
972 window_days: i64,
973 alpha: f32,
974 ) -> Vec<(i64, Vec<f32>)> {
975 self.region_trajectory(entity_id, level, window_days, alpha)
976 }
977
978 fn metadata(&self, node_id: u32) -> std::collections::HashMap<String, String> {
979 self.node_metadata(node_id)
980 }
981
982 fn search_with_metadata(
983 &self,
984 query: &[f32],
985 k: usize,
986 filter: TemporalFilter,
987 alpha: f32,
988 query_timestamp: i64,
989 metadata_filter: &cvx_core::types::MetadataFilter,
990 ) -> Vec<(u32, f32)> {
991 if metadata_filter.is_empty() {
992 return self.search(query, k, filter, alpha, query_timestamp);
993 }
994
995 match &self.metadata_store {
996 Some(store) => {
997 let temporal_bitmap = self.build_filter_bitmap(&filter);
999 let metadata_bitmap = store.build_filter_bitmap(metadata_filter);
1000 let combined = temporal_bitmap & metadata_bitmap;
1001
1002 if combined.is_empty() {
1003 return Vec::new();
1004 }
1005
1006 let candidates = self
1008 .graph
1009 .search_filtered(query, k, |id| combined.contains(id));
1010
1011 if alpha >= 1.0 {
1012 return candidates;
1013 }
1014
1015 let mut scored: Vec<(u32, f32)> = candidates
1017 .into_iter()
1018 .map(|(id, sem_dist)| {
1019 let t_dist = self.temporal_distance_normalized(
1020 self.timestamps[id as usize],
1021 query_timestamp,
1022 );
1023 let combined_score = alpha * sem_dist + (1.0 - alpha) * t_dist;
1024 (id, combined_score)
1025 })
1026 .collect();
1027
1028 scored.sort_by(|a, b| a.1.total_cmp(&b.1));
1029 scored.truncate(k);
1030 scored
1031 }
1032 None => {
1033 self.search(query, k, filter, alpha, query_timestamp)
1035 }
1036 }
1037 }
1038}
1039
1040#[cfg(test)]
1041mod tests {
1042 use super::*;
1043 use crate::metrics::{CosineDistance, L2Distance};
1044
1045 fn make_temporal_index() -> TemporalHnsw<L2Distance> {
1046 let config = HnswConfig {
1047 m: 16,
1048 ef_construction: 200,
1049 ef_search: 100,
1050 ..Default::default()
1051 };
1052 TemporalHnsw::new(config, L2Distance)
1053 }
1054
1055 #[test]
1058 fn empty_index() {
1059 let index = make_temporal_index();
1060 assert!(index.is_empty());
1061 assert_eq!(index.len(), 0);
1062 let results = index.search(&[1.0, 0.0], 5, TemporalFilter::All, 1.0, 0);
1063 assert!(results.is_empty());
1064 }
1065
1066 #[test]
1067 fn insert_and_metadata() {
1068 let mut index = make_temporal_index();
1069 let id = index.insert(42, 1000, &[1.0, 0.0, 0.0]);
1070 assert_eq!(id, 0);
1071 assert_eq!(index.len(), 1);
1072 assert_eq!(index.timestamp(0), 1000);
1073 assert_eq!(index.entity_id(0), 42);
1074 assert_eq!(index.vector(0), &[1.0, 0.0, 0.0]);
1075 }
1076
1077 #[test]
1080 fn snapshot_knn_returns_only_matching_timestamp() {
1081 let mut index = make_temporal_index();
1082 index.insert(1, 1000, &[1.0, 0.0]);
1084 index.insert(2, 2000, &[0.9, 0.1]);
1086 index.insert(3, 1000, &[0.8, 0.2]);
1088
1089 let results = index.search(&[1.0, 0.0], 10, TemporalFilter::Snapshot(1000), 1.0, 1000);
1090 assert_eq!(results.len(), 2);
1091 for &(id, _) in &results {
1093 assert_eq!(index.timestamp(id), 1000);
1094 }
1095 }
1096
1097 #[test]
1098 fn snapshot_knn_no_match_returns_empty() {
1099 let mut index = make_temporal_index();
1100 index.insert(1, 1000, &[1.0, 0.0]);
1101 index.insert(2, 2000, &[0.9, 0.1]);
1102
1103 let results = index.search(&[1.0, 0.0], 10, TemporalFilter::Snapshot(5000), 1.0, 5000);
1104 assert!(results.is_empty());
1105 }
1106
1107 #[test]
1110 fn range_knn_returns_only_in_range() {
1111 let mut index = make_temporal_index();
1112 index.insert(1, 1000, &[1.0, 0.0]);
1113 index.insert(2, 2000, &[0.9, 0.1]);
1114 index.insert(3, 3000, &[0.8, 0.2]);
1115 index.insert(4, 4000, &[0.7, 0.3]);
1116
1117 let results = index.search(
1118 &[1.0, 0.0],
1119 10,
1120 TemporalFilter::Range(1500, 3500),
1121 1.0,
1122 2000,
1123 );
1124
1125 assert_eq!(results.len(), 2);
1127 for &(id, _) in &results {
1128 let ts = index.timestamp(id);
1129 assert!((1500..=3500).contains(&ts), "timestamp {ts} out of range");
1130 }
1131 }
1132
1133 #[test]
1136 fn alpha_1_is_pure_semantic() {
1137 let mut index = make_temporal_index();
1138 index.insert(1, 1000, &[1.0, 0.0]);
1140 index.insert(2, 5000, &[0.99, 0.01]);
1141 index.insert(3, 100, &[0.0, 1.0]);
1142
1143 let results = index.search(&[1.0, 0.0], 3, TemporalFilter::All, 1.0, 1000);
1144 assert_eq!(results[0].0, 0); assert_eq!(results[1].0, 1); assert_eq!(results[2].0, 2); }
1149
1150 #[test]
1151 fn alpha_0_5_prefers_temporally_closer() {
1152 let mut index = make_temporal_index();
1153 index.insert(1, 1000, &[1.0, 0.0, 0.0]); index.insert(2, 5000, &[1.0, 0.0, 0.0]); let query_ts = 4900;
1158 let results = index.search(&[1.0, 0.0, 0.0], 2, TemporalFilter::All, 0.5, query_ts);
1159
1160 assert_eq!(results[0].0, 1); assert_eq!(results[1].0, 0); }
1164
1165 #[test]
1166 fn alpha_0_5_returns_temporally_closer_than_alpha_1() {
1167 let mut index = make_temporal_index();
1168 let dim = 8;
1169 let mut rng = rand::rng();
1170
1171 for i in 0..100u64 {
1173 let ts = (i as i64) * 1000;
1174 let v: Vec<f32> = (0..dim)
1175 .map(|_| rand::Rng::random::<f32>(&mut rng))
1176 .collect();
1177 index.insert(i, ts, &v);
1178 }
1179
1180 let query: Vec<f32> = (0..dim)
1181 .map(|_| rand::Rng::random::<f32>(&mut rng))
1182 .collect();
1183 let query_ts = 50_000; let k = 10;
1185
1186 let results_pure = index.search(&query, k, TemporalFilter::All, 1.0, query_ts);
1187 let results_mixed = index.search(&query, k, TemporalFilter::All, 0.5, query_ts);
1188
1189 let avg_tdist_pure: f64 = results_pure
1191 .iter()
1192 .map(|&(id, _)| (index.timestamp(id) - query_ts).unsigned_abs() as f64)
1193 .sum::<f64>()
1194 / k as f64;
1195 let avg_tdist_mixed: f64 = results_mixed
1196 .iter()
1197 .map(|&(id, _)| (index.timestamp(id) - query_ts).unsigned_abs() as f64)
1198 .sum::<f64>()
1199 / k as f64;
1200
1201 assert!(
1202 avg_tdist_mixed <= avg_tdist_pure,
1203 "alpha=0.5 avg temporal dist ({avg_tdist_mixed:.0}) should be <= alpha=1.0 ({avg_tdist_pure:.0})"
1204 );
1205 }
1206
1207 #[test]
1210 fn alpha_1_matches_vanilla_recall() {
1211 let dim = 32;
1212 let n = 1000u32;
1213 let k = 10;
1214 let config = HnswConfig {
1215 m: 16,
1216 ef_construction: 200,
1217 ef_search: 100,
1218 ..Default::default()
1219 };
1220
1221 let mut temporal = TemporalHnsw::new(config, L2Distance);
1222 let mut rng = rand::rng();
1223 let vectors: Vec<Vec<f32>> = (0..n)
1224 .map(|_| {
1225 (0..dim)
1226 .map(|_| rand::Rng::random::<f32>(&mut rng))
1227 .collect()
1228 })
1229 .collect();
1230
1231 for (i, v) in vectors.iter().enumerate() {
1232 temporal.insert(i as u64, (i as i64) * 100, v);
1233 }
1234
1235 let n_queries = 50;
1237 let mut total_recall = 0.0;
1238
1239 for _ in 0..n_queries {
1240 let query: Vec<f32> = (0..dim)
1241 .map(|_| rand::Rng::random::<f32>(&mut rng))
1242 .collect();
1243 let temporal_results = temporal.search(&query, k, TemporalFilter::All, 1.0, 0);
1244 let truth = temporal.graph().brute_force_knn(&query, k);
1245 let recall = super::super::recall_at_k(&temporal_results, &truth);
1246 total_recall += recall;
1247 }
1248
1249 let avg_recall = total_recall / n_queries as f64;
1250 assert!(
1251 avg_recall >= 0.90,
1252 "alpha=1.0 recall = {avg_recall:.3}, expected >= 0.90 (vanilla parity)"
1253 );
1254 }
1255
1256 #[test]
1259 fn range_knn_recall() {
1260 let dim = 32;
1261 let n = 1000u32;
1262 let k = 10;
1263 let config = HnswConfig {
1264 m: 16,
1265 ef_construction: 200,
1266 ef_search: 200,
1267 ..Default::default()
1268 };
1269
1270 let mut index = TemporalHnsw::new(config, L2Distance);
1271 let mut rng = rand::rng();
1272
1273 for i in 0..n {
1274 let ts = (i as i64) * 100;
1275 let v: Vec<f32> = (0..dim)
1276 .map(|_| rand::Rng::random::<f32>(&mut rng))
1277 .collect();
1278 index.insert(i as u64, ts, &v);
1279 }
1280
1281 let filter = TemporalFilter::Range(25_000, 75_000);
1283 let bitmap = index.build_filter_bitmap(&filter);
1284
1285 let n_queries = 50;
1286 let mut total_recall = 0.0;
1287
1288 for _ in 0..n_queries {
1289 let query: Vec<f32> = (0..dim)
1290 .map(|_| rand::Rng::random::<f32>(&mut rng))
1291 .collect();
1292 let results = index.search(&query, k, filter, 1.0, 50_000);
1293
1294 let mut truth: Vec<(u32, f32)> = (0..n)
1296 .filter(|&i| bitmap.contains(i))
1297 .map(|i| {
1298 (
1299 i,
1300 index
1301 .graph()
1302 .brute_force_knn(&query, n as usize)
1303 .iter()
1304 .find(|&&(id, _)| id == i)
1305 .map(|&(_, d)| d)
1306 .unwrap_or(f32::INFINITY),
1307 )
1308 })
1309 .collect();
1310 truth.sort_by(|a, b| a.1.total_cmp(&b.1));
1311 truth.truncate(k);
1312
1313 total_recall += super::super::recall_at_k(&results, &truth);
1314 }
1315
1316 let avg_recall = total_recall / n_queries as f64;
1317 assert!(
1318 avg_recall >= 0.90,
1319 "range kNN recall = {avg_recall:.3}, expected >= 0.90"
1320 );
1321 }
1322
1323 #[test]
1326 fn trajectory_returns_all_entity_points_ordered() {
1327 let mut index = make_temporal_index();
1328
1329 for i in 0..100u32 {
1331 index.insert(1, (i as i64) * 1000, &[i as f32, 0.0]);
1332 index.insert(2, (i as i64) * 1000 + 500, &[0.0, i as f32]);
1333 }
1334
1335 let traj = index.trajectory(1, TemporalFilter::All);
1336 assert_eq!(traj.len(), 100);
1337
1338 for window in traj.windows(2) {
1340 assert!(
1341 window[0].0 <= window[1].0,
1342 "trajectory not ordered: {} > {}",
1343 window[0].0,
1344 window[1].0
1345 );
1346 }
1347
1348 for &(_, node_id) in &traj {
1350 assert_eq!(index.entity_id(node_id), 1);
1351 }
1352 }
1353
1354 #[test]
1355 fn trajectory_with_range_filter() {
1356 let mut index = make_temporal_index();
1357
1358 for i in 0..50u32 {
1359 index.insert(1, (i as i64) * 100, &[i as f32]);
1360 }
1361
1362 let traj = index.trajectory(1, TemporalFilter::Range(1000, 3000));
1363
1364 assert_eq!(traj.len(), 21);
1366 for &(ts, _) in &traj {
1367 assert!((1000..=3000).contains(&ts));
1368 }
1369 }
1370
1371 #[test]
1372 fn trajectory_unknown_entity_returns_empty() {
1373 let mut index = make_temporal_index();
1374 index.insert(1, 1000, &[1.0]);
1375 assert!(index.trajectory(999, TemporalFilter::All).is_empty());
1376 }
1377
1378 #[test]
1381 fn bitmap_memory_under_1_byte_per_vector() {
1382 let mut index = make_temporal_index();
1383
1384 for i in 0..10_000u32 {
1386 index.insert(i as u64, i as i64, &[i as f32]);
1387 }
1388
1389 let mem = index.bitmap_memory_bytes();
1390 let bytes_per_vector = mem as f64 / 10_000.0;
1391 assert!(
1392 bytes_per_vector < 1.0,
1393 "bitmap uses {bytes_per_vector:.2} bytes/vector, expected < 1.0"
1394 );
1395 }
1396
1397 #[test]
1400 fn before_filter() {
1401 let mut index = make_temporal_index();
1402 index.insert(1, 1000, &[1.0, 0.0]);
1403 index.insert(2, 2000, &[0.9, 0.1]);
1404 index.insert(3, 3000, &[0.8, 0.2]);
1405
1406 let results = index.search(&[1.0, 0.0], 10, TemporalFilter::Before(2000), 1.0, 1000);
1407 assert_eq!(results.len(), 2);
1408 for &(id, _) in &results {
1409 assert!(index.timestamp(id) <= 2000);
1410 }
1411 }
1412
1413 #[test]
1414 fn after_filter() {
1415 let mut index = make_temporal_index();
1416 index.insert(1, 1000, &[1.0, 0.0]);
1417 index.insert(2, 2000, &[0.9, 0.1]);
1418 index.insert(3, 3000, &[0.8, 0.2]);
1419
1420 let results = index.search(&[1.0, 0.0], 10, TemporalFilter::After(2000), 1.0, 3000);
1421 assert_eq!(results.len(), 2);
1422 for &(id, _) in &results {
1423 assert!(index.timestamp(id) >= 2000);
1424 }
1425 }
1426
1427 #[test]
1430 fn works_with_cosine_metric() {
1431 let config = HnswConfig {
1432 m: 16,
1433 ef_construction: 100,
1434 ef_search: 50,
1435 ..Default::default()
1436 };
1437 let mut index = TemporalHnsw::new(config, CosineDistance);
1438
1439 index.insert(1, 1000, &[1.0, 0.0, 0.0]);
1440 index.insert(2, 2000, &[0.99, 0.01, 0.0]);
1441 index.insert(3, 3000, &[0.0, 0.0, 1.0]);
1442
1443 let results = index.search(&[1.0, 0.0, 0.0], 2, TemporalFilter::All, 1.0, 0);
1444 assert_eq!(results[0].0, 0);
1445 assert_eq!(results[1].0, 1);
1446 }
1447
1448 #[test]
1451 fn insert_with_metadata_stores_and_retrieves() {
1452 let config = HnswConfig::default();
1453 let mut index = TemporalHnsw::new(config, L2Distance);
1454
1455 let mut meta = std::collections::HashMap::new();
1456 meta.insert("reward".to_string(), "0.8".to_string());
1457 meta.insert("step_index".to_string(), "0".to_string());
1458
1459 let id = index.insert_with_metadata(1, 1000, &[1.0, 0.0, 0.0], meta);
1460
1461 let retrieved = index.node_metadata(id);
1462 assert_eq!(retrieved.get("reward").unwrap(), "0.8");
1463 assert_eq!(retrieved.get("step_index").unwrap(), "0");
1464 }
1465
1466 #[test]
1467 fn insert_with_metadata_enables_store_lazily() {
1468 let config = HnswConfig::default();
1469 let mut index = TemporalHnsw::new(config, L2Distance);
1470
1471 index.insert(1, 1000, &[1.0, 0.0, 0.0]);
1473
1474 let mut meta = std::collections::HashMap::new();
1476 meta.insert("reward".to_string(), "0.9".to_string());
1477 let id = index.insert_with_metadata(2, 2000, &[0.0, 1.0, 0.0], meta);
1478
1479 assert!(index.node_metadata(0).is_empty());
1481 assert_eq!(index.node_metadata(id).get("reward").unwrap(), "0.9");
1483 }
1484
1485 #[test]
1486 fn search_with_metadata_filter() {
1487 use cvx_core::TemporalIndexAccess;
1488 use cvx_core::types::MetadataFilter;
1489
1490 let config = HnswConfig {
1491 m: 16,
1492 ef_construction: 100,
1493 ef_search: 50,
1494 ..Default::default()
1495 };
1496 let mut index = TemporalHnsw::new(config, L2Distance);
1497
1498 for i in 0..10u64 {
1500 let mut meta = std::collections::HashMap::new();
1501 meta.insert("reward".to_string(), format!("{}", i as f64 * 0.1));
1502 meta.insert("step_index".to_string(), "0".to_string());
1503 index.insert_with_metadata(i, i as i64 * 1000, &[i as f32, 0.0, 0.0], meta);
1504 }
1505
1506 let filter = MetadataFilter::new().gte("reward", 0.5);
1508 let results =
1509 index.search_with_metadata(&[7.0, 0.0, 0.0], 5, TemporalFilter::All, 1.0, 0, &filter);
1510
1511 for &(nid, _) in &results {
1513 let meta = index.node_metadata(nid);
1514 let reward: f64 = meta.get("reward").unwrap().parse().unwrap();
1515 assert!(reward >= 0.5, "node {nid} has reward {reward} < 0.5");
1516 }
1517 assert!(!results.is_empty(), "should find some results");
1518 }
1519
1520 fn make_region_index() -> TemporalHnsw<L2Distance> {
1524 let config = HnswConfig {
1525 m: 4,
1526 ef_construction: 50,
1527 ef_search: 50,
1528 ..Default::default()
1529 };
1530 let mut index = TemporalHnsw::new(config, L2Distance);
1531 let mut rng = rand::rng();
1532 for i in 0..200u64 {
1534 let v: Vec<f32> = (0..8).map(|_| rand::Rng::random::<f32>(&mut rng)).collect();
1535 let entity = i % 4;
1536 index.insert(entity, i as i64 * 1000, &v);
1537 }
1538 index
1539 }
1540
1541 #[test]
1542 fn region_assignments_covers_all_nodes() {
1543 let index = make_region_index();
1544 let level = 1;
1545 let assignments = index.region_assignments(level, TemporalFilter::All);
1546
1547 let total: usize = assignments.values().map(|v| v.len()).sum();
1548 assert_eq!(
1549 total,
1550 index.len(),
1551 "sum of all region member counts ({total}) must equal index size ({})",
1552 index.len()
1553 );
1554 }
1555
1556 #[test]
1557 fn region_assignments_consistent_with_regions_counts() {
1558 let index = make_region_index();
1559 let level = 1;
1560 let regions = index.regions(level);
1561 let assignments = index.region_assignments(level, TemporalFilter::All);
1562
1563 for &(hub_id, _, count) in ®ions {
1564 let assigned_count = assignments.get(&hub_id).map_or(0, |v| v.len());
1565 assert_eq!(
1566 assigned_count, count,
1567 "region hub {hub_id}: region_assignments has {assigned_count} members but regions() reports {count}"
1568 );
1569 }
1570 }
1571
1572 #[test]
1573 fn region_assignments_temporal_filter_reduces_count() {
1574 let index = make_region_index();
1575 let level = 1;
1576
1577 let all = index.region_assignments(level, TemporalFilter::All);
1578 let total_all: usize = all.values().map(|v| v.len()).sum();
1579
1580 let filtered = index.region_assignments(level, TemporalFilter::Range(50_000, 150_000));
1582 let total_filtered: usize = filtered.values().map(|v| v.len()).sum();
1583
1584 assert!(
1585 total_filtered < total_all,
1586 "Range filter should reduce total members: filtered={total_filtered}, all={total_all}"
1587 );
1588
1589 for members in filtered.values() {
1591 for &(_eid, ts) in members {
1592 assert!(
1593 (50_000..=150_000).contains(&ts),
1594 "filtered result has timestamp {ts} outside [50000, 150000]"
1595 );
1596 }
1597 }
1598 }
1599
1600 #[test]
1601 fn region_assignments_each_member_in_exactly_one_region() {
1602 let index = make_region_index();
1603 let level = 1;
1604 let assignments = index.region_assignments(level, TemporalFilter::All);
1605
1606 let mut seen = std::collections::HashSet::new();
1608 let mut total = 0usize;
1609 for members in assignments.values() {
1610 for &(eid, ts) in members {
1611 total += 1;
1612 let _inserted = seen.insert((eid, ts));
1613 }
1616 }
1617
1618 assert_eq!(
1620 total,
1621 index.len(),
1622 "total assigned members ({total}) != index size ({}); a node appeared in multiple or no regions",
1623 index.len()
1624 );
1625
1626 let hubs: std::collections::HashSet<u32> = assignments.keys().copied().collect();
1628 let level_hubs: std::collections::HashSet<u32> =
1629 index.graph().nodes_at_level(level).into_iter().collect();
1630 for hub in &hubs {
1631 assert!(
1632 level_hubs.contains(hub),
1633 "assignment hub {hub} is not a level-{level} node"
1634 );
1635 }
1636 }
1637
1638 #[test]
1641 fn compute_centroid_empty_index() {
1642 let index = make_temporal_index();
1643 assert!(index.compute_centroid().is_none());
1644 }
1645
1646 #[test]
1647 fn compute_centroid_single_vector() {
1648 let mut index = make_temporal_index();
1649 index.insert(1, 1000, &[3.0, 4.0, 5.0]);
1650 let centroid = index.compute_centroid().unwrap();
1651 assert_eq!(centroid, vec![3.0, 4.0, 5.0]);
1652 }
1653
1654 #[test]
1655 fn compute_centroid_mean_of_vectors() {
1656 let mut index = make_temporal_index();
1657 index.insert(1, 1000, &[2.0, 0.0]);
1658 index.insert(2, 2000, &[4.0, 6.0]);
1659 let centroid = index.compute_centroid().unwrap();
1660 assert!((centroid[0] - 3.0).abs() < 1e-6);
1661 assert!((centroid[1] - 3.0).abs() < 1e-6);
1662 }
1663
1664 #[test]
1665 fn set_and_clear_centroid() {
1666 let mut index = make_temporal_index();
1667 index.insert(1, 1000, &[1.0, 2.0]);
1668
1669 assert!(index.centroid().is_none());
1670
1671 index.set_centroid(vec![0.5, 1.0]);
1672 assert!(index.centroid().is_some());
1673 assert_eq!(index.centroid().unwrap(), &[0.5, 1.0]);
1674
1675 index.clear_centroid();
1676 assert!(index.centroid().is_none());
1677 }
1678
1679 #[test]
1680 fn centered_vector_subtracts_centroid() {
1681 let mut index = make_temporal_index();
1682 index.insert(1, 1000, &[1.0, 2.0]);
1683 index.set_centroid(vec![0.5, 1.0]);
1684
1685 let centered = index.centered_vector(&[3.0, 5.0]);
1686 assert!((centered[0] - 2.5).abs() < 1e-6);
1687 assert!((centered[1] - 4.0).abs() < 1e-6);
1688 }
1689
1690 #[test]
1691 fn centered_vector_without_centroid_is_identity() {
1692 let mut index = make_temporal_index();
1693 index.insert(1, 1000, &[1.0, 2.0]);
1694 let centered = index.centered_vector(&[3.0, 5.0]);
1696 assert_eq!(centered, vec![3.0, 5.0]);
1697 }
1698
1699 #[test]
1700 fn centroid_survives_save_load() {
1701 let dir = std::env::temp_dir();
1702 let path = dir.join("test_centroid_snapshot.cvx");
1703
1704 let mut index = make_temporal_index();
1705 index.insert(1, 1000, &[1.0, 2.0, 3.0]);
1706 index.insert(2, 2000, &[4.0, 5.0, 6.0]);
1707 index.set_centroid(vec![2.5, 3.5, 4.5]);
1708
1709 index.save(&path).unwrap();
1710
1711 let loaded = TemporalHnsw::load(&path, L2Distance).unwrap();
1712 assert_eq!(loaded.centroid().unwrap(), &[2.5, 3.5, 4.5]);
1713
1714 std::fs::remove_file(&path).ok();
1715 }
1716
1717 #[test]
1718 fn load_without_centroid_is_none() {
1719 let dir = std::env::temp_dir();
1722 let path = dir.join("test_no_centroid_snapshot.cvx");
1723
1724 let mut index = make_temporal_index();
1725 index.insert(1, 1000, &[1.0, 0.0]);
1726 index.save(&path).unwrap();
1728
1729 let loaded = TemporalHnsw::load(&path, L2Distance).unwrap();
1730 assert!(loaded.centroid().is_none());
1731
1732 std::fs::remove_file(&path).ok();
1733 }
1734
1735 #[test]
1736 fn compute_centroid_precision_with_many_vectors() {
1737 let config = HnswConfig {
1738 m: 4,
1739 ef_construction: 20,
1740 ef_search: 10,
1741 ..Default::default()
1742 };
1743 let mut index = TemporalHnsw::new(config, L2Distance);
1744
1745 for i in 0..1000u64 {
1747 let v = vec![10.0 + (i as f32 * 0.001), 20.0 - (i as f32 * 0.001)];
1749 index.insert(i, i as i64, &v);
1750 }
1751
1752 let centroid = index.compute_centroid().unwrap();
1753 assert!(
1755 (centroid[0] - 10.4995).abs() < 0.01,
1756 "centroid[0] = {}, expected ~10.4995",
1757 centroid[0]
1758 );
1759 assert!(
1760 (centroid[1] - 19.5005).abs() < 0.01,
1761 "centroid[1] = {}, expected ~19.5005",
1762 centroid[1]
1763 );
1764 }
1765
1766 #[test]
1767 fn search_with_empty_metadata_filter_matches_all() {
1768 use cvx_core::TemporalIndexAccess;
1769 use cvx_core::types::MetadataFilter;
1770
1771 let config = HnswConfig::default();
1772 let mut index = TemporalHnsw::new(config, L2Distance);
1773
1774 for i in 0..5u64 {
1775 index.insert(i, i as i64 * 1000, &[i as f32, 0.0]);
1776 }
1777
1778 let filter = MetadataFilter::new();
1779 let results =
1780 index.search_with_metadata(&[2.0, 0.0], 3, TemporalFilter::All, 1.0, 0, &filter);
1781 assert_eq!(results.len(), 3);
1782 }
1783
1784 #[test]
1787 fn insert_with_reward_stores_reward() {
1788 let mut index = make_temporal_index();
1789 let n0 = index.insert(1, 1000, &[1.0, 0.0]);
1790 let n1 = index.insert_with_reward(2, 2000, &[0.0, 1.0], 0.8);
1791
1792 assert!(index.reward(n0).is_nan()); assert!((index.reward(n1) - 0.8).abs() < 1e-6);
1794 }
1795
1796 #[test]
1797 fn set_reward_updates_retroactively() {
1798 let mut index = make_temporal_index();
1799 let n0 = index.insert(1, 1000, &[1.0, 0.0]);
1800 assert!(index.reward(n0).is_nan());
1801
1802 index.set_reward(n0, 0.95);
1803 assert!((index.reward(n0) - 0.95).abs() < 1e-6);
1804 }
1805
1806 #[test]
1807 fn search_with_reward_filters() {
1808 let mut index = make_temporal_index();
1809 for i in 0..10u64 {
1811 index.insert_with_reward(i, i as i64 * 1000, &[i as f32, 0.0, 0.0], i as f32 * 0.1);
1812 }
1813
1814 let results =
1816 index.search_with_reward(&[7.0, 0.0, 0.0], 5, TemporalFilter::All, 1.0, 0, 0.5);
1817 assert!(!results.is_empty());
1818 for &(node_id, _) in &results {
1819 let r = index.reward(node_id);
1820 assert!(r >= 0.5, "node {node_id} has reward {r} < 0.5");
1821 }
1822 }
1823
1824 #[test]
1825 fn search_with_reward_no_matches() {
1826 let mut index = make_temporal_index();
1827 for i in 0..5u64 {
1828 index.insert_with_reward(i, i as i64 * 1000, &[i as f32, 0.0], 0.1);
1829 }
1830
1831 let results = index.search_with_reward(&[2.0, 0.0], 5, TemporalFilter::All, 1.0, 0, 0.9);
1832 assert!(results.is_empty());
1833 }
1834
1835 #[test]
1836 fn reward_survives_save_load() {
1837 let dir = std::env::temp_dir();
1838 let path = dir.join("test_reward_snapshot.cvx");
1839
1840 let mut index = make_temporal_index();
1841 index.insert_with_reward(1, 1000, &[1.0, 0.0], 0.75);
1842 index.insert(2, 2000, &[0.0, 1.0]); index.save(&path).unwrap();
1844
1845 let loaded = TemporalHnsw::load(&path, L2Distance).unwrap();
1846 assert!((loaded.reward(0) - 0.75).abs() < 1e-6);
1847 assert!(loaded.reward(1).is_nan());
1848
1849 std::fs::remove_file(&path).ok();
1850 }
1851
1852 #[test]
1855 fn normalize_semantic_distance_clamps() {
1856 let mut index = make_temporal_index();
1857 index.insert(1, 1000, &[1.0, 0.0]);
1858
1859 assert!((index.normalize_semantic_distance(0.0) - 0.0).abs() < 1e-6);
1861 assert!((index.normalize_semantic_distance(1.0) - 0.5).abs() < 1e-6);
1862 assert!((index.normalize_semantic_distance(2.0) - 1.0).abs() < 1e-6);
1863 assert!((index.normalize_semantic_distance(4.0) - 1.0).abs() < 1e-6);
1865 }
1866
1867 #[test]
1868 fn recency_penalty_zero_lambda() {
1869 let mut index = make_temporal_index();
1870 index.insert(1, 1000, &[1.0, 0.0]);
1871 index.insert(2, 2000, &[0.0, 1.0]);
1872 assert!((index.recency_penalty(1000, 0.0) - 0.0).abs() < 1e-6);
1874 }
1875
1876 #[test]
1877 fn recency_penalty_recent_is_lower() {
1878 let mut index = make_temporal_index();
1879 for i in 0..10u64 {
1880 index.insert(i, (i * 1000) as i64, &[i as f32, 0.0]);
1881 }
1882 let recent = index.recency_penalty(9000, 1.0); let old = index.recency_penalty(0, 1.0); assert!(
1885 recent < old,
1886 "recent penalty ({recent}) should be < old penalty ({old})"
1887 );
1888 }
1889
1890 #[test]
1891 fn search_with_recency_prefers_recent() {
1892 let mut index = make_temporal_index();
1893 index.insert(1, 1000, &[1.0, 0.0, 0.0]);
1895 index.insert(2, 9000, &[1.0, 0.0, 0.0]); let results = index.search_with_recency(
1898 &[1.0, 0.0, 0.0],
1899 2,
1900 TemporalFilter::All,
1901 1.0, 0,
1903 2.0, 0.5, );
1906
1907 assert_eq!(results.len(), 2);
1908 assert_eq!(
1910 results[0].0,
1911 1, "recent node should rank first"
1913 );
1914 }
1915
1916 #[test]
1917 fn search_normalized_distances_balanced() {
1918 let mut index = make_temporal_index();
1919 index.insert(1, 1000, &[1.0, 0.0, 0.0]);
1921 index.insert(2, 5000, &[0.0, 1.0, 0.0]);
1923 let results = index.search(
1925 &[0.9, 0.1, 0.0],
1926 2,
1927 TemporalFilter::All,
1928 0.5, 4900,
1930 );
1931
1932 assert_eq!(results.len(), 2);
1935 }
1936
1937 #[test]
1940 fn bulk_insert_parallel_basic() {
1941 let config = HnswConfig {
1942 m: 8,
1943 ef_construction: 50,
1944 ef_search: 50,
1945 ..Default::default()
1946 };
1947 let mut index = TemporalHnsw::new(config, L2Distance);
1948
1949 let n = 500;
1950 let dim = 16;
1951 let mut rng = rand::rng();
1952 let eids: Vec<u64> = (0..n).map(|i| i as u64 % 10).collect();
1953 let tss: Vec<i64> = (0..n).map(|i| i as i64 * 100).collect();
1954 let vecs: Vec<Vec<f32>> = (0..n)
1955 .map(|_| {
1956 (0..dim)
1957 .map(|_| rand::Rng::random::<f32>(&mut rng))
1958 .collect()
1959 })
1960 .collect();
1961 let vec_refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
1962
1963 let count = index.bulk_insert_parallel(&eids, &tss, &vec_refs);
1964 assert_eq!(count, n);
1965 assert_eq!(index.len(), n);
1966
1967 let results = index.search(&vecs[0], 5, TemporalFilter::All, 1.0, 0);
1969 assert_eq!(results.len(), 5);
1970 }
1971
1972 #[test]
1973 fn bulk_insert_parallel_recall() {
1974 let config = HnswConfig {
1975 m: 16,
1976 ef_construction: 100,
1977 ef_search: 100,
1978 ..Default::default()
1979 };
1980 let mut index = TemporalHnsw::new(config, L2Distance);
1981
1982 let n = 1000;
1983 let dim = 32;
1984 let mut rng = rand::rng();
1985 let eids: Vec<u64> = (0..n).map(|i| i as u64).collect();
1986 let tss: Vec<i64> = (0..n).map(|i| i as i64 * 100).collect();
1987 let vecs: Vec<Vec<f32>> = (0..n)
1988 .map(|_| {
1989 (0..dim)
1990 .map(|_| rand::Rng::random::<f32>(&mut rng))
1991 .collect()
1992 })
1993 .collect();
1994 let vec_refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
1995
1996 index.bulk_insert_parallel(&eids, &tss, &vec_refs);
1997
1998 let k = 10;
2000 let mut total_recall = 0.0;
2001 let n_queries = 20;
2002 for _ in 0..n_queries {
2003 let query: Vec<f32> = (0..dim)
2004 .map(|_| rand::Rng::random::<f32>(&mut rng))
2005 .collect();
2006 let results = index.search(&query, k, TemporalFilter::All, 1.0, 0);
2007 let truth = index.graph().brute_force_knn(&query, k);
2008 total_recall += super::super::recall_at_k(&results, &truth);
2009 }
2010 let avg_recall = total_recall / n_queries as f64;
2011 assert!(
2012 avg_recall >= 0.80,
2013 "parallel bulk_insert recall = {avg_recall:.3}, expected >= 0.80"
2014 );
2015 }
2016}