From 7c272806dfae022cb8f94bb8671cf133426c6d41 Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:26:29 -0600 Subject: [PATCH 1/4] Use math prod when available --- pysages/_compat.py | 7 +++++-- pysages/ml/utils.py | 8 +------- pysages/utils/__init__.py | 3 ++- pysages/utils/compat.py | 21 ++++++++++++++++++++- 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/pysages/_compat.py b/pysages/_compat.py index 339966a8..ca87ba01 100644 --- a/pysages/_compat.py +++ b/pysages/_compat.py @@ -4,7 +4,9 @@ # flake8: noqa F401 # pylint: disable=unused-import,relative-beyond-top-level -import jaxlib +from importlib import import_module +from platform import python_version + from plum._version import __version_tuple__ as _plum_version_tuple # Compatibility utils @@ -14,4 +16,5 @@ def _version_as_tuple(ver_str): return tuple(int(i) for i in ver_str.split(".") if i.isdigit()) -_jax_version_tuple = _version_as_tuple(jaxlib.__version__) +_jax_version_tuple = _version_as_tuple(import_module("jaxlib").__version__) +_python_version_tuple = _version_as_tuple(python_version()) diff --git a/pysages/ml/utils.py b/pysages/ml/utils.py index 753ffe9d..74466073 100644 --- a/pysages/ml/utils.py +++ b/pysages/ml/utils.py @@ -9,6 +9,7 @@ from plum import Dispatcher from pysages.typing import NamedTuple +from pysages.utils import prod # Dispatcher for the `ml` submodule dispatch = Dispatcher() @@ -36,13 +37,6 @@ def rng_key(seed=0, n=2): return key -def prod(xs): - y = 1 - for x in xs: - y *= x - return y - - # %% Models def unpack(params): """ diff --git a/pysages/utils/__init__.py b/pysages/utils/__init__.py index 02e3d1be..00279a4c 100644 --- a/pysages/utils/__init__.py +++ b/pysages/utils/__init__.py @@ -13,8 +13,9 @@ dispatch_table, has_method, is_generic_subclass, + prod, solve_pos_def, try_import, ) -from .core import ToCPU, copy, dispatch, eps, gaussian, identity, only_or_identity +from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity from .transformations import quaternion_from_euler, quaternion_matrix diff --git a/pysages/utils/compat.py b/pysages/utils/compat.py index bca62859..d72094e0 100644 --- a/pysages/utils/compat.py +++ b/pysages/utils/compat.py @@ -5,7 +5,11 @@ from jax.scipy import linalg -from pysages._compat import _jax_version_tuple, _plum_version_tuple +from pysages._compat import ( + _jax_version_tuple, + _plum_version_tuple, + _python_version_tuple, +) # Compatibility utils @@ -17,6 +21,21 @@ def try_import(new_name, old_name): return import_module(old_name) +if _python_version_tuple >= (3, 8): + prod = import_module("math").prod +else: + + def prod(iterable, start=1): + """ + Calculate the product of all the elements in the input iterable. + When the iterable is empty, return the start value (1 by default). + """ + result = start + for x in iterable: + result *= x + return result + + # Compatibility for jax >=0.4.1 # https://github.com/google/jax/releases/tag/jax-v0.4.1 From 5c0a09e116a8521314c728c95c5c45fb2d95a583 Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:30:02 -0600 Subject: [PATCH 2/4] Rename only_or_identity to first_or_all --- pysages/methods/analysis.py | 10 +++++----- pysages/methods/ann.py | 5 +---- pysages/methods/cff.py | 5 +---- pysages/methods/funn.py | 4 ++-- pysages/methods/spectral_abf.py | 5 +---- pysages/utils/core.py | 2 +- 6 files changed, 11 insertions(+), 20 deletions(-) diff --git a/pysages/methods/analysis.py b/pysages/methods/analysis.py index 41dc5658..f02677f1 100644 --- a/pysages/methods/analysis.py +++ b/pysages/methods/analysis.py @@ -20,7 +20,7 @@ from pysages.ml.optimizers import LevenbergMarquardt from pysages.ml.training import NNData, build_fitting_function, convolve from pysages.ml.utils import blackman_kernel, pack, unpack -from pysages.utils import dispatch, only_or_identity +from pysages.utils import dispatch, first_or_all class AnalysisStrategy: @@ -162,9 +162,9 @@ def average_forces(hist, Fsum): fes_fns.append(fes_fn) return { - "histogram": only_or_identity(hists), - "mean_force": only_or_identity(mean_forces), - "free_energy": only_or_identity(free_energies), - "fes_fn": only_or_identity(fes_fns), + "histogram": first_or_all(hists), + "mean_force": first_or_all(mean_forces), + "free_energy": first_or_all(free_energies), + "fes_fn": first_or_all(fes_fns), "mesh": mesh, } diff --git a/pysages/methods/ann.py b/pysages/methods/ann.py index 98952acd..1abc1dc5 100644 --- a/pysages/methods/ann.py +++ b/pysages/methods/ann.py @@ -32,7 +32,7 @@ from pysages.ml.training import NNData, build_fitting_function, convolve, normalize from pysages.ml.utils import blackman_kernel, pack, unpack from pysages.typing import JaxArray, NamedTuple -from pysages.utils import dispatch +from pysages.utils import dispatch, first_or_all class ANNState(NamedTuple): @@ -297,9 +297,6 @@ def fes_fn(x): return jit(fes_fn) - def first_or_all(seq): - return seq[0] if len(seq) == 1 else seq - histograms = [] free_energies = [] nns = [] diff --git a/pysages/methods/cff.py b/pysages/methods/cff.py index c5b4506e..a127ae9c 100644 --- a/pysages/methods/cff.py +++ b/pysages/methods/cff.py @@ -31,7 +31,7 @@ from pysages.ml.training import NNData, build_fitting_function, convolve, normalize from pysages.ml.utils import blackman_kernel, pack, unpack from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, solve_pos_def +from pysages.utils import dispatch, first_or_all, solve_pos_def # Aliases f32 = np.float32 @@ -388,9 +388,6 @@ def fes_fn(x): return jit(fes_fn) - def first_or_all(seq): - return seq[0] if len(seq) == 1 else seq - histograms = [] mean_forces = [] free_energies = [] diff --git a/pysages/methods/funn.py b/pysages/methods/funn.py index ad268d54..6130d396 100644 --- a/pysages/methods/funn.py +++ b/pysages/methods/funn.py @@ -34,7 +34,7 @@ from pysages.ml.training import NNData, build_fitting_function, convolve, normalize from pysages.ml.utils import blackman_kernel, pack, unpack from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, only_or_identity, solve_pos_def +from pysages.utils import dispatch, first_or_all, solve_pos_def class FUNNState(NamedTuple): @@ -332,5 +332,5 @@ def analyze(result: Result[FUNN], **kwargs): """ topology = kwargs.get("topology", result.method.topology) _result = _analyze(result, GradientLearning(), topology) - _result["nn"] = only_or_identity([state.nn for state in result.states]) + _result["nn"] = first_or_all([state.nn for state in result.states]) return numpyfy_vals(_result) diff --git a/pysages/methods/spectral_abf.py b/pysages/methods/spectral_abf.py index deccce1e..6e9d9d94 100644 --- a/pysages/methods/spectral_abf.py +++ b/pysages/methods/spectral_abf.py @@ -30,7 +30,7 @@ from pysages.methods.restraints import apply_restraints from pysages.methods.utils import numpyfy_vals from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, solve_pos_def +from pysages.utils import dispatch, first_or_all, solve_pos_def class SpectralABFState(NamedTuple): @@ -310,9 +310,6 @@ def fes_fn(x): return jit(fes_fn) - def first_or_all(seq): - return seq[0] if len(seq) == 1 else seq - hists = [] mean_forces = [] free_energies = [] diff --git a/pysages/utils/core.py b/pysages/utils/core.py index 9125d09a..06afdbc8 100644 --- a/pysages/utils/core.py +++ b/pysages/utils/core.py @@ -46,7 +46,7 @@ def identity(x): return x -def only_or_identity(seq): +def first_or_all(seq): """ Returns the only element of a sequence `seq` if its length is one, otherwise returns `seq` itself. From ac57b6fbc1701e83b25f4e14540077fd37e4443c Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Wed, 14 Feb 2024 15:20:27 -0600 Subject: [PATCH 3/4] Add helper function for transposing grid data --- pysages/grids.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/pysages/grids.py b/pysages/grids.py index d8fdb63e..3b5f5975 100644 --- a/pysages/grids.py +++ b/pysages/grids.py @@ -8,7 +8,7 @@ from plum import Union, parametric from pysages.typing import JaxArray -from pysages.utils import dispatch, is_generic_subclass +from pysages.utils import dispatch, is_generic_subclass, prod class GridType: @@ -156,3 +156,23 @@ def get_index(x): return (*np.flip(np.uint32(idx)),) return jit(get_index) + + +def grid_transposer(grid): + """ + Returns a function that transposes arrays mapped to a `Grid`. + + The result function takes an array, reshapes it to match the grid dimensions, + transposes it along the first axes. The first axes are assumed to correspond to the + axes of the grid. + """ + d = len(grid.shape) + shape = (*grid.shape,) + axes = (*reversed(range(d)),) + n = grid.shape.prod().item() + + def transpose(array: JaxArray): + m = prod(array.shape) // n + return array.reshape(*shape, m).transpose(*axes, d).squeeze() + + return transpose From 3c2f60eb51851209849e9b3d13599f8b7c7a264a Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Tue, 30 Jan 2024 17:02:57 -0600 Subject: [PATCH 4/4] Allow different number of grid points along each axis --- pysages/approxfun/core.py | 45 ++++++++++++++++----------------- pysages/grids.py | 14 +++++----- pysages/methods/analysis.py | 12 ++++++--- pysages/methods/ann.py | 24 ++++++++++-------- pysages/methods/cff.py | 30 ++++++++++++---------- pysages/methods/spectral_abf.py | 27 +++++++++++--------- tests/test_grids.py | 4 +-- 7 files changed, 85 insertions(+), 71 deletions(-) diff --git a/pysages/approxfun/core.py b/pysages/approxfun/core.py index fd76ba79..3658cecc 100644 --- a/pysages/approxfun/core.py +++ b/pysages/approxfun/core.py @@ -104,40 +104,39 @@ def scale(x, grid: Grid): return (x - grid.lower) * 2 / grid.size - 1 -@dispatch -def compute_mesh(grid: Grid): +def compute_mesh(grid): """ - Returns a dense mesh with the same shape as `grid`, but on the hypercube - [-1, 1]ⁿ, where `n` is the dimensionality of `grid`. + Returns a dense mesh with the same shape as `grid`, but on the hypercube [-1, 1]ⁿ, + where `n` is the dimensionality of `grid`. The resulting mesh is Chebyshev-distributed + if `grid: Grid[Chebyshev]`, or uniformly-distributed otherwise. """ - h = 2 / grid.shape - o = -1 + h / 2 - nodes = o + h * np.hstack([np.arange(i).reshape(-1, 1) for i in grid.shape]) + def generate_axis(n): + transform = _generate_transform(grid, n) + return transform(np.arange(n)) - return _compute_mesh(nodes) + return cartesian_product(*(generate_axis(i) for i in grid.shape)) @dispatch -def compute_mesh(grid: Grid[Chebyshev]): # noqa: F811 # pylint: disable=C0116,E0102 - """ - Returns a Chebyshev-distributed dense mesh with the same shape as `grid`, - but on the hypercube [-1, 1]ⁿ, where n is the dimensionality of `grid`. - """ - - def transform(n): - return vmap(lambda k: -np.cos((k + 1 / 2) * np.pi / n)) +def _generate_transform(_: Grid, n): + return vmap(lambda k: -1 + (2 * k + 1) / n) - nodes = np.hstack([transform(i)(np.arange(i).reshape(-1, 1)) for i in grid.shape]) - return _compute_mesh(nodes) +@dispatch +def _generate_transform(_: Grid[Chebyshev], n): # noqa: F811 # pylint: disable=C0116,E0102 + return vmap(lambda k: -np.cos((k + 1 / 2) * np.pi / n)) -def _compute_mesh(nodes): - components = np.meshgrid( - *nodes.T, - ) - return np.hstack([v.reshape(-1, 1) for v in components]) +def cartesian_product(*collections): + """ + Given a set of `collections`, returns an array with their [Cartesian + Product](https://en.wikipedia.org/wiki/Cartesian_product). + """ + n = len(collections) + coordinates = np.array(np.meshgrid(*collections, indexing="ij")) + permutation = np.roll(np.arange(n + 1), -1) + return np.transpose(coordinates, permutation).reshape(-1, n) def vander_builder(grid, exponents): diff --git a/pysages/grids.py b/pysages/grids.py index 3b5f5975..6c2039ee 100644 --- a/pysages/grids.py +++ b/pysages/grids.py @@ -36,7 +36,7 @@ class Grid: size: JaxArray @classmethod - def __infer_type_parameter__(cls, *args, **kwargs): + def __infer_type_parameter__(cls, *_, **kwargs): return Periodic if kwargs.get("periodic", False) else Regular def __init__(self, lower, upper, shape, **kwargs): @@ -49,7 +49,7 @@ def __init__(self, lower, upper, shape, **kwargs): self.size = self.upper - self.lower def __check_init_invariants__(self, **kwargs): - T = type(self).type_parameter + T = type(self).type_parameter # pylint: disable=E1101 if not (issubclass(type(T), type) and issubclass(T, GridType)): raise TypeError("Type parameter must be a subclass of GridType.") if len(kwargs) > 1 or (len(kwargs) == 1 and "periodic" not in kwargs): @@ -64,13 +64,13 @@ def __check_init_invariants__(self, **kwargs): raise ValueError("Incompatible type parameter and keyword argument") def __repr__(self): - T = type(self).type_parameter + T = type(self).type_parameter # pylint: disable=E1101 P = "" if T is Regular else f"[{T.__name__}]" return f"Grid{P} ({' x '.join(map(str, self.shape))})" @property def is_periodic(self): - return type(self).type_parameter is Periodic + return type(self).type_parameter is Periodic # pylint: disable=E1101 @dispatch @@ -118,7 +118,7 @@ def get_index(x): h = grid.size / grid.shape idx = (x.flatten() - grid.lower) // h idx = np.where((idx < 0) | (idx > grid.shape), grid.shape, idx) - return (*np.flip(np.uint32(idx)),) + return (*np.uint32(idx),) return jit(get_index) @@ -135,7 +135,7 @@ def get_index(x): h = grid.size / grid.shape idx = (x.flatten() - grid.lower) // h idx = idx % grid.shape - return (*np.flip(np.uint32(idx)),) + return (*np.uint32(idx),) return jit(get_index) @@ -153,7 +153,7 @@ def get_index(x): x = 2 * (grid.lower - x.flatten()) / grid.size + 1 idx = (grid.shape * np.arccos(x)) // np.pi idx = np.nan_to_num(idx, nan=grid.shape) - return (*np.flip(np.uint32(idx)),) + return (*np.uint32(idx),) return jit(get_index) diff --git a/pysages/methods/analysis.py b/pysages/methods/analysis.py index f02677f1..27d6554a 100644 --- a/pysages/methods/analysis.py +++ b/pysages/methods/analysis.py @@ -14,6 +14,7 @@ from pysages.approxfun import compute_mesh from pysages.approxfun import scale as _scale +from pysages.grids import grid_transposer from pysages.methods.core import Result from pysages.ml.models import MLP from pysages.ml.objectives import GradientsSSE, L2Regularization @@ -154,11 +155,14 @@ def average_forces(hist, Fsum): free_energies = [] fes_fns = [] + # We transpose the data for convenience when plotting + transpose = grid_transposer(grid) + for state in states: fes_fn = build_fes_fn(state) - hists.append(state.hist) - mean_forces.append(average_forces(state.hist, state.Fsum)) - free_energies.append(fes_fn(mesh).reshape(grid.shape)) + hists.append(transpose(state.hist)) + mean_forces.append(transpose(average_forces(state.hist, state.Fsum))) + free_energies.append(transpose(fes_fn(mesh))) fes_fns.append(fes_fn) return { @@ -166,5 +170,5 @@ def average_forces(hist, Fsum): "mean_force": first_or_all(mean_forces), "free_energy": first_or_all(free_energies), "fes_fn": first_or_all(fes_fns), - "mesh": mesh, + "mesh": transpose(mesh), } diff --git a/pysages/methods/ann.py b/pysages/methods/ann.py index 1abc1dc5..fec88f3a 100644 --- a/pysages/methods/ann.py +++ b/pysages/methods/ann.py @@ -23,7 +23,7 @@ from pysages.approxfun import compute_mesh from pysages.approxfun import scale as _scale -from pysages.grids import build_indexer +from pysages.grids import build_indexer, grid_transposer from pysages.methods.core import NNSamplingMethod, Result, generalize from pysages.methods.utils import numpyfy_vals from pysages.ml.models import MLP @@ -302,17 +302,21 @@ def fes_fn(x): nns = [] fes_fns = [] + # We transpose the data for convenience when plotting + transpose = grid_transposer(grid) + for s in states: - histograms.append(s.hist) - free_energies.append(s.phi.max() - s.phi) + histograms.append(transpose(s.hist)) + free_energies.append(transpose(s.phi.max() - s.phi)) nns.append(s.nn) fes_fns.append(build_fes_fn(s.nn)) - ana_result = dict( - histogram=first_or_all(histograms), - free_energy=first_or_all(free_energies), - mesh=mesh, - nn=first_or_all(nns), - fes_fn=first_or_all(fes_fns), - ) + ana_result = { + "histogram": first_or_all(histograms), + "free_energy": first_or_all(free_energies), + "mesh": transpose(mesh), + "nn": first_or_all(nns), + "fes_fn": first_or_all(fes_fns), + } + return numpyfy_vals(ana_result) diff --git a/pysages/methods/cff.py b/pysages/methods/cff.py index a127ae9c..9d650a4f 100644 --- a/pysages/methods/cff.py +++ b/pysages/methods/cff.py @@ -21,7 +21,7 @@ from pysages.approxfun import compute_mesh from pysages.approxfun import scale as _scale -from pysages.grids import build_indexer +from pysages.grids import build_indexer, grid_transposer from pysages.methods.core import NNSamplingMethod, Result, generalize from pysages.methods.restraints import apply_restraints from pysages.methods.utils import numpyfy_vals @@ -395,21 +395,25 @@ def fes_fn(x): fnns = [] fes_fns = [] + # We transpose the data for convenience when plotting + transpose = grid_transposer(grid) + for s in states: - histograms.append(s.hist) - mean_forces.append(average_forces(s.hist, s.Fsum)) - free_energies.append(s.fe.max() - s.fe) + histograms.append(transpose(s.hist)) + mean_forces.append(transpose(average_forces(s.hist, s.Fsum))) + free_energies.append(transpose(s.fe.max() - s.fe)) nns.append(s.nn) fnns.append(s.fnn) fes_fns.append(build_fes_fn(s.nn)) - ana_result = dict( - histogram=first_or_all(histograms), - mean_force=first_or_all(mean_forces), - free_energy=first_or_all(free_energies), - mesh=mesh, - nn=first_or_all(nns), - fnn=first_or_all(fnns), - fes_fn=first_or_all(fes_fns), - ) + ana_result = { + "histogram": first_or_all(histograms), + "mean_force": first_or_all(mean_forces), + "free_energy": first_or_all(free_energies), + "mesh": transpose(mesh), + "nn": first_or_all(nns), + "fnn": first_or_all(fnns), + "fes_fn": first_or_all(fes_fns), + } + return numpyfy_vals(ana_result) diff --git a/pysages/methods/spectral_abf.py b/pysages/methods/spectral_abf.py index 6e9d9d94..05cc147f 100644 --- a/pysages/methods/spectral_abf.py +++ b/pysages/methods/spectral_abf.py @@ -25,7 +25,7 @@ build_grad_evaluator, compute_mesh, ) -from pysages.grids import Chebyshev, Grid, build_indexer, convert +from pysages.grids import Chebyshev, Grid, build_indexer, convert, grid_transposer from pysages.methods.core import GriddedSamplingMethod, Result, generalize from pysages.methods.restraints import apply_restraints from pysages.methods.utils import numpyfy_vals @@ -316,21 +316,24 @@ def fes_fn(x): funs = [] fes_fns = [] + # We transpose the data for convenience when plotting + transpose = grid_transposer(grid) + for s in states: fes_fn = build_fes_fn(s.fun) - hists.append(s.hist) - mean_forces.append(average_forces(s.hist, s.Fsum)) - free_energies.append(fes_fn(mesh).reshape(grid.shape)) + hists.append(transpose(s.hist)) + mean_forces.append(transpose(average_forces(s.hist, s.Fsum))) + free_energies.append(transpose(fes_fn(mesh))) funs.append(s.fun) fes_fns.append(fes_fn) - ana_result = dict( - histogram=first_or_all(hists), - mean_force=first_or_all(mean_forces), - free_energy=first_or_all(free_energies), - mesh=mesh, - fun=first_or_all(funs), - fes_fn=first_or_all(fes_fns), - ) + ana_result = { + "histogram": first_or_all(hists), + "mean_force": first_or_all(mean_forces), + "free_energy": first_or_all(free_energies), + "mesh": transpose(mesh), + "fun": first_or_all(funs), + "fes_fn": first_or_all(fes_fns), + } return numpyfy_vals(ana_result) diff --git a/tests/test_grids.py b/tests/test_grids.py index 74e5460a..2d3aef27 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -88,5 +88,5 @@ def test_grid_indexing(): # Indexing 2D x_lo_up = np.array([-pi, 1]) x_up_lo_out = np.array([pi, -2]) - assert get_index_2d(x_lo_up) == (UInt32(32), UInt32(0)) - assert get_index_2d(x_up_lo_out) == (UInt32(32), UInt32(64)) + assert get_index_2d(x_lo_up) == (UInt32(0), UInt32(32)) + assert get_index_2d(x_up_lo_out) == (UInt32(64), UInt32(32))