diff --git a/examples/openmm/abf/alanine-dipeptide_openmm.py b/examples/openmm/abf/alanine-dipeptide_openmm.py index f913b23f..810c362d 100644 --- a/examples/openmm/abf/alanine-dipeptide_openmm.py +++ b/examples/openmm/abf/alanine-dipeptide_openmm.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 - import matplotlib.pyplot as plt import numpy @@ -115,7 +114,7 @@ def post_run_action(**kwargs): def main(): cvs = [DihedralAngle((4, 6, 8, 14)), DihedralAngle((6, 8, 14, 16))] grid = pysages.Grid(lower=(-pi, -pi), upper=(pi, pi), shape=(32, 32), periodic=True) - method = ABF(cvs, grid) + method = ABF(cvs, grid, use_pinv=True) raw_result = pysages.run(method, generate_simulation, 25, post_run_action=post_run_action) result = pysages.analyze(raw_result, topology=(14,)) diff --git a/pysages/methods/abf.py b/pysages/methods/abf.py index ba713d15..dbba8888 100644 --- a/pysages/methods/abf.py +++ b/pysages/methods/abf.py @@ -27,7 +27,7 @@ from pysages.methods.restraints import apply_restraints from pysages.methods.utils import numpyfy_vals from pysages.typing import JaxArray, NamedTuple -from pysages.utils import dispatch, solve_pos_def +from pysages.utils import dispatch, linear_solver class ABFState(NamedTuple): @@ -103,6 +103,11 @@ class ABF(GriddedSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -110,6 +115,7 @@ class ABF(GriddedSamplingMethod): def __init__(self, cvs, grid, **kwargs): super().__init__(cvs, grid, **kwargs) self.N = np.asarray(self.kwargs.get("N", 500)) + self.use_pinv = self.kwargs.get("use_pinv", False) def build(self, snapshot, helpers, *args, **kwargs): """ @@ -158,6 +164,7 @@ def _abf(method, snapshot, helpers): dt = snapshot.dt dims = grid.shape.size natoms = np.size(snapshot.positions, 0) + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) estimate_force = build_force_estimator(method) @@ -201,11 +208,7 @@ def update(state, data): xi, Jxi = cv(data) p = data.momenta - # The following could equivalently be computed as `linalg.pinv(Jxi.T) @ p` - # (both seem to have the same performance). - # Another option to benchmark against is - # Wp = linalg.tensorsolve(Jxi @ Jxi.T, Jxi @ p) - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) # Second order backward finite difference dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt diff --git a/pysages/methods/cff.py b/pysages/methods/cff.py index 57b5fa96..3a3306c1 100644 --- a/pysages/methods/cff.py +++ b/pysages/methods/cff.py @@ -31,7 +31,7 @@ from pysages.ml.training import NNData, build_fitting_function, convolve, normalize from pysages.ml.utils import blackman_kernel, pack, unpack from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, first_or_all, solve_pos_def +from pysages.utils import dispatch, first_or_all, linear_solver # Aliases f32 = np.float32 @@ -148,6 +148,11 @@ class CFF(NNSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -171,6 +176,7 @@ def __init__(self, cvs, grid, topology, kT, **kwargs): self.fmodel = MLP(dims, dims, topology, transform=scale) self.optimizer = kwargs.get("optimizer", default_optimizer) self.foptimizer = kwargs.get("foptimizer", default_foptimizer) + self.use_pinv = self.kwargs.get("use_pinv", False) def build(self, snapshot, helpers): return _cff(self, snapshot, helpers) @@ -187,6 +193,7 @@ def _cff(method: CFF, snapshot, helpers): fps, _ = unpack(method.fmodel.parameters) # Helper methods + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) learn_free_energy = build_free_energy_learner(method) estimate_force = build_force_estimator(method) @@ -221,7 +228,7 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # I_xi = get_grid_index(xi) diff --git a/pysages/methods/funn.py b/pysages/methods/funn.py index 6130d396..593dda5a 100644 --- a/pysages/methods/funn.py +++ b/pysages/methods/funn.py @@ -34,7 +34,7 @@ from pysages.ml.training import NNData, build_fitting_function, convolve, normalize from pysages.ml.utils import blackman_kernel, pack, unpack from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, first_or_all, solve_pos_def +from pysages.utils import dispatch, first_or_all, linear_solver class FUNNState(NamedTuple): @@ -126,6 +126,11 @@ class FUNN(NNSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -142,6 +147,7 @@ def __init__(self, cvs, grid, topology, **kwargs): self.model = MLP(dims, dims, topology, transform=scale) default_optimizer = LevenbergMarquardt(reg=L2Regularization(1e-6)) self.optimizer = kwargs.get("optimizer", default_optimizer) + self.use_pinv = self.kwargs.get("use_pinv", False) def build(self, snapshot, helpers): return _funn(self, snapshot, helpers) @@ -160,6 +166,7 @@ def _funn(method, snapshot, helpers): ps, _ = unpack(method.model.parameters) # Helper methods + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) learn_free_energy_grad = build_free_energy_grad_learner(method) estimate_free_energy_grad = build_force_estimator(method) @@ -186,7 +193,7 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # I_xi = get_grid_index(xi) diff --git a/pysages/methods/sirens.py b/pysages/methods/sirens.py index b1342f9b..836e31b6 100644 --- a/pysages/methods/sirens.py +++ b/pysages/methods/sirens.py @@ -32,7 +32,7 @@ from pysages.ml.training import NNData, build_fitting_function, convolve from pysages.ml.utils import blackman_kernel, pack, unpack from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, first_or_all, solve_pos_def +from pysages.utils import dispatch, first_or_all, linear_solver class SirensState(NamedTuple): # pylint: disable=R0903 @@ -146,6 +146,11 @@ class Sirens(NNSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -172,6 +177,7 @@ def __init__(self, cvs, grid, topology, **kwargs): scale = partial(_scale, grid=grid) self.model = Siren(dims, 1, topology, transform=scale) self.optimizer = optimizer + self.use_pinv = self.kwargs.get("use_pinv", False) def __check_init_invariants__(self, mode, kT, optimizer): if mode not in ("abf", "cff"): @@ -202,6 +208,7 @@ def _sirens(method: Sirens, snapshot, helpers): ps, _ = unpack(method.model.parameters) # Helper methods + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) learn_free_energy = build_free_energy_learner(method) estimate_force = build_force_estimator(method) @@ -244,7 +251,7 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # I_xi = get_grid_index(xi) diff --git a/pysages/methods/spectral_abf.py b/pysages/methods/spectral_abf.py index 63b47955..c50d2f11 100644 --- a/pysages/methods/spectral_abf.py +++ b/pysages/methods/spectral_abf.py @@ -30,7 +30,7 @@ from pysages.methods.restraints import apply_restraints from pysages.methods.utils import numpyfy_vals from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, first_or_all, solve_pos_def +from pysages.utils import dispatch, first_or_all, linear_solver class SpectralABFState(NamedTuple): @@ -124,6 +124,11 @@ class SpectralABF(GriddedSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -135,6 +140,7 @@ def __init__(self, cvs, grid, **kwargs): self.fit_threshold = self.kwargs.get("fit_threshold", 500) self.grid = self.grid if self.grid.is_periodic else convert(self.grid, Grid[Chebyshev]) self.model = SpectralGradientFit(self.grid) + self.use_pinv = self.kwargs.get("use_pinv", False) def build(self, snapshot, helpers, *_args, **_kwargs): """ @@ -154,6 +160,7 @@ def _spectral_abf(method, snapshot, helpers): natoms = np.size(snapshot.positions, 0) # Helper methods + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) fit = build_fitter(method.model) fit_forces = build_free_energy_fitter(method, fit) @@ -181,7 +188,7 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) # Second order backward finite difference dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # diff --git a/pysages/utils/__init__.py b/pysages/utils/__init__.py index 3351d48e..ffde6466 100644 --- a/pysages/utils/__init__.py +++ b/pysages/utils/__init__.py @@ -18,5 +18,14 @@ solve_pos_def, try_import, ) -from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity +from .core import ( + ToCPU, + copy, + dispatch, + eps, + first_or_all, + gaussian, + identity, + linear_solver, +) from .transformations import quaternion_from_euler, quaternion_matrix diff --git a/pysages/utils/core.py b/pysages/utils/core.py index 06afdbc8..20fe0f3e 100644 --- a/pysages/utils/core.py +++ b/pysages/utils/core.py @@ -8,6 +8,7 @@ from plum import Dispatcher from pysages.typing import JaxArray, Scalar +from pysages.utils.compat import solve_pos_def # PySAGES main dispatcher dispatch = Dispatcher() @@ -70,3 +71,22 @@ def gaussian(a, sigma, x): N-dimensional origin-centered gaussian with height `a` and standard deviation `sigma`. """ return a * np.exp(-row_sum((x / sigma) ** 2) / 2) + + +def linear_solver(use_pinv: bool): + """ + Returns a function that solves the linear system `A.T @ X = B` for `X`. + When `use_pinv == True`, `numpy.linalg.pinv` is used rather than `scipy.linalg.solve` + (this is computationally more expensive but numerically more stable). + """ + if use_pinv: + # This is numerically more robust + def tsolve(A, B): + return np.linalg.pinv(A.T) @ B + + else: + # Another option to benchmark against is `linalg.tensorsolve(A @ A.T, A @ B)` + def tsolve(A, B): + return solve_pos_def(A @ A.T, A @ B) + + return tsolve