diff --git a/FURTHER_DOCUMENTATION.md b/FURTHER_DOCUMENTATION.md
index 4e7604a25..5c5933339 100644
--- a/FURTHER_DOCUMENTATION.md
+++ b/FURTHER_DOCUMENTATION.md
@@ -57,3 +57,21 @@ For this solver, `rtol` and `atol` correspond to the tolerance for convergence o
- `adjoint_params`: The parameters to compute gradients with respect to in the backward pass. Should be a tuple of tensors. Defaults to `tuple(func.parameters())`.
- If passed then `func` does not have to be a `torch.nn.Module`.
- If `func` has no parameters, `adjoint_params=()` must be specified.
+
+
+ ## Callbacks
+
+ Callbacks can be triggered during the solve. Callbacks should be specified as methods of the `func` argument to `odeint` and `odeint_adjoint`.
+
+ At the moment support for this is minimal: let us know if you'd find additional callbacks useful.
+
+ **callback_step(self, t0, y0, dt):**
+ This is called immediately before taking a step of size `dt`, at time `t0`, with current solution value `y0`. This is supported by every solver except `scipy_solver`.
+
+ **callback_accept_step(self, t0, y0, dt):**
+ This is called when accepting a step of size `dt` at time `t0`, with current solution value `y0`. This is supported by the adaptive solvers (dopri8, dopri5, bosh3, adaptive_heun).
+
+ **callback_reject_step(self, t0, y0, dt):**
+ As `callback_accept_step`, except called when rejecting steps.
+
+ In addition, callbacks can be triggered during the adjoint pass by adding `_adjoint` to the name of any one of the supported callbacks, e.g. `callback_step_adjoint`.
\ No newline at end of file
diff --git a/README.md b/README.md
index 3e250dc84..8f911bba6 100644
--- a/README.md
+++ b/README.md
@@ -70,11 +70,11 @@ odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=ode
The solve is terminated at an event time `t` and state `y` when an element of `event_fn(t, y)` is equal to zero. Multiple outputs from `event_fn` can be used to specify multiple event functions, of which the first to trigger will terminate the solve.
-Both the event time and final state are returned from `odeint_event`, and can be differentiated. Gradients will be backpropagated through the event function.
+Both the event time and final state are returned from `odeint_event`, and can be differentiated. Gradients will be backpropagated through the event function. **NOTE**: parameters for the event function must be in the state itself to obtain gradients.
The numerical precision for the event time is determined by the `atol` argument.
-See example of simulating and differentiating through a bouncing ball in [`examples/bouncing_ball.py`](./examples/bouncing_ball.py).
+See example of simulating and differentiating through a bouncing ball in [`examples/bouncing_ball.py`](./examples/bouncing_ball.py). See example code for learning a simple event function in [`examples/learn_physics.py`](./examples/learn_physics.py).
@@ -118,13 +118,8 @@ For details of the adjoint-specific and solver-specific options, check out the [
Applications of differentiable ODE solvers and event handling are discussed in these two papers:
-[1] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." *Advances in Neural Information Processing Systems.* 2018. [[arxiv]](https://arxiv.org/abs/1806.07366)
+Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." *Advances in Neural Information Processing Systems.* 2018. [[arxiv]](https://arxiv.org/abs/1806.07366)
-[2] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel. "Learning Neural Event Functions for Ordinary Differential Equations." *International Conference on Learning Representations.* 2021. [[arxiv]](https://arxiv.org/abs/2011.03902)
-
----
-
-If you found this library useful in your research, please consider citing.
```
@article{chen2018neuralode,
title={Neural Ordinary Differential Equations},
@@ -132,7 +127,11 @@ If you found this library useful in your research, please consider citing.
journal={Advances in Neural Information Processing Systems},
year={2018}
}
+```
+Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel. "Learning Neural Event Functions for Ordinary Differential Equations." *International Conference on Learning Representations.* 2021. [[arxiv]](https://arxiv.org/abs/2011.03902)
+
+```
@article{chen2021eventfn,
title={Learning Neural Event Functions for Ordinary Differential Equations},
author={Chen, Ricky T. Q. and Amos, Brandon and Nickel, Maximilian},
@@ -140,3 +139,28 @@ If you found this library useful in your research, please consider citing.
year={2021}
}
```
+
+The seminorm option for computing adjoints is discussed in
+
+Patrick Kidger, Ricky T. Q. Chen, Terry Lyons. "'Hey, that’s not an ODE': Faster ODE Adjoints via Seminorms." *International Conference on Machine
+Learning.* 2021. [[arxiv]](https://arxiv.org/abs/2009.09457)
+```
+@article{kidger2021hey,
+ title={"Hey, that's not an ODE": Faster ODE Adjoints via Seminorms.},
+ author={Kidger, Patrick and Chen, Ricky T. Q. and Lyons, Terry J.},
+ journal={International Conference on Machine Learning},
+ year={2021}
+}
+```
+
+---
+
+If you found this library useful in your research, please consider citing.
+```
+@misc{torchdiffeq,
+ author={Chen, Ricky T. Q.},
+ title={torchdiffeq},
+ year={2018},
+ url={https://github.com/rtqichen/torchdiffeq},
+}
+```
diff --git a/examples/bouncing_ball.py b/examples/bouncing_ball.py
index f9cc88ba8..714a29eec 100644
--- a/examples/bouncing_ball.py
+++ b/examples/bouncing_ball.py
@@ -12,7 +12,6 @@
class BouncingBallExample(nn.Module):
-
def __init__(self, radius=0.2, gravity=9.8, adjoint=False):
super().__init__()
self.gravity = nn.Parameter(torch.as_tensor([gravity]))
@@ -39,9 +38,11 @@ def get_initial_state(self):
return self.t0, state
def state_update(self, state):
- """ Updates state based on an event (collision)."""
+ """Updates state based on an event (collision)."""
pos, vel, log_radius = state
- pos = pos + 1e-7 # need to add a small eps so as not to trigger the event function immediately.
+ pos = (
+ pos + 1e-7
+ ) # need to add a small eps so as not to trigger the event function immediately.
vel = -vel * (1 - self.absorption)
return (pos, vel, log_radius)
@@ -52,7 +53,16 @@ def get_collision_times(self, nbounces=1):
t0, state = self.get_initial_state()
for i in range(nbounces):
- event_t, solution = odeint_event(self, state, t0, event_fn=self.event_fn, reverse_time=False, atol=1e-8, rtol=1e-8, odeint_interface=self.odeint)
+ event_t, solution = odeint_event(
+ self,
+ state,
+ t0,
+ event_fn=self.event_fn,
+ reverse_time=False,
+ atol=1e-8,
+ rtol=1e-8,
+ odeint_interface=self.odeint,
+ )
event_times.append(event_t)
state = self.state_update(tuple(s[-1] for s in solution))
@@ -69,18 +79,25 @@ def simulate(self, nbounces=1):
velocity = [state[1][None]]
times = [t0.reshape(-1)]
for event_t in event_times:
- tt = torch.linspace(float(t0), float(event_t), int((float(event_t) - float(t0)) * 50))[1:-1]
+ tt = torch.linspace(
+ float(t0), float(event_t), int((float(event_t) - float(t0)) * 50)
+ )[1:-1]
tt = torch.cat([t0.reshape(-1), tt, event_t.reshape(-1)])
solution = odeint(self, state, tt, atol=1e-8, rtol=1e-8)
- trajectory.append(solution[0])
- velocity.append(solution[1])
- times.append(tt)
+ trajectory.append(solution[0][1:])
+ velocity.append(solution[1][1:])
+ times.append(tt[1:])
state = self.state_update(tuple(s[-1] for s in solution))
t0 = event_t
- return torch.cat(times), torch.cat(trajectory, dim=0).reshape(-1), torch.cat(velocity, dim=0).reshape(-1), event_times
+ return (
+ torch.cat(times),
+ torch.cat(trajectory, dim=0).reshape(-1),
+ torch.cat(velocity, dim=0).reshape(-1),
+ event_times,
+ )
def gradcheck(nbounces):
@@ -124,7 +141,9 @@ def gradcheck(nbounces):
fd = fd_grads[var]
if torch.norm(analytical - fd) > 1e-4:
success = False
- print(f"Got analytical grad {analytical.item()} for {var} param but finite difference is {fd.item()}")
+ print(
+ f"Got analytical grad {analytical.item()} for {var} param but finite difference is {fd.item()}"
+ )
if not success:
raise Exception("Gradient check failed.")
@@ -152,10 +171,20 @@ def gradcheck(nbounces):
# Event locations.
for event_t in event_times:
- plt.plot(event_t, 0.0, color="C0", marker="o", markersize=7, fillstyle='none', linestyle="")
-
- vel, = plt.plot(times, velocity, color="C1", alpha=0.7, linestyle="--", linewidth=2.0)
- pos, = plt.plot(times, trajectory, color="C0", linewidth=2.0)
+ plt.plot(
+ event_t,
+ 0.0,
+ color="C0",
+ marker="o",
+ markersize=7,
+ fillstyle="none",
+ linestyle="",
+ )
+
+ (vel,) = plt.plot(
+ times, velocity, color="C1", alpha=0.7, linestyle="--", linewidth=2.0
+ )
+ (pos,) = plt.plot(times, trajectory, color="C0", linewidth=2.0)
plt.hlines(0, 0, 100)
plt.xlim([times[0], times[-1]])
@@ -164,16 +193,20 @@ def gradcheck(nbounces):
plt.xlabel("Time", fontsize=13)
plt.legend([pos, vel], ["Position", "Velocity"], fontsize=16)
- plt.gca().xaxis.set_tick_params(direction='in', which='both') # The bottom will maintain the default of 'out'
- plt.gca().yaxis.set_tick_params(direction='in', which='both') # The bottom will maintain the default of 'out'
+ plt.gca().xaxis.set_tick_params(
+ direction="in", which="both"
+ ) # The bottom will maintain the default of 'out'
+ plt.gca().yaxis.set_tick_params(
+ direction="in", which="both"
+ ) # The bottom will maintain the default of 'out'
# Hide the right and top spines
- plt.gca().spines['right'].set_visible(False)
- plt.gca().spines['top'].set_visible(False)
+ plt.gca().spines["right"].set_visible(False)
+ plt.gca().spines["top"].set_visible(False)
# Only show ticks on the left and bottom spines
- plt.gca().yaxis.set_ticks_position('left')
- plt.gca().xaxis.set_ticks_position('bottom')
+ plt.gca().yaxis.set_ticks_position("left")
+ plt.gca().xaxis.set_ticks_position("bottom")
plt.tight_layout()
plt.savefig("bouncing_ball.png")
diff --git a/examples/learn_physics.py b/examples/learn_physics.py
new file mode 100644
index 000000000..824852e4b
--- /dev/null
+++ b/examples/learn_physics.py
@@ -0,0 +1,285 @@
+#!/usr/bin/env python3
+
+import argparse
+import os
+import math
+import matplotlib.pyplot as plt
+import torch
+import torch.nn as nn
+from torchdiffeq import odeint, odeint_event
+
+from bouncing_ball import BouncingBallExample
+
+
+class HamiltonianDynamics(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.dvel = nn.Linear(1, 1)
+ self.scale = nn.Parameter(torch.tensor(10.0))
+
+ def forward(self, t, state):
+ pos, vel, *rest = state
+ dpos = vel
+ dvel = torch.tanh(self.dvel(torch.zeros_like(vel))) * self.scale
+ return (dpos, dvel, *[torch.zeros_like(r) for r in rest])
+
+
+class EventFn(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.radius = nn.Parameter(torch.rand(1))
+
+ def parameters(self):
+ return [self.radius]
+
+ def forward(self, t, state):
+ # IMPORTANT: event computation must use variables from the state.
+ pos, _, radius = state
+ return pos - radius.reshape_as(pos) ** 2
+
+
+class InstantaneousStateChange(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.net = nn.Linear(1, 1)
+
+ def forward(self, t, state):
+ pos, vel, *rest = state
+ vel = -torch.sigmoid(self.net(torch.ones_like(vel))) * vel
+ return (pos, vel, *rest)
+
+
+class NeuralPhysics(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.initial_pos = nn.Parameter(torch.tensor([10.0]))
+ self.initial_vel = nn.Parameter(torch.tensor([0.0]))
+ self.dynamics_fn = HamiltonianDynamics()
+ self.event_fn = EventFn()
+ self.inst_update = InstantaneousStateChange()
+
+ def simulate(self, times):
+
+ t0 = torch.tensor([0.0]).to(times)
+
+ # Add a terminal time to the event function.
+ def event_fn(t, state):
+ if t > times[-1] + 1e-7:
+ return torch.zeros_like(t)
+ event_fval = self.event_fn(t, state)
+ return event_fval
+
+ # IMPORTANT: for gradients of odeint_event to be computed, parameters of the event function
+ # must appear in the state in the current implementation.
+ state = (self.initial_pos, self.initial_vel, *self.event_fn.parameters())
+
+ event_times = []
+
+ trajectory = [state[0][None]]
+
+ n_events = 0
+ max_events = 20
+
+ while t0 < times[-1] and n_events < max_events:
+ last = n_events == max_events - 1
+
+ if not last:
+ event_t, solution = odeint_event(
+ self.dynamics_fn,
+ state,
+ t0,
+ event_fn=event_fn,
+ atol=1e-8,
+ rtol=1e-8,
+ method="dopri5",
+ )
+ else:
+ event_t = times[-1]
+
+ interval_ts = times[times > t0]
+ interval_ts = interval_ts[interval_ts <= event_t]
+ interval_ts = torch.cat([t0.reshape(-1), interval_ts.reshape(-1)])
+
+ solution_ = odeint(
+ self.dynamics_fn, state, interval_ts, atol=1e-8, rtol=1e-8
+ )
+ traj_ = solution_[0][1:] # [0] for position; [1:] to remove intial state.
+ trajectory.append(traj_)
+
+ if event_t < times[-1]:
+ state = tuple(s[-1] for s in solution)
+
+ # update velocity instantaneously.
+ state = self.inst_update(event_t, state)
+
+ # advance the position a little bit to avoid re-triggering the event fn.
+ pos, *rest = state
+ pos = pos + 1e-7 * self.dynamics_fn(event_t, state)[0]
+ state = pos, *rest
+
+ event_times.append(event_t)
+ t0 = event_t
+
+ n_events += 1
+
+ # print(event_t.item(), state[0].item(), state[1].item(), self.event_fn.mod(pos).item())
+
+ trajectory = torch.cat(trajectory, dim=0).reshape(-1)
+ return trajectory, event_times
+
+
+class Sine(nn.Module):
+ def forward(self, x):
+ return torch.sin(x)
+
+
+class NeuralODE(nn.Module):
+ def __init__(self, aug_dim=2):
+ super().__init__()
+ self.initial_pos = nn.Parameter(torch.tensor([10.0]))
+ self.initial_aug = nn.Parameter(torch.zeros(aug_dim))
+ self.odefunc = mlp(
+ input_dim=1 + aug_dim,
+ hidden_dim=64,
+ output_dim=1 + aug_dim,
+ hidden_depth=2,
+ act=Sine,
+ )
+
+ def init(m):
+ if isinstance(m, nn.Linear):
+ std = 1.0 / math.sqrt(m.weight.size(1))
+ m.weight.data.uniform_(-2.0 * std, 2.0 * std)
+ m.bias.data.zero_()
+
+ self.odefunc.apply(init)
+
+ def forward(self, t, state):
+ return self.odefunc(state)
+
+ def simulate(self, times):
+ x0 = torch.cat([self.initial_pos, self.initial_aug]).reshape(-1)
+ solution = odeint(self, x0, times, atol=1e-8, rtol=1e-8, method="dopri5")
+ trajectory = solution[:, 0]
+ return trajectory, []
+
+
+def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None, act=nn.ReLU):
+ if hidden_depth == 0:
+ mods = [nn.Linear(input_dim, output_dim)]
+ else:
+ mods = [nn.Linear(input_dim, hidden_dim), act()]
+ for i in range(hidden_depth - 1):
+ mods += [nn.Linear(hidden_dim, hidden_dim), act()]
+ mods.append(nn.Linear(hidden_dim, output_dim))
+ if output_mod is not None:
+ mods.append(output_mod)
+ trunk = nn.Sequential(*mods)
+ return trunk
+
+
+def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0):
+ global_step = min(global_step, decay_steps)
+ cosine_decay = 0.5 * (1 + math.cos(math.pi * global_step / decay_steps))
+ decayed = (1 - alpha) * cosine_decay + alpha
+ return learning_rate * decayed
+
+
+def learning_rate_schedule(
+ global_step, warmup_steps, base_learning_rate, lr_scaling, train_steps
+):
+ warmup_steps = int(round(warmup_steps))
+ scaled_lr = base_learning_rate * lr_scaling
+ if warmup_steps:
+ learning_rate = global_step / warmup_steps * scaled_lr
+ else:
+ learning_rate = scaled_lr
+
+ if global_step < warmup_steps:
+ learning_rate = learning_rate
+ else:
+ learning_rate = cosine_decay(
+ scaled_lr, global_step - warmup_steps, train_steps - warmup_steps
+ )
+ return learning_rate
+
+
+def set_learning_rate(optimizer, lr):
+ for group in optimizer.param_groups:
+ group["lr"] = lr
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base_lr", type=float, default=0.1)
+ parser.add_argument("--num_iterations", type=int, default=1000)
+ parser.add_argument("--no_events", action="http://github.com/store_true")
+ parser.add_argument("--save", type=str, default="figs")
+ args = parser.parse_args()
+
+ torch.manual_seed(0)
+
+ torch.set_default_dtype(torch.float64)
+
+ with torch.no_grad():
+ system = BouncingBallExample()
+ obs_times, gt_trajectory, _, _ = system.simulate(nbounces=4)
+
+ obs_times = obs_times[:300]
+ gt_trajectory = gt_trajectory[:300]
+
+ if args.no_events:
+ model = NeuralODE()
+ else:
+ model = NeuralPhysics()
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.base_lr)
+
+ decay = 1.0
+
+ model.train()
+ for itr in range(args.num_iterations):
+ optimizer.zero_grad()
+ trajectory, event_times = model.simulate(obs_times)
+ weights = decay**obs_times
+ loss = (
+ ((trajectory - gt_trajectory) / (gt_trajectory + 1e-3))
+ .abs()
+ .mul(weights)
+ .mean()
+ )
+ loss.backward()
+
+ lr = learning_rate_schedule(itr, 0, args.base_lr, 1.0, args.num_iterations)
+ set_learning_rate(optimizer, lr)
+ optimizer.step()
+
+ if itr % 10 == 0:
+ print(itr, loss.item(), len(event_times))
+
+ if itr % 10 == 0:
+ plt.figure()
+ plt.plot(
+ obs_times.detach().cpu().numpy(),
+ gt_trajectory.detach().cpu().numpy(),
+ label="Target",
+ )
+ plt.plot(
+ obs_times.detach().cpu().numpy(),
+ trajectory.detach().cpu().numpy(),
+ label="Learned",
+ )
+ plt.tight_layout()
+ os.makedirs(args.save, exist_ok=True)
+ plt.savefig(f"{args.save}/{itr:05d}.png")
+ plt.close()
+
+ if (itr + 1) % 100 == 0:
+ torch.save(
+ {
+ "state_dict": model.state_dict(),
+ },
+ f"{args.save}/model.pt",
+ )
+
+ del trajectory, loss
diff --git a/setup.py b/setup.py
index e3d21e5a5..c671a064c 100644
--- a/setup.py
+++ b/setup.py
@@ -21,7 +21,7 @@
description="ODE solvers and adjoint sensitivity analysis in PyTorch.",
url="https://github.com/rtqichen/torchdiffeq",
packages=setuptools.find_packages(),
- install_requires=['torch>=1.3.0', 'scipy>=1.4.0'],
+ install_requires=['torch>=1.5.0', 'scipy>=1.4.0'],
python_requires='~=3.6',
classifiers=[
"Programming Language :: Python :: 3",
diff --git a/tests/api_tests.py b/tests/api_tests.py
index 99c7fa6fe..3dc208586 100644
--- a/tests/api_tests.py
+++ b/tests/api_tests.py
@@ -5,7 +5,7 @@
from problems import construct_problem, DTYPES, DEVICES, ADAPTIVE_METHODS
-EPS = {torch.float32: 1e-5, torch.float64: 1e-12}
+EPS = {torch.float32: 1e-4, torch.float64: 1e-12, torch.complex64: 1e-4}
class TestCollectionState(unittest.TestCase):
@@ -20,8 +20,8 @@ def test_forward(self):
with self.subTest(dtype=dtype, device=device, method=method):
tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method=method)
- max_error0 = (sol - tuple_y[0]).max()
- max_error1 = (sol - tuple_y[1]).max()
+ max_error0 = (sol - tuple_y[0]).abs().max()
+ max_error1 = (sol - tuple_y[1]).abs().max()
self.assertLess(max_error0, eps)
self.assertLess(max_error1, eps)
diff --git a/tests/event_tests.py b/tests/event_tests.py
index daa9e3e3d..176e3fc91 100644
--- a/tests/event_tests.py
+++ b/tests/event_tests.py
@@ -25,8 +25,10 @@ def test_odeint(self):
with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, method=method):
if method == "explicit_adams":
tol = 7e-2
- elif method == "euler":
+ elif method == "euler" or method == "implicit_euler":
tol = 5e-3
+ elif method == "gl6":
+ tol = 2e-3
else:
tol = 1e-4
@@ -34,7 +36,7 @@ def test_odeint(self):
reverse=reverse)
def event_fn(t, y):
- return torch.sum(y - sol[2])
+ return torch.sum(y - sol[2]).real
if method in FIXED_METHODS:
options = {"step_size": 0.01, "interp": "cubic"}
diff --git a/tests/gradient_tests.py b/tests/gradient_tests.py
index 08aff899d..bbb746486 100644
--- a/tests/gradient_tests.py
+++ b/tests/gradient_tests.py
@@ -44,6 +44,8 @@ def test_adjoint_against_odeint(self):
eps = 1e-5
elif ode == 'sine':
eps = 5e-3
+ elif ode == 'exp':
+ eps = 1e-2
else:
raise RuntimeError
@@ -102,7 +104,6 @@ def forward(self, t, y):
return func, y0, t_points
def test_against_dopri5(self):
- # TODO: add in adaptive adams if/when it's fixed.
method_eps = {
'dopri5': (3e-4, 1e-4, 2e-3),
'scipy_solver': (3e-4, 1e-4, 2e-3),
diff --git a/tests/norm_tests.py b/tests/norm_tests.py
index 70b444eef..30abbda25 100644
--- a/tests/norm_tests.py
+++ b/tests/norm_tests.py
@@ -257,7 +257,7 @@ def large_norm(tensor):
with self.subTest(dtype=dtype, device=device, method=method):
x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype)
- t = torch.tensor([0., 1.0], device=device, dtype=dtype)
+ t = torch.tensor([0., 1.0], device=device, dtype=torch.float64)
norm_f = _NeuralF(width=10, oscillate=True).to(device, dtype)
torchdiffeq.odeint(norm_f, x0, t, method=method, options=dict(norm=norm))
@@ -273,16 +273,22 @@ def test_seminorm(self):
for dtype in DTYPES:
for device in DEVICES:
for method in ADAPTIVE_METHODS:
+ # Tests with known failures
+ if (
+ dtype in [torch.float32] and
+ method in ['tsit5']
+ ):
+ continue
with self.subTest(dtype=dtype, device=device, method=method):
- if dtype == torch.float32:
- tol = 1e-6
- else:
+ if dtype == torch.float64:
tol = 1e-8
+ else:
+ tol = 1e-6
x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype)
- t = torch.tensor([0., 1.0], device=device, dtype=dtype)
+ t = torch.tensor([0., 1.0], device=device, dtype=torch.float64)
ode_f = _NeuralF(width=1024, oscillate=True).to(device, dtype)
diff --git a/tests/odeint_tests.py b/tests/odeint_tests.py
index 5ebf07378..620839eef 100644
--- a/tests/odeint_tests.py
+++ b/tests/odeint_tests.py
@@ -5,7 +5,7 @@
import torch
import torchdiffeq
-from problems import (construct_problem, PROBLEMS, DTYPES, DEVICES, METHODS, ADAPTIVE_METHODS, FIXED_METHODS)
+from problems import (construct_problem, PROBLEMS, DTYPES, DEVICES, METHODS, ADAPTIVE_METHODS, FIXED_METHODS, SCIPY_METHODS, IMPLICIT_METHODS)
def rel_error(true, estimate):
@@ -13,12 +13,17 @@ def rel_error(true, estimate):
class TestSolverError(unittest.TestCase):
+
def test_odeint(self):
for reverse in (False, True):
for dtype in DTYPES:
for device in DEVICES:
for method in METHODS:
+ if method in SCIPY_METHODS and dtype == torch.complex64:
+ # scipy solvers don't support complex types.
+ continue
+
kwargs = dict()
# Have to increase tolerance for dopri8.
if method == 'dopri8' and dtype == torch.float64:
@@ -26,12 +31,23 @@ def test_odeint(self):
if method == 'dopri8' and dtype == torch.float32:
kwargs = dict(rtol=1e-7, atol=1e-7)
- problems = PROBLEMS if method in ADAPTIVE_METHODS else ('constant',)
+ if method in ADAPTIVE_METHODS:
+ if method in IMPLICIT_METHODS:
+ problems = PROBLEMS
+ else:
+ problems = tuple(problem for problem in PROBLEMS)
+ elif method in IMPLICIT_METHODS:
+ problems = ('constant', 'exp')
+ else:
+ problems = ('constant',)
+
for ode in problems:
if method in ['adaptive_heun', 'bosh3']:
eps = 4e-3
elif ode == 'linear':
eps = 2e-3
+ elif ode == 'exp':
+ eps = 5e-2
else:
eps = 3e-4
@@ -71,6 +87,9 @@ def test_odeint(self):
with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, solver=solver):
f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode,
reverse=reverse)
+ if torch.is_complex(y0) and solver in ["Radau", "LSODA"]:
+ # scipy solvers don't support complex types.
+ continue
y = torchdiffeq.odeint(f, y0, t_points, method='scipy_solver', options={"solver": solver})
self.assertTrue(sol.shape == y.shape)
self.assertLess(rel_error(sol, y), eps)
@@ -87,6 +106,7 @@ def test_odeint(self):
with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, method=method):
f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode,
reverse=reverse)
+
y = torchdiffeq.odeint(f, y0, t_points[0:1], method=method)
self.assertLess((sol[0] - y).abs().max(), 1e-12)
@@ -112,6 +132,10 @@ def test_odeint_jump_t(self):
with self.subTest(adjoint=adjoint, dtype=dtype, device=device, method=method):
+ if method == "dopri8":
+ # Doesn't seem to work for some reason.
+ continue
+
x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype, requires_grad=True)
t = torch.tensor([0., 1.0], device=device)
@@ -142,6 +166,11 @@ def test_odeint_perturb(self):
for dtype in DTYPES:
for device in DEVICES:
for method in FIXED_METHODS:
+
+ # Singluar matrix error with float32 and implicit_euler
+ if dtype == torch.float32 and method == 'implicit_euler':
+ continue
+
for perturb in (True, False):
with self.subTest(adjoint=adjoint, dtype=dtype, device=device, method=method,
perturb=perturb):
@@ -219,5 +248,143 @@ def grid_constructor(f, y0, t):
self.assertLess((x0.grad - true_x0_grad).abs().max(), 1e-6)
+class TestMinMaxStep(unittest.TestCase):
+ def test_min_max_step(self):
+ # LSODA will complain about convergence otherwise
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ for device in DEVICES:
+ for min_step in (0, 2):
+ for max_step in (float('inf'), 5):
+ for method, options in [('dopri5', {}), ('scipy_solver', {"solver": "LSODA"})]:
+ options['min_step'] = min_step
+ options['max_step'] = max_step
+ f, y0, t_points, sol = construct_problem(device=device, ode="linear")
+ torchdiffeq.odeint(f, y0, t_points, method=method, options=options)
+ # Check min step produces far fewer evaluations
+ if min_step > 0:
+ self.assertLess(f.nfe, 50)
+ else:
+ self.assertGreater(f.nfe, 100)
+
+
+class _NeuralF(torch.nn.Module):
+ def __init__(self, width, oscillate):
+ super(_NeuralF, self).__init__()
+ self.linears = torch.nn.Sequential(torch.nn.Linear(2, width),
+ torch.nn.Tanh(),
+ torch.nn.Linear(width, 2),
+ torch.nn.Tanh())
+ self.nfe = 0
+ self.oscillate = oscillate
+
+ def forward(self, t, x):
+ self.nfe += 1
+ out = self.linears(x)
+ if self.oscillate:
+ out = out * t.mul(20).sin()
+ return out
+
+
+class TestCallbacks(unittest.TestCase):
+ def test_wrong_callback(self):
+ x0 = torch.tensor([1.0, 2.0])
+ t = torch.tensor([0., 1.0])
+
+ for method in FIXED_METHODS:
+ for callback_name in ('callback_accept_step', 'callback_reject_step'):
+ with self.subTest(method=method):
+ f = _NeuralF(width=10, oscillate=False)
+ setattr(f, callback_name, lambda t0, y0, dt: None)
+ with self.assertWarns(Warning):
+ torchdiffeq.odeint(f, x0, t, method=method)
+
+ for method in SCIPY_METHODS:
+ for callback_name in ('callback_step', 'callback_accept_step', 'callback_reject_step'):
+ with self.subTest(method=method):
+ f = _NeuralF(width=10, oscillate=False)
+ setattr(f, callback_name, lambda t0, y0, dt: None)
+ with self.assertWarns(Warning):
+ torchdiffeq.odeint(f, x0, t, method=method)
+
+ def test_steps(self):
+ for forward, adjoint in ((False, True), (True, False), (True, True)):
+ for method in FIXED_METHODS + ADAPTIVE_METHODS:
+ if method == 'dopri8': # using torch.float32
+ continue
+ with self.subTest(forward=forward, adjoint=adjoint, method=method):
+
+ f = _NeuralF(width=10, oscillate=False)
+
+ if forward:
+ forward_counter = 0
+ forward_accept_counter = 0
+ forward_reject_counter = 0
+
+ def callback_step(t0, y0, dt):
+ nonlocal forward_counter
+ forward_counter += 1
+
+ def callback_accept_step(t0, y0, dt):
+ nonlocal forward_accept_counter
+ forward_accept_counter += 1
+
+ def callback_reject_step(t0, y0, dt):
+ nonlocal forward_reject_counter
+ forward_reject_counter += 1
+
+ f.callback_step = callback_step
+ if method in ADAPTIVE_METHODS:
+ f.callback_accept_step = callback_accept_step
+ f.callback_reject_step = callback_reject_step
+
+ if adjoint:
+ adjoint_counter = 0
+ adjoint_accept_counter = 0
+ adjoint_reject_counter = 0
+
+ def callback_step_adjoint(t0, y0, dt):
+ nonlocal adjoint_counter
+ adjoint_counter += 1
+
+ def callback_accept_step_adjoint(t0, y0, dt):
+ nonlocal adjoint_accept_counter
+ adjoint_accept_counter += 1
+
+ def callback_reject_step_adjoint(t0, y0, dt):
+ nonlocal adjoint_reject_counter
+ adjoint_reject_counter += 1
+
+ f.callback_step_adjoint = callback_step_adjoint
+ if method in ADAPTIVE_METHODS:
+ f.callback_accept_step_adjoint = callback_accept_step_adjoint
+ f.callback_reject_step_adjoint = callback_reject_step_adjoint
+
+ x0 = torch.tensor([1.0, 2.0])
+ t = torch.tensor([0., 1.0])
+
+ if method in FIXED_METHODS:
+ kwargs = dict(options=dict(step_size=0.1))
+ elif method == 'implicit_adams':
+ kwargs = dict(rtol=1e-3, atol=1e-4)
+ else:
+ kwargs = {}
+ xs = torchdiffeq.odeint_adjoint(f, x0, t, method=method, **kwargs)
+
+ if forward:
+ if method in FIXED_METHODS:
+ self.assertEqual(forward_counter, 10)
+ if method in ADAPTIVE_METHODS:
+ self.assertGreater(forward_counter, 0)
+ self.assertEqual(forward_accept_counter + forward_reject_counter, forward_counter)
+ if adjoint:
+ xs.sum().backward()
+ if method in FIXED_METHODS:
+ self.assertEqual(adjoint_counter, 10)
+ if method in ADAPTIVE_METHODS:
+ self.assertGreater(adjoint_counter, 0)
+ self.assertEqual(adjoint_accept_counter + adjoint_reject_counter, adjoint_counter)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/problems.py b/tests/problems.py
index 96d8a0988..da945d4bf 100644
--- a/tests/problems.py
+++ b/tests/problems.py
@@ -38,8 +38,10 @@ def __init__(self, dim=10):
A = 2 * U - (U + U.transpose(0, 1))
self.A = torch.nn.Parameter(A)
self.initial_val = np.ones((dim, 1))
+ self.nfe = 0
def forward(self, t, y):
+ self.nfe += 1
return torch.mm(self.A, y.reshape(self.dim, 1)).reshape(-1)
def y_exact(self, t):
@@ -51,15 +53,26 @@ def y_exact(self, t):
return torch.stack([torch.tensor(ans_) for ans_ in ans]).reshape(len(t_numpy), self.dim).to(t)
-PROBLEMS = {'constant': ConstantODE, 'linear': LinearODE, 'sine': SineODE}
+class ExpODE(torch.nn.Module):
+ def forward(self, t, y):
+ return -0.1 * self.y_exact(t)
+
+ def y_exact(self, t):
+ return torch.exp(-0.1 * t)
+
+
+PROBLEMS = {'constant': ConstantODE, 'linear': LinearODE, 'sine': SineODE, 'exp': ExpODE}
DTYPES = (torch.float32, torch.float64)
DEVICES = ['cpu']
if torch.cuda.is_available():
DEVICES.append('cuda')
-FIXED_METHODS = ('euler', 'midpoint', 'rk4', 'explicit_adams', 'implicit_adams')
+FIXED_EXPLICIT_METHODS = ('euler', 'midpoint', 'heun2', 'heun3', 'rk4', 'explicit_adams', 'implicit_adams')
+FIXED_IMPLICIT_METHODS = ('implicit_euler', 'implicit_midpoint', 'trapezoid', 'radauIIA3', 'gl4', 'radauIIA5', 'gl6', 'sdirk2', 'trbdf2')
+FIXED_METHODS = FIXED_EXPLICIT_METHODS + FIXED_IMPLICIT_METHODS
ADAMS_METHODS = ('explicit_adams', 'implicit_adams')
-ADAPTIVE_METHODS = ('adaptive_heun', 'fehlberg2', 'bosh3', 'dopri5', 'dopri8')
+ADAPTIVE_METHODS = ('adaptive_heun', 'fehlberg2', 'bosh3', 'tsit5', 'dopri5', 'dopri8')
SCIPY_METHODS = ('scipy_solver',)
+IMPLICIT_METHODS = FIXED_IMPLICIT_METHODS
METHODS = FIXED_METHODS + ADAPTIVE_METHODS + SCIPY_METHODS
@@ -67,8 +80,8 @@ def construct_problem(device, npts=10, ode='constant', reverse=False, dtype=torc
f = PROBLEMS[ode]().to(dtype=dtype, device=device)
- t_points = torch.linspace(1, 8, npts, dtype=dtype, device=device, requires_grad=True)
- sol = f.y_exact(t_points)
+ t_points = torch.linspace(1, 8, npts, dtype=torch.float64, device=device, requires_grad=True)
+ sol = f.y_exact(t_points).to(dtype)
def _flip(x, dim):
indices = [slice(None)] * x.dim()
diff --git a/torchdiffeq/__init__.py b/torchdiffeq/__init__.py
index 4eff75dbf..b966c7d02 100644
--- a/torchdiffeq/__init__.py
+++ b/torchdiffeq/__init__.py
@@ -1,4 +1,5 @@
from ._impl import odeint
from ._impl import odeint_adjoint
from ._impl import odeint_event
-__version__ = "0.2.3"
+from ._impl import odeint_dense
+__version__ = "0.2.5"
diff --git a/torchdiffeq/_impl/__init__.py b/torchdiffeq/_impl/__init__.py
index 05b671e9c..60d7e0eb1 100644
--- a/torchdiffeq/_impl/__init__.py
+++ b/torchdiffeq/_impl/__init__.py
@@ -1,2 +1,2 @@
-from .odeint import odeint, odeint_event
+from .odeint import odeint, odeint_dense, odeint_event
from .adjoint import odeint_adjoint
diff --git a/torchdiffeq/_impl/adjoint.py b/torchdiffeq/_impl/adjoint.py
index ca9ba7f6d..72e4a685b 100644
--- a/torchdiffeq/_impl/adjoint.py
+++ b/torchdiffeq/_impl/adjoint.py
@@ -2,8 +2,7 @@
import torch
import torch.nn as nn
from .odeint import SOLVERS, odeint
-from .misc import _check_inputs, _flat_to_shape
-from .misc import _mixed_norm
+from .misc import _check_inputs, _flat_to_shape, _mixed_norm, _all_callback_names, _all_adjoint_callback_names
class OdeintAdjointMethod(torch.autograd.Function):
@@ -105,6 +104,15 @@ def augmented_dynamics(t, y_aug):
return (vjp_t, func_eval, vjp_y, *vjp_params)
+ # Add adjoint callbacks
+ for callback_name, adjoint_callback_name in zip(_all_callback_names, _all_adjoint_callback_names):
+ try:
+ callback = getattr(func, adjoint_callback_name)
+ except AttributeError:
+ pass
+ else:
+ setattr(augmented_dynamics, callback_name, callback)
+
##################################
# Solve adjoint ODE #
##################################
diff --git a/torchdiffeq/_impl/fixed_grid.py b/torchdiffeq/_impl/fixed_grid.py
index 7578627c4..29a020f7b 100644
--- a/torchdiffeq/_impl/fixed_grid.py
+++ b/torchdiffeq/_impl/fixed_grid.py
@@ -1,5 +1,5 @@
from .solvers import FixedGridODESolver
-from .rk_common import rk4_alt_step_func
+from .rk_common import rk4_alt_step_func, rk3_step_func, rk2_step_func
from .misc import Perturb
@@ -27,3 +27,34 @@ class RK4(FixedGridODESolver):
def _step_func(self, func, t0, dt, t1, y0):
f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)
return rk4_alt_step_func(func, t0, dt, t1, y0, f0=f0, perturb=self.perturb), f0
+
+
+class Heun3(FixedGridODESolver):
+ order = 3
+
+ def _step_func(self, func, t0, dt, t1, y0):
+ f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)
+
+ butcher_tableu = [
+ [0.0, 0.0, 0.0, 0.0],
+ [1/3, 1/3, 0.0, 0.0],
+ [2/3, 0.0, 2/3, 0.0],
+ [0.0, 1/4, 0.0, 3/4],
+ ]
+
+ return rk3_step_func(func, t0, dt, t1, y0, butcher_tableu=butcher_tableu, f0=f0, perturb=self.perturb), f0
+
+
+class Heun2(FixedGridODESolver):
+ order = 2
+
+ def _step_func(self, func, t0, dt, t1, y0):
+ f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)
+
+ butcher_tableu = [
+ [0.0, 0.0, 0.0],
+ [1.0, 1.0, 0.0],
+ [0.0, 1/2, 1/2],
+ ]
+
+ return rk2_step_func(func, t0, dt, t1, y0, butcher_tableu=butcher_tableu, f0=f0, perturb=self.perturb), f0
diff --git a/torchdiffeq/_impl/fixed_grid_implicit.py b/torchdiffeq/_impl/fixed_grid_implicit.py
new file mode 100644
index 000000000..7519efc8f
--- /dev/null
+++ b/torchdiffeq/_impl/fixed_grid_implicit.py
@@ -0,0 +1,140 @@
+import torch
+from .rk_common import FixedGridFIRKODESolver, FixedGridDIRKODESolver
+from .rk_common import _ButcherTableau
+
+_sqrt_2 = torch.sqrt(torch.tensor(2, dtype=torch.float64)).item()
+_sqrt_3 = torch.sqrt(torch.tensor(3, dtype=torch.float64)).item()
+_sqrt_6 = torch.sqrt(torch.tensor(6, dtype=torch.float64)).item()
+_sqrt_15 = torch.sqrt(torch.tensor(15, dtype=torch.float64)).item()
+
+_IMPLICIT_EULER_TABLEAU = _ButcherTableau(
+ alpha=torch.tensor([1], dtype=torch.float64),
+ beta=[
+ torch.tensor([1], dtype=torch.float64),
+ ],
+ c_sol=torch.tensor([1], dtype=torch.float64),
+ c_error=torch.tensor([], dtype=torch.float64),
+)
+
+class ImplicitEuler(FixedGridFIRKODESolver):
+ order = 1
+ tableau = _IMPLICIT_EULER_TABLEAU
+
+_IMPLICIT_MIDPOINT_TABLEAU = _ButcherTableau(
+ alpha=torch.tensor([1 / 2], dtype=torch.float64),
+ beta=[
+ torch.tensor([1 / 2], dtype=torch.float64),
+
+ ],
+ c_sol=torch.tensor([1], dtype=torch.float64),
+ c_error=torch.tensor([], dtype=torch.float64),
+)
+
+class ImplicitMidpoint(FixedGridFIRKODESolver):
+ order = 2
+ tableau = _IMPLICIT_MIDPOINT_TABLEAU
+
+_GAUSS_LEGENDRE_4_TABLEAU = _ButcherTableau(
+ alpha=torch.tensor([1 / 2 - _sqrt_3 / 6, 1 / 2 - _sqrt_3 / 6], dtype=torch.float64),
+ beta=[
+ torch.tensor([1 / 4, 1 / 4 - _sqrt_3 / 6], dtype=torch.float64),
+ torch.tensor([1 / 4 + _sqrt_3 / 6, 1 / 4], dtype=torch.float64),
+ ],
+ c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64),
+ c_error=torch.tensor([], dtype=torch.float64),
+)
+
+_TRAPEZOID_TABLEAU = _ButcherTableau(
+ alpha=torch.tensor([0, 1], dtype=torch.float64),
+ beta=[
+ torch.tensor([0, 0], dtype=torch.float64),
+ torch.tensor([1 /2, 1 / 2], dtype=torch.float64),
+ ],
+ c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64),
+ c_error=torch.tensor([], dtype=torch.float64),
+)
+
+class Trapezoid(FixedGridFIRKODESolver):
+ order = 2
+ tableau = _TRAPEZOID_TABLEAU
+
+
+class GaussLegendre4(FixedGridFIRKODESolver):
+ order = 4
+ tableau = _GAUSS_LEGENDRE_4_TABLEAU
+
+_GAUSS_LEGENDRE_6_TABLEAU = _ButcherTableau(
+ alpha=torch.tensor([1 / 2 - _sqrt_15 / 10, 1 / 2, 1 / 2 + _sqrt_15 / 10], dtype=torch.float64),
+ beta=[
+ torch.tensor([5 / 36 , 2 / 9 - _sqrt_15 / 15, 5 / 36 - _sqrt_15 / 30], dtype=torch.float64),
+ torch.tensor([5 / 36 + _sqrt_15 / 24, 2 / 9 , 5 / 36 - _sqrt_15 / 24], dtype=torch.float64),
+ torch.tensor([5 / 36 + _sqrt_15 / 30, 2 / 9 + _sqrt_15 / 15, 5 / 36 ], dtype=torch.float64),
+ ],
+ c_sol=torch.tensor([5 / 18, 4 / 9, 5 / 18], dtype=torch.float64),
+ c_error=torch.tensor([], dtype=torch.float64),
+)
+
+class GaussLegendre6(FixedGridFIRKODESolver):
+ order = 6
+ tableau = _GAUSS_LEGENDRE_6_TABLEAU
+
+_RADAU_IIA_3_TABLEAU = _ButcherTableau(
+ alpha=torch.tensor([1 / 3, 1], dtype=torch.float64),
+ beta=[
+ torch.tensor([5 / 12, -1 / 12], dtype=torch.float64),
+ torch.tensor([3 / 4, 1 / 4], dtype=torch.float64)
+ ],
+ c_sol=torch.tensor([3 / 4, 1 / 4], dtype=torch.float64),
+ c_error=torch.tensor([], dtype=torch.float64)
+)
+
+class RadauIIA3(FixedGridFIRKODESolver):
+ order = 3
+ tableau = _RADAU_IIA_3_TABLEAU
+
+_RADAU_IIA_5_TABLEAU = _ButcherTableau(
+ alpha=torch.tensor([2 / 5 - _sqrt_6 / 10, 2 / 5 + _sqrt_6 / 10, 1], dtype=torch.float64),
+ beta=[
+ torch.tensor([11 / 45 - 7 * _sqrt_6 / 360 , 37 / 225 - 169 * _sqrt_6 / 1800, -2 / 225 + _sqrt_6 / 75], dtype=torch.float64),
+ torch.tensor([37 / 225 + 169 * _sqrt_6 / 1800, 11 / 45 + 7 * _sqrt_6 / 360 , -2 / 225 - _sqrt_6 / 75], dtype=torch.float64),
+ torch.tensor([4 / 9 - _sqrt_6 / 36 , 4 / 9 + _sqrt_6 / 36 , 1 / 9], dtype=torch.float64)
+ ],
+ c_sol=torch.tensor([4 / 9 - _sqrt_6 / 36, 4 / 9 + _sqrt_6 / 36, 1 / 9], dtype=torch.float64),
+ c_error=torch.tensor([], dtype=torch.float64)
+)
+
+class RadauIIA5(FixedGridFIRKODESolver):
+ order = 5
+ tableau = _RADAU_IIA_5_TABLEAU
+
+gamma = (2. - _sqrt_2) / 2.
+_SDIRK_2_TABLEAU = _ButcherTableau(
+ alpha = torch.tensor([gamma, 1], dtype=torch.float64),
+ beta=[
+ torch.tensor([gamma], dtype=torch.float64),
+ torch.tensor([1 - gamma, gamma], dtype=torch.float64),
+ ],
+ c_sol=torch.tensor([1 - gamma, gamma], dtype=torch.float64),
+ c_error=torch.tensor([], dtype=torch.float64)
+)
+
+class SDIRK2(FixedGridDIRKODESolver):
+ order = 2
+ tableau = _SDIRK_2_TABLEAU
+
+gamma = 1. - _sqrt_2 / 2.
+beta = _sqrt_2 / 4.
+_TRBDF_2_TABLEAU = _ButcherTableau(
+ alpha = torch.tensor([0, 2 * gamma, 1], dtype=torch.float64),
+ beta=[
+ torch.tensor([0], dtype=torch.float64),
+ torch.tensor([gamma, gamma], dtype=torch.float64),
+ torch.tensor([beta, beta, gamma], dtype=torch.float64),
+ ],
+ c_sol=torch.tensor([beta, beta, gamma], dtype=torch.float64),
+ c_error=torch.tensor([], dtype=torch.float64)
+)
+
+class TRBDF2(FixedGridDIRKODESolver):
+ order = 2
+ tableau = _TRBDF_2_TABLEAU
diff --git a/torchdiffeq/_impl/misc.py b/torchdiffeq/_impl/misc.py
index 07cbb14dc..685cc8da2 100644
--- a/torchdiffeq/_impl/misc.py
+++ b/torchdiffeq/_impl/misc.py
@@ -6,17 +6,21 @@
from .event_handling import combine_event_functions
+_all_callback_names = ['callback_step', 'callback_accept_step', 'callback_reject_step']
+_all_adjoint_callback_names = [name + '_adjoint' for name in _all_callback_names]
+_null_callback = lambda *args, **kwargs: None
+
def _handle_unused_kwargs(solver, unused_kwargs):
if len(unused_kwargs) > 0:
warnings.warn('{}: Unexpected arguments {}'.format(solver.__class__.__name__, unused_kwargs))
def _linf_norm(tensor):
- return tensor.max()
+ return tensor.abs().max()
def _rms_norm(tensor):
- return tensor.pow(2).mean().sqrt()
+ return tensor.abs().pow(2).mean().sqrt()
def _zero_norm(tensor):
@@ -43,37 +47,39 @@ def _select_initial_step(func, t0, y0, order, rtol, atol, norm, f0=None):
dtype = y0.dtype
device = y0.device
t_dtype = t0.dtype
- t0 = t0.to(dtype)
+ t0 = t0.to(t_dtype)
if f0 is None:
f0 = func(t0, y0)
scale = atol + torch.abs(y0) * rtol
- d0 = norm(y0 / scale)
- d1 = norm(f0 / scale)
+ d0 = norm(y0 / scale).abs()
+ d1 = norm(f0 / scale).abs()
if d0 < 1e-5 or d1 < 1e-5:
h0 = torch.tensor(1e-6, dtype=dtype, device=device)
else:
h0 = 0.01 * d0 / d1
+ h0 = h0.abs()
y1 = y0 + h0 * f0
f1 = func(t0 + h0, y1)
- d2 = norm((f1 - f0) / scale) / h0
+ d2 = torch.abs(norm((f1 - f0) / scale) / h0)
if d1 <= 1e-15 and d2 <= 1e-15:
h1 = torch.max(torch.tensor(1e-6, dtype=dtype, device=device), h0 * 1e-3)
else:
h1 = (0.01 / max(d1, d2)) ** (1. / float(order + 1))
+ h1 = h1.abs()
return torch.min(100 * h0, h1).to(t_dtype)
def _compute_error_ratio(error_estimate, rtol, atol, y0, y1, norm):
error_tol = atol + rtol * torch.max(y0.abs(), y1.abs())
- return norm(error_estimate / error_tol)
+ return norm(error_estimate / error_tol).abs()
@torch.no_grad()
@@ -176,7 +182,9 @@ def forward(self, t, y, *, perturb=Perturb.NONE):
# This dtype change here might be buggy.
# The exact time value should be determined inside the solver,
# but this can slightly change it due to numerical differences during casting.
- t = t.to(y.dtype)
+ if torch.is_complex(t):
+ t = t.real
+ t = t.to(y.abs().dtype)
if perturb is Perturb.NEXT:
# Replace with next smallest representable value.
t = _nextafter(t, t + 1)
@@ -198,6 +206,9 @@ def _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS):
# Combine event functions if the output is multivariate.
event_fn = combine_event_functions(event_fn, t[0], y0)
+ # Keep reference to original func as passed in
+ original_func = func
+
# Normalise to tensor (non-tupled) input
shapes = None
is_tuple = not isinstance(y0, torch.Tensor)
@@ -210,7 +221,6 @@ def _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS):
func = _TupleFunc(func, shapes)
if event_fn is not None:
event_fn = _TupleInputOnlyFunc(event_fn, shapes)
- _assert_floating('y0', y0)
# Normalise method and options
if options is None:
@@ -300,6 +310,38 @@ def _norm(tensor):
# Add perturb argument to func.
func = _PerturbFunc(func)
+ # Add callbacks to wrapped_func
+ callback_names = set()
+ for callback_name in _all_callback_names:
+ try:
+ callback = getattr(original_func, callback_name)
+ except AttributeError:
+ setattr(func, callback_name, _null_callback)
+ else:
+ if callback is not _null_callback:
+ callback_names.add(callback_name)
+ # At the moment all callbacks have the arguments (t0, y0, dt).
+ # These will need adjusting on a per-callback basis if that changes in the future.
+ if is_tuple:
+ def callback(t0, y0, dt, _callback=callback):
+ y0 = _flat_to_shape(y0, (), shapes)
+ return _callback(t0, y0, dt)
+ if t_is_reversed:
+ def callback(t0, y0, dt, _callback=callback):
+ return _callback(-t0, y0, dt)
+ setattr(func, callback_name, callback)
+ for callback_name in _all_adjoint_callback_names:
+ try:
+ callback = getattr(original_func, callback_name)
+ except AttributeError:
+ pass
+ else:
+ setattr(func, callback_name, callback)
+
+ invalid_callbacks = callback_names - SOLVERS[method].valid_callbacks()
+ if len(invalid_callbacks) > 0:
+ warnings.warn("Solver '{}' does not support callbacks {}".format(method, invalid_callbacks))
+
return shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed
diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py
index a174219ad..14a01efee 100644
--- a/torchdiffeq/_impl/odeint.py
+++ b/torchdiffeq/_impl/odeint.py
@@ -4,23 +4,41 @@
from .bosh3 import Bosh3Solver
from .adaptive_heun import AdaptiveHeunSolver
from .fehlberg2 import Fehlberg2
-from .fixed_grid import Euler, Midpoint, RK4
+from .fixed_grid import Euler, Midpoint, Heun2, Heun3, RK4
+from .fixed_grid_implicit import ImplicitEuler, ImplicitMidpoint, Trapezoid
+from .fixed_grid_implicit import GaussLegendre4, GaussLegendre6
+from .fixed_grid_implicit import RadauIIA3, RadauIIA5
+from .fixed_grid_implicit import SDIRK2, TRBDF2
from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton
from .dopri8 import Dopri8Solver
+from .tsit5 import Tsit5Solver
from .scipy_wrapper import ScipyWrapperODESolver
from .misc import _check_inputs, _flat_to_shape
+from .interp import _interp_evaluate
SOLVERS = {
'dopri8': Dopri8Solver,
'dopri5': Dopri5Solver,
+ 'tsit5': Tsit5Solver,
'bosh3': Bosh3Solver,
'fehlberg2': Fehlberg2,
'adaptive_heun': AdaptiveHeunSolver,
'euler': Euler,
'midpoint': Midpoint,
+ 'heun2': Heun2,
+ 'heun3': Heun3,
'rk4': RK4,
'explicit_adams': AdamsBashforth,
'implicit_adams': AdamsBashforthMoulton,
+ 'implicit_euler': ImplicitEuler,
+ 'implicit_midpoint': ImplicitMidpoint,
+ 'trapezoid': Trapezoid,
+ 'radauIIA3': RadauIIA3,
+ 'gl4': GaussLegendre4,
+ 'radauIIA5': RadauIIA5,
+ 'gl6': GaussLegendre6,
+ 'sdirk2': SDIRK2,
+ 'trbdf2': TRBDF2,
# Backward compatibility: use the same name as before
'fixed_adams': AdamsBashforthMoulton,
# ~Backwards compatibility
@@ -90,6 +108,55 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even
return event_t, solution
+def odeint_dense(func, y0, t0, t1, *, rtol=1e-7, atol=1e-9, method=None, options=None):
+
+ assert torch.is_tensor(y0) # TODO: handle tuple of tensors
+
+ t = torch.tensor([t0, t1]).to(t0)
+
+ shapes, func, y0, t, rtol, atol, method, options, _, _ = _check_inputs(func, y0, t, rtol, atol, method, options, None, SOLVERS)
+
+ assert method == "dopri5"
+
+ solver = Dopri5Solver(func=func, y0=y0, rtol=rtol, atol=atol, **options)
+
+ # The integration loop
+ solution = torch.empty(len(t), *solver.y0.shape, dtype=solver.y0.dtype, device=solver.y0.device)
+ solution[0] = solver.y0
+ t = t.to(solver.dtype)
+ solver._before_integrate(t)
+ t0 = solver.rk_state.t0
+
+ times = [t0]
+ interp_coeffs = []
+
+ for i in range(1, len(t)):
+ next_t = t[i]
+ while next_t > solver.rk_state.t1:
+ solver.rk_state = solver._adaptive_step(solver.rk_state)
+ t1 = solver.rk_state.t1
+
+ if t1 != t0:
+ # Step accepted.
+ t0 = t1
+ times.append(t1)
+ interp_coeffs.append(torch.stack(solver.rk_state.interp_coeff))
+
+ solution[i] = _interp_evaluate(solver.rk_state.interp_coeff, solver.rk_state.t0, solver.rk_state.t1, next_t)
+
+ times = torch.stack(times).reshape(-1).cpu()
+ interp_coeffs = torch.stack(interp_coeffs)
+
+ def dense_output_fn(t_eval):
+ idx = torch.searchsorted(times, t_eval, side="right")
+ t0 = times[idx - 1]
+ t1 = times[idx]
+ coef = [interp_coeffs[idx - 1][i] for i in range(interp_coeffs.shape[1])]
+ return _interp_evaluate(coef, t0, t1, t_eval)
+
+ return dense_output_fn
+
+
def odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=odeint, **kwargs):
"""Automatically links up the gradient from the event time."""
diff --git a/torchdiffeq/_impl/rk_common.py b/torchdiffeq/_impl/rk_common.py
index 808fcc690..f0050dbfd 100644
--- a/torchdiffeq/_impl/rk_common.py
+++ b/torchdiffeq/_impl/rk_common.py
@@ -5,9 +5,11 @@
from .interp import _interp_evaluate, _interp_fit
from .misc import (_compute_error_ratio,
_select_initial_step,
- _optimal_step_size)
+ _optimal_step_size,
+ _handle_unused_kwargs)
from .misc import Perturb
-from .solvers import AdaptiveStepsizeEventODESolver
+from .solvers import AdaptiveStepsizeEventODESolver, FixedGridODESolver
+import warnings
_ButcherTableau = collections.namedtuple('_ButcherTableau', 'alpha, beta, c_sol, c_error')
@@ -56,9 +58,11 @@ def _runge_kutta_step(func, y0, f0, t0, dt, t1, tableau):
calculating these terms.
"""
- t0 = t0.to(y0.dtype)
- dt = dt.to(y0.dtype)
- t1 = t1.to(y0.dtype)
+ t_dtype = y0.abs().dtype
+
+ t0 = t0.to(t_dtype)
+ dt = dt.to(t_dtype)
+ t1 = t1.to(t_dtype)
# We use an unchecked assign to put data into k without incrementing its _version counter, so that the backward
# doesn't throw an (overzealous) error about in-place correctness. We know that it's actually correct.
@@ -114,12 +118,54 @@ def rk4_alt_step_func(func, t0, dt, t1, y0, f0=None, perturb=False):
return (k1 + 3 * (k2 + k3) + k4) * dt * 0.125
+def rk3_step_func(func, t0, dt, t1, y0, butcher_tableu=None, f0=None, perturb=False):
+ """butcher_tableu should be of the form
+
+ [
+ [0 , 0 , 0 , 0],
+ [c_2, a_{21}, 0 , 0],
+ [c_3, a_{31}, a_{32}, 0],
+ [0 , b_1 , b_2 , b_3],
+ ]
+
+ https://en.wikipedia.org/wiki/List_of_Runge-Kutta_methods
+ """
+ k1 = f0
+ if k1 is None:
+ k1 = func(t0, y0, perturb=Perturb.NEXT if perturb else Perturb.NONE)
+
+ k2 = func(t0 + dt * butcher_tableu[1][0], y0 + dt * k1 * butcher_tableu[1][1])
+ k3 = func(t0 + dt * butcher_tableu[2][0], y0 + dt * (k1 * butcher_tableu[2][1] + k2 * butcher_tableu[2][2]))
+ return dt * (k1 * butcher_tableu[3][1] + k2 * butcher_tableu[3][2] + k3 * butcher_tableu[3][3])
+
+
+def rk2_step_func(func, t0, dt, t1, y0, butcher_tableu=None, f0=None, perturb=False):
+ """butcher_tableu should be of the form
+
+ [
+ [0 , 0 , 0 ],
+ [c_2, a_{21}, 0 ],
+ [0 , b_1 , b_2 ],
+ ]
+
+ https://en.wikipedia.org/wiki/List_of_Runge-Kutta_methods
+ """
+ k1 = f0
+ if k1 is None:
+ k1 = func(t0, y0, perturb=Perturb.NEXT if perturb else Perturb.NONE)
+
+ k2 = func(t0 + dt * butcher_tableu[1][0], y0 + dt * k1 * butcher_tableu[1][1], perturb=Perturb.PREV if perturb else Perturb.NONE)
+ return dt * (k1 * butcher_tableu[2][1] + k2 * butcher_tableu[2][2])
+
+
class RKAdaptiveStepsizeODESolver(AdaptiveStepsizeEventODESolver):
order: int
tableau: _ButcherTableau
mid: torch.Tensor
def __init__(self, func, y0, rtol, atol,
+ min_step=0,
+ max_step=float('inf'),
first_step=None,
step_t=None,
jump_t=None,
@@ -133,12 +179,14 @@ def __init__(self, func, y0, rtol, atol,
# We use mixed precision. y has its original dtype (probably float32), whilst all 'time'-like objects use
# `dtype` (defaulting to float64).
- dtype = torch.promote_types(dtype, y0.dtype)
+ dtype = torch.promote_types(dtype, y0.abs().dtype)
device = y0.device
self.func = func
self.rtol = torch.as_tensor(rtol, dtype=dtype, device=device)
self.atol = torch.as_tensor(atol, dtype=dtype, device=device)
+ self.min_step = torch.as_tensor(min_step, dtype=dtype, device=device)
+ self.max_step = torch.as_tensor(max_step, dtype=dtype, device=device)
self.first_step = None if first_step is None else torch.as_tensor(first_step, dtype=dtype, device=device)
self.safety = torch.as_tensor(safety, dtype=dtype, device=device)
self.ifactor = torch.as_tensor(ifactor, dtype=dtype, device=device)
@@ -156,6 +204,12 @@ def __init__(self, func, y0, rtol, atol,
c_error=self.tableau.c_error.to(device=device, dtype=y0.dtype))
self.mid = self.mid.to(device=device, dtype=y0.dtype)
+ @classmethod
+ def valid_callbacks(cls):
+ return super(RKAdaptiveStepsizeODESolver, cls).valid_callbacks() | {'callback_step',
+ 'callback_accept_step',
+ 'callback_reject_step'}
+
def _before_integrate(self, t):
t0 = t[0]
f0 = self.func(t[0], self.y0)
@@ -212,6 +266,10 @@ def _advance_until_event(self, event_fn):
def _adaptive_step(self, rk_state):
"""Take an adaptive Runge-Kutta step to integrate the ODE."""
y0, f0, _, t0, dt, interp_coeff = rk_state
+ if not torch.isfinite(dt):
+ dt = self.min_step
+ dt = dt.clamp(self.min_step, self.max_step)
+ self.func.callback_step(t0, y0, dt)
t1 = t0 + dt
# dtypes: self.y0.dtype (probably float32); self.dtype (probably float64)
# used for state and timelike objects respectively.
@@ -264,6 +322,13 @@ def _adaptive_step(self, rk_state):
########################################################
error_ratio = _compute_error_ratio(y1_error, self.rtol, self.atol, y0, y1, self.norm)
accept_step = error_ratio <= 1
+
+ # Handle min max stepping
+ if dt > self.max_step:
+ accept_step = False
+ if dt <= self.min_step:
+ accept_step = True
+
# dtypes:
# error_ratio.dtype == self.dtype
@@ -271,6 +336,7 @@ def _adaptive_step(self, rk_state):
# Update RK State #
########################################################
if accept_step:
+ self.func.callback_accept_step(t0, y0, dt)
t_next = t1
y_next = y1
interp_coeff = self._interp_fit(y0, y_next, k, dt)
@@ -285,10 +351,12 @@ def _adaptive_step(self, rk_state):
f1 = self.func(t_next, y_next, perturb=Perturb.NEXT)
f_next = f1
else:
+ self.func.callback_reject_step(t0, y0, dt)
t_next = t0
y_next = y0
f_next = f0
dt_next = _optimal_step_size(dt, error_ratio, self.safety, self.ifactor, self.dfactor, self.order)
+ dt_next = dt_next.clamp(self.min_step, self.max_step)
rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff)
return rk_state
@@ -305,3 +373,186 @@ def _sort_tvals(tvals, t0):
# TODO: add warning if tvals come before t0?
tvals = tvals[tvals >= t0]
return torch.sort(tvals).values
+
+
+class FixedGridFIRKODESolver(FixedGridODESolver):
+ order: int
+ tableau: _ButcherTableau
+
+ def __init__(self, func, y0, step_size=None, grid_constructor=None, interp='linear', perturb=False, max_iters=100, **unused_kwargs):
+
+ self.max_iters = max_iters
+ self.atol = unused_kwargs.pop('atol')
+ unused_kwargs.pop('rtol', None)
+ unused_kwargs.pop('norm', None)
+ _handle_unused_kwargs(self, unused_kwargs)
+ del unused_kwargs
+
+ self.func = func
+ self.y0 = y0
+ self.dtype = y0.dtype
+ self.device = y0.device
+ self.step_size = step_size
+ self.interp = interp
+ self.perturb = perturb
+
+ if step_size is None:
+ if grid_constructor is None:
+ self.grid_constructor = lambda f, y0, t: t
+ else:
+ self.grid_constructor = grid_constructor
+ else:
+ if grid_constructor is None:
+ self.grid_constructor = self._grid_constructor_from_step_size(step_size)
+ else:
+ raise ValueError("step_size and grid_constructor are mutually exclusive arguments.")
+
+ self.tableau = _ButcherTableau(alpha=self.tableau.alpha.to(device=self.device, dtype=y0.dtype),
+ beta=[b.to(device=self.device, dtype=y0.dtype) for b in self.tableau.beta],
+ c_sol=self.tableau.c_sol.to(device=self.device, dtype=y0.dtype),
+ c_error=self.tableau.c_error.to(device=self.device, dtype=y0.dtype))
+
+ def _step_func(self, func, t0, dt, t1, y0):
+ if not isinstance(t0, torch.Tensor):
+ t0 = torch.tensor(t0)
+ if not isinstance(dt, torch.Tensor):
+ dt = torch.tensor(dt)
+ if not isinstance(t1, torch.Tensor):
+ t1 = torch.tensor(t1)
+ f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)
+
+ t_dtype = y0.abs().dtype
+ tol = 1e-8
+ if t_dtype == torch.float64:
+ tol = 1e-8
+ if t_dtype == torch.float32:
+ tol = 1e-6
+
+ t0 = t0.to(t_dtype)
+ dt = dt.to(t_dtype)
+ t1 = t1.to(t_dtype)
+
+ k = f0.clone().unsqueeze(-1).tile(len(self.tableau.alpha))
+ beta = torch.stack(self.tableau.beta, -1)
+
+ # Broyden's Method to solve the system of nonlinear equations
+ y = torch.matmul(k, beta * dt).add(y0.unsqueeze(-1)).movedim(-1, 0)
+ f = self._residual(func, k, y, t0, dt, t1)
+ J = torch.ones_like(f).diag()
+ converged = False
+ for _ in range(self.max_iters):
+ if torch.linalg.norm(f, 2) < tol:
+ converged = True
+ break
+
+ # If the matrix becomes singular, just stop and return the last value
+ try:
+ s = -torch.linalg.solve(J, f)
+ except torch._C._LinAlgError:
+ break
+
+ k = k + s.reshape_as(k)
+ y = torch.matmul(k, beta * dt).add(y0.unsqueeze(-1)).movedim(-1, 0)
+ newf = self._residual(func, k, y, t0, dt, t1)
+ z = newf - f
+ f = newf
+ J = J + (torch.outer ((z - torch.linalg.vecdot(J,s)),s)) / (torch.dot(s,s))
+
+ if not converged:
+ warnings.warn('Functional iteration did not converge. Solution may be incorrect.')
+
+ dy = torch.matmul(k, dt * self.tableau.c_sol)
+
+ return dy, f0
+
+ def _residual(self, func, K, y, t0, dt, t1):
+ res = torch.zeros_like(K)
+ for i, (y_i, alpha_i) in enumerate(zip(y, self.tableau.alpha)):
+ perturb = Perturb.NONE
+ if alpha_i == 1.:
+ ti = t1
+ perturb = Perturb.PREV
+ elif alpha_i == 0.:
+ if not torch.all(self.tableau.beta[i]):
+ # Same slope as stored so skip
+ continue
+ ti = t0
+ else:
+ ti = t0 + alpha_i * dt
+ res[...,i] = K[...,i] - func(ti, y_i, perturb=perturb)
+ return res.flatten()
+
+
+class FixedGridDIRKODESolver(FixedGridFIRKODESolver):
+
+ def _step_func(self, func, t0, dt, t1, y0):
+ if not isinstance(t0, torch.Tensor):
+ t0 = torch.tensor(t0)
+ if not isinstance(dt, torch.Tensor):
+ dt = torch.tensor(dt)
+ if not isinstance(t1, torch.Tensor):
+ t1 = torch.tensor(t1)
+ f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)
+
+ t_dtype = y0.abs().dtype
+ tol = 1e-8
+ if t_dtype == torch.float64:
+ tol = 1e-8
+ if t_dtype == torch.float32:
+ tol = 1e-6
+
+ t0 = t0.to(t_dtype)
+ dt = dt.to(t_dtype)
+ t1 = t1.to(t_dtype)
+
+ k = [f0.clone()] * len(self.tableau.alpha)
+
+ for i, (alpha_i, beta_i) in enumerate(zip(self.tableau.alpha, self.tableau.beta)):
+ perturb = Perturb.NONE
+ if alpha_i == 1.:
+ ti = t1
+ perturb = Perturb.PREV
+ elif alpha_i == 0.:
+ if not torch.all(self.tableau.beta[i]):
+ # Same slope as stored so skip
+ continue
+ ti = t0
+ else:
+ ti = t0 + alpha_i * dt
+
+ k_i = torch.stack(k[:i+1], -1)
+
+ # Broyden's Method to solve the system of nonlinear equations
+ y_i = torch.matmul(k_i, beta_i * dt).add(y0)
+ f = self._residual(func, k_i, y_i, ti, perturb)
+ J = torch.ones_like(f).diag()
+ converged = False
+ for _ in range(self.max_iters):
+ if torch.linalg.norm(f, 2) < tol:
+ converged = True
+ break
+
+ # If the matrix becomes singular, just stop and return the last value
+ try:
+ s = -torch.linalg.solve(J, f)
+ except torch._C._LinAlgError:
+ break
+
+ k[i] = k[i] + s.reshape_as(k[i])
+ k_i = torch.stack(k[:i+1], -1)
+ y_i = torch.matmul(k_i, beta_i * dt).add(y0)
+ newf = self._residual(func, k_i, y_i, ti, perturb)
+ z = newf - f
+ f = newf
+ J = J + (torch.outer ((z - torch.linalg.vecdot(J,s)),s)) / (torch.dot(s,s))
+
+ if not converged:
+ warnings.warn('Functional iteration did not converge. Solution may be incorrect.')
+
+ dy = torch.matmul(torch.stack(k, -1), dt * self.tableau.c_sol)
+
+ return dy, f0
+
+ def _residual(self, func, K, y, t, perturb):
+ res = K[...,-1] - func(t, y, perturb=perturb)
+ return res.flatten()
diff --git a/torchdiffeq/_impl/scipy_wrapper.py b/torchdiffeq/_impl/scipy_wrapper.py
index 06f93273e..41fe90149 100644
--- a/torchdiffeq/_impl/scipy_wrapper.py
+++ b/torchdiffeq/_impl/scipy_wrapper.py
@@ -6,7 +6,7 @@
class ScipyWrapperODESolver(metaclass=abc.ABCMeta):
- def __init__(self, func, y0, rtol, atol, solver="LSODA", **unused_kwargs):
+ def __init__(self, func, y0, rtol, atol, min_step=0, max_step=float('inf'), solver="LSODA", **unused_kwargs):
unused_kwargs.pop('norm', None)
unused_kwargs.pop('grid_points', None)
unused_kwargs.pop('eps', None)
@@ -19,6 +19,8 @@ def __init__(self, func, y0, rtol, atol, solver="LSODA", **unused_kwargs):
self.y0 = y0.detach().cpu().numpy().reshape(-1)
self.rtol = rtol
self.atol = atol
+ self.min_step = min_step
+ self.max_step = max_step
self.solver = solver
self.func = convert_func_to_numpy(func, self.shape, self.device, self.dtype)
@@ -34,10 +36,16 @@ def integrate(self, t):
method=self.solver,
rtol=self.rtol,
atol=self.atol,
+ min_step=self.min_step,
+ max_step=self.max_step
)
sol = torch.tensor(sol.y).T.to(self.device, self.dtype)
sol = sol.reshape(-1, *self.shape)
return sol
+
+ @classmethod
+ def valid_callbacks(cls):
+ return set()
def convert_func_to_numpy(func, shape, device, dtype):
diff --git a/torchdiffeq/_impl/solvers.py b/torchdiffeq/_impl/solvers.py
index 6915f2bd9..cc64218b1 100644
--- a/torchdiffeq/_impl/solvers.py
+++ b/torchdiffeq/_impl/solvers.py
@@ -21,6 +21,10 @@ def _before_integrate(self, t):
def _advance(self, next_t):
raise NotImplementedError
+ @classmethod
+ def valid_callbacks(cls):
+ return set()
+
def integrate(self, t):
solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
solution[0] = self.y0
@@ -74,6 +78,10 @@ def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="line
else:
raise ValueError("step_size and grid_constructor are mutually exclusive arguments.")
+ @classmethod
+ def valid_callbacks(cls):
+ return {'callback_step'}
+
@staticmethod
def _grid_constructor_from_step_size(step_size):
def _grid_constructor(func, y0, t):
@@ -102,6 +110,7 @@ def integrate(self, t):
y0 = self.y0
for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
dt = t1 - t0
+ self.func.callback_step(t0, y0, dt)
dy, f0 = self._step_func(self.func, t0, dt, t1, y0)
y1 = y0 + dy
@@ -121,7 +130,7 @@ def integrate(self, t):
def integrate_until_event(self, t0, event_fn):
assert self.step_size is not None, "Event handling for fixed step solvers currently requires `step_size` to be provided in options."
- t0 = t0.type_as(self.y0)
+ t0 = t0.type_as(self.y0.abs())
y0 = self.y0
dt = self.step_size
diff --git a/torchdiffeq/_impl/tsit5.py b/torchdiffeq/_impl/tsit5.py
new file mode 100644
index 000000000..4f4a22186
--- /dev/null
+++ b/torchdiffeq/_impl/tsit5.py
@@ -0,0 +1,82 @@
+import torch
+from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver
+# https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/lib/OrdinaryDiffEqTsit5/src/tsit_tableaus.jl
+# https://github.com/patrick-kidger/diffrax/blob/14baa1edddcacf27c0483962b3c9cf2e86e6e5b6/diffrax/_solver/tsit5.py#L158
+
+_TSITOURAS_TABLEAU = _ButcherTableau(
+ alpha=torch.tensor([
+ 161 / 1000,
+ 327 / 1000,
+ 9 / 10,
+ .9800255409045096857298102862870245954942137979563024768854764293221195950761080302604,
+ 1,
+ 1
+ ], dtype=torch.float64),
+ beta=[
+ torch.tensor([161 / 1000], dtype=torch.float64),
+ torch.tensor([
+ -.8480655492356988544426874250230774675121177393430391537369234245294192976164141156943e-2,
+ .3354806554923569885444268742502307746751211773934303915373692342452941929761641411569
+ ], dtype=torch.float64),
+ torch.tensor([
+ 2.897153057105493432130432594192938764924887287701866490314866693455023795137503079289,
+ -6.359448489975074843148159912383825625952700647415626703305928850207288721235210244366,
+ 4.362295432869581411017727318190886861027813359713760212991062156752264926097707165077,
+ ], dtype=torch.float64),
+ torch.tensor([
+ 5.325864828439256604428877920840511317836476253097040101202360397727981648835607691791,
+ -11.74888356406282787774717033978577296188744178259862899288666928009020615663593781589,
+ 7.495539342889836208304604784564358155658679161518186721010132816213648793440552049753,
+ -.9249506636175524925650207933207191611349983406029535244034750452930469056411389539635e-1
+ ], dtype=torch.float64),
+ torch.tensor([
+ 5.861455442946420028659251486982647890394337666164814434818157239052507339770711679748,
+ -12.92096931784710929170611868178335939541780751955743459166312250439928519268343184452,
+ 8.159367898576158643180400794539253485181918321135053305748355423955009222648673734986,
+ -.7158497328140099722453054252582973869127213147363544882721139659546372402303777878835e-1,
+ -.2826905039406838290900305721271224146717633626879770007617876201276764571291579142206e-1
+ ], dtype=torch.float64),
+ torch.tensor([
+ .9646076681806522951816731316512876333711995238157997181903319145764851595234062815396e-1,
+ 1 / 100,
+ .4798896504144995747752495322905965199130404621990332488332634944254542060153074523509,
+ 1.379008574103741893192274821856872770756462643091360525934940067397245698027561293331,
+ -3.290069515436080679901047585711363850115683290894936158531296799594813811049925401677,
+ 2.324710524099773982415355918398765796109060233222962411944060046314465391054716027841
+ ], dtype=torch.float64),
+ ],
+ c_sol=torch.tensor([
+ .9468075576583945807478876255758922856117527357724631226139574065785592789071067303271e-1,
+ .9183565540343253096776363936645313759813746240984095238905939532922955247253608687270e-2,
+ .4877705284247615707855642599631228241516691959761363774365216240304071651579571959813,
+ 1.234297566930478985655109673884237654035539930748192848315425833500484878378061439761,
+ -2.707712349983525454881109975059321670689605166938197378763992255714444407154902012702,
+ 1.866628418170587035753719399566211498666255505244122593996591602841258328965767580089,
+ 1 / 66
+ ], dtype=torch.float64),
+ c_error=torch.tensor([
+ -1.780011052225771443378550607539534775944678804333659557637450799792588061629796e-03,
+ -8.164344596567469032236360633546862401862537590159047610940604670770447527463931e-04,
+ 7.880878010261996010314727672526304238628733777103128603258129604952959142646516e-03,
+ -1.44711007173262907537165147972635116720922712343167677619514233896760819649515e-01,
+ 5.823571654525552250199376106520421794260781239567387797673045438803694038950012e-01,
+ -4.580821059291869466616365188325542974428047279788398179474684434732070620889539e-01,
+ 1 / 66
+ ], dtype=torch.float64),
+)
+
+x = 1 / 2
+TSIT_C_MID = torch.tensor([
+ -1.0530884977290216*x*(x-1.329989018975412)*(x*x-1.4364028541716351*x+0.7139816917074209),
+ 0.1017*x*x*(x*x-2.1966568338249754*x+1.2949852507374631),
+ 2.490627285651252793*x*x*(x*x-2.38535645472061657*x+1.57803468208092486),
+ -16.54810288924490272*(x-1.21712927295533244)*(x-0.61620406037800089)*x*x,
+ 47.37952196281928122*(x-1.203071208372362603)*(x-0.658047292653547382)*x*x,
+ -34.87065786149660974*(x-1.2)*(x-2/3)*x*x,
+ 2.5*(x-1)*(x-0.6)*x*x
+], dtype=torch.float64)
+
+class Tsit5Solver(RKAdaptiveStepsizeODESolver):
+ order = 5
+ tableau = _TSITOURAS_TABLEAU
+ mid = TSIT_C_MID