From 07ca1e1fcc09b98148ad69303c3a7bcb8537f939 Mon Sep 17 00:00:00 2001 From: Michael Plainer <6825443+plainerman@users.noreply.github.com> Date: Wed, 26 Mar 2025 10:47:44 +0100 Subject: [PATCH] Change everything to float64 --- README.md | 8 -------- main.py | 14 +++++++------- prepare_molecule.py | 2 +- systems.py | 8 ++++---- tps/plot.py | 2 +- training/qsetup.py | 4 ++-- training/setups/diagonal.py | 14 +++++++------- training/setups/full.py | 30 +++++++++++++++--------------- utils/plot.py | 4 ++-- utils/splines.py | 4 ++-- 10 files changed, 41 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 236054d..78a049c 100644 --- a/README.md +++ b/README.md @@ -29,14 +29,6 @@ A novel variational approach to transition path sampling (TPS) based on the Doob Running the deterministic and stochastic simulations using our algorithm for 2D potential.

-# FAQ -## I am getting NaN values when running experiments on alanine dipeptide! -This is an issue on certain devices, and, so far, we haven't figured out the underlying reason. However, we have found out that: - -1. Changing your floats to 64-bit precision prevents this problem from happening (at least on our machines), albeit at ~2x slower performance. To change to float64, simply search for all instances of `jnp.float32` (as can be seen [here](https://github.com/search?q=repo%3Aplainerman%2FVariational-Doob%20jnp.float32&type=code)) and change it to `jnp.float64`. - -2. First-order systems usually do not exhibit this behavior. So you can also change your `ode` in the config (e.g., [here](https://github.com/plainerman/Variational-Doob/blob/b3836998080569af5deaaa5bd1ef6ad0993e0bd9/configs/aldp_diagonal_single_gaussian.yaml#L7)) to `first_order` and see if this resolves the issue. In our tests, first-order ODE was sufficient for most setups. - # Getting started The best way to understand our method is to look at [the google colab notebook](https://colab.research.google.com/drive/1FcmEbec06cH4yk0t8vOIt8r1Gm-VjQZ0?usp=sharing) which contains the necessary code for 2D potentials in one place. diff --git a/main.py b/main.py index 43b67ce..80fee13 100644 --- a/main.py +++ b/main.py @@ -123,13 +123,13 @@ def main(): B = system.B elif args.ode == 'second_order': # We pad the A and B matrices with zeros to account for the velocity - A = jnp.hstack([system.A, jnp.zeros_like(system.A)], dtype=jnp.float32) - B = jnp.hstack([system.B, jnp.zeros_like(system.B)], dtype=jnp.float32) + A = jnp.hstack([system.A, jnp.zeros_like(system.A)], dtype=jnp.float64) + B = jnp.hstack([system.B, jnp.zeros_like(system.B)], dtype=jnp.float64) xi_velocity = jnp.ones_like(system.A) * xi xi_pos = jnp.zeros_like(xi_velocity) + args.xi_pos_noise - xi = jnp.concatenate((xi_pos, xi_velocity), axis=-1, dtype=jnp.float32) + xi = jnp.concatenate((xi_pos, xi_velocity), axis=-1, dtype=jnp.float64) else: raise ValueError(f"Unknown ODE: {args.ode}") @@ -144,7 +144,7 @@ def main(): key = jax.random.PRNGKey(args.seed) key, init_key = jax.random.split(key) - params_q = setup.model_q.init(init_key, jnp.zeros([args.BS, 1], dtype=jnp.float32)) + params_q = setup.model_q.init(init_key, jnp.zeros([args.BS, 1], dtype=jnp.float64)) optimizer_q = optax.adam(learning_rate=args.lr) state_q = train_state.TrainState.create(apply_fn=setup.model_q.apply, params=params_q, tx=optimizer_q) @@ -186,15 +186,15 @@ def main(): log_scale(args.log_plots, False, True) show_or_save_fig(args.save_dir, 'loss_plot', args.extension) - t = args.T * jnp.linspace(0, 1, args.BS, dtype=jnp.float32).reshape((-1, 1)) + t = args.T * jnp.linspace(0, 1, args.BS, dtype=jnp.float64).reshape((-1, 1)) key, path_key = jax.random.split(key) mu_t, _, w_logits = state_q.apply_fn(state_q.params, t) w = jax.nn.softmax(w_logits) print('Weights of mixtures:', w) key, init_key = jax.random.split(key) - x_0 = jnp.ones((args.num_paths, A.shape[0]), dtype=jnp.float32) * A - eps = jax.random.normal(key, shape=x_0.shape, dtype=jnp.float32) + x_0 = jnp.ones((args.num_paths, A.shape[0]), dtype=jnp.float64) * A + eps = jax.random.normal(key, shape=x_0.shape, dtype=jnp.float64) x_0 += args.base_sigma * eps x_t_det = setup.sample_paths(state_q, x_0, args.dt, args.T, args.BS, None) diff --git a/prepare_molecule.py b/prepare_molecule.py index 707b8eb..013df87 100644 --- a/prepare_molecule.py +++ b/prepare_molecule.py @@ -34,7 +34,7 @@ def minimize(pdb, out, steps): for mass_ in mass: for _ in range(3): new_mass.append(mass_) - mass = jnp.array(new_mass, dtype=jnp.float32) + mass = jnp.array(new_mass, dtype=jnp.float64) # Initialize the potential energy with amber forcefields ff = Hamiltonian('amber14/protein.ff14SB.xml', 'amber14/tip3p.xml') diff --git a/systems.py b/systems.py index 057daaf..4e90ad6 100644 --- a/systems.py +++ b/systems.py @@ -55,7 +55,7 @@ def from_name(cls, name: str, force_clip: float) -> Self: plot = partial(toy_plot_energy_surface, U=U, states=list(zip(['A', 'B'], [A, B])), xlim=xlim, ylim=ylim, alpha=1.0 ) - mass = jnp.array([1.0, 1.0], dtype=jnp.float32) + mass = jnp.array([1.0, 1.0], dtype=jnp.float64) return cls(U, A, B, mass, plot, force_clip) @classmethod @@ -64,10 +64,10 @@ def from_pdb(cls, A: str, B: str, forcefield: [str], cv: Optional[str], force_cl assert_same_molecule(A_pdb, B_pdb) mass = [a.element.mass.value_in_unit(unit.dalton) for a in A_pdb.topology.atoms()] - mass = jnp.broadcast_to(jnp.array(mass, dtype=jnp.float32).reshape(-1, 1), (len(mass), 3)).reshape(-1) + mass = jnp.broadcast_to(jnp.array(mass, dtype=jnp.float64).reshape(-1, 1), (len(mass), 3)).reshape(-1) - A = jnp.array(A_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer), dtype=jnp.float32) - B = jnp.array(B_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer), dtype=jnp.float32) + A = jnp.array(A_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer), dtype=jnp.float64) + B = jnp.array(B_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer), dtype=jnp.float64) num_atoms = A.shape[0] A, B = kabsch_align(A, B) A, B = A.reshape(-1), B.reshape(-1) diff --git a/tps/plot.py b/tps/plot.py index 38059c4..b7a80da 100644 --- a/tps/plot.py +++ b/tps/plot.py @@ -9,7 +9,7 @@ def __init__(self, bins=250, interpolate=True, scale=jnp.pi): self.bins = bins self.interpolate = interpolate self.scale = scale - self.hist = jnp.zeros((bins, bins), dtype=jnp.float32) + self.hist = jnp.zeros((bins, bins), dtype=jnp.float64) def add_paths(self, paths: list[jnp.ndarray], factors: list[float] = None): for path, factor in tqdm(zip(paths, factors or [1] * len(paths)), total=len(paths)): diff --git a/training/qsetup.py b/training/qsetup.py index 0b57d6d..2e7bc84 100644 --- a/training/qsetup.py +++ b/training/qsetup.py @@ -36,10 +36,10 @@ def sample_paths(self, state_q: TrainState, x_0: ArrayLike, dt: float, T: float, num_paths = x_0.shape[0] ndim = x_0.shape[1] - x_t = jnp.zeros((num_paths, N, ndim), dtype=jnp.float32) + x_t = jnp.zeros((num_paths, N, ndim), dtype=jnp.float64) x_t = x_t.at[:, 0, :].set(x_0) - t = jnp.zeros((BS, 1), dtype=jnp.float32) + t = jnp.zeros((BS, 1), dtype=jnp.float64) u = jax.jit(lambda _t, _x: self.u_t(state_q, _t, _x, key is None, *args, **kwargs)) for i in trange(N): diff --git a/training/setups/diagonal.py b/training/setups/diagonal.py index f282226..42d7ab7 100644 --- a/training/setups/diagonal.py +++ b/training/setups/diagonal.py @@ -29,7 +29,7 @@ def __call__(self, t): ndim = self.A.shape[0] BS = t.shape[0] t = t / self.T - t_grid = jnp.linspace(0, 1, self.n_points, dtype=jnp.float32) + t_grid = jnp.linspace(0, 1, self.n_points, dtype=jnp.float64) A = (jnp.ones((self.num_mixtures, ndim), dtype=self.A.dtype) * self.A).reshape(-1) B = (jnp.ones((self.num_mixtures, ndim), dtype=self.A.dtype) * self.B).reshape(-1) @@ -56,9 +56,9 @@ def __call__(self, t): sigma = jnp.exp(sigma.reshape(BS, self.num_mixtures, ndim)) if self.trainable_weights: - w_logits = self.param('w_logits', nn.initializers.zeros_init(), (self.num_mixtures,), dtype=jnp.float32) + w_logits = self.param('w_logits', nn.initializers.zeros_init(), (self.num_mixtures,), dtype=jnp.float64) else: - w_logits = jnp.zeros(self.num_mixtures, dtype=jnp.float32) + w_logits = jnp.zeros(self.num_mixtures, dtype=jnp.float64) out = (mu, sigma, w_logits) if self.transform: @@ -89,9 +89,9 @@ def _post_process(self, h: ArrayLike, t: ArrayLike): ) if self.trainable_weights: - w_logits = self.param('w_logits', nn.initializers.zeros_init(), (num_mixtures,), dtype=jnp.float32) + w_logits = self.param('w_logits', nn.initializers.zeros_init(), (num_mixtures,), dtype=jnp.float64) else: - w_logits = jnp.zeros(num_mixtures, dtype=jnp.float32) + w_logits = jnp.zeros(num_mixtures, dtype=jnp.float64) return mu, sigma, w_logits @@ -104,8 +104,8 @@ def loss_fn(params_q: Union[FrozenVariableDict, Dict[str, Any]], key: ArrayLike) ndim = self.model_q.A.shape[-1] key = jax.random.split(key) - t = self.T * jax.random.uniform(key[0], [BS, 1], dtype=jnp.float32) - eps = jax.random.normal(key[1], [BS, 1, ndim], dtype=jnp.float32) + t = self.T * jax.random.uniform(key[0], [BS, 1], dtype=jnp.float64) + eps = jax.random.normal(key[1], [BS, 1, ndim], dtype=jnp.float64) def v_t(_eps, _t): """This function is equal to v_t * xi ** 2.""" diff --git a/training/setups/full.py b/training/setups/full.py index f5aa1ef..39ff783 100644 --- a/training/setups/full.py +++ b/training/setups/full.py @@ -42,8 +42,8 @@ def __call__(self, t): ndim = self.A.shape[0] t = t / self.T - t_grid = jnp.linspace(0, 1, self.n_points, dtype=jnp.float32) - S_0 = jnp.log(self.base_sigma) * jnp.eye(ndim, dtype=jnp.float32) + t_grid = jnp.linspace(0, 1, self.n_points, dtype=jnp.float64) + S_0 = jnp.log(self.base_sigma) * jnp.eye(ndim, dtype=jnp.float64) S_0_vec = S_0[jnp.tril_indices(ndim)] mu_params = self.param('mu_params', lambda rng: jnp.linspace(self.A, self.B, self.n_points)[1:-1]) S_params = self.param('S_params', lambda rng: jnp.linspace(S_0_vec, S_0_vec, self.n_points)[1:-1]) @@ -52,7 +52,7 @@ def __call__(self, t): @jax.vmap def get_tril(v): - a = jnp.zeros((ndim, ndim), dtype=jnp.float32) + a = jnp.zeros((ndim, ndim), dtype=jnp.float64) a = a.at[jnp.tril_indices(ndim)].set(v) return a @@ -66,12 +66,12 @@ def get_tril(v): raise ValueError(f"Interpolation method {self.interpolation} not recognized.") S = get_tril(S) - S = jnp.tril(2 * jax.nn.sigmoid(S) - 1.0, k=-1) + jnp.eye(ndim, dtype=jnp.float32)[None, ...] * jnp.exp(S) + S = jnp.tril(2 * jax.nn.sigmoid(S) - 1.0, k=-1) + jnp.eye(ndim, dtype=jnp.float64)[None, ...] * jnp.exp(S) if self.trainable_weights: - w_logits = self.param('w_logits', nn.initializers.zeros_init(), (self.num_mixtures,), dtype=jnp.float32) + w_logits = self.param('w_logits', nn.initializers.zeros_init(), (self.num_mixtures,), dtype=jnp.float64) else: - w_logits = jnp.zeros(self.num_mixtures, dtype=jnp.float32) + w_logits = jnp.zeros(self.num_mixtures, dtype=jnp.float64) out = (mu, S, w_logits) if self.transform: @@ -94,9 +94,9 @@ def _post_process(self, h: ArrayLike, t: ArrayLike): num_mixtures = self.num_mixtures h_mu = (1 - t) * self.A + t * self.B - S_0 = jnp.eye(ndim, dtype=jnp.float32) - S_0 = S_0 * jnp.vstack([self.base_sigma * jnp.ones((ndim // 2, 1), dtype=jnp.float32), - self.base_sigma * jnp.ones((ndim // 2, 1), dtype=jnp.float32)]) + S_0 = jnp.eye(ndim, dtype=jnp.float64) + S_0 = S_0 * jnp.vstack([self.base_sigma * jnp.ones((ndim // 2, 1), dtype=jnp.float64), + self.base_sigma * jnp.ones((ndim // 2, 1), dtype=jnp.float64)]) S_0 = S_0[None, ...] h_S = (1 - 2 * t * (1 - t))[..., None] * S_0 @@ -109,19 +109,19 @@ def _post_process(self, h: ArrayLike, t: ArrayLike): @jax.vmap # once for num_mixtures @jax.vmap # once for batch def get_tril(v): - a = jnp.zeros((ndim, ndim), dtype=jnp.float32) + a = jnp.zeros((ndim, ndim), dtype=jnp.float64) a = a.at[jnp.tril_indices(ndim)].set(v) return a S = h[:, self.num_mixtures * ndim:].reshape(BS, self.num_mixtures, ndim * (ndim + 1) // 2) S = get_tril(S) - S = jnp.tril(2 * jax.nn.sigmoid(S) - 1.0, k=-1) + jnp.eye(ndim, dtype=jnp.float32)[None, ...] * jnp.exp(S) + S = jnp.tril(2 * jax.nn.sigmoid(S) - 1.0, k=-1) + jnp.eye(ndim, dtype=jnp.float64)[None, ...] * jnp.exp(S) S = h_S[:, None, ...] + 2 * ((1 - t) * t)[..., None, None] * S if self.trainable_weights: - w_logits = self.param('w_logits', nn.initializers.zeros_init(), (num_mixtures,), dtype=jnp.float32) + w_logits = self.param('w_logits', nn.initializers.zeros_init(), (num_mixtures,), dtype=jnp.float64) else: - w_logits = jnp.zeros(num_mixtures, dtype=jnp.float32) + w_logits = jnp.zeros(num_mixtures, dtype=jnp.float64) return mu, S, w_logits @@ -133,8 +133,8 @@ def loss_fn(params_q: Union[FrozenVariableDict, Dict[str, Any]], key: ArrayLike) ndim = self.model_q.A.shape[-1] key = jax.random.split(key) - t = self.T * jax.random.uniform(key[0], [BS, 1], dtype=jnp.float32) - eps = jax.random.normal(key[1], [BS, ndim, 1], dtype=jnp.float32) + t = self.T * jax.random.uniform(key[0], [BS, 1], dtype=jnp.float64) + eps = jax.random.normal(key[1], [BS, ndim, 1], dtype=jnp.float64) def v_t(_eps, _t): _mu_t, _S_t_val, _w_logits, _dmudt, _dSdt_val = forward_and_derivatives(state_q, _t, params_q) diff --git a/utils/plot.py b/utils/plot.py index 2ee2f24..41954de 100644 --- a/utils/plot.py +++ b/utils/plot.py @@ -179,7 +179,7 @@ def _plot_trajectories(trajectories: ArrayLike, bins: int, xlim: ArrayLike, ylim def plot_u_t(system: System, setup: QSetup, state_q: TrainState, T: float, save_dir: str, name: str, frames: int = 100, fps: int = 10): - t = T * jnp.linspace(0, 1, frames, dtype=jnp.float32).reshape((-1, 1)) + t = T * jnp.linspace(0, 1, frames, dtype=jnp.float64).reshape((-1, 1)) mu_t, sigma_t, _ = state_q.apply_fn(state_q.params, t) _u_t_func = jax.jit(lambda _t, _points: setup.u_t(state_q, _t * jnp.ones((len(_points), 1)), _points, True)) @@ -192,7 +192,7 @@ def get_lim(): x_lim, y_lim = get_lim() x, y = jnp.meshgrid(jnp.linspace(x_lim[0], x_lim[1], 10), jnp.linspace(y_lim[0], y_lim[1], 10)) - points = jnp.vstack([x.ravel(), y.ravel()], dtype=jnp.float32).T + points = jnp.vstack([x.ravel(), y.ravel()], dtype=jnp.float64).T x_all, y_all = [], [] u_all, v_all = [], [] diff --git a/utils/splines.py b/utils/splines.py index 6276988..f7a5f7b 100644 --- a/utils/splines.py +++ b/utils/splines.py @@ -7,8 +7,8 @@ def compute_spline_coefficients(x_knots, y_knots): h = jnp.diff(x_knots) b = (jnp.diff(y_knots, axis=0).T / h).T - u = jnp.zeros(n + 1, dtype=jnp.float32) - v = jnp.zeros((n + 1,) + y_knots.shape[1:], dtype=jnp.float32) + u = jnp.zeros(n + 1, dtype=jnp.float64) + v = jnp.zeros((n + 1,) + y_knots.shape[1:], dtype=jnp.float64) u = u.at[1:n].set(2 * (h[:-1] + h[1:])) v = v.at[1:n].set(6 * (b[1:] - b[:-1]))