Skip to content

Commit d98ae44

Browse files
committed
update
1 parent 582c837 commit d98ae44

File tree

5 files changed

+420
-173
lines changed

5 files changed

+420
-173
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 66 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from __future__ import annotations
12
from collections import OrderedDict
2-
from typing import Callable, Optional, Union
3+
from typing import Callable, Optional, Union, Any, Sequence
34

45
import geopandas as gpd
56
import matplotlib
@@ -16,28 +17,49 @@
1617
from pandas.api.types import is_categorical_dtype
1718
from spatial_image import SpatialImage
1819
from spatialdata import transform
19-
from spatialdata.models import Image2DModel
20+
from spatialdata.models import Image2DModel, TableModel
2021
from spatialdata.transformations import get_transformation
21-
22+
from matplotlib.colors import Colormap
2223
from spatialdata_plot._accessor import register_spatial_data_accessor
2324
from spatialdata_plot.pp.utils import (
2425
_get_instance_key,
2526
_get_region_key,
26-
_verify_plotting_tree_exists,
27+
_verify_plotting_tree,
2728
)
28-
from spatialdata_plot.render import (
29+
from spatialdata_plot.pl.render import (
2930
_render_channels,
3031
_render_images,
3132
_render_labels,
3233
_render_points,
3334
_render_shapes,
3435
)
35-
from spatialdata_plot.utils import (
36+
from spatialdata_plot.pl.utils import (
3637
_get_hex_colors_for_continous_values,
3738
_get_random_hex_colors,
3839
_get_subplots,
3940
_maybe_set_colors,
41+
Palette_t,
42+
CmapParams,
43+
_prepare_cmap_norm,
4044
)
45+
from matplotlib.colors import ListedColormap, Normalize, to_rgb
46+
from dataclasses import dataclass
47+
48+
49+
@dataclass
50+
class LabelsRenderParams:
51+
"""Labels render parameters.."""
52+
53+
region: str | None = None
54+
color: str | None = None
55+
groups: str | Sequence[str] | None = None
56+
contour_px: int | None = None
57+
outline: bool = False
58+
alt_var: str | None = None
59+
layer: str | None = None
60+
cmap_params: CmapParams = None
61+
palette: Palette_t = None
62+
alpha: float = 1.0
4163

4264

4365
@register_spatial_data_accessor("pl")
@@ -196,7 +218,7 @@ def render_shapes(
196218
raise ValueError(f"Column '{color_key}' not found in data.")
197219

198220
sdata = self._copy()
199-
sdata = _verify_plotting_tree_exists(sdata)
221+
sdata = _verify_plotting_tree(sdata)
200222
n_steps = len(sdata.plotting_tree.keys())
201223
sdata.plotting_tree[f"{n_steps+1}_render_shapes"] = {
202224
"palette": palette,
@@ -246,7 +268,7 @@ def render_points(
246268
raise TypeError("When giving a 'color_key', it must be of type 'str'.")
247269

248270
sdata = self._copy()
249-
sdata = _verify_plotting_tree_exists(sdata)
271+
sdata = _verify_plotting_tree(sdata)
250272
n_steps = len(sdata.plotting_tree.keys())
251273
sdata.plotting_tree[f"{n_steps+1}_render_points"] = {
252274
"palette": palette,
@@ -280,7 +302,7 @@ def render_images(
280302
281303
"""
282304
sdata = self._copy()
283-
sdata = _verify_plotting_tree_exists(sdata)
305+
sdata = _verify_plotting_tree(sdata)
284306
n_steps = len(sdata.plotting_tree.keys())
285307
sdata.plotting_tree[f"{n_steps+1}_render_images"] = {
286308
"palette": palette,
@@ -364,7 +386,7 @@ def render_channels(
364386
raise ValueError("Percentile parameters must satisfy pmin < pmax.")
365387

366388
sdata = self._copy()
367-
sdata = _verify_plotting_tree_exists(sdata)
389+
sdata = _verify_plotting_tree(sdata)
368390
n_steps = len(sdata.plotting_tree.keys())
369391

370392
sdata.plotting_tree[f"{n_steps+1}_render_channels"] = {
@@ -381,15 +403,19 @@ def render_channels(
381403

382404
def render_labels(
383405
self,
384-
instance_key: Optional[Union[str, None]] = None,
385-
color_key: Optional[Union[str, None]] = None,
386-
border_alpha: float = 1.0,
387-
border_color: Optional[Union[str, None]] = None,
388-
fill_alpha: float = 0.5,
389-
fill_color: Optional[Union[str, None]] = None,
390-
mode: str = "thick",
391-
palette: Optional[Union[str, list[str]]] = None,
392-
add_legend: bool = True,
406+
region: str | Sequence[str] | None = None,
407+
color: str | None = None,
408+
groups: str | Sequence[str] | None = None,
409+
contour_px: int | None = None,
410+
outline: bool = False,
411+
alt_var: str | None = None,
412+
layer: str | None = None,
413+
palette: Palette_t = None,
414+
cmap: Colormap | str | None = None,
415+
norm: Optional[Normalize] = None,
416+
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
417+
alpha: float = 1.0,
418+
**kwargs: Any,
393419
) -> sd.SpatialData:
394420
"""Render the labels contained in the given sd.SpatialData object
395421
@@ -399,24 +425,6 @@ def render_labels(
399425
sd.SpatialData
400426
instance_key : str
401427
The name of the column in the table that identifies individual labels
402-
color_key : str or None, optional (default: None)
403-
The name of the column in the table to use for coloring labels.
404-
border_alpha : float, optional (default: 1.0)
405-
The alpha value of the label border. Must be between 0 and 1.
406-
border_color : str or None, optional (default: None)
407-
The color of the border of the labels.
408-
fill_alpha : float, optional (default: 0.5)
409-
The alpha value of the fill of the labels. Must be between 0 and 1.
410-
fill_color : str or None, optional (default: None)
411-
The color of the fill of the labels.
412-
mode : str, optional (default: 'thick')
413-
The rendering mode of the labels. Must be one of 'thick', 'inner',
414-
'outer', or 'subpixel'.
415-
palette : str, list or None, optional (default: None)
416-
The color palette to use when coloring cells. If None, a default
417-
palette will be used.
418-
add_legend : bool, optional (default: True)
419-
Whether to add a legend to the plot.
420428
421429
Returns
422430
-------
@@ -441,66 +449,30 @@ def render_labels(
441449
alpha, color, and rendering mode of the labels, as well as whether to add a
442450
legend to the plot.
443451
"""
444-
if instance_key is not None:
445-
if not isinstance(instance_key, str):
446-
raise TypeError("Parameter 'instance_key' must be a string.")
447-
448-
if instance_key not in self._sdata.table.obs:
449-
raise ValueError(f"The provided instance_key '{instance_key}' is not a valid table column.")
450-
else:
451-
instance_key = self._sdata.table.uns["spatialdata_attrs"]["instance_key"]
452-
453-
if color_key is not None:
454-
if not isinstance(color_key, (str, type(None))):
455-
raise TypeError("Parameter 'color_key' must be a string.")
456-
457-
if color_key not in self._sdata.table.obs.columns and color_key not in self._sdata.table.var_names:
458-
raise ValueError(f"The provided color_key '{color_key}' is not a valid table column.")
459-
460-
if not isinstance(border_alpha, (int, float)):
461-
raise TypeError("Parameter 'border_alpha' must be a float.")
462-
463-
if not (border_alpha <= 1 and border_alpha >= 0):
464-
raise ValueError("Parameter 'border_alpha' must be between 0 and 1.")
465452

466-
if border_color is not None:
467-
if not isinstance(color_key, (str, type(None))):
468-
raise TypeError("If specified, parameter 'border_color' must be a string.")
469-
470-
if not isinstance(fill_alpha, (int, float)):
471-
raise TypeError("Parameter 'fill_alpha' must be a float.")
472-
473-
if not (fill_alpha <= 1 and fill_alpha >= 0):
474-
raise ValueError("Parameter 'fill_alpha' must be between 0 and 1.")
475-
476-
if fill_color is not None:
477-
if not isinstance(fill_color, (str, type(None))):
478-
raise TypeError("If specified, parameter 'fill_color' must be a string.")
479-
480-
valid_modes = ["thick", "inner", "outer", "subpixel"]
481-
if not isinstance(mode, str):
482-
raise TypeError("Parameter 'mode' must be a string.")
483-
484-
if mode not in valid_modes:
485-
raise ValueError("Parameter 'mode' must be one of 'thick', 'inner', 'outer', 'subpixel'.")
486-
487-
if not isinstance(add_legend, bool):
488-
raise TypeError("Parameter 'add_legend' must be a boolean.")
453+
if (
454+
color is not None
455+
and color not in self._sdata.table.obs.columns
456+
and color not in self._sdata.table.var_names
457+
):
458+
raise ValueError(f"'{color}' is not a valid table column.")
489459

490460
sdata = self._copy()
491-
sdata = _verify_plotting_tree_exists(sdata)
461+
sdata = _verify_plotting_tree(sdata)
492462
n_steps = len(sdata.plotting_tree.keys())
493-
sdata.plotting_tree[f"{n_steps+1}_render_labels"] = {
494-
"instance_key": instance_key,
495-
"color_key": color_key,
496-
"border_alpha": border_alpha,
497-
"border_color": border_color,
498-
"fill_alpha": fill_alpha,
499-
"fill_color": fill_color,
500-
"mode": mode,
501-
"palette": palette,
502-
"add_legend": add_legend,
503-
}
463+
cmap_params = _prepare_cmap_norm(cmap=cmap, norm=norm, na_color=na_color, **kwargs)
464+
sdata.plotting_tree[f"{n_steps+1}_render_labels"] = LabelsRenderParams(
465+
region=region,
466+
color=color,
467+
groups=groups,
468+
contour_px=contour_px,
469+
outline=outline,
470+
alt_var=alt_var,
471+
layer=layer,
472+
cmap_params=cmap_params,
473+
palette=palette,
474+
alpha=alpha,
475+
)
504476

505477
return sdata
506478

@@ -860,7 +832,7 @@ def show(
860832

861833
for idx, ax in enumerate(axs):
862834
key = list(sdata.labels.keys())[idx]
863-
_render_labels(sdata=sdata, params=params, key=key, ax=ax, extent=extent)
835+
_render_labels(sdata=sdata, render_params=params, key=key, ax=ax, extent=extent)
864836

865837
else:
866838
raise NotImplementedError(f"Command '{cmd}' is not supported.")

0 commit comments

Comments
 (0)