From 5180039499cd5d636c7ba6030f26954ab2f6dade Mon Sep 17 00:00:00 2001 From: smribet Date: Mon, 13 Jan 2025 13:44:03 -0800 Subject: [PATCH] trying to fix normalization error --- py4DSTEM/tomography/tomography.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/tomography/tomography.py b/py4DSTEM/tomography/tomography.py index 59925821a..1e67a901b 100644 --- a/py4DSTEM/tomography/tomography.py +++ b/py4DSTEM/tomography/tomography.py @@ -1261,20 +1261,33 @@ def _forward( ) # normalization real space - ind_real_bincount_weight = np.bincount( + ind_real_bincount_weight = xp.bincount( ind_real.ravel(), weights_real.ravel(), minlength=ind_real.max() ) - ind_real_bincount = np.bincount(ind_real.ravel(), minlength=ind_real.max()) + ind_real_bincount = xp.bincount(ind_real.ravel(), minlength=ind_real.max()) ind_real_bincount_weight = ind_real_bincount_weight[ind_real_bincount > 0] ind_real_bincount = ind_real_bincount[ind_real_bincount > 0] ind_real_bincount_weight[ind_real_bincount_weight == 0] = 1 - correction_factor = 1 / ind_real_bincount_weight - correction_factor = np.repeat(correction_factor, ind_real_bincount) - sorted_indicies = np.argsort(np.argsort(ind_real.ravel())) - correction_factor = correction_factor[sorted_indicies].reshape(ind_real.shape) - weights_real = weights_real * correction_factor + correction_factor_real = 1 / ind_real_bincount_weight + correction_factor_real = xp.repeat(correction_factor_real, ind_real_bincount) + sorted_indicies = xp.argsort(xp.argsort(ind_real.ravel())) + correction_factor_real = correction_factor_real[sorted_indicies].reshape(ind_real.shape) + weights_real = weights_real * correction_factor_real # normalization reciprocal space + # ind_diff_bincount_weight = xp.bincount( + # ind_diff.ravel(), weights_diff.ravel(), minlength=ind_diff.max() + # ) + # ind_diff_bincount = xp.bincount(ind_diff.ravel(), minlength=ind_diff.max()) + # ind_diff_bincount_weight = ind_diff_bincount_weight[ind_diff_bincount > 0] + # ind_diff_bincount = ind_diff_bincount[ind_diff_bincount > 0] + # ind_diff_bincount_weight[ind_diff_bincount_weight < 1 ] = 1 + # correction_factor_diff = 1 / ind_diff_bincount_weight + # correction_factor_diff = xp.repeat(correction_factor_diff, ind_diff_bincount) + # sorted_indicies = xp.argsort(xp.argsort(ind_diff.ravel())) + # correction_factor_diff = correction_factor_diff[sorted_indicies].reshape(ind_diff.shape) + # weights_diff = weights_diff * correction_factor_diff + # project bincount_x = (