1pub mod bayesian_scorer;
37pub mod concurrent;
38pub mod metadata_store;
39pub mod optimized;
40pub mod partitioned;
41pub mod region_mdp;
42pub mod streaming;
43pub mod temporal;
44pub mod temporal_edges;
45pub mod temporal_graph;
46pub mod temporal_lsh;
47pub mod typed_edges;
48
49pub use bayesian_scorer::{CandidateFeatures, ScoringWeights, WeightLearner};
50pub use concurrent::ConcurrentTemporalHnsw;
51pub use region_mdp::RegionMdp;
52pub use temporal::TemporalHnsw;
53pub use temporal_edges::TemporalEdgeLayer;
54pub use temporal_graph::{CausalSearchResult, TemporalGraphIndex};
55pub use typed_edges::{EdgeType, TypedEdgeStore};
56
57use std::cmp::Reverse;
58use std::collections::BinaryHeap;
59
60use cvx_core::DistanceMetric;
61use rand::rngs::SmallRng;
62use rand::{Rng, SeedableRng};
63use serde::{Deserialize, Serialize};
64use smallvec::SmallVec;
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct HnswConfig {
69 pub m: usize,
71 pub ef_construction: usize,
73 pub ef_search: usize,
75 pub max_level: usize,
77 pub level_mult: f64,
79 pub use_heuristic: bool,
82}
83
84impl Default for HnswConfig {
85 fn default() -> Self {
86 let m = 16;
87 Self {
88 m,
89 ef_construction: 200,
90 ef_search: 50,
91 max_level: 0,
92 level_mult: 1.0 / (m as f64).ln(),
93 use_heuristic: true,
94 }
95 }
96}
97
98type NeighborList = SmallVec<[u32; 16]>;
100
101#[derive(Serialize, Deserialize)]
103pub(crate) struct HnswNode {
104 vector: Vec<f32>,
106 neighbors: Vec<NeighborList>,
109}
110
111impl optimized::NodeVectors for [HnswNode] {
112 fn get_vector(&self, id: u32) -> &[f32] {
113 &self[id as usize].vector
114 }
115}
116
117pub struct HnswGraph<D: DistanceMetric> {
124 config: HnswConfig,
125 metric: D,
126 nodes: Vec<HnswNode>,
127 entry_point: Option<u32>,
128 max_level: usize,
129 rng: SmallRng,
130 sq_codes: Option<Vec<Vec<u8>>>,
133 sq_params: (f32, f32),
135}
136
137impl<D: DistanceMetric> HnswGraph<D> {
138 pub fn new(config: HnswConfig, metric: D) -> Self {
140 Self {
141 config,
142 metric,
143 nodes: Vec::new(),
144 entry_point: None,
145 max_level: 0,
146 rng: SmallRng::from_os_rng(),
147 sq_codes: None,
148 sq_params: (-1.0, 127.5), }
150 }
151
152 pub fn enable_scalar_quantization(&mut self, min_val: f32, max_val: f32) {
161 let range = max_val - min_val;
162 self.sq_params = (min_val, if range > 0.0 { 255.0 / range } else { 1.0 });
163
164 let codes: Vec<Vec<u8>> = self
166 .nodes
167 .iter()
168 .map(|node| Self::encode_sq(&node.vector, self.sq_params.0, self.sq_params.1))
169 .collect();
170 self.sq_codes = Some(codes);
171 }
172
173 pub fn disable_scalar_quantization(&mut self) {
175 self.sq_codes = None;
176 }
177
178 pub fn is_quantized(&self) -> bool {
180 self.sq_codes.is_some()
181 }
182
183 #[inline]
185 fn encode_sq(vector: &[f32], min_val: f32, scale: f32) -> Vec<u8> {
186 vector
187 .iter()
188 .map(|&v| ((v - min_val) * scale).clamp(0.0, 255.0) as u8)
189 .collect()
190 }
191
192 #[inline]
194 fn distance_sq(a: &[u8], b: &[u8]) -> f32 {
195 let mut sum: u32 = 0;
196 for i in 0..a.len() {
197 let diff = a[i] as i32 - b[i] as i32;
198 sum += (diff * diff) as u32;
199 }
200 sum as f32 }
202
203 pub fn len(&self) -> usize {
205 self.nodes.len()
206 }
207
208 pub fn is_empty(&self) -> bool {
210 self.nodes.is_empty()
211 }
212
213 pub fn set_ef_construction(&mut self, ef: usize) {
215 self.config.ef_construction = ef;
216 }
217
218 pub fn set_ef_search(&mut self, ef: usize) {
220 self.config.ef_search = ef;
221 }
222
223 pub(crate) fn random_level(&mut self) -> usize {
227 let r: f64 = self.rng.random();
228 let level = (-r.ln() * self.config.level_mult).floor() as usize;
229 level.min(32)
230 }
231
232 fn max_neighbors(&self, level: usize) -> usize {
234 if level == 0 {
235 self.config.m * 2
236 } else {
237 self.config.m
238 }
239 }
240
241 pub(crate) fn push_node(&mut self, vector: &[f32], level: usize) {
243 let node = HnswNode {
244 vector: vector.to_vec(),
245 neighbors: (0..=level).map(|_| NeighborList::new()).collect(),
246 };
247 self.nodes.push(node);
248
249 if let Some(ref mut codes) = self.sq_codes {
250 codes.push(Self::encode_sq(vector, self.sq_params.0, self.sq_params.1));
251 }
252
253 if self.nodes.len() == 1 {
254 self.entry_point = Some(0);
255 self.max_level = level;
256 } else if level > self.max_level {
257 self.entry_point = Some((self.nodes.len() - 1) as u32);
258 self.max_level = level;
259 }
260 }
261
262 pub(crate) fn connect_node(&mut self, id: u32, candidates: &[(u32, f32)], level: usize) {
264 let insert_from = level.min(self.max_level);
265 for lev in (0..=insert_from).rev() {
266 let max_n = self.max_neighbors(lev);
267 let selected: Vec<u32> = if self.config.use_heuristic {
268 optimized::select_neighbors_heuristic(
269 &self.metric,
270 candidates,
271 self.nodes.as_slice(),
272 max_n,
273 false,
274 )
275 } else {
276 candidates.iter().take(max_n).map(|&(n, _)| n).collect()
277 };
278
279 for &neighbor_id in &selected {
280 if neighbor_id == id {
281 continue;
282 }
283 if lev < self.nodes[id as usize].neighbors.len()
284 && !self.nodes[id as usize].neighbors[lev].contains(&neighbor_id)
285 {
286 self.nodes[id as usize].neighbors[lev].push(neighbor_id);
287 }
288 if lev < self.nodes[neighbor_id as usize].neighbors.len()
289 && !self.nodes[neighbor_id as usize].neighbors[lev].contains(&id)
290 {
291 self.nodes[neighbor_id as usize].neighbors[lev].push(id);
292 let count = self.nodes[neighbor_id as usize].neighbors[lev].len();
293 if count > max_n {
294 self.prune_neighbors(neighbor_id, lev, max_n);
295 }
296 }
297 }
298 }
299 }
300
301 pub fn insert(&mut self, id: u32, vector: &[f32]) {
306 assert_eq!(
307 id as usize,
308 self.nodes.len(),
309 "must insert sequentially: expected id {}, got {id}",
310 self.nodes.len()
311 );
312
313 let level = self.random_level();
314 let node = HnswNode {
315 vector: vector.to_vec(),
316 neighbors: (0..=level).map(|_| NeighborList::new()).collect(),
317 };
318 self.nodes.push(node);
319
320 if let Some(ref mut codes) = self.sq_codes {
322 codes.push(Self::encode_sq(vector, self.sq_params.0, self.sq_params.1));
323 }
324
325 if self.nodes.len() == 1 {
327 self.entry_point = Some(0);
328 self.max_level = level;
329 return;
330 }
331
332 let entry = self.entry_point.unwrap();
333 let mut current = entry;
334
335 for lev in (level + 1..=self.max_level).rev() {
337 current = self.greedy_closest(current, vector, lev);
338 }
339
340 let insert_from = level.min(self.max_level);
342 for lev in (0..=insert_from).rev() {
343 let neighbors = self.search_layer(current, vector, self.config.ef_construction, lev);
344
345 let max_n = self.max_neighbors(lev);
347 let mut selected: Vec<u32> = if self.config.use_heuristic {
348 optimized::select_neighbors_heuristic(
349 &self.metric,
350 &neighbors,
351 self.nodes.as_slice(),
352 max_n,
353 false,
354 )
355 } else {
356 neighbors.iter().take(max_n).map(|&(n, _)| n).collect()
357 };
358
359 if selected.is_empty() {
361 selected.push(current);
362 }
363
364 for &neighbor_id in &selected {
366 if neighbor_id == id {
368 continue;
369 }
370 if !self.nodes[id as usize].neighbors[lev].contains(&neighbor_id) {
372 self.nodes[id as usize].neighbors[lev].push(neighbor_id);
373 }
374 if !self.nodes[neighbor_id as usize].neighbors[lev].contains(&id) {
375 self.nodes[neighbor_id as usize].neighbors[lev].push(id);
376 }
377
378 let neighbor_count = self.nodes[neighbor_id as usize].neighbors[lev].len();
380 if neighbor_count > max_n {
381 self.prune_neighbors(neighbor_id, lev, max_n);
382 }
383 }
384
385 if let Some(&(closest, _)) = neighbors.first() {
387 current = closest;
388 }
389 }
390
391 if self.nodes[id as usize].neighbors[0].is_empty() {
394 let mut best_id = entry;
396 let mut best_dist = self.distance(entry, vector);
397 for i in 0..self.nodes.len() - 1 {
398 let d = self.distance(i as u32, vector);
399 if d < best_dist {
400 best_dist = d;
401 best_id = i as u32;
402 }
403 }
404 self.nodes[id as usize].neighbors[0].push(best_id);
405 self.nodes[best_id as usize].neighbors[0].push(id);
406 }
407
408 if level > self.max_level {
410 self.entry_point = Some(id);
411 self.max_level = level;
412 }
413 }
414
415 fn greedy_closest(&self, start: u32, query: &[f32], level: usize) -> u32 {
417 let query_code = self
418 .sq_codes
419 .as_ref()
420 .map(|_| Self::encode_sq(query, self.sq_params.0, self.sq_params.1));
421 let qc = query_code.as_deref();
422
423 let mut current = start;
424 let mut current_dist = self.distance_fast(current, qc, query);
425
426 loop {
427 let mut changed = false;
428 let neighbors = self.neighbors_at(current, level);
429 for &neighbor in neighbors {
430 let dist = self.distance_fast(neighbor, qc, query);
431 if dist < current_dist {
432 current = neighbor;
433 current_dist = dist;
434 changed = true;
435 }
436 }
437 if !changed {
438 return current;
439 }
440 }
441 }
442
443 fn search_layer(&self, entry: u32, query: &[f32], ef: usize, level: usize) -> Vec<(u32, f32)> {
448 let query_code = self
450 .sq_codes
451 .as_ref()
452 .map(|_| Self::encode_sq(query, self.sq_params.0, self.sq_params.1));
453 let qc = query_code.as_deref();
454
455 let entry_dist = self.distance_fast(entry, qc, query);
456
457 let mut candidates: BinaryHeap<Reverse<OrdF32Entry>> = BinaryHeap::new();
459 let mut results: BinaryHeap<OrdF32Entry> = BinaryHeap::new();
461 let mut visited = vec![false; self.nodes.len()];
463
464 candidates.push(Reverse(OrdF32Entry(entry_dist, entry)));
465 results.push(OrdF32Entry(entry_dist, entry));
466 visited[entry as usize] = true;
467
468 while let Some(Reverse(OrdF32Entry(c_dist, c_id))) = candidates.pop() {
469 let farthest_result = results.peek().map(|e| e.0).unwrap_or(f32::INFINITY);
471 if c_dist > farthest_result {
472 break;
473 }
474
475 let neighbors = self.neighbors_at(c_id, level);
476 for &neighbor in neighbors {
477 if visited[neighbor as usize] {
478 continue;
479 }
480 visited[neighbor as usize] = true;
481
482 let dist = self.distance_fast(neighbor, qc, query);
483 let farthest_result = results.peek().map(|e| e.0).unwrap_or(f32::INFINITY);
484
485 if dist < farthest_result || results.len() < ef {
486 candidates.push(Reverse(OrdF32Entry(dist, neighbor)));
487 results.push(OrdF32Entry(dist, neighbor));
488 if results.len() > ef {
489 results.pop(); }
491 }
492 }
493 }
494
495 let mut result_vec: Vec<(u32, f32)> = if self.sq_codes.is_some() {
497 results
498 .into_iter()
499 .map(|e| (e.1, self.distance(e.1, query)))
500 .collect()
501 } else {
502 results.into_iter().map(|e| (e.1, e.0)).collect()
503 };
504 result_vec.sort_by(|a, b| a.1.total_cmp(&b.1));
505 result_vec
506 }
507
508 fn prune_neighbors(&mut self, node_id: u32, level: usize, max_n: usize) {
513 let node_vec = self.nodes[node_id as usize].vector.clone();
514 let scored: Vec<(u32, f32)> = self.nodes[node_id as usize].neighbors[level]
515 .iter()
516 .map(|&n| {
517 (
518 n,
519 self.metric
520 .distance(&node_vec, &self.nodes[n as usize].vector),
521 )
522 })
523 .collect();
524
525 let pruned = if self.config.use_heuristic {
526 optimized::select_neighbors_heuristic(
527 &self.metric,
528 &scored,
529 self.nodes.as_slice(),
530 max_n,
531 false,
532 )
533 } else {
534 let mut s = scored;
535 s.sort_by(|a, b| a.1.total_cmp(&b.1));
536 s.truncate(max_n);
537 s.iter().map(|&(n, _)| n).collect()
538 };
539
540 self.nodes[node_id as usize].neighbors[level] = pruned.into_iter().collect();
541 }
542
543 pub fn search(&self, query: &[f32], k: usize) -> Vec<(u32, f32)> {
547 if self.nodes.is_empty() {
548 return Vec::new();
549 }
550
551 let entry = self.entry_point.unwrap();
552 let mut current = entry;
553
554 for lev in (1..=self.max_level).rev() {
556 current = self.greedy_closest(current, query, lev);
557 }
558
559 let mut results = self.search_layer(current, query, self.config.ef_search.max(k), 0);
561 results.truncate(k);
562 results
563 }
564
565 pub fn search_filtered(
571 &self,
572 query: &[f32],
573 k: usize,
574 filter: impl Fn(u32) -> bool,
575 ) -> Vec<(u32, f32)> {
576 if self.nodes.is_empty() {
577 return Vec::new();
578 }
579
580 let entry = self.entry_point.unwrap();
581 let mut current = entry;
582
583 for lev in (1..=self.max_level).rev() {
585 current = self.greedy_closest(current, query, lev);
586 }
587
588 let ef = self.config.ef_search.max(k * 4); let all_candidates = self.search_layer(current, query, ef, 0);
591
592 let mut results: Vec<(u32, f32)> = all_candidates
594 .into_iter()
595 .filter(|&(id, _)| filter(id))
596 .collect();
597 results.truncate(k);
598 results
599 }
600
601 pub fn vector(&self, node_id: u32) -> &[f32] {
603 &self.nodes[node_id as usize].vector
604 }
605
606 pub fn config(&self) -> &HnswConfig {
608 &self.config
609 }
610
611 pub fn entry_point(&self) -> Option<u32> {
613 self.entry_point
614 }
615
616 pub fn max_level(&self) -> usize {
618 self.max_level
619 }
620
621 pub fn neighbors_at_level(&self, node_id: u32, level: usize) -> &[u32] {
623 self.neighbors_at(node_id, level)
624 }
625
626 pub fn distance_to(&self, node_id: u32, query: &[f32]) -> f32 {
628 self.distance(node_id, query)
629 }
630
631 pub fn all_neighbors(&self, node_id: u32) -> Vec<Vec<u32>> {
633 let node = &self.nodes[node_id as usize];
634 node.neighbors.iter().map(|n| n.to_vec()).collect()
635 }
636
637 pub fn nodes_at_level(&self, level: usize) -> Vec<u32> {
643 (0..self.nodes.len() as u32)
644 .filter(|&id| self.nodes[id as usize].neighbors.len() > level)
645 .collect()
646 }
647
648 pub fn assign_region(&self, vector: &[f32], level: usize) -> Option<u32> {
653 if self.nodes.is_empty() {
654 return None;
655 }
656
657 let entry = self.entry_point.unwrap();
658 let mut current = entry;
659
660 for lev in (level + 1..=self.max_level).rev() {
662 current = self.greedy_closest(current, vector, lev);
663 }
664
665 if level <= self.max_level {
667 current = self.greedy_closest(current, vector, level);
668 }
669
670 if self.nodes[current as usize].neighbors.len() > level {
672 Some(current)
673 } else {
674 let hubs = self.nodes_at_level(level);
676 hubs.into_iter().min_by(|&a, &b| {
677 self.distance(a, vector)
678 .total_cmp(&self.distance(b, vector))
679 })
680 }
681 }
682
683 #[inline]
689 fn distance(&self, node_id: u32, query: &[f32]) -> f32 {
690 self.metric
691 .distance(&self.nodes[node_id as usize].vector, query)
692 }
693
694 #[inline]
698 fn distance_fast(&self, node_id: u32, query_code: Option<&[u8]>, query: &[f32]) -> f32 {
699 if let (Some(codes), Some(qc)) = (&self.sq_codes, query_code) {
700 Self::distance_sq(&codes[node_id as usize], qc)
701 } else {
702 self.distance(node_id, query)
703 }
704 }
705
706 #[inline]
708 fn neighbors_at(&self, node_id: u32, level: usize) -> &[u32] {
709 let node = &self.nodes[node_id as usize];
710 if level < node.neighbors.len() {
711 &node.neighbors[level]
712 } else {
713 &[]
714 }
715 }
716
717 pub fn count_reachable(&self) -> usize {
721 if self.nodes.is_empty() {
722 return 0;
723 }
724 let entry = self.entry_point.unwrap();
725 let mut visited = vec![false; self.nodes.len()];
726 let mut stack = vec![entry];
727 visited[entry as usize] = true;
728 let mut count = 1usize;
729
730 while let Some(node) = stack.pop() {
731 for &neighbor in self.neighbors_at(node, 0) {
732 if !visited[neighbor as usize] {
733 visited[neighbor as usize] = true;
734 count += 1;
735 stack.push(neighbor);
736 }
737 }
738 }
739 count
740 }
741
742 pub fn brute_force_knn(&self, query: &[f32], k: usize) -> Vec<(u32, f32)> {
746 let mut all: Vec<(u32, f32)> = self
747 .nodes
748 .iter()
749 .enumerate()
750 .map(|(i, node)| (i as u32, self.metric.distance(&node.vector, query)))
751 .collect();
752 all.sort_by(|a, b| a.1.total_cmp(&b.1));
753 all.truncate(k);
754 all
755 }
756}
757
758#[derive(Serialize, Deserialize)]
762pub(crate) struct HnswSnapshot {
763 pub(crate) config: HnswConfig,
764 pub(crate) nodes: Vec<HnswNode>,
765 pub(crate) entry_point: Option<u32>,
766 pub(crate) max_level: usize,
767 pub(crate) sq_codes: Option<Vec<Vec<u8>>>,
768 pub(crate) sq_params: (f32, f32),
769}
770
771impl<D: DistanceMetric> HnswGraph<D> {
772 pub(crate) fn to_snapshot(&self) -> HnswSnapshot {
774 HnswSnapshot {
775 config: self.config.clone(),
776 nodes: self
777 .nodes
778 .iter()
779 .map(|n| HnswNode {
780 vector: n.vector.clone(),
781 neighbors: n.neighbors.clone(),
782 })
783 .collect(),
784 entry_point: self.entry_point,
785 max_level: self.max_level,
786 sq_codes: self.sq_codes.clone(),
787 sq_params: self.sq_params,
788 }
789 }
790
791 pub(crate) fn from_snapshot(snapshot: HnswSnapshot, metric: D) -> Self {
793 Self {
794 config: snapshot.config,
795 metric,
796 nodes: snapshot.nodes,
797 entry_point: snapshot.entry_point,
798 max_level: snapshot.max_level,
799 rng: SmallRng::from_os_rng(),
800 sq_codes: snapshot.sq_codes,
801 sq_params: snapshot.sq_params,
802 }
803 }
804}
805
806#[derive(Clone, Copy)]
808struct OrdF32Entry(f32, u32);
809
810impl PartialEq for OrdF32Entry {
811 fn eq(&self, other: &Self) -> bool {
812 self.0.to_bits() == other.0.to_bits() && self.1 == other.1
813 }
814}
815
816impl Eq for OrdF32Entry {}
817
818impl PartialOrd for OrdF32Entry {
819 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
820 Some(self.cmp(other))
821 }
822}
823
824impl Ord for OrdF32Entry {
825 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
826 self.0.total_cmp(&other.0).then(self.1.cmp(&other.1))
827 }
828}
829
830pub fn recall_at_k(approximate: &[(u32, f32)], ground_truth: &[(u32, f32)]) -> f64 {
832 let truth_set: std::collections::HashSet<u32> =
833 ground_truth.iter().map(|&(id, _)| id).collect();
834 let found = approximate
835 .iter()
836 .filter(|&&(id, _)| truth_set.contains(&id))
837 .count();
838 found as f64 / ground_truth.len().max(1) as f64
839}
840
841#[cfg(test)]
842mod tests {
843 use super::*;
844 use crate::metrics::{CosineDistance, L2Distance};
845
846 fn make_graph(m: usize, ef_c: usize, ef_s: usize) -> HnswGraph<L2Distance> {
847 let config = HnswConfig {
848 m,
849 ef_construction: ef_c,
850 ef_search: ef_s,
851 ..Default::default()
852 };
853 HnswGraph::new(config, L2Distance)
854 }
855
856 #[test]
857 fn empty_graph() {
858 let graph = make_graph(16, 200, 50);
859 assert!(graph.is_empty());
860 assert_eq!(graph.len(), 0);
861 assert_eq!(graph.search(&[1.0, 2.0], 5), vec![]);
862 }
863
864 #[test]
865 fn single_insert_and_search() {
866 let mut graph = make_graph(16, 200, 50);
867 graph.insert(0, &[1.0, 0.0, 0.0]);
868
869 let results = graph.search(&[1.0, 0.0, 0.0], 1);
870 assert_eq!(results.len(), 1);
871 assert_eq!(results[0].0, 0);
872 assert!(results[0].1 < 1e-5); }
874
875 #[test]
876 fn three_vectors_correct_order() {
877 let mut graph = make_graph(16, 200, 50);
878 graph.insert(0, &[1.0, 0.0]);
879 graph.insert(1, &[0.9, 0.1]);
880 graph.insert(2, &[0.0, 1.0]);
881
882 let results = graph.search(&[1.0, 0.0], 3);
883 assert_eq!(results.len(), 3);
884 assert_eq!(results[0].0, 0); assert_eq!(results[1].0, 1); assert_eq!(results[2].0, 2); }
888
889 #[test]
890 fn all_nodes_reachable_100() {
891 let mut graph = make_graph(16, 200, 50);
892 for i in 0..100u32 {
893 graph.insert(i, &[i as f32, (100 - i) as f32]);
894 }
895 assert_eq!(graph.count_reachable(), 100);
896 }
897
898 #[test]
899 fn all_nodes_reachable_1000() {
900 let mut graph = make_graph(16, 200, 50);
901 for i in 0..1000u32 {
902 let angle = (i as f32) * 0.1;
903 graph.insert(i, &[angle.cos(), angle.sin()]);
904 }
905 assert_eq!(graph.count_reachable(), 1000);
906 }
907
908 #[test]
909 fn recall_at_10_random_1k_d32() {
910 let dim = 32;
911 let n = 1000u32;
912 let mut graph = make_graph(16, 200, 50);
913
914 let mut rng = rand::rng();
916 let vectors: Vec<Vec<f32>> = (0..n)
917 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
918 .collect();
919
920 for (i, v) in vectors.iter().enumerate() {
921 graph.insert(i as u32, v);
922 }
923
924 let k = 10;
926 let n_queries = 100;
927 let mut total_recall = 0.0;
928
929 for _ in 0..n_queries {
930 let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
931 let approx = graph.search(&query, k);
932 let truth = graph.brute_force_knn(&query, k);
933 total_recall += recall_at_k(&approx, &truth);
934 }
935
936 let avg_recall = total_recall / n_queries as f64;
937 assert!(
938 avg_recall >= 0.90,
939 "recall@10 = {avg_recall:.3}, expected >= 0.90"
940 );
941 }
942
943 #[test]
944 fn recall_at_10_random_10k_d128() {
945 let dim = 128;
946 let n = 10_000u32;
947 let k = 10;
948 let mut graph = HnswGraph::new(
949 HnswConfig {
950 m: 16,
951 ef_construction: 200,
952 ef_search: 200, ..Default::default()
954 },
955 L2Distance,
956 );
957
958 let mut rng = rand::rng();
959 let vectors: Vec<Vec<f32>> = (0..n)
960 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
961 .collect();
962
963 for (i, v) in vectors.iter().enumerate() {
964 graph.insert(i as u32, v);
965 }
966
967 let reachable = graph.count_reachable();
968 assert!(
972 reachable >= (n as usize) * 98 / 100,
973 "reachable: {reachable} / {n}, expected >= 98%"
974 );
975
976 let n_queries = 50;
977 let mut total_recall = 0.0;
978
979 for _ in 0..n_queries {
980 let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
981 let approx = graph.search(&query, k);
982 let truth = graph.brute_force_knn(&query, k);
983 total_recall += recall_at_k(&approx, &truth);
984 }
985
986 let avg_recall = total_recall / n_queries as f64;
987 assert!(
988 avg_recall >= 0.85,
989 "recall@10 on 10K D=128 = {avg_recall:.3}, expected >= 0.85"
990 );
991 }
992
993 #[test]
994 fn works_with_cosine_distance() {
995 let config = HnswConfig {
996 m: 16,
997 ef_construction: 100,
998 ef_search: 50,
999 ..Default::default()
1000 };
1001 let mut graph = HnswGraph::new(config, CosineDistance);
1002
1003 graph.insert(0, &[1.0, 0.0, 0.0]);
1004 graph.insert(1, &[0.99, 0.01, 0.0]);
1005 graph.insert(2, &[0.0, 0.0, 1.0]);
1006
1007 let results = graph.search(&[1.0, 0.0, 0.0], 2);
1008 assert_eq!(results[0].0, 0);
1009 assert_eq!(results[1].0, 1);
1010 }
1011
1012 #[test]
1013 fn search_k_larger_than_n() {
1014 let mut graph = make_graph(16, 200, 50);
1015 graph.insert(0, &[1.0, 0.0]);
1016 graph.insert(1, &[0.0, 1.0]);
1017
1018 let results = graph.search(&[1.0, 0.0], 10);
1019 assert_eq!(results.len(), 2); }
1021
1022 #[test]
1023 fn recall_helper_correct() {
1024 let approx = vec![(0, 0.1), (1, 0.2), (2, 0.3)];
1025 let truth = vec![(0, 0.1), (1, 0.2), (3, 0.25)];
1026 assert!((recall_at_k(&approx, &truth) - 2.0 / 3.0).abs() < 1e-10);
1027 }
1028
1029 #[test]
1032 fn vector_accessor() {
1033 let mut graph = make_graph(16, 200, 50);
1034 graph.insert(0, &[1.0, 2.0, 3.0]);
1035 graph.insert(1, &[4.0, 5.0, 6.0]);
1036 assert_eq!(graph.vector(0), &[1.0, 2.0, 3.0]);
1037 assert_eq!(graph.vector(1), &[4.0, 5.0, 6.0]);
1038 }
1039
1040 #[test]
1041 fn nodes_at_level() {
1042 let mut graph = make_graph(4, 50, 50);
1043 for i in 0..200u32 {
1045 graph.insert(i, &[i as f32, (200 - i) as f32]);
1046 }
1047 let l0 = graph.nodes_at_level(0);
1048 assert_eq!(l0.len(), 200);
1049 let l1 = graph.nodes_at_level(1);
1050 assert!(!l1.is_empty(), "should have some level-1 nodes");
1051 assert!(l1.len() < 200, "level 1 should be sparser than level 0");
1052 }
1053
1054 #[test]
1055 fn assign_region_returns_hub() {
1056 let mut graph = make_graph(4, 50, 50);
1057 let mut rng = rand::rng();
1058 for i in 0..200u32 {
1059 let v: Vec<f32> = (0..8).map(|_| rng.random::<f32>()).collect();
1060 graph.insert(i, &v);
1061 }
1062
1063 let query: Vec<f32> = (0..8).map(|_| rng.random::<f32>()).collect();
1064 let hub = graph.assign_region(&query, 1);
1065 assert!(hub.is_some(), "should find a hub at level 1");
1066
1067 let level1_nodes = graph.nodes_at_level(1);
1069 assert!(level1_nodes.contains(&hub.unwrap()));
1070 }
1071
1072 #[test]
1073 fn search_filtered_respects_predicate() {
1074 let mut graph = make_graph(16, 200, 100);
1075 for i in 0..100u32 {
1076 graph.insert(i, &[i as f32, 0.0]);
1077 }
1078
1079 let results = graph.search_filtered(&[50.0, 0.0], 5, |id| id % 2 == 0);
1081 assert_eq!(results.len(), 5);
1082 for &(id, _) in &results {
1083 assert_eq!(id % 2, 0, "node {id} should be even");
1084 }
1085 }
1086
1087 #[test]
1088 fn snapshot_round_trip() {
1089 let mut graph = make_graph(16, 100, 50);
1090 for i in 0..50u32 {
1091 graph.insert(i, &[i as f32, (50 - i) as f32]);
1092 }
1093
1094 let snapshot = graph.to_snapshot();
1095 let restored = HnswGraph::from_snapshot(snapshot, L2Distance);
1096
1097 assert_eq!(restored.len(), 50);
1098 assert_eq!(restored.vector(0), &[0.0, 50.0]);
1099 assert_eq!(restored.vector(49), &[49.0, 1.0]);
1100
1101 let results = restored.search(&[25.0, 25.0], 3);
1103 assert_eq!(results.len(), 3);
1104 }
1105
1106 #[test]
1107 fn distance_to_node() {
1108 let mut graph = make_graph(16, 200, 50);
1109 graph.insert(0, &[1.0, 0.0]);
1110 graph.insert(1, &[0.0, 1.0]);
1111
1112 let d = graph.distance_to(0, &[1.0, 0.0]);
1113 assert!(d < 1e-5, "distance to self should be ~0, got {d}");
1114
1115 let d2 = graph.distance_to(1, &[1.0, 0.0]);
1116 assert!(d2 > 1.0, "distance to orthogonal should be > 1, got {d2}");
1117 }
1118
1119 #[test]
1121 #[ignore] fn recall_100k_d128() {
1123 let dim = 128;
1124 let n = 100_000u32;
1125 let k = 10;
1126 let mut graph = HnswGraph::new(
1127 HnswConfig {
1128 m: 16,
1129 ef_construction: 200,
1130 ef_search: 100,
1131 ..Default::default()
1132 },
1133 L2Distance,
1134 );
1135
1136 let mut rng = rand::rng();
1137 let vectors: Vec<Vec<f32>> = (0..n)
1138 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
1139 .collect();
1140
1141 for (i, v) in vectors.iter().enumerate() {
1142 graph.insert(i as u32, v);
1143 }
1144
1145 assert_eq!(
1146 graph.count_reachable(),
1147 n as usize,
1148 "not all nodes reachable"
1149 );
1150
1151 let n_queries = 100;
1152 let mut total_recall = 0.0;
1153
1154 for _ in 0..n_queries {
1155 let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
1156 let approx = graph.search(&query, k);
1157 let truth = graph.brute_force_knn(&query, k);
1158 total_recall += recall_at_k(&approx, &truth);
1159 }
1160
1161 let avg_recall = total_recall / n_queries as f64;
1162 assert!(
1163 avg_recall >= 0.95,
1164 "recall@10 on 100K D=128 = {avg_recall:.3}, expected >= 0.95"
1165 );
1166 }
1167}