Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sc.pp.scale() doesn't support adata.X being a dask array #2491

Closed
2 of 3 tasks
elyall opened this issue May 18, 2023 · 1 comment · Fixed by #2942
Closed
2 of 3 tasks

sc.pp.scale() doesn't support adata.X being a dask array #2491

elyall opened this issue May 18, 2023 · 1 comment · Fixed by #2942

Comments

@elyall
Copy link

elyall commented May 18, 2023

  • I have checked that this issue has not already been reported.
  • I have confirmed this bug exists on the latest version of scanpy.
  • (optional) I have confirmed this bug exists on the master branch of scanpy.

In short sc.pp.scale() throws an error when adata.X is a dask array solely because sc.pp._simple.scale doesn't have the sc.pp._simple.scale_array function registered for a dask.array.Array type. Adding @scale.register(da.Array) here would fix that, but considering dask.array is only an optional dependency, there would have to be a conditional to wrap the function decoration.

This brings me to the larger issue, is scanpy supposed to or working toward supporting dask arrays completely? I'm new to scanpy and I'm not sure if this is a bug report or an enhancement request. I could submit a pull request for this one issue but I'm curious if I'll run into many such issues as I dive in further and trying to figure out if there's existing momentum in this direction or whether I should be following some other parallelization strategy.

I got here by following AnnData's guide on using dask and zarr to try to parallelize processing of a large scRNA-seq file. I come from a microscopy data analysis and ML background where my image data is stored in S3 hosted zarr arrays (in an OME-NGFF schema), handled by using dask arrays wrapped in xarray DataArrays, and parallelized across compute using ray. It would be nice to use a similar stack (dropping in anndata for xarray) for sc-seq analyses.

Minimal code sample (that we can copy&paste without having any data)

import zarr
import anndata as ad
import dask.array as da
import scanpy as sc

# write data to zarr file
rel_zarr_path = 'data/pbmc3k_processed.zarr'
adata = sc.datasets.pbmc3k_processed()
adata.write_zarr(f'./{rel_zarr_path}', chunks=[adata.shape[0], 5])
zarr.consolidate_metadata(f'./{rel_zarr_path}')

# read data from zarr file with X as a dask array
def read_dask(store):
    f = zarr.open(store, mode="r")

    def callback(func, elem_name: str, elem, iospec):
        if iospec.encoding_type in (
            "dataframe",
            "csr_matrix",
            "csc_matrix",
            "awkward-array",
        ):
            # Preventing recursing inside of these types
            return ad.experimental.read_elem(elem)
        elif iospec.encoding_type == "array":
            return da.from_zarr(elem)
        else:
            return func(elem)

    adata = ad.experimental.read_dispatched(f, callback=callback)

    return adata
adata_dask = read_dask(f'./{rel_zarr_path}')

# perform preprocessing
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
# sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
sc.pp.scale(adata_dask, max_value=10)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 39
     37 sc.pp.log1p(adata)
     38 # sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
---> 39 sc.pp.scale(adata_dask, max_value=10)

File ~/miniconda3/envs/omics/lib/python3.10/functools.py:889, in singledispatch..wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/miniconda3/envs/omics/lib/python3.10/site-packages/scanpy/preprocessing/_simple.py:844, in scale_anndata(adata, zero_center, max_value, copy, layer, obsm)
    842 view_to_actual(adata)
    843 X = _get_obs_rep(adata, layer=layer, obsm=obsm)
--> 844 X, adata.var["mean"], adata.var["std"] = scale(
    845     X,
    846     zero_center=zero_center,
    847     max_value=max_value,
    848     copy=False,  # because a copy has already been made, if it were to be made
    849     return_mean_std=True,
    850 )
    851 _set_obs_rep(adata, X, layer=layer, obsm=obsm)
    852 if copy:

File ~/miniconda3/envs/omics/lib/python3.10/functools.py:889, in singledispatch..wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

TypeError: scale() got an unexpected keyword argument 'return_mean_std'

Versions


anndata 0.9.1
scanpy 1.9.3

PIL 9.5.0
appnope 0.1.3
asciitree NA
asttokens NA
attr 23.1.0
backcall 0.2.0
brotli NA
certifi 2023.05.07
cffi 1.15.1
charset_normalizer 3.1.0
click 8.1.3
cloudpickle 2.2.1
colorama 0.4.6
colorful 0.5.5
colorful_orig 0.5.5
comm 0.1.3
cycler 0.10.0
cython_runtime NA
cytoolz 0.12.0
dask 2023.5.0
dateutil 2.8.2
debugpy 1.6.7
decorator 5.1.1
defusedxml 0.7.1
distributed 2023.5.0
entrypoints 0.4
executing 1.2.0
fasteners 0.17.3
filelock 3.12.0
fsspec 2023.5.0
google NA
grpc 1.43.0
h5py 3.8.0
idna 3.4
igraph 0.10.4
ipykernel 6.23.1
jedi 0.18.2
jinja2 3.1.2
joblib 1.2.0
jsonschema 4.17.3
kiwisolver 1.4.4
leidenalg 0.9.1
llvmlite 0.40.0
locket NA
lz4 4.3.2
markupsafe 2.1.2
matplotlib 3.6.3
matplotlib_inline 0.1.6
mpl_toolkits NA
msgpack 1.0.5
natsort 8.3.1
numba 0.57.0
numcodecs 0.11.0
numpy 1.24.3
packaging 23.1
pandas 2.0.1
parso 0.8.3
patsy 0.5.3
pexpect 4.8.0
pickleshare 0.7.5
pkg_resources NA
platformdirs 3.5.1
prometheus_client NA
prompt_toolkit 3.0.38
psutil 5.9.5
ptyprocess 0.7.0
pure_eval 0.2.2
pvectorc NA
pyarrow 9.0.0
pycparser 2.21
pydev_ipython NA
pydevconsole NA
pydevd 2.9.5
pydevd_file_utils NA
pydevd_plugins NA
pydevd_tracing NA
pygments 2.15.1
pynndescent 0.5.10
pynvml NA
pyparsing 3.0.9
pyrsistent NA
pytz 2023.3
ray 2.3.0
rb_analysis NA
requests 2.29.0
scipy 1.10.1
seaborn 0.12.2
session_info 1.0.0
setproctitle 1.2.2
setuptools 67.7.2
six 1.16.0
sklearn 1.2.2
socks 1.7.1
sortedcontainers 2.4.0
stack_data 0.6.2
statsmodels 0.14.0
tblib 1.7.0
texttable 1.6.7
threadpoolctl 3.1.0
tlz 0.12.0
toolz 0.12.0
tornado 6.3.2
tqdm 4.65.0
traitlets 5.9.0
typing_extensions NA
umap 0.5.3
urllib3 1.26.15
wcwidth 0.2.6
yaml 6.0
zarr 2.14.2
zict 3.0.0
zipp NA
zmq 25.0.2
zoneinfo NA

IPython 8.13.2
jupyter_client 8.2.0
jupyter_core 5.3.0

Python 3.10.11 | packaged by conda-forge | (main, May 10 2023, 19:01:19) [Clang 14.0.6 ]
macOS-13.3.1-arm64-arm-64bit

Session information updated at 2023-05-18 14:00

@elyall elyall changed the title sc.pp.normalize_total() doesn't support adata.X being a dask array sc.pp.scale() doesn't support adata.X being a dask array May 19, 2023
@flying-sheep
Copy link
Member

Fixed by #2942

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants