cvx_index/hnsw/
region_mdp.rs

1//! Region MDP: Markov Decision Process over HNSW semantic regions (RFC-013 Part A).
2//!
3//! HNSW regions define a discrete state space. Episode trajectories define
4//! transitions. Rewards define outcomes. This module learns P(success | region,
5//! action_type) from observed trajectories.
6//!
7//! States are HNSW hub node IDs at a given level. Actions are abstract types
8//! represented as string labels (e.g., "navigate", "take", "place").
9
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// A transition model over HNSW regions.
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct RegionMdp {
16    /// (region, action_type) → {next_region: count}.
17    transitions: HashMap<(u32, String), HashMap<u32, u32>>,
18    /// (region, action_type) → [rewards].
19    rewards: HashMap<(u32, String), Vec<f32>>,
20    /// region → [episode_rewards].
21    region_quality: HashMap<u32, Vec<f32>>,
22    /// Total transitions learned.
23    n_transitions: usize,
24}
25
26impl RegionMdp {
27    /// Create an empty MDP.
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    /// Learn from a single episode trajectory.
33    ///
34    /// `regions`: sequence of region IDs the episode traversed.
35    /// `actions`: corresponding action types (same length as regions).
36    /// `reward`: episode outcome (1.0 = success, 0.0 = failure).
37    pub fn learn_trajectory(&mut self, regions: &[u32], actions: &[String], reward: f32) {
38        let n = regions.len().min(actions.len());
39        for i in 0..n.saturating_sub(1) {
40            let s = regions[i];
41            let a = actions[i].clone();
42            let s_next = regions[i + 1];
43
44            *self
45                .transitions
46                .entry((s, a.clone()))
47                .or_default()
48                .entry(s_next)
49                .or_default() += 1;
50
51            self.rewards.entry((s, a)).or_default().push(reward);
52            self.n_transitions += 1;
53        }
54
55        // Track per-region quality
56        for &s in regions {
57            self.region_quality.entry(s).or_default().push(reward);
58        }
59    }
60
61    /// P(success | region, action_type) using Beta prior.
62    ///
63    /// Returns (1 + n_successes) / (2 + n_total) for Bayesian smoothing.
64    /// Falls back to region-level quality if no data for this (region, action).
65    pub fn action_success_rate(&self, region: u32, action_type: &str) -> f32 {
66        let key = (region, action_type.to_string());
67        if let Some(rewards) = self.rewards.get(&key) {
68            if !rewards.is_empty() {
69                let successes: f32 = rewards.iter().filter(|&&r| r > 0.5).count() as f32;
70                return (1.0 + successes) / (2.0 + rewards.len() as f32);
71            }
72        }
73        // Fallback: region-level quality
74        self.region_quality_score(region)
75    }
76
77    /// Overall quality of a region (mean reward of episodes passing through).
78    pub fn region_quality_score(&self, region: u32) -> f32 {
79        self.region_quality
80            .get(&region)
81            .filter(|r| !r.is_empty())
82            .map(|r| r.iter().sum::<f32>() / r.len() as f32)
83            .unwrap_or(0.5)
84    }
85
86    /// Rank action types by success rate in a given region.
87    ///
88    /// Returns `(action_type, success_rate)` sorted descending.
89    pub fn best_actions(&self, region: u32) -> Vec<(String, f32)> {
90        let mut scores: HashMap<String, f32> = HashMap::new();
91        for (key, rewards) in &self.rewards {
92            if key.0 == region && !rewards.is_empty() {
93                let successes = rewards.iter().filter(|&&r| r > 0.5).count() as f32;
94                let rate = (1.0 + successes) / (2.0 + rewards.len() as f32);
95                scores.insert(key.1.clone(), rate);
96            }
97        }
98        let mut sorted: Vec<_> = scores.into_iter().collect();
99        sorted.sort_by(|a, b| b.1.total_cmp(&a.1));
100        sorted
101    }
102
103    /// P(s' | s, a) — transition probability.
104    pub fn transition_probability(&self, region: u32, action: &str, next_region: u32) -> f32 {
105        let key = (region, action.to_string());
106        let counts = match self.transitions.get(&key) {
107            Some(c) => c,
108            None => return 0.0,
109        };
110        let total: u32 = counts.values().sum();
111        if total == 0 {
112            return 0.0;
113        }
114        *counts.get(&next_region).unwrap_or(&0) as f32 / total as f32
115    }
116
117    /// Context-aware decay factor based on region quality.
118    ///
119    /// High-quality region → small decay (0.95).
120    /// Low-quality region → large decay (0.70).
121    pub fn decay_factor(&self, region: u32) -> f32 {
122        let q = self.region_quality_score(region);
123        0.70 + 0.25 * q
124    }
125
126    /// Format action hints for a region as a string.
127    pub fn format_hints(&self, region: u32, top_n: usize) -> String {
128        let best = self.best_actions(region);
129        if best.is_empty() {
130            return String::new();
131        }
132        let hints: Vec<String> = best
133            .iter()
134            .take(top_n)
135            .map(|(a, s)| format!("{a}({:.0}%)", s * 100.0))
136            .collect();
137        format!("Action success rates: {}", hints.join(", "))
138    }
139
140    /// Total transitions learned.
141    pub fn n_transitions(&self) -> usize {
142        self.n_transitions
143    }
144
145    /// Number of distinct (region, action) pairs observed.
146    pub fn n_state_actions(&self) -> usize {
147        self.rewards.len()
148    }
149
150    /// Number of distinct regions observed.
151    pub fn n_regions(&self) -> usize {
152        self.region_quality.len()
153    }
154
155    /// Summary statistics.
156    pub fn stats(&self) -> String {
157        format!(
158            "{} transitions, {} state-action pairs, {} regions",
159            self.n_transitions,
160            self.n_state_actions(),
161            self.n_regions()
162        )
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn learn_and_query() {
172        let mut mdp = RegionMdp::new();
173        // Episode 1: regions [0, 1, 2], actions [navigate, take, place], success
174        mdp.learn_trajectory(
175            &[0, 1, 2],
176            &["navigate".into(), "take".into(), "place".into()],
177            1.0,
178        );
179        // Episode 2: same regions, different actions, failure
180        mdp.learn_trajectory(
181            &[0, 1, 3],
182            &["navigate".into(), "open".into(), "take".into()],
183            0.0,
184        );
185
186        assert_eq!(mdp.n_transitions(), 4);
187
188        // navigate from region 0: 1 success + 1 failure → P = (1+1)/(2+2) = 0.5
189        let rate = mdp.action_success_rate(0, "navigate");
190        assert!((rate - 0.5).abs() < 0.01, "rate = {rate}");
191
192        // take from region 1: 1 success → P = (1+1)/(2+1) = 0.67
193        let rate = mdp.action_success_rate(1, "take");
194        assert!((rate - 0.667).abs() < 0.02, "rate = {rate}");
195    }
196
197    #[test]
198    fn best_actions_sorted() {
199        let mut mdp = RegionMdp::new();
200        // 3 successes with navigate, 1 failure with open
201        for _ in 0..3 {
202            mdp.learn_trajectory(&[0, 1], &["navigate".into(), "done".into()], 1.0);
203        }
204        mdp.learn_trajectory(&[0, 2], &["open".into(), "fail".into()], 0.0);
205
206        let best = mdp.best_actions(0);
207        assert!(!best.is_empty());
208        assert_eq!(best[0].0, "navigate");
209        assert!(best[0].1 > best.last().unwrap().1);
210    }
211
212    #[test]
213    fn transition_probability() {
214        let mut mdp = RegionMdp::new();
215        mdp.learn_trajectory(&[0, 1], &["go".into(), "x".into()], 1.0);
216        mdp.learn_trajectory(&[0, 1], &["go".into(), "x".into()], 1.0);
217        mdp.learn_trajectory(&[0, 2], &["go".into(), "x".into()], 0.0);
218
219        let p1 = mdp.transition_probability(0, "go", 1);
220        let p2 = mdp.transition_probability(0, "go", 2);
221        assert!((p1 - 0.667).abs() < 0.02, "p1 = {p1}");
222        assert!((p2 - 0.333).abs() < 0.02, "p2 = {p2}");
223        assert!((p1 + p2 - 1.0).abs() < 0.01);
224    }
225
226    #[test]
227    fn region_quality_and_decay() {
228        let mut mdp = RegionMdp::new();
229        // Region 0: 3 successes, 1 failure → quality ~0.75
230        for _ in 0..3 {
231            mdp.learn_trajectory(&[0, 1], &["a".into(), "b".into()], 1.0);
232        }
233        mdp.learn_trajectory(&[0, 1], &["a".into(), "b".into()], 0.0);
234
235        let q = mdp.region_quality_score(0);
236        assert!((q - 0.75).abs() < 0.01, "quality = {q}");
237
238        let d = mdp.decay_factor(0);
239        // 0.70 + 0.25 * 0.75 = 0.8875
240        assert!((d - 0.8875).abs() < 0.01, "decay = {d}");
241    }
242
243    #[test]
244    fn unknown_region_defaults() {
245        let mdp = RegionMdp::new();
246        assert!((mdp.action_success_rate(99, "x") - 0.5).abs() < 0.01);
247        assert!((mdp.region_quality_score(99) - 0.5).abs() < 0.01);
248        assert!((mdp.decay_factor(99) - 0.825).abs() < 0.01);
249    }
250
251    #[test]
252    fn format_hints() {
253        let mut mdp = RegionMdp::new();
254        for _ in 0..5 {
255            mdp.learn_trajectory(&[0, 1], &["navigate".into(), "done".into()], 1.0);
256        }
257        mdp.learn_trajectory(&[0, 2], &["open".into(), "fail".into()], 0.0);
258
259        let hints = mdp.format_hints(0, 3);
260        assert!(hints.contains("navigate"), "hints = {hints}");
261        assert!(hints.contains("%"));
262    }
263
264    #[test]
265    fn serialization_roundtrip() {
266        let mut mdp = RegionMdp::new();
267        mdp.learn_trajectory(&[0, 1, 2], &["a".into(), "b".into(), "c".into()], 1.0);
268
269        let bytes = postcard::to_allocvec(&mdp).unwrap();
270        let restored: RegionMdp = postcard::from_bytes(&bytes).unwrap();
271
272        assert_eq!(restored.n_transitions(), 2);
273        assert!(
274            (restored.action_success_rate(0, "a") - mdp.action_success_rate(0, "a")).abs() < 0.01
275        );
276    }
277}