cvx_index/hnsw/
mod.rs

1//! Hierarchical Navigable Small World (HNSW) graph index.
2//!
3//! A multi-layer graph structure for approximate nearest neighbor search.
4//! Based on Malkov & Yashunin (2018) with single-threaded insert/search.
5//!
6//! This is a **vanilla HNSW** — no temporal filtering, decay, or timestamp graph.
7//! Those are added in Layer 4+.
8//!
9//! # Algorithm Overview
10//!
11//! - **Insert**: assign random level, greedily descend from top to target level,
12//!   then do beam search at each level to find neighbors
13//! - **Search**: greedy descend from top to level 0, then beam search on level 0
14//! - **Levels**: higher levels are sparser (fewer nodes), lower levels are denser
15//!
16//! # Example
17//!
18//! ```
19//! use cvx_index::hnsw::{HnswGraph, HnswConfig};
20//! use cvx_index::metrics::CosineDistance;
21//!
22//! let config = HnswConfig { m: 16, ef_construction: 200, ef_search: 50, ..Default::default() };
23//! let mut graph = HnswGraph::new(config, CosineDistance);
24//!
25//! // Insert vectors
26//! graph.insert(0, &[1.0, 0.0, 0.0]);
27//! graph.insert(1, &[0.9, 0.1, 0.0]);
28//! graph.insert(2, &[0.0, 1.0, 0.0]);
29//!
30//! // Search
31//! let results = graph.search(&[1.0, 0.0, 0.0], 2);
32//! assert_eq!(results[0].0, 0); // closest is itself
33//! assert_eq!(results[1].0, 1); // second closest
34//! ```
35
36pub mod bayesian_scorer;
37pub mod concurrent;
38pub mod metadata_store;
39pub mod optimized;
40pub mod partitioned;
41pub mod region_mdp;
42pub mod streaming;
43pub mod temporal;
44pub mod temporal_edges;
45pub mod temporal_graph;
46pub mod temporal_lsh;
47pub mod typed_edges;
48
49pub use bayesian_scorer::{CandidateFeatures, ScoringWeights, WeightLearner};
50pub use concurrent::ConcurrentTemporalHnsw;
51pub use region_mdp::RegionMdp;
52pub use temporal::TemporalHnsw;
53pub use temporal_edges::TemporalEdgeLayer;
54pub use temporal_graph::{CausalSearchResult, TemporalGraphIndex};
55pub use typed_edges::{EdgeType, TypedEdgeStore};
56
57use std::cmp::Reverse;
58use std::collections::BinaryHeap;
59
60use cvx_core::DistanceMetric;
61use rand::rngs::SmallRng;
62use rand::{Rng, SeedableRng};
63use serde::{Deserialize, Serialize};
64use smallvec::SmallVec;
65
66/// HNSW index configuration.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct HnswConfig {
69    /// Maximum connections per node per layer (except layer 0 which gets 2*M).
70    pub m: usize,
71    /// Search width during construction.
72    pub ef_construction: usize,
73    /// Default search width during queries.
74    pub ef_search: usize,
75    /// Maximum level (auto-calculated if 0).
76    pub max_level: usize,
77    /// Level generation multiplier: 1 / ln(M).
78    pub level_mult: f64,
79    /// Use heuristic neighbor selection (Malkov §4.2) for better connectivity.
80    /// When false, uses simple closest-M selection.
81    pub use_heuristic: bool,
82}
83
84impl Default for HnswConfig {
85    fn default() -> Self {
86        let m = 16;
87        Self {
88            m,
89            ef_construction: 200,
90            ef_search: 50,
91            max_level: 0,
92            level_mult: 1.0 / (m as f64).ln(),
93            use_heuristic: true,
94        }
95    }
96}
97
98/// Neighbor list: inline up to M entries (no heap allocation).
99type NeighborList = SmallVec<[u32; 16]>;
100
101/// A node in the HNSW graph.
102#[derive(Serialize, Deserialize)]
103pub(crate) struct HnswNode {
104    /// The vector data.
105    vector: Vec<f32>,
106    /// Neighbors at each level this node participates in.
107    /// `neighbors[0]` = layer 0 (max 2*M neighbors), `neighbors[1]` = layer 1 (max M), etc.
108    neighbors: Vec<NeighborList>,
109}
110
111impl optimized::NodeVectors for [HnswNode] {
112    fn get_vector(&self, id: u32) -> &[f32] {
113        &self[id as usize].vector
114    }
115}
116
117/// HNSW graph index for approximate nearest neighbor search.
118///
119/// Optionally stores scalar-quantized codes (uint8) for accelerated distance
120/// computation during construction and search. When enabled, candidate exploration
121/// uses fast integer distances on codes, while final neighbor selection uses exact
122/// float32 distances for quality. See RFC-005 §3.
123pub struct HnswGraph<D: DistanceMetric> {
124    config: HnswConfig,
125    metric: D,
126    nodes: Vec<HnswNode>,
127    entry_point: Option<u32>,
128    max_level: usize,
129    rng: SmallRng,
130    /// Scalar-quantized codes: node_id → uint8 code (same dim as vectors).
131    /// When Some, `distance_fast` uses integer arithmetic (~4× faster).
132    sq_codes: Option<Vec<Vec<u8>>>,
133    /// Quantization parameters: (min_val, scale) for encoding/decoding.
134    sq_params: (f32, f32),
135}
136
137impl<D: DistanceMetric> HnswGraph<D> {
138    /// Create a new empty HNSW graph.
139    pub fn new(config: HnswConfig, metric: D) -> Self {
140        Self {
141            config,
142            metric,
143            nodes: Vec::new(),
144            entry_point: None,
145            max_level: 0,
146            rng: SmallRng::from_os_rng(),
147            sq_codes: None,
148            sq_params: (-1.0, 127.5), // default for L2-normalized vectors: [-1,1]→[0,255]
149        }
150    }
151
152    /// Enable scalar quantization for accelerated distance computation.
153    ///
154    /// When enabled, each inserted vector is also encoded as uint8 and
155    /// candidate distances during `search_layer` use fast integer arithmetic.
156    /// Final neighbor selection still uses exact float32 distances.
157    ///
158    /// For L2-normalized embeddings (range [-1, 1]), use default parameters.
159    /// For unnormalized data, provide the expected min/max range.
160    pub fn enable_scalar_quantization(&mut self, min_val: f32, max_val: f32) {
161        let range = max_val - min_val;
162        self.sq_params = (min_val, if range > 0.0 { 255.0 / range } else { 1.0 });
163
164        // Encode existing nodes
165        let codes: Vec<Vec<u8>> = self
166            .nodes
167            .iter()
168            .map(|node| Self::encode_sq(&node.vector, self.sq_params.0, self.sq_params.1))
169            .collect();
170        self.sq_codes = Some(codes);
171    }
172
173    /// Disable scalar quantization (revert to exact distances only).
174    pub fn disable_scalar_quantization(&mut self) {
175        self.sq_codes = None;
176    }
177
178    /// Whether scalar quantization is active.
179    pub fn is_quantized(&self) -> bool {
180        self.sq_codes.is_some()
181    }
182
183    /// Encode a vector to uint8 using scalar quantization.
184    #[inline]
185    fn encode_sq(vector: &[f32], min_val: f32, scale: f32) -> Vec<u8> {
186        vector
187            .iter()
188            .map(|&v| ((v - min_val) * scale).clamp(0.0, 255.0) as u8)
189            .collect()
190    }
191
192    /// Fast L2 distance on uint8 codes (auto-vectorized by LLVM).
193    #[inline]
194    fn distance_sq(a: &[u8], b: &[u8]) -> f32 {
195        let mut sum: u32 = 0;
196        for i in 0..a.len() {
197            let diff = a[i] as i32 - b[i] as i32;
198            sum += (diff * diff) as u32;
199        }
200        sum as f32 // skip sqrt — monotonic, preserves ordering
201    }
202
203    /// Number of vectors in the index.
204    pub fn len(&self) -> usize {
205        self.nodes.len()
206    }
207
208    /// Whether the index is empty.
209    pub fn is_empty(&self) -> bool {
210        self.nodes.is_empty()
211    }
212
213    /// Set ef_construction at runtime (e.g., lower for bulk load, higher for incremental).
214    pub fn set_ef_construction(&mut self, ef: usize) {
215        self.config.ef_construction = ef;
216    }
217
218    /// Set ef_search at runtime.
219    pub fn set_ef_search(&mut self, ef: usize) {
220        self.config.ef_search = ef;
221    }
222
223    /// Generate a random level for a new node.
224    ///
225    /// Capped at 32 (supports up to M^32 ≈ 10^38 nodes). See RFC-002-07.
226    pub(crate) fn random_level(&mut self) -> usize {
227        let r: f64 = self.rng.random();
228        let level = (-r.ln() * self.config.level_mult).floor() as usize;
229        level.min(32)
230    }
231
232    /// Max neighbors allowed at a given level.
233    fn max_neighbors(&self, level: usize) -> usize {
234        if level == 0 {
235            self.config.m * 2
236        } else {
237            self.config.m
238        }
239    }
240
241    /// Allocate a node without connecting it (used by bulk_insert_parallel).
242    pub(crate) fn push_node(&mut self, vector: &[f32], level: usize) {
243        let node = HnswNode {
244            vector: vector.to_vec(),
245            neighbors: (0..=level).map(|_| NeighborList::new()).collect(),
246        };
247        self.nodes.push(node);
248
249        if let Some(ref mut codes) = self.sq_codes {
250            codes.push(Self::encode_sq(vector, self.sq_params.0, self.sq_params.1));
251        }
252
253        if self.nodes.len() == 1 {
254            self.entry_point = Some(0);
255            self.max_level = level;
256        } else if level > self.max_level {
257            self.entry_point = Some((self.nodes.len() - 1) as u32);
258            self.max_level = level;
259        }
260    }
261
262    /// Connect a pre-allocated node using pre-computed neighbor candidates.
263    pub(crate) fn connect_node(&mut self, id: u32, candidates: &[(u32, f32)], level: usize) {
264        let insert_from = level.min(self.max_level);
265        for lev in (0..=insert_from).rev() {
266            let max_n = self.max_neighbors(lev);
267            let selected: Vec<u32> = if self.config.use_heuristic {
268                optimized::select_neighbors_heuristic(
269                    &self.metric,
270                    candidates,
271                    self.nodes.as_slice(),
272                    max_n,
273                    false,
274                )
275            } else {
276                candidates.iter().take(max_n).map(|&(n, _)| n).collect()
277            };
278
279            for &neighbor_id in &selected {
280                if neighbor_id == id {
281                    continue;
282                }
283                if lev < self.nodes[id as usize].neighbors.len()
284                    && !self.nodes[id as usize].neighbors[lev].contains(&neighbor_id)
285                {
286                    self.nodes[id as usize].neighbors[lev].push(neighbor_id);
287                }
288                if lev < self.nodes[neighbor_id as usize].neighbors.len()
289                    && !self.nodes[neighbor_id as usize].neighbors[lev].contains(&id)
290                {
291                    self.nodes[neighbor_id as usize].neighbors[lev].push(id);
292                    let count = self.nodes[neighbor_id as usize].neighbors[lev].len();
293                    if count > max_n {
294                        self.prune_neighbors(neighbor_id, lev, max_n);
295                    }
296                }
297            }
298        }
299    }
300
301    /// Insert a vector into the index.
302    ///
303    /// The `id` should be a unique sequential identifier starting from 0.
304    /// Panics if `id != self.len()` (must insert in order).
305    pub fn insert(&mut self, id: u32, vector: &[f32]) {
306        assert_eq!(
307            id as usize,
308            self.nodes.len(),
309            "must insert sequentially: expected id {}, got {id}",
310            self.nodes.len()
311        );
312
313        let level = self.random_level();
314        let node = HnswNode {
315            vector: vector.to_vec(),
316            neighbors: (0..=level).map(|_| NeighborList::new()).collect(),
317        };
318        self.nodes.push(node);
319
320        // Store SQ code if quantization is enabled
321        if let Some(ref mut codes) = self.sq_codes {
322            codes.push(Self::encode_sq(vector, self.sq_params.0, self.sq_params.1));
323        }
324
325        // First node
326        if self.nodes.len() == 1 {
327            self.entry_point = Some(0);
328            self.max_level = level;
329            return;
330        }
331
332        let entry = self.entry_point.unwrap();
333        let mut current = entry;
334
335        // Phase 1: greedy descend from top level to node's level + 1
336        for lev in (level + 1..=self.max_level).rev() {
337            current = self.greedy_closest(current, vector, lev);
338        }
339
340        // Phase 2: insert at each level from min(level, max_level) down to 0
341        let insert_from = level.min(self.max_level);
342        for lev in (0..=insert_from).rev() {
343            let neighbors = self.search_layer(current, vector, self.config.ef_construction, lev);
344
345            // Select best M neighbors
346            let max_n = self.max_neighbors(lev);
347            let mut selected: Vec<u32> = if self.config.use_heuristic {
348                optimized::select_neighbors_heuristic(
349                    &self.metric,
350                    &neighbors,
351                    self.nodes.as_slice(),
352                    max_n,
353                    false,
354                )
355            } else {
356                neighbors.iter().take(max_n).map(|&(n, _)| n).collect()
357            };
358
359            // Safety: ensure at least one connection at every level
360            if selected.is_empty() {
361                selected.push(current);
362            }
363
364            // Add bidirectional connections
365            for &neighbor_id in &selected {
366                // Avoid self-loops
367                if neighbor_id == id {
368                    continue;
369                }
370                // Avoid duplicate edges
371                if !self.nodes[id as usize].neighbors[lev].contains(&neighbor_id) {
372                    self.nodes[id as usize].neighbors[lev].push(neighbor_id);
373                }
374                if !self.nodes[neighbor_id as usize].neighbors[lev].contains(&id) {
375                    self.nodes[neighbor_id as usize].neighbors[lev].push(id);
376                }
377
378                // Prune neighbor's list if over capacity
379                let neighbor_count = self.nodes[neighbor_id as usize].neighbors[lev].len();
380                if neighbor_count > max_n {
381                    self.prune_neighbors(neighbor_id, lev, max_n);
382                }
383            }
384
385            // Use closest as entry for next lower level
386            if let Some(&(closest, _)) = neighbors.first() {
387                current = closest;
388            }
389        }
390
391        // Ensure the new node has at least one connection at level 0.
392        // This prevents disconnected components from forming.
393        if self.nodes[id as usize].neighbors[0].is_empty() {
394            // Find closest node via brute force scan (rare case, only for disconnected nodes)
395            let mut best_id = entry;
396            let mut best_dist = self.distance(entry, vector);
397            for i in 0..self.nodes.len() - 1 {
398                let d = self.distance(i as u32, vector);
399                if d < best_dist {
400                    best_dist = d;
401                    best_id = i as u32;
402                }
403            }
404            self.nodes[id as usize].neighbors[0].push(best_id);
405            self.nodes[best_id as usize].neighbors[0].push(id);
406        }
407
408        // Update entry point if new node has higher level
409        if level > self.max_level {
410            self.entry_point = Some(id);
411            self.max_level = level;
412        }
413    }
414
415    /// Greedy search for the single closest node at a given level.
416    fn greedy_closest(&self, start: u32, query: &[f32], level: usize) -> u32 {
417        let query_code = self
418            .sq_codes
419            .as_ref()
420            .map(|_| Self::encode_sq(query, self.sq_params.0, self.sq_params.1));
421        let qc = query_code.as_deref();
422
423        let mut current = start;
424        let mut current_dist = self.distance_fast(current, qc, query);
425
426        loop {
427            let mut changed = false;
428            let neighbors = self.neighbors_at(current, level);
429            for &neighbor in neighbors {
430                let dist = self.distance_fast(neighbor, qc, query);
431                if dist < current_dist {
432                    current = neighbor;
433                    current_dist = dist;
434                    changed = true;
435                }
436            }
437            if !changed {
438                return current;
439            }
440        }
441    }
442
443    /// Beam search at a single level. Returns candidates sorted by distance (ascending).
444    ///
445    /// When scalar quantization is enabled, candidate exploration uses fast uint8
446    /// distances. Final results are re-ranked with exact float32 distances.
447    fn search_layer(&self, entry: u32, query: &[f32], ef: usize, level: usize) -> Vec<(u32, f32)> {
448        // Pre-encode query for SQ if enabled
449        let query_code = self
450            .sq_codes
451            .as_ref()
452            .map(|_| Self::encode_sq(query, self.sq_params.0, self.sq_params.1));
453        let qc = query_code.as_deref();
454
455        let entry_dist = self.distance_fast(entry, qc, query);
456
457        // Min-heap for candidates to explore (closest first)
458        let mut candidates: BinaryHeap<Reverse<OrdF32Entry>> = BinaryHeap::new();
459        // Max-heap for results (farthest first, so we can evict)
460        let mut results: BinaryHeap<OrdF32Entry> = BinaryHeap::new();
461        // Visited set
462        let mut visited = vec![false; self.nodes.len()];
463
464        candidates.push(Reverse(OrdF32Entry(entry_dist, entry)));
465        results.push(OrdF32Entry(entry_dist, entry));
466        visited[entry as usize] = true;
467
468        while let Some(Reverse(OrdF32Entry(c_dist, c_id))) = candidates.pop() {
469            // If closest candidate is farther than farthest result, stop
470            let farthest_result = results.peek().map(|e| e.0).unwrap_or(f32::INFINITY);
471            if c_dist > farthest_result {
472                break;
473            }
474
475            let neighbors = self.neighbors_at(c_id, level);
476            for &neighbor in neighbors {
477                if visited[neighbor as usize] {
478                    continue;
479                }
480                visited[neighbor as usize] = true;
481
482                let dist = self.distance_fast(neighbor, qc, query);
483                let farthest_result = results.peek().map(|e| e.0).unwrap_or(f32::INFINITY);
484
485                if dist < farthest_result || results.len() < ef {
486                    candidates.push(Reverse(OrdF32Entry(dist, neighbor)));
487                    results.push(OrdF32Entry(dist, neighbor));
488                    if results.len() > ef {
489                        results.pop(); // remove farthest
490                    }
491                }
492            }
493        }
494
495        // Re-rank with exact distances when SQ was used (quality matters for final results)
496        let mut result_vec: Vec<(u32, f32)> = if self.sq_codes.is_some() {
497            results
498                .into_iter()
499                .map(|e| (e.1, self.distance(e.1, query)))
500                .collect()
501        } else {
502            results.into_iter().map(|e| (e.1, e.0)).collect()
503        };
504        result_vec.sort_by(|a, b| a.1.total_cmp(&b.1));
505        result_vec
506    }
507
508    /// Prune a node's neighbor list to keep only the best `max_n`.
509    ///
510    /// Uses heuristic selection when enabled (diverse directions),
511    /// otherwise keeps the closest `max_n` by distance.
512    fn prune_neighbors(&mut self, node_id: u32, level: usize, max_n: usize) {
513        let node_vec = self.nodes[node_id as usize].vector.clone();
514        let scored: Vec<(u32, f32)> = self.nodes[node_id as usize].neighbors[level]
515            .iter()
516            .map(|&n| {
517                (
518                    n,
519                    self.metric
520                        .distance(&node_vec, &self.nodes[n as usize].vector),
521                )
522            })
523            .collect();
524
525        let pruned = if self.config.use_heuristic {
526            optimized::select_neighbors_heuristic(
527                &self.metric,
528                &scored,
529                self.nodes.as_slice(),
530                max_n,
531                false,
532            )
533        } else {
534            let mut s = scored;
535            s.sort_by(|a, b| a.1.total_cmp(&b.1));
536            s.truncate(max_n);
537            s.iter().map(|&(n, _)| n).collect()
538        };
539
540        self.nodes[node_id as usize].neighbors[level] = pruned.into_iter().collect();
541    }
542
543    /// Search for the k nearest neighbors of `query`.
544    ///
545    /// Returns a Vec of `(node_id, distance)` sorted by distance ascending.
546    pub fn search(&self, query: &[f32], k: usize) -> Vec<(u32, f32)> {
547        if self.nodes.is_empty() {
548            return Vec::new();
549        }
550
551        let entry = self.entry_point.unwrap();
552        let mut current = entry;
553
554        // Greedy descend from top to level 1
555        for lev in (1..=self.max_level).rev() {
556            current = self.greedy_closest(current, query, lev);
557        }
558
559        // Beam search on level 0
560        let mut results = self.search_layer(current, query, self.config.ef_search.max(k), 0);
561        results.truncate(k);
562        results
563    }
564
565    /// Search with a predicate filter.
566    ///
567    /// Like [`search`](Self::search), but only returns nodes where `filter(node_id)` is true.
568    /// The HNSW graph is still traversed through filtered-out nodes (they act as bridges),
569    /// but only matching nodes appear in the results.
570    pub fn search_filtered(
571        &self,
572        query: &[f32],
573        k: usize,
574        filter: impl Fn(u32) -> bool,
575    ) -> Vec<(u32, f32)> {
576        if self.nodes.is_empty() {
577            return Vec::new();
578        }
579
580        let entry = self.entry_point.unwrap();
581        let mut current = entry;
582
583        // Greedy descend from top to level 1
584        for lev in (1..=self.max_level).rev() {
585            current = self.greedy_closest(current, query, lev);
586        }
587
588        // Beam search on level 0, collecting all candidates
589        let ef = self.config.ef_search.max(k * 4); // over-fetch to compensate for filtering
590        let all_candidates = self.search_layer(current, query, ef, 0);
591
592        // Apply filter and take top-k
593        let mut results: Vec<(u32, f32)> = all_candidates
594            .into_iter()
595            .filter(|&(id, _)| filter(id))
596            .collect();
597        results.truncate(k);
598        results
599    }
600
601    /// Get the stored vector for a node.
602    pub fn vector(&self, node_id: u32) -> &[f32] {
603        &self.nodes[node_id as usize].vector
604    }
605
606    /// Get the configuration.
607    pub fn config(&self) -> &HnswConfig {
608        &self.config
609    }
610
611    /// Get the entry point node ID.
612    pub fn entry_point(&self) -> Option<u32> {
613        self.entry_point
614    }
615
616    /// Get the maximum level in the graph.
617    pub fn max_level(&self) -> usize {
618        self.max_level
619    }
620
621    /// Get the neighbor list for a node at a specific level (public accessor).
622    pub fn neighbors_at_level(&self, node_id: u32, level: usize) -> &[u32] {
623        self.neighbors_at(node_id, level)
624    }
625
626    /// Compute distance between a stored node and a query vector (public accessor).
627    pub fn distance_to(&self, node_id: u32, query: &[f32]) -> f32 {
628        self.distance(node_id, query)
629    }
630
631    /// Get all neighbor lists for a node (one per level).
632    pub fn all_neighbors(&self, node_id: u32) -> Vec<Vec<u32>> {
633        let node = &self.nodes[node_id as usize];
634        node.neighbors.iter().map(|n| n.to_vec()).collect()
635    }
636
637    /// Return node IDs present at the given HNSW level (RFC-004).
638    ///
639    /// These are the natural "hub" nodes of the graph hierarchy.
640    /// Level 0 = all nodes, higher levels = fewer, more connected hubs.
641    /// Count follows geometric distribution: ~N/M^level.
642    pub fn nodes_at_level(&self, level: usize) -> Vec<u32> {
643        (0..self.nodes.len() as u32)
644            .filter(|&id| self.nodes[id as usize].neighbors.len() > level)
645            .collect()
646    }
647
648    /// Assign a vector to its nearest hub at the given level (RFC-004).
649    ///
650    /// Uses greedy descent from the entry point — O(log N).
651    /// Returns the node_id of the nearest hub at that level.
652    pub fn assign_region(&self, vector: &[f32], level: usize) -> Option<u32> {
653        if self.nodes.is_empty() {
654            return None;
655        }
656
657        let entry = self.entry_point.unwrap();
658        let mut current = entry;
659
660        // Greedy descend from top to target level + 1
661        for lev in (level + 1..=self.max_level).rev() {
662            current = self.greedy_closest(current, vector, lev);
663        }
664
665        // At the target level, find the closest hub
666        if level <= self.max_level {
667            current = self.greedy_closest(current, vector, level);
668        }
669
670        // Ensure result is actually at the target level
671        if self.nodes[current as usize].neighbors.len() > level {
672            Some(current)
673        } else {
674            // Fallback: search among known hubs at this level
675            let hubs = self.nodes_at_level(level);
676            hubs.into_iter().min_by(|&a, &b| {
677                self.distance(a, vector)
678                    .total_cmp(&self.distance(b, vector))
679            })
680        }
681    }
682
683    /// Compute distance between a stored node and a query vector.
684    ///
685    /// When scalar quantization is enabled, uses fast uint8 distances
686    /// for candidate exploration (~4× faster). Falls back to exact
687    /// float32 distance when SQ is disabled.
688    #[inline]
689    fn distance(&self, node_id: u32, query: &[f32]) -> f32 {
690        self.metric
691            .distance(&self.nodes[node_id as usize].vector, query)
692    }
693
694    /// Fast approximate distance using scalar-quantized codes.
695    ///
696    /// Returns the exact distance if SQ is not enabled.
697    #[inline]
698    fn distance_fast(&self, node_id: u32, query_code: Option<&[u8]>, query: &[f32]) -> f32 {
699        if let (Some(codes), Some(qc)) = (&self.sq_codes, query_code) {
700            Self::distance_sq(&codes[node_id as usize], qc)
701        } else {
702            self.distance(node_id, query)
703        }
704    }
705
706    /// Get the neighbor list for a node at a given level.
707    #[inline]
708    fn neighbors_at(&self, node_id: u32, level: usize) -> &[u32] {
709        let node = &self.nodes[node_id as usize];
710        if level < node.neighbors.len() {
711            &node.neighbors[level]
712        } else {
713            &[]
714        }
715    }
716
717    /// Check graph invariant: all nodes are reachable from entry point at level 0.
718    ///
719    /// Returns the number of reachable nodes. Should equal `self.len()`.
720    pub fn count_reachable(&self) -> usize {
721        if self.nodes.is_empty() {
722            return 0;
723        }
724        let entry = self.entry_point.unwrap();
725        let mut visited = vec![false; self.nodes.len()];
726        let mut stack = vec![entry];
727        visited[entry as usize] = true;
728        let mut count = 1usize;
729
730        while let Some(node) = stack.pop() {
731            for &neighbor in self.neighbors_at(node, 0) {
732                if !visited[neighbor as usize] {
733                    visited[neighbor as usize] = true;
734                    count += 1;
735                    stack.push(neighbor);
736                }
737            }
738        }
739        count
740    }
741
742    /// Brute-force kNN for ground truth comparison.
743    ///
744    /// Returns `(node_id, distance)` sorted ascending. O(N) per query.
745    pub fn brute_force_knn(&self, query: &[f32], k: usize) -> Vec<(u32, f32)> {
746        let mut all: Vec<(u32, f32)> = self
747            .nodes
748            .iter()
749            .enumerate()
750            .map(|(i, node)| (i as u32, self.metric.distance(&node.vector, query)))
751            .collect();
752        all.sort_by(|a, b| a.1.total_cmp(&b.1));
753        all.truncate(k);
754        all
755    }
756}
757
758/// Serializable snapshot of an HNSW graph (excludes metric + RNG).
759///
760/// Used by `TemporalHnsw::save` / `TemporalHnsw::load` for index persistence.
761#[derive(Serialize, Deserialize)]
762pub(crate) struct HnswSnapshot {
763    pub(crate) config: HnswConfig,
764    pub(crate) nodes: Vec<HnswNode>,
765    pub(crate) entry_point: Option<u32>,
766    pub(crate) max_level: usize,
767    pub(crate) sq_codes: Option<Vec<Vec<u8>>>,
768    pub(crate) sq_params: (f32, f32),
769}
770
771impl<D: DistanceMetric> HnswGraph<D> {
772    /// Create a serializable snapshot (excludes metric and RNG).
773    pub(crate) fn to_snapshot(&self) -> HnswSnapshot {
774        HnswSnapshot {
775            config: self.config.clone(),
776            nodes: self
777                .nodes
778                .iter()
779                .map(|n| HnswNode {
780                    vector: n.vector.clone(),
781                    neighbors: n.neighbors.clone(),
782                })
783                .collect(),
784            entry_point: self.entry_point,
785            max_level: self.max_level,
786            sq_codes: self.sq_codes.clone(),
787            sq_params: self.sq_params,
788        }
789    }
790
791    /// Restore from a snapshot, providing the metric.
792    pub(crate) fn from_snapshot(snapshot: HnswSnapshot, metric: D) -> Self {
793        Self {
794            config: snapshot.config,
795            metric,
796            nodes: snapshot.nodes,
797            entry_point: snapshot.entry_point,
798            max_level: snapshot.max_level,
799            rng: SmallRng::from_os_rng(),
800            sq_codes: snapshot.sq_codes,
801            sq_params: snapshot.sq_params,
802        }
803    }
804}
805
806/// Wrapper for f32 ordering in BinaryHeap (total order via bit comparison).
807#[derive(Clone, Copy)]
808struct OrdF32Entry(f32, u32);
809
810impl PartialEq for OrdF32Entry {
811    fn eq(&self, other: &Self) -> bool {
812        self.0.to_bits() == other.0.to_bits() && self.1 == other.1
813    }
814}
815
816impl Eq for OrdF32Entry {}
817
818impl PartialOrd for OrdF32Entry {
819    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
820        Some(self.cmp(other))
821    }
822}
823
824impl Ord for OrdF32Entry {
825    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
826        self.0.total_cmp(&other.0).then(self.1.cmp(&other.1))
827    }
828}
829
830/// Compute recall@k: fraction of true kNN found by approximate search.
831pub fn recall_at_k(approximate: &[(u32, f32)], ground_truth: &[(u32, f32)]) -> f64 {
832    let truth_set: std::collections::HashSet<u32> =
833        ground_truth.iter().map(|&(id, _)| id).collect();
834    let found = approximate
835        .iter()
836        .filter(|&&(id, _)| truth_set.contains(&id))
837        .count();
838    found as f64 / ground_truth.len().max(1) as f64
839}
840
841#[cfg(test)]
842mod tests {
843    use super::*;
844    use crate::metrics::{CosineDistance, L2Distance};
845
846    fn make_graph(m: usize, ef_c: usize, ef_s: usize) -> HnswGraph<L2Distance> {
847        let config = HnswConfig {
848            m,
849            ef_construction: ef_c,
850            ef_search: ef_s,
851            ..Default::default()
852        };
853        HnswGraph::new(config, L2Distance)
854    }
855
856    #[test]
857    fn empty_graph() {
858        let graph = make_graph(16, 200, 50);
859        assert!(graph.is_empty());
860        assert_eq!(graph.len(), 0);
861        assert_eq!(graph.search(&[1.0, 2.0], 5), vec![]);
862    }
863
864    #[test]
865    fn single_insert_and_search() {
866        let mut graph = make_graph(16, 200, 50);
867        graph.insert(0, &[1.0, 0.0, 0.0]);
868
869        let results = graph.search(&[1.0, 0.0, 0.0], 1);
870        assert_eq!(results.len(), 1);
871        assert_eq!(results[0].0, 0);
872        assert!(results[0].1 < 1e-5); // exact match
873    }
874
875    #[test]
876    fn three_vectors_correct_order() {
877        let mut graph = make_graph(16, 200, 50);
878        graph.insert(0, &[1.0, 0.0]);
879        graph.insert(1, &[0.9, 0.1]);
880        graph.insert(2, &[0.0, 1.0]);
881
882        let results = graph.search(&[1.0, 0.0], 3);
883        assert_eq!(results.len(), 3);
884        assert_eq!(results[0].0, 0); // exact match
885        assert_eq!(results[1].0, 1); // close
886        assert_eq!(results[2].0, 2); // far
887    }
888
889    #[test]
890    fn all_nodes_reachable_100() {
891        let mut graph = make_graph(16, 200, 50);
892        for i in 0..100u32 {
893            graph.insert(i, &[i as f32, (100 - i) as f32]);
894        }
895        assert_eq!(graph.count_reachable(), 100);
896    }
897
898    #[test]
899    fn all_nodes_reachable_1000() {
900        let mut graph = make_graph(16, 200, 50);
901        for i in 0..1000u32 {
902            let angle = (i as f32) * 0.1;
903            graph.insert(i, &[angle.cos(), angle.sin()]);
904        }
905        assert_eq!(graph.count_reachable(), 1000);
906    }
907
908    #[test]
909    fn recall_at_10_random_1k_d32() {
910        let dim = 32;
911        let n = 1000u32;
912        let mut graph = make_graph(16, 200, 50);
913
914        // Insert random vectors
915        let mut rng = rand::rng();
916        let vectors: Vec<Vec<f32>> = (0..n)
917            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
918            .collect();
919
920        for (i, v) in vectors.iter().enumerate() {
921            graph.insert(i as u32, v);
922        }
923
924        // Test recall on 100 random queries
925        let k = 10;
926        let n_queries = 100;
927        let mut total_recall = 0.0;
928
929        for _ in 0..n_queries {
930            let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
931            let approx = graph.search(&query, k);
932            let truth = graph.brute_force_knn(&query, k);
933            total_recall += recall_at_k(&approx, &truth);
934        }
935
936        let avg_recall = total_recall / n_queries as f64;
937        assert!(
938            avg_recall >= 0.90,
939            "recall@10 = {avg_recall:.3}, expected >= 0.90"
940        );
941    }
942
943    #[test]
944    fn recall_at_10_random_10k_d128() {
945        let dim = 128;
946        let n = 10_000u32;
947        let k = 10;
948        let mut graph = HnswGraph::new(
949            HnswConfig {
950                m: 16,
951                ef_construction: 200,
952                ef_search: 200, // higher ef_search for better recall
953                ..Default::default()
954            },
955            L2Distance,
956        );
957
958        let mut rng = rand::rng();
959        let vectors: Vec<Vec<f32>> = (0..n)
960            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
961            .collect();
962
963        for (i, v) in vectors.iter().enumerate() {
964            graph.insert(i as u32, v);
965        }
966
967        let reachable = graph.count_reachable();
968        // Reachability may not be 100% due to pruning creating components.
969        // This is a known HNSW limitation that improves with higher M and ef_construction.
970        // We check that at least 98% of nodes are reachable.
971        assert!(
972            reachable >= (n as usize) * 98 / 100,
973            "reachable: {reachable} / {n}, expected >= 98%"
974        );
975
976        let n_queries = 50;
977        let mut total_recall = 0.0;
978
979        for _ in 0..n_queries {
980            let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
981            let approx = graph.search(&query, k);
982            let truth = graph.brute_force_knn(&query, k);
983            total_recall += recall_at_k(&approx, &truth);
984        }
985
986        let avg_recall = total_recall / n_queries as f64;
987        assert!(
988            avg_recall >= 0.85,
989            "recall@10 on 10K D=128 = {avg_recall:.3}, expected >= 0.85"
990        );
991    }
992
993    #[test]
994    fn works_with_cosine_distance() {
995        let config = HnswConfig {
996            m: 16,
997            ef_construction: 100,
998            ef_search: 50,
999            ..Default::default()
1000        };
1001        let mut graph = HnswGraph::new(config, CosineDistance);
1002
1003        graph.insert(0, &[1.0, 0.0, 0.0]);
1004        graph.insert(1, &[0.99, 0.01, 0.0]);
1005        graph.insert(2, &[0.0, 0.0, 1.0]);
1006
1007        let results = graph.search(&[1.0, 0.0, 0.0], 2);
1008        assert_eq!(results[0].0, 0);
1009        assert_eq!(results[1].0, 1);
1010    }
1011
1012    #[test]
1013    fn search_k_larger_than_n() {
1014        let mut graph = make_graph(16, 200, 50);
1015        graph.insert(0, &[1.0, 0.0]);
1016        graph.insert(1, &[0.0, 1.0]);
1017
1018        let results = graph.search(&[1.0, 0.0], 10);
1019        assert_eq!(results.len(), 2); // only 2 nodes exist
1020    }
1021
1022    #[test]
1023    fn recall_helper_correct() {
1024        let approx = vec![(0, 0.1), (1, 0.2), (2, 0.3)];
1025        let truth = vec![(0, 0.1), (1, 0.2), (3, 0.25)];
1026        assert!((recall_at_k(&approx, &truth) - 2.0 / 3.0).abs() < 1e-10);
1027    }
1028
1029    // ─── Accessor coverage ────────────────────────────────────────
1030
1031    #[test]
1032    fn vector_accessor() {
1033        let mut graph = make_graph(16, 200, 50);
1034        graph.insert(0, &[1.0, 2.0, 3.0]);
1035        graph.insert(1, &[4.0, 5.0, 6.0]);
1036        assert_eq!(graph.vector(0), &[1.0, 2.0, 3.0]);
1037        assert_eq!(graph.vector(1), &[4.0, 5.0, 6.0]);
1038    }
1039
1040    #[test]
1041    fn nodes_at_level() {
1042        let mut graph = make_graph(4, 50, 50);
1043        // Insert enough nodes to get some at level 1+
1044        for i in 0..200u32 {
1045            graph.insert(i, &[i as f32, (200 - i) as f32]);
1046        }
1047        let l0 = graph.nodes_at_level(0);
1048        assert_eq!(l0.len(), 200);
1049        let l1 = graph.nodes_at_level(1);
1050        assert!(!l1.is_empty(), "should have some level-1 nodes");
1051        assert!(l1.len() < 200, "level 1 should be sparser than level 0");
1052    }
1053
1054    #[test]
1055    fn assign_region_returns_hub() {
1056        let mut graph = make_graph(4, 50, 50);
1057        let mut rng = rand::rng();
1058        for i in 0..200u32 {
1059            let v: Vec<f32> = (0..8).map(|_| rng.random::<f32>()).collect();
1060            graph.insert(i, &v);
1061        }
1062
1063        let query: Vec<f32> = (0..8).map(|_| rng.random::<f32>()).collect();
1064        let hub = graph.assign_region(&query, 1);
1065        assert!(hub.is_some(), "should find a hub at level 1");
1066
1067        // Hub should be a node at level 1
1068        let level1_nodes = graph.nodes_at_level(1);
1069        assert!(level1_nodes.contains(&hub.unwrap()));
1070    }
1071
1072    #[test]
1073    fn search_filtered_respects_predicate() {
1074        let mut graph = make_graph(16, 200, 100);
1075        for i in 0..100u32 {
1076            graph.insert(i, &[i as f32, 0.0]);
1077        }
1078
1079        // Only allow even-numbered nodes
1080        let results = graph.search_filtered(&[50.0, 0.0], 5, |id| id % 2 == 0);
1081        assert_eq!(results.len(), 5);
1082        for &(id, _) in &results {
1083            assert_eq!(id % 2, 0, "node {id} should be even");
1084        }
1085    }
1086
1087    #[test]
1088    fn snapshot_round_trip() {
1089        let mut graph = make_graph(16, 100, 50);
1090        for i in 0..50u32 {
1091            graph.insert(i, &[i as f32, (50 - i) as f32]);
1092        }
1093
1094        let snapshot = graph.to_snapshot();
1095        let restored = HnswGraph::from_snapshot(snapshot, L2Distance);
1096
1097        assert_eq!(restored.len(), 50);
1098        assert_eq!(restored.vector(0), &[0.0, 50.0]);
1099        assert_eq!(restored.vector(49), &[49.0, 1.0]);
1100
1101        // Search should work on restored graph
1102        let results = restored.search(&[25.0, 25.0], 3);
1103        assert_eq!(results.len(), 3);
1104    }
1105
1106    #[test]
1107    fn distance_to_node() {
1108        let mut graph = make_graph(16, 200, 50);
1109        graph.insert(0, &[1.0, 0.0]);
1110        graph.insert(1, &[0.0, 1.0]);
1111
1112        let d = graph.distance_to(0, &[1.0, 0.0]);
1113        assert!(d < 1e-5, "distance to self should be ~0, got {d}");
1114
1115        let d2 = graph.distance_to(1, &[1.0, 0.0]);
1116        assert!(d2 > 1.0, "distance to orthogonal should be > 1, got {d2}");
1117    }
1118
1119    /// 100K vectors D=128, recall@10 ≥ 0.95 (Layer 2 exit criterion)
1120    #[test]
1121    #[ignore] // slow: ~10s, run with `cargo test -- --ignored`
1122    fn recall_100k_d128() {
1123        let dim = 128;
1124        let n = 100_000u32;
1125        let k = 10;
1126        let mut graph = HnswGraph::new(
1127            HnswConfig {
1128                m: 16,
1129                ef_construction: 200,
1130                ef_search: 100,
1131                ..Default::default()
1132            },
1133            L2Distance,
1134        );
1135
1136        let mut rng = rand::rng();
1137        let vectors: Vec<Vec<f32>> = (0..n)
1138            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
1139            .collect();
1140
1141        for (i, v) in vectors.iter().enumerate() {
1142            graph.insert(i as u32, v);
1143        }
1144
1145        assert_eq!(
1146            graph.count_reachable(),
1147            n as usize,
1148            "not all nodes reachable"
1149        );
1150
1151        let n_queries = 100;
1152        let mut total_recall = 0.0;
1153
1154        for _ in 0..n_queries {
1155            let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
1156            let approx = graph.search(&query, k);
1157            let truth = graph.brute_force_knn(&query, k);
1158            total_recall += recall_at_k(&approx, &truth);
1159        }
1160
1161        let avg_recall = total_recall / n_queries as f64;
1162        assert!(
1163            avg_recall >= 0.95,
1164            "recall@10 on 100K D=128 = {avg_recall:.3}, expected >= 0.95"
1165        );
1166    }
1167}