From 2efda6bc4e3a4c118b5720ceef71b848fd5b8606 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 7 Dec 2023 23:56:53 -0800 Subject: [PATCH] more plotting, minor tweaks --- py4DSTEM/process/phase/iterative_parallax.py | 100 +++++++++++++++---- 1 file changed, 79 insertions(+), 21 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index dbb459ac0..5641a8d56 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -18,7 +18,7 @@ from py4DSTEM.process.phase.utils import AffineTransform from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom -from py4DSTEM.visualize import show +from py4DSTEM.visualize import return_scaled_histogram_ordering, show from scipy.linalg import polar from scipy.ndimage import distance_transform_edt from scipy.optimize import minimize @@ -1302,8 +1302,8 @@ def subpixel_alignment( pixel_output_shape = np.round(BF_size * self._kde_upsample_factor).astype("int") # shifted coordinates - x = xp.arange(BF_size[0], dtype="float") - y = xp.arange(BF_size[1], dtype="float") + x = xp.arange(BF_size[0], dtype=xp.float32) + y = xp.arange(BF_size[1], dtype=xp.float32) xa_init, ya_init = xp.meshgrid(x, y, indexing="ij") # kernel density output the upsampled BF image @@ -1468,7 +1468,6 @@ def subpixel_alignment( ) update = fixed_step_scores < scores - self._probe_dx[update] += dx[update] self._probe_dy[update] += dy[update] @@ -1631,16 +1630,24 @@ def subpixel_alignment( -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), ] + nx, ny = self._recon_BF_subpixel_aligned.shape + kx = xp.fft.fftfreq(nx, d=1) + ky = xp.fft.fftfreq(ny, d=1) + k = xp.fft.fftshift(xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2)) + recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) + sx, sy = recon_fft.shape + + pad_x_post = (nx - sx) // 2 + pad_x_pre = nx - sx - pad_x_post + pad_y_post = (ny - sy) // 2 + pad_y_pre = ny - sy - pad_y_post - pad_x = np.round( - BF_size[0] * (self._kde_upsample_factor - 1) / 2 - ).astype("int") - pad_y = np.round( - BF_size[1] * (self._kde_upsample_factor - 1) / 2 - ).astype("int") pad_recon_fft = asnumpy( - xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y))) + xp.pad( + recon_fft, ((pad_x_pre, pad_x_post), (pad_y_pre, pad_y_post)) + ) + * k ) if ncols == 3: @@ -1649,12 +1656,14 @@ def subpixel_alignment( xp.fft.fftshift( xp.abs(xp.fft.fft2(recon_BF_subpixel_aligned_reference)) ) + * k ) upsampled_fft = asnumpy( xp.fft.fftshift( xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) ) + * k ) axs = [ax1, ax2, ax3] else: @@ -1662,41 +1671,48 @@ def subpixel_alignment( xp.fft.fftshift( xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) ) + * k ) axs = [ax1, ax2] - show( + _, vmin, vmax = return_scaled_histogram_ordering( + upsampled_fft_reference + ) + + axs[0].imshow( pad_recon_fft, - figax=(fig, axs[0]), extent=reciprocal_extent, + vmin=vmin, + vmax=vmax, cmap="gray", - title="Aligned Bright Field FFT", **kwargs, ) + axs[0].set_title("Aligned Bright Field FFT") - show( + axs[1].imshow( upsampled_fft_reference, - figax=(fig, axs[1]), extent=reciprocal_extent, + vmin=vmin, + vmax=vmax, cmap="gray", - title="Upsampled Bright Field FFT", **kwargs, ) + axs[1].set_title("Upsampled Bright Field FFT") if ncols == 3: - show( + axs[2].imshow( upsampled_fft, - figax=(fig, axs[2]), extent=reciprocal_extent, + vmin=vmin, + vmax=vmax, cmap="gray", - title="Probe-Corrected Bright Field FFT", **kwargs, ) + axs[2].set_title("Probe-Corrected Bright Field FFT") for ax in axs: ax.set_ylabel(r"$k_x$ [$A^{-1}$]") ax.set_xlabel(r"$k_y$ [$A^{-1}$]") - ax.xaxis.set_ticks_position("bottom") row_index += 1 @@ -3039,6 +3055,48 @@ def show_shifts( fig.tight_layout() + def show_probe_position_shifts( + self, + **kwargs, + ): + """ + Utility function to visualize probe-position shifts. + """ + probe_dx = self._crop_padded_object(self._probe_dx) + probe_dy = self._crop_padded_object(self._probe_dy) + max_shift = np.abs(np.dstack((probe_dx, probe_dy))).max() + + figsize = kwargs.pop("figsize", (9, 4)) + vmin = kwargs.pop("vmin", -max_shift) + vmax = kwargs.pop("vmax", max_shift) + cmap = kwargs.pop("cmap", "PuOr") + + extent = [ + 0, + self._scan_sampling[1] * probe_dx.shape[1], + self._scan_sampling[0] * probe_dx.shape[0], + 0, + ] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + im1 = ax1.imshow(probe_dx, extent=extent, vmin=vmin, vmax=vmax, cmap=cmap) + im2 = ax2.imshow(probe_dy, extent=extent, vmin=vmin, vmax=vmax, cmap=cmap) + + for ax, im in zip([ax1, ax2], [im1, im2]): + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + cb = fig.colorbar(im, cax=ax_cb) + cb.set_label("pix", rotation=0, ha="center", va="bottom") + cb.ax.yaxis.set_label_coords(0.5, 1.01) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + ax1.set_title("Probe Position Vertical Shifts") + ax2.set_title("Probe Position Horizontal Shifts") + + fig.tight_layout() + def visualize( self, **kwargs,