1use serde::{Deserialize, Serialize};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ScoringWeights {
25 pub similarity: f32,
27 pub recency: f32,
29 pub reward: f32,
31 pub success: f32,
33 pub region_match: f32,
35}
36
37impl Default for ScoringWeights {
38 fn default() -> Self {
39 Self {
40 similarity: 1.0,
41 recency: 0.0,
42 reward: 0.0,
43 success: 0.0,
44 region_match: 0.0,
45 }
46 }
47}
48
49impl ScoringWeights {
50 pub fn balanced() -> Self {
52 Self {
53 similarity: 1.0,
54 recency: 0.3,
55 reward: 0.5,
56 success: 0.4,
57 region_match: 0.2,
58 }
59 }
60
61 pub fn agent_memory() -> Self {
65 Self {
66 similarity: 1.0,
67 recency: 0.1,
68 reward: 0.6,
69 success: 0.5,
70 region_match: 0.2,
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct CandidateFeatures {
78 pub node_id: u32,
80 pub raw_distance: f32,
82 pub similarity: f32,
84 pub recency: f32,
86 pub reward: f32,
88 pub success_score: f32,
90 pub region_match: bool,
92}
93
94pub fn score_candidate(candidate: &CandidateFeatures, weights: &ScoringWeights) -> f32 {
98 let reward_factor = if candidate.reward.is_nan() {
99 0.5 } else {
101 1.0 - candidate.reward };
103
104 let success_factor = 1.0 - candidate.success_score; let region_factor = if candidate.region_match { 0.0 } else { 1.0 };
106
107 weights.similarity * candidate.similarity
108 + weights.recency * candidate.recency
109 + weights.reward * reward_factor
110 + weights.success * success_factor
111 + weights.region_match * region_factor
112}
113
114pub fn rerank(
119 candidates: &[CandidateFeatures],
120 weights: &ScoringWeights,
121 k: usize,
122) -> Vec<(u32, f32)> {
123 let mut scored: Vec<(u32, f32)> = candidates
124 .iter()
125 .map(|c| (c.node_id, score_candidate(c, weights)))
126 .collect();
127
128 scored.sort_by(|a, b| a.1.total_cmp(&b.1));
129 scored.truncate(k);
130 scored
131}
132
133pub struct WeightLearner {
140 pub weights: ScoringWeights,
142 pub learning_rate: f32,
144 pub n_updates: usize,
146}
147
148impl WeightLearner {
149 pub fn new(weights: ScoringWeights, learning_rate: f32) -> Self {
151 Self {
152 weights,
153 learning_rate,
154 n_updates: 0,
155 }
156 }
157
158 pub fn update(&mut self, candidates: &[CandidateFeatures], outcome: f32) {
166 if candidates.is_empty() {
167 return;
168 }
169
170 let direction = if outcome > 0.5 { -1.0 } else { 1.0 };
173 let lr = self.learning_rate / (1.0 + self.n_updates as f32 * 0.01); let n = candidates.len() as f32;
177 let avg_sim: f32 = candidates.iter().map(|c| c.similarity).sum::<f32>() / n;
178 let avg_rec: f32 = candidates.iter().map(|c| c.recency).sum::<f32>() / n;
179 let avg_rew: f32 = candidates
180 .iter()
181 .map(|c| {
182 if c.reward.is_nan() {
183 0.5
184 } else {
185 1.0 - c.reward
186 }
187 })
188 .sum::<f32>()
189 / n;
190 let avg_suc: f32 = candidates
191 .iter()
192 .map(|c| 1.0 - c.success_score)
193 .sum::<f32>()
194 / n;
195 let avg_reg: f32 = candidates
196 .iter()
197 .map(|c| if c.region_match { 0.0 } else { 1.0 })
198 .sum::<f32>()
199 / n;
200
201 self.weights.similarity += direction * lr * avg_sim;
203 self.weights.recency += direction * lr * avg_rec;
204 self.weights.reward += direction * lr * avg_rew;
205 self.weights.success += direction * lr * avg_suc;
206 self.weights.region_match += direction * lr * avg_reg;
207
208 self.weights.similarity = self.weights.similarity.clamp(0.0, 2.0);
210 self.weights.recency = self.weights.recency.clamp(0.0, 2.0);
211 self.weights.reward = self.weights.reward.clamp(0.0, 2.0);
212 self.weights.success = self.weights.success.clamp(0.0, 2.0);
213 self.weights.region_match = self.weights.region_match.clamp(0.0, 2.0);
214
215 self.n_updates += 1;
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 fn make_candidate(
224 sim: f32,
225 recency: f32,
226 reward: f32,
227 success: f32,
228 region: bool,
229 ) -> CandidateFeatures {
230 CandidateFeatures {
231 node_id: 0,
232 raw_distance: sim * 2.0,
233 similarity: sim,
234 recency,
235 reward,
236 success_score: success,
237 region_match: region,
238 }
239 }
240
241 #[test]
242 fn default_weights_pure_similarity() {
243 let w = ScoringWeights::default();
244 let c = make_candidate(0.3, 0.8, 0.9, 0.7, true);
245 let score = score_candidate(&c, &w);
246 assert!((score - 0.3).abs() < 0.01, "score = {score}");
248 }
249
250 #[test]
251 fn reward_lowers_score() {
252 let w = ScoringWeights::balanced();
253 let high_reward = make_candidate(0.5, 0.5, 0.9, 0.5, true);
254 let low_reward = make_candidate(0.5, 0.5, 0.1, 0.5, true);
255
256 let s_high = score_candidate(&high_reward, &w);
257 let s_low = score_candidate(&low_reward, &w);
258 assert!(
259 s_high < s_low,
260 "high reward ({s_high}) should score lower (better) than low ({s_low})"
261 );
262 }
263
264 #[test]
265 fn success_score_lowers_score() {
266 let w = ScoringWeights::balanced();
267 let high_success = make_candidate(0.5, 0.5, 0.5, 0.9, true);
268 let low_success = make_candidate(0.5, 0.5, 0.5, 0.1, true);
269
270 let s_high = score_candidate(&high_success, &w);
271 let s_low = score_candidate(&low_success, &w);
272 assert!(s_high < s_low);
273 }
274
275 #[test]
276 fn region_match_lowers_score() {
277 let w = ScoringWeights::balanced();
278 let same_region = make_candidate(0.5, 0.5, 0.5, 0.5, true);
279 let diff_region = make_candidate(0.5, 0.5, 0.5, 0.5, false);
280
281 let s_same = score_candidate(&same_region, &w);
282 let s_diff = score_candidate(&diff_region, &w);
283 assert!(s_same < s_diff);
284 }
285
286 #[test]
287 fn nan_reward_uses_default() {
288 let w = ScoringWeights::balanced();
289 let nan_reward = make_candidate(0.5, 0.5, f32::NAN, 0.5, true);
290 let mid_reward = make_candidate(0.5, 0.5, 0.5, 0.5, true);
291
292 let s_nan = score_candidate(&nan_reward, &w);
293 let s_mid = score_candidate(&mid_reward, &w);
294 assert!(
295 (s_nan - s_mid).abs() < 0.01,
296 "NaN should behave like 0.5 reward"
297 );
298 }
299
300 #[test]
301 fn rerank_sorts_by_composite() {
302 let w = ScoringWeights::agent_memory();
303 let candidates = vec![
304 CandidateFeatures {
305 node_id: 1,
306 raw_distance: 1.0,
307 similarity: 0.5,
308 recency: 0.5,
309 reward: 0.1,
310 success_score: 0.2,
311 region_match: false,
312 },
313 CandidateFeatures {
314 node_id: 2,
315 raw_distance: 0.8,
316 similarity: 0.4,
317 recency: 0.3,
318 reward: 0.9,
319 success_score: 0.8,
320 region_match: true,
321 },
322 CandidateFeatures {
323 node_id: 3,
324 raw_distance: 0.6,
325 similarity: 0.3,
326 recency: 0.8,
327 reward: 0.5,
328 success_score: 0.5,
329 region_match: true,
330 },
331 ];
332
333 let ranked = rerank(&candidates, &w, 2);
334 assert_eq!(ranked.len(), 2);
335 assert_eq!(ranked[0].0, 2);
337 }
338
339 #[test]
340 fn weight_learner_success_decreases_weights() {
341 let mut learner = WeightLearner::new(ScoringWeights::balanced(), 0.1);
342 let initial_sim = learner.weights.similarity;
343
344 let candidates = vec![make_candidate(0.5, 0.3, 0.8, 0.7, true)];
345 learner.update(&candidates, 1.0); assert!(learner.weights.similarity < initial_sim);
349 }
350
351 #[test]
352 fn weight_learner_failure_increases_weights() {
353 let mut learner = WeightLearner::new(ScoringWeights::balanced(), 0.1);
354 let initial_sim = learner.weights.similarity;
355
356 let candidates = vec![make_candidate(0.5, 0.3, 0.2, 0.3, false)];
357 learner.update(&candidates, 0.0); assert!(learner.weights.similarity > initial_sim);
361 }
362
363 #[test]
364 fn weight_learner_clamps() {
365 let mut learner = WeightLearner::new(ScoringWeights::default(), 10.0); let candidates = vec![make_candidate(0.9, 0.9, 0.9, 0.9, true)];
367
368 for _ in 0..100 {
369 learner.update(&candidates, 0.0);
370 }
371
372 assert!(learner.weights.similarity <= 2.0);
373 assert!(learner.weights.recency <= 2.0);
374 }
375}