1use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct RegionMdp {
16 transitions: HashMap<(u32, String), HashMap<u32, u32>>,
18 rewards: HashMap<(u32, String), Vec<f32>>,
20 region_quality: HashMap<u32, Vec<f32>>,
22 n_transitions: usize,
24}
25
26impl RegionMdp {
27 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn learn_trajectory(&mut self, regions: &[u32], actions: &[String], reward: f32) {
38 let n = regions.len().min(actions.len());
39 for i in 0..n.saturating_sub(1) {
40 let s = regions[i];
41 let a = actions[i].clone();
42 let s_next = regions[i + 1];
43
44 *self
45 .transitions
46 .entry((s, a.clone()))
47 .or_default()
48 .entry(s_next)
49 .or_default() += 1;
50
51 self.rewards.entry((s, a)).or_default().push(reward);
52 self.n_transitions += 1;
53 }
54
55 for &s in regions {
57 self.region_quality.entry(s).or_default().push(reward);
58 }
59 }
60
61 pub fn action_success_rate(&self, region: u32, action_type: &str) -> f32 {
66 let key = (region, action_type.to_string());
67 if let Some(rewards) = self.rewards.get(&key) {
68 if !rewards.is_empty() {
69 let successes: f32 = rewards.iter().filter(|&&r| r > 0.5).count() as f32;
70 return (1.0 + successes) / (2.0 + rewards.len() as f32);
71 }
72 }
73 self.region_quality_score(region)
75 }
76
77 pub fn region_quality_score(&self, region: u32) -> f32 {
79 self.region_quality
80 .get(®ion)
81 .filter(|r| !r.is_empty())
82 .map(|r| r.iter().sum::<f32>() / r.len() as f32)
83 .unwrap_or(0.5)
84 }
85
86 pub fn best_actions(&self, region: u32) -> Vec<(String, f32)> {
90 let mut scores: HashMap<String, f32> = HashMap::new();
91 for (key, rewards) in &self.rewards {
92 if key.0 == region && !rewards.is_empty() {
93 let successes = rewards.iter().filter(|&&r| r > 0.5).count() as f32;
94 let rate = (1.0 + successes) / (2.0 + rewards.len() as f32);
95 scores.insert(key.1.clone(), rate);
96 }
97 }
98 let mut sorted: Vec<_> = scores.into_iter().collect();
99 sorted.sort_by(|a, b| b.1.total_cmp(&a.1));
100 sorted
101 }
102
103 pub fn transition_probability(&self, region: u32, action: &str, next_region: u32) -> f32 {
105 let key = (region, action.to_string());
106 let counts = match self.transitions.get(&key) {
107 Some(c) => c,
108 None => return 0.0,
109 };
110 let total: u32 = counts.values().sum();
111 if total == 0 {
112 return 0.0;
113 }
114 *counts.get(&next_region).unwrap_or(&0) as f32 / total as f32
115 }
116
117 pub fn decay_factor(&self, region: u32) -> f32 {
122 let q = self.region_quality_score(region);
123 0.70 + 0.25 * q
124 }
125
126 pub fn format_hints(&self, region: u32, top_n: usize) -> String {
128 let best = self.best_actions(region);
129 if best.is_empty() {
130 return String::new();
131 }
132 let hints: Vec<String> = best
133 .iter()
134 .take(top_n)
135 .map(|(a, s)| format!("{a}({:.0}%)", s * 100.0))
136 .collect();
137 format!("Action success rates: {}", hints.join(", "))
138 }
139
140 pub fn n_transitions(&self) -> usize {
142 self.n_transitions
143 }
144
145 pub fn n_state_actions(&self) -> usize {
147 self.rewards.len()
148 }
149
150 pub fn n_regions(&self) -> usize {
152 self.region_quality.len()
153 }
154
155 pub fn stats(&self) -> String {
157 format!(
158 "{} transitions, {} state-action pairs, {} regions",
159 self.n_transitions,
160 self.n_state_actions(),
161 self.n_regions()
162 )
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn learn_and_query() {
172 let mut mdp = RegionMdp::new();
173 mdp.learn_trajectory(
175 &[0, 1, 2],
176 &["navigate".into(), "take".into(), "place".into()],
177 1.0,
178 );
179 mdp.learn_trajectory(
181 &[0, 1, 3],
182 &["navigate".into(), "open".into(), "take".into()],
183 0.0,
184 );
185
186 assert_eq!(mdp.n_transitions(), 4);
187
188 let rate = mdp.action_success_rate(0, "navigate");
190 assert!((rate - 0.5).abs() < 0.01, "rate = {rate}");
191
192 let rate = mdp.action_success_rate(1, "take");
194 assert!((rate - 0.667).abs() < 0.02, "rate = {rate}");
195 }
196
197 #[test]
198 fn best_actions_sorted() {
199 let mut mdp = RegionMdp::new();
200 for _ in 0..3 {
202 mdp.learn_trajectory(&[0, 1], &["navigate".into(), "done".into()], 1.0);
203 }
204 mdp.learn_trajectory(&[0, 2], &["open".into(), "fail".into()], 0.0);
205
206 let best = mdp.best_actions(0);
207 assert!(!best.is_empty());
208 assert_eq!(best[0].0, "navigate");
209 assert!(best[0].1 > best.last().unwrap().1);
210 }
211
212 #[test]
213 fn transition_probability() {
214 let mut mdp = RegionMdp::new();
215 mdp.learn_trajectory(&[0, 1], &["go".into(), "x".into()], 1.0);
216 mdp.learn_trajectory(&[0, 1], &["go".into(), "x".into()], 1.0);
217 mdp.learn_trajectory(&[0, 2], &["go".into(), "x".into()], 0.0);
218
219 let p1 = mdp.transition_probability(0, "go", 1);
220 let p2 = mdp.transition_probability(0, "go", 2);
221 assert!((p1 - 0.667).abs() < 0.02, "p1 = {p1}");
222 assert!((p2 - 0.333).abs() < 0.02, "p2 = {p2}");
223 assert!((p1 + p2 - 1.0).abs() < 0.01);
224 }
225
226 #[test]
227 fn region_quality_and_decay() {
228 let mut mdp = RegionMdp::new();
229 for _ in 0..3 {
231 mdp.learn_trajectory(&[0, 1], &["a".into(), "b".into()], 1.0);
232 }
233 mdp.learn_trajectory(&[0, 1], &["a".into(), "b".into()], 0.0);
234
235 let q = mdp.region_quality_score(0);
236 assert!((q - 0.75).abs() < 0.01, "quality = {q}");
237
238 let d = mdp.decay_factor(0);
239 assert!((d - 0.8875).abs() < 0.01, "decay = {d}");
241 }
242
243 #[test]
244 fn unknown_region_defaults() {
245 let mdp = RegionMdp::new();
246 assert!((mdp.action_success_rate(99, "x") - 0.5).abs() < 0.01);
247 assert!((mdp.region_quality_score(99) - 0.5).abs() < 0.01);
248 assert!((mdp.decay_factor(99) - 0.825).abs() < 0.01);
249 }
250
251 #[test]
252 fn format_hints() {
253 let mut mdp = RegionMdp::new();
254 for _ in 0..5 {
255 mdp.learn_trajectory(&[0, 1], &["navigate".into(), "done".into()], 1.0);
256 }
257 mdp.learn_trajectory(&[0, 2], &["open".into(), "fail".into()], 0.0);
258
259 let hints = mdp.format_hints(0, 3);
260 assert!(hints.contains("navigate"), "hints = {hints}");
261 assert!(hints.contains("%"));
262 }
263
264 #[test]
265 fn serialization_roundtrip() {
266 let mut mdp = RegionMdp::new();
267 mdp.learn_trajectory(&[0, 1, 2], &["a".into(), "b".into(), "c".into()], 1.0);
268
269 let bytes = postcard::to_allocvec(&mdp).unwrap();
270 let restored: RegionMdp = postcard::from_bytes(&bytes).unwrap();
271
272 assert_eq!(restored.n_transitions(), 2);
273 assert!(
274 (restored.action_success_rate(0, "a") - mdp.action_success_rate(0, "a")).abs() < 0.01
275 );
276 }
277}