cvx_index/hnsw/
temporal_lsh.rs

1//! Temporal Locality-Sensitive Hashing (T-LSH) — RFC-008 Phase 2.
2//!
3//! An auxiliary index for composite spatiotemporal queries (α < 1.0).
4//! Instead of over-fetching from the semantic HNSW and re-ranking,
5//! T-LSH generates candidates that are naturally distributed in both
6//! semantic and temporal space.
7//!
8//! # Hash function
9//!
10//! For a point `(vector, timestamp)`:
11//! ```text
12//! h_ST(v, t) = h_sem(v) ⊕ h_time(t)
13//!
14//! h_sem(v) = [sign(r₁·v), sign(r₂·v), ..., sign(rₖ·v)]  (random hyperplane LSH)
15//! h_time(t) = floor(t / bucket_size)                       (temporal bucketing)
16//!
17//! Combined: concatenate semantic_bits + temporal_bits into a u64 hash
18//! ```
19//!
20//! # References
21//!
22//! - Indyk & Motwani (1998). Approximate nearest neighbors. *STOC*.
23//! - Lv et al. (2007). Multi-probe LSH. *VLDB*.
24
25use std::collections::HashMap;
26
27use rand::rngs::SmallRng;
28use rand::{Rng, SeedableRng};
29
30// ─── Configuration ──────────────────────────────────────────────────
31
32/// Configuration for the T-LSH index.
33#[derive(Debug, Clone)]
34pub struct TLSHConfig {
35    /// Number of hash tables (more = higher recall, more memory).
36    pub n_tables: usize,
37    /// Number of semantic hash bits per table.
38    pub semantic_bits: usize,
39    /// Number of temporal hash bits per table.
40    pub temporal_bits: usize,
41    /// Temporal bucket size in microseconds.
42    pub temporal_bucket_us: i64,
43    /// Number of neighboring buckets to probe (multi-probe depth).
44    pub n_probes: usize,
45}
46
47impl Default for TLSHConfig {
48    fn default() -> Self {
49        Self {
50            n_tables: 16,
51            semantic_bits: 12,
52            temporal_bits: 4,
53            temporal_bucket_us: 86_400_000_000, // 1 day
54            n_probes: 3,
55        }
56    }
57}
58
59impl TLSHConfig {
60    /// Create config tuned for a given alpha value.
61    ///
62    /// Higher alpha → more semantic bits, fewer temporal bits.
63    pub fn for_alpha(alpha: f32, dim: usize) -> Self {
64        let total_bits = 16usize;
65        let sem_bits = ((alpha * total_bits as f32).round() as usize).clamp(2, total_bits - 2);
66        let time_bits = total_bits - sem_bits;
67
68        Self {
69            n_tables: 16,
70            semantic_bits: sem_bits,
71            temporal_bits: time_bits,
72            temporal_bucket_us: 86_400_000_000,
73            n_probes: if dim > 100 { 5 } else { 3 },
74        }
75    }
76}
77
78// ─── T-LSH Index ────────────────────────────────────────────────────
79
80/// Temporal Locality-Sensitive Hashing index.
81///
82/// Maintains multiple hash tables where each hash combines semantic
83/// (random hyperplane) and temporal (bucket) components.
84pub struct TemporalLSH {
85    /// Hash tables: `tables[table_idx][hash] → vec of node_ids`.
86    tables: Vec<HashMap<u64, Vec<u32>>>,
87    /// Random hyperplanes for semantic hashing.
88    /// Shape: `[n_tables][semantic_bits][dim]`.
89    hyperplanes: Vec<Vec<Vec<f32>>>,
90    /// Configuration.
91    config: TLSHConfig,
92    /// Dimensionality of vectors.
93    dim: usize,
94    /// Total number of indexed points.
95    n_points: usize,
96}
97
98impl TemporalLSH {
99    /// Create a new empty T-LSH index.
100    pub fn new(dim: usize, config: TLSHConfig) -> Self {
101        let mut rng = SmallRng::seed_from_u64(42);
102
103        let hyperplanes: Vec<Vec<Vec<f32>>> = (0..config.n_tables)
104            .map(|_| {
105                (0..config.semantic_bits)
106                    .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
107                    .collect()
108            })
109            .collect();
110
111        let tables = (0..config.n_tables).map(|_| HashMap::new()).collect();
112
113        Self {
114            tables,
115            hyperplanes,
116            config,
117            dim,
118            n_points: 0,
119        }
120    }
121
122    /// Build T-LSH from existing index data.
123    pub fn build(vectors: &[&[f32]], timestamps: &[i64], config: TLSHConfig) -> Self {
124        assert_eq!(vectors.len(), timestamps.len());
125        if vectors.is_empty() {
126            return Self::new(0, config);
127        }
128
129        let dim = vectors[0].len();
130        let mut index = Self::new(dim, config);
131
132        for (i, (v, &ts)) in vectors.iter().zip(timestamps.iter()).enumerate() {
133            index.insert(i as u32, v, ts);
134        }
135
136        index
137    }
138
139    /// Insert a point into all hash tables.
140    pub fn insert(&mut self, node_id: u32, vector: &[f32], timestamp: i64) {
141        for table_idx in 0..self.config.n_tables {
142            let hash = self.compute_hash(table_idx, vector, timestamp);
143            self.tables[table_idx]
144                .entry(hash)
145                .or_default()
146                .push(node_id);
147        }
148        self.n_points += 1;
149    }
150
151    /// Query: find candidate node IDs under spatiotemporal locality.
152    ///
153    /// Returns deduplicated candidate IDs from all tables + multi-probe.
154    pub fn query(&self, vector: &[f32], timestamp: i64) -> Vec<u32> {
155        let mut candidates = Vec::new();
156        let mut seen = std::collections::HashSet::new();
157
158        for table_idx in 0..self.config.n_tables {
159            let primary_hash = self.compute_hash(table_idx, vector, timestamp);
160
161            // Primary bucket
162            if let Some(ids) = self.tables[table_idx].get(&primary_hash) {
163                for &id in ids {
164                    if seen.insert(id) {
165                        candidates.push(id);
166                    }
167                }
168            }
169
170            // Multi-probe: neighboring temporal buckets
171            let temporal_bucket = self.temporal_bucket(timestamp);
172            for delta in 1..=self.config.n_probes as i64 {
173                for &dir in &[-1i64, 1] {
174                    let neighbor_bucket = temporal_bucket + delta * dir;
175                    let neighbor_hash = self.combine_hash(
176                        table_idx,
177                        &self.semantic_hash(table_idx, vector),
178                        neighbor_bucket,
179                    );
180                    if let Some(ids) = self.tables[table_idx].get(&neighbor_hash) {
181                        for &id in ids {
182                            if seen.insert(id) {
183                                candidates.push(id);
184                            }
185                        }
186                    }
187                }
188            }
189
190            // Multi-probe: flip one semantic bit
191            let sem_hash = self.semantic_hash(table_idx, vector);
192            for bit in 0..self.config.semantic_bits.min(3) {
193                let mut flipped = sem_hash.clone();
194                flipped[bit] = !flipped[bit];
195                let flipped_hash = self.combine_hash(table_idx, &flipped, temporal_bucket);
196                if let Some(ids) = self.tables[table_idx].get(&flipped_hash) {
197                    for &id in ids {
198                        if seen.insert(id) {
199                            candidates.push(id);
200                        }
201                    }
202                }
203            }
204        }
205
206        candidates
207    }
208
209    /// Number of indexed points.
210    pub fn len(&self) -> usize {
211        self.n_points
212    }
213
214    /// Whether the index is empty.
215    pub fn is_empty(&self) -> bool {
216        self.n_points == 0
217    }
218
219    /// Memory usage estimate in bytes.
220    pub fn memory_bytes(&self) -> usize {
221        let hyperplane_mem = self.config.n_tables
222            * self.config.semantic_bits
223            * self.dim
224            * std::mem::size_of::<f32>();
225
226        let table_mem: usize = self
227            .tables
228            .iter()
229            .map(|t| {
230                t.values()
231                    .map(|v| v.len() * std::mem::size_of::<u32>() + 8)
232                    .sum::<usize>()
233                    + t.len() * (std::mem::size_of::<u64>() + 24)
234            })
235            .sum();
236
237        hyperplane_mem + table_mem
238    }
239
240    // ─── Private helpers ────────────────────────────────────────
241
242    /// Compute the full hash for a point in a specific table.
243    fn compute_hash(&self, table_idx: usize, vector: &[f32], timestamp: i64) -> u64 {
244        let sem_bits = self.semantic_hash(table_idx, vector);
245        let temp_bucket = self.temporal_bucket(timestamp);
246        self.combine_hash(table_idx, &sem_bits, temp_bucket)
247    }
248
249    /// Compute semantic hash bits via random hyperplane LSH.
250    fn semantic_hash(&self, table_idx: usize, vector: &[f32]) -> Vec<bool> {
251        self.hyperplanes[table_idx]
252            .iter()
253            .map(|plane| {
254                let dot: f32 = plane.iter().zip(vector.iter()).map(|(a, b)| a * b).sum();
255                dot >= 0.0
256            })
257            .collect()
258    }
259
260    /// Compute temporal bucket index.
261    fn temporal_bucket(&self, timestamp: i64) -> i64 {
262        if self.config.temporal_bucket_us > 0 {
263            timestamp / self.config.temporal_bucket_us
264        } else {
265            0
266        }
267    }
268
269    /// Combine semantic bits and temporal bucket into a single u64 hash.
270    fn combine_hash(&self, _table_idx: usize, sem_bits: &[bool], temp_bucket: i64) -> u64 {
271        let mut hash: u64 = 0;
272
273        // Pack semantic bits into lower bits
274        for (i, &bit) in sem_bits.iter().enumerate() {
275            if bit {
276                hash |= 1u64 << i;
277            }
278        }
279
280        // Pack temporal bucket into upper bits
281        let temp_hash = temp_bucket as u64;
282        hash |= temp_hash << self.config.semantic_bits;
283
284        hash
285    }
286}
287
288// ─── Tests ──────────────────────────────────────────────────────────
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    fn default_config() -> TLSHConfig {
295        TLSHConfig {
296            n_tables: 4,
297            semantic_bits: 8,
298            temporal_bits: 4,
299            temporal_bucket_us: 1_000_000, // 1 second for testing
300            n_probes: 2,
301        }
302    }
303
304    // ─── Basic operations ───────────────────────────────────────
305
306    #[test]
307    fn new_empty() {
308        let index = TemporalLSH::new(4, default_config());
309        assert_eq!(index.len(), 0);
310        assert!(index.is_empty());
311    }
312
313    #[test]
314    fn insert_and_query_identical() {
315        let mut index = TemporalLSH::new(3, default_config());
316        let v = [1.0f32, 0.0, 0.0];
317        let ts = 1_000_000;
318
319        index.insert(0, &v, ts);
320        let candidates = index.query(&v, ts);
321
322        assert!(
323            candidates.contains(&0),
324            "query with identical vector+timestamp should find the point"
325        );
326    }
327
328    #[test]
329    fn insert_multiple_query_nearest() {
330        let config = default_config();
331        let mut index = TemporalLSH::new(3, config);
332
333        // Insert 100 points at various positions and timestamps
334        for i in 0..100u32 {
335            let v = [i as f32 * 0.1, (i as f32 * 0.05).sin(), 0.0];
336            let ts = i as i64 * 500_000; // spread over 50 seconds
337            index.insert(i, &v, ts);
338        }
339
340        assert_eq!(index.len(), 100);
341
342        // Query near point 50
343        let query_v = [5.0, (50.0 * 0.05f32).sin(), 0.0];
344        let query_ts = 25_000_000;
345        let candidates = index.query(&query_v, query_ts);
346
347        // Should find some candidates (may not be exact, it's LSH)
348        assert!(!candidates.is_empty(), "should find at least one candidate");
349    }
350
351    // ─── Temporal locality ──────────────────────────────────────
352
353    #[test]
354    fn temporal_neighbors_found_via_multiprobe() {
355        let config = TLSHConfig {
356            n_tables: 8,
357            semantic_bits: 8,
358            temporal_bits: 4,
359            temporal_bucket_us: 1_000_000, // 1 second buckets
360            n_probes: 3,
361        };
362        let mut index = TemporalLSH::new(2, config);
363
364        // Insert point at t=0
365        index.insert(0, &[1.0, 0.0], 0);
366        // Insert point at t=2s (2 buckets away)
367        index.insert(1, &[1.0, 0.0], 2_000_000);
368
369        // Query at t=1s — should find both via multi-probe
370        let candidates = index.query(&[1.0, 0.0], 1_000_000);
371
372        // With n_probes=3, should probe buckets -3 to +3 around bucket 1
373        // Bucket 0 (point 0) and bucket 2 (point 1) should be found
374        let found_0 = candidates.contains(&0);
375        let found_1 = candidates.contains(&1);
376        assert!(
377            found_0 || found_1,
378            "multi-probe should find at least one temporal neighbor, got {candidates:?}"
379        );
380    }
381
382    // ─── Semantic locality ──────────────────────────────────────
383
384    #[test]
385    fn similar_vectors_same_bucket() {
386        let config = TLSHConfig {
387            n_tables: 16,
388            semantic_bits: 8,
389            temporal_bits: 2,
390            temporal_bucket_us: 1_000_000,
391            n_probes: 1,
392        };
393        let mut index = TemporalLSH::new(4, config);
394
395        // Two very similar vectors at the same time
396        index.insert(0, &[1.0, 0.0, 0.0, 0.0], 0);
397        index.insert(1, &[0.99, 0.01, 0.0, 0.0], 0);
398        // One very different vector
399        index.insert(2, &[-1.0, 0.0, 0.0, 0.0], 0);
400
401        let candidates = index.query(&[1.0, 0.0, 0.0, 0.0], 0);
402
403        // Point 0 should definitely be found (exact match)
404        assert!(candidates.contains(&0));
405        // Point 1 should likely be found (very similar)
406        // Point 2 may or may not be found (opposite direction)
407    }
408
409    // ─── Build from data ────────────────────────────────────────
410
411    #[test]
412    fn build_from_vectors() {
413        let vectors: Vec<Vec<f32>> = (0..50).map(|i| vec![i as f32, 0.0]).collect();
414        let timestamps: Vec<i64> = (0..50).map(|i| i as i64 * 1_000_000).collect();
415
416        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
417        let index = TemporalLSH::build(&refs, &timestamps, default_config());
418
419        assert_eq!(index.len(), 50);
420    }
421
422    // ─── Config tuning ──────────────────────────────────────────
423
424    #[test]
425    fn config_for_alpha_high() {
426        let config = TLSHConfig::for_alpha(0.9, 384);
427        // High alpha → more semantic bits
428        assert!(config.semantic_bits > config.temporal_bits);
429    }
430
431    #[test]
432    fn config_for_alpha_balanced() {
433        let config = TLSHConfig::for_alpha(0.5, 384);
434        // Balanced alpha → roughly equal bits
435        let diff = (config.semantic_bits as i32 - config.temporal_bits as i32).unsigned_abs();
436        assert!(diff <= 2, "balanced alpha should give roughly equal bits");
437    }
438
439    #[test]
440    fn config_for_alpha_low() {
441        let config = TLSHConfig::for_alpha(0.2, 384);
442        // Low alpha → more temporal bits
443        assert!(config.temporal_bits > config.semantic_bits);
444    }
445
446    // ─── Memory estimate ────────────────────────────────────────
447
448    #[test]
449    fn memory_estimate_grows_with_data() {
450        let config = default_config();
451        let mut index = TemporalLSH::new(4, config);
452        let mem_empty = index.memory_bytes();
453
454        for i in 0..100u32 {
455            index.insert(i, &[i as f32, 0.0, 0.0, 0.0], i as i64 * 1000);
456        }
457        let mem_full = index.memory_bytes();
458
459        assert!(
460            mem_full > mem_empty,
461            "memory should grow with inserted points"
462        );
463    }
464
465    // ─── Edge cases ─────────────────────────────────────────────
466
467    #[test]
468    fn query_empty_index() {
469        let index = TemporalLSH::new(3, default_config());
470        let candidates = index.query(&[1.0, 0.0, 0.0], 0);
471        assert!(candidates.is_empty());
472    }
473
474    #[test]
475    fn negative_timestamps() {
476        let mut index = TemporalLSH::new(2, default_config());
477        index.insert(0, &[1.0, 0.0], -5_000_000);
478        index.insert(1, &[1.0, 0.0], -3_000_000);
479
480        let candidates = index.query(&[1.0, 0.0], -4_000_000);
481        assert!(!candidates.is_empty(), "should handle negative timestamps");
482    }
483
484    #[test]
485    fn high_dimensional() {
486        let dim = 384;
487        let config = TLSHConfig {
488            n_tables: 4,
489            semantic_bits: 12,
490            temporal_bits: 4,
491            temporal_bucket_us: 1_000_000,
492            n_probes: 2,
493        };
494        let mut index = TemporalLSH::new(dim, config);
495
496        let v: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.01).sin()).collect();
497        index.insert(0, &v, 0);
498
499        let candidates = index.query(&v, 0);
500        assert!(candidates.contains(&0));
501    }
502
503    // ─── Hash consistency ───────────────────────────────────────
504
505    #[test]
506    fn same_input_same_hash() {
507        let index = TemporalLSH::new(3, default_config());
508        let v = [1.0f32, 2.0, 3.0];
509        let ts = 5_000_000;
510
511        let h1 = index.compute_hash(0, &v, ts);
512        let h2 = index.compute_hash(0, &v, ts);
513        assert_eq!(h1, h2, "same input should produce same hash");
514    }
515
516    #[test]
517    fn different_time_different_hash() {
518        let index = TemporalLSH::new(3, default_config());
519        let v = [1.0f32, 0.0, 0.0];
520
521        // Different temporal buckets should give different hashes (usually)
522        let h1 = index.compute_hash(0, &v, 0);
523        let h2 = index.compute_hash(0, &v, 10_000_000); // 10 seconds later, different bucket
524
525        // They might be the same if semantic bits dominate, but for 1s buckets
526        // 10 seconds apart should differ
527        assert_ne!(
528            h1, h2,
529            "different temporal buckets should usually give different hashes"
530        );
531    }
532}