cvx_index/hnsw/
metadata_store.rs

1//! In-memory metadata store for HNSW nodes.
2//!
3//! Stores `HashMap<String, String>` per node_id, enabling metadata filtering
4//! on search results without modifying the HNSW graph structure.
5
6use roaring::RoaringBitmap;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use cvx_core::types::metadata_filter::{MetadataFilter, MetadataPredicate};
11
12/// Dense metadata store with inverted index for O(1) pre-filtering.
13///
14/// Two data structures:
15/// - `entries`: node_id → metadata map (for retrieval)
16/// - `inverted`: key → value → RoaringBitmap of node_ids (for filtering)
17///
18/// The inverted index supports exact-match pre-filtering during HNSW
19/// traversal, replacing the O(4k) post-filter with O(1) bitmap lookups.
20#[derive(Debug, Clone, Default, Serialize, Deserialize)]
21pub struct MetadataStore {
22    /// node_id → metadata. Empty HashMap for nodes without metadata.
23    entries: Vec<HashMap<String, String>>,
24    /// Inverted index: key → value → bitmap of matching node_ids.
25    /// Only populated for exact string values (not numeric ranges).
26    #[serde(default)]
27    inverted: HashMap<String, HashMap<String, RoaringBitmap>>,
28}
29
30impl MetadataStore {
31    /// Create an empty store.
32    pub fn new() -> Self {
33        Self {
34            entries: Vec::new(),
35            inverted: HashMap::new(),
36        }
37    }
38
39    /// Register metadata for a new node (must be called in order).
40    pub fn push(&mut self, metadata: HashMap<String, String>) {
41        let node_id = self.entries.len() as u32;
42        // Update inverted index
43        for (key, value) in &metadata {
44            self.inverted
45                .entry(key.clone())
46                .or_default()
47                .entry(value.clone())
48                .or_default()
49                .insert(node_id);
50        }
51        self.entries.push(metadata);
52    }
53
54    /// Register an empty metadata entry.
55    pub fn push_empty(&mut self) {
56        self.entries.push(HashMap::new());
57    }
58
59    /// Get metadata for a node.
60    pub fn get(&self, node_id: u32) -> &HashMap<String, String> {
61        static EMPTY: std::sync::LazyLock<HashMap<String, String>> =
62            std::sync::LazyLock::new(HashMap::new);
63        self.entries.get(node_id as usize).unwrap_or(&EMPTY)
64    }
65
66    /// Check if a node passes a metadata filter.
67    pub fn matches(&self, node_id: u32, filter: &MetadataFilter) -> bool {
68        if filter.is_empty() {
69            return true;
70        }
71        filter.matches(self.get(node_id))
72    }
73
74    /// Build a RoaringBitmap of node_ids matching the metadata filter.
75    ///
76    /// For `Equals` predicates: uses the inverted index for O(1) lookup.
77    /// For other predicates (Gte, Lte, Contains, Exists): falls back to
78    /// scanning entries.
79    ///
80    /// Multiple predicates are AND-combined (intersection).
81    pub fn build_filter_bitmap(&self, filter: &MetadataFilter) -> RoaringBitmap {
82        if filter.is_empty() {
83            // No filter → all nodes match
84            let mut all = RoaringBitmap::new();
85            for i in 0..self.entries.len() as u32 {
86                all.insert(i);
87            }
88            return all;
89        }
90
91        let mut result: Option<RoaringBitmap> = None;
92
93        for (field, predicate) in &filter.predicates {
94            let bitmap = match predicate {
95                MetadataPredicate::Equals(value) => {
96                    // Fast path: use inverted index
97                    self.inverted
98                        .get(field)
99                        .and_then(|values| values.get(value))
100                        .cloned()
101                        .unwrap_or_default()
102                }
103                _ => {
104                    // Slow path: scan entries
105                    let mut bm = RoaringBitmap::new();
106                    for (i, entry) in self.entries.iter().enumerate() {
107                        if predicate.matches(entry.get(field)) {
108                            bm.insert(i as u32);
109                        }
110                    }
111                    bm
112                }
113            };
114
115            result = Some(match result {
116                Some(existing) => existing & bitmap, // AND intersection
117                None => bitmap,
118            });
119        }
120
121        result.unwrap_or_default()
122    }
123
124    /// Filter a list of (node_id, score) results by metadata.
125    pub fn filter_results(
126        &self,
127        results: &[(u32, f32)],
128        filter: &MetadataFilter,
129    ) -> Vec<(u32, f32)> {
130        if filter.is_empty() {
131            return results.to_vec();
132        }
133        results
134            .iter()
135            .filter(|(nid, _)| self.matches(*nid, filter))
136            .copied()
137            .collect()
138    }
139
140    /// Number of entries.
141    pub fn len(&self) -> usize {
142        self.entries.len()
143    }
144
145    /// Whether empty.
146    pub fn is_empty(&self) -> bool {
147        self.entries.is_empty()
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn push_and_get() {
157        let mut store = MetadataStore::new();
158        let mut meta = HashMap::new();
159        meta.insert("reward".into(), "0.8".into());
160        meta.insert("step_index".into(), "0".into());
161        store.push(meta);
162        store.push_empty();
163
164        assert_eq!(store.get(0).get("reward").unwrap(), "0.8");
165        assert!(store.get(1).is_empty());
166        assert!(store.get(999).is_empty()); // out of bounds → empty
167    }
168
169    #[test]
170    fn filter_results_by_metadata() {
171        let mut store = MetadataStore::new();
172        for i in 0..5u32 {
173            let mut m = HashMap::new();
174            m.insert("reward".into(), format!("{}", i as f64 * 0.2));
175            m.insert("step_index".into(), format!("{i}"));
176            store.push(m);
177        }
178
179        let results: Vec<(u32, f32)> = (0..5).map(|i| (i, i as f32 * 0.1)).collect();
180
181        // Filter: reward >= 0.5 → nodes 3 (0.6) and 4 (0.8)
182        let filter = MetadataFilter::new().gte("reward", 0.5);
183        let filtered = store.filter_results(&results, &filter);
184        assert_eq!(filtered.len(), 2);
185        assert_eq!(filtered[0].0, 3);
186        assert_eq!(filtered[1].0, 4);
187    }
188
189    #[test]
190    fn empty_filter_passes_all() {
191        let mut store = MetadataStore::new();
192        store.push_empty();
193        store.push_empty();
194
195        let results = vec![(0u32, 0.1f32), (1, 0.2)];
196        let filtered = store.filter_results(&results, &MetadataFilter::new());
197        assert_eq!(filtered.len(), 2);
198    }
199
200    // ─── Inverted index tests ────────────────────────────────────
201
202    #[test]
203    fn inverted_index_built_on_push() {
204        let mut store = MetadataStore::new();
205        let mut m = HashMap::new();
206        m.insert("goal".into(), "clean".into());
207        m.insert("room".into(), "kitchen".into());
208        store.push(m);
209
210        let mut m2 = HashMap::new();
211        m2.insert("goal".into(), "clean".into());
212        m2.insert("room".into(), "bedroom".into());
213        store.push(m2);
214
215        let mut m3 = HashMap::new();
216        m3.insert("goal".into(), "find".into());
217        store.push(m3);
218
219        // Check inverted index
220        let goal_clean = &store.inverted["goal"]["clean"];
221        assert!(goal_clean.contains(0));
222        assert!(goal_clean.contains(1));
223        assert!(!goal_clean.contains(2));
224
225        let goal_find = &store.inverted["goal"]["find"];
226        assert!(goal_find.contains(2));
227        assert_eq!(goal_find.len(), 1);
228    }
229
230    #[test]
231    fn build_filter_bitmap_equals_uses_inverted() {
232        let mut store = MetadataStore::new();
233        for i in 0..100u32 {
234            let mut m = HashMap::new();
235            m.insert(
236                "goal".into(),
237                if i % 3 == 0 { "clean" } else { "find" }.into(),
238            );
239            store.push(m);
240        }
241
242        let filter = MetadataFilter::new().equals("goal", "clean");
243        let bitmap = store.build_filter_bitmap(&filter);
244        assert_eq!(bitmap.len(), 34); // 0,3,6,...,99 → 34 values
245
246        for id in bitmap.iter() {
247            assert_eq!(id % 3, 0);
248        }
249    }
250
251    #[test]
252    fn build_filter_bitmap_gte_scans() {
253        let mut store = MetadataStore::new();
254        for i in 0..10u32 {
255            let mut m = HashMap::new();
256            m.insert("reward".into(), format!("{}", i as f64 * 0.1));
257            store.push(m);
258        }
259
260        let filter = MetadataFilter::new().gte("reward", 0.5);
261        let bitmap = store.build_filter_bitmap(&filter);
262        // reward >= 0.5: nodes 5(0.5),6(0.6),7(0.7),8(0.8),9(0.9)
263        assert_eq!(bitmap.len(), 5);
264        for id in bitmap.iter() {
265            assert!(id >= 5);
266        }
267    }
268
269    #[test]
270    fn build_filter_bitmap_combined_and() {
271        let mut store = MetadataStore::new();
272        for i in 0..20u32 {
273            let mut m = HashMap::new();
274            m.insert(
275                "goal".into(),
276                if i % 2 == 0 { "clean" } else { "find" }.into(),
277            );
278            m.insert("reward".into(), format!("{}", i as f64 * 0.05));
279            store.push(m);
280        }
281
282        // goal=clean AND reward >= 0.5
283        let filter = MetadataFilter::new()
284            .equals("goal", "clean")
285            .gte("reward", 0.5);
286        let bitmap = store.build_filter_bitmap(&filter);
287
288        // goal=clean: 0,2,4,6,8,10,12,14,16,18
289        // reward >= 0.5: 10(0.5),11,12,...,19
290        // AND: 10,12,14,16,18
291        assert_eq!(bitmap.len(), 5);
292        for id in bitmap.iter() {
293            assert!(id >= 10);
294            assert_eq!(id % 2, 0);
295        }
296    }
297
298    #[test]
299    fn build_filter_bitmap_empty_returns_all() {
300        let mut store = MetadataStore::new();
301        for _ in 0..10 {
302            store.push_empty();
303        }
304        let bitmap = store.build_filter_bitmap(&MetadataFilter::new());
305        assert_eq!(bitmap.len(), 10);
306    }
307}