Skip to content

Commit f130032

Browse files
authored
Merge branch 'main' into bugfix/issue108-outline_color-doesnt-work
2 parents 2f04d04 + bde6827 commit f130032

24 files changed

+198
-45
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
from typing import Any
88

99
import matplotlib.pyplot as plt
10+
import numpy as np
1011
import scanpy as sc
1112
import spatialdata as sd
1213
from anndata import AnnData
1314
from dask.dataframe.core import DataFrame as DaskDataFrame
1415
from geopandas import GeoDataFrame
1516
from matplotlib.axes import Axes
16-
from matplotlib.colors import Colormap, Normalize
17+
from matplotlib.colors import Colormap, ListedColormap, Normalize
1718
from matplotlib.figure import Figure
1819
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
1920
from pandas.api.types import is_categorical_dtype
@@ -33,12 +34,12 @@
3334
)
3435
from spatialdata_plot.pl.utils import (
3536
LegendParams,
36-
Palette_t,
3737
_FontSize,
3838
_FontWeight,
3939
_get_cs_contents,
4040
_get_extent,
4141
_maybe_set_colors,
42+
_mpl_ax_contains_elements,
4243
_multiscale_to_image,
4344
_prepare_cmap_norm,
4445
_prepare_params_plot,
@@ -148,7 +149,7 @@ def render_shapes(
148149
outline_color: tuple[str, str] = ("#000000ff", "#ffffffff"), # black, white
149150
alt_var: str | None = None,
150151
layer: str | None = None,
151-
palette: Palette_t = None,
152+
palette: ListedColormap | str | None = None,
152153
cmap: Colormap | str | None = None,
153154
norm: None | Normalize = None,
154155
na_color: str | tuple[float, ...] | None = "lightgrey",
@@ -230,7 +231,7 @@ def render_points(
230231
color: str | None = None,
231232
groups: str | Sequence[str] | None = None,
232233
size: float = 1.0,
233-
palette: Palette_t = None,
234+
palette: ListedColormap | str | None = None,
234235
cmap: Colormap | str | None = None,
235236
norm: None | Normalize = None,
236237
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
@@ -298,7 +299,7 @@ def render_images(
298299
cmap: Colormap | str | None = None,
299300
norm: None | Normalize = None,
300301
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
301-
palette: Palette_t = None,
302+
palette: ListedColormap | str | None = None,
302303
alpha: float = 1.0,
303304
**kwargs: Any,
304305
) -> sd.SpatialData:
@@ -356,7 +357,7 @@ def render_labels(
356357
outline: bool = False,
357358
alt_var: str | None = None,
358359
layer: str | None = None,
359-
palette: Palette_t = None,
360+
palette: ListedColormap | str | None = None,
360361
cmap: Colormap | str | None = None,
361362
norm: None | Normalize = None,
362363
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
@@ -454,6 +455,7 @@ def show(
454455
fig: Figure | None = None,
455456
title: None | str | Sequence[str] = None,
456457
share_extent: bool = True,
458+
pad_extent: int = 0,
457459
ax: Axes | Sequence[Axes] | None = None,
458460
return_ax: bool = False,
459461
save: None | str | Path = None,
@@ -523,6 +525,14 @@ def show(
523525
# Simplicstic solution: If the images are multiscale, just use the first
524526
sdata = _multiscale_to_image(sdata)
525527

528+
# get original axis extent for later comparison
529+
x_min_orig, x_max_orig = (np.inf, -np.inf)
530+
y_min_orig, y_max_orig = (np.inf, -np.inf)
531+
532+
if isinstance(ax, Axes) and _mpl_ax_contains_elements(ax):
533+
x_min_orig, x_max_orig = ax.get_xlim()
534+
y_max_orig, y_min_orig = ax.get_ylim() # (0, 0) is top-left
535+
526536
# handle coordinate system
527537
coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems
528538
if isinstance(coordinate_systems, str):
@@ -532,12 +542,38 @@ def show(
532542
if cs not in sdata.coordinate_systems:
533543
raise ValueError(f"Unknown coordinate system '{cs}', valid choices are: {sdata.coordinate_systems}")
534544

545+
# Check if user specified only certain elements to be plotted
546+
cs_contents = _get_cs_contents(sdata)
547+
elements_to_be_rendered = []
548+
for cmd, params in render_cmds.items():
549+
if cmd == "render_images" and cs_contents.query(f"cs == '{cs}'")["has_images"][0]: # noqa: SIM114
550+
if params.elements is not None:
551+
elements_to_be_rendered += (
552+
[params.elements] if isinstance(params.elements, str) else params.elements
553+
)
554+
elif cmd == "render_shapes" and cs_contents.query(f"cs == '{cs}'")["has_shapes"][0]: # noqa: SIM114
555+
if params.elements is not None:
556+
elements_to_be_rendered += (
557+
[params.elements] if isinstance(params.elements, str) else params.elements
558+
)
559+
elif cmd == "render_points" and cs_contents.query(f"cs == '{cs}'")["has_points"][0]: # noqa: SIM114
560+
if params.elements is not None:
561+
elements_to_be_rendered += (
562+
[params.elements] if isinstance(params.elements, str) else params.elements
563+
)
564+
elif cmd == "render_labels" and cs_contents.query(f"cs == '{cs}'")["has_labels"][0]: # noqa: SIM102
565+
if params.elements is not None:
566+
elements_to_be_rendered += (
567+
[params.elements] if isinstance(params.elements, str) else params.elements
568+
)
569+
535570
extent = _get_extent(
536571
sdata=sdata,
537572
has_images="render_images" in render_cmds,
538573
has_labels="render_labels" in render_cmds,
539574
has_points="render_points" in render_cmds,
540575
has_shapes="render_shapes" in render_cmds,
576+
elements=elements_to_be_rendered,
541577
coordinate_systems=coordinate_systems,
542578
)
543579

@@ -585,7 +621,6 @@ def show(
585621
)
586622

587623
# go through tree
588-
cs_contents = _get_cs_contents(sdata)
589624
for i, cs in enumerate(coordinate_systems):
590625
sdata = self._copy()
591626
# properly transform all elements to the current coordinate system
@@ -693,12 +728,10 @@ def show(
693728
]
694729
):
695730
# If the axis already has limits, only expand them but not overwrite
696-
x_min, x_max = ax.get_xlim()
697-
y_min, y_max = ax.get_ylim()
698-
x_min = min(x_min, extent[cs][0])
699-
x_max = max(x_max, extent[cs][1])
700-
y_min = min(y_min, extent[cs][2])
701-
y_max = max(y_max, extent[cs][3])
731+
x_min = min(x_min_orig, extent[cs][0]) - pad_extent
732+
x_max = max(x_max_orig, extent[cs][1]) + pad_extent
733+
y_min = min(y_min_orig, extent[cs][2]) - pad_extent
734+
y_max = max(y_max_orig, extent[cs][3]) + pad_extent
702735
ax.set_xlim(x_min, x_max)
703736
ax.set_ylim(y_max, y_min) # (0, 0) is top-left
704737

src/spatialdata_plot/pl/render.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,9 @@ def _render_images(
363363
color = render_params.palette
364364

365365
else:
366-
color = _get_colors_for_categorical_obs(img.coords["c"].values.tolist())
366+
color = _get_colors_for_categorical_obs(
367+
img.coords["c"].values.tolist(), palette=render_params.cmap_params.cmap
368+
)
367369

368370
cmaps = _get_linear_colormap([str(c) for c in color[:num_channels]], "k")
369371
img = _normalize(img, clip=True)

src/spatialdata_plot/pl/utils.py

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from functools import partial
88
from pathlib import Path
99
from types import MappingProxyType
10-
from typing import Any, Literal, Optional, Union
10+
from typing import Any, Literal
1111

1212
import matplotlib
1313
import matplotlib.pyplot as plt
@@ -46,9 +46,6 @@
4646

4747
from spatialdata_plot.pp.utils import _get_coordinate_system_mapping
4848

49-
Palette_t = Optional[Union[str, ListedColormap]]
50-
_Normalize = Union[Normalize, Sequence[Normalize]]
51-
_SeqStr = Union[str, Sequence[str]]
5249
_FontWeight = Literal["light", "normal", "medium", "semibold", "bold", "heavy", "black"]
5350
_FontSize = Literal["xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large"]
5451

@@ -90,7 +87,7 @@ def _prepare_params_plot(
9087
frameon: bool | None = None,
9188
# this is passed at `render_*`
9289
cmap: Colormap | str | None = None,
93-
norm: _Normalize | None = None,
90+
norm: Normalize | Sequence[Normalize] | None = None,
9491
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
9592
vmin: float | None = None,
9693
vmax: float | None = None,
@@ -178,11 +175,12 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame:
178175

179176
def _get_extent(
180177
sdata: sd.SpatialData,
181-
coordinate_systems: None | str | Sequence[str] = None,
178+
coordinate_systems: Sequence[str] | str | None = None,
182179
has_images: bool = True,
183180
has_labels: bool = True,
184181
has_points: bool = True,
185182
has_shapes: bool = True,
183+
elements: Iterable[Any] | None = None,
186184
share_extent: bool = False,
187185
) -> dict[str, tuple[int, int, int, int]]:
188186
"""Return the extent of all elements in their respective coordinate systems.
@@ -191,16 +189,18 @@ def _get_extent(
191189
----------
192190
sdata
193191
The sd.SpatialData object to retrieve the extent from
194-
images
192+
has_images
195193
Flag indicating whether to consider images when calculating the extent
196-
labels
194+
has_labels
197195
Flag indicating whether to consider labels when calculating the extent
198-
points
196+
has_points
199197
Flag indicating whether to consider points when calculating the extent
200-
shapes
201-
Flag indicating whether to consider shaoes when calculating the extent
202-
img_transformations
203-
List of transformations already applied to the images
198+
has_shapes
199+
Flag indicating whether to consider shapes when calculating the extent
200+
elements
201+
Optional list of element names to be considered. When None, all are used.
202+
share_extent
203+
Flag indicating whether to use the same extent for all coordinate systems
204204
205205
Returns
206206
-------
@@ -212,6 +212,12 @@ def _get_extent(
212212
cs_mapping = _get_coordinate_system_mapping(sdata)
213213
cs_contents = _get_cs_contents(sdata)
214214

215+
if elements is None: # to shut up ruff
216+
elements = []
217+
218+
if not isinstance(elements, list):
219+
raise ValueError(f"Invalid type of `elements`: {type(elements)}, expected `list`.")
220+
215221
if coordinate_systems is not None:
216222
if isinstance(coordinate_systems, str):
217223
coordinate_systems = [coordinate_systems]
@@ -220,6 +226,8 @@ def _get_extent(
220226

221227
for cs_name, element_ids in cs_mapping.items():
222228
extent[cs_name] = {}
229+
if len(elements) > 0:
230+
element_ids = [e for e in element_ids if e in elements]
223231

224232
def _get_extent_after_transformations(element: Any, cs_name: str) -> Sequence[int]:
225233
tmp = element.copy()
@@ -449,7 +457,7 @@ class CmapParams:
449457

450458
def _prepare_cmap_norm(
451459
cmap: Colormap | str | None = None,
452-
norm: _Normalize | None = None,
460+
norm: Normalize | Sequence[Normalize] | None = None,
453461
na_color: str | tuple[float, ...] = (0.0, 0.0, 0.0, 0.0),
454462
vmin: float | None = None,
455463
vmax: float | None = None,
@@ -627,7 +635,11 @@ def _normalize(
627635
return norm
628636

629637

630-
def _get_colors_for_categorical_obs(categories: Sequence[str | int], palette: Palette_t = None) -> list[str]:
638+
def _get_colors_for_categorical_obs(
639+
categories: Sequence[str | int],
640+
palette: ListedColormap | str | None = None,
641+
alpha: float = 1.0,
642+
) -> list[str]:
631643
"""
632644
Return a list of colors for a categorical observation.
633645
@@ -644,27 +656,40 @@ def _get_colors_for_categorical_obs(categories: Sequence[str | int], palette: Pa
644656
-------
645657
None
646658
"""
647-
length = len(categories)
659+
len_cat = len(categories)
648660

649661
# check if default matplotlib palette has enough colors
650662
if palette is None:
651-
if len(rcParams["axes.prop_cycle"].by_key()["color"]) >= length:
663+
if len(rcParams["axes.prop_cycle"].by_key()["color"]) >= len_cat:
652664
cc = rcParams["axes.prop_cycle"]()
653-
palette = [next(cc)["color"] for _ in range(length)]
665+
palette = [next(cc)["color"] for _ in range(len_cat)]
654666
else:
655-
if length <= 20:
667+
if len_cat <= 20:
656668
palette = default_20
657-
elif length <= 28:
669+
elif len_cat <= 28:
658670
palette = default_28
659-
elif length <= len(default_102): # 103 colors
671+
elif len_cat <= len(default_102): # 103 colors
660672
palette = default_102
661673
else:
662-
palette = ["grey" for _ in range(length)]
674+
palette = ["grey" for _ in range(len_cat)]
663675
logging.info(
664676
"input has more than 103 categories. Uniform " "'grey' color will be used for all categories."
665677
)
666678

667-
return palette[:length] # type: ignore[return-value]
679+
# otherwise, single chanels turn out grey
680+
color_idx = np.linspace(0, 1, len_cat) if len_cat > 1 else [0.7]
681+
682+
if isinstance(palette, str):
683+
cmap = plt.get_cmap(palette)
684+
palette = [to_hex(x) for x in cmap(color_idx, alpha=alpha)]
685+
elif isinstance(palette, ListedColormap):
686+
palette = [to_hex(x) for x in palette(color_idx, alpha=alpha)]
687+
elif isinstance(palette, LinearSegmentedColormap):
688+
palette = [to_hex(palette(x, alpha=alpha)) for x in [color_idx]]
689+
else:
690+
raise TypeError(f"Palette is {type(palette)} but should be string or `ListedColormap`.")
691+
692+
return palette[:len_cat] # type: ignore[return-value]
668693

669694

670695
def _set_color_source_vec(
@@ -673,8 +698,8 @@ def _set_color_source_vec(
673698
use_raw: bool | None = None,
674699
alt_var: str | None = None,
675700
layer: str | None = None,
676-
groups: _SeqStr | None = None,
677-
palette: Palette_t = None,
701+
groups: Sequence[str] | str | None = None,
702+
palette: ListedColormap | str | None = None,
678703
na_color: str | tuple[float, ...] | None = None,
679704
alpha: float = 1.0,
680705
) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]:
@@ -769,7 +794,7 @@ def _get_palette(
769794
categories: Sequence[Any],
770795
adata: AnnData | None = None,
771796
cluster_key: None | str = None,
772-
palette: Palette_t = None,
797+
palette: ListedColormap | str | None = None,
773798
alpha: float = 1.0,
774799
) -> Mapping[str, str] | None:
775800
if adata is not None and palette is None:
@@ -845,7 +870,7 @@ def _decorate_axs(
845870
adata: AnnData,
846871
value_to_plot: str | None,
847872
color_source_vector: pd.Series[CategoricalDtype],
848-
palette: Palette_t = None,
873+
palette: ListedColormap | str | None = None,
849874
alpha: float = 1.0,
850875
na_color: str | tuple[float, ...] = (0.0, 0.0, 0.0, 0.0),
851876
legend_fontsize: int | float | _FontSize | None = None,
@@ -1127,3 +1152,17 @@ def _robust_transform(element: Any, cs: str) -> Any:
11271152
raise ValueError("Unable to transform element.") from e
11281153

11291154
return element
1155+
1156+
1157+
def _mpl_ax_contains_elements(ax: Axes) -> bool:
1158+
"""Check if any objects have been plotted on the axes object.
1159+
1160+
While extracting the extent, we need to know if the axes object has just been
1161+
initialised and therefore has extent (0, 1), (0,1) or if it has been plotted on
1162+
and therefore has a different extent.
1163+
1164+
Based on: https://stackoverflow.com/a/71966295
1165+
"""
1166+
return (
1167+
len(ax.lines) > 0 or len(ax.collections) > 0 or len(ax.images) > 0 or len(ax.patches) > 0 or len(ax.tables) > 0
1168+
)
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
14.1 KB
Loading

tests/_images/Labels_labels.png

-14.2 KB
Binary file not shown.
7.44 KB
Loading
-467 Bytes
Loading
-671 Bytes
Loading
21.4 KB
Loading

0 commit comments

Comments
 (0)