1use std::path::Path;
15
16use cvx_core::traits::DistanceMetric;
17use cvx_core::types::TemporalFilter;
18
19use super::partitioned::{PartitionConfig, PartitionedTemporalHnsw};
20
21#[derive(Debug, Clone)]
25pub struct StreamingConfig {
26 pub buffer_capacity: usize,
28 pub partition_config: PartitionConfig,
30}
31
32impl Default for StreamingConfig {
33 fn default() -> Self {
34 Self {
35 buffer_capacity: 10_000,
36 partition_config: PartitionConfig::default(),
37 }
38 }
39}
40
41struct HotBuffer {
47 points: Vec<BufferedPoint>,
48}
49
50#[derive(Clone)]
52struct BufferedPoint {
53 entity_id: u64,
54 timestamp: i64,
55 vector: Vec<f32>,
56 global_id: u32,
58}
59
60impl HotBuffer {
61 fn new() -> Self {
62 Self { points: Vec::new() }
63 }
64
65 fn push(&mut self, entity_id: u64, timestamp: i64, vector: Vec<f32>, global_id: u32) {
66 self.points.push(BufferedPoint {
67 entity_id,
68 timestamp,
69 vector,
70 global_id,
71 });
72 }
73
74 fn len(&self) -> usize {
75 self.points.len()
76 }
77
78 fn is_empty(&self) -> bool {
79 self.points.is_empty()
80 }
81
82 fn drain(&mut self) -> Vec<BufferedPoint> {
84 std::mem::take(&mut self.points)
85 }
86
87 fn brute_force_search(
89 &self,
90 query: &[f32],
91 k: usize,
92 filter: &TemporalFilter,
93 metric: &dyn DistanceMetric,
94 ) -> Vec<(u32, f32)> {
95 let mut results: Vec<(u32, f32)> = self
96 .points
97 .iter()
98 .filter(|p| filter.matches(p.timestamp))
99 .map(|p| {
100 let dist = metric.distance(query, &p.vector);
101 (p.global_id, dist)
102 })
103 .collect();
104
105 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
106 results.truncate(k);
107 results
108 }
109
110 fn trajectory(&self, entity_id: u64, filter: &TemporalFilter) -> Vec<(i64, u32)> {
112 self.points
113 .iter()
114 .filter(|p| p.entity_id == entity_id && filter.matches(p.timestamp))
115 .map(|p| (p.timestamp, p.global_id))
116 .collect()
117 }
118
119 fn find(&self, global_id: u32) -> Option<&BufferedPoint> {
121 self.points.iter().find(|p| p.global_id == global_id)
122 }
123}
124
125pub struct StreamingTemporalHnsw<D: DistanceMetric + Clone> {
130 buffer: HotBuffer,
132 compacted: PartitionedTemporalHnsw<D>,
134 config: StreamingConfig,
136 metric: D,
138 next_global_id: u32,
140 compaction_count: usize,
142}
143
144impl<D: DistanceMetric + Clone> StreamingTemporalHnsw<D> {
145 pub fn new(config: StreamingConfig, metric: D) -> Self {
147 let compacted =
148 PartitionedTemporalHnsw::new(config.partition_config.clone(), metric.clone());
149 Self {
150 buffer: HotBuffer::new(),
151 compacted,
152 config,
153 metric,
154 next_global_id: 0,
155 compaction_count: 0,
156 }
157 }
158
159 pub fn insert(&mut self, entity_id: u64, timestamp: i64, vector: &[f32]) -> u32 {
163 let global_id = self.next_global_id;
164 self.next_global_id += 1;
165
166 self.buffer
167 .push(entity_id, timestamp, vector.to_vec(), global_id);
168
169 if self.buffer.len() >= self.config.buffer_capacity {
171 self.compact();
172 }
173
174 global_id
175 }
176
177 pub fn compact(&mut self) {
179 let points = self.buffer.drain();
180 if points.is_empty() {
181 return;
182 }
183
184 for p in points {
185 self.compacted.insert(p.entity_id, p.timestamp, &p.vector);
186 }
187
188 self.compaction_count += 1;
189 }
190
191 pub fn search(
193 &self,
194 query: &[f32],
195 k: usize,
196 filter: TemporalFilter,
197 alpha: f32,
198 query_timestamp: i64,
199 ) -> Vec<(u32, f32)> {
200 let mut results = self
202 .compacted
203 .search(query, k, filter, alpha, query_timestamp);
204
205 let buffer_results = self
207 .buffer
208 .brute_force_search(query, k, &filter, &self.metric);
209 results.extend(buffer_results);
210
211 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
213 results.truncate(k);
214 results
215 }
216
217 pub fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
219 let mut traj = self.compacted.trajectory(entity_id, filter);
220 let buffer_traj = self.buffer.trajectory(entity_id, &filter);
221 traj.extend(buffer_traj);
222 traj.sort_by_key(|&(ts, _)| ts);
223 traj
224 }
225
226 pub fn vector(&self, global_id: u32) -> Vec<f32> {
228 if let Some(p) = self.buffer.find(global_id) {
229 return p.vector.clone();
230 }
231 self.compacted.vector(global_id)
232 }
233
234 pub fn entity_id(&self, global_id: u32) -> u64 {
236 if let Some(p) = self.buffer.find(global_id) {
237 return p.entity_id;
238 }
239 self.compacted.entity_id(global_id)
240 }
241
242 pub fn timestamp(&self, global_id: u32) -> i64 {
244 if let Some(p) = self.buffer.find(global_id) {
245 return p.timestamp;
246 }
247 self.compacted.timestamp(global_id)
248 }
249
250 pub fn len(&self) -> usize {
252 self.buffer.len() + self.compacted.len()
253 }
254
255 pub fn is_empty(&self) -> bool {
257 self.buffer.is_empty() && self.compacted.is_empty()
258 }
259
260 pub fn buffer_len(&self) -> usize {
262 self.buffer.len()
263 }
264
265 pub fn compacted_len(&self) -> usize {
267 self.compacted.len()
268 }
269
270 pub fn compaction_count(&self) -> usize {
272 self.compaction_count
273 }
274
275 pub fn save(&mut self, dir: &Path) -> std::io::Result<()> {
277 self.compact();
279 self.compacted.save(dir)
280 }
281}
282
283impl<D: DistanceMetric + Clone> cvx_core::TemporalIndexAccess for StreamingTemporalHnsw<D> {
286 fn search_raw(
287 &self,
288 query: &[f32],
289 k: usize,
290 filter: TemporalFilter,
291 alpha: f32,
292 query_timestamp: i64,
293 ) -> Vec<(u32, f32)> {
294 self.search(query, k, filter, alpha, query_timestamp)
295 }
296
297 fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
298 self.trajectory(entity_id, filter)
299 }
300
301 fn vector(&self, node_id: u32) -> Vec<f32> {
302 self.vector(node_id)
303 }
304
305 fn entity_id(&self, node_id: u32) -> u64 {
306 self.entity_id(node_id)
307 }
308
309 fn timestamp(&self, node_id: u32) -> i64 {
310 self.timestamp(node_id)
311 }
312
313 fn len(&self) -> usize {
314 self.len()
315 }
316}
317
318#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::metrics::L2Distance;
324
325 fn test_config(buffer_cap: usize) -> StreamingConfig {
326 StreamingConfig {
327 buffer_capacity: buffer_cap,
328 partition_config: PartitionConfig {
329 partition_duration_us: 10_000_000, ..Default::default()
331 },
332 }
333 }
334
335 #[test]
338 fn new_empty() {
339 let index = StreamingTemporalHnsw::new(test_config(100), L2Distance);
340 assert_eq!(index.len(), 0);
341 assert!(index.is_empty());
342 assert_eq!(index.buffer_len(), 0);
343 assert_eq!(index.compacted_len(), 0);
344 }
345
346 #[test]
347 fn insert_into_buffer() {
348 let mut index = StreamingTemporalHnsw::new(test_config(100), L2Distance);
349
350 for i in 0..10u64 {
351 index.insert(i, i as i64 * 1000, &[i as f32, 0.0]);
352 }
353
354 assert_eq!(index.len(), 10);
355 assert_eq!(index.buffer_len(), 10);
356 assert_eq!(index.compacted_len(), 0);
357 }
358
359 #[test]
360 fn auto_compaction_on_capacity() {
361 let mut index = StreamingTemporalHnsw::new(test_config(5), L2Distance);
362
363 for i in 0..5u64 {
365 index.insert(i, i as i64 * 1000, &[i as f32, 0.0]);
366 }
367
368 assert_eq!(index.compaction_count(), 1);
369 assert_eq!(index.buffer_len(), 0);
370 assert_eq!(index.compacted_len(), 5);
371 assert_eq!(index.len(), 5);
372 }
373
374 #[test]
375 fn manual_compaction() {
376 let mut index = StreamingTemporalHnsw::new(test_config(100), L2Distance);
377
378 for i in 0..10u64 {
379 index.insert(i, i as i64 * 1000, &[i as f32, 0.0]);
380 }
381
382 assert_eq!(index.buffer_len(), 10);
383 index.compact();
384 assert_eq!(index.buffer_len(), 0);
385 assert_eq!(index.compacted_len(), 10);
386 }
387
388 #[test]
391 fn search_buffer_only() {
392 let mut index = StreamingTemporalHnsw::new(test_config(100), L2Distance);
393
394 for i in 0..10u64 {
395 index.insert(i, i as i64 * 1000, &[i as f32, 0.0]);
396 }
397
398 let results = index.search(&[5.0, 0.0], 3, TemporalFilter::All, 1.0, 0);
399 assert_eq!(results.len(), 3);
400 }
401
402 #[test]
403 fn search_compacted_only() {
404 let mut index = StreamingTemporalHnsw::new(test_config(5), L2Distance);
405
406 for i in 0..10u64 {
407 index.insert(i, i as i64 * 1000, &[i as f32, 0.0]);
408 }
409 index.compact();
411
412 let results = index.search(&[5.0, 0.0], 3, TemporalFilter::All, 1.0, 0);
413 assert_eq!(results.len(), 3);
414 }
415
416 #[test]
417 fn search_merged_buffer_and_compacted() {
418 let mut index = StreamingTemporalHnsw::new(test_config(100), L2Distance);
419
420 for i in 0..5u64 {
422 index.insert(i, i as i64 * 1000, &[i as f32, 0.0]);
423 }
424 index.compact();
425
426 for i in 5..10u64 {
428 index.insert(i, i as i64 * 1000, &[i as f32, 0.0]);
429 }
430
431 assert_eq!(index.compacted_len(), 5);
432 assert_eq!(index.buffer_len(), 5);
433
434 let results = index.search(&[5.0, 0.0], 5, TemporalFilter::All, 1.0, 0);
435 assert_eq!(results.len(), 5);
436 }
437
438 #[test]
441 fn trajectory_across_buffer_and_compacted() {
442 let mut index = StreamingTemporalHnsw::new(test_config(100), L2Distance);
443
444 for i in 0..5u64 {
446 index.insert(42, i as i64 * 1000, &[i as f32]);
447 }
448 index.compact();
449 for i in 5..10u64 {
450 index.insert(42, i as i64 * 1000, &[i as f32]);
451 }
452
453 let traj = index.trajectory(42, TemporalFilter::All);
454 assert_eq!(traj.len(), 10, "should find all 10 points");
455
456 for w in traj.windows(2) {
458 assert!(w[0].0 <= w[1].0);
459 }
460 }
461
462 #[test]
465 fn resolve_ids_in_buffer() {
466 let mut index = StreamingTemporalHnsw::new(test_config(100), L2Distance);
467 let id = index.insert(42, 1000, &[1.0, 2.0]);
468
469 assert_eq!(index.entity_id(id), 42);
470 assert_eq!(index.timestamp(id), 1000);
471 assert_eq!(index.vector(id), vec![1.0, 2.0]);
472 }
473
474 #[test]
475 fn resolve_ids_after_compaction() {
476 let mut index = StreamingTemporalHnsw::new(test_config(2), L2Distance);
477 let _id0 = index.insert(1, 100, &[1.0]);
478 let _id1 = index.insert(2, 200, &[2.0]);
479 let traj_1 = index.trajectory(1, TemporalFilter::All);
486 let traj_2 = index.trajectory(2, TemporalFilter::All);
487 assert_eq!(traj_1.len(), 1);
488 assert_eq!(traj_2.len(), 1);
489 }
490
491 #[test]
494 fn trait_search() {
495 let mut index = StreamingTemporalHnsw::new(test_config(100), L2Distance);
496 for i in 0..10u64 {
497 index.insert(i, i as i64 * 1000, &[i as f32, 0.0]);
498 }
499
500 let trait_ref: &dyn cvx_core::TemporalIndexAccess = &index;
501 let results = trait_ref.search_raw(&[5.0, 0.0], 3, TemporalFilter::All, 1.0, 0);
502 assert_eq!(results.len(), 3);
503 }
504
505 #[test]
508 fn compact_empty_buffer() {
509 let mut index = StreamingTemporalHnsw::new(test_config(100), L2Distance);
510 index.compact(); assert_eq!(index.compaction_count(), 0); }
513
514 #[test]
515 fn multiple_compactions() {
516 let mut index = StreamingTemporalHnsw::new(test_config(5), L2Distance);
517
518 for i in 0..20u64 {
519 index.insert(i % 3, i as i64 * 1000, &[i as f32]);
520 }
521
522 assert!(
523 index.compaction_count() >= 3,
524 "should compact multiple times"
525 );
526 assert_eq!(index.len(), 20);
527 }
528}