1#[derive(Debug, Clone)]
21pub struct PqCodebook {
22 pub m: usize,
24 pub k: usize,
26 pub dim: usize,
28 pub centroids: Vec<f32>,
31}
32
33impl PqCodebook {
34 pub fn train(vectors: &[&[f32]], m: usize, k: usize, iterations: usize) -> Self {
41 assert!(!vectors.is_empty(), "need training data");
42 let dim = vectors[0].len();
43 assert!(dim % m == 0, "dim must be divisible by m");
44 let sub_dim = dim / m;
45
46 let mut centroids = vec![0.0f32; m * k * sub_dim];
47
48 for sub in 0..m {
49 let offset = sub * sub_dim;
50
51 {
55 let src = vectors[0];
57 for d in 0..sub_dim {
58 centroids[sub * k * sub_dim + d] = src[offset + d];
59 }
60 let mut rng_state: u64 = 42 + sub as u64;
61
62 for c in 1..k {
63 let weights: Vec<f64> = vectors
65 .iter()
66 .map(|v| {
67 let sub_vec = &v[offset..offset + sub_dim];
68 (0..c)
69 .map(|ci| {
70 let base = sub * k * sub_dim + ci * sub_dim;
71 (0..sub_dim)
72 .map(|d| {
73 let diff = sub_vec[d] - centroids[base + d];
74 (diff * diff) as f64
75 })
76 .sum::<f64>()
77 })
78 .fold(f64::INFINITY, f64::min)
79 })
80 .collect();
81
82 let total: f64 = weights.iter().sum();
84 if total <= 0.0 {
85 let src = vectors[c % vectors.len()];
87 for d in 0..sub_dim {
88 centroids[sub * k * sub_dim + c * sub_dim + d] = src[offset + d];
89 }
90 continue;
91 }
92
93 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
95 let threshold = ((rng_state >> 33) as f64 / u32::MAX as f64) * total;
96
97 let mut cumulative = 0.0;
98 let mut selected = vectors.len() - 1;
99 for (i, w) in weights.iter().enumerate() {
100 cumulative += w;
101 if cumulative >= threshold {
102 selected = i;
103 break;
104 }
105 }
106
107 let src = vectors[selected];
108 for d in 0..sub_dim {
109 centroids[sub * k * sub_dim + c * sub_dim + d] = src[offset + d];
110 }
111 }
112 }
113
114 for _ in 0..iterations {
116 let mut sums = vec![0.0f64; k * sub_dim];
117 let mut counts = vec![0usize; k];
118
119 for &v in vectors {
121 let sub_vec = &v[offset..offset + sub_dim];
122 let closest = find_closest_centroid(sub_vec, ¢roids, sub, k, sub_dim);
123 counts[closest] += 1;
124 for d in 0..sub_dim {
125 sums[closest * sub_dim + d] += sub_vec[d] as f64;
126 }
127 }
128
129 for c in 0..k {
131 if counts[c] > 0 {
132 for d in 0..sub_dim {
133 centroids[sub * k * sub_dim + c * sub_dim + d] =
134 (sums[c * sub_dim + d] / counts[c] as f64) as f32;
135 }
136 }
137 }
138 }
139 }
140
141 PqCodebook {
142 m,
143 k,
144 dim,
145 centroids,
146 }
147 }
148
149 pub fn sub_dim(&self) -> usize {
151 self.dim / self.m
152 }
153
154 pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
156 assert_eq!(vector.len(), self.dim);
157 let sub_dim = self.sub_dim();
158 let mut codes = Vec::with_capacity(self.m);
159
160 for sub in 0..self.m {
161 let offset = sub * sub_dim;
162 let sub_vec = &vector[offset..offset + sub_dim];
163 let closest = find_closest_centroid(sub_vec, &self.centroids, sub, self.k, sub_dim);
164 codes.push(closest as u8);
165 }
166
167 codes
168 }
169
170 pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
172 assert_eq!(codes.len(), self.m);
173 let sub_dim = self.sub_dim();
174 let mut vector = Vec::with_capacity(self.dim);
175
176 for (sub, &code) in codes.iter().enumerate() {
177 let base = sub * self.k * sub_dim + (code as usize) * sub_dim;
178 vector.extend_from_slice(&self.centroids[base..base + sub_dim]);
179 }
180
181 vector
182 }
183
184 pub fn build_distance_table(&self, query: &[f32]) -> Vec<Vec<f32>> {
189 assert_eq!(query.len(), self.dim);
190 let sub_dim = self.sub_dim();
191
192 (0..self.m)
193 .map(|sub| {
194 let q_offset = sub * sub_dim;
195 (0..self.k)
196 .map(|c| {
197 let c_base = sub * self.k * sub_dim + c * sub_dim;
198 (0..sub_dim)
199 .map(|d| {
200 let diff = query[q_offset + d] - self.centroids[c_base + d];
201 diff * diff
202 })
203 .sum()
204 })
205 .collect()
206 })
207 .collect()
208 }
209
210 pub fn asymmetric_distance(table: &[Vec<f32>], codes: &[u8]) -> f32 {
212 codes
213 .iter()
214 .enumerate()
215 .map(|(sub, &code)| table[sub][code as usize])
216 .sum()
217 }
218}
219
220fn find_closest_centroid(
221 sub_vec: &[f32],
222 centroids: &[f32],
223 sub: usize,
224 k: usize,
225 sub_dim: usize,
226) -> usize {
227 let mut best_idx = 0;
228 let mut best_dist = f32::INFINITY;
229
230 for c in 0..k {
231 let base = sub * k * sub_dim + c * sub_dim;
232 let dist: f32 = (0..sub_dim)
233 .map(|d| {
234 let diff = sub_vec[d] - centroids[base + d];
235 diff * diff
236 })
237 .sum();
238 if dist < best_dist {
239 best_dist = dist;
240 best_idx = c;
241 }
242 }
243
244 best_idx
245}
246
247pub struct ColdStore {
249 codebook: PqCodebook,
250 entries: Vec<ColdEntry>,
252}
253
254#[derive(Debug, Clone)]
256struct ColdEntry {
257 entity_id: u64,
258 space_id: u32,
259 timestamp: i64,
260 codes: Vec<u8>,
261}
262
263impl ColdStore {
264 pub fn new(codebook: PqCodebook) -> Self {
266 Self {
267 codebook,
268 entries: Vec::new(),
269 }
270 }
271
272 pub fn put(&mut self, entity_id: u64, space_id: u32, timestamp: i64, vector: &[f32]) {
274 let codes = self.codebook.encode(vector);
275 self.entries.push(ColdEntry {
276 entity_id,
277 space_id,
278 timestamp,
279 codes,
280 });
281 }
282
283 pub fn get(&self, entity_id: u64, space_id: u32, timestamp: i64) -> Option<Vec<f32>> {
285 self.entries
286 .iter()
287 .find(|e| {
288 e.entity_id == entity_id && e.space_id == space_id && e.timestamp == timestamp
289 })
290 .map(|e| self.codebook.decode(&e.codes))
291 }
292
293 pub fn len(&self) -> usize {
295 self.entries.len()
296 }
297
298 pub fn is_empty(&self) -> bool {
300 self.entries.is_empty()
301 }
302
303 pub fn storage_bytes(&self) -> usize {
305 self.entries.iter().map(|e| e.codes.len()).sum()
306 }
307
308 pub fn codebook(&self) -> &PqCodebook {
310 &self.codebook
311 }
312
313 pub fn search_adc(&self, query: &[f32], k: usize) -> Vec<(u64, i64, f32)> {
317 let table = self.codebook.build_distance_table(query);
318 let mut scored: Vec<(u64, i64, f32)> = self
319 .entries
320 .iter()
321 .map(|e| {
322 let dist = PqCodebook::asymmetric_distance(&table, &e.codes);
323 (e.entity_id, e.timestamp, dist)
324 })
325 .collect();
326 scored.sort_by(|a, b| a.2.total_cmp(&b.2));
327 scored.truncate(k);
328 scored
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
337 let mut state = seed;
338 (0..n)
339 .map(|_| {
340 (0..dim)
341 .map(|_| {
342 state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
343 ((state >> 33) as f32) / (u32::MAX as f32) - 0.5
344 })
345 .collect()
346 })
347 .collect()
348 }
349
350 #[test]
351 fn train_and_encode_decode() {
352 let vectors = random_vectors(100, 32, 42);
353 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
354 let codebook = PqCodebook::train(&refs, 4, 16, 10);
355
356 assert_eq!(codebook.m, 4);
357 assert_eq!(codebook.k, 16);
358 assert_eq!(codebook.dim, 32);
359 assert_eq!(codebook.sub_dim(), 8);
360
361 let codes = codebook.encode(&vectors[0]);
363 assert_eq!(codes.len(), 4);
364
365 let decoded = codebook.decode(&codes);
366 assert_eq!(decoded.len(), 32);
367 }
368
369 #[test]
370 fn decode_approximates_original() {
371 let vectors = random_vectors(500, 64, 42);
372 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
373 let codebook = PqCodebook::train(&refs, 8, 256, 20);
374
375 let mut total_error = 0.0f64;
377 for v in &vectors {
378 let codes = codebook.encode(v);
379 let decoded = codebook.decode(&codes);
380 let error: f64 = v
381 .iter()
382 .zip(decoded.iter())
383 .map(|(a, b)| ((*a - *b) as f64).powi(2))
384 .sum();
385 total_error += error;
386 }
387 let avg_error = total_error / vectors.len() as f64;
388
389 assert!(
391 avg_error < 10.0,
392 "avg reconstruction error too high: {avg_error:.4}"
393 );
394 }
395
396 #[test]
397 fn compression_ratio() {
398 let dim = 768;
399 let m = 8;
400 let n = 100;
401 let vectors = random_vectors(n, dim, 42);
402 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
403 let codebook = PqCodebook::train(&refs, m, 256, 5);
404
405 let mut total_code_bytes = 0usize;
407 let mut total_reconstruction_error = 0.0f64;
408
409 for v in &vectors {
410 let codes = codebook.encode(v);
411 total_code_bytes += codes.len();
412
413 let decoded = codebook.decode(&codes);
414 let error: f64 = v
415 .iter()
416 .zip(decoded.iter())
417 .map(|(a, b)| ((*a - *b) as f64).powi(2))
418 .sum();
419 total_reconstruction_error += error;
420 }
421
422 let original_bytes = n * dim * 4;
423 let ratio = original_bytes as f64 / total_code_bytes as f64;
424 let avg_error = total_reconstruction_error / n as f64;
425
426 assert!(
427 ratio >= 300.0,
428 "compression ratio = {ratio:.0}x, expected >= 300x for D={dim} M={m}"
429 );
430
431 assert_eq!(total_code_bytes, n * m);
433
434 assert!(
436 avg_error < 50.0,
437 "avg reconstruction error = {avg_error:.2}, expected < 50 for D={dim}"
438 );
439 }
440
441 #[test]
442 fn cold_store_put_get() {
443 let vectors = random_vectors(50, 32, 42);
444 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
445 let codebook = PqCodebook::train(&refs, 4, 16, 10);
446 let mut store = ColdStore::new(codebook);
447
448 store.put(1, 0, 1000, &vectors[0]);
449 store.put(1, 0, 2000, &vectors[1]);
450
451 assert_eq!(store.len(), 2);
452
453 let decoded = store.get(1, 0, 1000).unwrap();
454 assert_eq!(decoded.len(), 32);
455 }
456
457 #[test]
458 fn cold_store_get_nonexistent() {
459 let vectors = random_vectors(10, 16, 42);
460 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
461 let codebook = PqCodebook::train(&refs, 2, 8, 5);
462 let store = ColdStore::new(codebook);
463
464 assert!(store.get(999, 0, 0).is_none());
465 }
466
467 #[test]
468 fn adc_search() {
469 let dim = 32;
470 let vectors = random_vectors(200, dim, 42);
471 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
472 let codebook = PqCodebook::train(&refs, 4, 32, 10);
473 let mut store = ColdStore::new(codebook);
474
475 for (i, v) in vectors.iter().enumerate() {
476 store.put(i as u64, 0, (i as i64) * 1000, v);
477 }
478
479 let results = store.search_adc(&vectors[0], 5);
480 assert_eq!(results.len(), 5);
481
482 assert_eq!(results[0].0, 0, "closest should be the query vector itself");
484 }
485
486 #[test]
487 fn storage_bytes_compact() {
488 let vectors = random_vectors(1000, 768, 42);
489 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
490 let codebook = PqCodebook::train(&refs, 8, 256, 5);
491 let mut store = ColdStore::new(codebook);
492
493 for (i, v) in vectors.iter().enumerate() {
494 store.put(i as u64, 0, (i as i64) * 1000, v);
495 }
496
497 let original_bytes = 1000 * 768 * 4;
498 let cold_bytes = store.storage_bytes();
499 let ratio = original_bytes as f64 / cold_bytes as f64;
500
501 assert!(
502 ratio > 100.0,
503 "cold storage ratio = {ratio:.0}x, expected > 100x"
504 );
505 }
506}