cvx_index/hnsw/
concurrent.rs

1//! Thread-safe concurrent HNSW index.
2//!
3//! Wraps [`TemporalHnsw`] with a [`parking_lot::RwLock`] for concurrent access:
4//! - Multiple readers can search simultaneously (read lock)
5//! - A single writer can insert (write lock)
6//!
7//! This is the main entry point for production use.
8//!
9//! # Example
10//!
11//! ```
12//! use cvx_index::hnsw::{ConcurrentTemporalHnsw, HnswConfig};
13//! use cvx_index::metrics::L2Distance;
14//! use cvx_core::TemporalFilter;
15//! use std::sync::Arc;
16//!
17//! let config = HnswConfig::default();
18//! let index = Arc::new(ConcurrentTemporalHnsw::new(config, L2Distance));
19//!
20//! // Insert (takes write lock)
21//! index.insert(1, 1000, &[1.0, 0.0, 0.0]);
22//!
23//! // Search (takes read lock) — can run from multiple threads
24//! let results = index.search(&[1.0, 0.0, 0.0], 5, TemporalFilter::All, 1.0, 1000);
25//! assert_eq!(results.len(), 1);
26//! ```
27
28use cvx_core::{DistanceMetric, TemporalFilter};
29use parking_lot::{Mutex, RwLock};
30
31use super::HnswConfig;
32use super::temporal::TemporalHnsw;
33
34/// A pending insert waiting in the queue.
35struct PendingInsert {
36    entity_id: u64,
37    timestamp: i64,
38    vector: Vec<f32>,
39}
40
41/// Thread-safe spatiotemporal HNSW index with insert queue (RFC-002-04).
42///
43/// Uses a two-tier approach to reduce write lock contention:
44/// - Inserts are queued into a `Mutex<Vec<...>>` (sub-microsecond)
45/// - `flush_inserts()` drains the queue under a single write lock
46/// - Searches always acquire a read lock (concurrent, unblocked during queue drain)
47///
48/// For immediate visibility, use `insert()` which still takes the write lock directly.
49/// For high-throughput ingestion, use `queue_insert()` + `flush_inserts()`.
50pub struct ConcurrentTemporalHnsw<D: DistanceMetric> {
51    inner: RwLock<TemporalHnsw<D>>,
52    /// Insert queue for batched commits (RFC-002-04, Option A).
53    insert_queue: Mutex<Vec<PendingInsert>>,
54}
55
56impl<D: DistanceMetric> ConcurrentTemporalHnsw<D> {
57    /// Create a new empty concurrent index.
58    pub fn new(config: HnswConfig, metric: D) -> Self {
59        Self {
60            inner: RwLock::new(TemporalHnsw::new(config, metric)),
61            insert_queue: Mutex::new(Vec::new()),
62        }
63    }
64
65    /// Number of points in the index.
66    pub fn len(&self) -> usize {
67        self.inner.read().len()
68    }
69
70    /// Whether the index is empty.
71    pub fn is_empty(&self) -> bool {
72        self.inner.read().is_empty()
73    }
74
75    /// Insert a temporal point (exclusive write lock).
76    pub fn insert(&self, entity_id: u64, timestamp: i64, vector: &[f32]) -> u32 {
77        self.inner.write().insert(entity_id, timestamp, vector)
78    }
79
80    /// Search with temporal filtering (shared read lock).
81    pub fn search(
82        &self,
83        query: &[f32],
84        k: usize,
85        filter: TemporalFilter,
86        alpha: f32,
87        query_timestamp: i64,
88    ) -> Vec<(u32, f32)> {
89        self.inner
90            .read()
91            .search(query, k, filter, alpha, query_timestamp)
92    }
93
94    /// Retrieve trajectory for an entity (shared read lock).
95    pub fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
96        self.inner.read().trajectory(entity_id, filter)
97    }
98
99    /// Get timestamp for a node (shared read lock).
100    pub fn timestamp(&self, node_id: u32) -> i64 {
101        self.inner.read().timestamp(node_id)
102    }
103
104    /// Get entity_id for a node (shared read lock).
105    pub fn entity_id(&self, node_id: u32) -> u64 {
106        self.inner.read().entity_id(node_id)
107    }
108
109    /// Get vector for a node (shared read lock).
110    pub fn vector(&self, node_id: u32) -> Vec<f32> {
111        self.inner.read().vector(node_id).to_vec()
112    }
113
114    // ─── Centering (RFC-012 Part B) ──────────────────────────────────
115
116    /// Compute the centroid (mean vector) of all indexed vectors.
117    pub fn compute_centroid(&self) -> Option<Vec<f32>> {
118        self.inner.read().compute_centroid()
119    }
120
121    /// Set the centroid for anisotropy correction (write lock).
122    pub fn set_centroid(&self, centroid: Vec<f32>) {
123        self.inner.write().set_centroid(centroid);
124    }
125
126    /// Clear the centroid (write lock).
127    pub fn clear_centroid(&self) {
128        self.inner.write().clear_centroid();
129    }
130
131    /// Get the current centroid, if set.
132    pub fn centroid(&self) -> Option<Vec<f32>> {
133        self.inner.read().centroid().map(|c| c.to_vec())
134    }
135
136    /// Return a centered copy of the given vector (vec - centroid).
137    pub fn centered_vector(&self, vec: &[f32]) -> Vec<f32> {
138        self.inner.read().centered_vector(vec)
139    }
140
141    /// Queue an insert for batched processing (RFC-002-04).
142    ///
143    /// This only takes a `Mutex` (sub-microsecond), NOT the write lock.
144    /// The insert becomes visible after `flush_inserts()` is called.
145    pub fn queue_insert(&self, entity_id: u64, timestamp: i64, vector: Vec<f32>) {
146        self.insert_queue.lock().push(PendingInsert {
147            entity_id,
148            timestamp,
149            vector,
150        });
151    }
152
153    /// Number of pending inserts in the queue.
154    pub fn pending_inserts(&self) -> usize {
155        self.insert_queue.lock().len()
156    }
157
158    /// Flush all queued inserts, applying them under a single write lock.
159    ///
160    /// Returns the number of inserts applied.
161    pub fn flush_inserts(&self) -> usize {
162        let pending: Vec<PendingInsert> = {
163            let mut queue = self.insert_queue.lock();
164            std::mem::take(&mut *queue)
165        };
166
167        if pending.is_empty() {
168            return 0;
169        }
170
171        let count = pending.len();
172        let mut inner = self.inner.write();
173        for p in pending {
174            inner.insert(p.entity_id, p.timestamp, &p.vector);
175        }
176        count
177    }
178}
179
180impl<D: DistanceMetric> cvx_core::TemporalIndexAccess for ConcurrentTemporalHnsw<D> {
181    fn search_raw(
182        &self,
183        query: &[f32],
184        k: usize,
185        filter: TemporalFilter,
186        alpha: f32,
187        query_timestamp: i64,
188    ) -> Vec<(u32, f32)> {
189        self.inner
190            .read()
191            .search(query, k, filter, alpha, query_timestamp)
192    }
193
194    fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
195        self.inner.read().trajectory(entity_id, filter)
196    }
197
198    fn vector(&self, node_id: u32) -> Vec<f32> {
199        self.inner.read().vector(node_id).to_vec()
200    }
201
202    fn entity_id(&self, node_id: u32) -> u64 {
203        self.inner.read().entity_id(node_id)
204    }
205
206    fn timestamp(&self, node_id: u32) -> i64 {
207        self.inner.read().timestamp(node_id)
208    }
209
210    fn len(&self) -> usize {
211        self.inner.read().len()
212    }
213
214    fn regions(&self, level: usize) -> Vec<(u32, Vec<f32>, usize)> {
215        self.inner.read().regions(level)
216    }
217
218    fn region_members(
219        &self,
220        region_hub: u32,
221        level: usize,
222        filter: cvx_core::TemporalFilter,
223    ) -> Vec<(u32, u64, i64)> {
224        self.inner.read().region_members(region_hub, level, filter)
225    }
226
227    fn region_assignments(
228        &self,
229        level: usize,
230        filter: cvx_core::TemporalFilter,
231    ) -> std::collections::HashMap<u32, Vec<(u64, i64)>> {
232        self.inner.read().region_assignments(level, filter)
233    }
234
235    fn region_trajectory(
236        &self,
237        entity_id: u64,
238        level: usize,
239        window_days: i64,
240        alpha: f32,
241    ) -> Vec<(i64, Vec<f32>)> {
242        self.inner
243            .read()
244            .region_trajectory(entity_id, level, window_days, alpha)
245    }
246}
247
248impl<D: DistanceMetric> cvx_core::IndexBackend for ConcurrentTemporalHnsw<D> {
249    fn insert(
250        &self,
251        entity_id: u64,
252        vector: &[f32],
253        timestamp: i64,
254    ) -> Result<u32, cvx_core::error::IndexError> {
255        Ok(self.inner.write().insert(entity_id, timestamp, vector))
256    }
257
258    fn search(
259        &self,
260        query: &[f32],
261        k: usize,
262        filter: TemporalFilter,
263        alpha: f32,
264        query_timestamp: i64,
265    ) -> Result<Vec<cvx_core::ScoredResult>, cvx_core::error::QueryError> {
266        let inner = self.inner.read();
267        let raw_results = inner.search(query, k, filter, alpha, query_timestamp);
268
269        let results = raw_results
270            .into_iter()
271            .map(|(node_id, combined_score)| {
272                let entity_id = inner.entity_id(node_id);
273                let timestamp = inner.timestamp(node_id);
274                let vector = inner.vector(node_id).to_vec();
275                let point = cvx_core::TemporalPoint::new(entity_id, timestamp, vector);
276
277                // Decompose combined score into semantic and temporal components
278                let temporal_dist = inner.temporal_distance_normalized(timestamp, query_timestamp);
279                let semantic_dist = if alpha > 0.0 {
280                    (combined_score - (1.0 - alpha) * temporal_dist) / alpha
281                } else {
282                    0.0
283                };
284
285                cvx_core::ScoredResult::new(point, semantic_dist, temporal_dist, combined_score)
286            })
287            .collect();
288
289        Ok(results)
290    }
291
292    fn remove(&self, _point_id: u64) -> Result<(), cvx_core::error::IndexError> {
293        // Removal not yet supported in HNSW — mark as tombstone in future
294        Ok(())
295    }
296
297    fn len(&self) -> usize {
298        self.inner.read().len()
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use crate::metrics::L2Distance;
306    use std::sync::Arc;
307    use std::thread;
308
309    fn make_concurrent_index() -> Arc<ConcurrentTemporalHnsw<L2Distance>> {
310        let config = HnswConfig {
311            m: 16,
312            ef_construction: 200,
313            ef_search: 100,
314            ..Default::default()
315        };
316        Arc::new(ConcurrentTemporalHnsw::new(config, L2Distance))
317    }
318
319    #[test]
320    fn single_thread_basic() {
321        let index = make_concurrent_index();
322        index.insert(1, 1000, &[1.0, 0.0, 0.0]);
323        index.insert(2, 2000, &[0.0, 1.0, 0.0]);
324
325        let results = index.search(&[1.0, 0.0, 0.0], 2, TemporalFilter::All, 1.0, 0);
326        assert_eq!(results.len(), 2);
327        assert_eq!(results[0].0, 0); // closest
328    }
329
330    #[test]
331    fn concurrent_readers() {
332        let index = make_concurrent_index();
333
334        // Insert some data first
335        for i in 0..100u64 {
336            index.insert(i, (i * 100) as i64, &[i as f32, 0.0, 0.0]);
337        }
338
339        // Spawn 8 reader threads
340        let n_threads = 8;
341        let mut handles = Vec::new();
342
343        for t in 0..n_threads {
344            let idx = Arc::clone(&index);
345            handles.push(thread::spawn(move || {
346                let query = [t as f32, 0.0, 0.0];
347                for _ in 0..100 {
348                    let results = idx.search(&query, 5, TemporalFilter::All, 1.0, 0);
349                    assert_eq!(results.len(), 5);
350                }
351            }));
352        }
353
354        for h in handles {
355            h.join().unwrap();
356        }
357    }
358
359    #[test]
360    fn concurrent_readers_and_writer() {
361        let index = make_concurrent_index();
362
363        // Pre-populate with some data
364        for i in 0..50u64 {
365            index.insert(i, (i * 100) as i64, &[i as f32, 0.0, 0.0]);
366        }
367
368        let idx_writer = Arc::clone(&index);
369        let idx_readers: Vec<_> = (0..8).map(|_| Arc::clone(&index)).collect();
370
371        // Writer thread
372        let writer = thread::spawn(move || {
373            for i in 50..150u64 {
374                idx_writer.insert(i, (i * 100) as i64, &[i as f32, 0.0, 0.0]);
375            }
376        });
377
378        // Reader threads
379        let readers: Vec<_> = idx_readers
380            .into_iter()
381            .map(|idx| {
382                thread::spawn(move || {
383                    let query = [50.0, 0.0, 0.0];
384                    for _ in 0..50 {
385                        let results = idx.search(&query, 5, TemporalFilter::All, 1.0, 0);
386                        // Results should always be non-empty since we pre-populated
387                        assert!(!results.is_empty());
388                    }
389                })
390            })
391            .collect();
392
393        writer.join().unwrap();
394        for r in readers {
395            r.join().unwrap();
396        }
397
398        // Verify final state
399        assert_eq!(index.len(), 150);
400    }
401
402    #[test]
403    fn concurrent_search_with_temporal_filter() {
404        let index = make_concurrent_index();
405
406        for i in 0..200u64 {
407            index.insert(i % 10, (i * 100) as i64, &[i as f32, 0.0]);
408        }
409
410        let mut handles = Vec::new();
411        for t in 0..8 {
412            let idx = Arc::clone(&index);
413            handles.push(thread::spawn(move || {
414                let filter = TemporalFilter::Range(1000, 5000);
415                for _ in 0..50 {
416                    let results = idx.search(&[t as f32 * 10.0, 0.0], 5, filter, 0.5, 3000);
417                    // All results should have timestamps in range
418                    for &(id, _) in &results {
419                        let ts = idx.timestamp(id);
420                        assert!(
421                            (1000..=5000).contains(&ts),
422                            "timestamp {ts} out of [1000, 5000]"
423                        );
424                    }
425                }
426            }));
427        }
428
429        for h in handles {
430            h.join().unwrap();
431        }
432    }
433
434    #[test]
435    fn queue_insert_and_flush() {
436        let index = make_concurrent_index();
437
438        // Queue 100 inserts (no write lock held)
439        for i in 0..100u64 {
440            index.queue_insert(i, (i * 100) as i64, vec![i as f32, 0.0, 0.0]);
441        }
442
443        // Not yet visible
444        assert_eq!(index.len(), 0);
445        assert_eq!(index.pending_inserts(), 100);
446
447        // Flush applies them all under a single write lock
448        let flushed = index.flush_inserts();
449        assert_eq!(flushed, 100);
450        assert_eq!(index.len(), 100);
451        assert_eq!(index.pending_inserts(), 0);
452
453        // Searchable after flush
454        let results = index.search(&[50.0, 0.0, 0.0], 5, TemporalFilter::All, 1.0, 0);
455        assert_eq!(results.len(), 5);
456    }
457
458    #[test]
459    fn queue_insert_concurrent_with_search() {
460        let index = make_concurrent_index();
461
462        // Pre-populate so searches always return results
463        for i in 0..50u64 {
464            index.insert(i, (i * 100) as i64, &[i as f32, 0.0, 0.0]);
465        }
466
467        let idx_queue = Arc::clone(&index);
468        let idx_search: Vec<_> = (0..4).map(|_| Arc::clone(&index)).collect();
469
470        // Queue thread: queues inserts without blocking searches
471        let queue_thread = thread::spawn(move || {
472            for i in 50..200u64 {
473                idx_queue.queue_insert(i, (i * 100) as i64, vec![i as f32, 0.0, 0.0]);
474            }
475        });
476
477        // Search threads: run concurrently with queue thread
478        let search_threads: Vec<_> = idx_search
479            .into_iter()
480            .map(|idx| {
481                thread::spawn(move || {
482                    for _ in 0..50 {
483                        let results = idx.search(&[25.0, 0.0, 0.0], 5, TemporalFilter::All, 1.0, 0);
484                        assert!(!results.is_empty());
485                    }
486                })
487            })
488            .collect();
489
490        queue_thread.join().unwrap();
491        for t in search_threads {
492            t.join().unwrap();
493        }
494
495        // Flush and verify
496        let flushed = index.flush_inserts();
497        assert_eq!(flushed, 150);
498        assert_eq!(index.len(), 200);
499    }
500
501    #[test]
502    fn trajectory_concurrent() {
503        let index = make_concurrent_index();
504
505        // Insert trajectory for entity 1
506        for i in 0..50u32 {
507            index.insert(1, (i as i64) * 100, &[i as f32]);
508        }
509
510        let mut handles = Vec::new();
511        for _ in 0..4 {
512            let idx = Arc::clone(&index);
513            handles.push(thread::spawn(move || {
514                let traj = idx.trajectory(1, TemporalFilter::All);
515                assert_eq!(traj.len(), 50);
516            }));
517        }
518
519        for h in handles {
520            h.join().unwrap();
521        }
522    }
523
524    // ─── Centering thread-safety ─────────────────────────────────
525
526    #[test]
527    fn centroid_concurrent_read_write() {
528        let index = make_concurrent_index();
529        for i in 0..100u64 {
530            index.insert(i, (i * 100) as i64, &[i as f32, 0.0, 0.0]);
531        }
532
533        let centroid = index.compute_centroid().unwrap();
534        index.set_centroid(centroid.clone());
535
536        // Concurrent reads of centroid + centered_vector
537        let mut handles = Vec::new();
538        for _ in 0..8 {
539            let idx = Arc::clone(&index);
540            let c = centroid.clone();
541            handles.push(thread::spawn(move || {
542                for _ in 0..50 {
543                    let got = idx.centroid().unwrap();
544                    assert_eq!(got.len(), c.len());
545                    let centered = idx.centered_vector(&[50.0, 0.0, 0.0]);
546                    assert_eq!(centered.len(), 3);
547                }
548            }));
549        }
550        for h in handles {
551            h.join().unwrap();
552        }
553    }
554
555    #[test]
556    fn clear_centroid_concurrent() {
557        let index = make_concurrent_index();
558        index.insert(1, 1000, &[1.0, 2.0]);
559        index.set_centroid(vec![0.5, 1.0]);
560        assert!(index.centroid().is_some());
561        index.clear_centroid();
562        assert!(index.centroid().is_none());
563    }
564
565    // ─── Region delegations ─────────────────────────────────────
566
567    #[test]
568    fn regions_concurrent() {
569        let index = make_concurrent_index();
570        for i in 0..200u64 {
571            index.insert(i % 4, (i * 100) as i64, &[i as f32, (i * 2) as f32, 0.0]);
572        }
573
574        let mut handles = Vec::new();
575        for _ in 0..4 {
576            let idx = Arc::clone(&index);
577            handles.push(thread::spawn(move || {
578                let regions = idx.inner.read().regions(1);
579                assert!(!regions.is_empty());
580            }));
581        }
582        for h in handles {
583            h.join().unwrap();
584        }
585    }
586
587    #[test]
588    fn region_assignments_concurrent() {
589        let index = make_concurrent_index();
590        for i in 0..200u64 {
591            index.insert(i % 4, (i * 100) as i64, &[i as f32, 0.0]);
592        }
593
594        let assignments = index
595            .inner
596            .read()
597            .region_assignments(1, TemporalFilter::All);
598        let total: usize = assignments.values().map(|v| v.len()).sum();
599        assert_eq!(total, 200);
600    }
601
602    // ─── Accessor coverage ──────────────────────────────────────
603
604    #[test]
605    fn entity_id_and_vector_accessors() {
606        let index = make_concurrent_index();
607        index.insert(42, 1000, &[1.0, 2.0, 3.0]);
608
609        assert_eq!(index.entity_id(0), 42);
610        assert_eq!(index.timestamp(0), 1000);
611        assert_eq!(index.vector(0), vec![1.0, 2.0, 3.0]);
612        assert!(!index.is_empty());
613        assert_eq!(index.len(), 1);
614    }
615}