Skip to content

Commit

Permalink
more plotting, minor tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Dec 8, 2023
1 parent 0abd641 commit 2efda6b
Showing 1 changed file with 79 additions and 21 deletions.
100 changes: 79 additions & 21 deletions py4DSTEM/process/phase/iterative_parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from py4DSTEM.process.phase.utils import AffineTransform
from py4DSTEM.process.utils.cross_correlate import align_images_fourier
from py4DSTEM.process.utils.utils import electron_wavelength_angstrom
from py4DSTEM.visualize import show
from py4DSTEM.visualize import return_scaled_histogram_ordering, show
from scipy.linalg import polar
from scipy.ndimage import distance_transform_edt
from scipy.optimize import minimize
Expand Down Expand Up @@ -1302,8 +1302,8 @@ def subpixel_alignment(
pixel_output_shape = np.round(BF_size * self._kde_upsample_factor).astype("int")

# shifted coordinates
x = xp.arange(BF_size[0], dtype="float")
y = xp.arange(BF_size[1], dtype="float")
x = xp.arange(BF_size[0], dtype=xp.float32)
y = xp.arange(BF_size[1], dtype=xp.float32)
xa_init, ya_init = xp.meshgrid(x, y, indexing="ij")

# kernel density output the upsampled BF image
Expand Down Expand Up @@ -1468,7 +1468,6 @@ def subpixel_alignment(
)

update = fixed_step_scores < scores

self._probe_dx[update] += dx[update]
self._probe_dy[update] += dy[update]

Expand Down Expand Up @@ -1631,16 +1630,24 @@ def subpixel_alignment(
-0.5 / (self._scan_sampling[0] / self._kde_upsample_factor),
]

nx, ny = self._recon_BF_subpixel_aligned.shape
kx = xp.fft.fftfreq(nx, d=1)
ky = xp.fft.fftfreq(ny, d=1)
k = xp.fft.fftshift(xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2))

recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF)))
sx, sy = recon_fft.shape

pad_x_post = (nx - sx) // 2
pad_x_pre = nx - sx - pad_x_post
pad_y_post = (ny - sy) // 2
pad_y_pre = ny - sy - pad_y_post

pad_x = np.round(
BF_size[0] * (self._kde_upsample_factor - 1) / 2
).astype("int")
pad_y = np.round(
BF_size[1] * (self._kde_upsample_factor - 1) / 2
).astype("int")
pad_recon_fft = asnumpy(
xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y)))
xp.pad(
recon_fft, ((pad_x_pre, pad_x_post), (pad_y_pre, pad_y_post))
)
* k
)

if ncols == 3:
Expand All @@ -1649,54 +1656,63 @@ def subpixel_alignment(
xp.fft.fftshift(
xp.abs(xp.fft.fft2(recon_BF_subpixel_aligned_reference))
)
* k
)

upsampled_fft = asnumpy(
xp.fft.fftshift(
xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned))
)
* k
)
axs = [ax1, ax2, ax3]
else:
upsampled_fft_reference = asnumpy(
xp.fft.fftshift(
xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned))
)
* k
)
axs = [ax1, ax2]

show(
_, vmin, vmax = return_scaled_histogram_ordering(
upsampled_fft_reference
)

axs[0].imshow(
pad_recon_fft,
figax=(fig, axs[0]),
extent=reciprocal_extent,
vmin=vmin,
vmax=vmax,
cmap="gray",
title="Aligned Bright Field FFT",
**kwargs,
)
axs[0].set_title("Aligned Bright Field FFT")

show(
axs[1].imshow(
upsampled_fft_reference,
figax=(fig, axs[1]),
extent=reciprocal_extent,
vmin=vmin,
vmax=vmax,
cmap="gray",
title="Upsampled Bright Field FFT",
**kwargs,
)
axs[1].set_title("Upsampled Bright Field FFT")

if ncols == 3:
show(
axs[2].imshow(
upsampled_fft,
figax=(fig, axs[2]),
extent=reciprocal_extent,
vmin=vmin,
vmax=vmax,
cmap="gray",
title="Probe-Corrected Bright Field FFT",
**kwargs,
)
axs[2].set_title("Probe-Corrected Bright Field FFT")

for ax in axs:
ax.set_ylabel(r"$k_x$ [$A^{-1}$]")
ax.set_xlabel(r"$k_y$ [$A^{-1}$]")
ax.xaxis.set_ticks_position("bottom")

row_index += 1

Expand Down Expand Up @@ -3039,6 +3055,48 @@ def show_shifts(

fig.tight_layout()

def show_probe_position_shifts(
self,
**kwargs,
):
"""
Utility function to visualize probe-position shifts.
"""
probe_dx = self._crop_padded_object(self._probe_dx)
probe_dy = self._crop_padded_object(self._probe_dy)
max_shift = np.abs(np.dstack((probe_dx, probe_dy))).max()

figsize = kwargs.pop("figsize", (9, 4))
vmin = kwargs.pop("vmin", -max_shift)
vmax = kwargs.pop("vmax", max_shift)
cmap = kwargs.pop("cmap", "PuOr")

extent = [
0,
self._scan_sampling[1] * probe_dx.shape[1],
self._scan_sampling[0] * probe_dx.shape[0],
0,
]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
im1 = ax1.imshow(probe_dx, extent=extent, vmin=vmin, vmax=vmax, cmap=cmap)
im2 = ax2.imshow(probe_dy, extent=extent, vmin=vmin, vmax=vmax, cmap=cmap)

for ax, im in zip([ax1, ax2], [im1, im2]):
divider = make_axes_locatable(ax)
ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
fig.add_axes(ax_cb)
cb = fig.colorbar(im, cax=ax_cb)
cb.set_label("pix", rotation=0, ha="center", va="bottom")
cb.ax.yaxis.set_label_coords(0.5, 1.01)
ax.set_ylabel("x [A]")
ax.set_xlabel("y [A]")

ax1.set_title("Probe Position Vertical Shifts")
ax2.set_title("Probe Position Horizontal Shifts")

fig.tight_layout()

def visualize(
self,
**kwargs,
Expand Down

0 comments on commit 2efda6b

Please sign in to comment.