Skip to content

Commit

Permalink
Test branch for PR 300
Browse files Browse the repository at this point in the history
  • Loading branch information
trunk-io[bot] authored Feb 19, 2024
2 parents 802cc16 + 3c2f60e commit 414e3ed
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 103 deletions.
7 changes: 5 additions & 2 deletions pysages/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# flake8: noqa F401
# pylint: disable=unused-import,relative-beyond-top-level

import jaxlib
from importlib import import_module
from platform import python_version

from plum._version import __version_tuple__ as _plum_version_tuple

# Compatibility utils
Expand All @@ -14,4 +16,5 @@ def _version_as_tuple(ver_str):
return tuple(int(i) for i in ver_str.split(".") if i.isdigit())


_jax_version_tuple = _version_as_tuple(jaxlib.__version__)
_jax_version_tuple = _version_as_tuple(import_module("jaxlib").__version__)
_python_version_tuple = _version_as_tuple(python_version())
45 changes: 22 additions & 23 deletions pysages/approxfun/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,40 +104,39 @@ def scale(x, grid: Grid):
return (x - grid.lower) * 2 / grid.size - 1


@dispatch
def compute_mesh(grid: Grid):
def compute_mesh(grid):
"""
Returns a dense mesh with the same shape as `grid`, but on the hypercube
[-1, 1]ⁿ, where `n` is the dimensionality of `grid`.
Returns a dense mesh with the same shape as `grid`, but on the hypercube [-1, 1]ⁿ,
where `n` is the dimensionality of `grid`. The resulting mesh is Chebyshev-distributed
if `grid: Grid[Chebyshev]`, or uniformly-distributed otherwise.
"""
h = 2 / grid.shape
o = -1 + h / 2

nodes = o + h * np.hstack([np.arange(i).reshape(-1, 1) for i in grid.shape])
def generate_axis(n):
transform = _generate_transform(grid, n)
return transform(np.arange(n))

return _compute_mesh(nodes)
return cartesian_product(*(generate_axis(i) for i in grid.shape))


@dispatch
def compute_mesh(grid: Grid[Chebyshev]): # noqa: F811 # pylint: disable=C0116,E0102
"""
Returns a Chebyshev-distributed dense mesh with the same shape as `grid`,
but on the hypercube [-1, 1]ⁿ, where n is the dimensionality of `grid`.
"""

def transform(n):
return vmap(lambda k: -np.cos((k + 1 / 2) * np.pi / n))
def _generate_transform(_: Grid, n):
return vmap(lambda k: -1 + (2 * k + 1) / n)

nodes = np.hstack([transform(i)(np.arange(i).reshape(-1, 1)) for i in grid.shape])

return _compute_mesh(nodes)
@dispatch
def _generate_transform(_: Grid[Chebyshev], n): # noqa: F811 # pylint: disable=C0116,E0102
return vmap(lambda k: -np.cos((k + 1 / 2) * np.pi / n))


def _compute_mesh(nodes):
components = np.meshgrid(
*nodes.T,
)
return np.hstack([v.reshape(-1, 1) for v in components])
def cartesian_product(*collections):
"""
Given a set of `collections`, returns an array with their [Cartesian
Product](https://en.wikipedia.org/wiki/Cartesian_product).
"""
n = len(collections)
coordinates = np.array(np.meshgrid(*collections, indexing="ij"))
permutation = np.roll(np.arange(n + 1), -1)
return np.transpose(coordinates, permutation).reshape(-1, n)


def vander_builder(grid, exponents):
Expand Down
36 changes: 28 additions & 8 deletions pysages/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from plum import Union, parametric

from pysages.typing import JaxArray
from pysages.utils import dispatch, is_generic_subclass
from pysages.utils import dispatch, is_generic_subclass, prod


class GridType:
Expand Down Expand Up @@ -36,7 +36,7 @@ class Grid:
size: JaxArray

@classmethod
def __infer_type_parameter__(cls, *args, **kwargs):
def __infer_type_parameter__(cls, *_, **kwargs):
return Periodic if kwargs.get("periodic", False) else Regular

def __init__(self, lower, upper, shape, **kwargs):
Expand All @@ -49,7 +49,7 @@ def __init__(self, lower, upper, shape, **kwargs):
self.size = self.upper - self.lower

def __check_init_invariants__(self, **kwargs):
T = type(self).type_parameter
T = type(self).type_parameter # pylint: disable=E1101
if not (issubclass(type(T), type) and issubclass(T, GridType)):
raise TypeError("Type parameter must be a subclass of GridType.")
if len(kwargs) > 1 or (len(kwargs) == 1 and "periodic" not in kwargs):
Expand All @@ -64,13 +64,13 @@ def __check_init_invariants__(self, **kwargs):
raise ValueError("Incompatible type parameter and keyword argument")

def __repr__(self):
T = type(self).type_parameter
T = type(self).type_parameter # pylint: disable=E1101
P = "" if T is Regular else f"[{T.__name__}]"
return f"Grid{P} ({' x '.join(map(str, self.shape))})"

@property
def is_periodic(self):
return type(self).type_parameter is Periodic
return type(self).type_parameter is Periodic # pylint: disable=E1101


@dispatch
Expand Down Expand Up @@ -118,7 +118,7 @@ def get_index(x):
h = grid.size / grid.shape
idx = (x.flatten() - grid.lower) // h
idx = np.where((idx < 0) | (idx > grid.shape), grid.shape, idx)
return (*np.flip(np.uint32(idx)),)
return (*np.uint32(idx),)

return jit(get_index)

Expand All @@ -135,7 +135,7 @@ def get_index(x):
h = grid.size / grid.shape
idx = (x.flatten() - grid.lower) // h
idx = idx % grid.shape
return (*np.flip(np.uint32(idx)),)
return (*np.uint32(idx),)

return jit(get_index)

Expand All @@ -153,6 +153,26 @@ def get_index(x):
x = 2 * (grid.lower - x.flatten()) / grid.size + 1
idx = (grid.shape * np.arccos(x)) // np.pi
idx = np.nan_to_num(idx, nan=grid.shape)
return (*np.flip(np.uint32(idx)),)
return (*np.uint32(idx),)

return jit(get_index)


def grid_transposer(grid):
"""
Returns a function that transposes arrays mapped to a `Grid`.
The result function takes an array, reshapes it to match the grid dimensions,
transposes it along the first axes. The first axes are assumed to correspond to the
axes of the grid.
"""
d = len(grid.shape)
shape = (*grid.shape,)
axes = (*reversed(range(d)),)
n = grid.shape.prod().item()

def transpose(array: JaxArray):
m = prod(array.shape) // n
return array.reshape(*shape, m).transpose(*axes, d).squeeze()

return transpose
22 changes: 13 additions & 9 deletions pysages/methods/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

from pysages.approxfun import compute_mesh
from pysages.approxfun import scale as _scale
from pysages.grids import grid_transposer
from pysages.methods.core import Result
from pysages.ml.models import MLP
from pysages.ml.objectives import GradientsSSE, L2Regularization
from pysages.ml.optimizers import LevenbergMarquardt
from pysages.ml.training import NNData, build_fitting_function, convolve
from pysages.ml.utils import blackman_kernel, pack, unpack
from pysages.utils import dispatch, only_or_identity
from pysages.utils import dispatch, first_or_all


class AnalysisStrategy:
Expand Down Expand Up @@ -154,17 +155,20 @@ def average_forces(hist, Fsum):
free_energies = []
fes_fns = []

# We transpose the data for convenience when plotting
transpose = grid_transposer(grid)

for state in states:
fes_fn = build_fes_fn(state)
hists.append(state.hist)
mean_forces.append(average_forces(state.hist, state.Fsum))
free_energies.append(fes_fn(mesh).reshape(grid.shape))
hists.append(transpose(state.hist))
mean_forces.append(transpose(average_forces(state.hist, state.Fsum)))
free_energies.append(transpose(fes_fn(mesh)))
fes_fns.append(fes_fn)

return {
"histogram": only_or_identity(hists),
"mean_force": only_or_identity(mean_forces),
"free_energy": only_or_identity(free_energies),
"fes_fn": only_or_identity(fes_fns),
"mesh": mesh,
"histogram": first_or_all(hists),
"mean_force": first_or_all(mean_forces),
"free_energy": first_or_all(free_energies),
"fes_fn": first_or_all(fes_fns),
"mesh": transpose(mesh),
}
29 changes: 15 additions & 14 deletions pysages/methods/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from pysages.approxfun import compute_mesh
from pysages.approxfun import scale as _scale
from pysages.grids import build_indexer
from pysages.grids import build_indexer, grid_transposer
from pysages.methods.core import NNSamplingMethod, Result, generalize
from pysages.methods.utils import numpyfy_vals
from pysages.ml.models import MLP
Expand All @@ -32,7 +32,7 @@
from pysages.ml.training import NNData, build_fitting_function, convolve, normalize
from pysages.ml.utils import blackman_kernel, pack, unpack
from pysages.typing import JaxArray, NamedTuple
from pysages.utils import dispatch
from pysages.utils import dispatch, first_or_all


class ANNState(NamedTuple):
Expand Down Expand Up @@ -297,25 +297,26 @@ def fes_fn(x):

return jit(fes_fn)

def first_or_all(seq):
return seq[0] if len(seq) == 1 else seq

histograms = []
free_energies = []
nns = []
fes_fns = []

# We transpose the data for convenience when plotting
transpose = grid_transposer(grid)

for s in states:
histograms.append(s.hist)
free_energies.append(s.phi.max() - s.phi)
histograms.append(transpose(s.hist))
free_energies.append(transpose(s.phi.max() - s.phi))
nns.append(s.nn)
fes_fns.append(build_fes_fn(s.nn))

ana_result = dict(
histogram=first_or_all(histograms),
free_energy=first_or_all(free_energies),
mesh=mesh,
nn=first_or_all(nns),
fes_fn=first_or_all(fes_fns),
)
ana_result = {
"histogram": first_or_all(histograms),
"free_energy": first_or_all(free_energies),
"mesh": transpose(mesh),
"nn": first_or_all(nns),
"fes_fn": first_or_all(fes_fns),
}

return numpyfy_vals(ana_result)
35 changes: 18 additions & 17 deletions pysages/methods/cff.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pysages.approxfun import compute_mesh
from pysages.approxfun import scale as _scale
from pysages.grids import build_indexer
from pysages.grids import build_indexer, grid_transposer
from pysages.methods.core import NNSamplingMethod, Result, generalize
from pysages.methods.restraints import apply_restraints
from pysages.methods.utils import numpyfy_vals
Expand All @@ -31,7 +31,7 @@
from pysages.ml.training import NNData, build_fitting_function, convolve, normalize
from pysages.ml.utils import blackman_kernel, pack, unpack
from pysages.typing import JaxArray, NamedTuple, Tuple
from pysages.utils import dispatch, solve_pos_def
from pysages.utils import dispatch, first_or_all, solve_pos_def

# Aliases
f32 = np.float32
Expand Down Expand Up @@ -388,31 +388,32 @@ def fes_fn(x):

return jit(fes_fn)

def first_or_all(seq):
return seq[0] if len(seq) == 1 else seq

histograms = []
mean_forces = []
free_energies = []
nns = []
fnns = []
fes_fns = []

# We transpose the data for convenience when plotting
transpose = grid_transposer(grid)

for s in states:
histograms.append(s.hist)
mean_forces.append(average_forces(s.hist, s.Fsum))
free_energies.append(s.fe.max() - s.fe)
histograms.append(transpose(s.hist))
mean_forces.append(transpose(average_forces(s.hist, s.Fsum)))
free_energies.append(transpose(s.fe.max() - s.fe))
nns.append(s.nn)
fnns.append(s.fnn)
fes_fns.append(build_fes_fn(s.nn))

ana_result = dict(
histogram=first_or_all(histograms),
mean_force=first_or_all(mean_forces),
free_energy=first_or_all(free_energies),
mesh=mesh,
nn=first_or_all(nns),
fnn=first_or_all(fnns),
fes_fn=first_or_all(fes_fns),
)
ana_result = {
"histogram": first_or_all(histograms),
"mean_force": first_or_all(mean_forces),
"free_energy": first_or_all(free_energies),
"mesh": transpose(mesh),
"nn": first_or_all(nns),
"fnn": first_or_all(fnns),
"fes_fn": first_or_all(fes_fns),
}

return numpyfy_vals(ana_result)
4 changes: 2 additions & 2 deletions pysages/methods/funn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pysages.ml.training import NNData, build_fitting_function, convolve, normalize
from pysages.ml.utils import blackman_kernel, pack, unpack
from pysages.typing import JaxArray, NamedTuple, Tuple
from pysages.utils import dispatch, only_or_identity, solve_pos_def
from pysages.utils import dispatch, first_or_all, solve_pos_def


class FUNNState(NamedTuple):
Expand Down Expand Up @@ -332,5 +332,5 @@ def analyze(result: Result[FUNN], **kwargs):
"""
topology = kwargs.get("topology", result.method.topology)
_result = _analyze(result, GradientLearning(), topology)
_result["nn"] = only_or_identity([state.nn for state in result.states])
_result["nn"] = first_or_all([state.nn for state in result.states])
return numpyfy_vals(_result)
Loading

0 comments on commit 414e3ed

Please sign in to comment.