forked from LeifSeute/grappa-data-creation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsingle_points.py
227 lines (157 loc) · 8.26 KB
/
single_points.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
from ase import Atoms
from ase.calculators.psi4 import Psi4
import numpy as np
from time import time
from pathlib import Path
from utils import Logger
###################
import sys
import os
###################
def calc_state(pdb_folder, memory=32, num_threads=4):
"""
Calculates the energy and forces for one state defined by positions.npy, atomic_numbers.npy and charge.npy in the given folder. Several of these calls can run in parallel since the function signalizes that it is working on a state by writing inf to the psi4_energies.npy and psi4_forces.npy files.
"""
log = Logger(Path(pdb_folder).parent, print_to_screen=True)
METHOD = 'bmk'
BASIS = '6-311+G(2df,p)'
scratch = Path(__file__).parent/"psi_scratch"
if not scratch.exists():
os.makedirs(str(scratch), exist_ok=True)
os.environ['PSI_SCRATCH'] = str(scratch)
ACCURACY = 1e-2 # numerical threshold, in kcal/mol, we have 1kcal/mol ~50 meV
if not memory is None and memory > 0:
MEMORY = f'{int(memory)}GB'
else:
MEMORY = None
if not num_threads is None and num_threads > 0:
NUM_THREADS=num_threads
else:
NUM_THREADS = None
pdb_folder = Path(pdb_folder)
if not (pdb_folder/Path("positions.npy")).exists():
return
positions = np.load(str(pdb_folder/Path("positions.npy")))
atomic_numbers = np.load(str(pdb_folder/Path("atomic_numbers.npy")))
total_charge = np.load(str(pdb_folder/Path("charge.npy")))
if not total_charge.shape == (1,):
raise ValueError(f"total_charge.shape must be (1,), is: {total_charge.shape}")
total_charge = int(total_charge[0])
if not np.isclose(total_charge, round(total_charge,0), atol=1e-5):
raise ValueError(f"total_charge is no integer: {total_charge}")
multiplicity = 1
if (pdb_folder/Path("multiplicity.npy")).exists():
multiplicity = np.load(str(pdb_folder/Path("multiplicity.npy")))
if not multiplicity.shape == (1,):
raise ValueError(f"multiplicity.shape must be (1,), is: {multiplicity.shape}")
multiplicity = int(multiplicity[0])
# load if present:
if (pdb_folder/Path("psi4_energies.npy")).exists():
psi4_energies = np.load(str(pdb_folder/Path("psi4_energies.npy")))
else:
psi4_energies = np.zeros_like(positions[:,0,0])*np.nan
if (pdb_folder/Path("psi4_forces.npy")).exists():
psi4_forces = np.load(str(pdb_folder/Path("psi4_forces.npy")))
else:
psi4_forces = np.zeros_like(positions)*np.nan
if np.all(np.isfinite(psi4_energies)) and np.all(np.isfinite(psi4_forces)):
log(f"all states have been calculated for {pdb_folder.stem}")
return
# now pick a state index for an uncalculated state, i.e. an index in the energies array where the energy is either nan or inf:
state_index = np.where(np.isnan(psi4_energies) | np.isinf(psi4_energies))[0][0]
# assert that all force entries are nan of inf too
assert np.all(np.isnan(psi4_forces[state_index])) or np.all(np.isinf(psi4_forces[state_index]))
# now store the arrays with inf at the state index (to signal that the state is being calculated)
psi4_energies[state_index] = np.inf
psi4_forces[state_index] = np.ones_like(psi4_forces[state_index])*np.inf
np.save(str(pdb_folder/Path("psi4_energies.npy")), psi4_energies)
np.save(str(pdb_folder/Path("psi4_forces.npy")), psi4_forces)
start = time()
log(f"calculating state using the config\n")
log(f"\tMETHOD: {METHOD}")
log(f"\tBASIS: {BASIS}")
log(f"\tMEMORY: {MEMORY}")
log(f"\tNUM_THREADS: {NUM_THREADS}")
log(f"\ttotal_charge: {total_charge}")
log(f"\tmultiplicity: {multiplicity}")
msg = f"calculating state number {state_index}..."
start = time()
EV_IN_KCAL_PM = 23.0609
# Read the configuration
atoms = Atoms(numbers=atomic_numbers, positions=positions[state_index])
###################
# set up the calculator:
kwargs = {"atoms":atoms, "method":METHOD, "basis":BASIS, "charge":total_charge, "multiplicity":1, "d_convergence":ACCURACY*EV_IN_KCAL_PM}
if not MEMORY is None:
kwargs["memory"] = MEMORY
if not NUM_THREADS is None:
kwargs["num_threads"] = NUM_THREADS
atoms.set_calculator(Psi4(atoms=atoms, method=METHOD, memory=MEMORY, basis=BASIS, num_threads=NUM_THREADS, charge=total_charge, multiplicity=multiplicity))
###################
energy = atoms.get_potential_energy(apply_constraint=False) # units: eV
forces = atoms.get_forces(apply_constraint=False) # units: eV/Angstrom
energy = energy * EV_IN_KCAL_PM
forces = forces * EV_IN_KCAL_PM
print(f"time elapsed: {round((time() - start)/60., 2)} min")
# load the energies and forces again (another process might have written to them in the meantime)
psi4_energies = np.load(str(pdb_folder/Path("psi4_energies.npy")))
psi4_forces = np.load(str(pdb_folder/Path("psi4_forces.npy")))
psi4_energies[state_index] = energy
psi4_forces[state_index] = forces
np.save(str(pdb_folder/Path("psi4_energies.npy")), psi4_energies)
np.save(str(pdb_folder/Path("psi4_forces.npy")), psi4_forces)
def has_uncalculated_states(pdb_folder):
"""
Returns True if there are any states with nan or inf in the psi4_energies.npy file or if the file does not exist.
"""
pdb_folder = Path(pdb_folder)
if not (pdb_folder/Path("psi4_energies.npy")).exists():
return True
psi4_energies = np.load(str(pdb_folder/Path("psi4_energies.npy")))
return np.any(np.isnan(psi4_energies) | np.isinf(psi4_energies))
def calc_all_states(folder, skip_errs=False, memory=32, num_threads=8, permute_seed=None):
"""
For all folders in the given folder, call calc_states.
is_cleanup_run: list of bools. for every entry, iterates over all sub folders once. If True, the function will skip the folders where a calculation has already been started.
"""
from pathlib import Path
import random
if permute_seed is not None:
random.seed(permute_seed)
log = Logger(Path(folder), print_to_screen=True)
# if the folder itself contains the files, calculate them:
if (Path(folder)/"positions.npy").exists():
log(f"calculating states for {Path(folder).stem}...")
while has_uncalculated_states(folder):
calc_state(folder, memory=memory, num_threads=num_threads)
return
# iterate two times over the folders, since there might have been some errors
for i in range(2):
pdb_folders = [f for f in Path(folder).iterdir() if f.is_dir()]
random.shuffle(pdb_folders)
log(f"calculating states for {len(pdb_folders)} folders in a first iteration.")
# iterate over the pdb_folders until there are no uncalculated states left:
for i, pdb_folder in enumerate(pdb_folders):
log("")
log(f"calculating states for {i}, {Path(pdb_folder).stem}...")
while has_uncalculated_states(pdb_folder):
try:
calc_state(pdb_folder, memory=memory, num_threads=num_threads)
except KeyboardInterrupt:
raise
except Exception as e:
if not skip_errs:
raise
log(f"failed to calculate states for {i} ({Path(folder).stem}): {type(e)}\n: {e}")
break
log(f"finished calculating states for {i}, {Path(pdb_folder).stem}...")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Calculates states for a given folder.')
parser.add_argument('folder', type=str, help='The folder containing the PDB files.')
parser.add_argument('--skip_errs', '-s', action='store_true', help='Skip errors.', default=False)
parser.add_argument('--permute_seed', '-p', type=int, help='The seed to use for shuffling the folders.', default=None)
parser.add_argument('--memory', '-m', type=int, help='The amount of memory to use.', default=32)
parser.add_argument('--num_threads', '-t', type=int, help='The number of threads to use.', default=4)
args = parser.parse_args()
calc_all_states(folder=args.folder, skip_errs=args.skip_errs, memory=args.memory, num_threads=args.num_threads, permute_seed=args.permute_seed)