RFC-003: Neural ODE via TorchScript
This RFC proposes adding a PyTorch-based Neural ODE prediction backend. Models are trained in Python (PyTorch) or Rust (tch-rs), exported as TorchScript (.pt), and loaded for inference. Gated behind an optional torch-backend feature flag.
Summary Table
Section titled “Summary Table”| Phase | Scope | Effort | Status |
|---|---|---|---|
| 1 | Feature flags & dependencies (tch 0.23) | Low | ✅ Done |
| 2 | Configuration (model_path in TOML) | Low | ✅ Done |
| 3 | TorchScript loader + inference (torch_ode.rs) | Medium | ✅ Done |
| 4 | Query engine integration (fallback to linear) | Medium | ✅ Done |
| 5 | Python training script (PyTorch) | Low | ✅ Done |
| 6 | Rust training API (tch-rs) | Medium | ✅ Done |
| 7 | Python bindings update (PyO3) | Low | ✅ Done |
Why tch-rs over burn
Section titled “Why tch-rs over burn”| Factor | tch-rs (libtorch) | burn |
|---|---|---|
| Ecosystem | Full PyTorch compat | Rust-native |
| Pre-trained models | Any .pt from HuggingFace | Requires re-impl |
| Training | Python + Rust | Rust only |
| Community | Millions of users | Growing |
Architecture
Section titled “Architecture”Training (Python/Rust) Inference (Rust)───────────────────── ────────────────────Train NeuralODEPredictor → TorchOdeModel::load()Export TorchScript .pt → model.predict(traj, t) Fallback → linear_extrapolate()Model I/O Contract
Section titled “Model I/O Contract”| Tensor | Shape | Description |
|---|---|---|
Input trajectory | [1, T, D+1] | normalized_time + vector per step |
Input target_t | [1, 1] | normalized target timestamp |
Output predicted | [1, D] | predicted vector |
Configuration
Section titled “Configuration”[analytics]neural_ode = truemodel_path = "./models/neural_ode_d128.pt"Error Handling
Section titled “Error Handling”Every failure falls back to linear_extrapolate():
| Scenario | Behavior |
|---|---|
torch-backend not compiled | Linear fallback (zero overhead) |
| Model file missing | Warning + linear fallback |
| Inference fails (NaN, bad dims) | Warning + linear fallback per request |
| CUDA unavailable | Automatic CPU fallback |
Pre-trained Models
Section titled “Pre-trained Models”No standard model exists for high-dim trajectory prediction. Strategy:
- Now: Ship without default model
- Future: Publish
manucouto1/cvx-neural-ode-{dim}on HuggingFace
References
Section titled “References”- Chen et al. “Neural ODEs.” NeurIPS 2018.
- Rubanova et al. “Latent ODEs for Irregularly-Sampled Time Series.” NeurIPS 2019.
- De Brouwer et al. “GRU-ODE-Bayes.” NeurIPS 2019.
- Kidger. “On Neural Differential Equations.” PhD Thesis, Oxford, 2022.
- tch-rs: https://github.com/LaurentMazare/tch-rs
See full RFC in design/CVX_RFC_003_NeuralODE_TorchBackend.md.