Skip to content

Commit

Permalink
Merge pull request #616 from py4dstem/phase_contrast
Browse files Browse the repository at this point in the history
Quality of Life Improvements and Minor Bugfixes
  • Loading branch information
cophus authored Mar 14, 2024
2 parents 0ac074d + 1137191 commit b326bca
Show file tree
Hide file tree
Showing 17 changed files with 755 additions and 345 deletions.
10 changes: 8 additions & 2 deletions py4DSTEM/datacube/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,9 @@ def pad_Q(self, N=None, output_size=None):
d = pad_data_diffraction(self, pad_factor=N, output_size=output_size)
return d

def resample_Q(self, N=None, output_size=None, method="bilinear"):
def resample_Q(
self, N=None, output_size=None, method="bilinear", conserve_array_sums=False
):
"""
Resamples the data in diffraction space by resampling factor N, or to match output_size,
using either 'fourier' or 'bilinear' interpolation.
Expand All @@ -418,7 +420,11 @@ def resample_Q(self, N=None, output_size=None, method="bilinear"):
from py4DSTEM.preprocess import resample_data_diffraction

d = resample_data_diffraction(
self, resampling_factor=N, output_size=output_size, method=method
self,
resampling_factor=N,
output_size=output_size,
method=method,
conserve_array_sums=conserve_array_sums,
)
return d

Expand Down
41 changes: 34 additions & 7 deletions py4DSTEM/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,11 @@ def datacube_diffraction_shift(


def resample_data_diffraction(
datacube, resampling_factor=None, output_size=None, method="bilinear"
datacube,
resampling_factor=None,
output_size=None,
method="bilinear",
conserve_array_sums=False,
):
"""
Performs diffraction space resampling of data by resampling_factor or to match output_size.
Expand All @@ -594,7 +598,10 @@ def resample_data_diffraction(
old_size = datacube.data.shape

datacube.data = fourier_resample(
datacube.data, scale=resampling_factor, output_size=output_size
datacube.data,
scale=resampling_factor,
output_size=output_size,
conserve_array_sums=conserve_array_sums,
)

if not resampling_factor:
Expand All @@ -617,6 +624,10 @@ def resample_data_diffraction(
if resampling_factor.shape == ():
resampling_factor = np.tile(resampling_factor, 2)

output_size = np.round(
resampling_factor * np.array(datacube.shape[-2:])
).astype("int")

else:
if output_size is None:
raise ValueError(
Expand All @@ -630,12 +641,28 @@ def resample_data_diffraction(

resampling_factor = np.array(output_size) / np.array(datacube.shape[-2:])

resampling_factor = np.concatenate(((1, 1), resampling_factor))
datacube.data = zoom(
datacube.data, resampling_factor, order=1, mode="grid-wrap", grid_mode=True
)
output_data = np.zeros(datacube.Rshape + tuple(output_size))
for Rx, Ry in tqdmnd(
datacube.shape[0],
datacube.shape[1],
desc="Resampling 4D datacube",
unit="DP",
unit_scale=True,
):
output_data[Rx, Ry] = zoom(
datacube.data[Rx, Ry].astype(np.float32),
resampling_factor,
order=1,
mode="nearest",
grid_mode=True,
)

if conserve_array_sums:
output_data = output_data / resampling_factor.prod()

datacube.data = output_data
datacube.calibration.set_Q_pixel_size(
datacube.calibration.get_Q_pixel_size() / resampling_factor[2]
datacube.calibration.get_Q_pixel_size() / resampling_factor[0]
)

else:
Expand Down
2 changes: 2 additions & 0 deletions py4DSTEM/process/phase/dpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def _get_constructor_args(cls, group):
"name": instance_md["name"],
"verbose": True, # for compatibility
"device": "cpu", # for compatibility
"storage": "cpu", # for compatibility
"clear_fft_cache": True, # for compatibility
}

return kwargs
Expand Down
119 changes: 53 additions & 66 deletions py4DSTEM/process/phase/magnetic_ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class MagneticPtychographicTomography(
initial_scan_positions: list of np.ndarray, optional
Probe positions in Å for each diffraction intensity per tilt
If None, initialized to a grid scan centered along tilt axis
positions_offset_ang: list of np.ndarray, optional
Offset of positions in A
verbose: bool, optional
If True, class methods will inherit this and print additional information
object_type: str, optional
Expand Down Expand Up @@ -155,6 +157,7 @@ def __init__(
initial_object_guess: np.ndarray = None,
initial_probe_guess: np.ndarray = None,
initial_scan_positions: Sequence[np.ndarray] = None,
positions_offset_ang: Sequence[np.ndarray] = None,
verbose: bool = True,
device: str = "cpu",
storage: str = None,
Expand Down Expand Up @@ -199,6 +202,7 @@ def __init__(
# Common Metadata
self._vacuum_probe_intensity = vacuum_probe_intensity
self._scan_positions = initial_scan_positions
self._positions_offset_ang = positions_offset_ang
self._energy = energy
self._semiangle_cutoff = semiangle_cutoff
self._semiangle_cutoff_pixels = semiangle_cutoff_pixels
Expand All @@ -217,7 +221,7 @@ def __init__(
def preprocess(
self,
diffraction_intensities_shape: Tuple[int, int] = None,
reshaping_method: str = "fourier",
reshaping_method: str = "bilinear",
padded_diffraction_intensities_shape: Tuple[int, int] = None,
region_of_interest_shape: Tuple[int, int] = None,
dp_mask: np.ndarray = None,
Expand All @@ -227,12 +231,13 @@ def preprocess(
diffraction_patterns_rotate_degrees: float = None,
diffraction_patterns_transpose: bool = None,
force_com_shifts: Sequence[float] = None,
force_com_measured: Sequence[np.ndarray] = None,
vectorized_com_calculation: bool = True,
progress_bar: bool = True,
force_scan_sampling: float = None,
force_angular_sampling: float = None,
force_reciprocal_sampling: float = None,
object_fov_mask: np.ndarray = None,
object_fov_mask: np.ndarray = True,
crop_patterns: bool = False,
device: str = None,
clear_fft_cache: bool = None,
Expand Down Expand Up @@ -275,6 +280,8 @@ def preprocess(
Amplitudes come from diffraction patterns shifted with
the CoM in the upper left corner for each probe unless
shift is overwritten. One tuple per tilt.
force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured)
Force CoM measured shifts
vectorized_com_calculation: bool, optional
If True (default), the memory-intensive CoM calculation is vectorized
force_scan_sampling: float, optional
Expand Down Expand Up @@ -366,6 +373,7 @@ def preprocess(
(self._num_diffraction_patterns,) + roi_shape
)

self._amplitudes_shape = np.array(self._amplitudes.shape[-2:])
if region_of_interest_shape is not None:
self._resample_exit_waves = True
self._region_of_interest_shape = np.array(region_of_interest_shape)
Expand All @@ -377,9 +385,20 @@ def preprocess(
if force_com_shifts is None:
force_com_shifts = [None] * self._num_measurements

if force_com_measured is None:
force_com_measured = [None] * self._num_measurements

if self._positions_offset_ang is None:
self._positions_offset_ang = [None] * self._num_measurements

self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees)
self._rotation_best_transpose = diffraction_patterns_transpose

if progress_bar:
# turn off verbosity to play nice with tqdm
verbose = self._verbose
self._verbose = False

# loop over DPs for preprocessing
for index in tqdmnd(
self._num_measurements,
Expand All @@ -394,6 +413,7 @@ def preprocess(
self._vacuum_probe_intensity,
self._dp_mask,
force_com_shifts[index],
force_com_measured[index],
) = self._preprocess_datacube_and_vacuum_probe(
self._datacube[index],
diffraction_intensities_shape=self._diffraction_intensities_shape,
Expand All @@ -402,6 +422,7 @@ def preprocess(
vacuum_probe_intensity=self._vacuum_probe_intensity,
dp_mask=self._dp_mask,
com_shifts=force_com_shifts[index],
com_measured=force_com_measured[index],
)

else:
Expand All @@ -410,6 +431,7 @@ def preprocess(
_,
_,
force_com_shifts[index],
force_com_measured[index],
) = self._preprocess_datacube_and_vacuum_probe(
self._datacube[index],
diffraction_intensities_shape=self._diffraction_intensities_shape,
Expand All @@ -418,6 +440,7 @@ def preprocess(
vacuum_probe_intensity=None,
dp_mask=None,
com_shifts=force_com_shifts[index],
com_measured=force_com_measured[index],
)

# calibrations
Expand All @@ -443,6 +466,7 @@ def preprocess(
fit_function=fit_function,
com_shifts=force_com_shifts[index],
vectorized_calculation=vectorized_com_calculation,
com_measured=force_com_measured[index],
)

# corner-center amplitudes
Expand Down Expand Up @@ -484,8 +508,13 @@ def preprocess(
self._scan_positions[index],
self._positions_mask[index],
self._object_padding_px,
self._positions_offset_ang[index],
)

if progress_bar:
# reset verbosity
self._verbose = verbose

# handle semiangle specified in pixels
if self._semiangle_cutoff_pixels:
self._semiangle_cutoff = (
Expand Down Expand Up @@ -579,65 +608,16 @@ def preprocess(
self._slice_thicknesses,
)

# overlaps
if max_batch_size is None:
max_batch_size = self._num_diffraction_patterns

if object_fov_mask is None:
probe_overlap_3D = xp.zeros_like(self._object[0])
old_rot_matrix = np.eye(3) # identity

for index in range(self._num_measurements):
idx_start = self._cum_probes_per_measurement[index]
idx_end = self._cum_probes_per_measurement[index + 1]

rot_matrix = self._tilt_orientation_matrices[index]

probe_overlap_3D = self._rotate_zxy_volume(
probe_overlap_3D,
rot_matrix @ old_rot_matrix.T,
)

probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32)

num_diffraction_patterns = idx_end - idx_start
shuffled_indices = np.arange(idx_start, idx_end)

for start, end in generate_batches(
num_diffraction_patterns, max_batch=max_batch_size
):
# batch indices
batch_indices = shuffled_indices[start:end]
positions_px = self._positions_px_all[batch_indices]
positions_px_fractional = positions_px - xp_storage.round(
positions_px
)

shifted_probes = fft_shift(
self._probes_all[index], positions_px_fractional, xp
)
probe_overlap += self._sum_overlapping_patches_bincounts(
xp.abs(shifted_probes) ** 2, positions_px
)

del shifted_probes

probe_overlap_3D += probe_overlap[None]
old_rot_matrix = rot_matrix

probe_overlap_3D = self._rotate_zxy_volume(
probe_overlap_3D,
old_rot_matrix.T,
)

gaussian_filter = self._scipy.ndimage.gaussian_filter
probe_overlap_3D_blurred = gaussian_filter(probe_overlap_3D, 1.0)
self._object_fov_mask = asnumpy(
probe_overlap_3D_blurred > 0.25 * probe_overlap_3D_blurred.max()
)

if object_fov_mask is not True:
raise NotImplementedError()
else:
self._object_fov_mask = np.asarray(object_fov_mask)
self._object_fov_mask = np.full(self._object_shape, True)
self._object_fov_mask_inverse = np.invert(self._object_fov_mask)

# plot probe overlaps
if plot_probe_overlaps:
if max_batch_size is None:
max_batch_size = self._num_diffraction_patterns

probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32)

Expand All @@ -656,13 +636,8 @@ def preprocess(
)

del shifted_probes
probe_overlap = asnumpy(probe_overlap)

self._object_fov_mask_inverse = np.invert(self._object_fov_mask)

probe_overlap = asnumpy(probe_overlap)

# plot probe overlaps
if plot_probe_overlaps:
figsize = kwargs.pop("figsize", (13, 4))
chroma_boost = kwargs.pop("chroma_boost", 1)
power = kwargs.pop("power", 2)
Expand Down Expand Up @@ -877,6 +852,7 @@ def reconstruct(
object_positivity: bool = True,
shrinkage_rad: float = 0.0,
fix_potential_baseline: bool = True,
detector_fourier_mask: np.ndarray = None,
tv_denoise: bool = True,
tv_denoise_weights=None,
tv_denoise_inner_iter=40,
Expand Down Expand Up @@ -987,6 +963,11 @@ def reconstruct(
if True perform collective tilt updates
shrinkage_rad: float
Phase shift in radians to be subtracted from the potential at each iteration
fix_potential_baseline: bool
If true, the potential mean outside the FOV is forced to zero at each iteration
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
store_iterations: bool, optional
If True, reconstructed objects and probes are stored at each iteration
progress_bar: bool, optional
Expand Down Expand Up @@ -1070,6 +1051,11 @@ def reconstruct(
else:
max_batch_size = self._num_diffraction_patterns

if detector_fourier_mask is None:
detector_fourier_mask = xp.ones(self._amplitudes[0].shape)
else:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

# initialization
self._reset_reconstruction(store_iterations, reset, use_projection_scheme)

Expand Down Expand Up @@ -1179,6 +1165,7 @@ def reconstruct(
positions_px_fractional,
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
Loading

0 comments on commit b326bca

Please sign in to comment.