1use std::collections::BTreeMap;
13use std::path::Path;
14
15use serde::{Deserialize, Serialize};
16
17use crate::anchor::{AnchorMetric, project_to_anchors};
18use crate::calculus::{drift_magnitude_l2, drift_report};
19use cvx_core::types::TemporalFilter;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct AnchorSetConfig {
26 pub anchor_set_id: u32,
28 pub name: String,
30 pub metric: AnchorMetricSerde,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
36pub enum AnchorMetricSerde {
37 Cosine,
39 L2,
41}
42
43impl From<AnchorMetricSerde> for AnchorMetric {
44 fn from(m: AnchorMetricSerde) -> Self {
45 match m {
46 AnchorMetricSerde::Cosine => AnchorMetric::Cosine,
47 AnchorMetricSerde::L2 => AnchorMetric::L2,
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
56pub struct AnchorDriftReport {
57 pub per_anchor_delta: Vec<f32>,
59 pub l2_magnitude: f32,
61 pub cosine_drift: f32,
63 pub dominant_anchor: usize,
65 pub model_t1: u32,
67 pub model_t2: u32,
69}
70
71pub struct AnchorSpaceIndex {
78 config: AnchorSetConfig,
80 k: usize,
82 projected_vectors: Vec<Vec<f32>>,
84 source_model: Vec<u32>,
86 entity_ids: Vec<u64>,
88 timestamps: Vec<i64>,
90 entity_index: BTreeMap<u64, Vec<(i64, u32)>>,
92}
93
94impl AnchorSpaceIndex {
95 pub fn new(config: AnchorSetConfig, k: usize) -> Self {
97 Self {
98 config,
99 k,
100 projected_vectors: Vec::new(),
101 source_model: Vec::new(),
102 entity_ids: Vec::new(),
103 timestamps: Vec::new(),
104 entity_index: BTreeMap::new(),
105 }
106 }
107
108 pub fn insert(
113 &mut self,
114 entity_id: u64,
115 timestamp: i64,
116 vector: &[f32],
117 model_anchors: &[&[f32]],
118 model_id: u32,
119 ) -> u32 {
120 let traj = [(timestamp, vector)];
122 let projected = project_to_anchors(&traj, model_anchors, self.config.metric.into());
123
124 let proj_vec = projected.into_iter().next().unwrap().1;
125 self.insert_projected(entity_id, timestamp, proj_vec, model_id)
126 }
127
128 pub fn insert_projected(
130 &mut self,
131 entity_id: u64,
132 timestamp: i64,
133 projected: Vec<f32>,
134 model_id: u32,
135 ) -> u32 {
136 assert_eq!(
137 projected.len(),
138 self.k,
139 "projected vector dim {} != anchor count {}",
140 projected.len(),
141 self.k
142 );
143
144 let node_id = self.projected_vectors.len() as u32;
145 self.projected_vectors.push(projected);
146 self.source_model.push(model_id);
147 self.entity_ids.push(entity_id);
148 self.timestamps.push(timestamp);
149
150 self.entity_index
151 .entry(entity_id)
152 .or_default()
153 .push((timestamp, node_id));
154
155 node_id
156 }
157
158 pub fn search(
162 &self,
163 query_projected: &[f32],
164 k: usize,
165 filter: TemporalFilter,
166 ) -> Vec<(u32, f32)> {
167 let mut results: Vec<(u32, f32)> = self
168 .projected_vectors
169 .iter()
170 .enumerate()
171 .filter(|(i, _)| filter.matches(self.timestamps[*i]))
172 .map(|(i, v)| (i as u32, drift_magnitude_l2(query_projected, v)))
173 .collect();
174
175 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
176 results.truncate(k);
177 results
178 }
179
180 pub fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, Vec<f32>)> {
182 let Some(entries) = self.entity_index.get(&entity_id) else {
183 return Vec::new();
184 };
185
186 let mut result: Vec<(i64, Vec<f32>)> = entries
187 .iter()
188 .filter(|(ts, _)| filter.matches(*ts))
189 .map(|&(ts, nid)| (ts, self.projected_vectors[nid as usize].clone()))
190 .collect();
191
192 result.sort_by_key(|&(ts, _)| ts);
193 result
194 }
195
196 pub fn cross_model_trajectory(
198 &self,
199 entity_id: u64,
200 filter: TemporalFilter,
201 ) -> BTreeMap<u32, Vec<(i64, Vec<f32>)>> {
202 let Some(entries) = self.entity_index.get(&entity_id) else {
203 return BTreeMap::new();
204 };
205
206 let mut by_model: BTreeMap<u32, Vec<(i64, Vec<f32>)>> = BTreeMap::new();
207
208 for &(ts, nid) in entries {
209 if !filter.matches(ts) {
210 continue;
211 }
212 let model = self.source_model[nid as usize];
213 by_model
214 .entry(model)
215 .or_default()
216 .push((ts, self.projected_vectors[nid as usize].clone()));
217 }
218
219 for traj in by_model.values_mut() {
221 traj.sort_by_key(|&(ts, _)| ts);
222 }
223
224 by_model
225 }
226
227 pub fn anchor_drift(&self, entity_id: u64, t1: i64, t2: i64) -> Option<AnchorDriftReport> {
229 let entries = self.entity_index.get(&entity_id)?;
230
231 let (_, nid1) = entries
233 .iter()
234 .min_by_key(|&&(ts, _)| (ts - t1).unsigned_abs())?;
235 let (_, nid2) = entries
236 .iter()
237 .min_by_key(|&&(ts, _)| (ts - t2).unsigned_abs())?;
238
239 let v1 = &self.projected_vectors[*nid1 as usize];
240 let v2 = &self.projected_vectors[*nid2 as usize];
241
242 let per_anchor_delta: Vec<f32> = v2.iter().zip(v1.iter()).map(|(a, b)| a - b).collect();
243 let report = drift_report(v1, v2, self.k);
244
245 let dominant_anchor = per_anchor_delta
246 .iter()
247 .enumerate()
248 .max_by(|(_, a), (_, b)| a.abs().partial_cmp(&b.abs()).unwrap())
249 .map(|(i, _)| i)
250 .unwrap_or(0);
251
252 Some(AnchorDriftReport {
253 per_anchor_delta,
254 l2_magnitude: report.l2_magnitude,
255 cosine_drift: report.cosine_drift,
256 dominant_anchor,
257 model_t1: self.source_model[*nid1 as usize],
258 model_t2: self.source_model[*nid2 as usize],
259 })
260 }
261
262 pub fn len(&self) -> usize {
264 self.projected_vectors.len()
265 }
266
267 pub fn is_empty(&self) -> bool {
269 self.projected_vectors.is_empty()
270 }
271
272 pub fn n_entities(&self) -> usize {
274 self.entity_index.len()
275 }
276
277 pub fn entity_id(&self, node_id: u32) -> u64 {
279 self.entity_ids[node_id as usize]
280 }
281
282 pub fn timestamp(&self, node_id: u32) -> i64 {
284 self.timestamps[node_id as usize]
285 }
286
287 pub fn source_model(&self, node_id: u32) -> u32 {
289 self.source_model[node_id as usize]
290 }
291
292 pub fn projected_vector(&self, node_id: u32) -> &[f32] {
294 &self.projected_vectors[node_id as usize]
295 }
296
297 pub fn config(&self) -> &AnchorSetConfig {
299 &self.config
300 }
301
302 pub fn k(&self) -> usize {
304 self.k
305 }
306
307 pub fn save(&self, path: &Path) -> std::io::Result<()> {
309 let snapshot = AnchorSpaceSnapshot {
310 config: self.config.clone(),
311 k: self.k,
312 projected_vectors: self.projected_vectors.clone(),
313 source_model: self.source_model.clone(),
314 entity_ids: self.entity_ids.clone(),
315 timestamps: self.timestamps.clone(),
316 entity_index: self.entity_index.clone(),
317 };
318 let bytes = postcard::to_allocvec(&snapshot).map_err(std::io::Error::other)?;
319 std::fs::write(path, bytes)
320 }
321
322 pub fn load(path: &Path) -> std::io::Result<Self> {
324 let bytes = std::fs::read(path)?;
325 let snapshot: AnchorSpaceSnapshot = postcard::from_bytes(&bytes)
326 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
327 Ok(Self {
328 config: snapshot.config,
329 k: snapshot.k,
330 projected_vectors: snapshot.projected_vectors,
331 source_model: snapshot.source_model,
332 entity_ids: snapshot.entity_ids,
333 timestamps: snapshot.timestamps,
334 entity_index: snapshot.entity_index,
335 })
336 }
337}
338
339#[derive(Serialize, Deserialize)]
340struct AnchorSpaceSnapshot {
341 config: AnchorSetConfig,
342 k: usize,
343 projected_vectors: Vec<Vec<f32>>,
344 source_model: Vec<u32>,
345 entity_ids: Vec<u64>,
346 timestamps: Vec<i64>,
347 entity_index: BTreeMap<u64, Vec<(i64, u32)>>,
348}
349
350#[cfg(test)]
353mod tests {
354 use super::*;
355
356 fn test_config() -> AnchorSetConfig {
357 AnchorSetConfig {
358 anchor_set_id: 1,
359 name: "test_anchors".to_string(),
360 metric: AnchorMetricSerde::Cosine,
361 }
362 }
363
364 #[test]
367 fn new_empty() {
368 let index = AnchorSpaceIndex::new(test_config(), 3);
369 assert_eq!(index.len(), 0);
370 assert!(index.is_empty());
371 assert_eq!(index.k(), 3);
372 }
373
374 #[test]
375 fn insert_projected() {
376 let mut index = AnchorSpaceIndex::new(test_config(), 3);
377 let id = index.insert_projected(42, 1000, vec![0.1, 0.5, 0.3], 0);
378
379 assert_eq!(index.len(), 1);
380 assert_eq!(index.entity_id(id), 42);
381 assert_eq!(index.timestamp(id), 1000);
382 assert_eq!(index.source_model(id), 0);
383 assert_eq!(index.projected_vector(id), &[0.1, 0.5, 0.3]);
384 }
385
386 #[test]
387 fn insert_with_projection() {
388 let mut index = AnchorSpaceIndex::new(test_config(), 2);
389
390 let vector = [1.0f32, 0.0, 0.0];
391 let anchors: Vec<Vec<f32>> = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
392 let anchor_refs: Vec<&[f32]> = anchors.iter().map(|a| a.as_slice()).collect();
393
394 let id = index.insert(42, 1000, &vector, &anchor_refs, 0);
395
396 assert_eq!(index.len(), 1);
397 let proj = index.projected_vector(id);
398 assert_eq!(proj.len(), 2);
399 assert!(
401 proj[0] < 0.01,
402 "should be close to anchor 0, got {}",
403 proj[0]
404 );
405 assert!(
406 (proj[1] - 1.0).abs() < 0.01,
407 "should be far from anchor 1, got {}",
408 proj[1]
409 );
410 }
411
412 #[test]
415 fn search_finds_nearest() {
416 let mut index = AnchorSpaceIndex::new(test_config(), 3);
417
418 for i in 0..5u32 {
420 index.insert_projected(i as u64, i as i64 * 1000, vec![i as f32, 0.0, 0.0], 0);
421 }
422
423 let results = index.search(&[2.1, 0.0, 0.0], 3, TemporalFilter::All);
425 assert_eq!(results.len(), 3);
426 assert_eq!(results[0].0, 2); }
428
429 #[test]
430 fn search_with_temporal_filter() {
431 let mut index = AnchorSpaceIndex::new(test_config(), 2);
432
433 for i in 0..10u32 {
434 index.insert_projected(i as u64, i as i64 * 1000, vec![i as f32, 0.0], 0);
435 }
436
437 let results = index.search(&[5.0, 0.0], 10, TemporalFilter::Range(3000, 7000));
438 for &(nid, _) in &results {
439 let ts = index.timestamp(nid);
440 assert!((3000..=7000).contains(&ts), "ts {ts} outside range");
441 }
442 }
443
444 #[test]
447 fn cross_model_search() {
448 let mut index = AnchorSpaceIndex::new(test_config(), 2);
449
450 index.insert_projected(1, 1000, vec![0.1, 0.9], 0);
452 index.insert_projected(2, 1000, vec![0.1, 0.8], 1);
454 index.insert_projected(3, 1000, vec![5.0, 5.0], 0);
456
457 let results = index.search(&[0.1, 0.85], 2, TemporalFilter::All);
458 assert_eq!(results.len(), 2);
459
460 let model_0 = results.iter().any(|&(nid, _)| index.source_model(nid) == 0);
462 let model_1 = results.iter().any(|&(nid, _)| index.source_model(nid) == 1);
463 assert!(
464 model_0 && model_1,
465 "search should return results from both models"
466 );
467 }
468
469 #[test]
472 fn trajectory_in_anchor_space() {
473 let mut index = AnchorSpaceIndex::new(test_config(), 2);
474
475 for i in 0..5u64 {
476 index.insert_projected(42, i as i64 * 1000, vec![i as f32 * 0.1, 0.5], 0);
477 }
478
479 let traj = index.trajectory(42, TemporalFilter::All);
480 assert_eq!(traj.len(), 5);
481 for w in traj.windows(2) {
483 assert!(w[0].0 <= w[1].0);
484 }
485 }
486
487 #[test]
488 fn cross_model_trajectory() {
489 let mut index = AnchorSpaceIndex::new(test_config(), 2);
490
491 index.insert_projected(42, 1000, vec![0.1, 0.9], 0);
493 index.insert_projected(42, 2000, vec![0.2, 0.8], 0);
494 index.insert_projected(42, 1000, vec![0.15, 0.85], 1);
495 index.insert_projected(42, 2000, vec![0.25, 0.75], 1);
496
497 let by_model = index.cross_model_trajectory(42, TemporalFilter::All);
498 assert_eq!(by_model.len(), 2); assert_eq!(by_model[&0].len(), 2);
500 assert_eq!(by_model[&1].len(), 2);
501 }
502
503 #[test]
506 fn anchor_drift_approaching() {
507 let mut index = AnchorSpaceIndex::new(test_config(), 3);
508
509 index.insert_projected(1, 1000, vec![1.0, 0.5, 0.5], 0);
511 index.insert_projected(1, 2000, vec![0.5, 0.5, 0.5], 0);
512
513 let report = index.anchor_drift(1, 1000, 2000).unwrap();
514 assert!(
515 report.per_anchor_delta[0] < 0.0,
516 "should be approaching anchor 0"
517 );
518 assert_eq!(report.dominant_anchor, 0);
519 assert!(report.l2_magnitude > 0.0);
520 }
521
522 #[test]
523 fn anchor_drift_cross_model() {
524 let mut index = AnchorSpaceIndex::new(test_config(), 2);
525
526 index.insert_projected(1, 1000, vec![0.8, 0.2], 0);
528 index.insert_projected(1, 2000, vec![0.3, 0.7], 1);
529
530 let report = index.anchor_drift(1, 1000, 2000).unwrap();
531 assert_eq!(report.model_t1, 0);
532 assert_eq!(report.model_t2, 1);
533 assert!(report.l2_magnitude > 0.0);
534 }
535
536 #[test]
537 fn anchor_drift_unknown_entity() {
538 let index = AnchorSpaceIndex::new(test_config(), 2);
539 assert!(index.anchor_drift(999, 0, 1000).is_none());
540 }
541
542 #[test]
545 fn save_load_roundtrip() {
546 let mut index = AnchorSpaceIndex::new(test_config(), 3);
547
548 for i in 0..10u32 {
549 index.insert_projected(
550 i as u64 % 3,
551 i as i64 * 1000,
552 vec![i as f32 * 0.1, 0.5, 0.3],
553 i % 2,
554 );
555 }
556
557 let dir = tempfile::tempdir().unwrap();
558 let path = dir.path().join("anchor_index.bin");
559 index.save(&path).unwrap();
560
561 let loaded = AnchorSpaceIndex::load(&path).unwrap();
562 assert_eq!(loaded.len(), 10);
563 assert_eq!(loaded.k(), 3);
564 assert_eq!(loaded.n_entities(), 3);
565
566 let orig_results = index.search(&[0.5, 0.5, 0.3], 3, TemporalFilter::All);
568 let loaded_results = loaded.search(&[0.5, 0.5, 0.3], 3, TemporalFilter::All);
569 assert_eq!(orig_results.len(), loaded_results.len());
570 for (a, b) in orig_results.iter().zip(loaded_results.iter()) {
571 assert_eq!(a.0, b.0);
572 }
573 }
574
575 #[test]
578 #[should_panic(expected = "projected vector dim")]
579 fn insert_wrong_dim_panics() {
580 let mut index = AnchorSpaceIndex::new(test_config(), 3);
581 index.insert_projected(1, 1000, vec![0.1, 0.2], 0); }
583
584 #[test]
585 fn trajectory_unknown_entity() {
586 let index = AnchorSpaceIndex::new(test_config(), 2);
587 assert!(index.trajectory(999, TemporalFilter::All).is_empty());
588 }
589
590 #[test]
591 fn search_empty_index() {
592 let index = AnchorSpaceIndex::new(test_config(), 3);
593 let results = index.search(&[0.0, 0.0, 0.0], 5, TemporalFilter::All);
594 assert!(results.is_empty());
595 }
596}