1use std::collections::BTreeMap;
20use std::path::Path;
21
22use cvx_core::traits::DistanceMetric;
23use cvx_core::types::TemporalFilter;
24
25use super::HnswConfig;
26use super::temporal::TemporalHnsw;
27
28#[derive(Debug, Clone)]
32pub struct PartitionConfig {
33 pub partition_duration_us: i64,
36 pub hnsw_config: HnswConfig,
38}
39
40impl Default for PartitionConfig {
41 fn default() -> Self {
42 Self {
43 partition_duration_us: 7 * 24 * 3600 * 1_000_000, hnsw_config: HnswConfig::default(),
45 }
46 }
47}
48
49struct Partition<D: DistanceMetric> {
53 start_us: i64,
55 end_us: i64,
57 hnsw: TemporalHnsw<D>,
59 id_offset: u32,
62 point_count: usize,
64}
65
66impl<D: DistanceMetric> Partition<D> {
67 fn overlaps(&self, filter: &TemporalFilter) -> bool {
69 match filter {
70 TemporalFilter::All => true,
71 TemporalFilter::Snapshot(t) => *t >= self.start_us && *t < self.end_us,
72 TemporalFilter::Range(start, end) => *start < self.end_us && *end >= self.start_us,
73 TemporalFilter::Before(t) => *t >= self.start_us,
74 TemporalFilter::After(t) => *t < self.end_us,
75 }
76 }
77
78 fn contains_timestamp(&self, timestamp: i64) -> bool {
80 timestamp >= self.start_us && timestamp < self.end_us
81 }
82}
83
84pub struct PartitionedTemporalHnsw<D: DistanceMetric> {
91 partitions: Vec<Partition<D>>,
93 config: PartitionConfig,
95 global_entity_index: BTreeMap<u64, Vec<(usize, u32)>>,
97 total_points: usize,
99 next_global_id: u32,
101 metric: D,
103}
104
105impl<D: DistanceMetric + Clone> PartitionedTemporalHnsw<D> {
106 pub fn new(config: PartitionConfig, metric: D) -> Self {
108 Self {
109 partitions: Vec::new(),
110 config,
111 global_entity_index: BTreeMap::new(),
112 total_points: 0,
113 next_global_id: 0,
114 metric,
115 }
116 }
117
118 pub fn from_single(index: TemporalHnsw<D>, config: PartitionConfig, metric: D) -> Self {
122 let point_count = index.len();
123 let (min_ts, max_ts) = if point_count > 0 {
124 let mut min = i64::MAX;
125 let mut max = i64::MIN;
126 for i in 0..point_count {
127 let ts = index.timestamp(i as u32);
128 min = min.min(ts);
129 max = max.max(ts);
130 }
131 (min, max)
132 } else {
133 (0, 0)
134 };
135
136 let mut global_entity_index = BTreeMap::new();
138 for i in 0..point_count {
139 let eid = index.entity_id(i as u32);
140 global_entity_index
141 .entry(eid)
142 .or_insert_with(Vec::new)
143 .push((0usize, i as u32));
144 }
145
146 let partition = Partition {
147 start_us: min_ts,
148 end_us: max_ts + 1,
149 hnsw: index,
150 id_offset: 0,
151 point_count,
152 };
153
154 Self {
155 partitions: vec![partition],
156 config,
157 global_entity_index,
158 total_points: point_count,
159 next_global_id: point_count as u32,
160 metric,
161 }
162 }
163
164 pub fn insert(&mut self, entity_id: u64, timestamp: i64, vector: &[f32]) -> u32 {
166 let part_idx = self.ensure_partition_for(timestamp);
167
168 let local_id = self.partitions[part_idx]
169 .hnsw
170 .insert(entity_id, timestamp, vector);
171
172 let global_id = self.partitions[part_idx].id_offset + local_id;
173 self.partitions[part_idx].point_count += 1;
174
175 self.global_entity_index
177 .entry(entity_id)
178 .or_default()
179 .push((part_idx, local_id));
180
181 self.total_points += 1;
182 self.next_global_id = self.next_global_id.max(global_id + 1);
183
184 global_id
185 }
186
187 pub fn search(
189 &self,
190 query: &[f32],
191 k: usize,
192 filter: TemporalFilter,
193 alpha: f32,
194 query_timestamp: i64,
195 ) -> Vec<(u32, f32)> {
196 let mut all_results: Vec<(u32, f32)> = Vec::new();
197
198 for part in &self.partitions {
199 if !part.overlaps(&filter) {
200 continue;
201 }
202
203 let local_results = part.hnsw.search(query, k, filter, alpha, query_timestamp);
204
205 for (local_id, score) in local_results {
207 all_results.push((part.id_offset + local_id, score));
208 }
209 }
210
211 all_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
213 all_results.truncate(k);
214 all_results
215 }
216
217 pub fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
219 let Some(entries) = self.global_entity_index.get(&entity_id) else {
220 return Vec::new();
221 };
222
223 let mut result: Vec<(i64, u32)> = Vec::new();
224
225 for &(part_idx, local_id) in entries {
226 let part = &self.partitions[part_idx];
227 let ts = part.hnsw.timestamp(local_id);
228 if filter.matches(ts) {
229 result.push((ts, part.id_offset + local_id));
230 }
231 }
232
233 result.sort_by_key(|&(ts, _)| ts);
234 result
235 }
236
237 pub fn vector(&self, global_id: u32) -> Vec<f32> {
239 let (part, local_id) = self.resolve_global_id(global_id);
240 part.hnsw.vector(local_id).to_vec()
241 }
242
243 pub fn entity_id(&self, global_id: u32) -> u64 {
245 let (part, local_id) = self.resolve_global_id(global_id);
246 part.hnsw.entity_id(local_id)
247 }
248
249 pub fn timestamp(&self, global_id: u32) -> i64 {
251 let (part, local_id) = self.resolve_global_id(global_id);
252 part.hnsw.timestamp(local_id)
253 }
254
255 pub fn len(&self) -> usize {
257 self.total_points
258 }
259
260 pub fn is_empty(&self) -> bool {
262 self.total_points == 0
263 }
264
265 pub fn num_partitions(&self) -> usize {
267 self.partitions.len()
268 }
269
270 pub fn partition_info(&self) -> Vec<(i64, i64, usize)> {
272 self.partitions
273 .iter()
274 .map(|p| (p.start_us, p.end_us, p.point_count))
275 .collect()
276 }
277
278 pub fn save(&self, dir: &Path) -> std::io::Result<()> {
280 std::fs::create_dir_all(dir)?;
281
282 let meta = PartitionMeta {
284 partition_duration_us: self.config.partition_duration_us,
285 num_partitions: self.partitions.len(),
286 partitions: self
287 .partitions
288 .iter()
289 .map(|p| PartitionMetaEntry {
290 start_us: p.start_us,
291 end_us: p.end_us,
292 id_offset: p.id_offset,
293 point_count: p.point_count,
294 })
295 .collect(),
296 };
297 let meta_bytes = postcard::to_allocvec(&meta).map_err(std::io::Error::other)?;
298 std::fs::write(dir.join("partitions.meta"), meta_bytes)?;
299
300 for (i, part) in self.partitions.iter().enumerate() {
302 let path = dir.join(format!("partition_{i}.bin"));
303 part.hnsw.save(&path)?;
304 }
305
306 Ok(())
307 }
308
309 pub fn load(dir: &Path, metric: D) -> std::io::Result<Self> {
311 let meta_bytes = std::fs::read(dir.join("partitions.meta"))?;
312 let meta: PartitionMeta = postcard::from_bytes(&meta_bytes)
313 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
314
315 let mut partitions = Vec::with_capacity(meta.num_partitions);
316 let mut global_entity_index: BTreeMap<u64, Vec<(usize, u32)>> = BTreeMap::new();
317 let mut total_points = 0;
318 let mut next_global_id: u32 = 0;
319
320 for (i, pm) in meta.partitions.iter().enumerate() {
321 let path = dir.join(format!("partition_{i}.bin"));
322 let hnsw = TemporalHnsw::load(&path, metric.clone())?;
323
324 for local_id in 0..hnsw.len() as u32 {
326 let eid = hnsw.entity_id(local_id);
327 global_entity_index
328 .entry(eid)
329 .or_default()
330 .push((i, local_id));
331 }
332
333 total_points += pm.point_count;
334 next_global_id = next_global_id.max(pm.id_offset + hnsw.len() as u32);
335
336 partitions.push(Partition {
337 start_us: pm.start_us,
338 end_us: pm.end_us,
339 hnsw,
340 id_offset: pm.id_offset,
341 point_count: pm.point_count,
342 });
343 }
344
345 let config = PartitionConfig {
346 partition_duration_us: meta.partition_duration_us,
347 hnsw_config: HnswConfig::default(),
348 };
349
350 Ok(Self {
351 partitions,
352 config,
353 global_entity_index,
354 total_points,
355 next_global_id,
356 metric,
357 })
358 }
359
360 fn ensure_partition_for(&mut self, timestamp: i64) -> usize {
364 for (i, part) in self.partitions.iter().enumerate() {
366 if part.contains_timestamp(timestamp) {
367 return i;
368 }
369 }
370
371 let dur = self.config.partition_duration_us;
373 let start = if dur > 0 {
374 (timestamp / dur) * dur
375 } else {
376 timestamp
377 };
378 let end = start + dur;
379
380 let id_offset = self.next_global_id;
381
382 let partition = Partition {
383 start_us: start,
384 end_us: end,
385 hnsw: TemporalHnsw::new(self.config.hnsw_config.clone(), self.metric.clone()),
386 id_offset,
387 point_count: 0,
388 };
389
390 self.partitions.push(partition);
391
392 let idx = self.partitions.len() - 1;
394 self.partitions.sort_by_key(|p| p.start_us);
395
396 self.partitions
398 .iter()
399 .position(|p| p.start_us == start)
400 .unwrap_or(idx)
401 }
402
403 fn resolve_global_id(&self, global_id: u32) -> (&Partition<D>, u32) {
405 for part in &self.partitions {
406 if global_id >= part.id_offset && global_id < part.id_offset + part.hnsw.len() as u32 {
407 return (part, global_id - part.id_offset);
408 }
409 }
410 panic!(
411 "global_id {global_id} not found in any partition (total: {})",
412 self.total_points
413 );
414 }
415}
416
417impl<D: DistanceMetric + Clone> cvx_core::TemporalIndexAccess for PartitionedTemporalHnsw<D> {
420 fn search_raw(
421 &self,
422 query: &[f32],
423 k: usize,
424 filter: TemporalFilter,
425 alpha: f32,
426 query_timestamp: i64,
427 ) -> Vec<(u32, f32)> {
428 self.search(query, k, filter, alpha, query_timestamp)
429 }
430
431 fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
432 self.trajectory(entity_id, filter)
433 }
434
435 fn vector(&self, node_id: u32) -> Vec<f32> {
436 self.vector(node_id)
437 }
438
439 fn entity_id(&self, node_id: u32) -> u64 {
440 self.entity_id(node_id)
441 }
442
443 fn timestamp(&self, node_id: u32) -> i64 {
444 self.timestamp(node_id)
445 }
446
447 fn len(&self) -> usize {
448 self.len()
449 }
450}
451
452#[derive(serde::Serialize, serde::Deserialize)]
455struct PartitionMeta {
456 partition_duration_us: i64,
457 num_partitions: usize,
458 partitions: Vec<PartitionMetaEntry>,
459}
460
461#[derive(serde::Serialize, serde::Deserialize)]
462struct PartitionMetaEntry {
463 start_us: i64,
464 end_us: i64,
465 id_offset: u32,
466 point_count: usize,
467}
468
469#[cfg(test)]
472mod tests {
473 use super::*;
474 use crate::metrics::L2Distance;
475
476 fn default_config() -> PartitionConfig {
477 PartitionConfig {
478 partition_duration_us: 7_000_000, hnsw_config: HnswConfig::default(),
480 }
481 }
482
483 #[test]
486 fn insert_single_partition() {
487 let mut index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
488
489 for i in 0..10u64 {
491 let ts = i as i64 * 100_000; index.insert(i, ts, &[i as f32, 0.0, 0.0]);
493 }
494
495 assert_eq!(index.len(), 10);
496 assert_eq!(index.num_partitions(), 1);
497 }
498
499 #[test]
500 fn insert_multiple_partitions() {
501 let mut index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
502
503 for i in 0..30u64 {
505 let ts = i as i64 * 1_000_000; index.insert(0, ts, &[i as f32, 0.0, 0.0]);
507 }
508
509 assert_eq!(index.len(), 30);
510 assert!(
511 index.num_partitions() >= 3,
512 "expected >= 3 partitions, got {}",
513 index.num_partitions()
514 );
515 }
516
517 #[test]
518 fn search_across_partitions() {
519 let mut index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
520
521 for i in 0..30u64 {
522 let ts = i as i64 * 1_000_000;
523 index.insert(i, ts, &[i as f32, 0.0, 0.0]);
524 }
525
526 let results = index.search(&[5.0, 0.0, 0.0], 3, TemporalFilter::All, 1.0, 0);
527 assert_eq!(results.len(), 3);
528 }
529
530 #[test]
531 fn search_with_temporal_filter_prunes_partitions() {
532 let mut index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
533
534 for i in 0..14u64 {
536 let ts = i as i64 * 1_000_000;
537 index.insert(i, ts, &[i as f32, 0.0, 0.0]);
538 }
539
540 let results = index.search(
542 &[3.0, 0.0, 0.0],
543 5,
544 TemporalFilter::Range(0, 6_999_999),
545 1.0,
546 3_000_000,
547 );
548
549 for &(global_id, _) in &results {
551 let ts = index.timestamp(global_id);
552 assert!(ts < 7_000_000, "got timestamp {ts} outside partition 0");
553 }
554 }
555
556 #[test]
559 fn trajectory_across_partitions() {
560 let mut index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
561
562 for i in 0..21u64 {
564 let ts = i as i64 * 1_000_000;
565 index.insert(42, ts, &[i as f32, 0.0]);
566 }
567
568 let traj = index.trajectory(42, TemporalFilter::All);
569 assert_eq!(traj.len(), 21, "should get all 21 points");
570
571 for w in traj.windows(2) {
573 assert!(w[0].0 <= w[1].0, "trajectory should be sorted");
574 }
575 }
576
577 #[test]
578 fn trajectory_with_filter() {
579 let mut index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
580
581 for i in 0..21u64 {
582 let ts = i as i64 * 1_000_000;
583 index.insert(42, ts, &[i as f32]);
584 }
585
586 let traj = index.trajectory(42, TemporalFilter::Range(5_000_000, 15_000_000));
587 for &(ts, _) in &traj {
588 assert!(
589 (5_000_000..=15_000_000).contains(&ts),
590 "ts {ts} outside filter range"
591 );
592 }
593 assert!(!traj.is_empty());
594 }
595
596 #[test]
597 fn trajectory_unknown_entity() {
598 let index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
599 let traj = index.trajectory(999, TemporalFilter::All);
600 assert!(traj.is_empty());
601 }
602
603 #[test]
606 fn vector_entity_timestamp_resolution() {
607 let mut index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
608
609 let id0 = index.insert(1, 0, &[1.0, 2.0, 3.0]);
610 let id1 = index.insert(2, 8_000_000, &[4.0, 5.0, 6.0]); assert_eq!(index.entity_id(id0), 1);
613 assert_eq!(index.entity_id(id1), 2);
614 assert_eq!(index.timestamp(id0), 0);
615 assert_eq!(index.timestamp(id1), 8_000_000);
616 assert_eq!(index.vector(id0), vec![1.0, 2.0, 3.0]);
617 assert_eq!(index.vector(id1), vec![4.0, 5.0, 6.0]);
618 }
619
620 #[test]
623 fn trait_search_works() {
624 let mut index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
625 for i in 0..20u64 {
626 index.insert(i, i as i64 * 500_000, &[i as f32, 0.0]);
627 }
628
629 let trait_ref: &dyn cvx_core::TemporalIndexAccess = &index;
630 let results = trait_ref.search_raw(&[10.0, 0.0], 3, TemporalFilter::All, 1.0, 0);
631 assert_eq!(results.len(), 3);
632 }
633
634 #[test]
635 fn trait_trajectory_works() {
636 let mut index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
637 for i in 0..10u64 {
638 index.insert(42, i as i64 * 1_000_000, &[i as f32]);
639 }
640
641 let trait_ref: &dyn cvx_core::TemporalIndexAccess = &index;
642 let traj = trait_ref.trajectory(42, TemporalFilter::All);
643 assert_eq!(traj.len(), 10);
644 }
645
646 #[test]
649 fn partition_overlap_all() {
650 let part = Partition {
651 start_us: 100,
652 end_us: 200,
653 hnsw: TemporalHnsw::new(HnswConfig::default(), L2Distance),
654 id_offset: 0,
655 point_count: 0,
656 };
657 assert!(part.overlaps(&TemporalFilter::All));
658 }
659
660 #[test]
661 fn partition_overlap_range() {
662 let part = Partition {
663 start_us: 100,
664 end_us: 200,
665 hnsw: TemporalHnsw::new(HnswConfig::default(), L2Distance),
666 id_offset: 0,
667 point_count: 0,
668 };
669 assert!(part.overlaps(&TemporalFilter::Range(150, 250)));
671 assert!(!part.overlaps(&TemporalFilter::Range(200, 300)));
673 assert!(!part.overlaps(&TemporalFilter::Range(0, 99)));
674 }
675
676 #[test]
677 fn partition_overlap_before_after() {
678 let part = Partition {
679 start_us: 100,
680 end_us: 200,
681 hnsw: TemporalHnsw::new(HnswConfig::default(), L2Distance),
682 id_offset: 0,
683 point_count: 0,
684 };
685 assert!(part.overlaps(&TemporalFilter::Before(150)));
686 assert!(!part.overlaps(&TemporalFilter::Before(99)));
687 assert!(part.overlaps(&TemporalFilter::After(150)));
688 assert!(!part.overlaps(&TemporalFilter::After(200)));
689 }
690
691 #[test]
694 fn save_and_load_roundtrip() {
695 let mut index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
696
697 for i in 0..20u64 {
698 let ts = i as i64 * 1_000_000;
699 index.insert(i % 5, ts, &[i as f32, (i as f32).sin()]);
700 }
701
702 let dir = tempfile::tempdir().unwrap();
703 index.save(dir.path()).unwrap();
704
705 let loaded = PartitionedTemporalHnsw::load(dir.path(), L2Distance).unwrap();
706
707 assert_eq!(loaded.len(), 20);
708 assert_eq!(loaded.num_partitions(), index.num_partitions());
709
710 let traj_orig = index.trajectory(0, TemporalFilter::All);
712 let traj_loaded = loaded.trajectory(0, TemporalFilter::All);
713 assert_eq!(traj_orig.len(), traj_loaded.len());
714 }
715
716 #[test]
719 fn from_single_preserves_data() {
720 let config = HnswConfig::default();
721 let mut single = TemporalHnsw::new(config, L2Distance);
722
723 for i in 0..15u64 {
724 single.insert(i % 3, i as i64 * 1000, &[i as f32, 0.0]);
725 }
726
727 let partitioned =
728 PartitionedTemporalHnsw::from_single(single, default_config(), L2Distance);
729
730 assert_eq!(partitioned.len(), 15);
731 assert_eq!(partitioned.num_partitions(), 1);
732
733 for eid in 0..3 {
735 let traj = partitioned.trajectory(eid, TemporalFilter::All);
736 assert_eq!(traj.len(), 5, "entity {eid} should have 5 points");
737 }
738 }
739
740 #[test]
743 fn empty_index() {
744 let index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
745 assert_eq!(index.len(), 0);
746 assert!(index.is_empty());
747 assert_eq!(index.num_partitions(), 0);
748 }
749
750 #[test]
751 fn out_of_order_inserts() {
752 let mut index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
753
754 index.insert(1, 20_000_000, &[20.0]);
756 index.insert(1, 10_000_000, &[10.0]);
757 index.insert(1, 0, &[0.0]);
758
759 let traj = index.trajectory(1, TemporalFilter::All);
760 assert_eq!(traj.len(), 3);
761 assert!(traj[0].0 <= traj[1].0);
763 assert!(traj[1].0 <= traj[2].0);
764 }
765
766 #[test]
767 fn partition_info() {
768 let mut index = PartitionedTemporalHnsw::new(default_config(), L2Distance);
769
770 index.insert(1, 0, &[1.0]);
771 index.insert(2, 8_000_000, &[2.0]);
772
773 let info = index.partition_info();
774 assert_eq!(info.len(), 2);
775 assert_eq!(info[0].2, 1); assert_eq!(info[1].2, 1); }
778}