Skip to content

Commit

Permalink
Test branch for PR 292
Browse files Browse the repository at this point in the history
  • Loading branch information
trunk-io[bot] authored Jan 4, 2024
2 parents aeb07ad + 6ef2bde commit 92c8d43
Show file tree
Hide file tree
Showing 17 changed files with 101 additions and 67 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-20.04, macos-latest]
python-version: [3.8, 3.9]
python-version: [3.8, 3.9, 3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }}
Expand All @@ -32,7 +32,7 @@ jobs:
- name: Install python dependecies
run: |
python -m pip install --upgrade pip
pip install dill jaxlib "jax-md>=0.2.7" jaxopt pytest matplotlib
pip install ase dill "dm-haiku<0.0.11" "e3nn-jax!=0.20.4" jaxlib "jax-md>=0.2.7" jaxopt pytest matplotlib
- name: Install pysages
run: pip install .
Expand Down Expand Up @@ -60,7 +60,7 @@ jobs:
- name: Install python dependecies
run: |
python -m pip install --upgrade pip
pip install dill jaxlib "jax-md>=0.2.7" jaxopt pytest pylint flake8
pip install dill "dm-haiku<0.0.11" "e3nn-jax!=0.20.4" jaxlib "jax-md>=0.2.7" jaxopt pytest pylint flake8
pip install -r docs/requirements.txt
- name: Install pysages
run: pip install .
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ RUN python -m pip install ase gsd matplotlib "pyparsing<3"

# Install JAX and JAX-MD
RUN python -m pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN python -m pip install --upgrade "jax-md>=0.2.7" jaxopt
RUN python -m pip install --upgrade "dm-haiku<0.0.11" "e3nn-jax!=0.20.4" "jax-md>=0.2.7" jaxopt

COPY . /PySAGES
RUN pip install /PySAGES/
1 change: 1 addition & 0 deletions docs/source/pysages_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ metapotential
monocyclic
multi
nanometer
ncalls
nn
numpyfying
Penrose
Expand Down
49 changes: 37 additions & 12 deletions examples/ase/abf/water.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
#!/usr/bin/env python3

# %%
import ase.units as units
import argparse
import sys

import numpy as np
from ase import Atoms
from ase import Atoms, units
from ase.calculators.tip3p import TIP3P, angleHOH, rOH
from ase.constraints import FixBondLengths
from ase.io.trajectory import Trajectory
from ase.md import Langevin

import pysages
from pysages.colvars import Distance
from pysages.colvars import Angle
from pysages.grids import Grid
from pysages.methods import ABF


# %%
def generate_simulation(tag="tip3p"):
def generate_simulation(tag="tip3p", write_output=True):
x = angleHOH * np.pi / 180 / 2
pos = [
[0, 0, 0], # rOH is the distance between oxygen and hydrogen atoms in water
Expand All @@ -37,24 +39,47 @@ def generate_simulation(tag="tip3p"):
)

T = 300 * units.kB
logfile = tag + ".log"
atoms.calc = TIP3P(rc=4.5)
logfile = tag + ".log" if write_output else None
md = Langevin(atoms, 1 * units.fs, temperature_K=T, friction=0.01, logfile=logfile)

traj = Trajectory(tag + ".traj", "w", atoms)
md.attach(traj.write, interval=1)
if write_output:
traj = Trajectory(tag + ".traj", "w", atoms)
md.attach(traj.write, interval=1)

return md


# %%
def main():
cvs = [Distance([0, 3])]
grid = Grid(lower=0.1, upper=9.0, shape=64)
def process_args(argv):
print(repr(argv))
available_args = [
("timesteps", "t", int, 100, "Number of simulation steps"),
("write-output", "o", bool, 1, "Write log and trajectory of the ASE run"),
]
parser = argparse.ArgumentParser(description="Example script to run pysages with ASE")

for name, short, T, val, doc in available_args:
parser.add_argument("--" + name, "-" + short, type=T, default=T(val), help=doc)

return parser.parse_args(argv)


# %%
def run_simulation(timesteps, write_output):
cvs = [Angle([1, 0, 2])]
grid = Grid(lower=0.1, upper=9.0, shape=64, periodic=True)
method = ABF(cvs, grid)
pysages.run(method, generate_simulation, 100)
context_args = dict(write_output=write_output)
return pysages.run(method, generate_simulation, timesteps, context_args=context_args)


# %%
def main(argv=None):
args = process_args([] if argv is None else argv)
run_simulation(args.timesteps, args.write_output)


# %%
if __name__ == "__main__":
main()
main(sys.argv[1:])
8 changes: 6 additions & 2 deletions pysages/methods/abf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class ABFState(NamedTuple):
Wp_: JaxArray (CV shape)
Product of W matrix and momenta matrix for the previous step.
ncalls: int
Counts the number of times the method's update has been called.
"""

xi: JaxArray
Expand All @@ -66,6 +69,7 @@ class ABFState(NamedTuple):
force: JaxArray
Wp: JaxArray
Wp_: JaxArray
ncalls: int

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -174,7 +178,7 @@ def initialize():
force = np.zeros(dims)
Wp = np.zeros(dims)
Wp_ = np.zeros(dims)
return ABFState(xi, bias, hist, Fsum, force, Wp, Wp_)
return ABFState(xi, bias, hist, Fsum, force, Wp, Wp_, 0)

def update(state, data):
"""
Expand Down Expand Up @@ -213,7 +217,7 @@ def update(state, data):
force = estimate_force(xi, I_xi, Fsum, hist).reshape(dims)
bias = np.reshape(-Jxi.T @ force, state.bias.shape)

return ABFState(xi, bias, hist, Fsum, force, Wp, state.Wp)
return ABFState(xi, bias, hist, Fsum, force, Wp, state.Wp, state.ncalls + 1)

return snapshot, initialize, generalize(update, helpers)

Expand Down
16 changes: 8 additions & 8 deletions pysages/methods/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class ANNState(NamedTuple):
nn: NNDada
Bundle of the neural network parameters, and output scaling coefficients.
nstep: int
Count the number of times the method's update has been called.
ncalls: int
Counts the number of times the method's update has been called.
"""

xi: JaxArray
Expand All @@ -70,7 +70,7 @@ class ANNState(NamedTuple):
phi: JaxArray
prob: JaxArray
nn: NNData
nstep: int
ncalls: int

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -148,13 +148,13 @@ def initialize():
phi = np.zeros(shape)
prob = np.ones(shape)
nn = NNData(ps, np.array(0.0), np.array(1.0))
return ANNState(xi, bias, hist, phi, prob, nn, 1)
return ANNState(xi, bias, hist, phi, prob, nn, 0)

def update(state, data):
nstep = state.nstep
in_training_regime = nstep > train_freq
ncalls = state.ncalls + 1
in_training_regime = ncalls > train_freq
# We only train every `train_freq` timesteps
in_training_step = in_training_regime & (nstep % train_freq == 1)
in_training_step = in_training_regime & (ncalls % train_freq == 1)
hist, phi, prob, nn = learn_free_energy(state, in_training_step)
# Compute the collective variable and its jacobian
xi, Jxi = cv(data)
Expand All @@ -163,7 +163,7 @@ def update(state, data):
F = estimate_force(xi, I_xi, nn, in_training_regime)
bias = np.reshape(-Jxi.T @ F, state.bias.shape)
#
return ANNState(xi, bias, hist, phi, prob, nn, nstep + 1)
return ANNState(xi, bias, hist, phi, prob, nn, ncalls)

return snapshot, initialize, generalize(update, helpers)

Expand Down
18 changes: 8 additions & 10 deletions pysages/methods/cff.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class CFFState(NamedTuple):
nn: NNDada
Bundle of the neural network parameters, and output scaling coefficients.
nstep: int
Count the number of times the method's update has been called.
ncalls: int
Counts the number of times the method's update has been called.
"""

xi: JaxArray
Expand All @@ -93,7 +93,7 @@ class CFFState(NamedTuple):
Wp_: JaxArray
nn: NNData
fnn: NNData
nstep: int
ncalls: int

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -209,13 +209,13 @@ def initialize():
nn = NNData(ps, np.array(0.0), np.array(1.0))
fnn = NNData(fps, np.zeros(dims), np.array(1.0))

return CFFState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, Wp_, nn, fnn, 1)
return CFFState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, Wp_, nn, fnn, 0)

def update(state, data):
# During the intial stage, when there are not enough collected samples, use ABF
nstep = state.nstep
in_training_regime = nstep > 1 * train_freq
in_training_step = in_training_regime & (nstep % train_freq == 1)
ncalls = state.ncalls + 1
in_training_regime = ncalls > train_freq
in_training_step = in_training_regime & (ncalls % train_freq == 1)
histp, fe, prob, nn, fnn = learn_free_energy(state, in_training_step)
# Compute the collective variable and its jacobian
xi, Jxi = cv(data)
Expand All @@ -232,9 +232,7 @@ def update(state, data):
force = estimate_force(PartialCFFState(xi, hist, Fsum, I_xi, fnn, in_training_regime))
bias = (-Jxi.T @ force).reshape(state.bias.shape)
#
return CFFState(
xi, bias, hist, histp, prob, fe, Fsum, force, Wp, state.Wp, nn, fnn, nstep + 1
)
return CFFState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, state.Wp, nn, fnn, ncalls)

return snapshot, initialize, generalize(update, helpers)

Expand Down
5 changes: 3 additions & 2 deletions pysages/methods/ffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
class FFSState(NamedTuple):
xi: JaxArray
bias: Optional[JaxArray]
ncalls: int

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -210,11 +211,11 @@ def _ffs(method, snapshot, helpers):
# initialize method
def initialize():
xi = cv(helpers.query(snapshot))
return FFSState(xi, None)
return FFSState(xi, None, 0)

def update(state, data):
xi = cv(data)
return FFSState(xi, None)
return FFSState(xi, None, state.ncalls + 1)

return snapshot, initialize, generalize(update, helpers)

Expand Down
16 changes: 8 additions & 8 deletions pysages/methods/funn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class FUNNState(NamedTuple):
nn: NNData
Bundle of the neural network parameters, and output scaling coefficients.
nstep: int
Count the number of times the method's update has been called.
ncalls: int
Counts the number of times the method's update has been called.
"""

xi: JaxArray
Expand All @@ -78,7 +78,7 @@ class FUNNState(NamedTuple):
Wp: JaxArray
Wp_: JaxArray
nn: NNData
nstep: int
ncalls: int

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -173,13 +173,13 @@ def initialize():
Wp = np.zeros(dims)
Wp_ = np.zeros(dims)
nn = NNData(ps, F, F)
return FUNNState(xi, bias, hist, Fsum, F, Wp, Wp_, nn, 1)
return FUNNState(xi, bias, hist, Fsum, F, Wp, Wp_, nn, 0)

def update(state, data):
# During the intial stage, when there are not enough collected samples, use ABF
nstep = state.nstep
in_training_regime = nstep > 2 * train_freq
in_training_step = in_training_regime & (nstep % train_freq == 1)
ncalls = state.ncalls + 1
in_training_regime = ncalls > 2 * train_freq
in_training_step = in_training_regime & (ncalls % train_freq == 1)
# NN training
nn = learn_free_energy_grad(state, in_training_step)
# Compute the collective variable and its jacobian
Expand All @@ -198,7 +198,7 @@ def update(state, data):
)
bias = (-Jxi.T @ F).reshape(state.bias.shape)
#
return FUNNState(xi, bias, hist, Fsum, F, Wp, state.Wp, nn, state.nstep + 1)
return FUNNState(xi, bias, hist, Fsum, F, Wp, state.Wp, nn, state.ncalls)

return snapshot, initialize, generalize(update, helpers)

Expand Down
5 changes: 3 additions & 2 deletions pysages/methods/harmonic_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class HarmonicBiasState(NamedTuple):

xi: JaxArray
bias: JaxArray
ncalls: int

def __repr__(self):
return repr("PySAGES" + type(self).__name__)
Expand Down Expand Up @@ -118,14 +119,14 @@ def _harmonic_bias(method, snapshot, helpers):
def initialize():
xi, _ = cv(helpers.query(snapshot))
bias = np.zeros((natoms, helpers.dimensionality()))
return HarmonicBiasState(xi, bias)
return HarmonicBiasState(xi, bias, 0)

def update(state, data):
xi, Jxi = cv(data)
forces = kspring @ (xi - center).flatten()
bias = -Jxi.T @ forces.flatten()
bias = bias.reshape(state.bias.shape)

return HarmonicBiasState(xi, bias)
return HarmonicBiasState(xi, bias, state.ncalls + 1)

return snapshot, initialize, generalize(update, helpers)
Loading

0 comments on commit 92c8d43

Please sign in to comment.