Skip to content

Commit

Permalink
Implement missing plot functionality for pl.rank_genes_groups (#3428)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Jan 21, 2025
1 parent 6c89e1d commit 8ce811a
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 36 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/3428.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix {func}`scanpy.pl.rank_genes_groups`’s `ax` parameter {smaller}`P Angerer`
83 changes: 50 additions & 33 deletions src/scanpy/plotting/_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from matplotlib.figure import Figure

from ..._utils import Empty
from .._baseplot_class import BasePlot
from .._utils import DensityNorm

# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -346,7 +347,7 @@ def rank_genes_groups(
save: bool | None = None,
ax: Axes | None = None,
**kwds,
):
) -> list[Axes] | None:
"""\
Plot ranking of genes.
Expand All @@ -370,6 +371,9 @@ def rank_genes_groups(
`sharey=False`, each panel has its own y-axis range.
{show_save_ax}
Returns
-------
List of each group’s matplotlib axis or `None` if `show=True`.
Examples
--------
Expand Down Expand Up @@ -413,15 +417,26 @@ def rank_genes_groups(

from matplotlib import gridspec

fig = plt.figure(
figsize=(
n_panels_x * rcParams["figure.figsize"][0],
n_panels_y * rcParams["figure.figsize"][1],
if ax is None or (sps := ax.get_subplotspec()) is None:
fig = (
plt.figure(
figsize=(
n_panels_x * rcParams["figure.figsize"][0],
n_panels_y * rcParams["figure.figsize"][1],
)
)
if ax is None
else ax.get_figure()
)
)
gs = gridspec.GridSpec(nrows=n_panels_y, ncols=n_panels_x, wspace=0.22, hspace=0.3)
gs = gridspec.GridSpec(n_panels_y, n_panels_x, fig, wspace=0.22, hspace=0.3)
else:
fig = ax.get_figure()
gs = sps.subgridspec(n_panels_y, n_panels_x)
if fig is None:
msg = "passed ax has no associated figure"
raise RuntimeError(msg)

ax0 = None
axs: list[Axes] = []
ymin = np.inf
ymax = -np.inf
for count, group_name in enumerate(group_names):
Expand All @@ -433,20 +448,16 @@ def rank_genes_groups(
ymin = min(ymin, np.min(scores))
ymax = max(ymax, np.max(scores))

if ax0 is None:
ax = fig.add_subplot(gs[count])
ax0 = ax
else:
ax = fig.add_subplot(gs[count], sharey=ax0)
axs.append(fig.add_subplot(gs[count], sharey=axs[0] if axs else None))
else:
ymin = np.min(scores)
ymax = np.max(scores)
ymax += 0.3 * (ymax - ymin)

ax = fig.add_subplot(gs[count])
ax.set_ylim(ymin, ymax)
axs.append(fig.add_subplot(gs[count]))
axs[-1].set_ylim(ymin, ymax)

ax.set_xlim(-0.9, n_genes - 0.1)
axs[-1].set_xlim(-0.9, n_genes - 0.1)

# Mapping to gene_symbols
if gene_symbols is not None:
Expand All @@ -457,7 +468,7 @@ def rank_genes_groups(

# Making labels
for ig, gene_name in enumerate(gene_names):
ax.text(
axs[-1].text(
ig,
scores[ig],
gene_name,
Expand All @@ -467,23 +478,29 @@ def rank_genes_groups(
fontsize=fontsize,
)

ax.set_title(f"{group_name} vs. {reference}")
axs[-1].set_title(f"{group_name} vs. {reference}")
if count >= n_panels_x * (n_panels_y - 1):
ax.set_xlabel("ranking")
axs[-1].set_xlabel("ranking")

# print the 'score' label only on the first panel per row.
if count % n_panels_x == 0:
ax.set_ylabel("score")
axs[-1].set_ylabel("score")

if sharey is True:
if sharey is True and axs:
ymax += 0.3 * (ymax - ymin)
ax.set_ylim(ymin, ymax)
axs[0].set_ylim(ymin, ymax)

writekey = f"rank_genes_groups_{adata.uns[key]['params']['groupby']}"
savefig_or_show(writekey, show=show, save=save)
show = settings.autoshow if show is None else show
if show:
return None
return axs


def _fig_show_save_or_axes(plot_obj, return_fig, show, save):
def _fig_show_save_or_axes(
plot_obj: BasePlot, *, return_fig: bool, show: bool | None, save: bool | None
):
"""
Decides what to return
"""
Expand All @@ -510,7 +527,7 @@ def _rank_genes_groups_plot(
key: str | None = None,
show: bool | None = None,
save: bool | None = None,
return_fig: bool | None = False,
return_fig: bool = False,
gene_symbols: str | None = None,
**kwds,
):
Expand All @@ -524,10 +541,6 @@ def _rank_genes_groups_plot(
)
raise ValueError(msg)

if var_names is None and n_genes is None:
# set n_genes = 10 as default when none of the options is given
n_genes = 10

if key is None:
key = "rank_genes_groups"

Expand All @@ -544,6 +557,10 @@ def _rank_genes_groups_plot(
else:
var_names_list = var_names
else:
# set n_genes = 10 as default when none of the options is given
if n_genes is None:
n_genes = 10

# dict in which each group is the key and the n_genes are the values
var_names = {}
var_names_list = []
Expand Down Expand Up @@ -621,7 +638,7 @@ def _rank_genes_groups_plot(
if title is not None and "colorbar_title" not in kwds:
_pl.legend(title=title)

return _fig_show_save_or_axes(_pl, return_fig, show, save)
return _fig_show_save_or_axes(_pl, return_fig=return_fig, show=show, save=save)

elif plot_type == "stacked_violin":
from .._stacked_violin import stacked_violin
Expand All @@ -634,7 +651,7 @@ def _rank_genes_groups_plot(
gene_symbols=gene_symbols,
**kwds,
)
return _fig_show_save_or_axes(_pl, return_fig, show, save)
return _fig_show_save_or_axes(_pl, return_fig=return_fig, show=show, save=save)
elif plot_type == "heatmap":
from .._anndata import heatmap

Expand Down Expand Up @@ -846,7 +863,7 @@ def rank_genes_groups_dotplot(
key: str | None = None,
show: bool | None = None,
save: bool | None = None,
return_fig: bool | None = False,
return_fig: bool = False,
**kwds,
):
"""\
Expand Down Expand Up @@ -985,7 +1002,7 @@ def rank_genes_groups_stacked_violin(
key: str | None = None,
show: bool | None = None,
save: bool | None = None,
return_fig: bool | None = False,
return_fig: bool = False,
**kwds,
):
"""\
Expand Down Expand Up @@ -1073,7 +1090,7 @@ def rank_genes_groups_matrixplot(
key: str | None = None,
show: bool | None = None,
save: bool | None = None,
return_fig: bool | None = False,
return_fig: bool = False,
**kwds,
):
"""\
Expand Down
30 changes: 27 additions & 3 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
if TYPE_CHECKING:
from collections.abc import Callable

from matplotlib.axes import Axes


HERE: Path = Path(__file__).parent
ROOT = HERE / "_images"

Expand Down Expand Up @@ -842,9 +845,30 @@ def test_rank_genes_groups(image_comparer, name, fn):

with plt.rc_context({"axes.grid": True, "figure.figsize": (4, 4)}):
fn(pbmc)
key = "ranked_genes" if name == "basic" else f"ranked_genes_{name}"
save_and_compare_images(key)
plt.close()
key = "ranked_genes" if name == "basic" else f"ranked_genes_{name}"
save_and_compare_images(key)
plt.close()


def test_rank_genes_group_axes(image_comparer):
fn = next(fn for name, fn in _RANK_GENES_GROUPS_PARAMS if name == "basic")

save_and_compare_images = partial(image_comparer, ROOT, tol=23)

pbmc = pbmc68k_reduced()
sc.tl.rank_genes_groups(pbmc, "louvain", n_genes=pbmc.raw.shape[1])

pbmc.var["symbol"] = pbmc.var.index + "__"

fig, ax = plt.subplots(figsize=(12, 16))
ax.set_axis_off()
with plt.rc_context({"axes.grid": True}):
axes: list[Axes] = fn(pbmc, ax=ax, show=False)

assert len(axes) == 11
fig.show()
save_and_compare_images("ranked_genes")
plt.close()


@pytest.fixture(scope="session")
Expand Down

0 comments on commit 8ce811a

Please sign in to comment.