Skip to content

Commit

Permalink
whoops
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Jan 27, 2025
1 parent 1ab7c20 commit e05ac7a
Showing 1 changed file with 38 additions and 7 deletions.
45 changes: 38 additions & 7 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def _solve_for_indicies(
# solve for real space coordinates
y = np.arange(s[1])
z = np.arange(s[2])
yy, zz = np.meshgrid(y, z)
yy, zz = np.meshgrid(y, z, indexing="ij")
sin = np.sin(tilt)
cos = np.cos(tilt)
r = [[cos, sin], [-sin, cos]]
Expand Down Expand Up @@ -1083,6 +1083,28 @@ def _solve_for_indicies(
"clip",
)

# normalization real space
bincount_real_max = s[0] * s[1] * s[2]

ind_real_bincount_weight = np.bincount(
ind_real.ravel(), weights_real.ravel(), minlength=bincount_real_max
)
ind_real_bincount = np.bincount(ind_real.ravel(), minlength=bincount_real_max)

ind_real_bincount_weight = ind_real_bincount_weight[ind_real_bincount > 0]
ind_real_bincount = ind_real_bincount[ind_real_bincount > 0]

ind_real_bincount_weight[ind_real_bincount_weight == 0] = 1

correction_factor_real = 1 / ind_real_bincount_weight

correction_factor_real = np.repeat(correction_factor_real, ind_real_bincount)
sorted_indicies = np.argsort(np.argsort(ind_real.ravel()))
correction_factor_real = correction_factor_real[sorted_indicies].reshape(
ind_real.shape
)
weights_real = weights_real * correction_factor_real

if datacube_number == 0:
self._ind_real = []
self._weights_real = []
Expand Down Expand Up @@ -1327,21 +1349,27 @@ def _forward(
obj = copy_to_device(self._object[x_index], device)

ind_real = self._ind_real[datacube_number].reshape((4, s[1], s[2]))
ind_diff = self._ind_diff[datacube_number]
ind_diff = self._ind_diff[datacube_number].reshape((4, s[-1], s[-1]))
weights_real = self._weights_real[datacube_number].reshape((4, s[1], s[2]))
weights_diff = self._weights_diff[datacube_number]
weights_diff = self._weights_diff[datacube_number].reshape((4, s[-1], s[-1]))

xp = np

obj_q_summed = (obj[:, ind_diff] * weights_diff).sum((1))
bincount_diff = (
xp.tile(
(xp.tile(self._ind_diffraction_ravel, 4)),
self._ind_diffraction_ravel,
(s[1] * s[2]),
)
+ xp.repeat(xp.arange(s[1] * s[2]), ind_diff.shape[0]) * self._q_length
+ xp.repeat(
xp.arange(s[1] * s[2]), obj_q_summed.shape[1] * obj_q_summed.shape[2]
)
* self._q_length
)

obj_q_summed = xp.bincount(
bincount_diff,
(obj[:, ind_diff] * weights_diff[None, :]).ravel(),
obj_q_summed.ravel(),
minlength=s[1] * s[2] * self._q_length,
).reshape((-1, self._q_length))[:, self._circular_mask_bincount]

Expand Down Expand Up @@ -1531,7 +1559,10 @@ def _back(
minlength=((diff_max) * s[1]),
).reshape((s[1], -1))[:, ind_diff_bincount > 0]

update_q_summed = xp.tile(update_q_summed, (s[2] * 4, 1)) / (s[2])
# update_q_summed = xp.tile(update_q_summed, (s[2] * 4, 1)) / (s[2])
update_q_summed = xp.tile(xp.repeat(update_q_summed, s[2], axis=0), (4, 1)) / (
s[2]
)

diff_shape_bin = update_q_summed.shape[-1]

Expand Down

0 comments on commit e05ac7a

Please sign in to comment.