19
19
from multiscale_spatial_image .multiscale_spatial_image import MultiscaleSpatialImage
20
20
from pandas .api .types import is_categorical_dtype
21
21
from spatial_image import SpatialImage
22
- from spatialdata ._logging import logger as logg
22
+ from spatialdata ._core .data_extent import get_extent
23
+ from spatialdata .transformations .operations import get_transformation
23
24
24
25
from spatialdata_plot ._accessor import register_spatial_data_accessor
25
26
from spatialdata_plot .pl .render import (
40
41
)
41
42
from spatialdata_plot .pl .utils import (
42
43
_get_cs_contents ,
43
- _get_extent ,
44
44
_maybe_set_colors ,
45
45
_mpl_ax_contains_elements ,
46
46
_prepare_cmap_norm ,
47
47
_prepare_params_plot ,
48
- _robust_transform ,
49
48
_set_outline ,
50
49
save_fig ,
51
50
)
@@ -216,6 +215,8 @@ def render_shapes(
216
215
na_color = na_color , # type: ignore[arg-type]
217
216
** kwargs ,
218
217
)
218
+ if isinstance (elements , str ):
219
+ elements = [elements ]
219
220
outline_params = _set_outline (outline , outline_width , outline_color )
220
221
sdata .plotting_tree [f"{ n_steps + 1 } _render_shapes" ] = ShapesRenderParams (
221
222
elements = elements ,
@@ -285,12 +286,15 @@ def render_points(
285
286
sdata = self ._copy ()
286
287
sdata = _verify_plotting_tree (sdata )
287
288
n_steps = len (sdata .plotting_tree .keys ())
289
+
288
290
cmap_params = _prepare_cmap_norm (
289
291
cmap = cmap ,
290
292
norm = norm ,
291
293
na_color = na_color , # type: ignore[arg-type]
292
294
** kwargs ,
293
295
)
296
+ if isinstance (elements , str ):
297
+ elements = [elements ]
294
298
sdata .plotting_tree [f"{ n_steps + 1 } _render_points" ] = PointsRenderParams (
295
299
elements = elements ,
296
300
color = color ,
@@ -370,6 +374,8 @@ def render_images(
370
374
** kwargs ,
371
375
)
372
376
377
+ if isinstance (elements , str ):
378
+ elements = [elements ]
373
379
sdata .plotting_tree [f"{ n_steps + 1 } _render_images" ] = ImageRenderParams (
374
380
elements = elements ,
375
381
channel = channel ,
@@ -450,6 +456,8 @@ def render_labels(
450
456
na_color = na_color , # type: ignore[arg-type]
451
457
** kwargs ,
452
458
)
459
+ if isinstance (elements , str ):
460
+ elements = [elements ]
453
461
sdata .plotting_tree [f"{ n_steps + 1 } _render_labels" ] = LabelsRenderParams (
454
462
elements = elements ,
455
463
color = color ,
@@ -552,12 +560,12 @@ def show(
552
560
raise TypeError ("All titles must be strings." )
553
561
554
562
# get original axis extent for later comparison
555
- x_min_orig , x_max_orig = (np .inf , - np .inf )
556
- y_min_orig , y_max_orig = (np .inf , - np .inf )
563
+ ax_x_min , ax_x_max = (np .inf , - np .inf )
564
+ ax_y_min , ax_y_max = (np .inf , - np .inf )
557
565
558
566
if isinstance (ax , Axes ) and _mpl_ax_contains_elements (ax ):
559
- x_min_orig , x_max_orig = ax .get_xlim ()
560
- y_max_orig , y_min_orig = ax .get_ylim () # (0, 0) is top-left
567
+ ax_x_min , ax_x_max = ax .get_xlim ()
568
+ ax_y_max , ax_y_min = ax .get_ylim () # (0, 0) is top-left
561
569
562
570
# handle coordinate system
563
571
coordinate_systems = sdata .coordinate_systems if coordinate_systems is None else coordinate_systems
@@ -568,50 +576,6 @@ def show(
568
576
if cs not in sdata .coordinate_systems :
569
577
raise ValueError (f"Unknown coordinate system '{ cs } ', valid choices are: { sdata .coordinate_systems } " )
570
578
571
- # Check if user specified only certain elements to be plotted
572
- cs_contents = _get_cs_contents (sdata )
573
- elements_to_be_rendered = []
574
- for cmd , params in render_cmds .items ():
575
- if cmd == "render_images" and cs_contents .query (f"cs == '{ cs } '" )["has_images" ][0 ]: # noqa: SIM114
576
- if params .elements is not None :
577
- elements_to_be_rendered += (
578
- [params .elements ] if isinstance (params .elements , str ) else params .elements
579
- )
580
- elif cmd == "render_shapes" and cs_contents .query (f"cs == '{ cs } '" )["has_shapes" ][0 ]: # noqa: SIM114
581
- if params .elements is not None :
582
- elements_to_be_rendered += (
583
- [params .elements ] if isinstance (params .elements , str ) else params .elements
584
- )
585
- elif cmd == "render_points" and cs_contents .query (f"cs == '{ cs } '" )["has_points" ][0 ]: # noqa: SIM114
586
- if params .elements is not None :
587
- elements_to_be_rendered += (
588
- [params .elements ] if isinstance (params .elements , str ) else params .elements
589
- )
590
- elif cmd == "render_labels" and cs_contents .query (f"cs == '{ cs } '" )["has_labels" ][0 ]: # noqa: SIM102
591
- if params .elements is not None :
592
- elements_to_be_rendered += (
593
- [params .elements ] if isinstance (params .elements , str ) else params .elements
594
- )
595
-
596
- extent = _get_extent (
597
- sdata = sdata ,
598
- has_images = "render_images" in render_cmds ,
599
- has_labels = "render_labels" in render_cmds ,
600
- has_points = "render_points" in render_cmds ,
601
- has_shapes = "render_shapes" in render_cmds ,
602
- elements = elements_to_be_rendered ,
603
- coordinate_systems = coordinate_systems ,
604
- )
605
-
606
- # Use extent to filter out coordinate system without the relevant elements
607
- valid_cs = []
608
- for cs in coordinate_systems :
609
- if cs in extent :
610
- valid_cs .append (cs )
611
- else :
612
- logg .info (f"Dropping coordinate system '{ cs } ' since it doesn't have relevant elements." )
613
- coordinate_systems = valid_cs
614
-
615
579
# set up canvas
616
580
fig_params , scalebar_params = _prepare_params_plot (
617
581
num_panels = len (coordinate_systems ),
@@ -633,32 +597,25 @@ def show(
633
597
colorbar = colorbar ,
634
598
)
635
599
600
+ cs_contents = _get_cs_contents (sdata )
601
+
636
602
# go through tree
603
+
637
604
for i , cs in enumerate (coordinate_systems ):
638
605
sdata = self ._copy ()
639
- # properly transform all elements to the current coordinate system
640
- members = cs_contents .query (f"cs == '{ cs } '" )
641
-
642
- if members ["has_images" ].values [0 ]:
643
- for key in sdata .images :
644
- sdata .images [key ] = _robust_transform (sdata .images [key ], cs )
645
-
646
- if members ["has_labels" ].values [0 ]:
647
- for key in sdata .labels :
648
- sdata .labels [key ] = _robust_transform (sdata .labels [key ], cs )
649
-
650
- if members ["has_points" ].values [0 ]:
651
- for key in sdata .points :
652
- sdata .points [key ] = _robust_transform (sdata .points [key ], cs )
653
-
654
- if members ["has_shapes" ].values [0 ]:
655
- for key in sdata .shapes :
656
- sdata .shapes [key ] = _robust_transform (sdata .shapes [key ], cs )
657
-
606
+ _ , has_images , has_labels , has_points , has_shapes = (
607
+ cs_contents .query (f"cs == '{ cs } '" ).iloc [0 , :].values .tolist ()
608
+ )
658
609
ax = fig_params .ax if fig_params .axs is None else fig_params .axs [i ]
659
610
611
+ wants_images = False
612
+ wants_labels = False
613
+ wants_points = False
614
+ wants_shapes = False
615
+ wanted_elements = []
616
+
660
617
for cmd , params in render_cmds .items ():
661
- if cmd == "render_images" and cs_contents . query ( f"cs == ' { cs } '" )[ " has_images" ][ 0 ] :
618
+ if cmd == "render_images" and has_images :
662
619
_render_images (
663
620
sdata = sdata ,
664
621
render_params = params ,
@@ -667,9 +624,18 @@ def show(
667
624
fig_params = fig_params ,
668
625
scalebar_params = scalebar_params ,
669
626
legend_params = legend_params ,
670
- # extent=extent[cs],
671
627
)
672
- elif cmd == "render_shapes" and cs_contents .query (f"cs == '{ cs } '" )["has_shapes" ][0 ]:
628
+ wants_images = True
629
+ wanted_images = params .elements if params .elements is not None else list (sdata .images .keys ())
630
+ wanted_elements .extend (
631
+ [
632
+ image
633
+ for image in wanted_images
634
+ if cs in set (get_transformation (sdata .images [image ], get_all = True ).keys ())
635
+ ]
636
+ )
637
+
638
+ elif cmd == "render_shapes" and has_shapes :
673
639
_render_shapes (
674
640
sdata = sdata ,
675
641
render_params = params ,
@@ -679,8 +645,17 @@ def show(
679
645
scalebar_params = scalebar_params ,
680
646
legend_params = legend_params ,
681
647
)
648
+ wants_shapes = True
649
+ wanted_shapes = params .elements if params .elements is not None else list (sdata .shapes .keys ())
650
+ wanted_elements .extend (
651
+ [
652
+ shape
653
+ for shape in wanted_shapes
654
+ if cs in set (get_transformation (sdata .shapes [shape ], get_all = True ).keys ())
655
+ ]
656
+ )
682
657
683
- elif cmd == "render_points" and cs_contents . query ( f"cs == ' { cs } '" )[ " has_points" ][ 0 ] :
658
+ elif cmd == "render_points" and has_points :
684
659
_render_points (
685
660
sdata = sdata ,
686
661
render_params = params ,
@@ -690,8 +665,17 @@ def show(
690
665
scalebar_params = scalebar_params ,
691
666
legend_params = legend_params ,
692
667
)
668
+ wants_points = True
669
+ wanted_points = params .elements if params .elements is not None else list (sdata .points .keys ())
670
+ wanted_elements .extend (
671
+ [
672
+ point
673
+ for point in wanted_points
674
+ if cs in set (get_transformation (sdata .points [point ], get_all = True ).keys ())
675
+ ]
676
+ )
693
677
694
- elif cmd == "render_labels" and cs_contents . query ( f"cs == ' { cs } '" )[ " has_labels" ][ 0 ] :
678
+ elif cmd == "render_labels" and has_labels :
695
679
if sdata .table is not None and isinstance (params .color , str ):
696
680
colors = sc .get .obs_df (sdata .table , params .color )
697
681
if is_categorical_dtype (colors ):
@@ -710,33 +694,46 @@ def show(
710
694
scalebar_params = scalebar_params ,
711
695
legend_params = legend_params ,
712
696
)
697
+ wants_labels = True
698
+ wanted_labels = params .elements if params .elements is not None else list (sdata .labels .keys ())
699
+ wanted_elements .extend (
700
+ [
701
+ label
702
+ for label in wanted_labels
703
+ if cs in set (get_transformation (sdata .labels [label ], get_all = True ).keys ())
704
+ ]
705
+ )
713
706
714
- if title is not None :
715
- if len (title ) == 1 :
716
- t = title [0 ]
717
- else :
718
- try :
719
- t = title [i ]
720
- except IndexError as e :
721
- raise IndexError ("The number of titles must match the number of coordinate systems." ) from e
722
- else :
707
+ if title is None :
723
708
t = cs
709
+ elif len (title ) == 1 :
710
+ t = title [0 ]
711
+ else :
712
+ try :
713
+ t = title [i ]
714
+ except IndexError as e :
715
+ raise IndexError ("The number of titles must match the number of coordinate systems." ) from e
724
716
ax .set_title (t )
725
717
ax .set_aspect ("equal" )
726
718
727
- if any (
728
- [
729
- cs_contents .query (f"cs == '{ cs } '" )["has_images" ][0 ],
730
- cs_contents .query (f"cs == '{ cs } '" )["has_labels" ][0 ],
731
- cs_contents .query (f"cs == '{ cs } '" )["has_points" ][0 ],
732
- cs_contents .query (f"cs == '{ cs } '" )["has_shapes" ][0 ],
733
- ]
734
- ):
719
+ extent = get_extent (
720
+ sdata ,
721
+ coordinate_system = cs ,
722
+ has_images = has_images and wants_images ,
723
+ has_labels = has_labels and wants_labels ,
724
+ has_points = has_points and wants_points ,
725
+ has_shapes = has_shapes and wants_shapes ,
726
+ elements = wanted_elements ,
727
+ )
728
+ cs_x_min , cs_x_max = extent ["x" ]
729
+ cs_y_min , cs_y_max = extent ["y" ]
730
+
731
+ if any ([has_images , has_labels , has_points , has_shapes ]):
735
732
# If the axis already has limits, only expand them but not overwrite
736
- x_min = min (x_min_orig , extent [ cs ][ 0 ] ) - pad_extent
737
- x_max = max (x_max_orig , extent [ cs ][ 1 ] ) + pad_extent
738
- y_min = min (y_min_orig , extent [ cs ][ 2 ] ) - pad_extent
739
- y_max = max (y_max_orig , extent [ cs ][ 3 ] ) + pad_extent
733
+ x_min = min (ax_x_min , cs_x_min ) - pad_extent
734
+ x_max = max (ax_x_max , cs_x_max ) + pad_extent
735
+ y_min = min (ax_y_min , cs_y_min ) - pad_extent
736
+ y_max = max (ax_y_max , cs_y_max ) + pad_extent
740
737
ax .set_xlim (x_min , x_max )
741
738
ax .set_ylim (y_max , y_min ) # (0, 0) is top-left
742
739
@@ -747,5 +744,4 @@ def show(
747
744
# https://stackoverflow.com/a/64523765
748
745
if not hasattr (sys , "ps1" ):
749
746
plt .show ()
750
-
751
747
return (fig_params .ax if fig_params .axs is None else fig_params .axs ) if return_ax else None # shuts up ruff
0 commit comments