Skip to content

Commit

Permalink
Refactor of colorbar and norm logic (#346)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Wouter-Michiel Vierdag <[email protected]>
  • Loading branch information
3 people authored Sep 4, 2024
1 parent 6cef5df commit c6d6153
Show file tree
Hide file tree
Showing 20 changed files with 33 additions and 76 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ and this project adheres to [Semantic Versioning][].

- Lowered RMSE-threshold for plot-based tests from 45 to 15 (#344)
- When subsetting to `groups`, `NA` isn't automatically added to legend (#344)
- When rendering a single image channel, a colorbar is now shown (#346)
- Removed `percentiles_for_norm` parameter (#346)
- Changed `norm` to no longer accept bools, only `mpl.colors.Normalise` or `None` (#346)

### Fixed

- Filtering with `groups` now preserves original cmap (#344)
- Non-selected `groups` are now not shown in `na_color` (#344)
- Several issues associated with `norm` and `colorbar` (#346)

## [0.2.5] - 2024-08-23

Expand Down
16 changes: 3 additions & 13 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def render_shapes(
outline_color: str | list[float] = "#000000ff",
outline_alpha: float | int = 0.0,
cmap: Colormap | str | None = None,
norm: bool | Normalize = False,
norm: Normalize | None = None,
scale: float | int = 1.0,
method: str | None = None,
table_name: str | None = None,
Expand Down Expand Up @@ -301,7 +301,7 @@ def render_points(
palette: list[str] | str | None = None,
na_color: ColorLike | None = "default",
cmap: Colormap | str | None = None,
norm: None | Normalize = None,
norm: Normalize | None = None,
size: float | int = 1.0,
method: str | None = None,
table_name: str | None = None,
Expand Down Expand Up @@ -422,7 +422,6 @@ def render_images(
na_color: ColorLike | None = "default",
palette: list[str] | str | None = None,
alpha: float | int = 1.0,
percentiles_for_norm: tuple[float, float] | None = None,
scale: str | None = None,
**kwargs: Any,
) -> sd.SpatialData:
Expand Down Expand Up @@ -457,8 +456,6 @@ def render_images(
Palette to color images. The number of palettes should be equal to the number of channels.
alpha : float | int, default 1.0
Alpha value for the images. Must be a numeric between 0 and 1.
percentiles_for_norm : tuple[float, float] | None
Optional pair of floats (pmin < pmax, 0-100) which will be used for quantile normalization.
scale : str | None
Influences the resolution of the rendering. Possibilities include:
1) `None` (default): The image is rasterized to fit the canvas size. For
Expand Down Expand Up @@ -486,20 +483,14 @@ def render_images(
cmap=cmap,
norm=norm,
scale=scale,
percentiles_for_norm=percentiles_for_norm,
)

sdata = self._copy()
sdata = _verify_plotting_tree(sdata)
n_steps = len(sdata.plotting_tree.keys())

for element, param_values in params_dict.items():
# cmap_params = _prepare_cmap_norm(
# cmap=params_dict[element]["cmap"],
# norm=norm,
# na_color=params_dict[element]["na_color"], # type: ignore[arg-type]
# **kwargs,
# )

cmap_params: list[CmapParams] | CmapParams
if isinstance(cmap, list):
cmap_params = [
Expand All @@ -525,7 +516,6 @@ def render_images(
cmap_params=cmap_params,
palette=param_values["palette"],
alpha=param_values["alpha"],
percentiles_for_norm=param_values["percentiles_for_norm"],
scale=param_values["scale"],
zorder=n_steps,
)
Expand Down
30 changes: 13 additions & 17 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import datashader as ds
import geopandas as gpd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -47,7 +48,6 @@
_maybe_set_colors,
_mpl_ax_contains_elements,
_multiscale_to_spatial_image,
_normalize,
_rasterize_if_necessary,
_set_color_source_vec,
to_hex,
Expand Down Expand Up @@ -128,6 +128,7 @@ def _render_shapes(
shapes = shapes.reset_index()
color_source_vector = color_source_vector[mask]
color_vector = color_vector[mask]

shapes = gpd.GeoDataFrame(shapes, geometry="geometry")

# Using dict.fromkeys here since set returns in arbitrary order
Expand Down Expand Up @@ -255,9 +256,13 @@ def _render_shapes(
for path in _cax.get_paths():
path.vertices = trans.transform(path.vertices)

# Sets the limits of the colorbar to the values instead of [0, 1]
if not norm and not values_are_categorical:
_cax.set_clim(min(color_vector), max(color_vector))
if not values_are_categorical:
# If the user passed a Normalize object with vmin/vmax we'll use those,
# # if not we'll use the min/max of the color_vector
_cax.set_clim(
vmin=render_params.cmap_params.norm.vmin or min(color_vector),
vmax=render_params.cmap_params.norm.vmax or max(color_vector),
)

if len(set(color_vector)) != 1 or list(set(color_vector))[0] != to_hex(render_params.cmap_params.na_color):
# necessary in case different shapes elements are annotated with one table
Expand Down Expand Up @@ -603,11 +608,6 @@ def _render_images(
if n_channels == 1 and not isinstance(render_params.cmap_params, list):
layer = img.sel(c=channels[0]).squeeze() if isinstance(channels[0], str) else img.isel(c=channels[0]).squeeze()

if render_params.percentiles_for_norm != (None, None):
layer = _normalize(
layer, pmin=render_params.percentiles_for_norm[0], pmax=render_params.percentiles_for_norm[1], clip=True
)

if render_params.cmap_params.norm: # type: ignore[attr-defined]
layer = render_params.cmap_params.norm(layer) # type: ignore[attr-defined]

Expand All @@ -623,20 +623,16 @@ def _render_images(

_ax_show_and_transform(layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder)

if legend_params.colorbar:
sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm)
fig_params.fig.colorbar(sm, ax=ax)

# 2) Image has any number of channels but 1
else:
layers = {}
for ch_index, c in enumerate(channels):
layers[c] = img.sel(c=c).copy(deep=True).squeeze()

if render_params.percentiles_for_norm != (None, None):
layers[c] = _normalize(
layers[c],
pmin=render_params.percentiles_for_norm[0],
pmax=render_params.percentiles_for_norm[1],
clip=True,
)

if not isinstance(render_params.cmap_params, list):
if render_params.cmap_params.norm is not None:
layers[c] = render_params.cmap_params.norm(layers[c])
Expand Down
29 changes: 1 addition & 28 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def _get_scalebar(

def _prepare_cmap_norm(
cmap: Colormap | str | None = None,
norm: Normalize | bool = False,
norm: Normalize | None = None,
na_color: ColorLike | None = None,
vmin: float | None = None,
vmax: float | None = None,
Expand Down Expand Up @@ -1623,29 +1623,6 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
if scale < 0:
raise ValueError("Parameter 'scale' must be a positive number.")

if (percentiles_for_norm := param_dict.get("percentiles_for_norm")) is None:
percentiles_for_norm = (None, None)
elif not (isinstance(percentiles_for_norm, (list, tuple)) or len(percentiles_for_norm) != 2):
raise TypeError("Parameter 'percentiles_for_norm' must be a list or tuple of exactly two floats or None.")
elif not all(
isinstance(p, (float, int, type(None)))
and isinstance(p, type(percentiles_for_norm[0]))
and (p is None or 0 <= p <= 100)
for p in percentiles_for_norm
):
raise TypeError(
"Each item in 'percentiles_for_norm' must be of the same dtype and must be a float or int within [0, 100], "
"or None"
)
elif (
percentiles_for_norm[0] is not None
and percentiles_for_norm[1] is not None
and percentiles_for_norm[0] > percentiles_for_norm[1]
):
raise ValueError("The first number in 'percentiles_for_norm' must not be smaller than the second.")
if "percentiles_for_norm" in param_dict:
param_dict["percentiles_for_norm"] = percentiles_for_norm

if size := param_dict.get("size"):
if not isinstance(size, (float, int)):
raise TypeError("Parameter 'size' must be numeric.")
Expand Down Expand Up @@ -1886,7 +1863,6 @@ def _validate_image_render_params(
cmap: list[Colormap | str] | Colormap | str | None,
norm: Normalize | None,
scale: str | None,
percentiles_for_norm: tuple[float | None, float | None] | None,
) -> dict[str, dict[str, Any]]:
param_dict: dict[str, Any] = {
"sdata": sdata,
Expand All @@ -1898,7 +1874,6 @@ def _validate_image_render_params(
"cmap": cmap,
"norm": norm,
"scale": scale,
"percentiles_for_norm": percentiles_for_norm,
}
param_dict = _type_check_params(param_dict, "images")

Expand Down Expand Up @@ -1945,8 +1920,6 @@ def _validate_image_render_params(
else:
element_params[el]["scale"] = scale

element_params[el]["percentiles_for_norm"] = param_dict["percentiles_for_norm"]

return element_params


Expand Down
Binary file modified tests/_images/Images_can_pass_cmap_to_single_channel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_pass_color_to_each_channel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_pass_color_to_single_channel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_render_a_single_channel_from_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_stack_render_images.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Shapes_colorbar_can_be_normalised.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Shapes_colorbar_respects_input_limits.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion tests/pl/test_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def test_render_images_can_plot_one_cyx_image(request):
def test_render_images_can_plot_multiple_cyx_images(share_coordinate_system: str, request):
fun = request.getfixturevalue("get_sdata_with_multiple_images")
sdata = fun(share_coordinate_system)
sdata.pl.render_images().pl.show()
sdata.pl.render_images().pl.show(
colorbar=False, # otherwise we'll get one cbar per image in the same cs
)
axs = plt.gcf().get_axes()

if share_coordinate_system == "all":
Expand Down
22 changes: 6 additions & 16 deletions tests/pl/test_render_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import matplotlib
import numpy as np
import scanpy as sc
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
from spatial_image import to_spatial_image
from spatialdata import SpatialData
Expand Down Expand Up @@ -49,9 +48,6 @@ def test_plot_can_render_a_single_channel_from_image(self, sdata_blobs: SpatialD
def test_plot_can_render_a_single_channel_from_multiscale_image(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(element="blobs_multiscale_image", channel=0).pl.show()

def test_plot_can_render_a_single_channel_from_image_no_el(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(channel=0).pl.show()

def test_plot_can_render_a_single_channel_str_from_image(self, sdata_blobs_str: SpatialData):
sdata_blobs_str.pl.render_images(element="blobs_image", channel="c1").pl.show()

Expand All @@ -70,16 +66,13 @@ def test_plot_can_render_two_channels_str_from_image(self, sdata_blobs_str: Spat
def test_plot_can_render_two_channels_str_from_multiscale_image(self, sdata_blobs_str: SpatialData):
sdata_blobs_str.pl.render_images(element="blobs_multiscale_image", channel=["c1", "c2"]).pl.show()

def test_plot_can_pass_vmin_vmax(self, sdata_blobs: SpatialData):
fig, axs = plt.subplots(ncols=2, figsize=(6, 3))
sdata_blobs.pl.render_images(element="blobs_image", channel=1).pl.show(ax=axs[0])
sdata_blobs.pl.render_images(element="blobs_image", channel=1, vmin=0, vmax=0.4).pl.show(ax=axs[1])

def test_plot_can_pass_normalize(self, sdata_blobs: SpatialData):
fig, axs = plt.subplots(ncols=2, figsize=(6, 3))
def test_plot_can_pass_normalize_clip_True(self, sdata_blobs: SpatialData):
norm = Normalize(vmin=0, vmax=0.4, clip=True)
sdata_blobs.pl.render_images(element="blobs_image", channel=1).pl.show(ax=axs[0])
sdata_blobs.pl.render_images(element="blobs_image", channel=1, norm=norm).pl.show(ax=axs[1])
sdata_blobs.pl.render_images(element="blobs_image", channel=0, norm=norm).pl.show()

def test_plot_can_pass_normalize_clip_False(self, sdata_blobs: SpatialData):
norm = Normalize(vmin=0, vmax=0.4, clip=False)
sdata_blobs.pl.render_images(element="blobs_image", channel=0, norm=norm).pl.show()

def test_plot_can_pass_color_to_single_channel(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(element="blobs_image", channel=1, palette="red").pl.show()
Expand All @@ -97,9 +90,6 @@ def test_plot_can_pass_cmap_to_each_channel(self, sdata_blobs: SpatialData):
element="blobs_image", channel=[0, 1, 2], cmap=["Reds", "Greens", "Blues"]
).pl.show()

def test_plot_can_normalize_image(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(element="blobs_image", percentiles_for_norm=(5, 90)).pl.show()

def test_plot_can_render_multiscale_image(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images("blobs_multiscale_image").pl.show()

Expand Down
4 changes: 3 additions & 1 deletion tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import scanpy as sc
from anndata import AnnData
from matplotlib.colors import Normalize
from shapely.geometry import MultiPolygon, Point, Polygon
from spatialdata import SpatialData, deepcopy
from spatialdata.models import ShapesModel, TableModel
Expand Down Expand Up @@ -146,7 +147,8 @@ def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["cluster"] = [1, 2, 3, 5, 20]
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups=["c1"], norm=True).pl.show()
norm = Normalize(vmin=0, vmax=5, clip=True)
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups=["c1"], norm=norm).pl.show()

def test_plot_can_plot_shapes_after_spatial_query(self, sdata_blobs: SpatialData):
# subset to only shapes, should be unnecessary after rasterizeation of multiscale images is included
Expand Down

0 comments on commit c6d6153

Please sign in to comment.