1use cvx_core::{DistanceMetric, TemporalFilter};
29use parking_lot::{Mutex, RwLock};
30
31use super::HnswConfig;
32use super::temporal::TemporalHnsw;
33
34struct PendingInsert {
36 entity_id: u64,
37 timestamp: i64,
38 vector: Vec<f32>,
39}
40
41pub struct ConcurrentTemporalHnsw<D: DistanceMetric> {
51 inner: RwLock<TemporalHnsw<D>>,
52 insert_queue: Mutex<Vec<PendingInsert>>,
54}
55
56impl<D: DistanceMetric> ConcurrentTemporalHnsw<D> {
57 pub fn new(config: HnswConfig, metric: D) -> Self {
59 Self {
60 inner: RwLock::new(TemporalHnsw::new(config, metric)),
61 insert_queue: Mutex::new(Vec::new()),
62 }
63 }
64
65 pub fn len(&self) -> usize {
67 self.inner.read().len()
68 }
69
70 pub fn is_empty(&self) -> bool {
72 self.inner.read().is_empty()
73 }
74
75 pub fn insert(&self, entity_id: u64, timestamp: i64, vector: &[f32]) -> u32 {
77 self.inner.write().insert(entity_id, timestamp, vector)
78 }
79
80 pub fn search(
82 &self,
83 query: &[f32],
84 k: usize,
85 filter: TemporalFilter,
86 alpha: f32,
87 query_timestamp: i64,
88 ) -> Vec<(u32, f32)> {
89 self.inner
90 .read()
91 .search(query, k, filter, alpha, query_timestamp)
92 }
93
94 pub fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
96 self.inner.read().trajectory(entity_id, filter)
97 }
98
99 pub fn timestamp(&self, node_id: u32) -> i64 {
101 self.inner.read().timestamp(node_id)
102 }
103
104 pub fn entity_id(&self, node_id: u32) -> u64 {
106 self.inner.read().entity_id(node_id)
107 }
108
109 pub fn vector(&self, node_id: u32) -> Vec<f32> {
111 self.inner.read().vector(node_id).to_vec()
112 }
113
114 pub fn compute_centroid(&self) -> Option<Vec<f32>> {
118 self.inner.read().compute_centroid()
119 }
120
121 pub fn set_centroid(&self, centroid: Vec<f32>) {
123 self.inner.write().set_centroid(centroid);
124 }
125
126 pub fn clear_centroid(&self) {
128 self.inner.write().clear_centroid();
129 }
130
131 pub fn centroid(&self) -> Option<Vec<f32>> {
133 self.inner.read().centroid().map(|c| c.to_vec())
134 }
135
136 pub fn centered_vector(&self, vec: &[f32]) -> Vec<f32> {
138 self.inner.read().centered_vector(vec)
139 }
140
141 pub fn queue_insert(&self, entity_id: u64, timestamp: i64, vector: Vec<f32>) {
146 self.insert_queue.lock().push(PendingInsert {
147 entity_id,
148 timestamp,
149 vector,
150 });
151 }
152
153 pub fn pending_inserts(&self) -> usize {
155 self.insert_queue.lock().len()
156 }
157
158 pub fn flush_inserts(&self) -> usize {
162 let pending: Vec<PendingInsert> = {
163 let mut queue = self.insert_queue.lock();
164 std::mem::take(&mut *queue)
165 };
166
167 if pending.is_empty() {
168 return 0;
169 }
170
171 let count = pending.len();
172 let mut inner = self.inner.write();
173 for p in pending {
174 inner.insert(p.entity_id, p.timestamp, &p.vector);
175 }
176 count
177 }
178}
179
180impl<D: DistanceMetric> cvx_core::TemporalIndexAccess for ConcurrentTemporalHnsw<D> {
181 fn search_raw(
182 &self,
183 query: &[f32],
184 k: usize,
185 filter: TemporalFilter,
186 alpha: f32,
187 query_timestamp: i64,
188 ) -> Vec<(u32, f32)> {
189 self.inner
190 .read()
191 .search(query, k, filter, alpha, query_timestamp)
192 }
193
194 fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
195 self.inner.read().trajectory(entity_id, filter)
196 }
197
198 fn vector(&self, node_id: u32) -> Vec<f32> {
199 self.inner.read().vector(node_id).to_vec()
200 }
201
202 fn entity_id(&self, node_id: u32) -> u64 {
203 self.inner.read().entity_id(node_id)
204 }
205
206 fn timestamp(&self, node_id: u32) -> i64 {
207 self.inner.read().timestamp(node_id)
208 }
209
210 fn len(&self) -> usize {
211 self.inner.read().len()
212 }
213
214 fn regions(&self, level: usize) -> Vec<(u32, Vec<f32>, usize)> {
215 self.inner.read().regions(level)
216 }
217
218 fn region_members(
219 &self,
220 region_hub: u32,
221 level: usize,
222 filter: cvx_core::TemporalFilter,
223 ) -> Vec<(u32, u64, i64)> {
224 self.inner.read().region_members(region_hub, level, filter)
225 }
226
227 fn region_assignments(
228 &self,
229 level: usize,
230 filter: cvx_core::TemporalFilter,
231 ) -> std::collections::HashMap<u32, Vec<(u64, i64)>> {
232 self.inner.read().region_assignments(level, filter)
233 }
234
235 fn region_trajectory(
236 &self,
237 entity_id: u64,
238 level: usize,
239 window_days: i64,
240 alpha: f32,
241 ) -> Vec<(i64, Vec<f32>)> {
242 self.inner
243 .read()
244 .region_trajectory(entity_id, level, window_days, alpha)
245 }
246}
247
248impl<D: DistanceMetric> cvx_core::IndexBackend for ConcurrentTemporalHnsw<D> {
249 fn insert(
250 &self,
251 entity_id: u64,
252 vector: &[f32],
253 timestamp: i64,
254 ) -> Result<u32, cvx_core::error::IndexError> {
255 Ok(self.inner.write().insert(entity_id, timestamp, vector))
256 }
257
258 fn search(
259 &self,
260 query: &[f32],
261 k: usize,
262 filter: TemporalFilter,
263 alpha: f32,
264 query_timestamp: i64,
265 ) -> Result<Vec<cvx_core::ScoredResult>, cvx_core::error::QueryError> {
266 let inner = self.inner.read();
267 let raw_results = inner.search(query, k, filter, alpha, query_timestamp);
268
269 let results = raw_results
270 .into_iter()
271 .map(|(node_id, combined_score)| {
272 let entity_id = inner.entity_id(node_id);
273 let timestamp = inner.timestamp(node_id);
274 let vector = inner.vector(node_id).to_vec();
275 let point = cvx_core::TemporalPoint::new(entity_id, timestamp, vector);
276
277 let temporal_dist = inner.temporal_distance_normalized(timestamp, query_timestamp);
279 let semantic_dist = if alpha > 0.0 {
280 (combined_score - (1.0 - alpha) * temporal_dist) / alpha
281 } else {
282 0.0
283 };
284
285 cvx_core::ScoredResult::new(point, semantic_dist, temporal_dist, combined_score)
286 })
287 .collect();
288
289 Ok(results)
290 }
291
292 fn remove(&self, _point_id: u64) -> Result<(), cvx_core::error::IndexError> {
293 Ok(())
295 }
296
297 fn len(&self) -> usize {
298 self.inner.read().len()
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use crate::metrics::L2Distance;
306 use std::sync::Arc;
307 use std::thread;
308
309 fn make_concurrent_index() -> Arc<ConcurrentTemporalHnsw<L2Distance>> {
310 let config = HnswConfig {
311 m: 16,
312 ef_construction: 200,
313 ef_search: 100,
314 ..Default::default()
315 };
316 Arc::new(ConcurrentTemporalHnsw::new(config, L2Distance))
317 }
318
319 #[test]
320 fn single_thread_basic() {
321 let index = make_concurrent_index();
322 index.insert(1, 1000, &[1.0, 0.0, 0.0]);
323 index.insert(2, 2000, &[0.0, 1.0, 0.0]);
324
325 let results = index.search(&[1.0, 0.0, 0.0], 2, TemporalFilter::All, 1.0, 0);
326 assert_eq!(results.len(), 2);
327 assert_eq!(results[0].0, 0); }
329
330 #[test]
331 fn concurrent_readers() {
332 let index = make_concurrent_index();
333
334 for i in 0..100u64 {
336 index.insert(i, (i * 100) as i64, &[i as f32, 0.0, 0.0]);
337 }
338
339 let n_threads = 8;
341 let mut handles = Vec::new();
342
343 for t in 0..n_threads {
344 let idx = Arc::clone(&index);
345 handles.push(thread::spawn(move || {
346 let query = [t as f32, 0.0, 0.0];
347 for _ in 0..100 {
348 let results = idx.search(&query, 5, TemporalFilter::All, 1.0, 0);
349 assert_eq!(results.len(), 5);
350 }
351 }));
352 }
353
354 for h in handles {
355 h.join().unwrap();
356 }
357 }
358
359 #[test]
360 fn concurrent_readers_and_writer() {
361 let index = make_concurrent_index();
362
363 for i in 0..50u64 {
365 index.insert(i, (i * 100) as i64, &[i as f32, 0.0, 0.0]);
366 }
367
368 let idx_writer = Arc::clone(&index);
369 let idx_readers: Vec<_> = (0..8).map(|_| Arc::clone(&index)).collect();
370
371 let writer = thread::spawn(move || {
373 for i in 50..150u64 {
374 idx_writer.insert(i, (i * 100) as i64, &[i as f32, 0.0, 0.0]);
375 }
376 });
377
378 let readers: Vec<_> = idx_readers
380 .into_iter()
381 .map(|idx| {
382 thread::spawn(move || {
383 let query = [50.0, 0.0, 0.0];
384 for _ in 0..50 {
385 let results = idx.search(&query, 5, TemporalFilter::All, 1.0, 0);
386 assert!(!results.is_empty());
388 }
389 })
390 })
391 .collect();
392
393 writer.join().unwrap();
394 for r in readers {
395 r.join().unwrap();
396 }
397
398 assert_eq!(index.len(), 150);
400 }
401
402 #[test]
403 fn concurrent_search_with_temporal_filter() {
404 let index = make_concurrent_index();
405
406 for i in 0..200u64 {
407 index.insert(i % 10, (i * 100) as i64, &[i as f32, 0.0]);
408 }
409
410 let mut handles = Vec::new();
411 for t in 0..8 {
412 let idx = Arc::clone(&index);
413 handles.push(thread::spawn(move || {
414 let filter = TemporalFilter::Range(1000, 5000);
415 for _ in 0..50 {
416 let results = idx.search(&[t as f32 * 10.0, 0.0], 5, filter, 0.5, 3000);
417 for &(id, _) in &results {
419 let ts = idx.timestamp(id);
420 assert!(
421 (1000..=5000).contains(&ts),
422 "timestamp {ts} out of [1000, 5000]"
423 );
424 }
425 }
426 }));
427 }
428
429 for h in handles {
430 h.join().unwrap();
431 }
432 }
433
434 #[test]
435 fn queue_insert_and_flush() {
436 let index = make_concurrent_index();
437
438 for i in 0..100u64 {
440 index.queue_insert(i, (i * 100) as i64, vec![i as f32, 0.0, 0.0]);
441 }
442
443 assert_eq!(index.len(), 0);
445 assert_eq!(index.pending_inserts(), 100);
446
447 let flushed = index.flush_inserts();
449 assert_eq!(flushed, 100);
450 assert_eq!(index.len(), 100);
451 assert_eq!(index.pending_inserts(), 0);
452
453 let results = index.search(&[50.0, 0.0, 0.0], 5, TemporalFilter::All, 1.0, 0);
455 assert_eq!(results.len(), 5);
456 }
457
458 #[test]
459 fn queue_insert_concurrent_with_search() {
460 let index = make_concurrent_index();
461
462 for i in 0..50u64 {
464 index.insert(i, (i * 100) as i64, &[i as f32, 0.0, 0.0]);
465 }
466
467 let idx_queue = Arc::clone(&index);
468 let idx_search: Vec<_> = (0..4).map(|_| Arc::clone(&index)).collect();
469
470 let queue_thread = thread::spawn(move || {
472 for i in 50..200u64 {
473 idx_queue.queue_insert(i, (i * 100) as i64, vec![i as f32, 0.0, 0.0]);
474 }
475 });
476
477 let search_threads: Vec<_> = idx_search
479 .into_iter()
480 .map(|idx| {
481 thread::spawn(move || {
482 for _ in 0..50 {
483 let results = idx.search(&[25.0, 0.0, 0.0], 5, TemporalFilter::All, 1.0, 0);
484 assert!(!results.is_empty());
485 }
486 })
487 })
488 .collect();
489
490 queue_thread.join().unwrap();
491 for t in search_threads {
492 t.join().unwrap();
493 }
494
495 let flushed = index.flush_inserts();
497 assert_eq!(flushed, 150);
498 assert_eq!(index.len(), 200);
499 }
500
501 #[test]
502 fn trajectory_concurrent() {
503 let index = make_concurrent_index();
504
505 for i in 0..50u32 {
507 index.insert(1, (i as i64) * 100, &[i as f32]);
508 }
509
510 let mut handles = Vec::new();
511 for _ in 0..4 {
512 let idx = Arc::clone(&index);
513 handles.push(thread::spawn(move || {
514 let traj = idx.trajectory(1, TemporalFilter::All);
515 assert_eq!(traj.len(), 50);
516 }));
517 }
518
519 for h in handles {
520 h.join().unwrap();
521 }
522 }
523
524 #[test]
527 fn centroid_concurrent_read_write() {
528 let index = make_concurrent_index();
529 for i in 0..100u64 {
530 index.insert(i, (i * 100) as i64, &[i as f32, 0.0, 0.0]);
531 }
532
533 let centroid = index.compute_centroid().unwrap();
534 index.set_centroid(centroid.clone());
535
536 let mut handles = Vec::new();
538 for _ in 0..8 {
539 let idx = Arc::clone(&index);
540 let c = centroid.clone();
541 handles.push(thread::spawn(move || {
542 for _ in 0..50 {
543 let got = idx.centroid().unwrap();
544 assert_eq!(got.len(), c.len());
545 let centered = idx.centered_vector(&[50.0, 0.0, 0.0]);
546 assert_eq!(centered.len(), 3);
547 }
548 }));
549 }
550 for h in handles {
551 h.join().unwrap();
552 }
553 }
554
555 #[test]
556 fn clear_centroid_concurrent() {
557 let index = make_concurrent_index();
558 index.insert(1, 1000, &[1.0, 2.0]);
559 index.set_centroid(vec![0.5, 1.0]);
560 assert!(index.centroid().is_some());
561 index.clear_centroid();
562 assert!(index.centroid().is_none());
563 }
564
565 #[test]
568 fn regions_concurrent() {
569 let index = make_concurrent_index();
570 for i in 0..200u64 {
571 index.insert(i % 4, (i * 100) as i64, &[i as f32, (i * 2) as f32, 0.0]);
572 }
573
574 let mut handles = Vec::new();
575 for _ in 0..4 {
576 let idx = Arc::clone(&index);
577 handles.push(thread::spawn(move || {
578 let regions = idx.inner.read().regions(1);
579 assert!(!regions.is_empty());
580 }));
581 }
582 for h in handles {
583 h.join().unwrap();
584 }
585 }
586
587 #[test]
588 fn region_assignments_concurrent() {
589 let index = make_concurrent_index();
590 for i in 0..200u64 {
591 index.insert(i % 4, (i * 100) as i64, &[i as f32, 0.0]);
592 }
593
594 let assignments = index
595 .inner
596 .read()
597 .region_assignments(1, TemporalFilter::All);
598 let total: usize = assignments.values().map(|v| v.len()).sum();
599 assert_eq!(total, 200);
600 }
601
602 #[test]
605 fn entity_id_and_vector_accessors() {
606 let index = make_concurrent_index();
607 index.insert(42, 1000, &[1.0, 2.0, 3.0]);
608
609 assert_eq!(index.entity_id(0), 42);
610 assert_eq!(index.timestamp(0), 1000);
611 assert_eq!(index.vector(0), vec![1.0, 2.0, 3.0]);
612 assert!(!index.is_empty());
613 assert_eq!(index.len(), 1);
614 }
615}