Skip to content

Commit

Permalink
black formatting, updated variable name
Browse files Browse the repository at this point in the history
  • Loading branch information
cophus committed Jul 18, 2024
1 parent 9ada2d8 commit 8899487
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 96 deletions.
4 changes: 3 additions & 1 deletion py4DSTEM/process/diffraction/WK_scattering_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ def RI1(BI, BJ, G):
ri1[sub] = np.pi * (BI * np.log((BI + BJ) / BI) + BJ * np.log((BI + BJ) / BJ))

sub = np.logical_and(eps <= 0.1, G > 0.0)
temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(BJ / (BI + BJ))
temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(
BJ / (BI + BJ)
)
temp += 0.75 * (BI**2 + BJ**2) - 0.25 * (BI + BJ) ** 2
temp -= 0.5 * (BI - BJ) ** 2
ri1[sub] += np.pi * G[sub] ** 2 * temp
Expand Down
1 change: 0 additions & 1 deletion py4DSTEM/process/diffraction/crystal_ACOM.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,6 @@ def match_single_pattern(
sub = dqr < self.orientation_kernel_size

if np.any(sub):

im_polar[ind_radial, :] = np.sum(
np.power(
np.maximum(intensity[sub, None], 0.0),
Expand Down
68 changes: 35 additions & 33 deletions py4DSTEM/process/diffraction/crystal_phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern


class Crystal_Phase:
class CrystalPhase:
"""
A class storing multiple crystal structures, and associated diffraction data.
Must be initialized after matching orientations to a pointlistarray???
Expand Down Expand Up @@ -194,8 +194,8 @@ def quantify_single_pattern(
"""

# tolerance
tol2 = 1e-6
# tolerance for separating the origin peak.
tolerance_origin_2 = 1e-6

# calibrations
center = pointlistarray.calstate["center"]
Expand Down Expand Up @@ -231,11 +231,15 @@ def quantify_single_pattern(
)
# bragg_peaks = pointlistarray.get_pointlist(xy_position[0],xy_position[1]).copy()
if k_max is None:
keep = bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 > tol2
keep = (
bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2
> tolerance_origin_2
)
else:
keep = np.logical_and.reduce(
(
bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2 > tol2,
bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2
> tolerance_origin_2,
np.abs(bragg_peaks.data["qx"]) < k_max,
np.abs(bragg_peaks.data["qy"]) < k_max,
)
Expand Down Expand Up @@ -303,14 +307,14 @@ def quantify_single_pattern(
if k_max is None:
del_peak = (
bragg_peaks_fit.data["qx"] ** 2 + bragg_peaks_fit.data["qy"] ** 2
< tol2
< tolerance_origin_2
)
else:
del_peak = np.logical_or.reduce(
(
bragg_peaks_fit.data["qx"] ** 2
+ bragg_peaks_fit.data["qy"] ** 2
< tol2,
< tolerance_origin_2,
np.abs(bragg_peaks_fit.data["qx"]) > k_max,
np.abs(bragg_peaks_fit.data["qy"]) > k_max,
)
Expand Down Expand Up @@ -473,7 +477,6 @@ def quantify_single_pattern(
search = True

while search is True:

basis_solve = basis[:, inds_solve]
obs_solve = obs.copy()

Expand Down Expand Up @@ -685,7 +688,6 @@ def quantify_single_pattern(
crystal_inds_plot == None
or np.min(np.abs(c - crystal_inds_plot)) == 0
):

qx_fit = library_peaks[a0].data["qx"]
qy_fit = library_peaks[a0].data["qy"]

Expand All @@ -701,7 +703,6 @@ def quantify_single_pattern(
matches_fit = library_matches[a0]

if plot_only_nonzero_phases is False or phase_weights[a0] > 0:

# if np.mod(m,2) == 0:
ax.scatter(
qy_fit[matches_fit],
Expand Down Expand Up @@ -877,27 +878,30 @@ def quantify_phase(
disable=not progress_bar,
):
# calculate phase weights
phase_weights, phase_residual, phase_reliability, int_peaks = (
self.quantify_single_pattern(
pointlistarray=pointlistarray,
xy_position=(rx, ry),
corr_kernel_size=corr_kernel_size,
sigma_excitation_error=sigma_excitation_error,
precession_angle_degrees=precession_angle_degrees,
power_intensity=power_intensity,
power_intensity_experiment=power_intensity_experiment,
k_max=k_max,
max_number_patterns=max_number_patterns,
single_phase=single_phase,
allow_strain=allow_strain,
strain_iterations=strain_iterations,
strain_max=strain_max,
include_false_positives=include_false_positives,
weight_false_positives=weight_false_positives,
plot_result=False,
verbose=False,
returnfig=False,
)
(
phase_weights,
phase_residual,
phase_reliability,
int_peaks,
) = self.quantify_single_pattern(
pointlistarray=pointlistarray,
xy_position=(rx, ry),
corr_kernel_size=corr_kernel_size,
sigma_excitation_error=sigma_excitation_error,
precession_angle_degrees=precession_angle_degrees,
power_intensity=power_intensity,
power_intensity_experiment=power_intensity_experiment,
k_max=k_max,
max_number_patterns=max_number_patterns,
single_phase=single_phase,
allow_strain=allow_strain,
strain_iterations=strain_iterations,
strain_max=strain_max,
include_false_positives=include_false_positives,
weight_false_positives=weight_false_positives,
plot_result=False,
verbose=False,
returnfig=False,
)
self.phase_weights[rx, ry] = phase_weights
self.phase_residuals[rx, ry] = phase_residual
Expand Down Expand Up @@ -1178,7 +1182,6 @@ def plot_phase_maps(
)

for a0 in range(self.num_crystals):

ax[a0].imshow(
im_rgb_all[a0],
)
Expand Down Expand Up @@ -1315,7 +1318,6 @@ def plot_dominant_phase(
)

else:

# find the second correlation score for each crystal and match index
for a0 in range(self.num_crystals):
corr = phase_sig[a0].copy()
Expand Down
3 changes: 2 additions & 1 deletion py4DSTEM/process/diffraction/crystal_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,8 @@ def plot_scattering_intensity(
int_sf_plot = calc_1D_profile(
k,
self.g_vec_leng,
(self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale),
(self.struct_factors_int**int_power_scale)
* (self.g_vec_leng**k_power_scale),
remove_origin=True,
k_broadening=k_broadening,
int_scale=int_scale,
Expand Down
3 changes: 2 additions & 1 deletion py4DSTEM/process/fit/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def polar_gaussian_2D(
# t2 = np.min(np.vstack([t,1-t]))
t2 = np.square(t - mu_t)
return (
I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + C
I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2)))
+ C
)


Expand Down
28 changes: 14 additions & 14 deletions py4DSTEM/process/phase/magnetic_ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,20 +1196,20 @@ def reconstruct(

# position correction
if not fix_positions and a0 > 0:
self._positions_px_all[batch_indices] = (
self._position_correction(
object_sliced,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)
self._positions_px_all[
batch_indices
] = self._position_correction(
object_sliced,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)

measurement_error += batch_error
Expand Down
28 changes: 14 additions & 14 deletions py4DSTEM/process/phase/magnetic_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,20 +1497,20 @@ def reconstruct(

# position correction
if not fix_positions and a0 > 0:
self._positions_px_all[batch_indices] = (
self._position_correction(
self._object,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)
self._positions_px_all[
batch_indices
] = self._position_correction(
self._object,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)

measurement_error += batch_error
Expand Down
24 changes: 12 additions & 12 deletions py4DSTEM/process/phase/parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2202,16 +2202,16 @@ def score_CTF(coefs):
measured_shifts_sx = xp.zeros(
self._region_of_interest_shape, dtype=xp.float32
)
measured_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = (
self._xy_shifts_Ang[:, 0]
)
measured_shifts_sx[
self._xy_inds[:, 0], self._xy_inds[:, 1]
] = self._xy_shifts_Ang[:, 0]

measured_shifts_sy = xp.zeros(
self._region_of_interest_shape, dtype=xp.float32
)
measured_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = (
self._xy_shifts_Ang[:, 1]
)
measured_shifts_sy[
self._xy_inds[:, 0], self._xy_inds[:, 1]
] = self._xy_shifts_Ang[:, 1]

fitted_shifts = (
xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1)
Expand All @@ -2222,16 +2222,16 @@ def score_CTF(coefs):
fitted_shifts_sx = xp.zeros(
self._region_of_interest_shape, dtype=xp.float32
)
fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = (
fitted_shifts[:, 0]
)
fitted_shifts_sx[
self._xy_inds[:, 0], self._xy_inds[:, 1]
] = fitted_shifts[:, 0]

fitted_shifts_sy = xp.zeros(
self._region_of_interest_shape, dtype=xp.float32
)
fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = (
fitted_shifts[:, 1]
)
fitted_shifts_sy[
self._xy_inds[:, 0], self._xy_inds[:, 1]
] = fitted_shifts[:, 1]

max_shift = xp.max(
xp.array(
Expand Down
4 changes: 3 additions & 1 deletion py4DSTEM/process/phase/ptychographic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,9 @@ def _precompute_propagator_arrays(
propagators[i] = xp.exp(
1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz)
)
propagators[i] *= xp.exp(1.0j * (-(ky**2)[None] * np.pi * wavelength * dz))
propagators[i] *= xp.exp(
1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)
)

if theta_x is not None:
propagators[i] *= xp.exp(
Expand Down
28 changes: 14 additions & 14 deletions py4DSTEM/process/phase/ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,20 +1088,20 @@ def reconstruct(

# position correction
if not fix_positions:
self._positions_px_all[batch_indices] = (
self._position_correction(
object_sliced,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)
self._positions_px_all[
batch_indices
] = self._position_correction(
object_sliced,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)

measurement_error += batch_error
Expand Down
4 changes: 3 additions & 1 deletion py4DSTEM/process/phase/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def evaluate_gaussian_envelope(
self, alpha: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
xp = self._xp
return xp.exp(-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2)
return xp.exp(
-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2
)

def evaluate_spatial_envelope(
self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]
Expand Down
18 changes: 15 additions & 3 deletions py4DSTEM/process/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def electron_wavelength_angstrom(E_eV):
c = 299792458
h = 6.62607 * 10**-34

lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10
lam = (
h
/ ma.sqrt(2 * m * e * E_eV)
/ ma.sqrt(1 + e * E_eV / 2 / m / c**2)
* 10**10
)
return lam


Expand All @@ -102,8 +107,15 @@ def electron_interaction_parameter(E_eV):
e = 1.602177 * 10**-19
c = 299792458
h = 6.62607 * 10**-34
lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10
sigma = (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV)
lam = (
h
/ ma.sqrt(2 * m * e * E_eV)
/ ma.sqrt(1 + e * E_eV / 2 / m / c**2)
* 10**10
)
sigma = (
(2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV)
)
return sigma


Expand Down

0 comments on commit 8899487

Please sign in to comment.