cvx_analytics/
fisher_rao.rs1pub fn fisher_rao_distance(p: &[f64], q: &[f64]) -> f64 {
43 assert_eq!(p.len(), q.len(), "distributions must have equal length");
44
45 let bhattacharyya_coeff: f64 = p
46 .iter()
47 .zip(q.iter())
48 .map(|(&pi, &qi)| (pi.max(0.0) * qi.max(0.0)).sqrt())
49 .sum();
50
51 2.0 * bhattacharyya_coeff.clamp(0.0, 1.0).acos()
53}
54
55pub fn fisher_rao_distance_f32(p: &[f32], q: &[f32]) -> f64 {
57 let p64: Vec<f64> = p.iter().map(|&v| v as f64).collect();
58 let q64: Vec<f64> = q.iter().map(|&v| v as f64).collect();
59 fisher_rao_distance(&p64, &q64)
60}
61
62pub fn bhattacharyya_coefficient(p: &[f64], q: &[f64]) -> f64 {
67 assert_eq!(p.len(), q.len());
68 p.iter()
69 .zip(q.iter())
70 .map(|(&pi, &qi)| (pi.max(0.0) * qi.max(0.0)).sqrt())
71 .sum()
72}
73
74pub fn hellinger_distance(p: &[f64], q: &[f64]) -> f64 {
79 let bc = bhattacharyya_coefficient(p, q);
80 ((1.0 - bc).max(0.0) / 2.0).sqrt()
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86 use std::f64::consts::PI;
87
88 #[test]
89 fn identical_distributions_zero() {
90 let p = vec![0.3, 0.5, 0.2];
91 assert!(fisher_rao_distance(&p, &p) < 1e-10);
92 }
93
94 #[test]
95 fn disjoint_distributions_pi() {
96 let p = vec![1.0, 0.0, 0.0];
97 let q = vec![0.0, 1.0, 0.0];
98 let d = fisher_rao_distance(&p, &q);
99 assert!((d - PI).abs() < 1e-10, "disjoint should be π, got {d}");
100 }
101
102 #[test]
103 fn symmetric() {
104 let p = vec![0.4, 0.3, 0.3];
105 let q = vec![0.1, 0.6, 0.3];
106 let d_pq = fisher_rao_distance(&p, &q);
107 let d_qp = fisher_rao_distance(&q, &p);
108 assert!((d_pq - d_qp).abs() < 1e-10);
109 }
110
111 #[test]
112 fn triangle_inequality() {
113 let p = vec![0.5, 0.3, 0.2];
114 let q = vec![0.1, 0.4, 0.5];
115 let r = vec![0.3, 0.3, 0.4];
116 let d_pq = fisher_rao_distance(&p, &q);
117 let d_qr = fisher_rao_distance(&q, &r);
118 let d_pr = fisher_rao_distance(&p, &r);
119 assert!(
120 d_pr <= d_pq + d_qr + 1e-10,
121 "triangle: d(p,r)={d_pr} > d(p,q)+d(q,r)={}",
122 d_pq + d_qr
123 );
124 }
125
126 #[test]
127 fn bounded_zero_to_pi() {
128 let cases: Vec<(Vec<f64>, Vec<f64>)> = vec![
129 (vec![0.5, 0.5], vec![0.5, 0.5]),
130 (vec![1.0, 0.0], vec![0.0, 1.0]),
131 (vec![0.9, 0.1], vec![0.1, 0.9]),
132 (vec![0.25, 0.25, 0.25, 0.25], vec![0.7, 0.1, 0.1, 0.1]),
133 ];
134 for (p, q) in &cases {
135 let d = fisher_rao_distance(p, q);
136 assert!((0.0..=PI + 1e-10).contains(&d), "d={d} out of [0, π]");
137 }
138 }
139
140 #[test]
141 fn bhattacharyya_coefficient_range() {
142 let p = vec![0.3, 0.7];
143 let q = vec![0.6, 0.4];
144 let bc = bhattacharyya_coefficient(&p, &q);
145 assert!((0.0..=1.0).contains(&bc), "BC={bc} out of [0, 1]");
146 }
147
148 #[test]
149 fn hellinger_range() {
150 let p = vec![0.3, 0.7];
151 let q = vec![0.6, 0.4];
152 let h = hellinger_distance(&p, &q);
153 assert!((0.0..=1.0).contains(&h), "H={h} out of [0, 1]");
154 }
155
156 #[test]
157 fn f32_wrapper_matches() {
158 let p32 = vec![0.3f32, 0.5, 0.2];
159 let q32 = vec![0.1f32, 0.6, 0.3];
160 let p64: Vec<f64> = p32.iter().map(|&v| v as f64).collect();
161 let q64: Vec<f64> = q32.iter().map(|&v| v as f64).collect();
162 let d32 = fisher_rao_distance_f32(&p32, &q32);
163 let d64 = fisher_rao_distance(&p64, &q64);
164 assert!((d32 - d64).abs() < 1e-6);
165 }
166}