Skip to content

coloring shapes by categorical variable #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 4, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning][].
- Multipolygons are now handled correctly (#93)
- Legend order is now deterministic (#143)
- Images no longer normalised by default (#150)
- Filtering of shapes and points using the `groups` argument is now possible, coloring by palette and cmap arguments works for shapes and points (#153)
- Colorbar no longer autoscales to [0, 1] (#155)

## [0.0.4] - 2023-08-11
Expand Down
22 changes: 15 additions & 7 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dask.dataframe.core import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from matplotlib.axes import Axes
from matplotlib.colors import Colormap, ListedColormap, Normalize
from matplotlib.colors import Colormap, Normalize
from matplotlib.figure import Figure
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
from pandas.api.types import is_categorical_dtype
Expand Down Expand Up @@ -150,7 +150,7 @@ def render_shapes(
outline_width: float = 1.5,
outline_color: str | list[float] = "#000000ff",
layer: str | None = None,
palette: ListedColormap | str | None = None,
palette: str | list[str] | None = None,
cmap: Colormap | str | None = None,
norm: bool | Normalize = False,
na_color: str | tuple[float, ...] | None = "lightgrey",
Expand Down Expand Up @@ -182,9 +182,13 @@ def render_shapes(
layer
Key in :attr:`anndata.AnnData.layers` or `None` for :attr:`anndata.AnnData.X`.
palette
Palette for discrete annotations, see :class:`matplotlib.colors.Colormap`.
Palette for discrete annotations. List of valid color names that should be used
for the categories (all or as specified by `groups`). For a single category,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good descriptions 👌

a valid color name can be given as string.
cmap
Colormap for continuous annotations, see :class:`matplotlib.colors.Colormap`.
If no palette is given and `color` refers to a categorical, the colors are
sampled from this colormap.
norm
Colormap normalization for continuous annotations, see :class:`matplotlib.colors.Normalize`.
na_color
Expand Down Expand Up @@ -235,7 +239,7 @@ def render_points(
color: str | None = None,
groups: str | Sequence[str] | None = None,
size: float = 1.0,
palette: ListedColormap | str | None = None,
palette: str | list[str] | None = None,
cmap: Colormap | str | None = None,
norm: None | Normalize = None,
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
Expand All @@ -258,9 +262,13 @@ def render_points(
size
Value to scale points.
palette
Palette for discrete annotations, see :class:`matplotlib.colors.Colormap`.
Palette for discrete annotations. List of valid color names that should be used
for the categories (all or as specified by `groups`). For a single category,
a valid color name can be given as string.
cmap
Colormap for continuous annotations, see :class:`matplotlib.colors.Colormap`.
If no palette is given and `color` refers to a categorical, the colors are
sampled from this colormap.
norm
Colormap normalization for continuous annotations, see :class:`matplotlib.colors.Normalize`.
na_color
Expand Down Expand Up @@ -303,7 +311,7 @@ def render_images(
cmap: list[Colormap] | list[str] | Colormap | str | None = None,
norm: None | Normalize = None,
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
palette: ListedColormap | str | None = None,
palette: str | list[str] | None = None,
alpha: float = 1.0,
quantiles_for_norm: tuple[float | None, float | None] = (None, None),
**kwargs: Any,
Expand Down Expand Up @@ -381,7 +389,7 @@ def render_labels(
contour_px: int = 3,
outline: bool = False,
layer: str | None = None,
palette: ListedColormap | str | None = None,
palette: str | list[str] | None = None,
cmap: Colormap | str | None = None,
norm: None | Normalize = None,
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
Expand Down
58 changes: 52 additions & 6 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from copy import copy
from typing import Union

import dask
import geopandas as gpd
import matplotlib
import numpy as np
Expand All @@ -18,6 +19,7 @@
from spatialdata.models import (
Image2DModel,
Labels2DModel,
PointsModel,
)

from spatialdata_plot._logging import logger
Expand Down Expand Up @@ -57,6 +59,12 @@ def _render_shapes(
) -> None:
elements = render_params.elements

if render_params.groups is not None:
if isinstance(render_params.groups, str):
render_params.groups = [render_params.groups]
if not all(isinstance(g, str) for g in render_params.groups):
raise TypeError("All groups must be strings.")

sdata_filt = sdata.filter_by_coordinate_system(
coordinate_system=coordinate_system,
filter_table=sdata.table is not None,
Expand All @@ -68,7 +76,6 @@ def _render_shapes(
elements = list(sdata_filt.shapes.keys())

for e in elements:
# shapes = [sdata.shapes[e] for e in elements]
shapes = sdata.shapes[e]
n_shapes = sum([len(s) for s in shapes])

Expand All @@ -88,6 +95,7 @@ def _render_shapes(
palette=render_params.palette,
na_color=render_params.cmap_params.na_color,
alpha=render_params.fill_alpha,
cmap_params=render_params.cmap_params,
)

values_are_categorical = color_source_vector is not None
Expand All @@ -101,7 +109,15 @@ def _render_shapes(
if len(color_vector) == 0:
color_vector = [render_params.cmap_params.na_color]

# filter by `groups`
if render_params.groups is not None and color_source_vector is not None:
mask = color_source_vector.isin(render_params.groups)
shapes = shapes[mask]
shapes = shapes.reset_index()
color_source_vector = color_source_vector[mask]
color_vector = color_vector[mask]
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")

_cax = _get_collection_shape(
shapes=shapes,
s=render_params.scale,
Expand All @@ -122,9 +138,12 @@ def _render_shapes(
cax = ax.add_collection(_cax)

# Using dict.fromkeys here since set returns in arbitrary order
palette = (
ListedColormap(dict.fromkeys(color_vector)) if render_params.palette is None else render_params.palette
)
# remove the color of NaN values, else it might be assigned to a category
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to have this comment directly next to the dict.from_keys

# order of color in the palette should agree to order of occurence
if color_source_vector is None:
palette = ListedColormap(dict.fromkeys(color_vector))
else:
palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()]))

if not (
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
Expand Down Expand Up @@ -159,6 +178,12 @@ def _render_points(
scalebar_params: ScalebarParams,
legend_params: LegendParams,
) -> None:
if render_params.groups is not None:
if isinstance(render_params.groups, str):
render_params.groups = [render_params.groups]
if not all(isinstance(g, str) for g in render_params.groups):
raise TypeError("All groups must be strings.")

elements = render_params.elements

sdata_filt = sdata.filter_by_coordinate_system(
Expand All @@ -178,6 +203,14 @@ def _render_points(
color = [render_params.color] if isinstance(render_params.color, str) else render_params.color
coords.extend(color)

points = points[coords].compute()
# points[color[0]].cat.set_categories(render_params.groups, inplace=True)
if render_params.groups is not None:
points = points[points[color].isin(render_params.groups).values]
points[color[0]] = points[color[0]].cat.set_categories(render_params.groups)
points = dask.dataframe.from_pandas(points, npartitions=1)
sdata_filt.points[e] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"})

point_df = points[coords].compute()

# we construct an anndata to hack the plotting functions
Expand All @@ -204,6 +237,7 @@ def _render_points(
palette=render_params.palette,
na_color=render_params.cmap_params.na_color,
alpha=render_params.alpha,
cmap_params=render_params.cmap_params,
)

# color_source_vector is None when the values aren't categorical
Expand All @@ -226,14 +260,19 @@ def _render_points(
if not (
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
):
if color_source_vector is None:
palette = ListedColormap(dict.fromkeys(color_vector))
else:
palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()]))

_ = _decorate_axs(
ax=ax,
cax=cax,
fig_params=fig_params,
adata=adata,
value_to_plot=render_params.color,
color_source_vector=color_source_vector,
palette=render_params.palette,
palette=palette,
alpha=render_params.alpha,
na_color=render_params.cmap_params.na_color,
legend_fontsize=legend_params.legend_fontsize,
Expand Down Expand Up @@ -415,6 +454,12 @@ def _render_labels(
) -> None:
elements = render_params.elements

if render_params.groups is not None:
if isinstance(render_params.groups, str):
render_params.groups = [render_params.groups]
if not all(isinstance(g, str) for g in render_params.groups):
raise TypeError("All groups must be strings.")

sdata_filt = sdata.filter_by_coordinate_system(
coordinate_system=coordinate_system,
filter_table=sdata.table is not None,
Expand All @@ -441,7 +486,7 @@ def _render_labels(

table = sdata.table[sdata.table.obs[region_key].isin([label_key])]

# get isntance id based on subsetted table
# get instance id based on subsetted table
instance_id = table.obs[instance_key].values

# get color vector (categorical or continuous)
Expand All @@ -455,6 +500,7 @@ def _render_labels(
palette=render_params.palette,
na_color=render_params.cmap_params.na_color,
alpha=render_params.fill_alpha,
cmap_params=render_params.cmap_params,
)

if (render_params.fill_alpha != render_params.outline_alpha) and render_params.contour_px is not None:
Expand Down
1 change: 1 addition & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class CmapParams:
cmap: Colormap
norm: Normalize
na_color: str | tuple[float, ...] = (0.0, 0.0, 0.0, 0.0)
is_default: bool = True


@dataclass
Expand Down
Loading