Skip to content

Commit

Permalink
Merge pull request #658 from py4dstem/phase_contrast
Browse files Browse the repository at this point in the history
The Tortured Phase Contrast Department
  • Loading branch information
gvarnavi authored May 24, 2024
2 parents 22f8859 + 87d44d5 commit a49e030
Show file tree
Hide file tree
Showing 14 changed files with 114 additions and 39 deletions.
41 changes: 36 additions & 5 deletions py4DSTEM/io/filereaders/read_arina.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def read_arina(
binfactor: int = 1,
dtype_bin: float = None,
flatfield: np.ndarray = None,
median_filter_masked_pixels_array: np.ndarray = None,
median_filter_masked_pixels_kernel: int = 4,
):
"""
File reader for arina 4D-STEM datasets
Expand Down Expand Up @@ -79,6 +81,8 @@ def read_arina(
array_3D,
binfactor,
correction_factors,
median_filter_masked_pixels_array,
median_filter_masked_pixels_kernel,
)

if f.__bool__():
Expand All @@ -95,23 +99,50 @@ def read_arina(
)
)

if median_filter_masked_pixels_array is not None and binfactor == 1:
datacube = datacube.median_filter_masked_pixels(
median_filter_masked_pixels_array, median_filter_masked_pixels_kernel
)

return datacube


def _processDataSet(dset, start_index, array_3D, binfactor, correction_factors):
def _processDataSet(
dset,
start_index,
array_3D,
binfactor,
correction_factors,
median_filter_masked_pixels_array,
median_filter_masked_pixels_kernel,
):
image_index = start_index
nimages_dset = dset.shape[0]

if median_filter_masked_pixels_array is not None and binfactor != 1:
from py4DSTEM.preprocess import median_filter_masked_pixels_2D

for i in range(nimages_dset):
if binfactor == 1:
array_3D[image_index] = np.multiply(
dset[i].astype(array_3D.dtype), correction_factors
)

else:
array_3D[image_index] = bin2D(
np.multiply(dset[i].astype(array_3D.dtype), correction_factors),
binfactor,
)
if median_filter_masked_pixels_array is not None:
array_3D[image_index] = bin2D(
median_filter_masked_pixels_2D(
np.multiply(dset[i].astype(array_3D.dtype), correction_factors),
median_filter_masked_pixels_array,
median_filter_masked_pixels_kernel,
),
binfactor,
)
else:
array_3D[image_index] = bin2D(
np.multiply(dset[i].astype(array_3D.dtype), correction_factors),
binfactor,
)

image_index = image_index + 1
return image_index
55 changes: 53 additions & 2 deletions py4DSTEM/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,8 @@ def median_filter_masked_pixels(datacube, mask, kernel_width: int = 3):
width_min = kernel_width // 2

else:
width_max = int(kernel_width / 2 + 0.5)
width_min = int(kernel_width / 2 - 0.5)
width_max = int(np.ceil(kernel_width / 2))
width_min = int(np.floor(kernel_width / 2))

num_bad_pixels_indicies = np.array(np.where(mask))
for a0 in range(num_bad_pixels_indicies.shape[1]):
Expand Down Expand Up @@ -525,6 +525,57 @@ def median_filter_masked_pixels(datacube, mask, kernel_width: int = 3):
return datacube


def median_filter_masked_pixels_2D(array, mask, kernel_width: int = 3):
"""
Median filters a 2D array
Parameters
----------
array:
array to be filtered
mask:
a boolean mask that specifies the bad pixels in the datacube
kernel_width (optional):
specifies the width of the median kernel
Returns
----------
filtered datacube
"""
if kernel_width % 2 == 0:
width_max = kernel_width // 2
width_min = kernel_width // 2

else:
width_max = int(np.ceil(kernel_width / 2))
width_min = int(np.floor(kernel_width / 2))

num_bad_pixels_indicies = np.array(np.where(mask))
for a0 in range(num_bad_pixels_indicies.shape[1]):
index_x = num_bad_pixels_indicies[0, a0]
index_y = num_bad_pixels_indicies[1, a0]

x_min = index_x - width_min
y_min = index_y - width_min

x_max = index_x + width_max
y_max = index_y + width_max

if x_min < 0:
x_min = 0
if y_min < 0:
y_min = 0

if x_max > array.shape[0]:
x_max = array.shape[0]
if y_max > array.shape[1]:
y_max = array.shape[1]

array[index_x, index_y] = np.median(array[x_min:x_max, y_min:y_max])

return array


def datacube_diffraction_shift(
datacube,
xshifts,
Expand Down
4 changes: 1 addition & 3 deletions py4DSTEM/process/phase/magnetic_ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,9 +1064,7 @@ def reconstruct(
else:
max_batch_size = self._num_diffraction_patterns

if detector_fourier_mask is None:
detector_fourier_mask = xp.ones(self._amplitudes[0].shape)
else:
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if gaussian_filter_sigma_m is None:
Expand Down
4 changes: 1 addition & 3 deletions py4DSTEM/process/phase/magnetic_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,9 +1388,7 @@ def reconstruct(
else:
max_batch_size = self._num_diffraction_patterns

if detector_fourier_mask is None:
detector_fourier_mask = xp.ones(self._amplitudes[0].shape)
else:
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if gaussian_filter_sigma_m is None:
Expand Down
4 changes: 1 addition & 3 deletions py4DSTEM/process/phase/mixedstate_multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,9 +956,7 @@ def reconstruct(
else:
max_batch_size = self._num_diffraction_patterns

if detector_fourier_mask is None:
detector_fourier_mask = xp.ones(self._amplitudes[0].shape)
else:
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

# main loop
Expand Down
4 changes: 1 addition & 3 deletions py4DSTEM/process/phase/mixedstate_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,9 +850,7 @@ def reconstruct(
else:
max_batch_size = self._num_diffraction_patterns

if detector_fourier_mask is None:
detector_fourier_mask = xp.ones(self._amplitudes[0].shape)
else:
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

# main loop
Expand Down
4 changes: 1 addition & 3 deletions py4DSTEM/process/phase/multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,9 +932,7 @@ def reconstruct(
else:
max_batch_size = self._num_diffraction_patterns

if detector_fourier_mask is None:
detector_fourier_mask = xp.ones(self._amplitudes[0].shape)
else:
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

# main loop
Expand Down
2 changes: 1 addition & 1 deletion py4DSTEM/process/phase/parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ def reconstruct(
alignment_bin_values: list = None,
cross_correlation_upsample_factor: int = 8,
regularizer_matrix_size: Tuple[int, int] = (1, 1),
regularize_shifts: bool = True,
regularize_shifts: bool = False,
running_average: bool = True,
progress_bar: bool = True,
plot_aligned_bf: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion py4DSTEM/process/phase/parameter_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def f(ptycho):
)
return np.log(ptycho.error) if converged else 0.0

elif error_metric == "log-linear":
elif error_metric == "linear-converged":

def f(ptycho):
converged = ptycho.error_iterations[-1] <= np.min(
Expand Down
19 changes: 11 additions & 8 deletions py4DSTEM/process/phase/ptychographic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,14 +1578,16 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
xp=xp,
)

fourier_overlap *= fourier_mask
if fourier_mask is not None:
fourier_overlap *= fourier_mask

farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap)
error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2)
fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap))

fourier_modified_overlap = (
fourier_modified_overlap - fourier_overlap
) * fourier_mask
fourier_modified_overlap = fourier_modified_overlap - fourier_overlap
if fourier_mask is not None:
fourier_modified_overlap *= fourier_mask

# resample back to region_of_interest_shape, note: this needs to happen in reciprocal-space
if self._resample_exit_waves:
Expand Down Expand Up @@ -2742,7 +2744,8 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
xp=xp,
)

fourier_overlap *= fourier_mask
if fourier_mask is not None:
fourier_overlap *= fourier_mask
farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap)
error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2)

Expand All @@ -2751,9 +2754,9 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask

fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap

fourier_modified_overlap = (
fourier_modified_overlap - fourier_overlap
) * fourier_mask
fourier_modified_overlap = fourier_modified_overlap - fourier_overlap
if fourier_mask is not None:
fourier_modified_overlap *= fourier_mask

# resample back to region_of_interest_shape, note: this needs to happen in reciprocal-space
if self._resample_exit_waves:
Expand Down
4 changes: 1 addition & 3 deletions py4DSTEM/process/phase/ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,9 +971,7 @@ def reconstruct(
else:
max_batch_size = self._num_diffraction_patterns

if detector_fourier_mask is None:
detector_fourier_mask = xp.ones(self._amplitudes[0].shape)
else:
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

# main loop
Expand Down
4 changes: 1 addition & 3 deletions py4DSTEM/process/phase/singleslice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,9 +821,7 @@ def reconstruct(
else:
max_batch_size = self._num_diffraction_patterns

if detector_fourier_mask is None:
detector_fourier_mask = xp.ones(self._amplitudes[0].shape)
else:
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

# main loop
Expand Down
4 changes: 4 additions & 0 deletions py4DSTEM/process/polar/polar_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,10 @@ def plot_radial_peaks(
minlength=q_num,
)

# storing arrays for further plotting
self.q_bins = q_bins
self.int_peaks = int_peaks

# plotting
fig, ax = plt.subplots(figsize=figsize)
ax.plot(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"ncempy >= 1.8.1",
"matplotlib >= 3.2.2",
"scikit-image >= 0.17.2",
"scikit-learn >= 0.23.2",
"scikit-learn >= 0.23.2, < 1.5",
"scikit-optimize >= 0.9.0",
"tqdm >= 4.46.1",
"dill >= 0.3.3",
Expand Down

0 comments on commit a49e030

Please sign in to comment.