Skip to content

Commit

Permalink
alignment and quality of life improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Sep 30, 2024
1 parent e58b6e7 commit 5ad901f
Showing 1 changed file with 75 additions and 38 deletions.
113 changes: 75 additions & 38 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def preprocess(
robust_thresh: int = 2,
force_q_to_r_rotation_deg=None,
force_q_to_r_transpose=False,
dp_shift_method="subpixel",
):
"""
Preprocessing for nanobeam tomography
Expand Down Expand Up @@ -178,6 +179,8 @@ def preprocess(
robust_thresh: int
threshold for including points, in units of root-mean-square (standard deviations) error
of the predicted values after fitting.
dp_shift_method: float
method to shift diffraction patterns "subpixel" or "pixel"
"""
xp_storage = self._xp_storage
storage = self._storage
Expand Down Expand Up @@ -272,6 +275,7 @@ def preprocess(
qx0_fit=qx0_fit,
qy0_fit=qy0_fit,
q_max_inv_A=q_max_inv_A,
dp_shift_method=dp_shift_method,
)

return self
Expand Down Expand Up @@ -694,6 +698,7 @@ def _reshape_diffraction_patterns(
qx0_fit,
qy0_fit,
q_max_inv_A,
dp_shift_method,
):
"""
Reshapes diffraction data into a 2 column array
Expand All @@ -712,6 +717,8 @@ def _reshape_diffraction_patterns(
qy shifts
q_max_inv_A: int
maximum q in inverse angstroms
dp_shift_method: float
method to shift diffraction patterns "subpixel" or "pixel"
"""
# calculate bincount array
if datacube_number == 0:
Expand All @@ -722,6 +729,7 @@ def _reshape_diffraction_patterns(
qx0_fit=qx0_fit,
qy0_fit=qy0_fit,
ind_diffraction_ravel=self._ind_diffraction_rotate_transpose_ravel,
dp_shift_method=dp_shift_method,
)

del datacube
Expand Down Expand Up @@ -811,7 +819,12 @@ def _make_diffraction_masks(self, q_max_inv_A):
)

def _reshape_4D_array_to_2D(
self, data, qx0_fit=None, qy0_fit=None, ind_diffraction_ravel=None
self,
data,
qx0_fit=None,
qy0_fit=None,
ind_diffraction_ravel=None,
dp_shift_method="subpixel",
):
"""
reshape diffraction 4D-data to 2D ravelled patterns
Expand All @@ -826,6 +839,8 @@ def _reshape_4D_array_to_2D(
qy shifts
ind_diffraction: np.ndarray
1D array (length of number of pixels in diffraciton space to project 4D array into)
dp_shift_method: float
method to shift diffraction patterns "subpixel" or "pixel"
Returns
Expand All @@ -850,55 +865,66 @@ def _reshape_4D_array_to_2D(
qx0 = center[0] - qx0_fit[a0, a1]
qy0 = center[1] - qy0_fit[a0, a1]

xF = int(np.floor(qx0))
yF = int(np.floor(qy0))
if dp_shift_method == "subpixel":
xF = int(np.floor(qx0))
yF = int(np.floor(qy0))

wx = qx0 - xF
wy = qy0 - yF
wx = qx0 - xF
wy = qy0 - yF

dp_projected = (
(
dp_projected = (
(
(1 - wx)
(
(1 - wx)
* (1 - wy)
* np.bincount(
ind_diffraction_ravel,
np.roll(dp, (xF, yF), axis=(0, 1)).ravel(),
minlength=self._q_length,
)
)
)
+ (
(wx)
* (1 - wy)
* np.bincount(
ind_diffraction_ravel,
np.roll(dp, (xF, yF), axis=(0, 1)).ravel(),
np.roll(dp, (xF + 1, yF), axis=(0, 1)).ravel(),
minlength=self._q_length,
)
)
)
+ (
(wx)
* (1 - wy)
* np.bincount(
ind_diffraction_ravel,
np.roll(dp, (xF + 1, yF), axis=(0, 1)).ravel(),
minlength=self._q_length,
)
)
+ (
(1 - wx)
* (wy)
* np.bincount(
ind_diffraction_ravel,
np.roll(dp, (xF, yF + 1), axis=(0, 1)).ravel(),
minlength=self._q_length,
+ (
(1 - wx)
* (wy)
* np.bincount(
ind_diffraction_ravel,
np.roll(dp, (xF, yF + 1), axis=(0, 1)).ravel(),
minlength=self._q_length,
)
)
)
+ (
(wx)
* (wy)
* np.bincount(
ind_diffraction_ravel,
np.roll(dp, (xF + 1, yF + 1), axis=(0, 1)).ravel(),
minlength=self._q_length,
+ (
(wx)
* (wy)
* np.bincount(
ind_diffraction_ravel,
np.roll(dp, (xF + 1, yF + 1), axis=(0, 1)).ravel(),
minlength=self._q_length,
)
)
)
)

diffraction_patterns_reshaped[index] = dp_projected
diffraction_patterns_reshaped[index] = dp_projected

elif dp_shift_method == "pixel":
xF = int(qx0)
yF = int(qy0)

dp_projected = np.bincount(
ind_diffraction_ravel,
np.roll(dp, (xF, yF), axis=(0, 1)).ravel(),
minlength=self._q_length,
)
diffraction_patterns_reshaped[index] = dp_projected
else:
diffraction_patterns_reshaped[index] = np.bincount(
ind_diffraction_ravel,
Expand Down Expand Up @@ -1159,7 +1185,6 @@ def _forward(
device = self._device
obj = copy_to_device(self._object[x_index], device)

# TODO check sign
tilt = -xp.deg2rad(tilt_deg)

# solve for real space coordinates
Expand Down Expand Up @@ -1529,7 +1554,9 @@ def _constraints(self, zero_edges: bool, baseline_thresh: float):
self._object[:, ind_zero] = 0

if baseline_thresh is not None:
_, vmin, _ = return_scaled_histogram_ordering(self._object, vmin = baseline_thresh)
_, vmin, _ = return_scaled_histogram_ordering(
self._object, vmin=baseline_thresh
)
xp = self._xp_storage
self._object = xp.clip(self._object - vmin, 0, np.inf)

Expand Down Expand Up @@ -1657,6 +1684,16 @@ def object_6D(self):

return self._object.reshape(self._object_shape_6D)

def recovered_4D_scan(self, index):
"""recovered 4D-STEM scan from projected patterns"""

scan = self._reshape_2D_array_to_4D(
self._diffraction_patterns_projected[index],
positions=self._positions_vox_F[index],
)

return scan

#### Code for sims, To be removed later
# def _make_test_object(
# self,
Expand Down

0 comments on commit 5ad901f

Please sign in to comment.