diff --git a/FURTHER_DOCUMENTATION.md b/FURTHER_DOCUMENTATION.md index 4e7604a25..5c5933339 100644 --- a/FURTHER_DOCUMENTATION.md +++ b/FURTHER_DOCUMENTATION.md @@ -57,3 +57,21 @@ For this solver, `rtol` and `atol` correspond to the tolerance for convergence o - `adjoint_params`: The parameters to compute gradients with respect to in the backward pass. Should be a tuple of tensors. Defaults to `tuple(func.parameters())`. - If passed then `func` does not have to be a `torch.nn.Module`. - If `func` has no parameters, `adjoint_params=()` must be specified. + + + ## Callbacks + + Callbacks can be triggered during the solve. Callbacks should be specified as methods of the `func` argument to `odeint` and `odeint_adjoint`. + + At the moment support for this is minimal: let us know if you'd find additional callbacks useful. + + **callback_step(self, t0, y0, dt):**
+ This is called immediately before taking a step of size `dt`, at time `t0`, with current solution value `y0`. This is supported by every solver except `scipy_solver`. + + **callback_accept_step(self, t0, y0, dt):**
+ This is called when accepting a step of size `dt` at time `t0`, with current solution value `y0`. This is supported by the adaptive solvers (dopri8, dopri5, bosh3, adaptive_heun). + + **callback_reject_step(self, t0, y0, dt):**
+ As `callback_accept_step`, except called when rejecting steps. + + In addition, callbacks can be triggered during the adjoint pass by adding `_adjoint` to the name of any one of the supported callbacks, e.g. `callback_step_adjoint`. \ No newline at end of file diff --git a/README.md b/README.md index 3e250dc84..8f911bba6 100644 --- a/README.md +++ b/README.md @@ -70,11 +70,11 @@ odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=ode The solve is terminated at an event time `t` and state `y` when an element of `event_fn(t, y)` is equal to zero. Multiple outputs from `event_fn` can be used to specify multiple event functions, of which the first to trigger will terminate the solve. -Both the event time and final state are returned from `odeint_event`, and can be differentiated. Gradients will be backpropagated through the event function. +Both the event time and final state are returned from `odeint_event`, and can be differentiated. Gradients will be backpropagated through the event function. **NOTE**: parameters for the event function must be in the state itself to obtain gradients. The numerical precision for the event time is determined by the `atol` argument. -See example of simulating and differentiating through a bouncing ball in [`examples/bouncing_ball.py`](./examples/bouncing_ball.py). +See example of simulating and differentiating through a bouncing ball in [`examples/bouncing_ball.py`](./examples/bouncing_ball.py). See example code for learning a simple event function in [`examples/learn_physics.py`](./examples/learn_physics.py).

Bouncing Ball @@ -118,13 +118,8 @@ For details of the adjoint-specific and solver-specific options, check out the [ Applications of differentiable ODE solvers and event handling are discussed in these two papers: -[1] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." *Advances in Neural Information Processing Systems.* 2018. [[arxiv]](https://arxiv.org/abs/1806.07366) +Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." *Advances in Neural Information Processing Systems.* 2018. [[arxiv]](https://arxiv.org/abs/1806.07366) -[2] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel. "Learning Neural Event Functions for Ordinary Differential Equations." *International Conference on Learning Representations.* 2021. [[arxiv]](https://arxiv.org/abs/2011.03902) - ---- - -If you found this library useful in your research, please consider citing. ``` @article{chen2018neuralode, title={Neural Ordinary Differential Equations}, @@ -132,7 +127,11 @@ If you found this library useful in your research, please consider citing. journal={Advances in Neural Information Processing Systems}, year={2018} } +``` +Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel. "Learning Neural Event Functions for Ordinary Differential Equations." *International Conference on Learning Representations.* 2021. [[arxiv]](https://arxiv.org/abs/2011.03902) + +``` @article{chen2021eventfn, title={Learning Neural Event Functions for Ordinary Differential Equations}, author={Chen, Ricky T. Q. and Amos, Brandon and Nickel, Maximilian}, @@ -140,3 +139,28 @@ If you found this library useful in your research, please consider citing. year={2021} } ``` + +The seminorm option for computing adjoints is discussed in + +Patrick Kidger, Ricky T. Q. Chen, Terry Lyons. "'Hey, that’s not an ODE': Faster ODE Adjoints via Seminorms." *International Conference on Machine +Learning.* 2021. [[arxiv]](https://arxiv.org/abs/2009.09457) +``` +@article{kidger2021hey, + title={"Hey, that's not an ODE": Faster ODE Adjoints via Seminorms.}, + author={Kidger, Patrick and Chen, Ricky T. Q. and Lyons, Terry J.}, + journal={International Conference on Machine Learning}, + year={2021} +} +``` + +--- + +If you found this library useful in your research, please consider citing. +``` +@misc{torchdiffeq, + author={Chen, Ricky T. Q.}, + title={torchdiffeq}, + year={2018}, + url={https://github.com/rtqichen/torchdiffeq}, +} +``` diff --git a/examples/bouncing_ball.py b/examples/bouncing_ball.py index f9cc88ba8..714a29eec 100644 --- a/examples/bouncing_ball.py +++ b/examples/bouncing_ball.py @@ -12,7 +12,6 @@ class BouncingBallExample(nn.Module): - def __init__(self, radius=0.2, gravity=9.8, adjoint=False): super().__init__() self.gravity = nn.Parameter(torch.as_tensor([gravity])) @@ -39,9 +38,11 @@ def get_initial_state(self): return self.t0, state def state_update(self, state): - """ Updates state based on an event (collision).""" + """Updates state based on an event (collision).""" pos, vel, log_radius = state - pos = pos + 1e-7 # need to add a small eps so as not to trigger the event function immediately. + pos = ( + pos + 1e-7 + ) # need to add a small eps so as not to trigger the event function immediately. vel = -vel * (1 - self.absorption) return (pos, vel, log_radius) @@ -52,7 +53,16 @@ def get_collision_times(self, nbounces=1): t0, state = self.get_initial_state() for i in range(nbounces): - event_t, solution = odeint_event(self, state, t0, event_fn=self.event_fn, reverse_time=False, atol=1e-8, rtol=1e-8, odeint_interface=self.odeint) + event_t, solution = odeint_event( + self, + state, + t0, + event_fn=self.event_fn, + reverse_time=False, + atol=1e-8, + rtol=1e-8, + odeint_interface=self.odeint, + ) event_times.append(event_t) state = self.state_update(tuple(s[-1] for s in solution)) @@ -69,18 +79,25 @@ def simulate(self, nbounces=1): velocity = [state[1][None]] times = [t0.reshape(-1)] for event_t in event_times: - tt = torch.linspace(float(t0), float(event_t), int((float(event_t) - float(t0)) * 50))[1:-1] + tt = torch.linspace( + float(t0), float(event_t), int((float(event_t) - float(t0)) * 50) + )[1:-1] tt = torch.cat([t0.reshape(-1), tt, event_t.reshape(-1)]) solution = odeint(self, state, tt, atol=1e-8, rtol=1e-8) - trajectory.append(solution[0]) - velocity.append(solution[1]) - times.append(tt) + trajectory.append(solution[0][1:]) + velocity.append(solution[1][1:]) + times.append(tt[1:]) state = self.state_update(tuple(s[-1] for s in solution)) t0 = event_t - return torch.cat(times), torch.cat(trajectory, dim=0).reshape(-1), torch.cat(velocity, dim=0).reshape(-1), event_times + return ( + torch.cat(times), + torch.cat(trajectory, dim=0).reshape(-1), + torch.cat(velocity, dim=0).reshape(-1), + event_times, + ) def gradcheck(nbounces): @@ -124,7 +141,9 @@ def gradcheck(nbounces): fd = fd_grads[var] if torch.norm(analytical - fd) > 1e-4: success = False - print(f"Got analytical grad {analytical.item()} for {var} param but finite difference is {fd.item()}") + print( + f"Got analytical grad {analytical.item()} for {var} param but finite difference is {fd.item()}" + ) if not success: raise Exception("Gradient check failed.") @@ -152,10 +171,20 @@ def gradcheck(nbounces): # Event locations. for event_t in event_times: - plt.plot(event_t, 0.0, color="C0", marker="o", markersize=7, fillstyle='none', linestyle="") - - vel, = plt.plot(times, velocity, color="C1", alpha=0.7, linestyle="--", linewidth=2.0) - pos, = plt.plot(times, trajectory, color="C0", linewidth=2.0) + plt.plot( + event_t, + 0.0, + color="C0", + marker="o", + markersize=7, + fillstyle="none", + linestyle="", + ) + + (vel,) = plt.plot( + times, velocity, color="C1", alpha=0.7, linestyle="--", linewidth=2.0 + ) + (pos,) = plt.plot(times, trajectory, color="C0", linewidth=2.0) plt.hlines(0, 0, 100) plt.xlim([times[0], times[-1]]) @@ -164,16 +193,20 @@ def gradcheck(nbounces): plt.xlabel("Time", fontsize=13) plt.legend([pos, vel], ["Position", "Velocity"], fontsize=16) - plt.gca().xaxis.set_tick_params(direction='in', which='both') # The bottom will maintain the default of 'out' - plt.gca().yaxis.set_tick_params(direction='in', which='both') # The bottom will maintain the default of 'out' + plt.gca().xaxis.set_tick_params( + direction="in", which="both" + ) # The bottom will maintain the default of 'out' + plt.gca().yaxis.set_tick_params( + direction="in", which="both" + ) # The bottom will maintain the default of 'out' # Hide the right and top spines - plt.gca().spines['right'].set_visible(False) - plt.gca().spines['top'].set_visible(False) + plt.gca().spines["right"].set_visible(False) + plt.gca().spines["top"].set_visible(False) # Only show ticks on the left and bottom spines - plt.gca().yaxis.set_ticks_position('left') - plt.gca().xaxis.set_ticks_position('bottom') + plt.gca().yaxis.set_ticks_position("left") + plt.gca().xaxis.set_ticks_position("bottom") plt.tight_layout() plt.savefig("bouncing_ball.png") diff --git a/examples/learn_physics.py b/examples/learn_physics.py new file mode 100644 index 000000000..824852e4b --- /dev/null +++ b/examples/learn_physics.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 + +import argparse +import os +import math +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +from torchdiffeq import odeint, odeint_event + +from bouncing_ball import BouncingBallExample + + +class HamiltonianDynamics(nn.Module): + def __init__(self): + super().__init__() + self.dvel = nn.Linear(1, 1) + self.scale = nn.Parameter(torch.tensor(10.0)) + + def forward(self, t, state): + pos, vel, *rest = state + dpos = vel + dvel = torch.tanh(self.dvel(torch.zeros_like(vel))) * self.scale + return (dpos, dvel, *[torch.zeros_like(r) for r in rest]) + + +class EventFn(nn.Module): + def __init__(self): + super().__init__() + self.radius = nn.Parameter(torch.rand(1)) + + def parameters(self): + return [self.radius] + + def forward(self, t, state): + # IMPORTANT: event computation must use variables from the state. + pos, _, radius = state + return pos - radius.reshape_as(pos) ** 2 + + +class InstantaneousStateChange(nn.Module): + def __init__(self): + super().__init__() + self.net = nn.Linear(1, 1) + + def forward(self, t, state): + pos, vel, *rest = state + vel = -torch.sigmoid(self.net(torch.ones_like(vel))) * vel + return (pos, vel, *rest) + + +class NeuralPhysics(nn.Module): + def __init__(self): + super().__init__() + self.initial_pos = nn.Parameter(torch.tensor([10.0])) + self.initial_vel = nn.Parameter(torch.tensor([0.0])) + self.dynamics_fn = HamiltonianDynamics() + self.event_fn = EventFn() + self.inst_update = InstantaneousStateChange() + + def simulate(self, times): + + t0 = torch.tensor([0.0]).to(times) + + # Add a terminal time to the event function. + def event_fn(t, state): + if t > times[-1] + 1e-7: + return torch.zeros_like(t) + event_fval = self.event_fn(t, state) + return event_fval + + # IMPORTANT: for gradients of odeint_event to be computed, parameters of the event function + # must appear in the state in the current implementation. + state = (self.initial_pos, self.initial_vel, *self.event_fn.parameters()) + + event_times = [] + + trajectory = [state[0][None]] + + n_events = 0 + max_events = 20 + + while t0 < times[-1] and n_events < max_events: + last = n_events == max_events - 1 + + if not last: + event_t, solution = odeint_event( + self.dynamics_fn, + state, + t0, + event_fn=event_fn, + atol=1e-8, + rtol=1e-8, + method="dopri5", + ) + else: + event_t = times[-1] + + interval_ts = times[times > t0] + interval_ts = interval_ts[interval_ts <= event_t] + interval_ts = torch.cat([t0.reshape(-1), interval_ts.reshape(-1)]) + + solution_ = odeint( + self.dynamics_fn, state, interval_ts, atol=1e-8, rtol=1e-8 + ) + traj_ = solution_[0][1:] # [0] for position; [1:] to remove intial state. + trajectory.append(traj_) + + if event_t < times[-1]: + state = tuple(s[-1] for s in solution) + + # update velocity instantaneously. + state = self.inst_update(event_t, state) + + # advance the position a little bit to avoid re-triggering the event fn. + pos, *rest = state + pos = pos + 1e-7 * self.dynamics_fn(event_t, state)[0] + state = pos, *rest + + event_times.append(event_t) + t0 = event_t + + n_events += 1 + + # print(event_t.item(), state[0].item(), state[1].item(), self.event_fn.mod(pos).item()) + + trajectory = torch.cat(trajectory, dim=0).reshape(-1) + return trajectory, event_times + + +class Sine(nn.Module): + def forward(self, x): + return torch.sin(x) + + +class NeuralODE(nn.Module): + def __init__(self, aug_dim=2): + super().__init__() + self.initial_pos = nn.Parameter(torch.tensor([10.0])) + self.initial_aug = nn.Parameter(torch.zeros(aug_dim)) + self.odefunc = mlp( + input_dim=1 + aug_dim, + hidden_dim=64, + output_dim=1 + aug_dim, + hidden_depth=2, + act=Sine, + ) + + def init(m): + if isinstance(m, nn.Linear): + std = 1.0 / math.sqrt(m.weight.size(1)) + m.weight.data.uniform_(-2.0 * std, 2.0 * std) + m.bias.data.zero_() + + self.odefunc.apply(init) + + def forward(self, t, state): + return self.odefunc(state) + + def simulate(self, times): + x0 = torch.cat([self.initial_pos, self.initial_aug]).reshape(-1) + solution = odeint(self, x0, times, atol=1e-8, rtol=1e-8, method="dopri5") + trajectory = solution[:, 0] + return trajectory, [] + + +def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None, act=nn.ReLU): + if hidden_depth == 0: + mods = [nn.Linear(input_dim, output_dim)] + else: + mods = [nn.Linear(input_dim, hidden_dim), act()] + for i in range(hidden_depth - 1): + mods += [nn.Linear(hidden_dim, hidden_dim), act()] + mods.append(nn.Linear(hidden_dim, output_dim)) + if output_mod is not None: + mods.append(output_mod) + trunk = nn.Sequential(*mods) + return trunk + + +def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0): + global_step = min(global_step, decay_steps) + cosine_decay = 0.5 * (1 + math.cos(math.pi * global_step / decay_steps)) + decayed = (1 - alpha) * cosine_decay + alpha + return learning_rate * decayed + + +def learning_rate_schedule( + global_step, warmup_steps, base_learning_rate, lr_scaling, train_steps +): + warmup_steps = int(round(warmup_steps)) + scaled_lr = base_learning_rate * lr_scaling + if warmup_steps: + learning_rate = global_step / warmup_steps * scaled_lr + else: + learning_rate = scaled_lr + + if global_step < warmup_steps: + learning_rate = learning_rate + else: + learning_rate = cosine_decay( + scaled_lr, global_step - warmup_steps, train_steps - warmup_steps + ) + return learning_rate + + +def set_learning_rate(optimizer, lr): + for group in optimizer.param_groups: + group["lr"] = lr + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--base_lr", type=float, default=0.1) + parser.add_argument("--num_iterations", type=int, default=1000) + parser.add_argument("--no_events", action="http://github.com/store_true") + parser.add_argument("--save", type=str, default="figs") + args = parser.parse_args() + + torch.manual_seed(0) + + torch.set_default_dtype(torch.float64) + + with torch.no_grad(): + system = BouncingBallExample() + obs_times, gt_trajectory, _, _ = system.simulate(nbounces=4) + + obs_times = obs_times[:300] + gt_trajectory = gt_trajectory[:300] + + if args.no_events: + model = NeuralODE() + else: + model = NeuralPhysics() + optimizer = torch.optim.Adam(model.parameters(), lr=args.base_lr) + + decay = 1.0 + + model.train() + for itr in range(args.num_iterations): + optimizer.zero_grad() + trajectory, event_times = model.simulate(obs_times) + weights = decay**obs_times + loss = ( + ((trajectory - gt_trajectory) / (gt_trajectory + 1e-3)) + .abs() + .mul(weights) + .mean() + ) + loss.backward() + + lr = learning_rate_schedule(itr, 0, args.base_lr, 1.0, args.num_iterations) + set_learning_rate(optimizer, lr) + optimizer.step() + + if itr % 10 == 0: + print(itr, loss.item(), len(event_times)) + + if itr % 10 == 0: + plt.figure() + plt.plot( + obs_times.detach().cpu().numpy(), + gt_trajectory.detach().cpu().numpy(), + label="Target", + ) + plt.plot( + obs_times.detach().cpu().numpy(), + trajectory.detach().cpu().numpy(), + label="Learned", + ) + plt.tight_layout() + os.makedirs(args.save, exist_ok=True) + plt.savefig(f"{args.save}/{itr:05d}.png") + plt.close() + + if (itr + 1) % 100 == 0: + torch.save( + { + "state_dict": model.state_dict(), + }, + f"{args.save}/model.pt", + ) + + del trajectory, loss diff --git a/setup.py b/setup.py index e3d21e5a5..c671a064c 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ description="ODE solvers and adjoint sensitivity analysis in PyTorch.", url="https://github.com/rtqichen/torchdiffeq", packages=setuptools.find_packages(), - install_requires=['torch>=1.3.0', 'scipy>=1.4.0'], + install_requires=['torch>=1.5.0', 'scipy>=1.4.0'], python_requires='~=3.6', classifiers=[ "Programming Language :: Python :: 3", diff --git a/tests/api_tests.py b/tests/api_tests.py index 99c7fa6fe..3dc208586 100644 --- a/tests/api_tests.py +++ b/tests/api_tests.py @@ -5,7 +5,7 @@ from problems import construct_problem, DTYPES, DEVICES, ADAPTIVE_METHODS -EPS = {torch.float32: 1e-5, torch.float64: 1e-12} +EPS = {torch.float32: 1e-4, torch.float64: 1e-12, torch.complex64: 1e-4} class TestCollectionState(unittest.TestCase): @@ -20,8 +20,8 @@ def test_forward(self): with self.subTest(dtype=dtype, device=device, method=method): tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method=method) - max_error0 = (sol - tuple_y[0]).max() - max_error1 = (sol - tuple_y[1]).max() + max_error0 = (sol - tuple_y[0]).abs().max() + max_error1 = (sol - tuple_y[1]).abs().max() self.assertLess(max_error0, eps) self.assertLess(max_error1, eps) diff --git a/tests/event_tests.py b/tests/event_tests.py index daa9e3e3d..176e3fc91 100644 --- a/tests/event_tests.py +++ b/tests/event_tests.py @@ -25,8 +25,10 @@ def test_odeint(self): with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, method=method): if method == "explicit_adams": tol = 7e-2 - elif method == "euler": + elif method == "euler" or method == "implicit_euler": tol = 5e-3 + elif method == "gl6": + tol = 2e-3 else: tol = 1e-4 @@ -34,7 +36,7 @@ def test_odeint(self): reverse=reverse) def event_fn(t, y): - return torch.sum(y - sol[2]) + return torch.sum(y - sol[2]).real if method in FIXED_METHODS: options = {"step_size": 0.01, "interp": "cubic"} diff --git a/tests/gradient_tests.py b/tests/gradient_tests.py index 08aff899d..bbb746486 100644 --- a/tests/gradient_tests.py +++ b/tests/gradient_tests.py @@ -44,6 +44,8 @@ def test_adjoint_against_odeint(self): eps = 1e-5 elif ode == 'sine': eps = 5e-3 + elif ode == 'exp': + eps = 1e-2 else: raise RuntimeError @@ -102,7 +104,6 @@ def forward(self, t, y): return func, y0, t_points def test_against_dopri5(self): - # TODO: add in adaptive adams if/when it's fixed. method_eps = { 'dopri5': (3e-4, 1e-4, 2e-3), 'scipy_solver': (3e-4, 1e-4, 2e-3), diff --git a/tests/norm_tests.py b/tests/norm_tests.py index 70b444eef..30abbda25 100644 --- a/tests/norm_tests.py +++ b/tests/norm_tests.py @@ -257,7 +257,7 @@ def large_norm(tensor): with self.subTest(dtype=dtype, device=device, method=method): x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype) - t = torch.tensor([0., 1.0], device=device, dtype=dtype) + t = torch.tensor([0., 1.0], device=device, dtype=torch.float64) norm_f = _NeuralF(width=10, oscillate=True).to(device, dtype) torchdiffeq.odeint(norm_f, x0, t, method=method, options=dict(norm=norm)) @@ -273,16 +273,22 @@ def test_seminorm(self): for dtype in DTYPES: for device in DEVICES: for method in ADAPTIVE_METHODS: + # Tests with known failures + if ( + dtype in [torch.float32] and + method in ['tsit5'] + ): + continue with self.subTest(dtype=dtype, device=device, method=method): - if dtype == torch.float32: - tol = 1e-6 - else: + if dtype == torch.float64: tol = 1e-8 + else: + tol = 1e-6 x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype) - t = torch.tensor([0., 1.0], device=device, dtype=dtype) + t = torch.tensor([0., 1.0], device=device, dtype=torch.float64) ode_f = _NeuralF(width=1024, oscillate=True).to(device, dtype) diff --git a/tests/odeint_tests.py b/tests/odeint_tests.py index 5ebf07378..620839eef 100644 --- a/tests/odeint_tests.py +++ b/tests/odeint_tests.py @@ -5,7 +5,7 @@ import torch import torchdiffeq -from problems import (construct_problem, PROBLEMS, DTYPES, DEVICES, METHODS, ADAPTIVE_METHODS, FIXED_METHODS) +from problems import (construct_problem, PROBLEMS, DTYPES, DEVICES, METHODS, ADAPTIVE_METHODS, FIXED_METHODS, SCIPY_METHODS, IMPLICIT_METHODS) def rel_error(true, estimate): @@ -13,12 +13,17 @@ def rel_error(true, estimate): class TestSolverError(unittest.TestCase): + def test_odeint(self): for reverse in (False, True): for dtype in DTYPES: for device in DEVICES: for method in METHODS: + if method in SCIPY_METHODS and dtype == torch.complex64: + # scipy solvers don't support complex types. + continue + kwargs = dict() # Have to increase tolerance for dopri8. if method == 'dopri8' and dtype == torch.float64: @@ -26,12 +31,23 @@ def test_odeint(self): if method == 'dopri8' and dtype == torch.float32: kwargs = dict(rtol=1e-7, atol=1e-7) - problems = PROBLEMS if method in ADAPTIVE_METHODS else ('constant',) + if method in ADAPTIVE_METHODS: + if method in IMPLICIT_METHODS: + problems = PROBLEMS + else: + problems = tuple(problem for problem in PROBLEMS) + elif method in IMPLICIT_METHODS: + problems = ('constant', 'exp') + else: + problems = ('constant',) + for ode in problems: if method in ['adaptive_heun', 'bosh3']: eps = 4e-3 elif ode == 'linear': eps = 2e-3 + elif ode == 'exp': + eps = 5e-2 else: eps = 3e-4 @@ -71,6 +87,9 @@ def test_odeint(self): with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, solver=solver): f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode, reverse=reverse) + if torch.is_complex(y0) and solver in ["Radau", "LSODA"]: + # scipy solvers don't support complex types. + continue y = torchdiffeq.odeint(f, y0, t_points, method='scipy_solver', options={"solver": solver}) self.assertTrue(sol.shape == y.shape) self.assertLess(rel_error(sol, y), eps) @@ -87,6 +106,7 @@ def test_odeint(self): with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, method=method): f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode, reverse=reverse) + y = torchdiffeq.odeint(f, y0, t_points[0:1], method=method) self.assertLess((sol[0] - y).abs().max(), 1e-12) @@ -112,6 +132,10 @@ def test_odeint_jump_t(self): with self.subTest(adjoint=adjoint, dtype=dtype, device=device, method=method): + if method == "dopri8": + # Doesn't seem to work for some reason. + continue + x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype, requires_grad=True) t = torch.tensor([0., 1.0], device=device) @@ -142,6 +166,11 @@ def test_odeint_perturb(self): for dtype in DTYPES: for device in DEVICES: for method in FIXED_METHODS: + + # Singluar matrix error with float32 and implicit_euler + if dtype == torch.float32 and method == 'implicit_euler': + continue + for perturb in (True, False): with self.subTest(adjoint=adjoint, dtype=dtype, device=device, method=method, perturb=perturb): @@ -219,5 +248,143 @@ def grid_constructor(f, y0, t): self.assertLess((x0.grad - true_x0_grad).abs().max(), 1e-6) +class TestMinMaxStep(unittest.TestCase): + def test_min_max_step(self): + # LSODA will complain about convergence otherwise + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + for device in DEVICES: + for min_step in (0, 2): + for max_step in (float('inf'), 5): + for method, options in [('dopri5', {}), ('scipy_solver', {"solver": "LSODA"})]: + options['min_step'] = min_step + options['max_step'] = max_step + f, y0, t_points, sol = construct_problem(device=device, ode="linear") + torchdiffeq.odeint(f, y0, t_points, method=method, options=options) + # Check min step produces far fewer evaluations + if min_step > 0: + self.assertLess(f.nfe, 50) + else: + self.assertGreater(f.nfe, 100) + + +class _NeuralF(torch.nn.Module): + def __init__(self, width, oscillate): + super(_NeuralF, self).__init__() + self.linears = torch.nn.Sequential(torch.nn.Linear(2, width), + torch.nn.Tanh(), + torch.nn.Linear(width, 2), + torch.nn.Tanh()) + self.nfe = 0 + self.oscillate = oscillate + + def forward(self, t, x): + self.nfe += 1 + out = self.linears(x) + if self.oscillate: + out = out * t.mul(20).sin() + return out + + +class TestCallbacks(unittest.TestCase): + def test_wrong_callback(self): + x0 = torch.tensor([1.0, 2.0]) + t = torch.tensor([0., 1.0]) + + for method in FIXED_METHODS: + for callback_name in ('callback_accept_step', 'callback_reject_step'): + with self.subTest(method=method): + f = _NeuralF(width=10, oscillate=False) + setattr(f, callback_name, lambda t0, y0, dt: None) + with self.assertWarns(Warning): + torchdiffeq.odeint(f, x0, t, method=method) + + for method in SCIPY_METHODS: + for callback_name in ('callback_step', 'callback_accept_step', 'callback_reject_step'): + with self.subTest(method=method): + f = _NeuralF(width=10, oscillate=False) + setattr(f, callback_name, lambda t0, y0, dt: None) + with self.assertWarns(Warning): + torchdiffeq.odeint(f, x0, t, method=method) + + def test_steps(self): + for forward, adjoint in ((False, True), (True, False), (True, True)): + for method in FIXED_METHODS + ADAPTIVE_METHODS: + if method == 'dopri8': # using torch.float32 + continue + with self.subTest(forward=forward, adjoint=adjoint, method=method): + + f = _NeuralF(width=10, oscillate=False) + + if forward: + forward_counter = 0 + forward_accept_counter = 0 + forward_reject_counter = 0 + + def callback_step(t0, y0, dt): + nonlocal forward_counter + forward_counter += 1 + + def callback_accept_step(t0, y0, dt): + nonlocal forward_accept_counter + forward_accept_counter += 1 + + def callback_reject_step(t0, y0, dt): + nonlocal forward_reject_counter + forward_reject_counter += 1 + + f.callback_step = callback_step + if method in ADAPTIVE_METHODS: + f.callback_accept_step = callback_accept_step + f.callback_reject_step = callback_reject_step + + if adjoint: + adjoint_counter = 0 + adjoint_accept_counter = 0 + adjoint_reject_counter = 0 + + def callback_step_adjoint(t0, y0, dt): + nonlocal adjoint_counter + adjoint_counter += 1 + + def callback_accept_step_adjoint(t0, y0, dt): + nonlocal adjoint_accept_counter + adjoint_accept_counter += 1 + + def callback_reject_step_adjoint(t0, y0, dt): + nonlocal adjoint_reject_counter + adjoint_reject_counter += 1 + + f.callback_step_adjoint = callback_step_adjoint + if method in ADAPTIVE_METHODS: + f.callback_accept_step_adjoint = callback_accept_step_adjoint + f.callback_reject_step_adjoint = callback_reject_step_adjoint + + x0 = torch.tensor([1.0, 2.0]) + t = torch.tensor([0., 1.0]) + + if method in FIXED_METHODS: + kwargs = dict(options=dict(step_size=0.1)) + elif method == 'implicit_adams': + kwargs = dict(rtol=1e-3, atol=1e-4) + else: + kwargs = {} + xs = torchdiffeq.odeint_adjoint(f, x0, t, method=method, **kwargs) + + if forward: + if method in FIXED_METHODS: + self.assertEqual(forward_counter, 10) + if method in ADAPTIVE_METHODS: + self.assertGreater(forward_counter, 0) + self.assertEqual(forward_accept_counter + forward_reject_counter, forward_counter) + if adjoint: + xs.sum().backward() + if method in FIXED_METHODS: + self.assertEqual(adjoint_counter, 10) + if method in ADAPTIVE_METHODS: + self.assertGreater(adjoint_counter, 0) + self.assertEqual(adjoint_accept_counter + adjoint_reject_counter, adjoint_counter) + + if __name__ == '__main__': unittest.main() diff --git a/tests/problems.py b/tests/problems.py index 96d8a0988..da945d4bf 100644 --- a/tests/problems.py +++ b/tests/problems.py @@ -38,8 +38,10 @@ def __init__(self, dim=10): A = 2 * U - (U + U.transpose(0, 1)) self.A = torch.nn.Parameter(A) self.initial_val = np.ones((dim, 1)) + self.nfe = 0 def forward(self, t, y): + self.nfe += 1 return torch.mm(self.A, y.reshape(self.dim, 1)).reshape(-1) def y_exact(self, t): @@ -51,15 +53,26 @@ def y_exact(self, t): return torch.stack([torch.tensor(ans_) for ans_ in ans]).reshape(len(t_numpy), self.dim).to(t) -PROBLEMS = {'constant': ConstantODE, 'linear': LinearODE, 'sine': SineODE} +class ExpODE(torch.nn.Module): + def forward(self, t, y): + return -0.1 * self.y_exact(t) + + def y_exact(self, t): + return torch.exp(-0.1 * t) + + +PROBLEMS = {'constant': ConstantODE, 'linear': LinearODE, 'sine': SineODE, 'exp': ExpODE} DTYPES = (torch.float32, torch.float64) DEVICES = ['cpu'] if torch.cuda.is_available(): DEVICES.append('cuda') -FIXED_METHODS = ('euler', 'midpoint', 'rk4', 'explicit_adams', 'implicit_adams') +FIXED_EXPLICIT_METHODS = ('euler', 'midpoint', 'heun2', 'heun3', 'rk4', 'explicit_adams', 'implicit_adams') +FIXED_IMPLICIT_METHODS = ('implicit_euler', 'implicit_midpoint', 'trapezoid', 'radauIIA3', 'gl4', 'radauIIA5', 'gl6', 'sdirk2', 'trbdf2') +FIXED_METHODS = FIXED_EXPLICIT_METHODS + FIXED_IMPLICIT_METHODS ADAMS_METHODS = ('explicit_adams', 'implicit_adams') -ADAPTIVE_METHODS = ('adaptive_heun', 'fehlberg2', 'bosh3', 'dopri5', 'dopri8') +ADAPTIVE_METHODS = ('adaptive_heun', 'fehlberg2', 'bosh3', 'tsit5', 'dopri5', 'dopri8') SCIPY_METHODS = ('scipy_solver',) +IMPLICIT_METHODS = FIXED_IMPLICIT_METHODS METHODS = FIXED_METHODS + ADAPTIVE_METHODS + SCIPY_METHODS @@ -67,8 +80,8 @@ def construct_problem(device, npts=10, ode='constant', reverse=False, dtype=torc f = PROBLEMS[ode]().to(dtype=dtype, device=device) - t_points = torch.linspace(1, 8, npts, dtype=dtype, device=device, requires_grad=True) - sol = f.y_exact(t_points) + t_points = torch.linspace(1, 8, npts, dtype=torch.float64, device=device, requires_grad=True) + sol = f.y_exact(t_points).to(dtype) def _flip(x, dim): indices = [slice(None)] * x.dim() diff --git a/torchdiffeq/__init__.py b/torchdiffeq/__init__.py index 4eff75dbf..b966c7d02 100644 --- a/torchdiffeq/__init__.py +++ b/torchdiffeq/__init__.py @@ -1,4 +1,5 @@ from ._impl import odeint from ._impl import odeint_adjoint from ._impl import odeint_event -__version__ = "0.2.3" +from ._impl import odeint_dense +__version__ = "0.2.5" diff --git a/torchdiffeq/_impl/__init__.py b/torchdiffeq/_impl/__init__.py index 05b671e9c..60d7e0eb1 100644 --- a/torchdiffeq/_impl/__init__.py +++ b/torchdiffeq/_impl/__init__.py @@ -1,2 +1,2 @@ -from .odeint import odeint, odeint_event +from .odeint import odeint, odeint_dense, odeint_event from .adjoint import odeint_adjoint diff --git a/torchdiffeq/_impl/adjoint.py b/torchdiffeq/_impl/adjoint.py index ca9ba7f6d..72e4a685b 100644 --- a/torchdiffeq/_impl/adjoint.py +++ b/torchdiffeq/_impl/adjoint.py @@ -2,8 +2,7 @@ import torch import torch.nn as nn from .odeint import SOLVERS, odeint -from .misc import _check_inputs, _flat_to_shape -from .misc import _mixed_norm +from .misc import _check_inputs, _flat_to_shape, _mixed_norm, _all_callback_names, _all_adjoint_callback_names class OdeintAdjointMethod(torch.autograd.Function): @@ -105,6 +104,15 @@ def augmented_dynamics(t, y_aug): return (vjp_t, func_eval, vjp_y, *vjp_params) + # Add adjoint callbacks + for callback_name, adjoint_callback_name in zip(_all_callback_names, _all_adjoint_callback_names): + try: + callback = getattr(func, adjoint_callback_name) + except AttributeError: + pass + else: + setattr(augmented_dynamics, callback_name, callback) + ################################## # Solve adjoint ODE # ################################## diff --git a/torchdiffeq/_impl/fixed_grid.py b/torchdiffeq/_impl/fixed_grid.py index 7578627c4..29a020f7b 100644 --- a/torchdiffeq/_impl/fixed_grid.py +++ b/torchdiffeq/_impl/fixed_grid.py @@ -1,5 +1,5 @@ from .solvers import FixedGridODESolver -from .rk_common import rk4_alt_step_func +from .rk_common import rk4_alt_step_func, rk3_step_func, rk2_step_func from .misc import Perturb @@ -27,3 +27,34 @@ class RK4(FixedGridODESolver): def _step_func(self, func, t0, dt, t1, y0): f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) return rk4_alt_step_func(func, t0, dt, t1, y0, f0=f0, perturb=self.perturb), f0 + + +class Heun3(FixedGridODESolver): + order = 3 + + def _step_func(self, func, t0, dt, t1, y0): + f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) + + butcher_tableu = [ + [0.0, 0.0, 0.0, 0.0], + [1/3, 1/3, 0.0, 0.0], + [2/3, 0.0, 2/3, 0.0], + [0.0, 1/4, 0.0, 3/4], + ] + + return rk3_step_func(func, t0, dt, t1, y0, butcher_tableu=butcher_tableu, f0=f0, perturb=self.perturb), f0 + + +class Heun2(FixedGridODESolver): + order = 2 + + def _step_func(self, func, t0, dt, t1, y0): + f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) + + butcher_tableu = [ + [0.0, 0.0, 0.0], + [1.0, 1.0, 0.0], + [0.0, 1/2, 1/2], + ] + + return rk2_step_func(func, t0, dt, t1, y0, butcher_tableu=butcher_tableu, f0=f0, perturb=self.perturb), f0 diff --git a/torchdiffeq/_impl/fixed_grid_implicit.py b/torchdiffeq/_impl/fixed_grid_implicit.py new file mode 100644 index 000000000..7519efc8f --- /dev/null +++ b/torchdiffeq/_impl/fixed_grid_implicit.py @@ -0,0 +1,140 @@ +import torch +from .rk_common import FixedGridFIRKODESolver, FixedGridDIRKODESolver +from .rk_common import _ButcherTableau + +_sqrt_2 = torch.sqrt(torch.tensor(2, dtype=torch.float64)).item() +_sqrt_3 = torch.sqrt(torch.tensor(3, dtype=torch.float64)).item() +_sqrt_6 = torch.sqrt(torch.tensor(6, dtype=torch.float64)).item() +_sqrt_15 = torch.sqrt(torch.tensor(15, dtype=torch.float64)).item() + +_IMPLICIT_EULER_TABLEAU = _ButcherTableau( + alpha=torch.tensor([1], dtype=torch.float64), + beta=[ + torch.tensor([1], dtype=torch.float64), + ], + c_sol=torch.tensor([1], dtype=torch.float64), + c_error=torch.tensor([], dtype=torch.float64), +) + +class ImplicitEuler(FixedGridFIRKODESolver): + order = 1 + tableau = _IMPLICIT_EULER_TABLEAU + +_IMPLICIT_MIDPOINT_TABLEAU = _ButcherTableau( + alpha=torch.tensor([1 / 2], dtype=torch.float64), + beta=[ + torch.tensor([1 / 2], dtype=torch.float64), + + ], + c_sol=torch.tensor([1], dtype=torch.float64), + c_error=torch.tensor([], dtype=torch.float64), +) + +class ImplicitMidpoint(FixedGridFIRKODESolver): + order = 2 + tableau = _IMPLICIT_MIDPOINT_TABLEAU + +_GAUSS_LEGENDRE_4_TABLEAU = _ButcherTableau( + alpha=torch.tensor([1 / 2 - _sqrt_3 / 6, 1 / 2 - _sqrt_3 / 6], dtype=torch.float64), + beta=[ + torch.tensor([1 / 4, 1 / 4 - _sqrt_3 / 6], dtype=torch.float64), + torch.tensor([1 / 4 + _sqrt_3 / 6, 1 / 4], dtype=torch.float64), + ], + c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64), + c_error=torch.tensor([], dtype=torch.float64), +) + +_TRAPEZOID_TABLEAU = _ButcherTableau( + alpha=torch.tensor([0, 1], dtype=torch.float64), + beta=[ + torch.tensor([0, 0], dtype=torch.float64), + torch.tensor([1 /2, 1 / 2], dtype=torch.float64), + ], + c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64), + c_error=torch.tensor([], dtype=torch.float64), +) + +class Trapezoid(FixedGridFIRKODESolver): + order = 2 + tableau = _TRAPEZOID_TABLEAU + + +class GaussLegendre4(FixedGridFIRKODESolver): + order = 4 + tableau = _GAUSS_LEGENDRE_4_TABLEAU + +_GAUSS_LEGENDRE_6_TABLEAU = _ButcherTableau( + alpha=torch.tensor([1 / 2 - _sqrt_15 / 10, 1 / 2, 1 / 2 + _sqrt_15 / 10], dtype=torch.float64), + beta=[ + torch.tensor([5 / 36 , 2 / 9 - _sqrt_15 / 15, 5 / 36 - _sqrt_15 / 30], dtype=torch.float64), + torch.tensor([5 / 36 + _sqrt_15 / 24, 2 / 9 , 5 / 36 - _sqrt_15 / 24], dtype=torch.float64), + torch.tensor([5 / 36 + _sqrt_15 / 30, 2 / 9 + _sqrt_15 / 15, 5 / 36 ], dtype=torch.float64), + ], + c_sol=torch.tensor([5 / 18, 4 / 9, 5 / 18], dtype=torch.float64), + c_error=torch.tensor([], dtype=torch.float64), +) + +class GaussLegendre6(FixedGridFIRKODESolver): + order = 6 + tableau = _GAUSS_LEGENDRE_6_TABLEAU + +_RADAU_IIA_3_TABLEAU = _ButcherTableau( + alpha=torch.tensor([1 / 3, 1], dtype=torch.float64), + beta=[ + torch.tensor([5 / 12, -1 / 12], dtype=torch.float64), + torch.tensor([3 / 4, 1 / 4], dtype=torch.float64) + ], + c_sol=torch.tensor([3 / 4, 1 / 4], dtype=torch.float64), + c_error=torch.tensor([], dtype=torch.float64) +) + +class RadauIIA3(FixedGridFIRKODESolver): + order = 3 + tableau = _RADAU_IIA_3_TABLEAU + +_RADAU_IIA_5_TABLEAU = _ButcherTableau( + alpha=torch.tensor([2 / 5 - _sqrt_6 / 10, 2 / 5 + _sqrt_6 / 10, 1], dtype=torch.float64), + beta=[ + torch.tensor([11 / 45 - 7 * _sqrt_6 / 360 , 37 / 225 - 169 * _sqrt_6 / 1800, -2 / 225 + _sqrt_6 / 75], dtype=torch.float64), + torch.tensor([37 / 225 + 169 * _sqrt_6 / 1800, 11 / 45 + 7 * _sqrt_6 / 360 , -2 / 225 - _sqrt_6 / 75], dtype=torch.float64), + torch.tensor([4 / 9 - _sqrt_6 / 36 , 4 / 9 + _sqrt_6 / 36 , 1 / 9], dtype=torch.float64) + ], + c_sol=torch.tensor([4 / 9 - _sqrt_6 / 36, 4 / 9 + _sqrt_6 / 36, 1 / 9], dtype=torch.float64), + c_error=torch.tensor([], dtype=torch.float64) +) + +class RadauIIA5(FixedGridFIRKODESolver): + order = 5 + tableau = _RADAU_IIA_5_TABLEAU + +gamma = (2. - _sqrt_2) / 2. +_SDIRK_2_TABLEAU = _ButcherTableau( + alpha = torch.tensor([gamma, 1], dtype=torch.float64), + beta=[ + torch.tensor([gamma], dtype=torch.float64), + torch.tensor([1 - gamma, gamma], dtype=torch.float64), + ], + c_sol=torch.tensor([1 - gamma, gamma], dtype=torch.float64), + c_error=torch.tensor([], dtype=torch.float64) +) + +class SDIRK2(FixedGridDIRKODESolver): + order = 2 + tableau = _SDIRK_2_TABLEAU + +gamma = 1. - _sqrt_2 / 2. +beta = _sqrt_2 / 4. +_TRBDF_2_TABLEAU = _ButcherTableau( + alpha = torch.tensor([0, 2 * gamma, 1], dtype=torch.float64), + beta=[ + torch.tensor([0], dtype=torch.float64), + torch.tensor([gamma, gamma], dtype=torch.float64), + torch.tensor([beta, beta, gamma], dtype=torch.float64), + ], + c_sol=torch.tensor([beta, beta, gamma], dtype=torch.float64), + c_error=torch.tensor([], dtype=torch.float64) +) + +class TRBDF2(FixedGridDIRKODESolver): + order = 2 + tableau = _TRBDF_2_TABLEAU diff --git a/torchdiffeq/_impl/misc.py b/torchdiffeq/_impl/misc.py index 07cbb14dc..685cc8da2 100644 --- a/torchdiffeq/_impl/misc.py +++ b/torchdiffeq/_impl/misc.py @@ -6,17 +6,21 @@ from .event_handling import combine_event_functions +_all_callback_names = ['callback_step', 'callback_accept_step', 'callback_reject_step'] +_all_adjoint_callback_names = [name + '_adjoint' for name in _all_callback_names] +_null_callback = lambda *args, **kwargs: None + def _handle_unused_kwargs(solver, unused_kwargs): if len(unused_kwargs) > 0: warnings.warn('{}: Unexpected arguments {}'.format(solver.__class__.__name__, unused_kwargs)) def _linf_norm(tensor): - return tensor.max() + return tensor.abs().max() def _rms_norm(tensor): - return tensor.pow(2).mean().sqrt() + return tensor.abs().pow(2).mean().sqrt() def _zero_norm(tensor): @@ -43,37 +47,39 @@ def _select_initial_step(func, t0, y0, order, rtol, atol, norm, f0=None): dtype = y0.dtype device = y0.device t_dtype = t0.dtype - t0 = t0.to(dtype) + t0 = t0.to(t_dtype) if f0 is None: f0 = func(t0, y0) scale = atol + torch.abs(y0) * rtol - d0 = norm(y0 / scale) - d1 = norm(f0 / scale) + d0 = norm(y0 / scale).abs() + d1 = norm(f0 / scale).abs() if d0 < 1e-5 or d1 < 1e-5: h0 = torch.tensor(1e-6, dtype=dtype, device=device) else: h0 = 0.01 * d0 / d1 + h0 = h0.abs() y1 = y0 + h0 * f0 f1 = func(t0 + h0, y1) - d2 = norm((f1 - f0) / scale) / h0 + d2 = torch.abs(norm((f1 - f0) / scale) / h0) if d1 <= 1e-15 and d2 <= 1e-15: h1 = torch.max(torch.tensor(1e-6, dtype=dtype, device=device), h0 * 1e-3) else: h1 = (0.01 / max(d1, d2)) ** (1. / float(order + 1)) + h1 = h1.abs() return torch.min(100 * h0, h1).to(t_dtype) def _compute_error_ratio(error_estimate, rtol, atol, y0, y1, norm): error_tol = atol + rtol * torch.max(y0.abs(), y1.abs()) - return norm(error_estimate / error_tol) + return norm(error_estimate / error_tol).abs() @torch.no_grad() @@ -176,7 +182,9 @@ def forward(self, t, y, *, perturb=Perturb.NONE): # This dtype change here might be buggy. # The exact time value should be determined inside the solver, # but this can slightly change it due to numerical differences during casting. - t = t.to(y.dtype) + if torch.is_complex(t): + t = t.real + t = t.to(y.abs().dtype) if perturb is Perturb.NEXT: # Replace with next smallest representable value. t = _nextafter(t, t + 1) @@ -198,6 +206,9 @@ def _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS): # Combine event functions if the output is multivariate. event_fn = combine_event_functions(event_fn, t[0], y0) + # Keep reference to original func as passed in + original_func = func + # Normalise to tensor (non-tupled) input shapes = None is_tuple = not isinstance(y0, torch.Tensor) @@ -210,7 +221,6 @@ def _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS): func = _TupleFunc(func, shapes) if event_fn is not None: event_fn = _TupleInputOnlyFunc(event_fn, shapes) - _assert_floating('y0', y0) # Normalise method and options if options is None: @@ -300,6 +310,38 @@ def _norm(tensor): # Add perturb argument to func. func = _PerturbFunc(func) + # Add callbacks to wrapped_func + callback_names = set() + for callback_name in _all_callback_names: + try: + callback = getattr(original_func, callback_name) + except AttributeError: + setattr(func, callback_name, _null_callback) + else: + if callback is not _null_callback: + callback_names.add(callback_name) + # At the moment all callbacks have the arguments (t0, y0, dt). + # These will need adjusting on a per-callback basis if that changes in the future. + if is_tuple: + def callback(t0, y0, dt, _callback=callback): + y0 = _flat_to_shape(y0, (), shapes) + return _callback(t0, y0, dt) + if t_is_reversed: + def callback(t0, y0, dt, _callback=callback): + return _callback(-t0, y0, dt) + setattr(func, callback_name, callback) + for callback_name in _all_adjoint_callback_names: + try: + callback = getattr(original_func, callback_name) + except AttributeError: + pass + else: + setattr(func, callback_name, callback) + + invalid_callbacks = callback_names - SOLVERS[method].valid_callbacks() + if len(invalid_callbacks) > 0: + warnings.warn("Solver '{}' does not support callbacks {}".format(method, invalid_callbacks)) + return shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py index a174219ad..14a01efee 100644 --- a/torchdiffeq/_impl/odeint.py +++ b/torchdiffeq/_impl/odeint.py @@ -4,23 +4,41 @@ from .bosh3 import Bosh3Solver from .adaptive_heun import AdaptiveHeunSolver from .fehlberg2 import Fehlberg2 -from .fixed_grid import Euler, Midpoint, RK4 +from .fixed_grid import Euler, Midpoint, Heun2, Heun3, RK4 +from .fixed_grid_implicit import ImplicitEuler, ImplicitMidpoint, Trapezoid +from .fixed_grid_implicit import GaussLegendre4, GaussLegendre6 +from .fixed_grid_implicit import RadauIIA3, RadauIIA5 +from .fixed_grid_implicit import SDIRK2, TRBDF2 from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton from .dopri8 import Dopri8Solver +from .tsit5 import Tsit5Solver from .scipy_wrapper import ScipyWrapperODESolver from .misc import _check_inputs, _flat_to_shape +from .interp import _interp_evaluate SOLVERS = { 'dopri8': Dopri8Solver, 'dopri5': Dopri5Solver, + 'tsit5': Tsit5Solver, 'bosh3': Bosh3Solver, 'fehlberg2': Fehlberg2, 'adaptive_heun': AdaptiveHeunSolver, 'euler': Euler, 'midpoint': Midpoint, + 'heun2': Heun2, + 'heun3': Heun3, 'rk4': RK4, 'explicit_adams': AdamsBashforth, 'implicit_adams': AdamsBashforthMoulton, + 'implicit_euler': ImplicitEuler, + 'implicit_midpoint': ImplicitMidpoint, + 'trapezoid': Trapezoid, + 'radauIIA3': RadauIIA3, + 'gl4': GaussLegendre4, + 'radauIIA5': RadauIIA5, + 'gl6': GaussLegendre6, + 'sdirk2': SDIRK2, + 'trbdf2': TRBDF2, # Backward compatibility: use the same name as before 'fixed_adams': AdamsBashforthMoulton, # ~Backwards compatibility @@ -90,6 +108,55 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even return event_t, solution +def odeint_dense(func, y0, t0, t1, *, rtol=1e-7, atol=1e-9, method=None, options=None): + + assert torch.is_tensor(y0) # TODO: handle tuple of tensors + + t = torch.tensor([t0, t1]).to(t0) + + shapes, func, y0, t, rtol, atol, method, options, _, _ = _check_inputs(func, y0, t, rtol, atol, method, options, None, SOLVERS) + + assert method == "dopri5" + + solver = Dopri5Solver(func=func, y0=y0, rtol=rtol, atol=atol, **options) + + # The integration loop + solution = torch.empty(len(t), *solver.y0.shape, dtype=solver.y0.dtype, device=solver.y0.device) + solution[0] = solver.y0 + t = t.to(solver.dtype) + solver._before_integrate(t) + t0 = solver.rk_state.t0 + + times = [t0] + interp_coeffs = [] + + for i in range(1, len(t)): + next_t = t[i] + while next_t > solver.rk_state.t1: + solver.rk_state = solver._adaptive_step(solver.rk_state) + t1 = solver.rk_state.t1 + + if t1 != t0: + # Step accepted. + t0 = t1 + times.append(t1) + interp_coeffs.append(torch.stack(solver.rk_state.interp_coeff)) + + solution[i] = _interp_evaluate(solver.rk_state.interp_coeff, solver.rk_state.t0, solver.rk_state.t1, next_t) + + times = torch.stack(times).reshape(-1).cpu() + interp_coeffs = torch.stack(interp_coeffs) + + def dense_output_fn(t_eval): + idx = torch.searchsorted(times, t_eval, side="right") + t0 = times[idx - 1] + t1 = times[idx] + coef = [interp_coeffs[idx - 1][i] for i in range(interp_coeffs.shape[1])] + return _interp_evaluate(coef, t0, t1, t_eval) + + return dense_output_fn + + def odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=odeint, **kwargs): """Automatically links up the gradient from the event time.""" diff --git a/torchdiffeq/_impl/rk_common.py b/torchdiffeq/_impl/rk_common.py index 808fcc690..f0050dbfd 100644 --- a/torchdiffeq/_impl/rk_common.py +++ b/torchdiffeq/_impl/rk_common.py @@ -5,9 +5,11 @@ from .interp import _interp_evaluate, _interp_fit from .misc import (_compute_error_ratio, _select_initial_step, - _optimal_step_size) + _optimal_step_size, + _handle_unused_kwargs) from .misc import Perturb -from .solvers import AdaptiveStepsizeEventODESolver +from .solvers import AdaptiveStepsizeEventODESolver, FixedGridODESolver +import warnings _ButcherTableau = collections.namedtuple('_ButcherTableau', 'alpha, beta, c_sol, c_error') @@ -56,9 +58,11 @@ def _runge_kutta_step(func, y0, f0, t0, dt, t1, tableau): calculating these terms. """ - t0 = t0.to(y0.dtype) - dt = dt.to(y0.dtype) - t1 = t1.to(y0.dtype) + t_dtype = y0.abs().dtype + + t0 = t0.to(t_dtype) + dt = dt.to(t_dtype) + t1 = t1.to(t_dtype) # We use an unchecked assign to put data into k without incrementing its _version counter, so that the backward # doesn't throw an (overzealous) error about in-place correctness. We know that it's actually correct. @@ -114,12 +118,54 @@ def rk4_alt_step_func(func, t0, dt, t1, y0, f0=None, perturb=False): return (k1 + 3 * (k2 + k3) + k4) * dt * 0.125 +def rk3_step_func(func, t0, dt, t1, y0, butcher_tableu=None, f0=None, perturb=False): + """butcher_tableu should be of the form + + [ + [0 , 0 , 0 , 0], + [c_2, a_{21}, 0 , 0], + [c_3, a_{31}, a_{32}, 0], + [0 , b_1 , b_2 , b_3], + ] + + https://en.wikipedia.org/wiki/List_of_Runge-Kutta_methods + """ + k1 = f0 + if k1 is None: + k1 = func(t0, y0, perturb=Perturb.NEXT if perturb else Perturb.NONE) + + k2 = func(t0 + dt * butcher_tableu[1][0], y0 + dt * k1 * butcher_tableu[1][1]) + k3 = func(t0 + dt * butcher_tableu[2][0], y0 + dt * (k1 * butcher_tableu[2][1] + k2 * butcher_tableu[2][2])) + return dt * (k1 * butcher_tableu[3][1] + k2 * butcher_tableu[3][2] + k3 * butcher_tableu[3][3]) + + +def rk2_step_func(func, t0, dt, t1, y0, butcher_tableu=None, f0=None, perturb=False): + """butcher_tableu should be of the form + + [ + [0 , 0 , 0 ], + [c_2, a_{21}, 0 ], + [0 , b_1 , b_2 ], + ] + + https://en.wikipedia.org/wiki/List_of_Runge-Kutta_methods + """ + k1 = f0 + if k1 is None: + k1 = func(t0, y0, perturb=Perturb.NEXT if perturb else Perturb.NONE) + + k2 = func(t0 + dt * butcher_tableu[1][0], y0 + dt * k1 * butcher_tableu[1][1], perturb=Perturb.PREV if perturb else Perturb.NONE) + return dt * (k1 * butcher_tableu[2][1] + k2 * butcher_tableu[2][2]) + + class RKAdaptiveStepsizeODESolver(AdaptiveStepsizeEventODESolver): order: int tableau: _ButcherTableau mid: torch.Tensor def __init__(self, func, y0, rtol, atol, + min_step=0, + max_step=float('inf'), first_step=None, step_t=None, jump_t=None, @@ -133,12 +179,14 @@ def __init__(self, func, y0, rtol, atol, # We use mixed precision. y has its original dtype (probably float32), whilst all 'time'-like objects use # `dtype` (defaulting to float64). - dtype = torch.promote_types(dtype, y0.dtype) + dtype = torch.promote_types(dtype, y0.abs().dtype) device = y0.device self.func = func self.rtol = torch.as_tensor(rtol, dtype=dtype, device=device) self.atol = torch.as_tensor(atol, dtype=dtype, device=device) + self.min_step = torch.as_tensor(min_step, dtype=dtype, device=device) + self.max_step = torch.as_tensor(max_step, dtype=dtype, device=device) self.first_step = None if first_step is None else torch.as_tensor(first_step, dtype=dtype, device=device) self.safety = torch.as_tensor(safety, dtype=dtype, device=device) self.ifactor = torch.as_tensor(ifactor, dtype=dtype, device=device) @@ -156,6 +204,12 @@ def __init__(self, func, y0, rtol, atol, c_error=self.tableau.c_error.to(device=device, dtype=y0.dtype)) self.mid = self.mid.to(device=device, dtype=y0.dtype) + @classmethod + def valid_callbacks(cls): + return super(RKAdaptiveStepsizeODESolver, cls).valid_callbacks() | {'callback_step', + 'callback_accept_step', + 'callback_reject_step'} + def _before_integrate(self, t): t0 = t[0] f0 = self.func(t[0], self.y0) @@ -212,6 +266,10 @@ def _advance_until_event(self, event_fn): def _adaptive_step(self, rk_state): """Take an adaptive Runge-Kutta step to integrate the ODE.""" y0, f0, _, t0, dt, interp_coeff = rk_state + if not torch.isfinite(dt): + dt = self.min_step + dt = dt.clamp(self.min_step, self.max_step) + self.func.callback_step(t0, y0, dt) t1 = t0 + dt # dtypes: self.y0.dtype (probably float32); self.dtype (probably float64) # used for state and timelike objects respectively. @@ -264,6 +322,13 @@ def _adaptive_step(self, rk_state): ######################################################## error_ratio = _compute_error_ratio(y1_error, self.rtol, self.atol, y0, y1, self.norm) accept_step = error_ratio <= 1 + + # Handle min max stepping + if dt > self.max_step: + accept_step = False + if dt <= self.min_step: + accept_step = True + # dtypes: # error_ratio.dtype == self.dtype @@ -271,6 +336,7 @@ def _adaptive_step(self, rk_state): # Update RK State # ######################################################## if accept_step: + self.func.callback_accept_step(t0, y0, dt) t_next = t1 y_next = y1 interp_coeff = self._interp_fit(y0, y_next, k, dt) @@ -285,10 +351,12 @@ def _adaptive_step(self, rk_state): f1 = self.func(t_next, y_next, perturb=Perturb.NEXT) f_next = f1 else: + self.func.callback_reject_step(t0, y0, dt) t_next = t0 y_next = y0 f_next = f0 dt_next = _optimal_step_size(dt, error_ratio, self.safety, self.ifactor, self.dfactor, self.order) + dt_next = dt_next.clamp(self.min_step, self.max_step) rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff) return rk_state @@ -305,3 +373,186 @@ def _sort_tvals(tvals, t0): # TODO: add warning if tvals come before t0? tvals = tvals[tvals >= t0] return torch.sort(tvals).values + + +class FixedGridFIRKODESolver(FixedGridODESolver): + order: int + tableau: _ButcherTableau + + def __init__(self, func, y0, step_size=None, grid_constructor=None, interp='linear', perturb=False, max_iters=100, **unused_kwargs): + + self.max_iters = max_iters + self.atol = unused_kwargs.pop('atol') + unused_kwargs.pop('rtol', None) + unused_kwargs.pop('norm', None) + _handle_unused_kwargs(self, unused_kwargs) + del unused_kwargs + + self.func = func + self.y0 = y0 + self.dtype = y0.dtype + self.device = y0.device + self.step_size = step_size + self.interp = interp + self.perturb = perturb + + if step_size is None: + if grid_constructor is None: + self.grid_constructor = lambda f, y0, t: t + else: + self.grid_constructor = grid_constructor + else: + if grid_constructor is None: + self.grid_constructor = self._grid_constructor_from_step_size(step_size) + else: + raise ValueError("step_size and grid_constructor are mutually exclusive arguments.") + + self.tableau = _ButcherTableau(alpha=self.tableau.alpha.to(device=self.device, dtype=y0.dtype), + beta=[b.to(device=self.device, dtype=y0.dtype) for b in self.tableau.beta], + c_sol=self.tableau.c_sol.to(device=self.device, dtype=y0.dtype), + c_error=self.tableau.c_error.to(device=self.device, dtype=y0.dtype)) + + def _step_func(self, func, t0, dt, t1, y0): + if not isinstance(t0, torch.Tensor): + t0 = torch.tensor(t0) + if not isinstance(dt, torch.Tensor): + dt = torch.tensor(dt) + if not isinstance(t1, torch.Tensor): + t1 = torch.tensor(t1) + f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) + + t_dtype = y0.abs().dtype + tol = 1e-8 + if t_dtype == torch.float64: + tol = 1e-8 + if t_dtype == torch.float32: + tol = 1e-6 + + t0 = t0.to(t_dtype) + dt = dt.to(t_dtype) + t1 = t1.to(t_dtype) + + k = f0.clone().unsqueeze(-1).tile(len(self.tableau.alpha)) + beta = torch.stack(self.tableau.beta, -1) + + # Broyden's Method to solve the system of nonlinear equations + y = torch.matmul(k, beta * dt).add(y0.unsqueeze(-1)).movedim(-1, 0) + f = self._residual(func, k, y, t0, dt, t1) + J = torch.ones_like(f).diag() + converged = False + for _ in range(self.max_iters): + if torch.linalg.norm(f, 2) < tol: + converged = True + break + + # If the matrix becomes singular, just stop and return the last value + try: + s = -torch.linalg.solve(J, f) + except torch._C._LinAlgError: + break + + k = k + s.reshape_as(k) + y = torch.matmul(k, beta * dt).add(y0.unsqueeze(-1)).movedim(-1, 0) + newf = self._residual(func, k, y, t0, dt, t1) + z = newf - f + f = newf + J = J + (torch.outer ((z - torch.linalg.vecdot(J,s)),s)) / (torch.dot(s,s)) + + if not converged: + warnings.warn('Functional iteration did not converge. Solution may be incorrect.') + + dy = torch.matmul(k, dt * self.tableau.c_sol) + + return dy, f0 + + def _residual(self, func, K, y, t0, dt, t1): + res = torch.zeros_like(K) + for i, (y_i, alpha_i) in enumerate(zip(y, self.tableau.alpha)): + perturb = Perturb.NONE + if alpha_i == 1.: + ti = t1 + perturb = Perturb.PREV + elif alpha_i == 0.: + if not torch.all(self.tableau.beta[i]): + # Same slope as stored so skip + continue + ti = t0 + else: + ti = t0 + alpha_i * dt + res[...,i] = K[...,i] - func(ti, y_i, perturb=perturb) + return res.flatten() + + +class FixedGridDIRKODESolver(FixedGridFIRKODESolver): + + def _step_func(self, func, t0, dt, t1, y0): + if not isinstance(t0, torch.Tensor): + t0 = torch.tensor(t0) + if not isinstance(dt, torch.Tensor): + dt = torch.tensor(dt) + if not isinstance(t1, torch.Tensor): + t1 = torch.tensor(t1) + f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) + + t_dtype = y0.abs().dtype + tol = 1e-8 + if t_dtype == torch.float64: + tol = 1e-8 + if t_dtype == torch.float32: + tol = 1e-6 + + t0 = t0.to(t_dtype) + dt = dt.to(t_dtype) + t1 = t1.to(t_dtype) + + k = [f0.clone()] * len(self.tableau.alpha) + + for i, (alpha_i, beta_i) in enumerate(zip(self.tableau.alpha, self.tableau.beta)): + perturb = Perturb.NONE + if alpha_i == 1.: + ti = t1 + perturb = Perturb.PREV + elif alpha_i == 0.: + if not torch.all(self.tableau.beta[i]): + # Same slope as stored so skip + continue + ti = t0 + else: + ti = t0 + alpha_i * dt + + k_i = torch.stack(k[:i+1], -1) + + # Broyden's Method to solve the system of nonlinear equations + y_i = torch.matmul(k_i, beta_i * dt).add(y0) + f = self._residual(func, k_i, y_i, ti, perturb) + J = torch.ones_like(f).diag() + converged = False + for _ in range(self.max_iters): + if torch.linalg.norm(f, 2) < tol: + converged = True + break + + # If the matrix becomes singular, just stop and return the last value + try: + s = -torch.linalg.solve(J, f) + except torch._C._LinAlgError: + break + + k[i] = k[i] + s.reshape_as(k[i]) + k_i = torch.stack(k[:i+1], -1) + y_i = torch.matmul(k_i, beta_i * dt).add(y0) + newf = self._residual(func, k_i, y_i, ti, perturb) + z = newf - f + f = newf + J = J + (torch.outer ((z - torch.linalg.vecdot(J,s)),s)) / (torch.dot(s,s)) + + if not converged: + warnings.warn('Functional iteration did not converge. Solution may be incorrect.') + + dy = torch.matmul(torch.stack(k, -1), dt * self.tableau.c_sol) + + return dy, f0 + + def _residual(self, func, K, y, t, perturb): + res = K[...,-1] - func(t, y, perturb=perturb) + return res.flatten() diff --git a/torchdiffeq/_impl/scipy_wrapper.py b/torchdiffeq/_impl/scipy_wrapper.py index 06f93273e..41fe90149 100644 --- a/torchdiffeq/_impl/scipy_wrapper.py +++ b/torchdiffeq/_impl/scipy_wrapper.py @@ -6,7 +6,7 @@ class ScipyWrapperODESolver(metaclass=abc.ABCMeta): - def __init__(self, func, y0, rtol, atol, solver="LSODA", **unused_kwargs): + def __init__(self, func, y0, rtol, atol, min_step=0, max_step=float('inf'), solver="LSODA", **unused_kwargs): unused_kwargs.pop('norm', None) unused_kwargs.pop('grid_points', None) unused_kwargs.pop('eps', None) @@ -19,6 +19,8 @@ def __init__(self, func, y0, rtol, atol, solver="LSODA", **unused_kwargs): self.y0 = y0.detach().cpu().numpy().reshape(-1) self.rtol = rtol self.atol = atol + self.min_step = min_step + self.max_step = max_step self.solver = solver self.func = convert_func_to_numpy(func, self.shape, self.device, self.dtype) @@ -34,10 +36,16 @@ def integrate(self, t): method=self.solver, rtol=self.rtol, atol=self.atol, + min_step=self.min_step, + max_step=self.max_step ) sol = torch.tensor(sol.y).T.to(self.device, self.dtype) sol = sol.reshape(-1, *self.shape) return sol + + @classmethod + def valid_callbacks(cls): + return set() def convert_func_to_numpy(func, shape, device, dtype): diff --git a/torchdiffeq/_impl/solvers.py b/torchdiffeq/_impl/solvers.py index 6915f2bd9..cc64218b1 100644 --- a/torchdiffeq/_impl/solvers.py +++ b/torchdiffeq/_impl/solvers.py @@ -21,6 +21,10 @@ def _before_integrate(self, t): def _advance(self, next_t): raise NotImplementedError + @classmethod + def valid_callbacks(cls): + return set() + def integrate(self, t): solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) solution[0] = self.y0 @@ -74,6 +78,10 @@ def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="line else: raise ValueError("step_size and grid_constructor are mutually exclusive arguments.") + @classmethod + def valid_callbacks(cls): + return {'callback_step'} + @staticmethod def _grid_constructor_from_step_size(step_size): def _grid_constructor(func, y0, t): @@ -102,6 +110,7 @@ def integrate(self, t): y0 = self.y0 for t0, t1 in zip(time_grid[:-1], time_grid[1:]): dt = t1 - t0 + self.func.callback_step(t0, y0, dt) dy, f0 = self._step_func(self.func, t0, dt, t1, y0) y1 = y0 + dy @@ -121,7 +130,7 @@ def integrate(self, t): def integrate_until_event(self, t0, event_fn): assert self.step_size is not None, "Event handling for fixed step solvers currently requires `step_size` to be provided in options." - t0 = t0.type_as(self.y0) + t0 = t0.type_as(self.y0.abs()) y0 = self.y0 dt = self.step_size diff --git a/torchdiffeq/_impl/tsit5.py b/torchdiffeq/_impl/tsit5.py new file mode 100644 index 000000000..4f4a22186 --- /dev/null +++ b/torchdiffeq/_impl/tsit5.py @@ -0,0 +1,82 @@ +import torch +from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver +# https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/lib/OrdinaryDiffEqTsit5/src/tsit_tableaus.jl +# https://github.com/patrick-kidger/diffrax/blob/14baa1edddcacf27c0483962b3c9cf2e86e6e5b6/diffrax/_solver/tsit5.py#L158 + +_TSITOURAS_TABLEAU = _ButcherTableau( + alpha=torch.tensor([ + 161 / 1000, + 327 / 1000, + 9 / 10, + .9800255409045096857298102862870245954942137979563024768854764293221195950761080302604, + 1, + 1 + ], dtype=torch.float64), + beta=[ + torch.tensor([161 / 1000], dtype=torch.float64), + torch.tensor([ + -.8480655492356988544426874250230774675121177393430391537369234245294192976164141156943e-2, + .3354806554923569885444268742502307746751211773934303915373692342452941929761641411569 + ], dtype=torch.float64), + torch.tensor([ + 2.897153057105493432130432594192938764924887287701866490314866693455023795137503079289, + -6.359448489975074843148159912383825625952700647415626703305928850207288721235210244366, + 4.362295432869581411017727318190886861027813359713760212991062156752264926097707165077, + ], dtype=torch.float64), + torch.tensor([ + 5.325864828439256604428877920840511317836476253097040101202360397727981648835607691791, + -11.74888356406282787774717033978577296188744178259862899288666928009020615663593781589, + 7.495539342889836208304604784564358155658679161518186721010132816213648793440552049753, + -.9249506636175524925650207933207191611349983406029535244034750452930469056411389539635e-1 + ], dtype=torch.float64), + torch.tensor([ + 5.861455442946420028659251486982647890394337666164814434818157239052507339770711679748, + -12.92096931784710929170611868178335939541780751955743459166312250439928519268343184452, + 8.159367898576158643180400794539253485181918321135053305748355423955009222648673734986, + -.7158497328140099722453054252582973869127213147363544882721139659546372402303777878835e-1, + -.2826905039406838290900305721271224146717633626879770007617876201276764571291579142206e-1 + ], dtype=torch.float64), + torch.tensor([ + .9646076681806522951816731316512876333711995238157997181903319145764851595234062815396e-1, + 1 / 100, + .4798896504144995747752495322905965199130404621990332488332634944254542060153074523509, + 1.379008574103741893192274821856872770756462643091360525934940067397245698027561293331, + -3.290069515436080679901047585711363850115683290894936158531296799594813811049925401677, + 2.324710524099773982415355918398765796109060233222962411944060046314465391054716027841 + ], dtype=torch.float64), + ], + c_sol=torch.tensor([ + .9468075576583945807478876255758922856117527357724631226139574065785592789071067303271e-1, + .9183565540343253096776363936645313759813746240984095238905939532922955247253608687270e-2, + .4877705284247615707855642599631228241516691959761363774365216240304071651579571959813, + 1.234297566930478985655109673884237654035539930748192848315425833500484878378061439761, + -2.707712349983525454881109975059321670689605166938197378763992255714444407154902012702, + 1.866628418170587035753719399566211498666255505244122593996591602841258328965767580089, + 1 / 66 + ], dtype=torch.float64), + c_error=torch.tensor([ + -1.780011052225771443378550607539534775944678804333659557637450799792588061629796e-03, + -8.164344596567469032236360633546862401862537590159047610940604670770447527463931e-04, + 7.880878010261996010314727672526304238628733777103128603258129604952959142646516e-03, + -1.44711007173262907537165147972635116720922712343167677619514233896760819649515e-01, + 5.823571654525552250199376106520421794260781239567387797673045438803694038950012e-01, + -4.580821059291869466616365188325542974428047279788398179474684434732070620889539e-01, + 1 / 66 + ], dtype=torch.float64), +) + +x = 1 / 2 +TSIT_C_MID = torch.tensor([ + -1.0530884977290216*x*(x-1.329989018975412)*(x*x-1.4364028541716351*x+0.7139816917074209), + 0.1017*x*x*(x*x-2.1966568338249754*x+1.2949852507374631), + 2.490627285651252793*x*x*(x*x-2.38535645472061657*x+1.57803468208092486), + -16.54810288924490272*(x-1.21712927295533244)*(x-0.61620406037800089)*x*x, + 47.37952196281928122*(x-1.203071208372362603)*(x-0.658047292653547382)*x*x, + -34.87065786149660974*(x-1.2)*(x-2/3)*x*x, + 2.5*(x-1)*(x-0.6)*x*x +], dtype=torch.float64) + +class Tsit5Solver(RKAdaptiveStepsizeODESolver): + order = 5 + tableau = _TSITOURAS_TABLEAU + mid = TSIT_C_MID