cvx_core/traits/
quantizer.rs

1//! Distance acceleration through vector quantization.
2//!
3//! The [`Quantizer`] trait abstracts over different strategies for accelerating
4//! distance computations in high-dimensional spaces. The HNSW graph stores both
5//! full-precision vectors (for exact operations like trajectories and signatures)
6//! and compact codes (for fast approximate distance during graph construction/search).
7//!
8//! # Two-Phase Distance Computation
9//!
10//! 1. **Candidate selection**: Use fast approximate distance on compact codes
11//! 2. **Final ranking**: Use exact distance on full vectors
12//!
13//! This mirrors production systems like Qdrant (Scalar Quantization),
14//! Faiss (Product Quantization), and Weaviate (Binary Quantization).
15//!
16//! # Available Strategies
17//!
18//! | Strategy | Code size (D=768) | Speedup | Recall impact |
19//! |----------|-------------------|---------|---------------|
20//! | [`NoQuantizer`] | 0 bytes | 1× | None |
21//! | Scalar (SQ8) | 768 bytes | ~4× | < 1% |
22//! | Product (PQ96) | 96 bytes | ~8× | 1-3% |
23//! | Binary (BQ) | 96 bytes | ~32× | 5-10% |
24//!
25//! # Example
26//!
27//! ```
28//! use cvx_core::traits::quantizer::{Quantizer, NoQuantizer, L2Fn};
29//!
30//! // No acceleration (default, exact distances)
31//! let q = NoQuantizer::new(L2Fn);
32//! let code = q.encode(&[1.0, 2.0, 3.0]);
33//! ```
34
35use super::DistanceMetric;
36
37/// Acceleration strategy for distance computations.
38///
39/// Implementations encode vectors into compact codes and provide
40/// fast approximate distance between codes.
41///
42/// The graph uses `distance_approx` for candidate exploration (hot path)
43/// and `distance_exact` for final neighbor selection (quality-critical).
44pub trait Quantizer: Send + Sync {
45    /// Compact representation of a vector.
46    ///
47    /// - `NoQuantizer`: `()` (zero overhead)
48    /// - Scalar: `Vec<u8>` (D bytes)
49    /// - Product: `Vec<u8>` (M bytes, where M = D/subvector_dim)
50    /// - Binary: `Vec<u64>` (D/64 words)
51    type Code: Clone + Send + Sync;
52
53    /// Encode a full-precision vector into a compact code.
54    ///
55    /// Called once per insert. The code is stored alongside the vector.
56    fn encode(&self, vector: &[f32]) -> Self::Code;
57
58    /// Fast approximate distance between two codes.
59    ///
60    /// This is the hot path: called O(ef_construction × log N) times per insert.
61    /// Must be significantly faster than `distance_exact` for the acceleration
62    /// to be worthwhile.
63    fn distance_approx(&self, a: &Self::Code, b: &Self::Code) -> f32;
64
65    /// Exact distance between full-precision vectors.
66    ///
67    /// Used for final neighbor selection (heuristic pruning) where
68    /// distance quality matters more than speed.
69    fn distance_exact(&self, a: &[f32], b: &[f32]) -> f32;
70
71    /// Whether this quantizer provides actual acceleration.
72    ///
73    /// When `false`, `distance_approx` is unused and the graph
74    /// calls `distance_exact` directly. This avoids storing codes.
75    fn is_accelerated(&self) -> bool;
76
77    /// Whether this quantizer needs training on a data sample before use.
78    ///
79    /// Product Quantization requires training a codebook; Scalar and Binary don't.
80    fn needs_training(&self) -> bool {
81        false
82    }
83
84    /// Train the quantizer on a sample of vectors.
85    ///
86    /// Only called when `needs_training()` returns true.
87    /// For PQ: trains the codebook via k-means on subvectors.
88    fn train(&mut self, _sample: &[&[f32]]) {}
89
90    /// Human-readable name of this strategy.
91    fn name(&self) -> &str;
92}
93
94// ─── NoQuantizer (default: exact distances) ─────────────────────
95
96/// Identity quantizer — no acceleration, exact distances only.
97///
98/// This is the default. Codes are zero-sized (no storage overhead).
99/// All distance computations use the underlying [`DistanceMetric`].
100#[derive(Clone)]
101pub struct NoQuantizer<D: DistanceMetric> {
102    metric: D,
103}
104
105impl<D: DistanceMetric> NoQuantizer<D> {
106    /// Create a no-acceleration quantizer wrapping the given metric.
107    pub fn new(metric: D) -> Self {
108        Self { metric }
109    }
110
111    /// Access the underlying metric.
112    pub fn metric(&self) -> &D {
113        &self.metric
114    }
115}
116
117impl<D: DistanceMetric> Quantizer for NoQuantizer<D> {
118    type Code = ();
119
120    fn encode(&self, _vector: &[f32]) -> Self::Code {}
121
122    fn distance_approx(&self, _a: &Self::Code, _b: &Self::Code) -> f32 {
123        // Never called when is_accelerated() returns false
124        0.0
125    }
126
127    fn distance_exact(&self, a: &[f32], b: &[f32]) -> f32 {
128        self.metric.distance(a, b)
129    }
130
131    fn is_accelerated(&self) -> bool {
132        false
133    }
134
135    fn name(&self) -> &str {
136        "none"
137    }
138}
139
140// ─── ScalarQuantizer (uint8, ~4× speedup) ───────────────────────
141
142/// Scalar Quantization: compress each float32 dimension to uint8.
143///
144/// For each dimension, maps the value range [min, max] → [0, 255].
145/// Distances are computed on uint8 values using integer arithmetic.
146///
147/// **Pros**: Simple, no training needed, ~4× distance speedup, <1% recall loss.
148/// **Cons**: Requires knowing the value range (uses [-1, 1] for normalized vectors).
149///
150/// This is what Qdrant uses by default for HNSW construction.
151#[derive(Clone)]
152pub struct ScalarQuantizer<D: DistanceMetric> {
153    metric: D,
154    /// Min value per dimension (for denormalization). Default: -1.0
155    min_val: f32,
156    /// Scale factor: 255.0 / (max_val - min_val)
157    scale: f32,
158}
159
160impl<D: DistanceMetric> ScalarQuantizer<D> {
161    /// Create a scalar quantizer for L2-normalized vectors (range [-1, 1]).
162    pub fn new(metric: D) -> Self {
163        Self {
164            metric,
165            min_val: -1.0,
166            scale: 255.0 / 2.0, // maps [-1, 1] → [0, 255]
167        }
168    }
169
170    /// Create a scalar quantizer with custom value range.
171    pub fn with_range(metric: D, min_val: f32, max_val: f32) -> Self {
172        let range = max_val - min_val;
173        Self {
174            metric,
175            min_val,
176            scale: if range > 0.0 { 255.0 / range } else { 1.0 },
177        }
178    }
179}
180
181impl<D: DistanceMetric> Quantizer for ScalarQuantizer<D> {
182    type Code = Vec<u8>;
183
184    fn encode(&self, vector: &[f32]) -> Self::Code {
185        vector
186            .iter()
187            .map(|&v| {
188                let normalized = (v - self.min_val) * self.scale;
189                normalized.clamp(0.0, 255.0) as u8
190            })
191            .collect()
192    }
193
194    fn distance_approx(&self, a: &Self::Code, b: &Self::Code) -> f32 {
195        // L2 distance on uint8 values (integer arithmetic, auto-vectorized by LLVM)
196        let mut sum: u32 = 0;
197        for i in 0..a.len() {
198            let diff = a[i] as i32 - b[i] as i32;
199            sum += (diff * diff) as u32;
200        }
201        // Scale back to float distance (approximate)
202        (sum as f32).sqrt() / self.scale
203    }
204
205    fn distance_exact(&self, a: &[f32], b: &[f32]) -> f32 {
206        self.metric.distance(a, b)
207    }
208
209    fn is_accelerated(&self) -> bool {
210        true
211    }
212
213    fn name(&self) -> &str {
214        "scalar_u8"
215    }
216}
217
218// ─── Helper for L2 distance function (used in NoQuantizer default) ──
219
220/// Simple L2 distance function for use with quantizers.
221#[derive(Clone, Copy)]
222pub struct L2Fn;
223
224impl DistanceMetric for L2Fn {
225    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
226        a.iter()
227            .zip(b.iter())
228            .map(|(x, y)| (x - y) * (x - y))
229            .sum::<f32>()
230            .sqrt()
231    }
232
233    fn name(&self) -> &str {
234        "l2"
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn no_quantizer_exact() {
244        let q = NoQuantizer::new(L2Fn);
245        assert!(!q.is_accelerated());
246        let d = q.distance_exact(&[1.0, 0.0], &[0.0, 1.0]);
247        assert!((d - std::f32::consts::SQRT_2).abs() < 1e-6);
248    }
249
250    #[test]
251    fn scalar_quantizer_encode_decode() {
252        let q = ScalarQuantizer::new(L2Fn);
253
254        // Normalized vector
255        let v = [0.5, -0.3, 0.0, 1.0, -1.0];
256        let code = q.encode(&v);
257        assert_eq!(code.len(), 5);
258        assert_eq!(code[4], 0); // -1.0 → 0
259        assert_eq!(code[3], 255); // 1.0 → 255
260
261        // Distance between identical codes should be ~0
262        let d = q.distance_approx(&code, &code);
263        assert!(d < 1e-6);
264    }
265
266    #[test]
267    fn scalar_quantizer_preserves_order() {
268        let q = ScalarQuantizer::new(L2Fn);
269
270        let a = [0.5, 0.3, 0.0];
271        let b = [0.6, 0.3, 0.0]; // close to a
272        let c = [-0.5, -0.3, 0.9]; // far from a
273
274        let code_a = q.encode(&a);
275        let code_b = q.encode(&b);
276        let code_c = q.encode(&c);
277
278        let d_ab = q.distance_approx(&code_a, &code_b);
279        let d_ac = q.distance_approx(&code_a, &code_c);
280
281        // Approximate distance should preserve ordering
282        assert!(d_ab < d_ac, "d(a,b)={d_ab} should be < d(a,c)={d_ac}");
283
284        // Exact distances for comparison
285        let exact_ab = q.distance_exact(&a, &b);
286        let exact_ac = q.distance_exact(&a, &c);
287        assert!(exact_ab < exact_ac);
288    }
289}