Skip to content

Commit

Permalink
pr suggestions 1
Browse files Browse the repository at this point in the history
  • Loading branch information
shoubhikraj committed Nov 27, 2023
1 parent 35d5e4d commit 396bdc8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
11 changes: 3 additions & 8 deletions autode/opt/coordinates/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,10 @@ def __init__(self, *args: Any):
f"from {args}. Must be primitive internals"
)

def append(self, item) -> None:
def append(self, item: Primitive) -> None:
"""Append an item to this set of primitives"""
if isinstance(item, Primitive):
super().append(item)
else:
raise TypeError(
f"Can only append Primitive type but"
f" {type(item)} was provided"
)
assert isinstance(item, Primitive), "Must be a Primitive type!"
super().append(item)

@property
def B(self) -> np.ndarray:
Expand Down
8 changes: 4 additions & 4 deletions autode/opt/coordinates/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _cross_vec3(


def _get_vars_from_atom_idxs(
*args,
*args: int,
x: "CartesianCoordinates",
deriv_order: int,
) -> List["VectorHyperDual"]:
Expand All @@ -65,13 +65,13 @@ def _get_vars_from_atom_idxs(
Returns:
(list[VectorHyperDual]): A list of differentiable variables
"""
assert all(isinstance(atom, int) and atom >= 0 for atom in args)
assert all(isinstance(idx, int) and idx >= 0 for idx in args)
# get positions in the flat Cartesian array
_x = x.ravel()
cart_idxs = []
for atom in args:
for atom_idx in args:
for k in range(3):
cart_idxs.append(3 * atom + k)
cart_idxs.append(3 * atom_idx + k)
return get_differentiable_vars(
values=[_x[idx] for idx in cart_idxs],
symbols=[str(idx) for idx in cart_idxs],
Expand Down

0 comments on commit 396bdc8

Please sign in to comment.