1use std::collections::HashMap;
26
27use rand::rngs::SmallRng;
28use rand::{Rng, SeedableRng};
29
30#[derive(Debug, Clone)]
34pub struct TLSHConfig {
35 pub n_tables: usize,
37 pub semantic_bits: usize,
39 pub temporal_bits: usize,
41 pub temporal_bucket_us: i64,
43 pub n_probes: usize,
45}
46
47impl Default for TLSHConfig {
48 fn default() -> Self {
49 Self {
50 n_tables: 16,
51 semantic_bits: 12,
52 temporal_bits: 4,
53 temporal_bucket_us: 86_400_000_000, n_probes: 3,
55 }
56 }
57}
58
59impl TLSHConfig {
60 pub fn for_alpha(alpha: f32, dim: usize) -> Self {
64 let total_bits = 16usize;
65 let sem_bits = ((alpha * total_bits as f32).round() as usize).clamp(2, total_bits - 2);
66 let time_bits = total_bits - sem_bits;
67
68 Self {
69 n_tables: 16,
70 semantic_bits: sem_bits,
71 temporal_bits: time_bits,
72 temporal_bucket_us: 86_400_000_000,
73 n_probes: if dim > 100 { 5 } else { 3 },
74 }
75 }
76}
77
78pub struct TemporalLSH {
85 tables: Vec<HashMap<u64, Vec<u32>>>,
87 hyperplanes: Vec<Vec<Vec<f32>>>,
90 config: TLSHConfig,
92 dim: usize,
94 n_points: usize,
96}
97
98impl TemporalLSH {
99 pub fn new(dim: usize, config: TLSHConfig) -> Self {
101 let mut rng = SmallRng::seed_from_u64(42);
102
103 let hyperplanes: Vec<Vec<Vec<f32>>> = (0..config.n_tables)
104 .map(|_| {
105 (0..config.semantic_bits)
106 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
107 .collect()
108 })
109 .collect();
110
111 let tables = (0..config.n_tables).map(|_| HashMap::new()).collect();
112
113 Self {
114 tables,
115 hyperplanes,
116 config,
117 dim,
118 n_points: 0,
119 }
120 }
121
122 pub fn build(vectors: &[&[f32]], timestamps: &[i64], config: TLSHConfig) -> Self {
124 assert_eq!(vectors.len(), timestamps.len());
125 if vectors.is_empty() {
126 return Self::new(0, config);
127 }
128
129 let dim = vectors[0].len();
130 let mut index = Self::new(dim, config);
131
132 for (i, (v, &ts)) in vectors.iter().zip(timestamps.iter()).enumerate() {
133 index.insert(i as u32, v, ts);
134 }
135
136 index
137 }
138
139 pub fn insert(&mut self, node_id: u32, vector: &[f32], timestamp: i64) {
141 for table_idx in 0..self.config.n_tables {
142 let hash = self.compute_hash(table_idx, vector, timestamp);
143 self.tables[table_idx]
144 .entry(hash)
145 .or_default()
146 .push(node_id);
147 }
148 self.n_points += 1;
149 }
150
151 pub fn query(&self, vector: &[f32], timestamp: i64) -> Vec<u32> {
155 let mut candidates = Vec::new();
156 let mut seen = std::collections::HashSet::new();
157
158 for table_idx in 0..self.config.n_tables {
159 let primary_hash = self.compute_hash(table_idx, vector, timestamp);
160
161 if let Some(ids) = self.tables[table_idx].get(&primary_hash) {
163 for &id in ids {
164 if seen.insert(id) {
165 candidates.push(id);
166 }
167 }
168 }
169
170 let temporal_bucket = self.temporal_bucket(timestamp);
172 for delta in 1..=self.config.n_probes as i64 {
173 for &dir in &[-1i64, 1] {
174 let neighbor_bucket = temporal_bucket + delta * dir;
175 let neighbor_hash = self.combine_hash(
176 table_idx,
177 &self.semantic_hash(table_idx, vector),
178 neighbor_bucket,
179 );
180 if let Some(ids) = self.tables[table_idx].get(&neighbor_hash) {
181 for &id in ids {
182 if seen.insert(id) {
183 candidates.push(id);
184 }
185 }
186 }
187 }
188 }
189
190 let sem_hash = self.semantic_hash(table_idx, vector);
192 for bit in 0..self.config.semantic_bits.min(3) {
193 let mut flipped = sem_hash.clone();
194 flipped[bit] = !flipped[bit];
195 let flipped_hash = self.combine_hash(table_idx, &flipped, temporal_bucket);
196 if let Some(ids) = self.tables[table_idx].get(&flipped_hash) {
197 for &id in ids {
198 if seen.insert(id) {
199 candidates.push(id);
200 }
201 }
202 }
203 }
204 }
205
206 candidates
207 }
208
209 pub fn len(&self) -> usize {
211 self.n_points
212 }
213
214 pub fn is_empty(&self) -> bool {
216 self.n_points == 0
217 }
218
219 pub fn memory_bytes(&self) -> usize {
221 let hyperplane_mem = self.config.n_tables
222 * self.config.semantic_bits
223 * self.dim
224 * std::mem::size_of::<f32>();
225
226 let table_mem: usize = self
227 .tables
228 .iter()
229 .map(|t| {
230 t.values()
231 .map(|v| v.len() * std::mem::size_of::<u32>() + 8)
232 .sum::<usize>()
233 + t.len() * (std::mem::size_of::<u64>() + 24)
234 })
235 .sum();
236
237 hyperplane_mem + table_mem
238 }
239
240 fn compute_hash(&self, table_idx: usize, vector: &[f32], timestamp: i64) -> u64 {
244 let sem_bits = self.semantic_hash(table_idx, vector);
245 let temp_bucket = self.temporal_bucket(timestamp);
246 self.combine_hash(table_idx, &sem_bits, temp_bucket)
247 }
248
249 fn semantic_hash(&self, table_idx: usize, vector: &[f32]) -> Vec<bool> {
251 self.hyperplanes[table_idx]
252 .iter()
253 .map(|plane| {
254 let dot: f32 = plane.iter().zip(vector.iter()).map(|(a, b)| a * b).sum();
255 dot >= 0.0
256 })
257 .collect()
258 }
259
260 fn temporal_bucket(&self, timestamp: i64) -> i64 {
262 if self.config.temporal_bucket_us > 0 {
263 timestamp / self.config.temporal_bucket_us
264 } else {
265 0
266 }
267 }
268
269 fn combine_hash(&self, _table_idx: usize, sem_bits: &[bool], temp_bucket: i64) -> u64 {
271 let mut hash: u64 = 0;
272
273 for (i, &bit) in sem_bits.iter().enumerate() {
275 if bit {
276 hash |= 1u64 << i;
277 }
278 }
279
280 let temp_hash = temp_bucket as u64;
282 hash |= temp_hash << self.config.semantic_bits;
283
284 hash
285 }
286}
287
288#[cfg(test)]
291mod tests {
292 use super::*;
293
294 fn default_config() -> TLSHConfig {
295 TLSHConfig {
296 n_tables: 4,
297 semantic_bits: 8,
298 temporal_bits: 4,
299 temporal_bucket_us: 1_000_000, n_probes: 2,
301 }
302 }
303
304 #[test]
307 fn new_empty() {
308 let index = TemporalLSH::new(4, default_config());
309 assert_eq!(index.len(), 0);
310 assert!(index.is_empty());
311 }
312
313 #[test]
314 fn insert_and_query_identical() {
315 let mut index = TemporalLSH::new(3, default_config());
316 let v = [1.0f32, 0.0, 0.0];
317 let ts = 1_000_000;
318
319 index.insert(0, &v, ts);
320 let candidates = index.query(&v, ts);
321
322 assert!(
323 candidates.contains(&0),
324 "query with identical vector+timestamp should find the point"
325 );
326 }
327
328 #[test]
329 fn insert_multiple_query_nearest() {
330 let config = default_config();
331 let mut index = TemporalLSH::new(3, config);
332
333 for i in 0..100u32 {
335 let v = [i as f32 * 0.1, (i as f32 * 0.05).sin(), 0.0];
336 let ts = i as i64 * 500_000; index.insert(i, &v, ts);
338 }
339
340 assert_eq!(index.len(), 100);
341
342 let query_v = [5.0, (50.0 * 0.05f32).sin(), 0.0];
344 let query_ts = 25_000_000;
345 let candidates = index.query(&query_v, query_ts);
346
347 assert!(!candidates.is_empty(), "should find at least one candidate");
349 }
350
351 #[test]
354 fn temporal_neighbors_found_via_multiprobe() {
355 let config = TLSHConfig {
356 n_tables: 8,
357 semantic_bits: 8,
358 temporal_bits: 4,
359 temporal_bucket_us: 1_000_000, n_probes: 3,
361 };
362 let mut index = TemporalLSH::new(2, config);
363
364 index.insert(0, &[1.0, 0.0], 0);
366 index.insert(1, &[1.0, 0.0], 2_000_000);
368
369 let candidates = index.query(&[1.0, 0.0], 1_000_000);
371
372 let found_0 = candidates.contains(&0);
375 let found_1 = candidates.contains(&1);
376 assert!(
377 found_0 || found_1,
378 "multi-probe should find at least one temporal neighbor, got {candidates:?}"
379 );
380 }
381
382 #[test]
385 fn similar_vectors_same_bucket() {
386 let config = TLSHConfig {
387 n_tables: 16,
388 semantic_bits: 8,
389 temporal_bits: 2,
390 temporal_bucket_us: 1_000_000,
391 n_probes: 1,
392 };
393 let mut index = TemporalLSH::new(4, config);
394
395 index.insert(0, &[1.0, 0.0, 0.0, 0.0], 0);
397 index.insert(1, &[0.99, 0.01, 0.0, 0.0], 0);
398 index.insert(2, &[-1.0, 0.0, 0.0, 0.0], 0);
400
401 let candidates = index.query(&[1.0, 0.0, 0.0, 0.0], 0);
402
403 assert!(candidates.contains(&0));
405 }
408
409 #[test]
412 fn build_from_vectors() {
413 let vectors: Vec<Vec<f32>> = (0..50).map(|i| vec![i as f32, 0.0]).collect();
414 let timestamps: Vec<i64> = (0..50).map(|i| i as i64 * 1_000_000).collect();
415
416 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
417 let index = TemporalLSH::build(&refs, ×tamps, default_config());
418
419 assert_eq!(index.len(), 50);
420 }
421
422 #[test]
425 fn config_for_alpha_high() {
426 let config = TLSHConfig::for_alpha(0.9, 384);
427 assert!(config.semantic_bits > config.temporal_bits);
429 }
430
431 #[test]
432 fn config_for_alpha_balanced() {
433 let config = TLSHConfig::for_alpha(0.5, 384);
434 let diff = (config.semantic_bits as i32 - config.temporal_bits as i32).unsigned_abs();
436 assert!(diff <= 2, "balanced alpha should give roughly equal bits");
437 }
438
439 #[test]
440 fn config_for_alpha_low() {
441 let config = TLSHConfig::for_alpha(0.2, 384);
442 assert!(config.temporal_bits > config.semantic_bits);
444 }
445
446 #[test]
449 fn memory_estimate_grows_with_data() {
450 let config = default_config();
451 let mut index = TemporalLSH::new(4, config);
452 let mem_empty = index.memory_bytes();
453
454 for i in 0..100u32 {
455 index.insert(i, &[i as f32, 0.0, 0.0, 0.0], i as i64 * 1000);
456 }
457 let mem_full = index.memory_bytes();
458
459 assert!(
460 mem_full > mem_empty,
461 "memory should grow with inserted points"
462 );
463 }
464
465 #[test]
468 fn query_empty_index() {
469 let index = TemporalLSH::new(3, default_config());
470 let candidates = index.query(&[1.0, 0.0, 0.0], 0);
471 assert!(candidates.is_empty());
472 }
473
474 #[test]
475 fn negative_timestamps() {
476 let mut index = TemporalLSH::new(2, default_config());
477 index.insert(0, &[1.0, 0.0], -5_000_000);
478 index.insert(1, &[1.0, 0.0], -3_000_000);
479
480 let candidates = index.query(&[1.0, 0.0], -4_000_000);
481 assert!(!candidates.is_empty(), "should handle negative timestamps");
482 }
483
484 #[test]
485 fn high_dimensional() {
486 let dim = 384;
487 let config = TLSHConfig {
488 n_tables: 4,
489 semantic_bits: 12,
490 temporal_bits: 4,
491 temporal_bucket_us: 1_000_000,
492 n_probes: 2,
493 };
494 let mut index = TemporalLSH::new(dim, config);
495
496 let v: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.01).sin()).collect();
497 index.insert(0, &v, 0);
498
499 let candidates = index.query(&v, 0);
500 assert!(candidates.contains(&0));
501 }
502
503 #[test]
506 fn same_input_same_hash() {
507 let index = TemporalLSH::new(3, default_config());
508 let v = [1.0f32, 2.0, 3.0];
509 let ts = 5_000_000;
510
511 let h1 = index.compute_hash(0, &v, ts);
512 let h2 = index.compute_hash(0, &v, ts);
513 assert_eq!(h1, h2, "same input should produce same hash");
514 }
515
516 #[test]
517 fn different_time_different_hash() {
518 let index = TemporalLSH::new(3, default_config());
519 let v = [1.0f32, 0.0, 0.0];
520
521 let h1 = index.compute_hash(0, &v, 0);
523 let h2 = index.compute_hash(0, &v, 10_000_000); assert_ne!(
528 h1, h2,
529 "different temporal buckets should usually give different hashes"
530 );
531 }
532}