cvx_index/hnsw/
bayesian_scorer.rs

1//! Bayesian retrieval scoring for multi-factor candidate ranking (RFC-013 Part C).
2//!
3//! Replaces flat cosine scoring with a weighted composite:
4//!
5//! ```text
6//! score = w_sim * similarity
7//!       + w_recency * recency_factor
8//!       + w_reward * reward
9//!       + w_success * success_score
10//!       + w_region * region_match
11//! ```
12//!
13//! Weights are configurable and can be learned from online feedback
14//! (logistic regression on outcome data).
15
16use serde::{Deserialize, Serialize};
17
18/// Scoring weights for Bayesian retrieval ranking.
19///
20/// Each weight controls the contribution of a factor to the final score.
21/// Higher score = less relevant (distance-like). Factors are normalized
22/// to [0, 1] before weighting.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ScoringWeights {
25    /// Weight for semantic similarity (HNSW distance, normalized).
26    pub similarity: f32,
27    /// Weight for recency factor (1 - exp(-λ·age)).
28    pub recency: f32,
29    /// Weight for reward (1 - reward, so higher reward = lower score).
30    pub reward: f32,
31    /// Weight for typed-edge success score (1 - P(success)).
32    pub success: f32,
33    /// Weight for region match (0 if same region, 1 if different).
34    pub region_match: f32,
35}
36
37impl Default for ScoringWeights {
38    fn default() -> Self {
39        Self {
40            similarity: 1.0,
41            recency: 0.0,
42            reward: 0.0,
43            success: 0.0,
44            region_match: 0.0,
45        }
46    }
47}
48
49impl ScoringWeights {
50    /// Create weights with all factors active at equal strength.
51    pub fn balanced() -> Self {
52        Self {
53            similarity: 1.0,
54            recency: 0.3,
55            reward: 0.5,
56            success: 0.4,
57            region_match: 0.2,
58        }
59    }
60
61    /// Create weights optimized for agent memory retrieval.
62    ///
63    /// Prioritizes reward and success over recency.
64    pub fn agent_memory() -> Self {
65        Self {
66            similarity: 1.0,
67            recency: 0.1,
68            reward: 0.6,
69            success: 0.5,
70            region_match: 0.2,
71        }
72    }
73}
74
75/// Features for a single retrieval candidate.
76#[derive(Debug, Clone)]
77pub struct CandidateFeatures {
78    /// Node ID in the index.
79    pub node_id: u32,
80    /// Raw semantic distance from HNSW (unnormalized).
81    pub raw_distance: f32,
82    /// Normalized semantic distance [0, 1].
83    pub similarity: f32,
84    /// Recency factor [0, 1] (0 = most recent, 1 = oldest).
85    pub recency: f32,
86    /// Reward annotation [0, 1] (NaN → 0.5 default).
87    pub reward: f32,
88    /// Typed-edge success score [0, 1] from Beta prior.
89    pub success_score: f32,
90    /// Whether candidate is in the same HNSW region as the query.
91    pub region_match: bool,
92}
93
94/// Compute the Bayesian composite score for a candidate.
95///
96/// Lower score = more relevant (distance-like convention).
97pub fn score_candidate(candidate: &CandidateFeatures, weights: &ScoringWeights) -> f32 {
98    let reward_factor = if candidate.reward.is_nan() {
99        0.5 // uninformative
100    } else {
101        1.0 - candidate.reward // higher reward → lower (better) score
102    };
103
104    let success_factor = 1.0 - candidate.success_score; // higher success → lower score
105    let region_factor = if candidate.region_match { 0.0 } else { 1.0 };
106
107    weights.similarity * candidate.similarity
108        + weights.recency * candidate.recency
109        + weights.reward * reward_factor
110        + weights.success * success_factor
111        + weights.region_match * region_factor
112}
113
114/// Re-rank a list of candidates using Bayesian scoring.
115///
116/// Takes pre-computed features for each candidate, scores them, and
117/// returns the top-k sorted by composite score (ascending = best).
118pub fn rerank(
119    candidates: &[CandidateFeatures],
120    weights: &ScoringWeights,
121    k: usize,
122) -> Vec<(u32, f32)> {
123    let mut scored: Vec<(u32, f32)> = candidates
124        .iter()
125        .map(|c| (c.node_id, score_candidate(c, weights)))
126        .collect();
127
128    scored.sort_by(|a, b| a.1.total_cmp(&b.1));
129    scored.truncate(k);
130    scored
131}
132
133/// Online weight learning from outcome feedback.
134///
135/// Simple gradient update: if the retrieval led to success, decrease the
136/// score (make it more likely to be retrieved again). If failure, increase.
137///
138/// Uses a learning rate to control update speed.
139pub struct WeightLearner {
140    /// Current weights.
141    pub weights: ScoringWeights,
142    /// Learning rate for gradient updates.
143    pub learning_rate: f32,
144    /// Number of updates applied.
145    pub n_updates: usize,
146}
147
148impl WeightLearner {
149    /// Create a new learner with initial weights.
150    pub fn new(weights: ScoringWeights, learning_rate: f32) -> Self {
151        Self {
152            weights,
153            learning_rate,
154            n_updates: 0,
155        }
156    }
157
158    /// Update weights based on outcome feedback.
159    ///
160    /// `candidates`: the features of candidates that were retrieved.
161    /// `outcome`: 1.0 for success, 0.0 for failure.
162    ///
163    /// Adjusts weights to make factors that correlated with success
164    /// stronger, and factors that correlated with failure weaker.
165    pub fn update(&mut self, candidates: &[CandidateFeatures], outcome: f32) {
166        if candidates.is_empty() {
167            return;
168        }
169
170        // Direction: if success (outcome=1), we want to decrease score
171        // (make these candidates rank higher). If failure, increase.
172        let direction = if outcome > 0.5 { -1.0 } else { 1.0 };
173        let lr = self.learning_rate / (1.0 + self.n_updates as f32 * 0.01); // Decay LR
174
175        // Average features across candidates
176        let n = candidates.len() as f32;
177        let avg_sim: f32 = candidates.iter().map(|c| c.similarity).sum::<f32>() / n;
178        let avg_rec: f32 = candidates.iter().map(|c| c.recency).sum::<f32>() / n;
179        let avg_rew: f32 = candidates
180            .iter()
181            .map(|c| {
182                if c.reward.is_nan() {
183                    0.5
184                } else {
185                    1.0 - c.reward
186                }
187            })
188            .sum::<f32>()
189            / n;
190        let avg_suc: f32 = candidates
191            .iter()
192            .map(|c| 1.0 - c.success_score)
193            .sum::<f32>()
194            / n;
195        let avg_reg: f32 = candidates
196            .iter()
197            .map(|c| if c.region_match { 0.0 } else { 1.0 })
198            .sum::<f32>()
199            / n;
200
201        // Gradient step on each weight
202        self.weights.similarity += direction * lr * avg_sim;
203        self.weights.recency += direction * lr * avg_rec;
204        self.weights.reward += direction * lr * avg_rew;
205        self.weights.success += direction * lr * avg_suc;
206        self.weights.region_match += direction * lr * avg_reg;
207
208        // Clamp weights to [0, 2]
209        self.weights.similarity = self.weights.similarity.clamp(0.0, 2.0);
210        self.weights.recency = self.weights.recency.clamp(0.0, 2.0);
211        self.weights.reward = self.weights.reward.clamp(0.0, 2.0);
212        self.weights.success = self.weights.success.clamp(0.0, 2.0);
213        self.weights.region_match = self.weights.region_match.clamp(0.0, 2.0);
214
215        self.n_updates += 1;
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    fn make_candidate(
224        sim: f32,
225        recency: f32,
226        reward: f32,
227        success: f32,
228        region: bool,
229    ) -> CandidateFeatures {
230        CandidateFeatures {
231            node_id: 0,
232            raw_distance: sim * 2.0,
233            similarity: sim,
234            recency,
235            reward,
236            success_score: success,
237            region_match: region,
238        }
239    }
240
241    #[test]
242    fn default_weights_pure_similarity() {
243        let w = ScoringWeights::default();
244        let c = make_candidate(0.3, 0.8, 0.9, 0.7, true);
245        let score = score_candidate(&c, &w);
246        // Only similarity matters with default weights
247        assert!((score - 0.3).abs() < 0.01, "score = {score}");
248    }
249
250    #[test]
251    fn reward_lowers_score() {
252        let w = ScoringWeights::balanced();
253        let high_reward = make_candidate(0.5, 0.5, 0.9, 0.5, true);
254        let low_reward = make_candidate(0.5, 0.5, 0.1, 0.5, true);
255
256        let s_high = score_candidate(&high_reward, &w);
257        let s_low = score_candidate(&low_reward, &w);
258        assert!(
259            s_high < s_low,
260            "high reward ({s_high}) should score lower (better) than low ({s_low})"
261        );
262    }
263
264    #[test]
265    fn success_score_lowers_score() {
266        let w = ScoringWeights::balanced();
267        let high_success = make_candidate(0.5, 0.5, 0.5, 0.9, true);
268        let low_success = make_candidate(0.5, 0.5, 0.5, 0.1, true);
269
270        let s_high = score_candidate(&high_success, &w);
271        let s_low = score_candidate(&low_success, &w);
272        assert!(s_high < s_low);
273    }
274
275    #[test]
276    fn region_match_lowers_score() {
277        let w = ScoringWeights::balanced();
278        let same_region = make_candidate(0.5, 0.5, 0.5, 0.5, true);
279        let diff_region = make_candidate(0.5, 0.5, 0.5, 0.5, false);
280
281        let s_same = score_candidate(&same_region, &w);
282        let s_diff = score_candidate(&diff_region, &w);
283        assert!(s_same < s_diff);
284    }
285
286    #[test]
287    fn nan_reward_uses_default() {
288        let w = ScoringWeights::balanced();
289        let nan_reward = make_candidate(0.5, 0.5, f32::NAN, 0.5, true);
290        let mid_reward = make_candidate(0.5, 0.5, 0.5, 0.5, true);
291
292        let s_nan = score_candidate(&nan_reward, &w);
293        let s_mid = score_candidate(&mid_reward, &w);
294        assert!(
295            (s_nan - s_mid).abs() < 0.01,
296            "NaN should behave like 0.5 reward"
297        );
298    }
299
300    #[test]
301    fn rerank_sorts_by_composite() {
302        let w = ScoringWeights::agent_memory();
303        let candidates = vec![
304            CandidateFeatures {
305                node_id: 1,
306                raw_distance: 1.0,
307                similarity: 0.5,
308                recency: 0.5,
309                reward: 0.1,
310                success_score: 0.2,
311                region_match: false,
312            },
313            CandidateFeatures {
314                node_id: 2,
315                raw_distance: 0.8,
316                similarity: 0.4,
317                recency: 0.3,
318                reward: 0.9,
319                success_score: 0.8,
320                region_match: true,
321            },
322            CandidateFeatures {
323                node_id: 3,
324                raw_distance: 0.6,
325                similarity: 0.3,
326                recency: 0.8,
327                reward: 0.5,
328                success_score: 0.5,
329                region_match: true,
330            },
331        ];
332
333        let ranked = rerank(&candidates, &w, 2);
334        assert_eq!(ranked.len(), 2);
335        // Node 2 should rank first (best reward + success + region match)
336        assert_eq!(ranked[0].0, 2);
337    }
338
339    #[test]
340    fn weight_learner_success_decreases_weights() {
341        let mut learner = WeightLearner::new(ScoringWeights::balanced(), 0.1);
342        let initial_sim = learner.weights.similarity;
343
344        let candidates = vec![make_candidate(0.5, 0.3, 0.8, 0.7, true)];
345        learner.update(&candidates, 1.0); // success
346
347        // Similarity weight should decrease (lower score = better for these candidates)
348        assert!(learner.weights.similarity < initial_sim);
349    }
350
351    #[test]
352    fn weight_learner_failure_increases_weights() {
353        let mut learner = WeightLearner::new(ScoringWeights::balanced(), 0.1);
354        let initial_sim = learner.weights.similarity;
355
356        let candidates = vec![make_candidate(0.5, 0.3, 0.2, 0.3, false)];
357        learner.update(&candidates, 0.0); // failure
358
359        // Similarity weight should increase (push these candidates down)
360        assert!(learner.weights.similarity > initial_sim);
361    }
362
363    #[test]
364    fn weight_learner_clamps() {
365        let mut learner = WeightLearner::new(ScoringWeights::default(), 10.0); // huge LR
366        let candidates = vec![make_candidate(0.9, 0.9, 0.9, 0.9, true)];
367
368        for _ in 0..100 {
369            learner.update(&candidates, 0.0);
370        }
371
372        assert!(learner.weights.similarity <= 2.0);
373        assert!(learner.weights.recency <= 2.0);
374    }
375}