7
7
from typing import Any
8
8
9
9
import matplotlib .pyplot as plt
10
+ import numpy as np
10
11
import scanpy as sc
11
12
import spatialdata as sd
12
13
from anndata import AnnData
13
14
from dask .dataframe .core import DataFrame as DaskDataFrame
14
15
from geopandas import GeoDataFrame
15
16
from matplotlib .axes import Axes
16
- from matplotlib .colors import Colormap , Normalize
17
+ from matplotlib .colors import Colormap , ListedColormap , Normalize
17
18
from matplotlib .figure import Figure
18
19
from multiscale_spatial_image .multiscale_spatial_image import MultiscaleSpatialImage
19
20
from pandas .api .types import is_categorical_dtype
32
33
_render_shapes ,
33
34
)
34
35
from spatialdata_plot .pl .utils import (
36
+ CmapParams ,
35
37
LegendParams ,
36
- Palette_t ,
37
38
_FontSize ,
38
39
_FontWeight ,
39
40
_get_cs_contents ,
40
41
_get_extent ,
41
42
_maybe_set_colors ,
42
- _multiscale_to_image ,
43
+ _mpl_ax_contains_elements ,
43
44
_prepare_cmap_norm ,
44
45
_prepare_params_plot ,
45
46
_robust_transform ,
@@ -144,11 +145,11 @@ def render_shapes(
144
145
groups : str | Sequence [str ] | None = None ,
145
146
size : float = 1.0 ,
146
147
outline : bool = False ,
147
- outline_width : tuple [ float , float ] = ( 0.3 , 0.05 ) ,
148
- outline_color : tuple [ str , str ] = ( "#000000ff" , "#ffffffff" ), # black, white
148
+ outline_width : float = 1.5 ,
149
+ outline_color : str | list [ float ] = "#000000ff" ,
149
150
alt_var : str | None = None ,
150
151
layer : str | None = None ,
151
- palette : Palette_t = None ,
152
+ palette : ListedColormap | str | None = None ,
152
153
cmap : Colormap | str | None = None ,
153
154
norm : None | Normalize = None ,
154
155
na_color : str | tuple [float , ...] | None = "lightgrey" ,
@@ -194,6 +195,11 @@ def render_shapes(
194
195
kwargs
195
196
Additional arguments to be passed to cmap and norm.
196
197
198
+ Notes
199
+ -----
200
+ Empty geometries will be removed at the time of plotting.
201
+ An ``outline_width`` of 0.0 leads to no border being plotted.
202
+
197
203
Returns
198
204
-------
199
205
None
@@ -230,7 +236,7 @@ def render_points(
230
236
color : str | None = None ,
231
237
groups : str | Sequence [str ] | None = None ,
232
238
size : float = 1.0 ,
233
- palette : Palette_t = None ,
239
+ palette : ListedColormap | str | None = None ,
234
240
cmap : Colormap | str | None = None ,
235
241
norm : None | Normalize = None ,
236
242
na_color : str | tuple [float , ...] | None = (0.0 , 0.0 , 0.0 , 0.0 ),
@@ -295,11 +301,12 @@ def render_images(
295
301
self ,
296
302
elements : str | list [str ] | None = None ,
297
303
channel : list [str ] | list [int ] | int | str | None = None ,
298
- cmap : Colormap | str | None = None ,
304
+ cmap : list [ Colormap ] | list [ str ] | Colormap | str | None = None ,
299
305
norm : None | Normalize = None ,
300
306
na_color : str | tuple [float , ...] | None = (0.0 , 0.0 , 0.0 , 0.0 ),
301
- palette : Palette_t = None ,
307
+ palette : ListedColormap | str | None = None ,
302
308
alpha : float = 1.0 ,
309
+ quantiles_for_norm : tuple [float | None , float | None ] = (3.0 , 99.8 ), # defaults from CSBDeep
303
310
** kwargs : Any ,
304
311
) -> sd .SpatialData :
305
312
"""
@@ -320,6 +327,8 @@ def render_images(
320
327
Color to be used for NAs values, if present.
321
328
alpha
322
329
Alpha value for the shapes.
330
+ quantiles_for_norm
331
+ Tuple of (pmin, pmax) which will be used for quantile normalization.
323
332
kwargs
324
333
Additional arguments to be passed to cmap and norm.
325
334
@@ -331,18 +340,36 @@ def render_images(
331
340
sdata = _verify_plotting_tree (sdata )
332
341
n_steps = len (sdata .plotting_tree .keys ())
333
342
334
- cmap_params = _prepare_cmap_norm (
335
- cmap = cmap ,
336
- norm = norm ,
337
- na_color = na_color , # type: ignore[arg-type]
338
- ** kwargs ,
339
- )
343
+ if channel is None and cmap is None :
344
+ cmap = "brg"
345
+
346
+ cmap_params : list [CmapParams ] | CmapParams
347
+ if isinstance (cmap , list ):
348
+ cmap_params = [
349
+ _prepare_cmap_norm (
350
+ cmap = c ,
351
+ norm = norm ,
352
+ na_color = na_color , # type: ignore[arg-type]
353
+ ** kwargs ,
354
+ )
355
+ for c in cmap
356
+ ]
357
+
358
+ else :
359
+ cmap_params = _prepare_cmap_norm (
360
+ cmap = cmap ,
361
+ norm = norm ,
362
+ na_color = na_color , # type: ignore[arg-type]
363
+ ** kwargs ,
364
+ )
365
+
340
366
sdata .plotting_tree [f"{ n_steps + 1 } _render_images" ] = ImageRenderParams (
341
367
elements = elements ,
342
368
channel = channel ,
343
369
cmap_params = cmap_params ,
344
370
palette = palette ,
345
371
alpha = alpha ,
372
+ quantiles_for_norm = quantiles_for_norm ,
346
373
)
347
374
348
375
return sdata
@@ -356,7 +383,7 @@ def render_labels(
356
383
outline : bool = False ,
357
384
alt_var : str | None = None ,
358
385
layer : str | None = None ,
359
- palette : Palette_t = None ,
386
+ palette : ListedColormap | str | None = None ,
360
387
cmap : Colormap | str | None = None ,
361
388
norm : None | Normalize = None ,
362
389
na_color : str | tuple [float , ...] | None = (0.0 , 0.0 , 0.0 , 0.0 ),
@@ -454,6 +481,7 @@ def show(
454
481
fig : Figure | None = None ,
455
482
title : None | str | Sequence [str ] = None ,
456
483
share_extent : bool = True ,
484
+ pad_extent : int = 0 ,
457
485
ax : Axes | Sequence [Axes ] | None = None ,
458
486
return_ax : bool = False ,
459
487
save : None | str | Path = None ,
@@ -511,7 +539,7 @@ def show(
511
539
render_cmds [cmd ] = params
512
540
513
541
if len (render_cmds .keys ()) == 0 :
514
- raise TypeError ("Please specify what to plot using the 'render_*' functions before calling 'imshow()." )
542
+ raise TypeError ("Please specify what to plot using the 'render_*' functions before calling 'imshow()' ." )
515
543
516
544
if title is not None :
517
545
if isinstance (title , str ):
@@ -520,8 +548,13 @@ def show(
520
548
if not all (isinstance (t , str ) for t in title ):
521
549
raise TypeError ("All titles must be strings." )
522
550
523
- # Simplicstic solution: If the images are multiscale, just use the first
524
- sdata = _multiscale_to_image (sdata )
551
+ # get original axis extent for later comparison
552
+ x_min_orig , x_max_orig = (np .inf , - np .inf )
553
+ y_min_orig , y_max_orig = (np .inf , - np .inf )
554
+
555
+ if isinstance (ax , Axes ) and _mpl_ax_contains_elements (ax ):
556
+ x_min_orig , x_max_orig = ax .get_xlim ()
557
+ y_max_orig , y_min_orig = ax .get_ylim () # (0, 0) is top-left
525
558
526
559
# handle coordinate system
527
560
coordinate_systems = sdata .coordinate_systems if coordinate_systems is None else coordinate_systems
@@ -532,12 +565,38 @@ def show(
532
565
if cs not in sdata .coordinate_systems :
533
566
raise ValueError (f"Unknown coordinate system '{ cs } ', valid choices are: { sdata .coordinate_systems } " )
534
567
568
+ # Check if user specified only certain elements to be plotted
569
+ cs_contents = _get_cs_contents (sdata )
570
+ elements_to_be_rendered = []
571
+ for cmd , params in render_cmds .items ():
572
+ if cmd == "render_images" and cs_contents .query (f"cs == '{ cs } '" )["has_images" ][0 ]: # noqa: SIM114
573
+ if params .elements is not None :
574
+ elements_to_be_rendered += (
575
+ [params .elements ] if isinstance (params .elements , str ) else params .elements
576
+ )
577
+ elif cmd == "render_shapes" and cs_contents .query (f"cs == '{ cs } '" )["has_shapes" ][0 ]: # noqa: SIM114
578
+ if params .elements is not None :
579
+ elements_to_be_rendered += (
580
+ [params .elements ] if isinstance (params .elements , str ) else params .elements
581
+ )
582
+ elif cmd == "render_points" and cs_contents .query (f"cs == '{ cs } '" )["has_points" ][0 ]: # noqa: SIM114
583
+ if params .elements is not None :
584
+ elements_to_be_rendered += (
585
+ [params .elements ] if isinstance (params .elements , str ) else params .elements
586
+ )
587
+ elif cmd == "render_labels" and cs_contents .query (f"cs == '{ cs } '" )["has_labels" ][0 ]: # noqa: SIM102
588
+ if params .elements is not None :
589
+ elements_to_be_rendered += (
590
+ [params .elements ] if isinstance (params .elements , str ) else params .elements
591
+ )
592
+
535
593
extent = _get_extent (
536
594
sdata = sdata ,
537
595
has_images = "render_images" in render_cmds ,
538
596
has_labels = "render_labels" in render_cmds ,
539
597
has_points = "render_points" in render_cmds ,
540
598
has_shapes = "render_shapes" in render_cmds ,
599
+ elements = elements_to_be_rendered ,
541
600
coordinate_systems = coordinate_systems ,
542
601
)
543
602
@@ -550,19 +609,6 @@ def show(
550
609
logg .info (f"Dropping coordinate system '{ cs } ' since it doesn't have relevant elements." )
551
610
coordinate_systems = valid_cs
552
611
553
- # print(coordinate_systems)
554
- # cs_mapping = _get_coordinate_system_mapping(sdata)
555
- # print(cs_mapping)
556
-
557
- # check that coordinate system and elements to be rendered match
558
- # for cmd, params in render_cmds.items():
559
- # if params.elements is not None and len([params.elements]) != len(coordinate_systems):
560
- # print(params.elements)
561
- # raise ValueError(
562
- # f"Number of coordinate systems ({len(coordinate_systems)}) does not match number of elements "
563
- # f"({len(params.elements)}) in command {cmd}."
564
- # )
565
-
566
612
# set up canvas
567
613
fig_params , scalebar_params = _prepare_params_plot (
568
614
num_panels = len (coordinate_systems ),
@@ -585,7 +631,6 @@ def show(
585
631
)
586
632
587
633
# go through tree
588
- cs_contents = _get_cs_contents (sdata )
589
634
for i , cs in enumerate (coordinate_systems ):
590
635
sdata = self ._copy ()
591
636
# properly transform all elements to the current coordinate system
@@ -693,12 +738,10 @@ def show(
693
738
]
694
739
):
695
740
# If the axis already has limits, only expand them but not overwrite
696
- x_min , x_max = ax .get_xlim ()
697
- y_min , y_max = ax .get_ylim ()
698
- x_min = min (x_min , extent [cs ][0 ])
699
- x_max = max (x_max , extent [cs ][1 ])
700
- y_min = min (y_min , extent [cs ][2 ])
701
- y_max = max (y_max , extent [cs ][3 ])
741
+ x_min = min (x_min_orig , extent [cs ][0 ]) - pad_extent
742
+ x_max = max (x_max_orig , extent [cs ][1 ]) + pad_extent
743
+ y_min = min (y_min_orig , extent [cs ][2 ]) - pad_extent
744
+ y_max = max (y_max_orig , extent [cs ][3 ]) + pad_extent
702
745
ax .set_xlim (x_min , x_max )
703
746
ax .set_ylim (y_max , y_min ) # (0, 0) is top-left
704
747
0 commit comments