Skip to content

RFC-003: Neural ODE via TorchScript

This RFC proposes adding a PyTorch-based Neural ODE prediction backend. Models are trained in Python (PyTorch) or Rust (tch-rs), exported as TorchScript (.pt), and loaded for inference. Gated behind an optional torch-backend feature flag.

PhaseScopeEffortStatus
1Feature flags & dependencies (tch 0.23)Low✅ Done
2Configuration (model_path in TOML)Low✅ Done
3TorchScript loader + inference (torch_ode.rs)Medium✅ Done
4Query engine integration (fallback to linear)Medium✅ Done
5Python training script (PyTorch)Low✅ Done
6Rust training API (tch-rs)Medium✅ Done
7Python bindings update (PyO3)Low✅ Done
Factortch-rs (libtorch)burn
EcosystemFull PyTorch compatRust-native
Pre-trained modelsAny .pt from HuggingFaceRequires re-impl
TrainingPython + RustRust only
CommunityMillions of usersGrowing
Training (Python/Rust) Inference (Rust)
───────────────────── ────────────────────
Train NeuralODEPredictor → TorchOdeModel::load()
Export TorchScript .pt → model.predict(traj, t)
Fallback → linear_extrapolate()
TensorShapeDescription
Input trajectory[1, T, D+1]normalized_time + vector per step
Input target_t[1, 1]normalized target timestamp
Output predicted[1, D]predicted vector
[analytics]
neural_ode = true
model_path = "./models/neural_ode_d128.pt"

Every failure falls back to linear_extrapolate():

ScenarioBehavior
torch-backend not compiledLinear fallback (zero overhead)
Model file missingWarning + linear fallback
Inference fails (NaN, bad dims)Warning + linear fallback per request
CUDA unavailableAutomatic CPU fallback

No standard model exists for high-dim trajectory prediction. Strategy:

  1. Now: Ship without default model
  2. Future: Publish manucouto1/cvx-neural-ode-{dim} on HuggingFace
  1. Chen et al. “Neural ODEs.” NeurIPS 2018.
  2. Rubanova et al. “Latent ODEs for Irregularly-Sampled Time Series.” NeurIPS 2019.
  3. De Brouwer et al. “GRU-ODE-Bayes.” NeurIPS 2019.
  4. Kidger. “On Neural Differential Equations.” PhD Thesis, Oxford, 2022.
  5. tch-rs: https://github.com/LaurentMazare/tch-rs

See full RFC in design/CVX_RFC_003_NeuralODE_TorchBackend.md.