Skip to content

Commit

Permalink
Friedel origin finder, various FEM tools (#607)
Browse files Browse the repository at this point in the history
* bug fix

* Revert "bug fix"

This reverts commit cbce25c.

* num_probes_fit_aberrations

* flip 180 option for read arina

* ptycho positions offset

* remove 180 flip

* median_filter_masked_pixels

* save polar aberrations instead

* read-write compatibility

* uncertainty viz bug

* adding force_com_measured functionality to ptycho

* clean up force_com_measured

* adding DP normalization progress bar

* moving fov_mask calc to as needed

* adding detector_fourier_mask

* Auto center finding works!

* Adding center finding without a mask

* Adding plot range to show_origin_fit, fixing bug

* Adding local mean / variance functions for FEM

* Adding symmetry analysis and plotting

* Fix divide by zero in correlation origin finding function

* bug fix for filtering

* adding initial commit for friedel correlation origin finder

* Correlation working, but mask still buggy

* Fixing the masked CC

* Simplifying the expression

* cleaning up - still need subpixel shifts

* Update origin fitting visualization

* adding device arg to upsample function

* First attempt to add Friedel origin finding to ptycho

GPU not yet working

* Adding GPU implementation warning

* parabolic subpixel fitting

* minor updates

* Going back to dev version of phase contrast

* Changing np to xp for GPU compatibility

* Fixing xp = np device options

* Revering phase contrast options back to dev

* Cleaning up code, fixing GPU support

* black formatting

* black updates

* Adding annular symmetry plotting function

* black formatting

* cleaning typos and dead code and such

* Update py4DSTEM/process/polar/polar_analysis.py

Co-authored-by: Steve Zeltmann <[email protected]>

---------

Co-authored-by: Stephanie Ribet <[email protected]>
Co-authored-by: Georgios Varnavides <[email protected]>
Co-authored-by: gvarnavi <[email protected]>
Co-authored-by: SE Zeltmann <[email protected]>
Co-authored-by: Steve Zeltmann <[email protected]>
  • Loading branch information
6 people authored Feb 26, 2024
1 parent a1e3662 commit 2eae7d8
Show file tree
Hide file tree
Showing 27 changed files with 701 additions and 213 deletions.
8 changes: 5 additions & 3 deletions py4DSTEM/braggvectors/diskdetection_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,11 @@ def find_Bragg_disks_CUDA(
patt_idx = batch_idx * batch_size + subbatch_idx
rx, ry = np.unravel_index(patt_idx, (datacube.R_Nx, datacube.R_Ny))
batched_subcube[subbatch_idx, :, :] = cp.array(
datacube.data[rx, ry, :, :]
if filter_function is None
else filter_function(datacube.data[rx, ry, :, :]),
(
datacube.data[rx, ry, :, :]
if filter_function is None
else filter_function(datacube.data[rx, ry, :, :])
),
dtype=cp.float32,
)

Expand Down
20 changes: 18 additions & 2 deletions py4DSTEM/datacube/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,13 +479,29 @@ def filter_hot_pixels(self, thresh, ind_compare=1, return_mask=False):
"""
from py4DSTEM.preprocess import filter_hot_pixels

d = filter_hot_pixels(
datacube = filter_hot_pixels(
self,
thresh,
ind_compare,
return_mask,
)
return d
return datacube

def median_filter_masked_pixels(self, mask, kernel_width: int = 3):
"""
This function fixes a datacube where the same pixels are consistently
bad. It requires a mask that identifies all the bad pixels in the dataset.
Then for each diffraction pattern, a median kernel is applied around each
bad pixel with the specified width.
"""
from py4DSTEM.preprocess import median_filter_masked_pixels

datacube = median_filter_masked_pixels(
self,
mask,
kernel_width,
)
return datacube

# Probe

Expand Down
2 changes: 1 addition & 1 deletion py4DSTEM/io/importfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def import_file(
"gatan_K2_bin",
"mib",
"arina",
"abTEM"
"abTEM",
# "kitware_counted",
], "Error: filetype not recognized"

Expand Down
2 changes: 1 addition & 1 deletion py4DSTEM/io/legacy/legacy13/v13_emd_classes/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def Array_to_h5(array, group):
data = grp.create_dataset(
"data",
shape=array.data.shape,
data=array.data
data=array.data,
# dtype = type(array.data)
)
data.attrs.create(
Expand Down
55 changes: 55 additions & 0 deletions py4DSTEM/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,61 @@ def filter_hot_pixels(datacube, thresh, ind_compare=1, return_mask=False):
return datacube


def median_filter_masked_pixels(datacube, mask, kernel_width: int = 3):
"""
This function fixes a datacube where the same pixels are consistently
bad. It requires a mask that identifies all the bad pixels in the dataset.
Then for each diffraction pattern, a median kernel is applied around each
bad pixel with the specified width.
Parameters
----------
datacube:
Datacube to be filtered
mask:
a boolean mask that specifies the bad pixels in the datacube
kernel_width (optional):
specifies the width of the median kernel
Returns
----------
filtered datacube
"""
if kernel_width % 2 == 0:
width_max = kernel_width // 2
width_min = kernel_width // 2

else:
width_max = int(kernel_width / 2 + 0.5)
width_min = int(kernel_width / 2 - 0.5)

num_bad_pixels_indicies = np.array(np.where(mask))
for a0 in range(num_bad_pixels_indicies.shape[1]):
index_x = num_bad_pixels_indicies[0, a0]
index_y = num_bad_pixels_indicies[1, a0]

x_min = index_x - width_min
y_min = index_y - width_min

x_max = index_x + width_max
y_max = index_y + width_max

if x_min < 0:
x_min = 0
if y_min < 0:
y_min = 0

if x_max > datacube.Qshape[0]:
x_max = datacube.Qshape[0]
if y_max > datacube.Qshape[1]:
y_max = datacube.Qshape[1]

datacube.data[:, :, index_x, index_y] = np.median(
datacube.data[:, :, x_min:x_max, y_min:y_max], axis=(2, 3)
)
return datacube


def datacube_diffraction_shift(
datacube,
xshifts,
Expand Down
165 changes: 121 additions & 44 deletions py4DSTEM/process/calibration/origin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,27 @@

import functools
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
from scipy.optimize import leastsq
import matplotlib.pyplot as plt

from emdfile import tqdmnd, PointListArray
from py4DSTEM.datacube import DataCube
from py4DSTEM.process.calibration.probe import get_probe_size
from py4DSTEM.process.fit import plane, parabola, bezier_two, fit_2D
from py4DSTEM.process.utils import get_CoM, add_to_2D_array_from_floats, get_maxima_2D
from py4DSTEM.process.utils import (
get_CoM,
add_to_2D_array_from_floats,
get_maxima_2D,
upsampled_correlation,
)
from py4DSTEM.process.phase.utils import copy_to_device

try:
import cupy as cp
except (ImportError, ModuleNotFoundError):
cp = np


#
Expand Down Expand Up @@ -309,58 +322,122 @@ def get_origin(
return qx0, qy0, mask


def get_origin_single_dp_beamstop(DP: np.ndarray, mask: np.ndarray, **kwargs):
def get_origin_friedel(
datacube: DataCube,
mask=None,
upsample_factor=1,
device="cpu",
return_cpu=True,
):
"""
Find the origin for a single diffraction pattern, assuming there is a beam stop.
Args:
DP (np array): diffraction pattern
mask (np array): boolean mask which is False under the beamstop and True
in the diffraction pattern. One approach to generating this mask
is to apply a suitable threshold on the average diffraction pattern
and use binary opening/closing to remove and holes
Returns:
qx0, qy0 (tuple) measured center position of diffraction pattern
Fit the origin for each diffraction pattern, with or without a beam stop.
The method we have developed here is a heavily modified version of masked
cross correlation, where we use Friedel symmetry of the diffraction pattern
to find the common center.
More details about how the correlation step can be found in:
https://doi.org/10.1109/TIP.2011.2181402
Parameters
----------
datacube: (DataCube)
The 4D dataset.
mask: (np array, optional)
Boolean mask which is False under the beamstop and True
in the diffraction pattern. One approach to generating this mask
is to apply a suitable threshold on the average diffraction pattern
and use binary opening/closing to remove any holes.
If no mask is provided, this method will likely not work with a beamstop.
upsample_factor: (int)
Upsample factor for subpixel fitting of the image shifts.
device: string
'cpu' or 'gpu' to select device
return_cpu: bool
Return arrays on cpu.
Returns
-------
qx0, qy0
(tuple of np arrays) measured center position of each diffraction pattern
"""

imCorr = np.real(
np.fft.ifft2(
np.fft.fft2(DP * mask)
* np.conj(np.fft.fft2(np.rot90(DP, 2) * np.rot90(mask, 2)))
# Select device
if device == "cpu":
xp = np
elif device == "gpu":
xp = cp

# init measurement arrays
qx0 = xp.zeros(datacube.data.shape[:2])
qy0 = xp.zeros_like(qx0)

# pad the mask
if mask is not None:
mask = xp.asarray(mask).astype("float")
mask_pad = xp.pad(
mask,
((0, datacube.data.shape[2]), (0, datacube.data.shape[3])),
constant_values=(1.0, 1.0),
)
)
M = xp.fft.fft2(mask_pad)

xp, yp = np.unravel_index(np.argmax(imCorr), imCorr.shape)

dx = ((xp + DP.shape[0] / 2) % DP.shape[0]) - DP.shape[0] / 2
dy = ((yp + DP.shape[1] / 2) % DP.shape[1]) - DP.shape[1] / 2

return (DP.shape[0] + dx) / 2, (DP.shape[1] + dy) / 2
# main loop over all probe positions
for rx, ry in tqdmnd(datacube.R_Nx, datacube.R_Ny):
if mask is None:
# pad image
im_xp = xp.asarray(datacube.data[rx, ry])
im = xp.pad(
im_xp,
((0, datacube.data.shape[2]), (0, datacube.data.shape[3])),
)
G = xp.fft.fft2(im)

# Cross correlation of masked image with its inverse
cc = xp.real(xp.fft.ifft2(G**2))

def get_origin_beamstop(datacube: DataCube, mask: np.ndarray, **kwargs):
"""
Find the origin for each diffraction pattern, assuming there is a beam stop.
Args:
datacube (DataCube)
mask (np array): boolean mask which is False under the beamstop and True
in the diffraction pattern. One approach to generating this mask
is to apply a suitable threshold on the average diffraction pattern
and use binary opening/closing to remove any holes
else:
im_xp = xp.asarray(datacube.data[rx, ry, :, :])
im = xp.pad(
im_xp,
((0, datacube.data.shape[2]), (0, datacube.data.shape[3])),
)

Returns:
qx0, qy0 (tuple of np arrays) measured center position of each diffraction pattern
"""
# Masked cross correlation of masked image with its inverse
term1 = xp.real(xp.fft.ifft2(xp.fft.fft2(im) ** 2) * xp.fft.ifft2(M**2))
term2 = xp.real(xp.fft.ifft2(xp.fft.fft2(im**2) * M))
term3 = xp.real(xp.fft.ifft2(xp.fft.fft2(im * mask_pad)))
cc = (term1 - term3) / (term2 - term3)

qx0 = np.zeros(datacube.data.shape[:2])
qy0 = np.zeros_like(qx0)
# get correlation peak
x, y = xp.unravel_index(xp.argmax(cc), im.shape)

for rx, ry in tqdmnd(datacube.R_Nx, datacube.R_Ny):
x, y = get_origin_single_dp_beamstop(datacube.data[rx, ry, :, :], mask)
# half pixel upsampling / parabola subpixel fitting
dx = (cc[x + 1, y] - cc[x - 1, y]) / (
4.0 * cc[x, y] - 2.0 * cc[x + 1, y] - 2.0 * cc[x - 1, y]
)
dy = (cc[x, y + 1] - cc[x, y - 1]) / (
4.0 * cc[x, y] - 2.0 * cc[x, y + 1] - 2.0 * cc[x, y - 1]
)
# xp += np.round(dx*2.0)/2.0
# yp += np.round(dy*2.0)/2.0
x = x.astype("float") + dx
y = y.astype("float") + dy

# upsample peak if needed
if upsample_factor > 1:
x, y = upsampled_correlation(
xp.fft.fft2(cc),
upsampleFactor=upsample_factor,
xyShift=xp.array((x, y)),
device=device,
)

qx0[rx, ry] = x
qy0[rx, ry] = y
# Correlation peak, moved to image center shift
qx0[rx, ry] = (x / 2) % datacube.data.shape[2]
qy0[rx, ry] = (y / 2) % datacube.data.shape[3]

return qx0, qy0
if return_cpu:
return copy_to_device(qx0), copy_to_device(qy0)
else:
return qx0, qy0
4 changes: 1 addition & 3 deletions py4DSTEM/process/diffraction/WK_scattering_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,7 @@ def RI1(BI, BJ, G):
ri1[sub] = np.pi * (BI * np.log((BI + BJ) / BI) + BJ * np.log((BI + BJ) / BJ))

sub = np.logical_and(eps <= 0.1, G > 0.0)
temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(
BJ / (BI + BJ)
)
temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(BJ / (BI + BJ))
temp += 0.75 * (BI**2 + BJ**2) - 0.25 * (BI + BJ) ** 2
temp -= 0.5 * (BI - BJ) ** 2
ri1[sub] += np.pi * G[sub] ** 2 * temp
Expand Down
3 changes: 1 addition & 2 deletions py4DSTEM/process/diffraction/crystal_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,7 @@ def plot_scattering_intensity(
int_sf_plot = calc_1D_profile(
k,
self.g_vec_leng,
(self.struct_factors_int**int_power_scale)
* (self.g_vec_leng**k_power_scale),
(self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale),
remove_origin=True,
k_broadening=k_broadening,
int_scale=int_scale,
Expand Down
3 changes: 1 addition & 2 deletions py4DSTEM/process/fit/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ def polar_gaussian_2D(
# t2 = np.min(np.vstack([t,1-t]))
t2 = np.square(t - mu_t)
return (
I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2)))
+ C
I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + C
)


Expand Down
Loading

0 comments on commit 2eae7d8

Please sign in to comment.