Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup scrublet #3044

Merged
merged 6 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/release-notes/1.10.2.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@

```{rubric} Performance
```

* `sparse_mean_variance_axis` now uses all cores for the calculations {pr}`3015` {smaller}`S Dicks`
* Speed up {func}`~scanpy.pp.scrublet` {pr}`3044` {smaller}`S Dicks` and {pr}`3056` {smaller}`P Angerer`
11 changes: 6 additions & 5 deletions scanpy/preprocessing/_scrublet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,11 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
pp.normalize_total(ad_obs)

# HVG process needs log'd data.

logged = pp.log1p(ad_obs, copy=True)
pp.highly_variable_genes(logged)
ad_obs = ad_obs[:, logged.var["highly_variable"]].copy()
ad_obs.layers["log1p"] = ad_obs.X.copy()
pp.log1p(ad_obs, layer="log1p")
pp.highly_variable_genes(ad_obs, layer="log1p")
del ad_obs.layers["log1p"]
ad_obs = ad_obs[:, ad_obs.var["highly_variable"]].copy()

# Simulate the doublets based on the raw expressions from the normalised
# and filtered object.
Expand All @@ -214,7 +215,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
synthetic_doublet_umi_subsampling=synthetic_doublet_umi_subsampling,
random_seed=random_state,
)

del ad_obs.layers["raw"]
if log_transform:
pp.log1p(ad_obs)
pp.log1p(ad_sim)
Expand Down
11 changes: 7 additions & 4 deletions scanpy/preprocessing/_scrublet/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import numpy as np
from scipy import sparse

from .sparse_utils import sparse_multiply, sparse_var, sparse_zscore
from scanpy.preprocessing._utils import _get_mean_var

from .sparse_utils import sparse_multiply, sparse_zscore

if TYPE_CHECKING:
from ..._utils import AnyRandom
Expand All @@ -20,7 +22,8 @@ def mean_center(self: Scrublet) -> None:


def normalize_variance(self: Scrublet) -> None:
gene_stdevs = np.sqrt(sparse_var(self._counts_obs_norm, axis=0))
_, gene_vars = _get_mean_var(self._counts_obs_norm, axis=0)
gene_stdevs = np.sqrt(gene_vars)
self._counts_obs_norm = sparse_multiply(self._counts_obs_norm.T, 1 / gene_stdevs).T
if self._counts_sim_norm is not None:
self._counts_sim_norm = sparse_multiply(
Expand All @@ -29,8 +32,8 @@ def normalize_variance(self: Scrublet) -> None:


def zscore(self: Scrublet) -> None:
gene_means = self._counts_obs_norm.mean(0)
gene_stdevs = np.sqrt(sparse_var(self._counts_obs_norm, axis=0))
gene_means, gene_vars = _get_mean_var(self._counts_obs_norm, axis=0)
gene_stdevs = np.sqrt(gene_vars)
self._counts_obs_norm = sparse_zscore(
self._counts_obs_norm, gene_mean=gene_means, gene_stdev=gene_stdevs
)
Expand Down
30 changes: 8 additions & 22 deletions scanpy/preprocessing/_scrublet/sparse_utils.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,28 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING

import numpy as np
from scipy import sparse

from scanpy.preprocessing._utils import _get_mean_var

from ..._utils import AnyRandom, get_random_state

if TYPE_CHECKING:
from numpy.typing import NDArray


def sparse_var(
E: sparse.csr_matrix | sparse.csc_matrix,
*,
axis: Literal[0, 1],
) -> NDArray[np.float64]:
"""variance across the specified axis"""

mean_gene: NDArray[np.float64] = E.mean(axis=axis).A.squeeze()
tmp: sparse.csc_matrix | sparse.csr_matrix = E.copy()
tmp.data **= 2
return tmp.mean(axis=axis).A.squeeze() - mean_gene**2


def sparse_multiply(
E: sparse.csr_matrix | sparse.csc_matrix | NDArray[np.float64],
a: float | int | NDArray[np.float64],
) -> sparse.csr_matrix | sparse.csc_matrix:
"""multiply each row of E by a scalar"""

nrow = E.shape[0]
w = sparse.lil_matrix((nrow, nrow))
w.setdiag(a)
w = sparse.dia_matrix((a, 0), shape=(nrow, nrow), dtype=a.dtype)
r = w @ E
if isinstance(r, (np.matrix, np.ndarray)):
if isinstance(r, np.ndarray):
return sparse.csc_matrix(r)
return r

Expand All @@ -46,11 +34,9 @@
gene_stdev: NDArray[np.float64] | None = None,
) -> sparse.csr_matrix | sparse.csc_matrix:
"""z-score normalize each column of E"""

if gene_mean is None:
gene_mean = E.mean(0)
if gene_stdev is None:
gene_stdev = np.sqrt(sparse_var(E, axis=0))
if gene_mean is None or gene_stdev is None:
gene_means, gene_stdevs = _get_mean_var(E, axis=0)
gene_stdevs = np.sqrt(gene_stdevs)

Check warning on line 39 in scanpy/preprocessing/_scrublet/sparse_utils.py

View check run for this annotation

Codecov / codecov/patch

scanpy/preprocessing/_scrublet/sparse_utils.py#L38-L39

Added lines #L38 - L39 were not covered by tests
return sparse_multiply(np.asarray((E - gene_mean).T), 1 / gene_stdev).T


Expand Down
Loading