cvx_index/hnsw/
optimized.rs

1//! Advanced HNSW optimizations.
2//!
3//! ## Heuristic Neighbor Selection (Malkov §4.2)
4//!
5//! Instead of keeping the M closest neighbors, selects neighbors that provide
6//! good graph connectivity by preferring diverse directions over raw proximity.
7//!
8//! ## Time-Decay Edge Weights
9//!
10//! Edge weights decay exponentially with age: `w(e, t) = w0 * exp(-λ * age)`.
11//! During search, decayed weights penalize stale connections.
12//!
13//! ## Backup Neighbors
14//!
15//! Each node maintains a secondary neighbor list used when primary neighbors
16//! are removed (e.g., node expiration in streaming scenarios).
17//!
18//! ## Index Persistence
19//!
20//! Serialize/deserialize the HNSW graph for crash recovery and snapshots.
21
22use cvx_core::DistanceMetric;
23
24use super::HnswGraph;
25
26/// Select neighbors using the heuristic from Malkov & Yashunin (2018) §4.2.
27///
28/// From a candidate set, greedily selects neighbors that are closer to the
29/// target than to any already-selected neighbor. This produces a more diverse
30/// neighbor set with better graph connectivity.
31///
32/// Returns at most `m` neighbor IDs.
33/// Trait for accessing node vectors by ID, avoiding full-collection clones.
34pub trait NodeVectors {
35    /// Get the vector for a given node ID.
36    fn get_vector(&self, id: u32) -> &[f32];
37}
38
39/// Implementation for a slice of Vec<f32>.
40impl NodeVectors for [Vec<f32>] {
41    fn get_vector(&self, id: u32) -> &[f32] {
42        &self[id as usize]
43    }
44}
45
46/// Select neighbors using the heuristic from Malkov & Yashunin (2018) §4.2.
47///
48/// Greedily selects neighbors closer to target than to any already-selected
49/// neighbor, producing a diverse set with better graph connectivity.
50///
51/// Returns at most `m` neighbor IDs.
52pub fn select_neighbors_heuristic<D: DistanceMetric, N: NodeVectors + ?Sized>(
53    metric: &D,
54    candidates: &[(u32, f32)], // (node_id, distance_to_target)
55    node_vectors: &N,
56    m: usize,
57    extend_candidates: bool,
58) -> Vec<u32> {
59    if candidates.is_empty() || m == 0 {
60        return Vec::new();
61    }
62
63    // Working set sorted by distance (closest first)
64    let mut working: Vec<(u32, f32)> = candidates.to_vec();
65    working.sort_by(|a, b| a.1.total_cmp(&b.1));
66
67    let mut selected: Vec<u32> = Vec::with_capacity(m);
68    let mut selected_vectors: Vec<&[f32]> = Vec::with_capacity(m);
69
70    for &(cand_id, cand_dist) in &working {
71        if selected.len() >= m {
72            break;
73        }
74
75        // Check if candidate is closer to target than to any selected neighbor.
76        // Strict inequality per Malkov Algorithm 4 (RFC-002-03): forces
77        // genuinely diverse directions instead of accepting equidistant candidates.
78        let cand_vec = node_vectors.get_vector(cand_id);
79        let is_good = selected_vectors.iter().all(|&sel_vec| {
80            let dist_to_selected = metric.distance(cand_vec, sel_vec);
81            cand_dist < dist_to_selected
82        });
83
84        if is_good || (extend_candidates && selected.len() < m / 2) {
85            selected.push(cand_id);
86            selected_vectors.push(cand_vec);
87        }
88    }
89
90    // If we didn't fill up, add closest remaining candidates
91    if selected.len() < m {
92        for &(cand_id, _) in &working {
93            if selected.len() >= m {
94                break;
95            }
96            if !selected.contains(&cand_id) {
97                selected.push(cand_id);
98            }
99        }
100    }
101
102    selected
103}
104
105/// Time-decay weight for an edge.
106///
107/// Returns `exp(-lambda * age)` where age is in the same units as timestamps.
108pub fn time_decay_weight(edge_timestamp: i64, current_time: i64, lambda: f64) -> f32 {
109    let age = (current_time - edge_timestamp).max(0) as f64;
110    (-lambda * age).exp() as f32
111}
112
113/// Apply time-decay to a distance score.
114///
115/// The effective distance increases for stale edges:
116/// `d_effective = d_raw / decay_weight`
117pub fn decay_adjusted_distance(
118    raw_distance: f32,
119    edge_timestamp: i64,
120    current_time: i64,
121    lambda: f64,
122) -> f32 {
123    let weight = time_decay_weight(edge_timestamp, current_time, lambda);
124    if weight > 1e-10 {
125        raw_distance / weight
126    } else {
127        f32::INFINITY
128    }
129}
130
131/// Backup neighbor storage for handling node expiration.
132#[derive(Debug, Clone)]
133pub struct BackupNeighbors {
134    /// Primary neighbors (from HNSW construction).
135    pub primary: Vec<u32>,
136    /// Backup neighbors (next-best candidates).
137    pub backup: Vec<u32>,
138}
139
140impl BackupNeighbors {
141    /// Create with primary and backup lists.
142    pub fn new(primary: Vec<u32>, backup: Vec<u32>) -> Self {
143        Self { primary, backup }
144    }
145
146    /// Get active neighbors, replacing expired primaries with backups.
147    pub fn active_neighbors(&self, is_expired: &dyn Fn(u32) -> bool) -> Vec<u32> {
148        let mut active: Vec<u32> = self
149            .primary
150            .iter()
151            .copied()
152            .filter(|&id| !is_expired(id))
153            .collect();
154
155        // Fill with backups if primaries were expired
156        let needed = self.primary.len().saturating_sub(active.len());
157        for &b in self.backup.iter().take(needed) {
158            if !is_expired(b) && !active.contains(&b) {
159                active.push(b);
160            }
161        }
162
163        active
164    }
165}
166
167/// Serialized HNSW graph for persistence.
168///
169/// Contains all data needed to reconstruct the graph.
170#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
171pub struct SerializedGraph {
172    /// Configuration.
173    pub m: usize,
174    /// Number of nodes.
175    pub num_nodes: usize,
176    /// Entry point node ID.
177    pub entry_point: Option<u32>,
178    /// Maximum level in the graph.
179    pub max_level: usize,
180    /// All vectors (flattened: `[node_0_dim_0, node_0_dim_1, ..., node_1_dim_0, ...]`).
181    pub vectors: Vec<f32>,
182    /// Vector dimensionality.
183    pub dim: usize,
184    /// Neighbor lists: `[node_0_level_0, node_0_level_1, ..., node_1_level_0, ...]`
185    /// Encoded as: for each node, number of levels, then for each level, count + neighbor IDs.
186    pub neighbors: Vec<u32>,
187}
188
189/// Serialize an HNSW graph to a portable format.
190pub fn serialize_graph<D: DistanceMetric>(graph: &HnswGraph<D>) -> SerializedGraph {
191    let num_nodes = graph.len();
192    let dim = if num_nodes > 0 {
193        graph.vector(0).len()
194    } else {
195        0
196    };
197
198    let mut vectors = Vec::with_capacity(num_nodes * dim);
199    let mut neighbors = Vec::new();
200
201    for i in 0..num_nodes {
202        let id = i as u32;
203        vectors.extend_from_slice(graph.vector(id));
204
205        // Encode neighbor lists for this node
206        let node_neighbors = graph.all_neighbors(id);
207        neighbors.push(node_neighbors.len() as u32); // number of levels
208        for level_neighbors in &node_neighbors {
209            neighbors.push(level_neighbors.len() as u32); // count
210            neighbors.extend(level_neighbors.iter());
211        }
212    }
213
214    SerializedGraph {
215        m: graph.config().m,
216        num_nodes,
217        entry_point: graph.entry_point(),
218        max_level: graph.max_level(),
219        vectors,
220        dim,
221        neighbors,
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use crate::metrics::L2Distance;
229
230    // ─── Heuristic neighbor selection ───────────────────────────────
231
232    #[test]
233    fn heuristic_selects_diverse_neighbors() {
234        let metric = L2Distance;
235        let target = vec![0.0, 0.0];
236        let node_vectors = vec![
237            vec![1.0, 0.0],   // 0: right, dist=1.0
238            vec![0.95, 0.05], // 1: almost same as 0, dist≈0.91
239            vec![0.0, 1.0],   // 2: up, dist=1.0
240            vec![-1.0, 0.0],  // 3: left, dist=1.0
241        ];
242
243        let candidates = vec![
244            (0, metric.distance(&target, &node_vectors[0])),
245            (1, metric.distance(&target, &node_vectors[1])),
246            (2, metric.distance(&target, &node_vectors[2])),
247            (3, metric.distance(&target, &node_vectors[3])),
248        ];
249
250        let selected =
251            select_neighbors_heuristic(&metric, &candidates, node_vectors.as_slice(), 3, false);
252        assert_eq!(selected.len(), 3);
253
254        // The heuristic should exclude one of {0, 1} since they're in the same direction
255        // Node 1 is closest, so it gets selected first. Then 2 and 3 add diversity.
256        assert!(
257            selected.contains(&2),
258            "should include node 2 (up direction)"
259        );
260        assert!(
261            selected.contains(&3),
262            "should include node 3 (left direction)"
263        );
264    }
265
266    #[test]
267    fn heuristic_empty_candidates() {
268        let metric = L2Distance;
269        let empty: &[Vec<f32>] = &[];
270        let result = select_neighbors_heuristic(&metric, &[], empty, 5, false);
271        assert!(result.is_empty());
272    }
273
274    #[test]
275    fn heuristic_fewer_candidates_than_m() {
276        let metric = L2Distance;
277        let vectors = vec![vec![1.0], vec![2.0]];
278        let candidates = vec![(0, 1.0), (1, 2.0)];
279        let selected =
280            select_neighbors_heuristic(&metric, &candidates, vectors.as_slice(), 5, false);
281        assert_eq!(selected.len(), 2); // only 2 available
282    }
283
284    // ─── Time-decay weights ─────────────────────────────────────────
285
286    #[test]
287    fn decay_weight_same_time_is_one() {
288        let w = time_decay_weight(1000, 1000, 0.001);
289        assert!((w - 1.0).abs() < 1e-6);
290    }
291
292    #[test]
293    fn decay_weight_decreases_with_age() {
294        let w1 = time_decay_weight(900, 1000, 0.01);
295        let w2 = time_decay_weight(500, 1000, 0.01);
296        assert!(w1 > w2, "newer edge should have higher weight");
297    }
298
299    #[test]
300    fn decay_weight_high_lambda_fast_decay() {
301        let w_slow = time_decay_weight(0, 1000, 0.001);
302        let w_fast = time_decay_weight(0, 1000, 0.01);
303        assert!(w_slow > w_fast, "higher lambda should decay faster");
304    }
305
306    #[test]
307    fn decay_adjusted_distance_increases_with_age() {
308        let d_new = decay_adjusted_distance(1.0, 900, 1000, 0.01);
309        let d_old = decay_adjusted_distance(1.0, 0, 1000, 0.01);
310        assert!(
311            d_old > d_new,
312            "older edges should have larger effective distance"
313        );
314    }
315
316    // ─── Backup neighbors ───────────────────────────────────────────
317
318    #[test]
319    fn backup_fills_expired_primary() {
320        let bn = BackupNeighbors::new(vec![1, 2, 3], vec![10, 11]);
321
322        // Node 2 expired
323        let active = bn.active_neighbors(&|id| id == 2);
324        assert_eq!(active.len(), 3); // 1, 3 from primary + 10 from backup
325        assert!(active.contains(&1));
326        assert!(active.contains(&3));
327        assert!(active.contains(&10));
328        assert!(!active.contains(&2));
329    }
330
331    #[test]
332    fn backup_no_expired() {
333        let bn = BackupNeighbors::new(vec![1, 2, 3], vec![10]);
334        let active = bn.active_neighbors(&|_| false);
335        assert_eq!(active, vec![1, 2, 3]);
336    }
337
338    #[test]
339    fn backup_all_expired_uses_all_backups() {
340        let bn = BackupNeighbors::new(vec![1, 2], vec![10, 11, 12]);
341        let active = bn.active_neighbors(&|id| id <= 2);
342        assert_eq!(active, vec![10, 11]);
343    }
344
345    // ─── Serialization ──────────────────────────────────────────────
346
347    #[test]
348    fn serialize_empty_graph() {
349        let config = super::super::HnswConfig::default();
350        let graph = HnswGraph::new(config, L2Distance);
351        let serialized = serialize_graph(&graph);
352        assert_eq!(serialized.num_nodes, 0);
353        assert_eq!(serialized.entry_point, None);
354    }
355
356    #[test]
357    fn serialize_graph_preserves_data() {
358        let config = super::super::HnswConfig {
359            m: 8,
360            ef_construction: 100,
361            ef_search: 50,
362            ..Default::default()
363        };
364        let mut graph = HnswGraph::new(config, L2Distance);
365
366        for i in 0..50u32 {
367            graph.insert(i, &[i as f32, (50 - i) as f32]);
368        }
369
370        let serialized = serialize_graph(&graph);
371        assert_eq!(serialized.num_nodes, 50);
372        assert_eq!(serialized.dim, 2);
373        assert_eq!(serialized.vectors.len(), 100); // 50 * 2
374        assert!(serialized.entry_point.is_some());
375        assert_eq!(serialized.m, 8);
376
377        // Verify first vector
378        assert!((serialized.vectors[0] - 0.0).abs() < 1e-6);
379        assert!((serialized.vectors[1] - 50.0).abs() < 1e-6);
380    }
381
382    #[test]
383    fn serialized_graph_is_serde_roundtrip() {
384        let config = super::super::HnswConfig {
385            m: 4,
386            ..Default::default()
387        };
388        let mut graph = HnswGraph::new(config, L2Distance);
389        for i in 0..10u32 {
390            graph.insert(i, &[i as f32]);
391        }
392
393        let serialized = serialize_graph(&graph);
394        let json = serde_json::to_string(&serialized).unwrap();
395        let recovered: SerializedGraph = serde_json::from_str(&json).unwrap();
396
397        assert_eq!(recovered.num_nodes, 10);
398        assert_eq!(recovered.vectors, serialized.vectors);
399        assert_eq!(recovered.neighbors, serialized.neighbors);
400    }
401}