cvx_storage/cold/
mod.rs

1//! Cold storage tier with Product Quantization (PQ) compression.
2//!
3//! Vectors are encoded into compact PQ codes for massive storage reduction.
4//! A codebook of centroids is trained via k-means on representative data,
5//! then each vector is encoded as a sequence of centroid indices.
6//!
7//! ## Compression
8//!
9//! With M=8 subspaces and K=256 centroids per subspace:
10//! - Original D=768 vector: 768 × 4 bytes = 3,072 bytes
11//! - PQ code: 8 × 1 byte = 8 bytes
12//! - **Compression ratio: 384×**
13//!
14//! ## Asymmetric Distance Computation (ADC)
15//!
16//! Query-to-code distance is computed without decoding:
17//! precompute query-to-centroid distances, then sum lookup table entries.
18
19/// Product Quantization codebook.
20#[derive(Debug, Clone)]
21pub struct PqCodebook {
22    /// Number of subspaces.
23    pub m: usize,
24    /// Number of centroids per subspace.
25    pub k: usize,
26    /// Original vector dimensionality.
27    pub dim: usize,
28    /// Centroids: `[subspace][centroid][sub_dim]`.
29    /// Flattened: length = m * k * (dim / m).
30    pub centroids: Vec<f32>,
31}
32
33impl PqCodebook {
34    /// Train a codebook from a set of vectors using k-means.
35    ///
36    /// - `vectors`: training data (each of length `dim`)
37    /// - `m`: number of subspaces
38    /// - `k`: centroids per subspace
39    /// - `iterations`: k-means iterations
40    pub fn train(vectors: &[&[f32]], m: usize, k: usize, iterations: usize) -> Self {
41        assert!(!vectors.is_empty(), "need training data");
42        let dim = vectors[0].len();
43        assert!(dim % m == 0, "dim must be divisible by m");
44        let sub_dim = dim / m;
45
46        let mut centroids = vec![0.0f32; m * k * sub_dim];
47
48        for sub in 0..m {
49            let offset = sub * sub_dim;
50
51            // k-means++ initialization (Arthur & Vassilvitskii, SODA 2007)
52            // Sample centroids proportional to D²(x) for O(log k) approximation.
53            // See RFC-002-09.
54            {
55                // First centroid: pick the first vector's subvector
56                let src = vectors[0];
57                for d in 0..sub_dim {
58                    centroids[sub * k * sub_dim + d] = src[offset + d];
59                }
60                let mut rng_state: u64 = 42 + sub as u64;
61
62                for c in 1..k {
63                    // Compute D²(x): min squared distance to any existing centroid
64                    let weights: Vec<f64> = vectors
65                        .iter()
66                        .map(|v| {
67                            let sub_vec = &v[offset..offset + sub_dim];
68                            (0..c)
69                                .map(|ci| {
70                                    let base = sub * k * sub_dim + ci * sub_dim;
71                                    (0..sub_dim)
72                                        .map(|d| {
73                                            let diff = sub_vec[d] - centroids[base + d];
74                                            (diff * diff) as f64
75                                        })
76                                        .sum::<f64>()
77                                })
78                                .fold(f64::INFINITY, f64::min)
79                        })
80                        .collect();
81
82                    // Cumulative sum for weighted sampling
83                    let total: f64 = weights.iter().sum();
84                    if total <= 0.0 {
85                        // All points coincide with existing centroids; just cycle
86                        let src = vectors[c % vectors.len()];
87                        for d in 0..sub_dim {
88                            centroids[sub * k * sub_dim + c * sub_dim + d] = src[offset + d];
89                        }
90                        continue;
91                    }
92
93                    // Simple LCG for deterministic sampling
94                    rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
95                    let threshold = ((rng_state >> 33) as f64 / u32::MAX as f64) * total;
96
97                    let mut cumulative = 0.0;
98                    let mut selected = vectors.len() - 1;
99                    for (i, w) in weights.iter().enumerate() {
100                        cumulative += w;
101                        if cumulative >= threshold {
102                            selected = i;
103                            break;
104                        }
105                    }
106
107                    let src = vectors[selected];
108                    for d in 0..sub_dim {
109                        centroids[sub * k * sub_dim + c * sub_dim + d] = src[offset + d];
110                    }
111                }
112            }
113
114            // K-means iterations
115            for _ in 0..iterations {
116                let mut sums = vec![0.0f64; k * sub_dim];
117                let mut counts = vec![0usize; k];
118
119                // Assign
120                for &v in vectors {
121                    let sub_vec = &v[offset..offset + sub_dim];
122                    let closest = find_closest_centroid(sub_vec, &centroids, sub, k, sub_dim);
123                    counts[closest] += 1;
124                    for d in 0..sub_dim {
125                        sums[closest * sub_dim + d] += sub_vec[d] as f64;
126                    }
127                }
128
129                // Update
130                for c in 0..k {
131                    if counts[c] > 0 {
132                        for d in 0..sub_dim {
133                            centroids[sub * k * sub_dim + c * sub_dim + d] =
134                                (sums[c * sub_dim + d] / counts[c] as f64) as f32;
135                        }
136                    }
137                }
138            }
139        }
140
141        PqCodebook {
142            m,
143            k,
144            dim,
145            centroids,
146        }
147    }
148
149    /// Subdimension size.
150    pub fn sub_dim(&self) -> usize {
151        self.dim / self.m
152    }
153
154    /// Encode a vector into PQ codes.
155    pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
156        assert_eq!(vector.len(), self.dim);
157        let sub_dim = self.sub_dim();
158        let mut codes = Vec::with_capacity(self.m);
159
160        for sub in 0..self.m {
161            let offset = sub * sub_dim;
162            let sub_vec = &vector[offset..offset + sub_dim];
163            let closest = find_closest_centroid(sub_vec, &self.centroids, sub, self.k, sub_dim);
164            codes.push(closest as u8);
165        }
166
167        codes
168    }
169
170    /// Decode PQ codes back to an approximate vector.
171    pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
172        assert_eq!(codes.len(), self.m);
173        let sub_dim = self.sub_dim();
174        let mut vector = Vec::with_capacity(self.dim);
175
176        for (sub, &code) in codes.iter().enumerate() {
177            let base = sub * self.k * sub_dim + (code as usize) * sub_dim;
178            vector.extend_from_slice(&self.centroids[base..base + sub_dim]);
179        }
180
181        vector
182    }
183
184    /// Build asymmetric distance table for a query vector.
185    ///
186    /// Returns a table of size `[m][k]` where `table[sub][code]` is the
187    /// squared distance from the query subvector to that centroid.
188    pub fn build_distance_table(&self, query: &[f32]) -> Vec<Vec<f32>> {
189        assert_eq!(query.len(), self.dim);
190        let sub_dim = self.sub_dim();
191
192        (0..self.m)
193            .map(|sub| {
194                let q_offset = sub * sub_dim;
195                (0..self.k)
196                    .map(|c| {
197                        let c_base = sub * self.k * sub_dim + c * sub_dim;
198                        (0..sub_dim)
199                            .map(|d| {
200                                let diff = query[q_offset + d] - self.centroids[c_base + d];
201                                diff * diff
202                            })
203                            .sum()
204                    })
205                    .collect()
206            })
207            .collect()
208    }
209
210    /// Compute asymmetric distance from query to a PQ code using precomputed table.
211    pub fn asymmetric_distance(table: &[Vec<f32>], codes: &[u8]) -> f32 {
212        codes
213            .iter()
214            .enumerate()
215            .map(|(sub, &code)| table[sub][code as usize])
216            .sum()
217    }
218}
219
220fn find_closest_centroid(
221    sub_vec: &[f32],
222    centroids: &[f32],
223    sub: usize,
224    k: usize,
225    sub_dim: usize,
226) -> usize {
227    let mut best_idx = 0;
228    let mut best_dist = f32::INFINITY;
229
230    for c in 0..k {
231        let base = sub * k * sub_dim + c * sub_dim;
232        let dist: f32 = (0..sub_dim)
233            .map(|d| {
234                let diff = sub_vec[d] - centroids[base + d];
235                diff * diff
236            })
237            .sum();
238        if dist < best_dist {
239            best_dist = dist;
240            best_idx = c;
241        }
242    }
243
244    best_idx
245}
246
247/// Cold store using PQ-encoded vectors.
248pub struct ColdStore {
249    codebook: PqCodebook,
250    /// Encoded vectors: (entity_id, space_id, timestamp, pq_codes).
251    entries: Vec<ColdEntry>,
252}
253
254/// A single entry in cold storage.
255#[derive(Debug, Clone)]
256struct ColdEntry {
257    entity_id: u64,
258    space_id: u32,
259    timestamp: i64,
260    codes: Vec<u8>,
261}
262
263impl ColdStore {
264    /// Create a cold store with a trained codebook.
265    pub fn new(codebook: PqCodebook) -> Self {
266        Self {
267            codebook,
268            entries: Vec::new(),
269        }
270    }
271
272    /// Store a vector (encodes it with PQ).
273    pub fn put(&mut self, entity_id: u64, space_id: u32, timestamp: i64, vector: &[f32]) {
274        let codes = self.codebook.encode(vector);
275        self.entries.push(ColdEntry {
276            entity_id,
277            space_id,
278            timestamp,
279            codes,
280        });
281    }
282
283    /// Retrieve and decode a vector.
284    pub fn get(&self, entity_id: u64, space_id: u32, timestamp: i64) -> Option<Vec<f32>> {
285        self.entries
286            .iter()
287            .find(|e| {
288                e.entity_id == entity_id && e.space_id == space_id && e.timestamp == timestamp
289            })
290            .map(|e| self.codebook.decode(&e.codes))
291    }
292
293    /// Number of stored entries.
294    pub fn len(&self) -> usize {
295        self.entries.len()
296    }
297
298    /// Whether the store is empty.
299    pub fn is_empty(&self) -> bool {
300        self.entries.is_empty()
301    }
302
303    /// Storage size in bytes (codes only, no overhead).
304    pub fn storage_bytes(&self) -> usize {
305        self.entries.iter().map(|e| e.codes.len()).sum()
306    }
307
308    /// Access the codebook.
309    pub fn codebook(&self) -> &PqCodebook {
310        &self.codebook
311    }
312
313    /// Search using asymmetric distance computation.
314    ///
315    /// Returns `(entity_id, timestamp, distance)` sorted by distance.
316    pub fn search_adc(&self, query: &[f32], k: usize) -> Vec<(u64, i64, f32)> {
317        let table = self.codebook.build_distance_table(query);
318        let mut scored: Vec<(u64, i64, f32)> = self
319            .entries
320            .iter()
321            .map(|e| {
322                let dist = PqCodebook::asymmetric_distance(&table, &e.codes);
323                (e.entity_id, e.timestamp, dist)
324            })
325            .collect();
326        scored.sort_by(|a, b| a.2.total_cmp(&b.2));
327        scored.truncate(k);
328        scored
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
337        let mut state = seed;
338        (0..n)
339            .map(|_| {
340                (0..dim)
341                    .map(|_| {
342                        state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
343                        ((state >> 33) as f32) / (u32::MAX as f32) - 0.5
344                    })
345                    .collect()
346            })
347            .collect()
348    }
349
350    #[test]
351    fn train_and_encode_decode() {
352        let vectors = random_vectors(100, 32, 42);
353        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
354        let codebook = PqCodebook::train(&refs, 4, 16, 10);
355
356        assert_eq!(codebook.m, 4);
357        assert_eq!(codebook.k, 16);
358        assert_eq!(codebook.dim, 32);
359        assert_eq!(codebook.sub_dim(), 8);
360
361        // Encode and decode
362        let codes = codebook.encode(&vectors[0]);
363        assert_eq!(codes.len(), 4);
364
365        let decoded = codebook.decode(&codes);
366        assert_eq!(decoded.len(), 32);
367    }
368
369    #[test]
370    fn decode_approximates_original() {
371        let vectors = random_vectors(500, 64, 42);
372        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
373        let codebook = PqCodebook::train(&refs, 8, 256, 20);
374
375        // Measure reconstruction error
376        let mut total_error = 0.0f64;
377        for v in &vectors {
378            let codes = codebook.encode(v);
379            let decoded = codebook.decode(&codes);
380            let error: f64 = v
381                .iter()
382                .zip(decoded.iter())
383                .map(|(a, b)| ((*a - *b) as f64).powi(2))
384                .sum();
385            total_error += error;
386        }
387        let avg_error = total_error / vectors.len() as f64;
388
389        // PQ should have reasonable reconstruction error
390        assert!(
391            avg_error < 10.0,
392            "avg reconstruction error too high: {avg_error:.4}"
393        );
394    }
395
396    #[test]
397    fn compression_ratio() {
398        let dim = 768;
399        let m = 8;
400        let n = 100;
401        let vectors = random_vectors(n, dim, 42);
402        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
403        let codebook = PqCodebook::train(&refs, m, 256, 5);
404
405        // Actually encode all vectors and measure real sizes
406        let mut total_code_bytes = 0usize;
407        let mut total_reconstruction_error = 0.0f64;
408
409        for v in &vectors {
410            let codes = codebook.encode(v);
411            total_code_bytes += codes.len();
412
413            let decoded = codebook.decode(&codes);
414            let error: f64 = v
415                .iter()
416                .zip(decoded.iter())
417                .map(|(a, b)| ((*a - *b) as f64).powi(2))
418                .sum();
419            total_reconstruction_error += error;
420        }
421
422        let original_bytes = n * dim * 4;
423        let ratio = original_bytes as f64 / total_code_bytes as f64;
424        let avg_error = total_reconstruction_error / n as f64;
425
426        assert!(
427            ratio >= 300.0,
428            "compression ratio = {ratio:.0}x, expected >= 300x for D={dim} M={m}"
429        );
430
431        // Verify each code is M bytes (1 byte per subspace with K=256)
432        assert_eq!(total_code_bytes, n * m);
433
434        // Reconstruction error should be bounded (PQ is lossy but usable)
435        assert!(
436            avg_error < 50.0,
437            "avg reconstruction error = {avg_error:.2}, expected < 50 for D={dim}"
438        );
439    }
440
441    #[test]
442    fn cold_store_put_get() {
443        let vectors = random_vectors(50, 32, 42);
444        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
445        let codebook = PqCodebook::train(&refs, 4, 16, 10);
446        let mut store = ColdStore::new(codebook);
447
448        store.put(1, 0, 1000, &vectors[0]);
449        store.put(1, 0, 2000, &vectors[1]);
450
451        assert_eq!(store.len(), 2);
452
453        let decoded = store.get(1, 0, 1000).unwrap();
454        assert_eq!(decoded.len(), 32);
455    }
456
457    #[test]
458    fn cold_store_get_nonexistent() {
459        let vectors = random_vectors(10, 16, 42);
460        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
461        let codebook = PqCodebook::train(&refs, 2, 8, 5);
462        let store = ColdStore::new(codebook);
463
464        assert!(store.get(999, 0, 0).is_none());
465    }
466
467    #[test]
468    fn adc_search() {
469        let dim = 32;
470        let vectors = random_vectors(200, dim, 42);
471        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
472        let codebook = PqCodebook::train(&refs, 4, 32, 10);
473        let mut store = ColdStore::new(codebook);
474
475        for (i, v) in vectors.iter().enumerate() {
476            store.put(i as u64, 0, (i as i64) * 1000, v);
477        }
478
479        let results = store.search_adc(&vectors[0], 5);
480        assert_eq!(results.len(), 5);
481
482        // First result should be the query itself (or very close)
483        assert_eq!(results[0].0, 0, "closest should be the query vector itself");
484    }
485
486    #[test]
487    fn storage_bytes_compact() {
488        let vectors = random_vectors(1000, 768, 42);
489        let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
490        let codebook = PqCodebook::train(&refs, 8, 256, 5);
491        let mut store = ColdStore::new(codebook);
492
493        for (i, v) in vectors.iter().enumerate() {
494            store.put(i as u64, 0, (i as i64) * 1000, v);
495        }
496
497        let original_bytes = 1000 * 768 * 4;
498        let cold_bytes = store.storage_bytes();
499        let ratio = original_bytes as f64 / cold_bytes as f64;
500
501        assert!(
502            ratio > 100.0,
503            "cold storage ratio = {ratio:.0}x, expected > 100x"
504        );
505    }
506}