1use std::path::Path;
15
16use cvx_core::StorageBackend;
17use cvx_core::error::StorageError;
18use cvx_core::types::{EntityTimeline, TemporalPoint};
19use rocksdb::{
20 ColumnFamilyDescriptor, DBWithThreadMode, IteratorMode, Options, SingleThreaded, SliceTransform,
21};
22
23use crate::keys;
24
25const CF_VECTORS: &str = "vectors";
26const CF_TIMELINES: &str = "timelines";
27const CF_SYSTEM: &str = "default";
28
29pub struct HotStore {
47 db: DBWithThreadMode<SingleThreaded>,
48}
49
50impl HotStore {
51 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, StorageError> {
53 let mut db_opts = Options::default();
54 db_opts.create_if_missing(true);
55 db_opts.create_missing_column_families(true);
56
57 let cf_descriptors = vec![
58 ColumnFamilyDescriptor::new(CF_SYSTEM, Options::default()),
59 ColumnFamilyDescriptor::new(CF_VECTORS, Self::vectors_cf_options()),
60 ColumnFamilyDescriptor::new(CF_TIMELINES, Self::timelines_cf_options()),
61 ];
62
63 let db =
64 DBWithThreadMode::<SingleThreaded>::open_cf_descriptors(&db_opts, path, cf_descriptors)
65 .map_err(|e| StorageError::Io(std::io::Error::other(e.to_string())))?;
66
67 Ok(Self { db })
68 }
69
70 fn vectors_cf_options() -> Options {
72 let mut opts = Options::default();
73 opts.set_prefix_extractor(SliceTransform::create_fixed_prefix(keys::PREFIX_SIZE));
75 opts.set_compression_type(rocksdb::DBCompressionType::None);
77 opts
78 }
79
80 fn timelines_cf_options() -> Options {
82 let mut opts = Options::default();
83 opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
84 opts
85 }
86
87 fn vectors_cf(&self) -> &rocksdb::ColumnFamily {
89 self.db.cf_handle(CF_VECTORS).expect("vectors CF missing")
90 }
91
92 fn timelines_cf(&self) -> &rocksdb::ColumnFamily {
94 self.db
95 .cf_handle(CF_TIMELINES)
96 .expect("timelines CF missing")
97 }
98
99 fn serialize_vector(vector: &[f32]) -> Vec<u8> {
101 vector.iter().flat_map(|f| f.to_le_bytes()).collect()
102 }
103
104 fn deserialize_vector(bytes: &[u8]) -> Vec<f32> {
106 bytes
107 .chunks_exact(4)
108 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
109 .collect()
110 }
111
112 fn update_timeline(
114 &self,
115 entity_id: u64,
116 space_id: u32,
117 timestamp: i64,
118 ) -> Result<(), StorageError> {
119 let tl_key = keys::encode_prefix(entity_id, space_id);
120
121 let timeline = match self
122 .db
123 .get_cf(self.timelines_cf(), tl_key)
124 .map_err(|e| StorageError::Io(std::io::Error::other(e.to_string())))?
125 {
126 Some(bytes) => {
127 let existing: EntityTimeline = postcard::from_bytes(&bytes)
128 .map_err(|e| StorageError::Io(std::io::Error::other(e.to_string())))?;
129 EntityTimeline::new(
130 entity_id,
131 space_id,
132 existing.first_seen().min(timestamp),
133 existing.last_seen().max(timestamp),
134 existing.point_count() + 1,
135 existing.keyframe_interval(),
136 )
137 }
138 None => EntityTimeline::new(entity_id, space_id, timestamp, timestamp, 1, 10),
139 };
140
141 let tl_bytes = postcard::to_allocvec(&timeline)
142 .map_err(|e| StorageError::Io(std::io::Error::other(e.to_string())))?;
143 self.db
144 .put_cf(self.timelines_cf(), tl_key, tl_bytes)
145 .map_err(|e| StorageError::Io(std::io::Error::other(e.to_string())))?;
146
147 Ok(())
148 }
149
150 pub fn get_timeline(
152 &self,
153 entity_id: u64,
154 space_id: u32,
155 ) -> Result<Option<EntityTimeline>, StorageError> {
156 let tl_key = keys::encode_prefix(entity_id, space_id);
157 match self
158 .db
159 .get_cf(self.timelines_cf(), tl_key)
160 .map_err(|e| StorageError::Io(std::io::Error::other(e.to_string())))?
161 {
162 Some(bytes) => {
163 let tl: EntityTimeline = postcard::from_bytes(&bytes)
164 .map_err(|e| StorageError::Io(std::io::Error::other(e.to_string())))?;
165 Ok(Some(tl))
166 }
167 None => Ok(None),
168 }
169 }
170}
171
172impl StorageBackend for HotStore {
173 fn get(
174 &self,
175 entity_id: u64,
176 space_id: u32,
177 timestamp: i64,
178 ) -> Result<Option<TemporalPoint>, StorageError> {
179 let key = keys::encode_key(entity_id, space_id, timestamp);
180 match self
181 .db
182 .get_cf(self.vectors_cf(), key)
183 .map_err(|e| StorageError::Io(std::io::Error::other(e.to_string())))?
184 {
185 Some(bytes) => {
186 let vector = Self::deserialize_vector(&bytes);
187 Ok(Some(TemporalPoint::new(entity_id, timestamp, vector)))
188 }
189 None => Ok(None),
190 }
191 }
192
193 fn put(&self, space_id: u32, point: &TemporalPoint) -> Result<(), StorageError> {
194 let key = keys::encode_key(point.entity_id(), space_id, point.timestamp());
195 let value = Self::serialize_vector(point.vector());
196
197 self.db
198 .put_cf(self.vectors_cf(), key, value)
199 .map_err(|e| StorageError::Io(std::io::Error::other(e.to_string())))?;
200
201 self.update_timeline(point.entity_id(), space_id, point.timestamp())?;
202
203 Ok(())
204 }
205
206 fn range(
207 &self,
208 entity_id: u64,
209 space_id: u32,
210 start: i64,
211 end: i64,
212 ) -> Result<Vec<TemporalPoint>, StorageError> {
213 let start_key = keys::encode_key(entity_id, space_id, start);
214 let end_key = keys::encode_key(entity_id, space_id, end);
215 let prefix = keys::encode_prefix(entity_id, space_id);
216
217 let iter = self.db.iterator_cf(
218 self.vectors_cf(),
219 IteratorMode::From(&start_key, rocksdb::Direction::Forward),
220 );
221
222 let mut results = Vec::new();
223 for item in iter {
224 let (key, value) =
225 item.map_err(|e| StorageError::Io(std::io::Error::other(e.to_string())))?;
226
227 if key.len() < keys::PREFIX_SIZE || key[..keys::PREFIX_SIZE] != prefix[..] {
229 break;
230 }
231 if key[..] > end_key[..] {
233 break;
234 }
235
236 let (_, _, timestamp) = keys::decode_key(&key);
237 let vector = Self::deserialize_vector(&value);
238 results.push(TemporalPoint::new(entity_id, timestamp, vector));
239 }
240
241 Ok(results)
242 }
243
244 fn delete(&self, entity_id: u64, space_id: u32, timestamp: i64) -> Result<(), StorageError> {
245 let key = keys::encode_key(entity_id, space_id, timestamp);
246 self.db
247 .delete_cf(self.vectors_cf(), key)
248 .map_err(|e| StorageError::Io(std::io::Error::other(e.to_string())))?;
249 Ok(())
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 fn temp_store() -> (tempfile::TempDir, HotStore) {
258 let dir = tempfile::tempdir().unwrap();
259 let store = HotStore::open(dir.path()).unwrap();
260 (dir, store)
261 }
262
263 fn sample_point(entity_id: u64, timestamp: i64) -> TemporalPoint {
264 TemporalPoint::new(entity_id, timestamp, vec![0.1, 0.2, 0.3])
265 }
266
267 #[test]
268 fn put_and_get() {
269 let (_dir, store) = temp_store();
270 let p = sample_point(42, 1000);
271 store.put(0, &p).unwrap();
272
273 let result = store.get(42, 0, 1000).unwrap();
274 assert_eq!(result, Some(p));
275 }
276
277 #[test]
278 fn get_nonexistent_returns_none() {
279 let (_dir, store) = temp_store();
280 assert_eq!(store.get(999, 0, 0).unwrap(), None);
281 }
282
283 #[test]
284 fn range_returns_ordered() {
285 let (_dir, store) = temp_store();
286 for ts in [100, 200, 300, 400, 500] {
287 store.put(0, &sample_point(1, ts)).unwrap();
288 }
289
290 let results = store.range(1, 0, 200, 400).unwrap();
291 assert_eq!(results.len(), 3);
292 assert_eq!(results[0].timestamp(), 200);
293 assert_eq!(results[1].timestamp(), 300);
294 assert_eq!(results[2].timestamp(), 400);
295 }
296
297 #[test]
298 fn range_does_not_cross_entities() {
299 let (_dir, store) = temp_store();
300 store.put(0, &sample_point(1, 100)).unwrap();
301 store.put(0, &sample_point(2, 100)).unwrap();
302
303 let results = store.range(1, 0, 0, i64::MAX).unwrap();
304 assert_eq!(results.len(), 1);
305 assert_eq!(results[0].entity_id(), 1);
306 }
307
308 #[test]
309 fn delete_removes_point() {
310 let (_dir, store) = temp_store();
311 store.put(0, &sample_point(1, 1000)).unwrap();
312 store.delete(1, 0, 1000).unwrap();
313 assert_eq!(store.get(1, 0, 1000).unwrap(), None);
314 }
315
316 #[test]
317 fn timeline_tracks_metadata() {
318 let (_dir, store) = temp_store();
319 store.put(0, &sample_point(42, 1000)).unwrap();
320 store.put(0, &sample_point(42, 2000)).unwrap();
321 store.put(0, &sample_point(42, 3000)).unwrap();
322
323 let tl = store.get_timeline(42, 0).unwrap().unwrap();
324 assert_eq!(tl.entity_id(), 42);
325 assert_eq!(tl.first_seen(), 1000);
326 assert_eq!(tl.last_seen(), 3000);
327 assert_eq!(tl.point_count(), 3);
328 }
329
330 #[test]
331 fn negative_timestamps_work() {
332 let (_dir, store) = temp_store();
333 store.put(0, &sample_point(1, -5000)).unwrap();
334 store.put(0, &sample_point(1, -1000)).unwrap();
335 store.put(0, &sample_point(1, 0)).unwrap();
336
337 let results = store.range(1, 0, -3000, 0).unwrap();
338 assert_eq!(results.len(), 2);
339 assert_eq!(results[0].timestamp(), -1000);
340 assert_eq!(results[1].timestamp(), 0);
341 }
342
343 #[test]
344 fn data_survives_reopen() {
345 let dir = tempfile::tempdir().unwrap();
346
347 {
349 let store = HotStore::open(dir.path()).unwrap();
350 store.put(0, &sample_point(42, 1000)).unwrap();
351 store.put(0, &sample_point(42, 2000)).unwrap();
352 }
353 {
357 let store = HotStore::open(dir.path()).unwrap();
358 let p1 = store.get(42, 0, 1000).unwrap();
359 assert!(p1.is_some());
360 assert_eq!(p1.unwrap().timestamp(), 1000);
361
362 let p2 = store.get(42, 0, 2000).unwrap();
363 assert!(p2.is_some());
364
365 let tl = store.get_timeline(42, 0).unwrap().unwrap();
366 assert_eq!(tl.point_count(), 2);
367 }
368 }
369
370 #[test]
371 fn d768_vectors_roundtrip() {
372 let (_dir, store) = temp_store();
373 let vector: Vec<f32> = (0..768).map(|i| i as f32 * 0.001).collect();
374 let p = TemporalPoint::new(1, 1000, vector.clone());
375 store.put(0, &p).unwrap();
376
377 let retrieved = store.get(1, 0, 1000).unwrap().unwrap();
378 assert_eq!(retrieved.vector(), vector.as_slice());
379 }
380
381 #[test]
382 fn insert_100k_and_retrieve() {
383 let (_dir, store) = temp_store();
384 let dim = 8; for i in 0..100_000u64 {
386 let entity = i / 100;
387 let ts = (i % 100) as i64 * 1000;
388 let vec = vec![i as f32; dim];
389 store.put(0, &TemporalPoint::new(entity, ts, vec)).unwrap();
390 }
391
392 let results = store.range(42, 0, 0, 100_000).unwrap();
394 assert_eq!(results.len(), 100);
395
396 for window in results.windows(2) {
398 assert!(window[0].timestamp() < window[1].timestamp());
399 }
400
401 let tl = store.get_timeline(42, 0).unwrap().unwrap();
403 assert_eq!(tl.point_count(), 100);
404 assert_eq!(tl.first_seen(), 0);
405 assert_eq!(tl.last_seen(), 99_000);
406 }
407}