Skip to content

Commit 3ab73ea

Browse files
authored
Merge branch 'main' into bugfix/sd-notebooks-spatial-query
2 parents 872bf02 + 0e0ecb1 commit 3ab73ea

13 files changed

+131
-27
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning][].
1010

1111
## [0.1.0] - tbd
1212

13+
## [0.0.5] - 2023-10-02
14+
1315
### Added
1416

1517
- Can now scale shapes (#152)
@@ -20,6 +22,7 @@ and this project adheres to [Semantic Versioning][].
2022
- Multipolygons are now handled correctly (#93)
2123
- Legend order is now deterministic (#143)
2224
- Images no longer normalised by default (#150)
25+
- Filtering of shapes and points using the `groups` argument is now possible, coloring by palette and cmap arguments works for shapes and points (#153)
2326
- Colorbar no longer autoscales to [0, 1] (#155)
2427
- Plotting shapes after a spatial query is now possible (#163)
2528

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
[![Documentation][badge-docs]][link-docs]
77
[![Codecov][badge-codecov]][link-codecov]
88
[![Documentation][badge-pypi]][link-pypi]
9+
[![DOI](https://zenodo.org/badge/588223127.svg)](https://zenodo.org/badge/latestdoi/588223127)
910

1011
[badge-tests]: https://img.shields.io/github/actions/workflow/status/scverse/spatialdata-plot/test_and_deploy.yaml?branch=main
1112
[link-tests]: https://github.com/scverse/spatialdata-plot/actions/workflows/test.yml

src/spatialdata_plot/pl/basic.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from dask.dataframe.core import DataFrame as DaskDataFrame
1515
from geopandas import GeoDataFrame
1616
from matplotlib.axes import Axes
17-
from matplotlib.colors import Colormap, ListedColormap, Normalize
17+
from matplotlib.colors import Colormap, Normalize
1818
from matplotlib.figure import Figure
1919
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
2020
from pandas.api.types import is_categorical_dtype
@@ -150,7 +150,7 @@ def render_shapes(
150150
outline_width: float = 1.5,
151151
outline_color: str | list[float] = "#000000ff",
152152
layer: str | None = None,
153-
palette: ListedColormap | str | None = None,
153+
palette: str | list[str] | None = None,
154154
cmap: Colormap | str | None = None,
155155
norm: bool | Normalize = False,
156156
na_color: str | tuple[float, ...] | None = "lightgrey",
@@ -182,9 +182,13 @@ def render_shapes(
182182
layer
183183
Key in :attr:`anndata.AnnData.layers` or `None` for :attr:`anndata.AnnData.X`.
184184
palette
185-
Palette for discrete annotations, see :class:`matplotlib.colors.Colormap`.
185+
Palette for discrete annotations. List of valid color names that should be used
186+
for the categories (all or as specified by `groups`). For a single category,
187+
a valid color name can be given as string.
186188
cmap
187189
Colormap for continuous annotations, see :class:`matplotlib.colors.Colormap`.
190+
If no palette is given and `color` refers to a categorical, the colors are
191+
sampled from this colormap.
188192
norm
189193
Colormap normalization for continuous annotations, see :class:`matplotlib.colors.Normalize`.
190194
na_color
@@ -235,7 +239,7 @@ def render_points(
235239
color: str | None = None,
236240
groups: str | Sequence[str] | None = None,
237241
size: float = 1.0,
238-
palette: ListedColormap | str | None = None,
242+
palette: str | list[str] | None = None,
239243
cmap: Colormap | str | None = None,
240244
norm: None | Normalize = None,
241245
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
@@ -258,9 +262,13 @@ def render_points(
258262
size
259263
Value to scale points.
260264
palette
261-
Palette for discrete annotations, see :class:`matplotlib.colors.Colormap`.
265+
Palette for discrete annotations. List of valid color names that should be used
266+
for the categories (all or as specified by `groups`). For a single category,
267+
a valid color name can be given as string.
262268
cmap
263269
Colormap for continuous annotations, see :class:`matplotlib.colors.Colormap`.
270+
If no palette is given and `color` refers to a categorical, the colors are
271+
sampled from this colormap.
264272
norm
265273
Colormap normalization for continuous annotations, see :class:`matplotlib.colors.Normalize`.
266274
na_color
@@ -303,7 +311,7 @@ def render_images(
303311
cmap: list[Colormap] | list[str] | Colormap | str | None = None,
304312
norm: None | Normalize = None,
305313
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
306-
palette: ListedColormap | str | None = None,
314+
palette: str | list[str] | None = None,
307315
alpha: float = 1.0,
308316
quantiles_for_norm: tuple[float | None, float | None] = (None, None),
309317
**kwargs: Any,
@@ -381,7 +389,7 @@ def render_labels(
381389
contour_px: int = 3,
382390
outline: bool = False,
383391
layer: str | None = None,
384-
palette: ListedColormap | str | None = None,
392+
palette: str | list[str] | None = None,
385393
cmap: Colormap | str | None = None,
386394
norm: None | Normalize = None,
387395
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),

src/spatialdata_plot/pl/render.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from copy import copy
55
from typing import Union
66

7+
import dask
78
import geopandas as gpd
89
import matplotlib
910
import numpy as np
@@ -18,6 +19,7 @@
1819
from spatialdata.models import (
1920
Image2DModel,
2021
Labels2DModel,
22+
PointsModel,
2123
)
2224

2325
from spatialdata_plot._logging import logger
@@ -57,6 +59,12 @@ def _render_shapes(
5759
) -> None:
5860
elements = render_params.elements
5961

62+
if render_params.groups is not None:
63+
if isinstance(render_params.groups, str):
64+
render_params.groups = [render_params.groups]
65+
if not all(isinstance(g, str) for g in render_params.groups):
66+
raise TypeError("All groups must be strings.")
67+
6068
sdata_filt = sdata.filter_by_coordinate_system(
6169
coordinate_system=coordinate_system,
6270
filter_table=sdata.table is not None,
@@ -68,7 +76,6 @@ def _render_shapes(
6876
elements = list(sdata_filt.shapes.keys())
6977

7078
for e in elements:
71-
# shapes = [sdata.shapes[e] for e in elements]
7279
shapes = sdata.shapes[e]
7380
n_shapes = sum([len(s) for s in shapes])
7481

@@ -88,6 +95,7 @@ def _render_shapes(
8895
palette=render_params.palette,
8996
na_color=render_params.cmap_params.na_color,
9097
alpha=render_params.fill_alpha,
98+
cmap_params=render_params.cmap_params,
9199
)
92100

93101
values_are_categorical = color_source_vector is not None
@@ -101,7 +109,15 @@ def _render_shapes(
101109
if len(color_vector) == 0:
102110
color_vector = [render_params.cmap_params.na_color]
103111

112+
# filter by `groups`
113+
if render_params.groups is not None and color_source_vector is not None:
114+
mask = color_source_vector.isin(render_params.groups)
115+
shapes = shapes[mask]
116+
shapes = shapes.reset_index()
117+
color_source_vector = color_source_vector[mask]
118+
color_vector = color_vector[mask]
104119
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
120+
105121
_cax = _get_collection_shape(
106122
shapes=shapes,
107123
s=render_params.scale,
@@ -122,9 +138,12 @@ def _render_shapes(
122138
cax = ax.add_collection(_cax)
123139

124140
# Using dict.fromkeys here since set returns in arbitrary order
125-
palette = (
126-
ListedColormap(dict.fromkeys(color_vector)) if render_params.palette is None else render_params.palette
127-
)
141+
# remove the color of NaN values, else it might be assigned to a category
142+
# order of color in the palette should agree to order of occurence
143+
if color_source_vector is None:
144+
palette = ListedColormap(dict.fromkeys(color_vector))
145+
else:
146+
palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()]))
128147

129148
if not (
130149
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
@@ -159,6 +178,12 @@ def _render_points(
159178
scalebar_params: ScalebarParams,
160179
legend_params: LegendParams,
161180
) -> None:
181+
if render_params.groups is not None:
182+
if isinstance(render_params.groups, str):
183+
render_params.groups = [render_params.groups]
184+
if not all(isinstance(g, str) for g in render_params.groups):
185+
raise TypeError("All groups must be strings.")
186+
162187
elements = render_params.elements
163188

164189
sdata_filt = sdata.filter_by_coordinate_system(
@@ -178,6 +203,14 @@ def _render_points(
178203
color = [render_params.color] if isinstance(render_params.color, str) else render_params.color
179204
coords.extend(color)
180205

206+
points = points[coords].compute()
207+
# points[color[0]].cat.set_categories(render_params.groups, inplace=True)
208+
if render_params.groups is not None:
209+
points = points[points[color].isin(render_params.groups).values]
210+
points[color[0]] = points[color[0]].cat.set_categories(render_params.groups)
211+
points = dask.dataframe.from_pandas(points, npartitions=1)
212+
sdata_filt.points[e] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"})
213+
181214
point_df = points[coords].compute()
182215

183216
# we construct an anndata to hack the plotting functions
@@ -204,6 +237,7 @@ def _render_points(
204237
palette=render_params.palette,
205238
na_color=render_params.cmap_params.na_color,
206239
alpha=render_params.alpha,
240+
cmap_params=render_params.cmap_params,
207241
)
208242

209243
# color_source_vector is None when the values aren't categorical
@@ -226,14 +260,19 @@ def _render_points(
226260
if not (
227261
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
228262
):
263+
if color_source_vector is None:
264+
palette = ListedColormap(dict.fromkeys(color_vector))
265+
else:
266+
palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()]))
267+
229268
_ = _decorate_axs(
230269
ax=ax,
231270
cax=cax,
232271
fig_params=fig_params,
233272
adata=adata,
234273
value_to_plot=render_params.color,
235274
color_source_vector=color_source_vector,
236-
palette=render_params.palette,
275+
palette=palette,
237276
alpha=render_params.alpha,
238277
na_color=render_params.cmap_params.na_color,
239278
legend_fontsize=legend_params.legend_fontsize,
@@ -415,6 +454,12 @@ def _render_labels(
415454
) -> None:
416455
elements = render_params.elements
417456

457+
if render_params.groups is not None:
458+
if isinstance(render_params.groups, str):
459+
render_params.groups = [render_params.groups]
460+
if not all(isinstance(g, str) for g in render_params.groups):
461+
raise TypeError("All groups must be strings.")
462+
418463
sdata_filt = sdata.filter_by_coordinate_system(
419464
coordinate_system=coordinate_system,
420465
filter_table=sdata.table is not None,
@@ -441,7 +486,7 @@ def _render_labels(
441486

442487
table = sdata.table[sdata.table.obs[region_key].isin([label_key])]
443488

444-
# get isntance id based on subsetted table
489+
# get instance id based on subsetted table
445490
instance_id = table.obs[instance_key].values
446491

447492
# get color vector (categorical or continuous)
@@ -455,6 +500,7 @@ def _render_labels(
455500
palette=render_params.palette,
456501
na_color=render_params.cmap_params.na_color,
457502
alpha=render_params.fill_alpha,
503+
cmap_params=render_params.cmap_params,
458504
)
459505

460506
if (render_params.fill_alpha != render_params.outline_alpha) and render_params.contour_px is not None:

src/spatialdata_plot/pl/render_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class CmapParams:
1919
cmap: Colormap
2020
norm: Normalize
2121
na_color: str | tuple[float, ...] = (0.0, 0.0, 0.0, 0.0)
22+
is_default: bool = True
2223

2324

2425
@dataclass

0 commit comments

Comments
 (0)