1
+ from __future__ import annotations
1
2
from collections import OrderedDict
2
- from typing import Callable , Optional , Union
3
+ from typing import Callable , Optional , Union , Any , Sequence
3
4
4
5
import geopandas as gpd
5
6
import matplotlib
16
17
from pandas .api .types import is_categorical_dtype
17
18
from spatial_image import SpatialImage
18
19
from spatialdata import transform
19
- from spatialdata .models import Image2DModel
20
+ from spatialdata .models import Image2DModel , TableModel
20
21
from spatialdata .transformations import get_transformation
21
-
22
+ from matplotlib . colors import Colormap
22
23
from spatialdata_plot ._accessor import register_spatial_data_accessor
23
24
from spatialdata_plot .pp .utils import (
24
25
_get_instance_key ,
25
26
_get_region_key ,
26
- _verify_plotting_tree_exists ,
27
+ _verify_plotting_tree ,
27
28
)
28
- from spatialdata_plot .render import (
29
+ from spatialdata_plot .pl . render import (
29
30
_render_channels ,
30
31
_render_images ,
31
32
_render_labels ,
32
33
_render_points ,
33
34
_render_shapes ,
34
35
)
35
- from spatialdata_plot .utils import (
36
+ from spatialdata_plot .pl . utils import (
36
37
_get_hex_colors_for_continous_values ,
37
38
_get_random_hex_colors ,
38
39
_get_subplots ,
39
40
_maybe_set_colors ,
41
+ Palette_t ,
42
+ CmapParams ,
43
+ _prepare_cmap_norm ,
40
44
)
45
+ from matplotlib .colors import ListedColormap , Normalize , to_rgb
46
+ from dataclasses import dataclass
47
+
48
+
49
+ @dataclass
50
+ class LabelsRenderParams :
51
+ """Labels render parameters.."""
52
+
53
+ region : str | None = None
54
+ color : str | None = None
55
+ groups : str | Sequence [str ] | None = None
56
+ contour_px : int | None = None
57
+ outline : bool = False
58
+ alt_var : str | None = None
59
+ layer : str | None = None
60
+ cmap_params : CmapParams = None
61
+ palette : Palette_t = None
62
+ alpha : float = 1.0
41
63
42
64
43
65
@register_spatial_data_accessor ("pl" )
@@ -196,7 +218,7 @@ def render_shapes(
196
218
raise ValueError (f"Column '{ color_key } ' not found in data." )
197
219
198
220
sdata = self ._copy ()
199
- sdata = _verify_plotting_tree_exists (sdata )
221
+ sdata = _verify_plotting_tree (sdata )
200
222
n_steps = len (sdata .plotting_tree .keys ())
201
223
sdata .plotting_tree [f"{ n_steps + 1 } _render_shapes" ] = {
202
224
"palette" : palette ,
@@ -246,7 +268,7 @@ def render_points(
246
268
raise TypeError ("When giving a 'color_key', it must be of type 'str'." )
247
269
248
270
sdata = self ._copy ()
249
- sdata = _verify_plotting_tree_exists (sdata )
271
+ sdata = _verify_plotting_tree (sdata )
250
272
n_steps = len (sdata .plotting_tree .keys ())
251
273
sdata .plotting_tree [f"{ n_steps + 1 } _render_points" ] = {
252
274
"palette" : palette ,
@@ -280,7 +302,7 @@ def render_images(
280
302
281
303
"""
282
304
sdata = self ._copy ()
283
- sdata = _verify_plotting_tree_exists (sdata )
305
+ sdata = _verify_plotting_tree (sdata )
284
306
n_steps = len (sdata .plotting_tree .keys ())
285
307
sdata .plotting_tree [f"{ n_steps + 1 } _render_images" ] = {
286
308
"palette" : palette ,
@@ -364,7 +386,7 @@ def render_channels(
364
386
raise ValueError ("Percentile parameters must satisfy pmin < pmax." )
365
387
366
388
sdata = self ._copy ()
367
- sdata = _verify_plotting_tree_exists (sdata )
389
+ sdata = _verify_plotting_tree (sdata )
368
390
n_steps = len (sdata .plotting_tree .keys ())
369
391
370
392
sdata .plotting_tree [f"{ n_steps + 1 } _render_channels" ] = {
@@ -381,15 +403,19 @@ def render_channels(
381
403
382
404
def render_labels (
383
405
self ,
384
- instance_key : Optional [Union [str , None ]] = None ,
385
- color_key : Optional [Union [str , None ]] = None ,
386
- border_alpha : float = 1.0 ,
387
- border_color : Optional [Union [str , None ]] = None ,
388
- fill_alpha : float = 0.5 ,
389
- fill_color : Optional [Union [str , None ]] = None ,
390
- mode : str = "thick" ,
391
- palette : Optional [Union [str , list [str ]]] = None ,
392
- add_legend : bool = True ,
406
+ region : str | Sequence [str ] | None = None ,
407
+ color : str | None = None ,
408
+ groups : str | Sequence [str ] | None = None ,
409
+ contour_px : int | None = None ,
410
+ outline : bool = False ,
411
+ alt_var : str | None = None ,
412
+ layer : str | None = None ,
413
+ palette : Palette_t = None ,
414
+ cmap : Colormap | str | None = None ,
415
+ norm : Optional [Normalize ] = None ,
416
+ na_color : str | tuple [float , ...] | None = (0.0 , 0.0 , 0.0 , 0.0 ),
417
+ alpha : float = 1.0 ,
418
+ ** kwargs : Any ,
393
419
) -> sd .SpatialData :
394
420
"""Render the labels contained in the given sd.SpatialData object
395
421
@@ -399,24 +425,6 @@ def render_labels(
399
425
sd.SpatialData
400
426
instance_key : str
401
427
The name of the column in the table that identifies individual labels
402
- color_key : str or None, optional (default: None)
403
- The name of the column in the table to use for coloring labels.
404
- border_alpha : float, optional (default: 1.0)
405
- The alpha value of the label border. Must be between 0 and 1.
406
- border_color : str or None, optional (default: None)
407
- The color of the border of the labels.
408
- fill_alpha : float, optional (default: 0.5)
409
- The alpha value of the fill of the labels. Must be between 0 and 1.
410
- fill_color : str or None, optional (default: None)
411
- The color of the fill of the labels.
412
- mode : str, optional (default: 'thick')
413
- The rendering mode of the labels. Must be one of 'thick', 'inner',
414
- 'outer', or 'subpixel'.
415
- palette : str, list or None, optional (default: None)
416
- The color palette to use when coloring cells. If None, a default
417
- palette will be used.
418
- add_legend : bool, optional (default: True)
419
- Whether to add a legend to the plot.
420
428
421
429
Returns
422
430
-------
@@ -441,66 +449,30 @@ def render_labels(
441
449
alpha, color, and rendering mode of the labels, as well as whether to add a
442
450
legend to the plot.
443
451
"""
444
- if instance_key is not None :
445
- if not isinstance (instance_key , str ):
446
- raise TypeError ("Parameter 'instance_key' must be a string." )
447
-
448
- if instance_key not in self ._sdata .table .obs :
449
- raise ValueError (f"The provided instance_key '{ instance_key } ' is not a valid table column." )
450
- else :
451
- instance_key = self ._sdata .table .uns ["spatialdata_attrs" ]["instance_key" ]
452
-
453
- if color_key is not None :
454
- if not isinstance (color_key , (str , type (None ))):
455
- raise TypeError ("Parameter 'color_key' must be a string." )
456
-
457
- if color_key not in self ._sdata .table .obs .columns and color_key not in self ._sdata .table .var_names :
458
- raise ValueError (f"The provided color_key '{ color_key } ' is not a valid table column." )
459
-
460
- if not isinstance (border_alpha , (int , float )):
461
- raise TypeError ("Parameter 'border_alpha' must be a float." )
462
-
463
- if not (border_alpha <= 1 and border_alpha >= 0 ):
464
- raise ValueError ("Parameter 'border_alpha' must be between 0 and 1." )
465
452
466
- if border_color is not None :
467
- if not isinstance (color_key , (str , type (None ))):
468
- raise TypeError ("If specified, parameter 'border_color' must be a string." )
469
-
470
- if not isinstance (fill_alpha , (int , float )):
471
- raise TypeError ("Parameter 'fill_alpha' must be a float." )
472
-
473
- if not (fill_alpha <= 1 and fill_alpha >= 0 ):
474
- raise ValueError ("Parameter 'fill_alpha' must be between 0 and 1." )
475
-
476
- if fill_color is not None :
477
- if not isinstance (fill_color , (str , type (None ))):
478
- raise TypeError ("If specified, parameter 'fill_color' must be a string." )
479
-
480
- valid_modes = ["thick" , "inner" , "outer" , "subpixel" ]
481
- if not isinstance (mode , str ):
482
- raise TypeError ("Parameter 'mode' must be a string." )
483
-
484
- if mode not in valid_modes :
485
- raise ValueError ("Parameter 'mode' must be one of 'thick', 'inner', 'outer', 'subpixel'." )
486
-
487
- if not isinstance (add_legend , bool ):
488
- raise TypeError ("Parameter 'add_legend' must be a boolean." )
453
+ if (
454
+ color is not None
455
+ and color not in self ._sdata .table .obs .columns
456
+ and color not in self ._sdata .table .var_names
457
+ ):
458
+ raise ValueError (f"'{ color } ' is not a valid table column." )
489
459
490
460
sdata = self ._copy ()
491
- sdata = _verify_plotting_tree_exists (sdata )
461
+ sdata = _verify_plotting_tree (sdata )
492
462
n_steps = len (sdata .plotting_tree .keys ())
493
- sdata .plotting_tree [f"{ n_steps + 1 } _render_labels" ] = {
494
- "instance_key" : instance_key ,
495
- "color_key" : color_key ,
496
- "border_alpha" : border_alpha ,
497
- "border_color" : border_color ,
498
- "fill_alpha" : fill_alpha ,
499
- "fill_color" : fill_color ,
500
- "mode" : mode ,
501
- "palette" : palette ,
502
- "add_legend" : add_legend ,
503
- }
463
+ cmap_params = _prepare_cmap_norm (cmap = cmap , norm = norm , na_color = na_color , ** kwargs )
464
+ sdata .plotting_tree [f"{ n_steps + 1 } _render_labels" ] = LabelsRenderParams (
465
+ region = region ,
466
+ color = color ,
467
+ groups = groups ,
468
+ contour_px = contour_px ,
469
+ outline = outline ,
470
+ alt_var = alt_var ,
471
+ layer = layer ,
472
+ cmap_params = cmap_params ,
473
+ palette = palette ,
474
+ alpha = alpha ,
475
+ )
504
476
505
477
return sdata
506
478
@@ -860,7 +832,7 @@ def show(
860
832
861
833
for idx , ax in enumerate (axs ):
862
834
key = list (sdata .labels .keys ())[idx ]
863
- _render_labels (sdata = sdata , params = params , key = key , ax = ax , extent = extent )
835
+ _render_labels (sdata = sdata , render_params = params , key = key , ax = ax , extent = extent )
864
836
865
837
else :
866
838
raise NotImplementedError (f"Command '{ cmd } ' is not supported." )
0 commit comments