1use cvx_core::DistanceMetric;
23
24use super::HnswGraph;
25
26pub trait NodeVectors {
35 fn get_vector(&self, id: u32) -> &[f32];
37}
38
39impl NodeVectors for [Vec<f32>] {
41 fn get_vector(&self, id: u32) -> &[f32] {
42 &self[id as usize]
43 }
44}
45
46pub fn select_neighbors_heuristic<D: DistanceMetric, N: NodeVectors + ?Sized>(
53 metric: &D,
54 candidates: &[(u32, f32)], node_vectors: &N,
56 m: usize,
57 extend_candidates: bool,
58) -> Vec<u32> {
59 if candidates.is_empty() || m == 0 {
60 return Vec::new();
61 }
62
63 let mut working: Vec<(u32, f32)> = candidates.to_vec();
65 working.sort_by(|a, b| a.1.total_cmp(&b.1));
66
67 let mut selected: Vec<u32> = Vec::with_capacity(m);
68 let mut selected_vectors: Vec<&[f32]> = Vec::with_capacity(m);
69
70 for &(cand_id, cand_dist) in &working {
71 if selected.len() >= m {
72 break;
73 }
74
75 let cand_vec = node_vectors.get_vector(cand_id);
79 let is_good = selected_vectors.iter().all(|&sel_vec| {
80 let dist_to_selected = metric.distance(cand_vec, sel_vec);
81 cand_dist < dist_to_selected
82 });
83
84 if is_good || (extend_candidates && selected.len() < m / 2) {
85 selected.push(cand_id);
86 selected_vectors.push(cand_vec);
87 }
88 }
89
90 if selected.len() < m {
92 for &(cand_id, _) in &working {
93 if selected.len() >= m {
94 break;
95 }
96 if !selected.contains(&cand_id) {
97 selected.push(cand_id);
98 }
99 }
100 }
101
102 selected
103}
104
105pub fn time_decay_weight(edge_timestamp: i64, current_time: i64, lambda: f64) -> f32 {
109 let age = (current_time - edge_timestamp).max(0) as f64;
110 (-lambda * age).exp() as f32
111}
112
113pub fn decay_adjusted_distance(
118 raw_distance: f32,
119 edge_timestamp: i64,
120 current_time: i64,
121 lambda: f64,
122) -> f32 {
123 let weight = time_decay_weight(edge_timestamp, current_time, lambda);
124 if weight > 1e-10 {
125 raw_distance / weight
126 } else {
127 f32::INFINITY
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct BackupNeighbors {
134 pub primary: Vec<u32>,
136 pub backup: Vec<u32>,
138}
139
140impl BackupNeighbors {
141 pub fn new(primary: Vec<u32>, backup: Vec<u32>) -> Self {
143 Self { primary, backup }
144 }
145
146 pub fn active_neighbors(&self, is_expired: &dyn Fn(u32) -> bool) -> Vec<u32> {
148 let mut active: Vec<u32> = self
149 .primary
150 .iter()
151 .copied()
152 .filter(|&id| !is_expired(id))
153 .collect();
154
155 let needed = self.primary.len().saturating_sub(active.len());
157 for &b in self.backup.iter().take(needed) {
158 if !is_expired(b) && !active.contains(&b) {
159 active.push(b);
160 }
161 }
162
163 active
164 }
165}
166
167#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
171pub struct SerializedGraph {
172 pub m: usize,
174 pub num_nodes: usize,
176 pub entry_point: Option<u32>,
178 pub max_level: usize,
180 pub vectors: Vec<f32>,
182 pub dim: usize,
184 pub neighbors: Vec<u32>,
187}
188
189pub fn serialize_graph<D: DistanceMetric>(graph: &HnswGraph<D>) -> SerializedGraph {
191 let num_nodes = graph.len();
192 let dim = if num_nodes > 0 {
193 graph.vector(0).len()
194 } else {
195 0
196 };
197
198 let mut vectors = Vec::with_capacity(num_nodes * dim);
199 let mut neighbors = Vec::new();
200
201 for i in 0..num_nodes {
202 let id = i as u32;
203 vectors.extend_from_slice(graph.vector(id));
204
205 let node_neighbors = graph.all_neighbors(id);
207 neighbors.push(node_neighbors.len() as u32); for level_neighbors in &node_neighbors {
209 neighbors.push(level_neighbors.len() as u32); neighbors.extend(level_neighbors.iter());
211 }
212 }
213
214 SerializedGraph {
215 m: graph.config().m,
216 num_nodes,
217 entry_point: graph.entry_point(),
218 max_level: graph.max_level(),
219 vectors,
220 dim,
221 neighbors,
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use crate::metrics::L2Distance;
229
230 #[test]
233 fn heuristic_selects_diverse_neighbors() {
234 let metric = L2Distance;
235 let target = vec![0.0, 0.0];
236 let node_vectors = vec![
237 vec![1.0, 0.0], vec![0.95, 0.05], vec![0.0, 1.0], vec![-1.0, 0.0], ];
242
243 let candidates = vec![
244 (0, metric.distance(&target, &node_vectors[0])),
245 (1, metric.distance(&target, &node_vectors[1])),
246 (2, metric.distance(&target, &node_vectors[2])),
247 (3, metric.distance(&target, &node_vectors[3])),
248 ];
249
250 let selected =
251 select_neighbors_heuristic(&metric, &candidates, node_vectors.as_slice(), 3, false);
252 assert_eq!(selected.len(), 3);
253
254 assert!(
257 selected.contains(&2),
258 "should include node 2 (up direction)"
259 );
260 assert!(
261 selected.contains(&3),
262 "should include node 3 (left direction)"
263 );
264 }
265
266 #[test]
267 fn heuristic_empty_candidates() {
268 let metric = L2Distance;
269 let empty: &[Vec<f32>] = &[];
270 let result = select_neighbors_heuristic(&metric, &[], empty, 5, false);
271 assert!(result.is_empty());
272 }
273
274 #[test]
275 fn heuristic_fewer_candidates_than_m() {
276 let metric = L2Distance;
277 let vectors = vec![vec![1.0], vec![2.0]];
278 let candidates = vec![(0, 1.0), (1, 2.0)];
279 let selected =
280 select_neighbors_heuristic(&metric, &candidates, vectors.as_slice(), 5, false);
281 assert_eq!(selected.len(), 2); }
283
284 #[test]
287 fn decay_weight_same_time_is_one() {
288 let w = time_decay_weight(1000, 1000, 0.001);
289 assert!((w - 1.0).abs() < 1e-6);
290 }
291
292 #[test]
293 fn decay_weight_decreases_with_age() {
294 let w1 = time_decay_weight(900, 1000, 0.01);
295 let w2 = time_decay_weight(500, 1000, 0.01);
296 assert!(w1 > w2, "newer edge should have higher weight");
297 }
298
299 #[test]
300 fn decay_weight_high_lambda_fast_decay() {
301 let w_slow = time_decay_weight(0, 1000, 0.001);
302 let w_fast = time_decay_weight(0, 1000, 0.01);
303 assert!(w_slow > w_fast, "higher lambda should decay faster");
304 }
305
306 #[test]
307 fn decay_adjusted_distance_increases_with_age() {
308 let d_new = decay_adjusted_distance(1.0, 900, 1000, 0.01);
309 let d_old = decay_adjusted_distance(1.0, 0, 1000, 0.01);
310 assert!(
311 d_old > d_new,
312 "older edges should have larger effective distance"
313 );
314 }
315
316 #[test]
319 fn backup_fills_expired_primary() {
320 let bn = BackupNeighbors::new(vec![1, 2, 3], vec![10, 11]);
321
322 let active = bn.active_neighbors(&|id| id == 2);
324 assert_eq!(active.len(), 3); assert!(active.contains(&1));
326 assert!(active.contains(&3));
327 assert!(active.contains(&10));
328 assert!(!active.contains(&2));
329 }
330
331 #[test]
332 fn backup_no_expired() {
333 let bn = BackupNeighbors::new(vec![1, 2, 3], vec![10]);
334 let active = bn.active_neighbors(&|_| false);
335 assert_eq!(active, vec![1, 2, 3]);
336 }
337
338 #[test]
339 fn backup_all_expired_uses_all_backups() {
340 let bn = BackupNeighbors::new(vec![1, 2], vec![10, 11, 12]);
341 let active = bn.active_neighbors(&|id| id <= 2);
342 assert_eq!(active, vec![10, 11]);
343 }
344
345 #[test]
348 fn serialize_empty_graph() {
349 let config = super::super::HnswConfig::default();
350 let graph = HnswGraph::new(config, L2Distance);
351 let serialized = serialize_graph(&graph);
352 assert_eq!(serialized.num_nodes, 0);
353 assert_eq!(serialized.entry_point, None);
354 }
355
356 #[test]
357 fn serialize_graph_preserves_data() {
358 let config = super::super::HnswConfig {
359 m: 8,
360 ef_construction: 100,
361 ef_search: 50,
362 ..Default::default()
363 };
364 let mut graph = HnswGraph::new(config, L2Distance);
365
366 for i in 0..50u32 {
367 graph.insert(i, &[i as f32, (50 - i) as f32]);
368 }
369
370 let serialized = serialize_graph(&graph);
371 assert_eq!(serialized.num_nodes, 50);
372 assert_eq!(serialized.dim, 2);
373 assert_eq!(serialized.vectors.len(), 100); assert!(serialized.entry_point.is_some());
375 assert_eq!(serialized.m, 8);
376
377 assert!((serialized.vectors[0] - 0.0).abs() < 1e-6);
379 assert!((serialized.vectors[1] - 50.0).abs() < 1e-6);
380 }
381
382 #[test]
383 fn serialized_graph_is_serde_roundtrip() {
384 let config = super::super::HnswConfig {
385 m: 4,
386 ..Default::default()
387 };
388 let mut graph = HnswGraph::new(config, L2Distance);
389 for i in 0..10u32 {
390 graph.insert(i, &[i as f32]);
391 }
392
393 let serialized = serialize_graph(&graph);
394 let json = serde_json::to_string(&serialized).unwrap();
395 let recovered: SerializedGraph = serde_json::from_str(&json).unwrap();
396
397 assert_eq!(recovered.num_nodes, 10);
398 assert_eq!(recovered.vectors, serialized.vectors);
399 assert_eq!(recovered.neighbors, serialized.neighbors);
400 }
401}