cvx_index/hnsw/
temporal.rs

1//! Spatiotemporal HNSW index (ST-HNSW).
2//!
3//! Wraps the vanilla [`HnswGraph`] with temporal awareness:
4//! - **Roaring Bitmaps** for O(1) temporal pre-filtering
5//! - **Composite distance**: $d_{ST} = \alpha \cdot d_{sem} + (1 - \alpha) \cdot d_{time}$
6//! - **TemporalFilter** integration: snapshot kNN, range kNN
7//! - **Trajectory retrieval**: all points for an entity ordered by time
8//!
9//! # Example
10//!
11//! ```
12//! use cvx_index::hnsw::{HnswConfig, TemporalHnsw};
13//! use cvx_index::metrics::L2Distance;
14//! use cvx_core::TemporalFilter;
15//!
16//! let config = HnswConfig::default();
17//! let mut index = TemporalHnsw::new(config, L2Distance);
18//!
19//! // Insert vectors with entity_id and timestamp
20//! index.insert(1, 1000, &[1.0, 0.0, 0.0]);
21//! index.insert(1, 2000, &[0.9, 0.1, 0.0]);
22//! index.insert(2, 1500, &[0.0, 1.0, 0.0]);
23//!
24//! // Temporal range search with alpha=0.5
25//! let results = index.search(
26//!     &[1.0, 0.0, 0.0],
27//!     2,
28//!     TemporalFilter::Range(900, 1600),
29//!     0.5,
30//!     1000, // query timestamp for temporal distance
31//! );
32//! assert_eq!(results.len(), 2); // only 2 points in [900, 1600]
33//! ```
34
35use std::collections::BTreeMap;
36use std::io::{Read, Write};
37use std::path::Path;
38
39use cvx_core::{DistanceMetric, TemporalFilter};
40use roaring::RoaringBitmap;
41use serde::{Deserialize, Serialize};
42
43use super::{HnswConfig, HnswGraph, HnswSnapshot};
44
45/// Spatiotemporal HNSW index.
46///
47/// Each inserted point has an `entity_id`, a `timestamp`, and a vector.
48/// Internal node IDs are assigned sequentially.
49pub struct TemporalHnsw<D: DistanceMetric> {
50    graph: HnswGraph<D>,
51    /// node_id → timestamp
52    timestamps: Vec<i64>,
53    /// node_id → entity_id
54    entity_ids: Vec<u64>,
55    /// entity_id → sorted vec of (timestamp, node_id)
56    entity_index: BTreeMap<u64, Vec<(i64, u32)>>,
57    /// Global temporal range for normalization
58    min_timestamp: i64,
59    max_timestamp: i64,
60    /// Optional per-node metadata store.
61    metadata_store: Option<super::metadata_store::MetadataStore>,
62    /// Optional centroid for anisotropy correction (RFC-012 Part B).
63    ///
64    /// When set, all distance computations center vectors by subtracting
65    /// this mean vector, amplifying the discriminative signal that is
66    /// otherwise compressed by the dominant "average text" direction.
67    centroid: Option<Vec<f32>>,
68    /// Optional per-node reward for outcome-aware search (RFC-012 P4).
69    ///
70    /// NaN means "no reward assigned". Stored parallel to timestamps/entity_ids.
71    rewards: Vec<f32>,
72}
73
74impl<D: DistanceMetric> TemporalHnsw<D> {
75    /// Create a new empty spatiotemporal index.
76    pub fn new(config: HnswConfig, metric: D) -> Self {
77        Self {
78            graph: HnswGraph::new(config, metric),
79            timestamps: Vec::new(),
80            entity_ids: Vec::new(),
81            entity_index: BTreeMap::new(),
82            min_timestamp: i64::MAX,
83            max_timestamp: i64::MIN,
84            metadata_store: None,
85            centroid: None,
86            rewards: Vec::new(),
87        }
88    }
89
90    /// Number of points in the index.
91    pub fn len(&self) -> usize {
92        self.graph.len()
93    }
94
95    /// Whether the index is empty.
96    pub fn is_empty(&self) -> bool {
97        self.graph.is_empty()
98    }
99
100    /// Get the last (most recent) node_id for an entity, or None if not found.
101    pub fn entity_last_node(&self, entity_id: u64) -> Option<u32> {
102        self.entity_index
103            .get(&entity_id)
104            .and_then(|pts| pts.last().map(|&(_, nid)| nid))
105    }
106
107    /// Set ef_construction at runtime (lower for bulk load, higher for quality).
108    pub fn set_ef_construction(&mut self, ef: usize) {
109        self.graph.set_ef_construction(ef);
110    }
111
112    /// Set ef_search at runtime.
113    pub fn set_ef_search(&mut self, ef: usize) {
114        self.graph.set_ef_search(ef);
115    }
116
117    /// Get the current configuration.
118    pub fn config(&self) -> &HnswConfig {
119        self.graph.config()
120    }
121
122    /// Enable scalar quantization for faster distance computation.
123    pub fn enable_scalar_quantization(&mut self, min_val: f32, max_val: f32) {
124        self.graph.enable_scalar_quantization(min_val, max_val);
125    }
126
127    /// Disable scalar quantization.
128    pub fn disable_scalar_quantization(&mut self) {
129        self.graph.disable_scalar_quantization();
130    }
131
132    /// Whether scalar quantization is active.
133    pub fn is_quantized(&self) -> bool {
134        self.graph.is_quantized()
135    }
136
137    /// Insert a temporal point into the index.
138    ///
139    /// Returns the internal node_id assigned to this point.
140    pub fn insert(&mut self, entity_id: u64, timestamp: i64, vector: &[f32]) -> u32 {
141        self.insert_with_reward(entity_id, timestamp, vector, f32::NAN)
142    }
143
144    /// Bulk insert multiple points with parallel distance computation (RFC-012 P9).
145    ///
146    /// Faster than sequential `insert()` calls for large batches. Uses rayon
147    /// to parallelize neighbor search across chunks while keeping graph
148    /// modifications sequential.
149    ///
150    /// Returns the number of points inserted.
151    pub fn bulk_insert_parallel(
152        &mut self,
153        entity_ids: &[u64],
154        timestamps: &[i64],
155        vectors: &[&[f32]],
156    ) -> usize {
157        use rayon::prelude::*;
158
159        let n = entity_ids.len();
160        if n == 0 {
161            return 0;
162        }
163
164        // Phase 1: Insert first 100 points sequentially to build initial graph
165        let seed_count = n.min(100);
166        for i in 0..seed_count {
167            self.insert(entity_ids[i], timestamps[i], vectors[i]);
168        }
169
170        if seed_count >= n {
171            return n;
172        }
173
174        // Phase 2: For remaining points, compute neighbors in parallel batches
175        let batch_size = 256;
176        let remaining = &vectors[seed_count..];
177        let remaining_eids = &entity_ids[seed_count..];
178        let remaining_ts = &timestamps[seed_count..];
179
180        for batch_start in (0..remaining.len()).step_by(batch_size) {
181            let batch_end = (batch_start + batch_size).min(remaining.len());
182            let batch_vecs = &remaining[batch_start..batch_end];
183
184            // Parallel: compute nearest neighbors for each vector in the batch
185            let neighbor_lists: Vec<Vec<(u32, f32)>> = batch_vecs
186                .par_iter()
187                .map(|vec| self.graph.search(vec, self.graph.config().ef_construction))
188                .collect();
189
190            // Sequential: insert nodes and connect using pre-computed neighbors
191            for (i, neighbors) in neighbor_lists.into_iter().enumerate() {
192                let idx = batch_start + i;
193                let eid = remaining_eids[idx];
194                let ts = remaining_ts[idx];
195                let vec = remaining[idx];
196
197                let node_id = self.graph.len() as u32;
198                // Allocate node
199                let level = self.graph.random_level();
200                self.graph.push_node(vec, level);
201                self.timestamps.push(ts);
202                self.entity_ids.push(eid);
203                self.rewards.push(f32::NAN);
204
205                // Connect using pre-computed neighbors
206                self.graph.connect_node(node_id, &neighbors, level);
207
208                // Update entity index
209                self.entity_index
210                    .entry(eid)
211                    .or_default()
212                    .push((ts, node_id));
213                self.min_timestamp = self.min_timestamp.min(ts);
214                self.max_timestamp = self.max_timestamp.max(ts);
215
216                if let Some(ref mut store) = self.metadata_store {
217                    store.push_empty();
218                }
219            }
220        }
221
222        n
223    }
224
225    /// Insert a temporal point with an outcome reward.
226    ///
227    /// `reward` annotates this point with an outcome signal (e.g., 0.0-1.0).
228    /// Use `f32::NAN` for "no reward assigned".
229    pub fn insert_with_reward(
230        &mut self,
231        entity_id: u64,
232        timestamp: i64,
233        vector: &[f32],
234        reward: f32,
235    ) -> u32 {
236        let node_id = self.graph.len() as u32;
237        self.graph.insert(node_id, vector);
238        self.timestamps.push(timestamp);
239        self.entity_ids.push(entity_id);
240        self.rewards.push(reward);
241
242        // Update entity index
243        self.entity_index
244            .entry(entity_id)
245            .or_default()
246            .push((timestamp, node_id));
247
248        // Update temporal range
249        self.min_timestamp = self.min_timestamp.min(timestamp);
250        self.max_timestamp = self.max_timestamp.max(timestamp);
251
252        // Store metadata (empty if store not enabled)
253        if let Some(ref mut store) = self.metadata_store {
254            store.push_empty();
255        }
256
257        node_id
258    }
259
260    /// Insert a temporal point with metadata.
261    pub fn insert_with_metadata(
262        &mut self,
263        entity_id: u64,
264        timestamp: i64,
265        vector: &[f32],
266        metadata: std::collections::HashMap<String, String>,
267    ) -> u32 {
268        // Enable metadata store on first metadata insert
269        if self.metadata_store.is_none() {
270            let mut store = super::metadata_store::MetadataStore::new();
271            // Backfill empty entries for existing nodes
272            for _ in 0..self.graph.len() {
273                store.push_empty();
274            }
275            self.metadata_store = Some(store);
276        }
277
278        let node_id = self.graph.len() as u32;
279        self.graph.insert(node_id, vector);
280        self.timestamps.push(timestamp);
281        self.entity_ids.push(entity_id);
282        self.rewards.push(f32::NAN);
283
284        self.entity_index
285            .entry(entity_id)
286            .or_default()
287            .push((timestamp, node_id));
288
289        self.min_timestamp = self.min_timestamp.min(timestamp);
290        self.max_timestamp = self.max_timestamp.max(timestamp);
291
292        if let Some(ref mut store) = self.metadata_store {
293            store.push(metadata);
294        }
295
296        node_id
297    }
298
299    /// Get metadata for a node. Returns empty map if metadata store not enabled.
300    pub fn node_metadata(&self, node_id: u32) -> std::collections::HashMap<String, String> {
301        self.metadata_store
302            .as_ref()
303            .map(|s| s.get(node_id).clone())
304            .unwrap_or_default()
305    }
306
307    /// Build a Roaring Bitmap of node IDs matching the temporal filter.
308    pub fn build_filter_bitmap(&self, filter: &TemporalFilter) -> RoaringBitmap {
309        let mut bitmap = RoaringBitmap::new();
310        for (i, &ts) in self.timestamps.iter().enumerate() {
311            if filter.matches(ts) {
312                bitmap.insert(i as u32);
313            }
314        }
315        bitmap
316    }
317
318    /// Compute normalized temporal distance between two timestamps.
319    ///
320    /// Returns a value in `[0.0, 1.0]` where 0 = same timestamp, 1 = max range.
321    pub fn temporal_distance_normalized(&self, t1: i64, t2: i64) -> f32 {
322        let range = (self.max_timestamp - self.min_timestamp).max(1) as f64;
323        let diff = (t1 as f64 - t2 as f64).abs();
324        (diff / range) as f32
325    }
326
327    /// Normalize semantic distance to [0, 1] range (RFC-012 P8).
328    ///
329    /// Cosine distance ∈ [0, 2], L2 distance ∈ [0, ∞). This clamps and
330    /// scales to [0, 1] so it's comparable with temporal distance [0, 1].
331    pub(crate) fn normalize_semantic_distance(&self, d: f32) -> f32 {
332        // Cosine: [0, 2] → [0, 1] by halving. L2: clamp to [0, 4] then /4.
333        // Both produce [0, 1]. For most embeddings, distances rarely exceed 2.
334        (d / 2.0).min(1.0)
335    }
336
337    /// Compute recency penalty for a node (RFC-012 P7).
338    ///
339    /// Returns a value in `[0.0, 1.0]` where 0 = most recent, 1 = oldest.
340    /// Uses exponential decay: `1 - exp(-λ · age)` where age is normalized.
341    ///
342    /// `recency_lambda` controls decay speed:
343    /// - λ = 0: no recency effect
344    /// - λ = 1: moderate decay
345    /// - λ = 3: strong decay (old nodes heavily penalized)
346    pub(crate) fn recency_penalty(&self, node_timestamp: i64, recency_lambda: f32) -> f32 {
347        if recency_lambda <= 0.0 {
348            return 0.0;
349        }
350        let age = self.temporal_distance_normalized(node_timestamp, self.max_timestamp);
351        1.0 - (-recency_lambda * age).exp()
352    }
353
354    /// Search with full composite scoring (RFC-012 P7 + P8).
355    ///
356    /// Enhanced distance: `d = α·d_sem_norm + (1-α)·d_temporal + γ·recency`
357    ///
358    /// - `alpha`: semantic vs temporal weight (1.0 = pure semantic)
359    /// - `recency_lambda`: recency decay strength (0.0 = off, 1.0 = moderate, 3.0 = strong)
360    /// - `recency_weight`: weight of recency term in composite score (0.0-1.0)
361    #[allow(clippy::too_many_arguments)]
362    pub fn search_with_recency(
363        &self,
364        query: &[f32],
365        k: usize,
366        filter: TemporalFilter,
367        alpha: f32,
368        query_timestamp: i64,
369        recency_lambda: f32,
370        recency_weight: f32,
371    ) -> Vec<(u32, f32)> {
372        if self.is_empty() {
373            return Vec::new();
374        }
375
376        let bitmap = self.build_filter_bitmap(&filter);
377        if bitmap.is_empty() {
378            return Vec::new();
379        }
380
381        let over_fetch = k * 4;
382        let candidates = self
383            .graph
384            .search_filtered(query, over_fetch, |id| bitmap.contains(id));
385
386        let mut scored: Vec<(u32, f32)> = candidates
387            .into_iter()
388            .map(|(id, sem_dist)| {
389                let sem_norm = self.normalize_semantic_distance(sem_dist);
390                let t_dist = self
391                    .temporal_distance_normalized(self.timestamps[id as usize], query_timestamp);
392                let recency = self.recency_penalty(self.timestamps[id as usize], recency_lambda);
393
394                let combined = alpha * sem_norm + (1.0 - alpha) * t_dist + recency_weight * recency;
395                (id, combined)
396            })
397            .collect();
398
399        scored.sort_by(|a, b| a.1.total_cmp(&b.1));
400        scored.truncate(k);
401        scored
402    }
403
404    // ─── Outcome / Reward (RFC-012 P4) ──────────────────────────────
405
406    /// Get the reward for a node. Returns NaN if no reward was assigned.
407    pub fn reward(&self, node_id: u32) -> f32 {
408        self.rewards
409            .get(node_id as usize)
410            .copied()
411            .unwrap_or(f32::NAN)
412    }
413
414    /// Set the reward for a node retroactively.
415    ///
416    /// Useful for annotating outcomes after an episode completes.
417    pub fn set_reward(&mut self, node_id: u32, reward: f32) {
418        if let Some(r) = self.rewards.get_mut(node_id as usize) {
419            *r = reward;
420        }
421    }
422
423    /// Build a bitmap of node_ids with reward >= min_reward.
424    pub fn build_reward_bitmap(&self, min_reward: f32) -> RoaringBitmap {
425        let mut bitmap = RoaringBitmap::new();
426        for (i, &r) in self.rewards.iter().enumerate() {
427            if !r.is_nan() && r >= min_reward {
428                bitmap.insert(i as u32);
429            }
430        }
431        bitmap
432    }
433
434    /// Search with reward filtering: only return nodes with reward >= min_reward.
435    ///
436    /// Combines temporal filter + reward filter as bitmap pre-filter.
437    pub fn search_with_reward(
438        &self,
439        query: &[f32],
440        k: usize,
441        filter: TemporalFilter,
442        alpha: f32,
443        query_timestamp: i64,
444        min_reward: f32,
445    ) -> Vec<(u32, f32)> {
446        if self.is_empty() {
447            return Vec::new();
448        }
449
450        let temporal_bitmap = self.build_filter_bitmap(&filter);
451        let reward_bitmap = self.build_reward_bitmap(min_reward);
452        let combined = temporal_bitmap & reward_bitmap;
453
454        if combined.is_empty() {
455            return Vec::new();
456        }
457
458        let candidates = self
459            .graph
460            .search_filtered(query, k, |id| combined.contains(id));
461
462        if alpha >= 1.0 {
463            return candidates;
464        }
465
466        let mut scored: Vec<(u32, f32)> = candidates
467            .into_iter()
468            .map(|(id, sem_dist)| {
469                let t_dist = self
470                    .temporal_distance_normalized(self.timestamps[id as usize], query_timestamp);
471                (id, alpha * sem_dist + (1.0 - alpha) * t_dist)
472            })
473            .collect();
474
475        scored.sort_by(|a, b| a.1.total_cmp(&b.1));
476        scored.truncate(k);
477        scored
478    }
479
480    // ─── Centering (RFC-012 Part B) ──────────────────────────────────
481
482    /// Compute the centroid (mean vector) of all indexed vectors.
483    ///
484    /// Single O(N×D) pass over stored vectors. Returns `None` if the index
485    /// is empty.
486    pub fn compute_centroid(&self) -> Option<Vec<f32>> {
487        let n = self.graph.len();
488        if n == 0 {
489            return None;
490        }
491
492        let dim = self.graph.vector(0).len();
493        let mut sum = vec![0.0f64; dim];
494
495        for i in 0..n {
496            let v = self.graph.vector(i as u32);
497            for (s, &val) in sum.iter_mut().zip(v.iter()) {
498                *s += val as f64;
499            }
500        }
501
502        let inv_n = 1.0 / n as f64;
503        Some(sum.into_iter().map(|s| (s * inv_n) as f32).collect())
504    }
505
506    /// Set the centroid for anisotropy correction.
507    ///
508    /// Once set, `centered_vector()` subtracts this from any vector,
509    /// and search operations use centered distances. The centroid is
510    /// serialized with the index snapshot.
511    ///
512    /// You can provide an externally computed centroid (e.g., from a
513    /// larger corpus) or use `compute_centroid()` for the index contents.
514    pub fn set_centroid(&mut self, centroid: Vec<f32>) {
515        self.centroid = Some(centroid);
516    }
517
518    /// Clear the centroid, reverting to raw (uncentered) distances.
519    pub fn clear_centroid(&mut self) {
520        self.centroid = None;
521    }
522
523    /// Get the current centroid, if set.
524    pub fn centroid(&self) -> Option<&[f32]> {
525        self.centroid.as_deref()
526    }
527
528    /// Return a centered copy of the given vector (vec - centroid).
529    ///
530    /// If no centroid is set, returns the vector unchanged (cloned).
531    pub fn centered_vector(&self, vec: &[f32]) -> Vec<f32> {
532        match &self.centroid {
533            Some(c) => vec.iter().zip(c.iter()).map(|(v, m)| v - m).collect(),
534            None => vec.to_vec(),
535        }
536    }
537
538    /// Search for the k nearest neighbors with temporal filtering and composite scoring.
539    ///
540    /// - `query`: the query vector
541    /// - `k`: number of results
542    /// - `filter`: temporal constraint (Snapshot, Range, Before, After, All)
543    /// - `alpha`: weight for semantic distance (1.0 = pure semantic, 0.0 = pure temporal)
544    /// - `query_timestamp`: reference timestamp for temporal distance computation
545    ///
546    /// Returns `(node_id, combined_score)` sorted by combined score ascending.
547    pub fn search(
548        &self,
549        query: &[f32],
550        k: usize,
551        filter: TemporalFilter,
552        alpha: f32,
553        query_timestamp: i64,
554    ) -> Vec<(u32, f32)> {
555        if self.is_empty() {
556            return Vec::new();
557        }
558
559        // Build bitmap of temporally valid nodes
560        let bitmap = self.build_filter_bitmap(&filter);
561        if bitmap.is_empty() {
562            return Vec::new();
563        }
564
565        if alpha >= 1.0 {
566            // Pure semantic: just filter, no re-ranking needed
567            return self
568                .graph
569                .search_filtered(query, k, |id| bitmap.contains(id));
570        }
571
572        // Get more candidates than needed for re-ranking
573        let over_fetch = k * 4;
574        let candidates = self
575            .graph
576            .search_filtered(query, over_fetch, |id| bitmap.contains(id));
577
578        // Re-rank with composite distance (P8: normalized scales)
579        let mut scored: Vec<(u32, f32)> = candidates
580            .into_iter()
581            .map(|(id, sem_dist)| {
582                let sem_norm = self.normalize_semantic_distance(sem_dist);
583                let t_dist = self
584                    .temporal_distance_normalized(self.timestamps[id as usize], query_timestamp);
585                let combined = alpha * sem_norm + (1.0 - alpha) * t_dist;
586                (id, combined)
587            })
588            .collect();
589
590        scored.sort_by(|a, b| a.1.total_cmp(&b.1));
591        scored.truncate(k);
592        scored
593    }
594
595    /// Retrieve the full trajectory for an entity within a time range.
596    ///
597    /// Returns `(timestamp, node_id)` pairs sorted by timestamp ascending.
598    pub fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
599        let Some(points) = self.entity_index.get(&entity_id) else {
600            return Vec::new();
601        };
602
603        let mut result: Vec<(i64, u32)> = points
604            .iter()
605            .filter(|&&(ts, _)| filter.matches(ts))
606            .copied()
607            .collect();
608
609        result.sort_by_key(|&(ts, _)| ts);
610        result
611    }
612
613    /// Get the timestamp for a node.
614    pub fn timestamp(&self, node_id: u32) -> i64 {
615        self.timestamps[node_id as usize]
616    }
617
618    /// Get the entity_id for a node.
619    pub fn entity_id(&self, node_id: u32) -> u64 {
620        self.entity_ids[node_id as usize]
621    }
622
623    /// Get the vector for a node.
624    pub fn vector(&self, node_id: u32) -> &[f32] {
625        self.graph.vector(node_id)
626    }
627
628    /// Approximate memory usage of the Roaring Bitmaps for a full-index filter.
629    ///
630    /// Useful for verifying the < 1 byte/vector target.
631    pub fn bitmap_memory_bytes(&self) -> usize {
632        let bitmap = self.build_filter_bitmap(&TemporalFilter::All);
633        bitmap.serialized_size()
634    }
635
636    /// Access the underlying HNSW graph (for recall comparisons, etc.).
637    pub fn graph(&self) -> &HnswGraph<D> {
638        &self.graph
639    }
640
641    // ─── Semantic Regions (RFC-004) ────────────────────────────────
642
643    /// Get semantic regions at a given HNSW level.
644    ///
645    /// Returns `(hub_node_id, hub_vector, n_assigned_nodes)` for each region.
646    /// Use level 2-3 for interpretable granularity (~N/M^L regions).
647    pub fn regions(&self, level: usize) -> Vec<(u32, Vec<f32>, usize)> {
648        let hubs = self.graph.nodes_at_level(level);
649        let n = self.graph.len();
650
651        // Count assignments: for each node, find nearest hub
652        let mut counts = vec![0usize; hubs.len()];
653        let hub_set: std::collections::HashMap<u32, usize> =
654            hubs.iter().enumerate().map(|(i, &h)| (h, i)).collect();
655
656        for node_id in 0..n as u32 {
657            if let Some(hub) = self.graph.assign_region(self.graph.vector(node_id), level) {
658                if let Some(&idx) = hub_set.get(&hub) {
659                    counts[idx] += 1;
660                }
661            }
662        }
663
664        hubs.iter()
665            .enumerate()
666            .map(|(i, &hub_id)| (hub_id, self.graph.vector(hub_id).to_vec(), counts[i]))
667            .collect()
668    }
669
670    /// Compute smoothed region-distribution trajectory for an entity (RFC-004).
671    ///
672    /// - `level`: HNSW level for region granularity
673    /// - `window_days`: sliding window in timestamp units (same scale as ingested timestamps)
674    /// - `alpha`: EMA smoothing factor (0.3 typical, higher = more reactive)
675    ///
676    /// Returns `(timestamp, distribution)` where distribution is a Vec<f32> of length K
677    /// (number of regions) that sums to ~1.0.
678    pub fn region_trajectory(
679        &self,
680        entity_id: u64,
681        level: usize,
682        window_days: i64,
683        alpha: f32,
684    ) -> Vec<(i64, Vec<f32>)> {
685        let hubs = self.graph.nodes_at_level(level);
686        let k = hubs.len();
687        if k == 0 {
688            return Vec::new();
689        }
690
691        // Map hub_node_id → region index
692        let hub_index: std::collections::HashMap<u32, usize> =
693            hubs.iter().enumerate().map(|(i, &h)| (h, i)).collect();
694
695        // Get entity's posts sorted by time
696        let posts = self.trajectory(entity_id, TemporalFilter::All);
697        if posts.is_empty() {
698            return Vec::new();
699        }
700
701        // Assign each post to a region
702        let assignments: Vec<(i64, usize)> = posts
703            .iter()
704            .filter_map(|&(ts, node_id)| {
705                let vec = self.graph.vector(node_id);
706                self.graph
707                    .assign_region(vec, level)
708                    .and_then(|hub| hub_index.get(&hub).map(|&idx| (ts, idx)))
709            })
710            .collect();
711
712        if assignments.is_empty() {
713            return Vec::new();
714        }
715
716        // Group by time windows
717        let t_start = assignments[0].0;
718        let t_end = assignments.last().unwrap().0;
719        let mut result = Vec::new();
720        let mut ema_state: Vec<f32> = vec![0.0; k];
721        let mut first = true;
722
723        let mut window_start = t_start;
724        while window_start <= t_end {
725            let window_end = window_start + window_days;
726
727            // Count posts per region in this window
728            let mut counts = vec![0.0f32; k];
729            let mut n_in_window = 0.0f32;
730            for &(ts, region_idx) in &assignments {
731                if ts >= window_start && ts < window_end {
732                    counts[region_idx] += 1.0;
733                    n_in_window += 1.0;
734                }
735            }
736
737            if n_in_window > 0.0 {
738                // Normalize to distribution
739                for c in &mut counts {
740                    *c /= n_in_window;
741                }
742
743                // EMA smoothing
744                if first {
745                    ema_state = counts;
746                    first = false;
747                } else {
748                    for i in 0..k {
749                        ema_state[i] = alpha * counts[i] + (1.0 - alpha) * ema_state[i];
750                    }
751                }
752
753                let mid_ts = window_start + window_days / 2;
754                result.push((mid_ts, ema_state.clone()));
755            }
756
757            window_start = window_end;
758        }
759
760        result
761    }
762
763    /// Get points assigned to a specific region, optionally time-filtered (RFC-004, RFC-005).
764    ///
765    /// Returns `(node_id, entity_id, timestamp)` for all points in the region.
766    /// This is the "SELECT * FROM points WHERE region = R" equivalent.
767    ///
768    /// **Performance warning**: This does a full scan of all nodes. For multiple regions,
769    /// use `region_assignments()` instead (single scan for all regions).
770    pub fn region_members(
771        &self,
772        region_hub: u32,
773        level: usize,
774        filter: TemporalFilter,
775    ) -> Vec<(u32, u64, i64)> {
776        let mut members = Vec::new();
777        for node_id in 0..self.graph.len() as u32 {
778            let ts = self.timestamps[node_id as usize];
779            if !filter.matches(ts) {
780                continue;
781            }
782            let vec = self.graph.vector(node_id);
783            if let Some(assigned_hub) = self.graph.assign_region(vec, level) {
784                if assigned_hub == region_hub {
785                    let eid = self.entity_ids[node_id as usize];
786                    members.push((node_id, eid, ts));
787                }
788            }
789        }
790        members
791    }
792
793    /// Assign ALL nodes to their regions in a single pass, optionally time-filtered.
794    ///
795    /// Returns a HashMap from hub_id → Vec<(entity_id, timestamp)>.
796    /// This is O(N) — one full scan instead of O(N × K) for K `region_members` calls.
797    pub fn region_assignments(
798        &self,
799        level: usize,
800        filter: TemporalFilter,
801    ) -> std::collections::HashMap<u32, Vec<(u64, i64)>> {
802        let mut assignments: std::collections::HashMap<u32, Vec<(u64, i64)>> =
803            std::collections::HashMap::new();
804
805        for node_id in 0..self.graph.len() as u32 {
806            let ts = self.timestamps[node_id as usize];
807            if !filter.matches(ts) {
808                continue;
809            }
810            let vec = self.graph.vector(node_id);
811            if let Some(hub) = self.graph.assign_region(vec, level) {
812                let eid = self.entity_ids[node_id as usize];
813                assignments.entry(hub).or_default().push((eid, ts));
814            }
815        }
816
817        assignments
818    }
819}
820
821/// Current snapshot format version. Increment when adding fields.
822const SNAPSHOT_VERSION: u32 = 2;
823
824/// Serializable snapshot of a TemporalHnsw index.
825#[derive(Serialize, Deserialize)]
826struct TemporalSnapshot {
827    /// Format version for forward compatibility (RFC-012 P5).
828    /// v1: original (graph, timestamps, entity_ids, entity_index, min/max_timestamp)
829    /// v2: + metadata_store, centroid, rewards
830    #[serde(default = "default_snapshot_version")]
831    version: u32,
832    graph: HnswSnapshot,
833    timestamps: Vec<i64>,
834    entity_ids: Vec<u64>,
835    entity_index: BTreeMap<u64, Vec<(i64, u32)>>,
836    min_timestamp: i64,
837    max_timestamp: i64,
838    #[serde(default)]
839    metadata_store: Option<super::metadata_store::MetadataStore>,
840    /// Centroid for anisotropy correction (RFC-012 Part B).
841    #[serde(default)]
842    centroid: Option<Vec<f32>>,
843    /// Per-node reward for outcome-aware search (RFC-012 P4).
844    #[serde(default)]
845    rewards: Vec<f32>,
846}
847
848fn default_snapshot_version() -> u32 {
849    1 // Old snapshots without version field default to v1
850}
851
852impl<D: DistanceMetric> TemporalHnsw<D> {
853    /// Save the index to a file using postcard binary encoding.
854    ///
855    /// The distance metric is NOT serialized (it's stateless).
856    /// On load, you must provide the same metric type.
857    pub fn save(&self, path: &Path) -> std::io::Result<()> {
858        let snapshot = TemporalSnapshot {
859            version: SNAPSHOT_VERSION,
860            graph: self.graph.to_snapshot(),
861            timestamps: self.timestamps.clone(),
862            entity_ids: self.entity_ids.clone(),
863            entity_index: self.entity_index.clone(),
864            min_timestamp: self.min_timestamp,
865            max_timestamp: self.max_timestamp,
866            metadata_store: self.metadata_store.clone(),
867            centroid: self.centroid.clone(),
868            rewards: self.rewards.clone(),
869        };
870
871        let bytes = postcard::to_allocvec(&snapshot).map_err(std::io::Error::other)?;
872
873        let mut file = std::fs::File::create(path)?;
874        file.write_all(&bytes)?;
875        Ok(())
876    }
877
878    /// Load an index from a file, providing the distance metric.
879    ///
880    /// Supports all snapshot versions. Unknown future versions produce
881    /// a clear error instead of silent corruption.
882    pub fn load(path: &Path, metric: D) -> std::io::Result<Self> {
883        let mut file = std::fs::File::open(path)?;
884        let mut bytes = Vec::new();
885        file.read_to_end(&mut bytes)?;
886
887        let snapshot: TemporalSnapshot = postcard::from_bytes(&bytes)
888            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
889
890        if snapshot.version > SNAPSHOT_VERSION {
891            return Err(std::io::Error::new(
892                std::io::ErrorKind::InvalidData,
893                format!(
894                    "Snapshot version {} is newer than supported version {}. \
895                     Please upgrade chronos-vector.",
896                    snapshot.version, SNAPSHOT_VERSION
897                ),
898            ));
899        }
900
901        let n_points = snapshot.timestamps.len();
902        let rewards = if snapshot.rewards.is_empty() {
903            // Backward compat: old snapshots have no rewards → fill with NaN
904            vec![f32::NAN; n_points]
905        } else {
906            snapshot.rewards
907        };
908
909        Ok(Self {
910            graph: HnswGraph::from_snapshot(snapshot.graph, metric),
911            timestamps: snapshot.timestamps,
912            entity_ids: snapshot.entity_ids,
913            entity_index: snapshot.entity_index,
914            min_timestamp: snapshot.min_timestamp,
915            max_timestamp: snapshot.max_timestamp,
916            metadata_store: snapshot.metadata_store,
917            centroid: snapshot.centroid,
918            rewards,
919        })
920    }
921}
922
923impl<D: DistanceMetric> cvx_core::TemporalIndexAccess for TemporalHnsw<D> {
924    fn search_raw(
925        &self,
926        query: &[f32],
927        k: usize,
928        filter: TemporalFilter,
929        alpha: f32,
930        query_timestamp: i64,
931    ) -> Vec<(u32, f32)> {
932        self.search(query, k, filter, alpha, query_timestamp)
933    }
934
935    fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, u32)> {
936        self.trajectory(entity_id, filter)
937    }
938
939    fn vector(&self, node_id: u32) -> Vec<f32> {
940        self.graph.vector(node_id).to_vec()
941    }
942
943    fn entity_id(&self, node_id: u32) -> u64 {
944        self.entity_ids[node_id as usize]
945    }
946
947    fn timestamp(&self, node_id: u32) -> i64 {
948        self.timestamps[node_id as usize]
949    }
950
951    fn len(&self) -> usize {
952        self.graph.len()
953    }
954
955    fn regions(&self, level: usize) -> Vec<(u32, Vec<f32>, usize)> {
956        self.regions(level)
957    }
958
959    fn region_members(
960        &self,
961        region_hub: u32,
962        level: usize,
963        filter: TemporalFilter,
964    ) -> Vec<(u32, u64, i64)> {
965        self.region_members(region_hub, level, filter)
966    }
967
968    fn region_trajectory(
969        &self,
970        entity_id: u64,
971        level: usize,
972        window_days: i64,
973        alpha: f32,
974    ) -> Vec<(i64, Vec<f32>)> {
975        self.region_trajectory(entity_id, level, window_days, alpha)
976    }
977
978    fn metadata(&self, node_id: u32) -> std::collections::HashMap<String, String> {
979        self.node_metadata(node_id)
980    }
981
982    fn search_with_metadata(
983        &self,
984        query: &[f32],
985        k: usize,
986        filter: TemporalFilter,
987        alpha: f32,
988        query_timestamp: i64,
989        metadata_filter: &cvx_core::types::MetadataFilter,
990    ) -> Vec<(u32, f32)> {
991        if metadata_filter.is_empty() {
992            return self.search(query, k, filter, alpha, query_timestamp);
993        }
994
995        match &self.metadata_store {
996            Some(store) => {
997                // Pre-filter: build combined temporal + metadata bitmap
998                let temporal_bitmap = self.build_filter_bitmap(&filter);
999                let metadata_bitmap = store.build_filter_bitmap(metadata_filter);
1000                let combined = temporal_bitmap & metadata_bitmap;
1001
1002                if combined.is_empty() {
1003                    return Vec::new();
1004                }
1005
1006                // Search with combined bitmap
1007                let candidates = self
1008                    .graph
1009                    .search_filtered(query, k, |id| combined.contains(id));
1010
1011                if alpha >= 1.0 {
1012                    return candidates;
1013                }
1014
1015                // Re-rank with composite distance
1016                let mut scored: Vec<(u32, f32)> = candidates
1017                    .into_iter()
1018                    .map(|(id, sem_dist)| {
1019                        let t_dist = self.temporal_distance_normalized(
1020                            self.timestamps[id as usize],
1021                            query_timestamp,
1022                        );
1023                        let combined_score = alpha * sem_dist + (1.0 - alpha) * t_dist;
1024                        (id, combined_score)
1025                    })
1026                    .collect();
1027
1028                scored.sort_by(|a, b| a.1.total_cmp(&b.1));
1029                scored.truncate(k);
1030                scored
1031            }
1032            None => {
1033                // No metadata store: fall back to search without metadata
1034                self.search(query, k, filter, alpha, query_timestamp)
1035            }
1036        }
1037    }
1038}
1039
1040#[cfg(test)]
1041mod tests {
1042    use super::*;
1043    use crate::metrics::{CosineDistance, L2Distance};
1044
1045    fn make_temporal_index() -> TemporalHnsw<L2Distance> {
1046        let config = HnswConfig {
1047            m: 16,
1048            ef_construction: 200,
1049            ef_search: 100,
1050            ..Default::default()
1051        };
1052        TemporalHnsw::new(config, L2Distance)
1053    }
1054
1055    // ─── Basic functionality ────────────────────────────────────────────
1056
1057    #[test]
1058    fn empty_index() {
1059        let index = make_temporal_index();
1060        assert!(index.is_empty());
1061        assert_eq!(index.len(), 0);
1062        let results = index.search(&[1.0, 0.0], 5, TemporalFilter::All, 1.0, 0);
1063        assert!(results.is_empty());
1064    }
1065
1066    #[test]
1067    fn insert_and_metadata() {
1068        let mut index = make_temporal_index();
1069        let id = index.insert(42, 1000, &[1.0, 0.0, 0.0]);
1070        assert_eq!(id, 0);
1071        assert_eq!(index.len(), 1);
1072        assert_eq!(index.timestamp(0), 1000);
1073        assert_eq!(index.entity_id(0), 42);
1074        assert_eq!(index.vector(0), &[1.0, 0.0, 0.0]);
1075    }
1076
1077    // ─── Snapshot kNN ───────────────────────────────────────────────────
1078
1079    #[test]
1080    fn snapshot_knn_returns_only_matching_timestamp() {
1081        let mut index = make_temporal_index();
1082        // Entity 1 at t=1000
1083        index.insert(1, 1000, &[1.0, 0.0]);
1084        // Entity 2 at t=2000
1085        index.insert(2, 2000, &[0.9, 0.1]);
1086        // Entity 3 at t=1000
1087        index.insert(3, 1000, &[0.8, 0.2]);
1088
1089        let results = index.search(&[1.0, 0.0], 10, TemporalFilter::Snapshot(1000), 1.0, 1000);
1090        assert_eq!(results.len(), 2);
1091        // Both results should have timestamp 1000
1092        for &(id, _) in &results {
1093            assert_eq!(index.timestamp(id), 1000);
1094        }
1095    }
1096
1097    #[test]
1098    fn snapshot_knn_no_match_returns_empty() {
1099        let mut index = make_temporal_index();
1100        index.insert(1, 1000, &[1.0, 0.0]);
1101        index.insert(2, 2000, &[0.9, 0.1]);
1102
1103        let results = index.search(&[1.0, 0.0], 10, TemporalFilter::Snapshot(5000), 1.0, 5000);
1104        assert!(results.is_empty());
1105    }
1106
1107    // ─── Range kNN ──────────────────────────────────────────────────────
1108
1109    #[test]
1110    fn range_knn_returns_only_in_range() {
1111        let mut index = make_temporal_index();
1112        index.insert(1, 1000, &[1.0, 0.0]);
1113        index.insert(2, 2000, &[0.9, 0.1]);
1114        index.insert(3, 3000, &[0.8, 0.2]);
1115        index.insert(4, 4000, &[0.7, 0.3]);
1116
1117        let results = index.search(
1118            &[1.0, 0.0],
1119            10,
1120            TemporalFilter::Range(1500, 3500),
1121            1.0,
1122            2000,
1123        );
1124
1125        // Only t=2000 and t=3000 should match
1126        assert_eq!(results.len(), 2);
1127        for &(id, _) in &results {
1128            let ts = index.timestamp(id);
1129            assert!((1500..=3500).contains(&ts), "timestamp {ts} out of range");
1130        }
1131    }
1132
1133    // ─── Composite distance ─────────────────────────────────────────────
1134
1135    #[test]
1136    fn alpha_1_is_pure_semantic() {
1137        let mut index = make_temporal_index();
1138        // Same vector, different times
1139        index.insert(1, 1000, &[1.0, 0.0]);
1140        index.insert(2, 5000, &[0.99, 0.01]);
1141        index.insert(3, 100, &[0.0, 1.0]);
1142
1143        let results = index.search(&[1.0, 0.0], 3, TemporalFilter::All, 1.0, 1000);
1144        // Pure semantic: [1.0, 0.0] is closest to itself, then [0.99, 0.01], then [0.0, 1.0]
1145        assert_eq!(results[0].0, 0); // entity 1
1146        assert_eq!(results[1].0, 1); // entity 2
1147        assert_eq!(results[2].0, 2); // entity 3
1148    }
1149
1150    #[test]
1151    fn alpha_0_5_prefers_temporally_closer() {
1152        let mut index = make_temporal_index();
1153        // Two vectors equidistant semantically but at different timestamps
1154        index.insert(1, 1000, &[1.0, 0.0, 0.0]); // far in time
1155        index.insert(2, 5000, &[1.0, 0.0, 0.0]); // close in time
1156
1157        let query_ts = 4900;
1158        let results = index.search(&[1.0, 0.0, 0.0], 2, TemporalFilter::All, 0.5, query_ts);
1159
1160        // With alpha=0.5 and equal semantic distance, the temporally closer one wins
1161        assert_eq!(results[0].0, 1); // entity 2 at t=5000 is closer to query_ts=4900
1162        assert_eq!(results[1].0, 0); // entity 1 at t=1000 is farther
1163    }
1164
1165    #[test]
1166    fn alpha_0_5_returns_temporally_closer_than_alpha_1() {
1167        let mut index = make_temporal_index();
1168        let dim = 8;
1169        let mut rng = rand::rng();
1170
1171        // Insert 100 points at various timestamps
1172        for i in 0..100u64 {
1173            let ts = (i as i64) * 1000;
1174            let v: Vec<f32> = (0..dim)
1175                .map(|_| rand::Rng::random::<f32>(&mut rng))
1176                .collect();
1177            index.insert(i, ts, &v);
1178        }
1179
1180        let query: Vec<f32> = (0..dim)
1181            .map(|_| rand::Rng::random::<f32>(&mut rng))
1182            .collect();
1183        let query_ts = 50_000; // middle of the range
1184        let k = 10;
1185
1186        let results_pure = index.search(&query, k, TemporalFilter::All, 1.0, query_ts);
1187        let results_mixed = index.search(&query, k, TemporalFilter::All, 0.5, query_ts);
1188
1189        // Average temporal distance of results
1190        let avg_tdist_pure: f64 = results_pure
1191            .iter()
1192            .map(|&(id, _)| (index.timestamp(id) - query_ts).unsigned_abs() as f64)
1193            .sum::<f64>()
1194            / k as f64;
1195        let avg_tdist_mixed: f64 = results_mixed
1196            .iter()
1197            .map(|&(id, _)| (index.timestamp(id) - query_ts).unsigned_abs() as f64)
1198            .sum::<f64>()
1199            / k as f64;
1200
1201        assert!(
1202            avg_tdist_mixed <= avg_tdist_pure,
1203            "alpha=0.5 avg temporal dist ({avg_tdist_mixed:.0}) should be <= alpha=1.0 ({avg_tdist_pure:.0})"
1204        );
1205    }
1206
1207    // ─── Alpha=1.0 parity with vanilla HNSW ────────────────────────────
1208
1209    #[test]
1210    fn alpha_1_matches_vanilla_recall() {
1211        let dim = 32;
1212        let n = 1000u32;
1213        let k = 10;
1214        let config = HnswConfig {
1215            m: 16,
1216            ef_construction: 200,
1217            ef_search: 100,
1218            ..Default::default()
1219        };
1220
1221        let mut temporal = TemporalHnsw::new(config, L2Distance);
1222        let mut rng = rand::rng();
1223        let vectors: Vec<Vec<f32>> = (0..n)
1224            .map(|_| {
1225                (0..dim)
1226                    .map(|_| rand::Rng::random::<f32>(&mut rng))
1227                    .collect()
1228            })
1229            .collect();
1230
1231        for (i, v) in vectors.iter().enumerate() {
1232            temporal.insert(i as u64, (i as i64) * 100, v);
1233        }
1234
1235        // Compare temporal search (alpha=1.0, All) with brute force
1236        let n_queries = 50;
1237        let mut total_recall = 0.0;
1238
1239        for _ in 0..n_queries {
1240            let query: Vec<f32> = (0..dim)
1241                .map(|_| rand::Rng::random::<f32>(&mut rng))
1242                .collect();
1243            let temporal_results = temporal.search(&query, k, TemporalFilter::All, 1.0, 0);
1244            let truth = temporal.graph().brute_force_knn(&query, k);
1245            let recall = super::super::recall_at_k(&temporal_results, &truth);
1246            total_recall += recall;
1247        }
1248
1249        let avg_recall = total_recall / n_queries as f64;
1250        assert!(
1251            avg_recall >= 0.90,
1252            "alpha=1.0 recall = {avg_recall:.3}, expected >= 0.90 (vanilla parity)"
1253        );
1254    }
1255
1256    // ─── Temporal filtering recall ──────────────────────────────────────
1257
1258    #[test]
1259    fn range_knn_recall() {
1260        let dim = 32;
1261        let n = 1000u32;
1262        let k = 10;
1263        let config = HnswConfig {
1264            m: 16,
1265            ef_construction: 200,
1266            ef_search: 200,
1267            ..Default::default()
1268        };
1269
1270        let mut index = TemporalHnsw::new(config, L2Distance);
1271        let mut rng = rand::rng();
1272
1273        for i in 0..n {
1274            let ts = (i as i64) * 100;
1275            let v: Vec<f32> = (0..dim)
1276                .map(|_| rand::Rng::random::<f32>(&mut rng))
1277                .collect();
1278            index.insert(i as u64, ts, &v);
1279        }
1280
1281        // Filter to middle 50% of timestamps
1282        let filter = TemporalFilter::Range(25_000, 75_000);
1283        let bitmap = index.build_filter_bitmap(&filter);
1284
1285        let n_queries = 50;
1286        let mut total_recall = 0.0;
1287
1288        for _ in 0..n_queries {
1289            let query: Vec<f32> = (0..dim)
1290                .map(|_| rand::Rng::random::<f32>(&mut rng))
1291                .collect();
1292            let results = index.search(&query, k, filter, 1.0, 50_000);
1293
1294            // Brute-force ground truth within filter
1295            let mut truth: Vec<(u32, f32)> = (0..n)
1296                .filter(|&i| bitmap.contains(i))
1297                .map(|i| {
1298                    (
1299                        i,
1300                        index
1301                            .graph()
1302                            .brute_force_knn(&query, n as usize)
1303                            .iter()
1304                            .find(|&&(id, _)| id == i)
1305                            .map(|&(_, d)| d)
1306                            .unwrap_or(f32::INFINITY),
1307                    )
1308                })
1309                .collect();
1310            truth.sort_by(|a, b| a.1.total_cmp(&b.1));
1311            truth.truncate(k);
1312
1313            total_recall += super::super::recall_at_k(&results, &truth);
1314        }
1315
1316        let avg_recall = total_recall / n_queries as f64;
1317        assert!(
1318            avg_recall >= 0.90,
1319            "range kNN recall = {avg_recall:.3}, expected >= 0.90"
1320        );
1321    }
1322
1323    // ─── Trajectory retrieval ───────────────────────────────────────────
1324
1325    #[test]
1326    fn trajectory_returns_all_entity_points_ordered() {
1327        let mut index = make_temporal_index();
1328
1329        // Insert 100 points for entity 1, interleaved with other entities
1330        for i in 0..100u32 {
1331            index.insert(1, (i as i64) * 1000, &[i as f32, 0.0]);
1332            index.insert(2, (i as i64) * 1000 + 500, &[0.0, i as f32]);
1333        }
1334
1335        let traj = index.trajectory(1, TemporalFilter::All);
1336        assert_eq!(traj.len(), 100);
1337
1338        // Verify ordering
1339        for window in traj.windows(2) {
1340            assert!(
1341                window[0].0 <= window[1].0,
1342                "trajectory not ordered: {} > {}",
1343                window[0].0,
1344                window[1].0
1345            );
1346        }
1347
1348        // Verify all belong to entity 1
1349        for &(_, node_id) in &traj {
1350            assert_eq!(index.entity_id(node_id), 1);
1351        }
1352    }
1353
1354    #[test]
1355    fn trajectory_with_range_filter() {
1356        let mut index = make_temporal_index();
1357
1358        for i in 0..50u32 {
1359            index.insert(1, (i as i64) * 100, &[i as f32]);
1360        }
1361
1362        let traj = index.trajectory(1, TemporalFilter::Range(1000, 3000));
1363
1364        // timestamps 1000, 1100, ..., 3000 → 21 points
1365        assert_eq!(traj.len(), 21);
1366        for &(ts, _) in &traj {
1367            assert!((1000..=3000).contains(&ts));
1368        }
1369    }
1370
1371    #[test]
1372    fn trajectory_unknown_entity_returns_empty() {
1373        let mut index = make_temporal_index();
1374        index.insert(1, 1000, &[1.0]);
1375        assert!(index.trajectory(999, TemporalFilter::All).is_empty());
1376    }
1377
1378    // ─── Roaring Bitmap memory ──────────────────────────────────────────
1379
1380    #[test]
1381    fn bitmap_memory_under_1_byte_per_vector() {
1382        let mut index = make_temporal_index();
1383
1384        // Insert 10K points
1385        for i in 0..10_000u32 {
1386            index.insert(i as u64, i as i64, &[i as f32]);
1387        }
1388
1389        let mem = index.bitmap_memory_bytes();
1390        let bytes_per_vector = mem as f64 / 10_000.0;
1391        assert!(
1392            bytes_per_vector < 1.0,
1393            "bitmap uses {bytes_per_vector:.2} bytes/vector, expected < 1.0"
1394        );
1395    }
1396
1397    // ─── Before/After filters ───────────────────────────────────────────
1398
1399    #[test]
1400    fn before_filter() {
1401        let mut index = make_temporal_index();
1402        index.insert(1, 1000, &[1.0, 0.0]);
1403        index.insert(2, 2000, &[0.9, 0.1]);
1404        index.insert(3, 3000, &[0.8, 0.2]);
1405
1406        let results = index.search(&[1.0, 0.0], 10, TemporalFilter::Before(2000), 1.0, 1000);
1407        assert_eq!(results.len(), 2);
1408        for &(id, _) in &results {
1409            assert!(index.timestamp(id) <= 2000);
1410        }
1411    }
1412
1413    #[test]
1414    fn after_filter() {
1415        let mut index = make_temporal_index();
1416        index.insert(1, 1000, &[1.0, 0.0]);
1417        index.insert(2, 2000, &[0.9, 0.1]);
1418        index.insert(3, 3000, &[0.8, 0.2]);
1419
1420        let results = index.search(&[1.0, 0.0], 10, TemporalFilter::After(2000), 1.0, 3000);
1421        assert_eq!(results.len(), 2);
1422        for &(id, _) in &results {
1423            assert!(index.timestamp(id) >= 2000);
1424        }
1425    }
1426
1427    // ─── Cosine metric works ────────────────────────────────────────────
1428
1429    #[test]
1430    fn works_with_cosine_metric() {
1431        let config = HnswConfig {
1432            m: 16,
1433            ef_construction: 100,
1434            ef_search: 50,
1435            ..Default::default()
1436        };
1437        let mut index = TemporalHnsw::new(config, CosineDistance);
1438
1439        index.insert(1, 1000, &[1.0, 0.0, 0.0]);
1440        index.insert(2, 2000, &[0.99, 0.01, 0.0]);
1441        index.insert(3, 3000, &[0.0, 0.0, 1.0]);
1442
1443        let results = index.search(&[1.0, 0.0, 0.0], 2, TemporalFilter::All, 1.0, 0);
1444        assert_eq!(results[0].0, 0);
1445        assert_eq!(results[1].0, 1);
1446    }
1447
1448    // ─── Metadata integration ───────────────────────────────────────
1449
1450    #[test]
1451    fn insert_with_metadata_stores_and_retrieves() {
1452        let config = HnswConfig::default();
1453        let mut index = TemporalHnsw::new(config, L2Distance);
1454
1455        let mut meta = std::collections::HashMap::new();
1456        meta.insert("reward".to_string(), "0.8".to_string());
1457        meta.insert("step_index".to_string(), "0".to_string());
1458
1459        let id = index.insert_with_metadata(1, 1000, &[1.0, 0.0, 0.0], meta);
1460
1461        let retrieved = index.node_metadata(id);
1462        assert_eq!(retrieved.get("reward").unwrap(), "0.8");
1463        assert_eq!(retrieved.get("step_index").unwrap(), "0");
1464    }
1465
1466    #[test]
1467    fn insert_with_metadata_enables_store_lazily() {
1468        let config = HnswConfig::default();
1469        let mut index = TemporalHnsw::new(config, L2Distance);
1470
1471        // First insert without metadata
1472        index.insert(1, 1000, &[1.0, 0.0, 0.0]);
1473
1474        // Second insert with metadata — should enable store and backfill
1475        let mut meta = std::collections::HashMap::new();
1476        meta.insert("reward".to_string(), "0.9".to_string());
1477        let id = index.insert_with_metadata(2, 2000, &[0.0, 1.0, 0.0], meta);
1478
1479        // First node should have empty metadata
1480        assert!(index.node_metadata(0).is_empty());
1481        // Second node should have metadata
1482        assert_eq!(index.node_metadata(id).get("reward").unwrap(), "0.9");
1483    }
1484
1485    #[test]
1486    fn search_with_metadata_filter() {
1487        use cvx_core::TemporalIndexAccess;
1488        use cvx_core::types::MetadataFilter;
1489
1490        let config = HnswConfig {
1491            m: 16,
1492            ef_construction: 100,
1493            ef_search: 50,
1494            ..Default::default()
1495        };
1496        let mut index = TemporalHnsw::new(config, L2Distance);
1497
1498        // Insert 10 points: 5 with reward >= 0.5, 5 with reward < 0.5
1499        for i in 0..10u64 {
1500            let mut meta = std::collections::HashMap::new();
1501            meta.insert("reward".to_string(), format!("{}", i as f64 * 0.1));
1502            meta.insert("step_index".to_string(), "0".to_string());
1503            index.insert_with_metadata(i, i as i64 * 1000, &[i as f32, 0.0, 0.0], meta);
1504        }
1505
1506        // Search with metadata filter: reward >= 0.5
1507        let filter = MetadataFilter::new().gte("reward", 0.5);
1508        let results =
1509            index.search_with_metadata(&[7.0, 0.0, 0.0], 5, TemporalFilter::All, 1.0, 0, &filter);
1510
1511        // All results should have reward >= 0.5
1512        for &(nid, _) in &results {
1513            let meta = index.node_metadata(nid);
1514            let reward: f64 = meta.get("reward").unwrap().parse().unwrap();
1515            assert!(reward >= 0.5, "node {nid} has reward {reward} < 0.5");
1516        }
1517        assert!(!results.is_empty(), "should find some results");
1518    }
1519
1520    // ─── Region assignments ──────────────────────────────────────────
1521
1522    /// Build an index with enough points so that level-1 hubs exist.
1523    fn make_region_index() -> TemporalHnsw<L2Distance> {
1524        let config = HnswConfig {
1525            m: 4,
1526            ef_construction: 50,
1527            ef_search: 50,
1528            ..Default::default()
1529        };
1530        let mut index = TemporalHnsw::new(config, L2Distance);
1531        let mut rng = rand::rng();
1532        // 200 points across 4 entities, timestamps 0..199_000
1533        for i in 0..200u64 {
1534            let v: Vec<f32> = (0..8).map(|_| rand::Rng::random::<f32>(&mut rng)).collect();
1535            let entity = i % 4;
1536            index.insert(entity, i as i64 * 1000, &v);
1537        }
1538        index
1539    }
1540
1541    #[test]
1542    fn region_assignments_covers_all_nodes() {
1543        let index = make_region_index();
1544        let level = 1;
1545        let assignments = index.region_assignments(level, TemporalFilter::All);
1546
1547        let total: usize = assignments.values().map(|v| v.len()).sum();
1548        assert_eq!(
1549            total,
1550            index.len(),
1551            "sum of all region member counts ({total}) must equal index size ({})",
1552            index.len()
1553        );
1554    }
1555
1556    #[test]
1557    fn region_assignments_consistent_with_regions_counts() {
1558        let index = make_region_index();
1559        let level = 1;
1560        let regions = index.regions(level);
1561        let assignments = index.region_assignments(level, TemporalFilter::All);
1562
1563        for &(hub_id, _, count) in &regions {
1564            let assigned_count = assignments.get(&hub_id).map_or(0, |v| v.len());
1565            assert_eq!(
1566                assigned_count, count,
1567                "region hub {hub_id}: region_assignments has {assigned_count} members but regions() reports {count}"
1568            );
1569        }
1570    }
1571
1572    #[test]
1573    fn region_assignments_temporal_filter_reduces_count() {
1574        let index = make_region_index();
1575        let level = 1;
1576
1577        let all = index.region_assignments(level, TemporalFilter::All);
1578        let total_all: usize = all.values().map(|v| v.len()).sum();
1579
1580        // Filter to middle 50% of timestamps (50_000..150_000)
1581        let filtered = index.region_assignments(level, TemporalFilter::Range(50_000, 150_000));
1582        let total_filtered: usize = filtered.values().map(|v| v.len()).sum();
1583
1584        assert!(
1585            total_filtered < total_all,
1586            "Range filter should reduce total members: filtered={total_filtered}, all={total_all}"
1587        );
1588
1589        // Verify every member in filtered results has a timestamp within the range
1590        for members in filtered.values() {
1591            for &(_eid, ts) in members {
1592                assert!(
1593                    (50_000..=150_000).contains(&ts),
1594                    "filtered result has timestamp {ts} outside [50000, 150000]"
1595                );
1596            }
1597        }
1598    }
1599
1600    #[test]
1601    fn region_assignments_each_member_in_exactly_one_region() {
1602        let index = make_region_index();
1603        let level = 1;
1604        let assignments = index.region_assignments(level, TemporalFilter::All);
1605
1606        // Collect (entity_id, timestamp) across all regions and check for duplicates
1607        let mut seen = std::collections::HashSet::new();
1608        let mut total = 0usize;
1609        for members in assignments.values() {
1610            for &(eid, ts) in members {
1611                total += 1;
1612                let _inserted = seen.insert((eid, ts));
1613                // Note: same (eid, ts) can appear if entity has multiple nodes at same ts,
1614                // so we count total instead and verify it matches index.len()
1615            }
1616        }
1617
1618        // Each node_id maps to exactly one region, so total must equal index size
1619        assert_eq!(
1620            total,
1621            index.len(),
1622            "total assigned members ({total}) != index size ({}); a node appeared in multiple or no regions",
1623            index.len()
1624        );
1625
1626        // Additionally, no hub should appear in two different regions' keys that don't exist at level
1627        let hubs: std::collections::HashSet<u32> = assignments.keys().copied().collect();
1628        let level_hubs: std::collections::HashSet<u32> =
1629            index.graph().nodes_at_level(level).into_iter().collect();
1630        for hub in &hubs {
1631            assert!(
1632                level_hubs.contains(hub),
1633                "assignment hub {hub} is not a level-{level} node"
1634            );
1635        }
1636    }
1637
1638    // ─── Centering (RFC-012 Part B) ─────────────────────────────────
1639
1640    #[test]
1641    fn compute_centroid_empty_index() {
1642        let index = make_temporal_index();
1643        assert!(index.compute_centroid().is_none());
1644    }
1645
1646    #[test]
1647    fn compute_centroid_single_vector() {
1648        let mut index = make_temporal_index();
1649        index.insert(1, 1000, &[3.0, 4.0, 5.0]);
1650        let centroid = index.compute_centroid().unwrap();
1651        assert_eq!(centroid, vec![3.0, 4.0, 5.0]);
1652    }
1653
1654    #[test]
1655    fn compute_centroid_mean_of_vectors() {
1656        let mut index = make_temporal_index();
1657        index.insert(1, 1000, &[2.0, 0.0]);
1658        index.insert(2, 2000, &[4.0, 6.0]);
1659        let centroid = index.compute_centroid().unwrap();
1660        assert!((centroid[0] - 3.0).abs() < 1e-6);
1661        assert!((centroid[1] - 3.0).abs() < 1e-6);
1662    }
1663
1664    #[test]
1665    fn set_and_clear_centroid() {
1666        let mut index = make_temporal_index();
1667        index.insert(1, 1000, &[1.0, 2.0]);
1668
1669        assert!(index.centroid().is_none());
1670
1671        index.set_centroid(vec![0.5, 1.0]);
1672        assert!(index.centroid().is_some());
1673        assert_eq!(index.centroid().unwrap(), &[0.5, 1.0]);
1674
1675        index.clear_centroid();
1676        assert!(index.centroid().is_none());
1677    }
1678
1679    #[test]
1680    fn centered_vector_subtracts_centroid() {
1681        let mut index = make_temporal_index();
1682        index.insert(1, 1000, &[1.0, 2.0]);
1683        index.set_centroid(vec![0.5, 1.0]);
1684
1685        let centered = index.centered_vector(&[3.0, 5.0]);
1686        assert!((centered[0] - 2.5).abs() < 1e-6);
1687        assert!((centered[1] - 4.0).abs() < 1e-6);
1688    }
1689
1690    #[test]
1691    fn centered_vector_without_centroid_is_identity() {
1692        let mut index = make_temporal_index();
1693        index.insert(1, 1000, &[1.0, 2.0]);
1694        // No centroid set
1695        let centered = index.centered_vector(&[3.0, 5.0]);
1696        assert_eq!(centered, vec![3.0, 5.0]);
1697    }
1698
1699    #[test]
1700    fn centroid_survives_save_load() {
1701        let dir = std::env::temp_dir();
1702        let path = dir.join("test_centroid_snapshot.cvx");
1703
1704        let mut index = make_temporal_index();
1705        index.insert(1, 1000, &[1.0, 2.0, 3.0]);
1706        index.insert(2, 2000, &[4.0, 5.0, 6.0]);
1707        index.set_centroid(vec![2.5, 3.5, 4.5]);
1708
1709        index.save(&path).unwrap();
1710
1711        let loaded = TemporalHnsw::load(&path, L2Distance).unwrap();
1712        assert_eq!(loaded.centroid().unwrap(), &[2.5, 3.5, 4.5]);
1713
1714        std::fs::remove_file(&path).ok();
1715    }
1716
1717    #[test]
1718    fn load_without_centroid_is_none() {
1719        // Verifies backward compatibility: old snapshots without centroid
1720        // field deserialize with centroid = None (via #[serde(default)])
1721        let dir = std::env::temp_dir();
1722        let path = dir.join("test_no_centroid_snapshot.cvx");
1723
1724        let mut index = make_temporal_index();
1725        index.insert(1, 1000, &[1.0, 0.0]);
1726        // No centroid set
1727        index.save(&path).unwrap();
1728
1729        let loaded = TemporalHnsw::load(&path, L2Distance).unwrap();
1730        assert!(loaded.centroid().is_none());
1731
1732        std::fs::remove_file(&path).ok();
1733    }
1734
1735    #[test]
1736    fn compute_centroid_precision_with_many_vectors() {
1737        let config = HnswConfig {
1738            m: 4,
1739            ef_construction: 20,
1740            ef_search: 10,
1741            ..Default::default()
1742        };
1743        let mut index = TemporalHnsw::new(config, L2Distance);
1744
1745        // Insert 1000 vectors with known mean
1746        for i in 0..1000u64 {
1747            // Vectors centered around [10.0, 20.0] with small perturbation
1748            let v = vec![10.0 + (i as f32 * 0.001), 20.0 - (i as f32 * 0.001)];
1749            index.insert(i, i as i64, &v);
1750        }
1751
1752        let centroid = index.compute_centroid().unwrap();
1753        // Expected mean: [10.0 + 0.4995, 20.0 - 0.4995] = [10.4995, 19.5005]
1754        assert!(
1755            (centroid[0] - 10.4995).abs() < 0.01,
1756            "centroid[0] = {}, expected ~10.4995",
1757            centroid[0]
1758        );
1759        assert!(
1760            (centroid[1] - 19.5005).abs() < 0.01,
1761            "centroid[1] = {}, expected ~19.5005",
1762            centroid[1]
1763        );
1764    }
1765
1766    #[test]
1767    fn search_with_empty_metadata_filter_matches_all() {
1768        use cvx_core::TemporalIndexAccess;
1769        use cvx_core::types::MetadataFilter;
1770
1771        let config = HnswConfig::default();
1772        let mut index = TemporalHnsw::new(config, L2Distance);
1773
1774        for i in 0..5u64 {
1775            index.insert(i, i as i64 * 1000, &[i as f32, 0.0]);
1776        }
1777
1778        let filter = MetadataFilter::new();
1779        let results =
1780            index.search_with_metadata(&[2.0, 0.0], 3, TemporalFilter::All, 1.0, 0, &filter);
1781        assert_eq!(results.len(), 3);
1782    }
1783
1784    // ─── Reward / outcome-aware search (RFC-012 P4) ──────────────
1785
1786    #[test]
1787    fn insert_with_reward_stores_reward() {
1788        let mut index = make_temporal_index();
1789        let n0 = index.insert(1, 1000, &[1.0, 0.0]);
1790        let n1 = index.insert_with_reward(2, 2000, &[0.0, 1.0], 0.8);
1791
1792        assert!(index.reward(n0).is_nan()); // no reward
1793        assert!((index.reward(n1) - 0.8).abs() < 1e-6);
1794    }
1795
1796    #[test]
1797    fn set_reward_updates_retroactively() {
1798        let mut index = make_temporal_index();
1799        let n0 = index.insert(1, 1000, &[1.0, 0.0]);
1800        assert!(index.reward(n0).is_nan());
1801
1802        index.set_reward(n0, 0.95);
1803        assert!((index.reward(n0) - 0.95).abs() < 1e-6);
1804    }
1805
1806    #[test]
1807    fn search_with_reward_filters() {
1808        let mut index = make_temporal_index();
1809        // Insert 10 points with varying rewards
1810        for i in 0..10u64 {
1811            index.insert_with_reward(i, i as i64 * 1000, &[i as f32, 0.0, 0.0], i as f32 * 0.1);
1812        }
1813
1814        // min_reward=0.5 → only nodes 5..9
1815        let results =
1816            index.search_with_reward(&[7.0, 0.0, 0.0], 5, TemporalFilter::All, 1.0, 0, 0.5);
1817        assert!(!results.is_empty());
1818        for &(node_id, _) in &results {
1819            let r = index.reward(node_id);
1820            assert!(r >= 0.5, "node {node_id} has reward {r} < 0.5");
1821        }
1822    }
1823
1824    #[test]
1825    fn search_with_reward_no_matches() {
1826        let mut index = make_temporal_index();
1827        for i in 0..5u64 {
1828            index.insert_with_reward(i, i as i64 * 1000, &[i as f32, 0.0], 0.1);
1829        }
1830
1831        let results = index.search_with_reward(&[2.0, 0.0], 5, TemporalFilter::All, 1.0, 0, 0.9);
1832        assert!(results.is_empty());
1833    }
1834
1835    #[test]
1836    fn reward_survives_save_load() {
1837        let dir = std::env::temp_dir();
1838        let path = dir.join("test_reward_snapshot.cvx");
1839
1840        let mut index = make_temporal_index();
1841        index.insert_with_reward(1, 1000, &[1.0, 0.0], 0.75);
1842        index.insert(2, 2000, &[0.0, 1.0]); // no reward
1843        index.save(&path).unwrap();
1844
1845        let loaded = TemporalHnsw::load(&path, L2Distance).unwrap();
1846        assert!((loaded.reward(0) - 0.75).abs() < 1e-6);
1847        assert!(loaded.reward(1).is_nan());
1848
1849        std::fs::remove_file(&path).ok();
1850    }
1851
1852    // ─── P7: Recency + P8: Normalization ─────────────────────────
1853
1854    #[test]
1855    fn normalize_semantic_distance_clamps() {
1856        let mut index = make_temporal_index();
1857        index.insert(1, 1000, &[1.0, 0.0]);
1858
1859        // Cosine distance [0, 2] → [0, 1]
1860        assert!((index.normalize_semantic_distance(0.0) - 0.0).abs() < 1e-6);
1861        assert!((index.normalize_semantic_distance(1.0) - 0.5).abs() < 1e-6);
1862        assert!((index.normalize_semantic_distance(2.0) - 1.0).abs() < 1e-6);
1863        // Clamp: values > 2 stay at 1.0
1864        assert!((index.normalize_semantic_distance(4.0) - 1.0).abs() < 1e-6);
1865    }
1866
1867    #[test]
1868    fn recency_penalty_zero_lambda() {
1869        let mut index = make_temporal_index();
1870        index.insert(1, 1000, &[1.0, 0.0]);
1871        index.insert(2, 2000, &[0.0, 1.0]);
1872        // lambda=0 → no recency effect
1873        assert!((index.recency_penalty(1000, 0.0) - 0.0).abs() < 1e-6);
1874    }
1875
1876    #[test]
1877    fn recency_penalty_recent_is_lower() {
1878        let mut index = make_temporal_index();
1879        for i in 0..10u64 {
1880            index.insert(i, (i * 1000) as i64, &[i as f32, 0.0]);
1881        }
1882        let recent = index.recency_penalty(9000, 1.0); // most recent
1883        let old = index.recency_penalty(0, 1.0); // oldest
1884        assert!(
1885            recent < old,
1886            "recent penalty ({recent}) should be < old penalty ({old})"
1887        );
1888    }
1889
1890    #[test]
1891    fn search_with_recency_prefers_recent() {
1892        let mut index = make_temporal_index();
1893        // Two identical vectors at different times
1894        index.insert(1, 1000, &[1.0, 0.0, 0.0]);
1895        index.insert(2, 9000, &[1.0, 0.0, 0.0]); // more recent
1896
1897        let results = index.search_with_recency(
1898            &[1.0, 0.0, 0.0],
1899            2,
1900            TemporalFilter::All,
1901            1.0, // pure semantic
1902            0,
1903            2.0, // strong recency
1904            0.5, // high recency weight
1905        );
1906
1907        assert_eq!(results.len(), 2);
1908        // More recent node should score lower (better)
1909        assert_eq!(
1910            results[0].0,
1911            1, // node 1 = entity 2 at t=9000 (more recent)
1912            "recent node should rank first"
1913        );
1914    }
1915
1916    #[test]
1917    fn search_normalized_distances_balanced() {
1918        let mut index = make_temporal_index();
1919        // Semantically close but temporally far
1920        index.insert(1, 1000, &[1.0, 0.0, 0.0]);
1921        // Semantically far but temporally close
1922        index.insert(2, 5000, &[0.0, 1.0, 0.0]);
1923        // Query at t=4900 with alpha=0.5
1924        let results = index.search(
1925            &[0.9, 0.1, 0.0],
1926            2,
1927            TemporalFilter::All,
1928            0.5, // balanced
1929            4900,
1930        );
1931
1932        // With normalized scales, temporal distance is comparable to semantic
1933        // Entity 2 at t=5000 is temporally close to query_ts=4900
1934        assert_eq!(results.len(), 2);
1935    }
1936
1937    // ─── P9: Parallel bulk insert ────────────────────────────────
1938
1939    #[test]
1940    fn bulk_insert_parallel_basic() {
1941        let config = HnswConfig {
1942            m: 8,
1943            ef_construction: 50,
1944            ef_search: 50,
1945            ..Default::default()
1946        };
1947        let mut index = TemporalHnsw::new(config, L2Distance);
1948
1949        let n = 500;
1950        let dim = 16;
1951        let mut rng = rand::rng();
1952        let eids: Vec<u64> = (0..n).map(|i| i as u64 % 10).collect();
1953        let tss: Vec<i64> = (0..n).map(|i| i as i64 * 100).collect();
1954        let vecs: Vec<Vec<f32>> = (0..n)
1955            .map(|_| {
1956                (0..dim)
1957                    .map(|_| rand::Rng::random::<f32>(&mut rng))
1958                    .collect()
1959            })
1960            .collect();
1961        let vec_refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
1962
1963        let count = index.bulk_insert_parallel(&eids, &tss, &vec_refs);
1964        assert_eq!(count, n);
1965        assert_eq!(index.len(), n);
1966
1967        // Should be searchable
1968        let results = index.search(&vecs[0], 5, TemporalFilter::All, 1.0, 0);
1969        assert_eq!(results.len(), 5);
1970    }
1971
1972    #[test]
1973    fn bulk_insert_parallel_recall() {
1974        let config = HnswConfig {
1975            m: 16,
1976            ef_construction: 100,
1977            ef_search: 100,
1978            ..Default::default()
1979        };
1980        let mut index = TemporalHnsw::new(config, L2Distance);
1981
1982        let n = 1000;
1983        let dim = 32;
1984        let mut rng = rand::rng();
1985        let eids: Vec<u64> = (0..n).map(|i| i as u64).collect();
1986        let tss: Vec<i64> = (0..n).map(|i| i as i64 * 100).collect();
1987        let vecs: Vec<Vec<f32>> = (0..n)
1988            .map(|_| {
1989                (0..dim)
1990                    .map(|_| rand::Rng::random::<f32>(&mut rng))
1991                    .collect()
1992            })
1993            .collect();
1994        let vec_refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
1995
1996        index.bulk_insert_parallel(&eids, &tss, &vec_refs);
1997
1998        // Check recall
1999        let k = 10;
2000        let mut total_recall = 0.0;
2001        let n_queries = 20;
2002        for _ in 0..n_queries {
2003            let query: Vec<f32> = (0..dim)
2004                .map(|_| rand::Rng::random::<f32>(&mut rng))
2005                .collect();
2006            let results = index.search(&query, k, TemporalFilter::All, 1.0, 0);
2007            let truth = index.graph().brute_force_knn(&query, k);
2008            total_recall += super::super::recall_at_k(&results, &truth);
2009        }
2010        let avg_recall = total_recall / n_queries as f64;
2011        assert!(
2012            avg_recall >= 0.80,
2013            "parallel bulk_insert recall = {avg_recall:.3}, expected >= 0.80"
2014        );
2015    }
2016}