diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 0967a1b46..dbb459ac0 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -10,6 +10,7 @@ import numpy as np from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from matplotlib.gridspec import GridSpec +from matplotlib.ticker import PercentFormatter from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from py4DSTEM import Calibration, DataCube from py4DSTEM.preprocess.utils import get_shifted_ar @@ -1320,6 +1321,8 @@ def subpixel_alignment( # Perform probe position correction if needed if position_correction_num_iter is not None: + recon_BF_subpixel_aligned_reference = pix_output.copy() + # init position shift array self._probe_dx = xp.zeros_like(xa_init) self._probe_dy = xp.zeros_like(xa_init) @@ -1523,66 +1526,113 @@ def subpixel_alignment( ) position_correction_stats[a0 + 1] = scores.mean() - - if plot_position_correction_convergence: - fig, ax = plt.subplots(figsize=(8, 2)) - ax.plot( - np.arange(position_correction_num_iter + 1), - position_correction_stats, - color=(1, 0, 0), - ) - ax.set_xlabel("iterations") - ax.set_ylabel("position error") + else: + plot_position_correction_convergence = False self._recon_BF_subpixel_aligned = pix_output self.recon_BF_subpixel_aligned = asnumpy(self._recon_BF_subpixel_aligned) # plotting - if plot_upsampled_BF_comparison: - if plot_upsampled_FFT_comparison: - figsize = kwargs.pop("figsize", (8, 8)) - fig, axs = plt.subplots(2, 2, figsize=figsize) - else: - figsize = kwargs.pop("figsize", (8, 4)) - fig, axs = plt.subplots(1, 2, figsize=figsize) + nrows = np.count_nonzero( + np.array( + [ + plot_upsampled_BF_comparison, + plot_upsampled_FFT_comparison, + plot_position_correction_convergence, + ] + ) + ) + if nrows > 0: + ncols = 3 if position_correction_num_iter is not None else 2 + height_ratios = ( + [4, 4, 2][-nrows:] + if plot_position_correction_convergence + else [4, 4, 2][:nrows] + ) + spec = GridSpec( + ncols=ncols, nrows=nrows, height_ratios=height_ratios, hspace=0.15 + ) - axs = axs.flat + figsize = kwargs.pop("figsize", (4 * ncols, sum(height_ratios))) cmap = kwargs.pop("cmap", "magma") + fig = plt.figure(figsize=figsize) - cropped_object = self._crop_padded_object(self._recon_BF) - cropped_object_aligned = self._crop_padded_object( - self._recon_BF_subpixel_aligned, upsampled=True - ) + row_index = 0 - extent = [ - 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], - 0, - ] + if plot_upsampled_BF_comparison: + ax1 = fig.add_subplot(spec[row_index, 0]) + ax2 = fig.add_subplot(spec[row_index, 1]) - axs[0].imshow( - cropped_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - axs[0].set_title("Aligned Bright Field") + cropped_object = self._crop_padded_object(self._recon_BF) - axs[1].imshow( - cropped_object_aligned, - extent=extent, - cmap=cmap, - **kwargs, - ) - axs[1].set_title("Upsampled Bright Field") + if ncols == 3: + ax3 = fig.add_subplot(spec[row_index, 2]) - for ax in axs[:2]: - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") + cropped_object_reference_aligned = self._crop_padded_object( + recon_BF_subpixel_aligned_reference, upsampled=True + ) + cropped_object_aligned = self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) + axs = [ax1, ax2, ax3] + + else: + cropped_object_reference_aligned = self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) + axs = [ax1, ax2] + + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] + + axs[0].imshow( + cropped_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[0].set_title("Aligned Bright Field") + + axs[1].imshow( + cropped_object_reference_aligned, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[1].set_title("Upsampled Bright Field") + + if ncols == 3: + axs[2].imshow( + cropped_object_aligned, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[2].set_title("Probe-Corrected Bright Field") + + for ax in axs: + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + row_index += 1 if plot_upsampled_FFT_comparison: + ax1 = fig.add_subplot(spec[row_index, 0]) + ax2 = fig.add_subplot(spec[row_index, 1]) + + reciprocal_extent = [ + -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + ] + recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) + pad_x = np.round( BF_size[0] * (self._kde_upsample_factor - 1) / 2 ).astype("int") @@ -1593,22 +1643,31 @@ def subpixel_alignment( xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y))) ) - upsampled_fft = asnumpy( - xp.fft.fftshift( - xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + if ncols == 3: + ax3 = fig.add_subplot(spec[row_index, 2]) + upsampled_fft_reference = asnumpy( + xp.fft.fftshift( + xp.abs(xp.fft.fft2(recon_BF_subpixel_aligned_reference)) + ) ) - ) - reciprocal_extent = [ - -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), - 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), - 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), - -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), - ] + upsampled_fft = asnumpy( + xp.fft.fftshift( + xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + ) + ) + axs = [ax1, ax2, ax3] + else: + upsampled_fft_reference = asnumpy( + xp.fft.fftshift( + xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + ) + ) + axs = [ax1, ax2] show( pad_recon_fft, - figax=(fig, axs[2]), + figax=(fig, axs[0]), extent=reciprocal_extent, cmap="gray", title="Aligned Bright Field FFT", @@ -1616,20 +1675,50 @@ def subpixel_alignment( ) show( - upsampled_fft, - figax=(fig, axs[3]), + upsampled_fft_reference, + figax=(fig, axs[1]), extent=reciprocal_extent, cmap="gray", title="Upsampled Bright Field FFT", **kwargs, ) - for ax in axs[2:]: + if ncols == 3: + show( + upsampled_fft, + figax=(fig, axs[2]), + extent=reciprocal_extent, + cmap="gray", + title="Probe-Corrected Bright Field FFT", + **kwargs, + ) + + for ax in axs: ax.set_ylabel(r"$k_x$ [$A^{-1}$]") ax.set_xlabel(r"$k_y$ [$A^{-1}$]") ax.xaxis.set_ticks_position("bottom") - fig.tight_layout() + row_index += 1 + + if plot_position_correction_convergence: + axs = fig.add_subplot(spec[row_index, :]) + + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + color = kwargs.pop("color", (1, 0, 0)) + + axs.semilogy( + np.arange(position_correction_num_iter + 1), + position_correction_stats / position_correction_stats[0], + color=color, + **kwargs, + ) + axs.set_xlabel("Iteration number") + axs.set_ylabel("NMSE") + axs.yaxis.set_major_formatter(PercentFormatter(1.0, decimals=0)) + axs.yaxis.set_minor_formatter(PercentFormatter(1.0, decimals=0)) + + spec.tight_layout(fig) def _bilinearly_sample_array( self,