Skip to content

Commit

Permalink
mpire windows fix?
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Jan 18, 2025
1 parent fc21a18 commit e82f3d0
Showing 1 changed file with 120 additions and 48 deletions.
168 changes: 120 additions & 48 deletions py4DSTEM/process/phase/direct_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namely single-sideband and Wigner-distribution deconvolution.
"""

import itertools
import platform
import warnings
from typing import Mapping, Sequence, Union

Expand Down Expand Up @@ -1474,15 +1474,14 @@ def __init__(self, *args, **kwargs):
def _reconstruct_single_frequency(
self,
shared_objects,
ind_x,
ind_y,
inds,
intensities_FFT,
phase_compensation=True,
virtual_detector_masks=None,
xp=np,
):
""" """
(
input_array,
output_array,
Qx_array,
Qy_array,
Expand All @@ -1497,10 +1496,10 @@ def _reconstruct_single_frequency(
sx, sy = Qx_array.shape

# 2 stride is for complex values
ind_x, ind_y = inds
ind_real = ind_x * sy * 2 + ind_y * 2 + 0
ind_imag = ind_x * sy * 2 + ind_y * 2 + 1

intensities_FFT = input_array[ind_x, ind_y]
Qx = Qx_array[ind_x, ind_y]
Qy = Qy_array[ind_x, ind_y]

Expand Down Expand Up @@ -1591,10 +1590,9 @@ def reconstruct(
phase_compensation: bool, optional
If True, the measured phase is compensated using a complex virtual detector. Recommended.
num_jobs: int, optional
Number of processes to use. Default is None, which spawns as many processes as CPUs on
the system.
threads_per_job: int, optional
Number of threads to use to avoid over-subscribing when using multiple processors.
Number of processes to use. Default is None, which has the following behavior:
- For Windows systems, it uses one process, since multiprocessing requires extra deps.
- For Unix systems, it uses as many processes as CPUs on the system.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model trotters,
to allow comparison with arbitrary geometry detector datasets. TO-DO
Expand Down Expand Up @@ -1639,15 +1637,31 @@ def reconstruct(
if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks).astype(xp.bool_)

# check if windows
is_windows = platform.system() == "Windows"
if is_windows and num_jobs is None:
num_jobs = 1

# set output_arrays
if num_jobs == 1:
output_array = xp.empty(sx * sy * 2, dtype=xp.float32)
else:
from multiprocessing import Array as mp_Array
if is_windows:
try:
from multiprocess import RawArray
except (ImportError, ModuleNotFoundError) as exc:
raise Exception(
(
'On Windows, num_jobs>1 requires the additional "multiprocess" package. '
"Please install it and try again. Note this is probably inadvisable."
)
) from exc
else:
from multiprocessing import RawArray

output_array = mp_Array("f", sx * sy * 2, lock=False)
output_array = RawArray("f", sx * sy * 2)

shared_objects = (
self._intensities_FFT,
output_array,
Qx_array,
Qy_array,
Expand All @@ -1666,20 +1680,20 @@ def reconstruct(
sx,
sy,
desc="Reconstructing object",
unit="freq.",
disable=not progress_bar,
):
self._reconstruct_single_frequency(
shared_objects,
ind_x,
ind_y,
(ind_x, ind_y),
self._intensities_FFT[ind_x, ind_y],
phase_compensation=phase_compensation,
virtual_detector_masks=virtual_detector_masks,
xp=xp,
)

psi = output_array.view(xp.complex64).reshape((sx, sy))
else:

if self._device == "gpu":
raise NotImplementedError()

Expand All @@ -1695,12 +1709,20 @@ def wrapper_function(*args):
xp=xp,
)

def generator_over_real_space_dimensions(data):
for i in np.ndindex(data.shape[:2]):
yield i, data[i]

data_generator = generator_over_real_space_dimensions(self._intensities_FFT)

with WorkerPool(
n_jobs=num_jobs, shared_objects=shared_objects, use_dill=True
n_jobs=num_jobs,
shared_objects=shared_objects,
use_dill=is_windows,
) as pool:
pool.map(
wrapper_function,
itertools.product(range(sx), range(sy)),
data_generator,
iterable_len=sx * sy,
n_splits=num_jobs,
progress_bar=progress_bar,
Expand Down Expand Up @@ -1732,14 +1754,13 @@ def __init__(self, *args, **kwargs):
def _reconstruct_single_frequency(
self,
shared_objects,
ind_x,
ind_y,
inds,
intensities_FFT,
virtual_detector_masks: Sequence[np.ndarray] = None,
xp=np,
):
""" """
(
input_array,
output_array,
Qx_array,
Qy_array,
Expand All @@ -1754,10 +1775,10 @@ def _reconstruct_single_frequency(
sx, sy = Qx_array.shape

# 2 stride is for complex values
ind_x, ind_y = inds
ind_real = ind_x * sy * 2 + ind_y * 2 + 0
ind_imag = ind_x * sy * 2 + ind_y * 2 + 1

intensities_FFT = input_array[ind_x, ind_y]
Qx = Qx_array[ind_x, ind_y]
Qy = Qy_array[ind_x, ind_y]

Expand Down Expand Up @@ -1821,8 +1842,9 @@ def reconstruct(
Parameters
--------
num_jobs: int, optional
Number of processes to use. Default is None, which spawns as many processes as CPUs on
the system.
Number of processes to use. Default is None, which has the following behavior:
- For Windows systems, it uses one process, since multiprocessing requires extra deps.
- For Unix systems, it uses as many processes as CPUs on the system.
threads_per_job: int, optional
Number of threads to use to avoid over-subscribing when using multiple processors.
virtual_detector_masks: np.ndarray
Expand Down Expand Up @@ -1874,15 +1896,31 @@ def reconstruct(
probe_normalization, virtual_detector_masks, in_place=True
)

# check if windows
is_windows = platform.system() == "Windows"
if is_windows and num_jobs is None:
num_jobs = 1

# set output_arrays
if num_jobs == 1:
output_array = xp.empty(sx * sy * 2, dtype=xp.float32)
else:
from multiprocessing import Array as mp_Array
if is_windows:
try:
from multiprocess import RawArray
except (ImportError, ModuleNotFoundError) as exc:
raise Exception(
(
'On Windows, num_jobs>1 requires the additional "multiprocess" package. '
"Please install it and try again. Note this is probably inadvisable."
)
) from exc
else:
from multiprocessing import RawArray

output_array = mp_Array("f", sx * sy * 2, lock=False)
output_array = RawArray("f", sx * sy * 2)

shared_objects = (
self._intensities_FFT,
output_array,
Qx_array,
Qy_array,
Expand All @@ -1901,19 +1939,19 @@ def reconstruct(
sx,
sy,
desc="Reconstructing object",
unit="freq.",
disable=not progress_bar,
):
self._reconstruct_single_frequency(
shared_objects,
ind_x,
ind_y,
(ind_x, ind_y),
self._intensities_FFT[ind_x, ind_y],
virtual_detector_masks=virtual_detector_masks,
xp=xp,
)

psi = output_array.view(xp.complex64).reshape((sx, sy))
else:

if self._device == "gpu":
raise NotImplementedError()

Expand All @@ -1928,19 +1966,28 @@ def wrapper_function(*args):
xp=xp,
)

def generator_over_real_space_dimensions(data):
for i in np.ndindex(data.shape[:2]):
yield i, data[i]

data_generator = generator_over_real_space_dimensions(self._intensities_FFT)

with WorkerPool(
n_jobs=num_jobs, shared_objects=shared_objects, use_dill=True
n_jobs=num_jobs,
shared_objects=shared_objects,
use_dill=is_windows,
) as pool:
pool.map(
wrapper_function,
itertools.product(range(sx), range(sy)),
data_generator,
iterable_len=sx * sy,
n_splits=num_jobs,
progress_bar=progress_bar,
)
psi = xp.frombuffer(output_array, dtype=xp.complex64).reshape((sx, sy))

self._object = xp.fft.ifft2(psi) / self._mean_diffraction_intensity

# no idea why this is necessary..
self._object = (2 - xp.abs(self._object)) * xp.exp(1j * xp.angle(self._object))

Expand All @@ -1961,13 +2008,12 @@ def __init__(self, *args, **kwargs):
def _reconstruct_single_frequency(
self,
shared_objects,
ind_x,
ind_y,
inds,
intensities_FFT,
xp=np,
):
""" """
(
input_array,
output_array,
Qx_array,
Qy_array,
Expand All @@ -1981,10 +2027,10 @@ def _reconstruct_single_frequency(
sx, sy = Qx_array.shape

# 2 stride is for complex values
ind_x, ind_y = inds
ind_real = ind_x * sy * 2 + ind_y * 2 + 0
ind_imag = ind_x * sy * 2 + ind_y * 2 + 1

intensities_FFT = input_array[ind_x, ind_y]
Qx = Qx_array[ind_x, ind_y]
Qy = Qy_array[ind_x, ind_y]

Expand Down Expand Up @@ -2024,11 +2070,13 @@ def reconstruct(
Parameters
--------
worker_pool: WorkerPool
If not None, reconstruction is dispatched to mpire WorkerPool instance.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model trotters,
to allow comparison with arbitrary geometry detector datasets. TO-DO
relative_wiener_epsilon: float
Value to scale probe Wigner distribution in avoiding division by zero in Wiener filtering.
~0.01 is a reasonable value.
num_jobs: int, optional
Number of processes to use. Default is None, which has the following behavior:
- For Windows systems, it uses one process, since multiprocessing requires extra deps.
- For Unix systems, it uses as many processes as CPUs on the system.
progress_bar: bool, optional
If True, reconstruction progress is displayed
Expand Down Expand Up @@ -2068,15 +2116,31 @@ def reconstruct(
Kx, Ky = self._spatial_frequencies
Qx_array, Qy_array = self._scan_frequencies

# check if windows
is_windows = platform.system() == "Windows"
if is_windows and num_jobs is None:
num_jobs = 1

# set output_arrays
if num_jobs == 1:
output_array = xp.empty(sx * sy * 2, dtype=xp.float32)
else:
from multiprocessing import Array as mp_Array
if is_windows:
try:
from multiprocess import RawArray
except (ImportError, ModuleNotFoundError) as exc:
raise Exception(
(
'On Windows, num_jobs>1 requires the additional "multiprocess" package. '
"Please install it and try again. Note this is probably inadvisable."
)
) from exc
else:
from multiprocessing import RawArray

output_array = mp_Array("f", sx * sy * 2, lock=False)
output_array = RawArray("f", sx * sy * 2)

shared_objects = (
self._intensities_FFT,
output_array,
Qx_array,
Qy_array,
Expand All @@ -2094,18 +2158,18 @@ def reconstruct(
sx,
sy,
desc="Reconstructing object",
unit="freq.",
disable=not progress_bar,
):
self._reconstruct_single_frequency(
shared_objects,
ind_x,
ind_y,
(ind_x, ind_y),
self._intensities_FFT[ind_x, ind_y],
xp=xp,
)

psi = output_array.view(xp.complex64).reshape((sx, sy))
else:

if self._device == "gpu":
raise NotImplementedError()

Expand All @@ -2119,12 +2183,20 @@ def wrapper_function(*args):
xp=xp,
)

def generator_over_real_space_dimensions(data):
for i in np.ndindex(data.shape[:2]):
yield i, data[i]

data_generator = generator_over_real_space_dimensions(self._intensities_FFT)

with WorkerPool(
n_jobs=num_jobs, shared_objects=shared_objects, use_dill=True
n_jobs=num_jobs,
shared_objects=shared_objects,
use_dill=is_windows,
) as pool:
pool.map(
wrapper_function,
itertools.product(range(sx), range(sy)),
data_generator,
iterable_len=sx * sy,
n_splits=num_jobs,
progress_bar=progress_bar,
Expand Down

0 comments on commit e82f3d0

Please sign in to comment.