cvx_analytics/
anchor_index.rs

1//! Anchor-space invariant index (RFC-011).
2//!
3//! Indexes entities in anchor-projected space (ℝᴷ) where K = number of anchors.
4//! Vectors from DIFFERENT embedding models are directly comparable when projected
5//! through the same anchor set, enabling cross-model search and trajectory analysis.
6//!
7//! # Why flat scan?
8//!
9//! K is typically 5-20 (number of clinical/semantic anchors). For 1M points in
10//! ℝ¹⁰, a flat scan with SIMD L2 takes ~10ms — fast enough for most use cases.
11
12use std::collections::BTreeMap;
13use std::path::Path;
14
15use serde::{Deserialize, Serialize};
16
17use crate::anchor::{AnchorMetric, project_to_anchors};
18use crate::calculus::{drift_magnitude_l2, drift_report};
19use cvx_core::types::TemporalFilter;
20
21// ─── Configuration ──────────────────────────────────────────────────
22
23/// Configuration for an anchor set.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct AnchorSetConfig {
26    /// Unique identifier for this anchor set.
27    pub anchor_set_id: u32,
28    /// Human-readable name (e.g., "clinical_anchors_v1").
29    pub name: String,
30    /// Distance metric for projection.
31    pub metric: AnchorMetricSerde,
32}
33
34/// Serializable anchor metric (mirrors `AnchorMetric`).
35#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
36pub enum AnchorMetricSerde {
37    /// Cosine distance.
38    Cosine,
39    /// L2 distance.
40    L2,
41}
42
43impl From<AnchorMetricSerde> for AnchorMetric {
44    fn from(m: AnchorMetricSerde) -> Self {
45        match m {
46            AnchorMetricSerde::Cosine => AnchorMetric::Cosine,
47            AnchorMetricSerde::L2 => AnchorMetric::L2,
48        }
49    }
50}
51
52// ─── Drift report in anchor space ───────────────────────────────────
53
54/// Drift report in anchor-projected space.
55#[derive(Debug, Clone)]
56pub struct AnchorDriftReport {
57    /// Per-anchor distance change: positive = moved away, negative = approached.
58    pub per_anchor_delta: Vec<f32>,
59    /// L2 magnitude of the drift vector in anchor space.
60    pub l2_magnitude: f32,
61    /// Cosine drift in anchor space.
62    pub cosine_drift: f32,
63    /// Index of the anchor with the largest absolute change.
64    pub dominant_anchor: usize,
65    /// Source model ID at t1.
66    pub model_t1: u32,
67    /// Source model ID at t2.
68    pub model_t2: u32,
69}
70
71// ─── AnchorSpaceIndex ───────────────────────────────────────────────
72
73/// An index operating in anchor-projected space (ℝᴷ).
74///
75/// Stores pre-projected vectors from potentially multiple embedding models.
76/// All comparable because they use the same anchor set.
77pub struct AnchorSpaceIndex {
78    /// Anchor set configuration.
79    config: AnchorSetConfig,
80    /// Number of anchors (= dimensionality of projected space).
81    k: usize,
82    /// Projected vectors: node_id → Vec<f32> of length K.
83    projected_vectors: Vec<Vec<f32>>,
84    /// Source model identifier per node.
85    source_model: Vec<u32>,
86    /// Entity ID per node.
87    entity_ids: Vec<u64>,
88    /// Timestamp per node.
89    timestamps: Vec<i64>,
90    /// Entity index: entity_id → sorted vec of (timestamp, node_id).
91    entity_index: BTreeMap<u64, Vec<(i64, u32)>>,
92}
93
94impl AnchorSpaceIndex {
95    /// Create a new empty anchor space index.
96    pub fn new(config: AnchorSetConfig, k: usize) -> Self {
97        Self {
98            config,
99            k,
100            projected_vectors: Vec::new(),
101            source_model: Vec::new(),
102            entity_ids: Vec::new(),
103            timestamps: Vec::new(),
104            entity_index: BTreeMap::new(),
105        }
106    }
107
108    /// Insert a raw vector, projecting it to anchor space.
109    ///
110    /// `model_anchors` are the anchor vectors embedded in the SAME model
111    /// as the input vector. These may differ from anchors of other models.
112    pub fn insert(
113        &mut self,
114        entity_id: u64,
115        timestamp: i64,
116        vector: &[f32],
117        model_anchors: &[&[f32]],
118        model_id: u32,
119    ) -> u32 {
120        // Project single point via project_to_anchors
121        let traj = [(timestamp, vector)];
122        let projected = project_to_anchors(&traj, model_anchors, self.config.metric.into());
123
124        let proj_vec = projected.into_iter().next().unwrap().1;
125        self.insert_projected(entity_id, timestamp, proj_vec, model_id)
126    }
127
128    /// Insert a pre-projected vector (already in ℝᴷ).
129    pub fn insert_projected(
130        &mut self,
131        entity_id: u64,
132        timestamp: i64,
133        projected: Vec<f32>,
134        model_id: u32,
135    ) -> u32 {
136        assert_eq!(
137            projected.len(),
138            self.k,
139            "projected vector dim {} != anchor count {}",
140            projected.len(),
141            self.k
142        );
143
144        let node_id = self.projected_vectors.len() as u32;
145        self.projected_vectors.push(projected);
146        self.source_model.push(model_id);
147        self.entity_ids.push(entity_id);
148        self.timestamps.push(timestamp);
149
150        self.entity_index
151            .entry(entity_id)
152            .or_default()
153            .push((timestamp, node_id));
154
155        node_id
156    }
157
158    /// Search in anchor space: flat scan kNN by L2 distance in ℝᴷ.
159    ///
160    /// Cross-model: results may come from ANY source model.
161    pub fn search(
162        &self,
163        query_projected: &[f32],
164        k: usize,
165        filter: TemporalFilter,
166    ) -> Vec<(u32, f32)> {
167        let mut results: Vec<(u32, f32)> = self
168            .projected_vectors
169            .iter()
170            .enumerate()
171            .filter(|(i, _)| filter.matches(self.timestamps[*i]))
172            .map(|(i, v)| (i as u32, drift_magnitude_l2(query_projected, v)))
173            .collect();
174
175        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
176        results.truncate(k);
177        results
178    }
179
180    /// Retrieve trajectory in anchor space for an entity.
181    pub fn trajectory(&self, entity_id: u64, filter: TemporalFilter) -> Vec<(i64, Vec<f32>)> {
182        let Some(entries) = self.entity_index.get(&entity_id) else {
183            return Vec::new();
184        };
185
186        let mut result: Vec<(i64, Vec<f32>)> = entries
187            .iter()
188            .filter(|(ts, _)| filter.matches(*ts))
189            .map(|&(ts, nid)| (ts, self.projected_vectors[nid as usize].clone()))
190            .collect();
191
192        result.sort_by_key(|&(ts, _)| ts);
193        result
194    }
195
196    /// Cross-model trajectory: separate trajectories per source model.
197    pub fn cross_model_trajectory(
198        &self,
199        entity_id: u64,
200        filter: TemporalFilter,
201    ) -> BTreeMap<u32, Vec<(i64, Vec<f32>)>> {
202        let Some(entries) = self.entity_index.get(&entity_id) else {
203            return BTreeMap::new();
204        };
205
206        let mut by_model: BTreeMap<u32, Vec<(i64, Vec<f32>)>> = BTreeMap::new();
207
208        for &(ts, nid) in entries {
209            if !filter.matches(ts) {
210                continue;
211            }
212            let model = self.source_model[nid as usize];
213            by_model
214                .entry(model)
215                .or_default()
216                .push((ts, self.projected_vectors[nid as usize].clone()));
217        }
218
219        // Sort each model's trajectory by timestamp
220        for traj in by_model.values_mut() {
221            traj.sort_by_key(|&(ts, _)| ts);
222        }
223
224        by_model
225    }
226
227    /// Compute drift in anchor space between two timestamps.
228    pub fn anchor_drift(&self, entity_id: u64, t1: i64, t2: i64) -> Option<AnchorDriftReport> {
229        let entries = self.entity_index.get(&entity_id)?;
230
231        // Find nearest point to t1 and t2
232        let (_, nid1) = entries
233            .iter()
234            .min_by_key(|&&(ts, _)| (ts - t1).unsigned_abs())?;
235        let (_, nid2) = entries
236            .iter()
237            .min_by_key(|&&(ts, _)| (ts - t2).unsigned_abs())?;
238
239        let v1 = &self.projected_vectors[*nid1 as usize];
240        let v2 = &self.projected_vectors[*nid2 as usize];
241
242        let per_anchor_delta: Vec<f32> = v2.iter().zip(v1.iter()).map(|(a, b)| a - b).collect();
243        let report = drift_report(v1, v2, self.k);
244
245        let dominant_anchor = per_anchor_delta
246            .iter()
247            .enumerate()
248            .max_by(|(_, a), (_, b)| a.abs().partial_cmp(&b.abs()).unwrap())
249            .map(|(i, _)| i)
250            .unwrap_or(0);
251
252        Some(AnchorDriftReport {
253            per_anchor_delta,
254            l2_magnitude: report.l2_magnitude,
255            cosine_drift: report.cosine_drift,
256            dominant_anchor,
257            model_t1: self.source_model[*nid1 as usize],
258            model_t2: self.source_model[*nid2 as usize],
259        })
260    }
261
262    /// Number of indexed points.
263    pub fn len(&self) -> usize {
264        self.projected_vectors.len()
265    }
266
267    /// Whether the index is empty.
268    pub fn is_empty(&self) -> bool {
269        self.projected_vectors.is_empty()
270    }
271
272    /// Number of unique entities.
273    pub fn n_entities(&self) -> usize {
274        self.entity_index.len()
275    }
276
277    /// Get entity ID for a node.
278    pub fn entity_id(&self, node_id: u32) -> u64 {
279        self.entity_ids[node_id as usize]
280    }
281
282    /// Get timestamp for a node.
283    pub fn timestamp(&self, node_id: u32) -> i64 {
284        self.timestamps[node_id as usize]
285    }
286
287    /// Get source model for a node.
288    pub fn source_model(&self, node_id: u32) -> u32 {
289        self.source_model[node_id as usize]
290    }
291
292    /// Get projected vector for a node.
293    pub fn projected_vector(&self, node_id: u32) -> &[f32] {
294        &self.projected_vectors[node_id as usize]
295    }
296
297    /// Anchor set config.
298    pub fn config(&self) -> &AnchorSetConfig {
299        &self.config
300    }
301
302    /// Dimensionality of the projected space (= number of anchors).
303    pub fn k(&self) -> usize {
304        self.k
305    }
306
307    /// Save to file via postcard.
308    pub fn save(&self, path: &Path) -> std::io::Result<()> {
309        let snapshot = AnchorSpaceSnapshot {
310            config: self.config.clone(),
311            k: self.k,
312            projected_vectors: self.projected_vectors.clone(),
313            source_model: self.source_model.clone(),
314            entity_ids: self.entity_ids.clone(),
315            timestamps: self.timestamps.clone(),
316            entity_index: self.entity_index.clone(),
317        };
318        let bytes = postcard::to_allocvec(&snapshot).map_err(std::io::Error::other)?;
319        std::fs::write(path, bytes)
320    }
321
322    /// Load from file.
323    pub fn load(path: &Path) -> std::io::Result<Self> {
324        let bytes = std::fs::read(path)?;
325        let snapshot: AnchorSpaceSnapshot = postcard::from_bytes(&bytes)
326            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
327        Ok(Self {
328            config: snapshot.config,
329            k: snapshot.k,
330            projected_vectors: snapshot.projected_vectors,
331            source_model: snapshot.source_model,
332            entity_ids: snapshot.entity_ids,
333            timestamps: snapshot.timestamps,
334            entity_index: snapshot.entity_index,
335        })
336    }
337}
338
339#[derive(Serialize, Deserialize)]
340struct AnchorSpaceSnapshot {
341    config: AnchorSetConfig,
342    k: usize,
343    projected_vectors: Vec<Vec<f32>>,
344    source_model: Vec<u32>,
345    entity_ids: Vec<u64>,
346    timestamps: Vec<i64>,
347    entity_index: BTreeMap<u64, Vec<(i64, u32)>>,
348}
349
350// ─── Tests ──────────────────────────────────────────────────────────
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    fn test_config() -> AnchorSetConfig {
357        AnchorSetConfig {
358            anchor_set_id: 1,
359            name: "test_anchors".to_string(),
360            metric: AnchorMetricSerde::Cosine,
361        }
362    }
363
364    // ─── Basic operations ───────────────────────────────────────
365
366    #[test]
367    fn new_empty() {
368        let index = AnchorSpaceIndex::new(test_config(), 3);
369        assert_eq!(index.len(), 0);
370        assert!(index.is_empty());
371        assert_eq!(index.k(), 3);
372    }
373
374    #[test]
375    fn insert_projected() {
376        let mut index = AnchorSpaceIndex::new(test_config(), 3);
377        let id = index.insert_projected(42, 1000, vec![0.1, 0.5, 0.3], 0);
378
379        assert_eq!(index.len(), 1);
380        assert_eq!(index.entity_id(id), 42);
381        assert_eq!(index.timestamp(id), 1000);
382        assert_eq!(index.source_model(id), 0);
383        assert_eq!(index.projected_vector(id), &[0.1, 0.5, 0.3]);
384    }
385
386    #[test]
387    fn insert_with_projection() {
388        let mut index = AnchorSpaceIndex::new(test_config(), 2);
389
390        let vector = [1.0f32, 0.0, 0.0];
391        let anchors: Vec<Vec<f32>> = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
392        let anchor_refs: Vec<&[f32]> = anchors.iter().map(|a| a.as_slice()).collect();
393
394        let id = index.insert(42, 1000, &vector, &anchor_refs, 0);
395
396        assert_eq!(index.len(), 1);
397        let proj = index.projected_vector(id);
398        assert_eq!(proj.len(), 2);
399        // [1,0,0] is cosine distance 0 from anchor [1,0,0] and 1 from [0,1,0]
400        assert!(
401            proj[0] < 0.01,
402            "should be close to anchor 0, got {}",
403            proj[0]
404        );
405        assert!(
406            (proj[1] - 1.0).abs() < 0.01,
407            "should be far from anchor 1, got {}",
408            proj[1]
409        );
410    }
411
412    // ─── Search ─────────────────────────────────────────────────
413
414    #[test]
415    fn search_finds_nearest() {
416        let mut index = AnchorSpaceIndex::new(test_config(), 3);
417
418        // Insert 5 points at different positions in anchor space
419        for i in 0..5u32 {
420            index.insert_projected(i as u64, i as i64 * 1000, vec![i as f32, 0.0, 0.0], 0);
421        }
422
423        // Query near point 2
424        let results = index.search(&[2.1, 0.0, 0.0], 3, TemporalFilter::All);
425        assert_eq!(results.len(), 3);
426        assert_eq!(results[0].0, 2); // closest
427    }
428
429    #[test]
430    fn search_with_temporal_filter() {
431        let mut index = AnchorSpaceIndex::new(test_config(), 2);
432
433        for i in 0..10u32 {
434            index.insert_projected(i as u64, i as i64 * 1000, vec![i as f32, 0.0], 0);
435        }
436
437        let results = index.search(&[5.0, 0.0], 10, TemporalFilter::Range(3000, 7000));
438        for &(nid, _) in &results {
439            let ts = index.timestamp(nid);
440            assert!((3000..=7000).contains(&ts), "ts {ts} outside range");
441        }
442    }
443
444    // ─── Cross-model search ─────────────────────────────────────
445
446    #[test]
447    fn cross_model_search() {
448        let mut index = AnchorSpaceIndex::new(test_config(), 2);
449
450        // Model 0: entity 1 near [0.1, 0.9]
451        index.insert_projected(1, 1000, vec![0.1, 0.9], 0);
452        // Model 1: entity 2 near [0.1, 0.8]
453        index.insert_projected(2, 1000, vec![0.1, 0.8], 1);
454        // Model 0: entity 3 far away
455        index.insert_projected(3, 1000, vec![5.0, 5.0], 0);
456
457        let results = index.search(&[0.1, 0.85], 2, TemporalFilter::All);
458        assert_eq!(results.len(), 2);
459
460        // Top 2 should be entities 1 and 2 (from different models!)
461        let model_0 = results.iter().any(|&(nid, _)| index.source_model(nid) == 0);
462        let model_1 = results.iter().any(|&(nid, _)| index.source_model(nid) == 1);
463        assert!(
464            model_0 && model_1,
465            "search should return results from both models"
466        );
467    }
468
469    // ─── Trajectory ─────────────────────────────────────────────
470
471    #[test]
472    fn trajectory_in_anchor_space() {
473        let mut index = AnchorSpaceIndex::new(test_config(), 2);
474
475        for i in 0..5u64 {
476            index.insert_projected(42, i as i64 * 1000, vec![i as f32 * 0.1, 0.5], 0);
477        }
478
479        let traj = index.trajectory(42, TemporalFilter::All);
480        assert_eq!(traj.len(), 5);
481        // Sorted by timestamp
482        for w in traj.windows(2) {
483            assert!(w[0].0 <= w[1].0);
484        }
485    }
486
487    #[test]
488    fn cross_model_trajectory() {
489        let mut index = AnchorSpaceIndex::new(test_config(), 2);
490
491        // Same entity from 2 models
492        index.insert_projected(42, 1000, vec![0.1, 0.9], 0);
493        index.insert_projected(42, 2000, vec![0.2, 0.8], 0);
494        index.insert_projected(42, 1000, vec![0.15, 0.85], 1);
495        index.insert_projected(42, 2000, vec![0.25, 0.75], 1);
496
497        let by_model = index.cross_model_trajectory(42, TemporalFilter::All);
498        assert_eq!(by_model.len(), 2); // 2 models
499        assert_eq!(by_model[&0].len(), 2);
500        assert_eq!(by_model[&1].len(), 2);
501    }
502
503    // ─── Anchor drift ───────────────────────────────────────────
504
505    #[test]
506    fn anchor_drift_approaching() {
507        let mut index = AnchorSpaceIndex::new(test_config(), 3);
508
509        // Entity moves closer to anchor 0 over time
510        index.insert_projected(1, 1000, vec![1.0, 0.5, 0.5], 0);
511        index.insert_projected(1, 2000, vec![0.5, 0.5, 0.5], 0);
512
513        let report = index.anchor_drift(1, 1000, 2000).unwrap();
514        assert!(
515            report.per_anchor_delta[0] < 0.0,
516            "should be approaching anchor 0"
517        );
518        assert_eq!(report.dominant_anchor, 0);
519        assert!(report.l2_magnitude > 0.0);
520    }
521
522    #[test]
523    fn anchor_drift_cross_model() {
524        let mut index = AnchorSpaceIndex::new(test_config(), 2);
525
526        // t1 from model 0, t2 from model 1
527        index.insert_projected(1, 1000, vec![0.8, 0.2], 0);
528        index.insert_projected(1, 2000, vec![0.3, 0.7], 1);
529
530        let report = index.anchor_drift(1, 1000, 2000).unwrap();
531        assert_eq!(report.model_t1, 0);
532        assert_eq!(report.model_t2, 1);
533        assert!(report.l2_magnitude > 0.0);
534    }
535
536    #[test]
537    fn anchor_drift_unknown_entity() {
538        let index = AnchorSpaceIndex::new(test_config(), 2);
539        assert!(index.anchor_drift(999, 0, 1000).is_none());
540    }
541
542    // ─── Persistence ────────────────────────────────────────────
543
544    #[test]
545    fn save_load_roundtrip() {
546        let mut index = AnchorSpaceIndex::new(test_config(), 3);
547
548        for i in 0..10u32 {
549            index.insert_projected(
550                i as u64 % 3,
551                i as i64 * 1000,
552                vec![i as f32 * 0.1, 0.5, 0.3],
553                i % 2,
554            );
555        }
556
557        let dir = tempfile::tempdir().unwrap();
558        let path = dir.path().join("anchor_index.bin");
559        index.save(&path).unwrap();
560
561        let loaded = AnchorSpaceIndex::load(&path).unwrap();
562        assert_eq!(loaded.len(), 10);
563        assert_eq!(loaded.k(), 3);
564        assert_eq!(loaded.n_entities(), 3);
565
566        // Verify search results match
567        let orig_results = index.search(&[0.5, 0.5, 0.3], 3, TemporalFilter::All);
568        let loaded_results = loaded.search(&[0.5, 0.5, 0.3], 3, TemporalFilter::All);
569        assert_eq!(orig_results.len(), loaded_results.len());
570        for (a, b) in orig_results.iter().zip(loaded_results.iter()) {
571            assert_eq!(a.0, b.0);
572        }
573    }
574
575    // ─── Edge cases ─────────────────────────────────────────────
576
577    #[test]
578    #[should_panic(expected = "projected vector dim")]
579    fn insert_wrong_dim_panics() {
580        let mut index = AnchorSpaceIndex::new(test_config(), 3);
581        index.insert_projected(1, 1000, vec![0.1, 0.2], 0); // 2 dims, expected 3
582    }
583
584    #[test]
585    fn trajectory_unknown_entity() {
586        let index = AnchorSpaceIndex::new(test_config(), 2);
587        assert!(index.trajectory(999, TemporalFilter::All).is_empty());
588    }
589
590    #[test]
591    fn search_empty_index() {
592        let index = AnchorSpaceIndex::new(test_config(), 3);
593        let results = index.search(&[0.0, 0.0, 0.0], 5, TemporalFilter::All);
594        assert!(results.is_empty());
595    }
596}