7
7
from functools import partial
8
8
from pathlib import Path
9
9
from types import MappingProxyType
10
- from typing import Any , Literal , Optional , Union
10
+ from typing import Any , Literal
11
11
12
12
import matplotlib
13
13
import matplotlib .pyplot as plt
46
46
47
47
from spatialdata_plot .pp .utils import _get_coordinate_system_mapping
48
48
49
- Palette_t = Optional [Union [str , ListedColormap ]]
50
- _Normalize = Union [Normalize , Sequence [Normalize ]]
51
- _SeqStr = Union [str , Sequence [str ]]
52
49
_FontWeight = Literal ["light" , "normal" , "medium" , "semibold" , "bold" , "heavy" , "black" ]
53
50
_FontSize = Literal ["xx-small" , "x-small" , "small" , "medium" , "large" , "x-large" , "xx-large" ]
54
51
@@ -90,7 +87,7 @@ def _prepare_params_plot(
90
87
frameon : bool | None = None ,
91
88
# this is passed at `render_*`
92
89
cmap : Colormap | str | None = None ,
93
- norm : _Normalize | None = None ,
90
+ norm : Normalize | Sequence [ Normalize ] | None = None ,
94
91
na_color : str | tuple [float , ...] | None = (0.0 , 0.0 , 0.0 , 0.0 ),
95
92
vmin : float | None = None ,
96
93
vmax : float | None = None ,
@@ -178,11 +175,12 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame:
178
175
179
176
def _get_extent (
180
177
sdata : sd .SpatialData ,
181
- coordinate_systems : None | str | Sequence [ str ] = None ,
178
+ coordinate_systems : Sequence [ str ] | str | None = None ,
182
179
has_images : bool = True ,
183
180
has_labels : bool = True ,
184
181
has_points : bool = True ,
185
182
has_shapes : bool = True ,
183
+ elements : Iterable [Any ] | None = None ,
186
184
share_extent : bool = False ,
187
185
) -> dict [str , tuple [int , int , int , int ]]:
188
186
"""Return the extent of all elements in their respective coordinate systems.
@@ -191,16 +189,18 @@ def _get_extent(
191
189
----------
192
190
sdata
193
191
The sd.SpatialData object to retrieve the extent from
194
- images
192
+ has_images
195
193
Flag indicating whether to consider images when calculating the extent
196
- labels
194
+ has_labels
197
195
Flag indicating whether to consider labels when calculating the extent
198
- points
196
+ has_points
199
197
Flag indicating whether to consider points when calculating the extent
200
- shapes
201
- Flag indicating whether to consider shaoes when calculating the extent
202
- img_transformations
203
- List of transformations already applied to the images
198
+ has_shapes
199
+ Flag indicating whether to consider shapes when calculating the extent
200
+ elements
201
+ Optional list of element names to be considered. When None, all are used.
202
+ share_extent
203
+ Flag indicating whether to use the same extent for all coordinate systems
204
204
205
205
Returns
206
206
-------
@@ -212,6 +212,12 @@ def _get_extent(
212
212
cs_mapping = _get_coordinate_system_mapping (sdata )
213
213
cs_contents = _get_cs_contents (sdata )
214
214
215
+ if elements is None : # to shut up ruff
216
+ elements = []
217
+
218
+ if not isinstance (elements , list ):
219
+ raise ValueError (f"Invalid type of `elements`: { type (elements )} , expected `list`." )
220
+
215
221
if coordinate_systems is not None :
216
222
if isinstance (coordinate_systems , str ):
217
223
coordinate_systems = [coordinate_systems ]
@@ -220,6 +226,8 @@ def _get_extent(
220
226
221
227
for cs_name , element_ids in cs_mapping .items ():
222
228
extent [cs_name ] = {}
229
+ if len (elements ) > 0 :
230
+ element_ids = [e for e in element_ids if e in elements ]
223
231
224
232
def _get_extent_after_transformations (element : Any , cs_name : str ) -> Sequence [int ]:
225
233
tmp = element .copy ()
@@ -449,7 +457,7 @@ class CmapParams:
449
457
450
458
def _prepare_cmap_norm (
451
459
cmap : Colormap | str | None = None ,
452
- norm : _Normalize | None = None ,
460
+ norm : Normalize | Sequence [ Normalize ] | None = None ,
453
461
na_color : str | tuple [float , ...] = (0.0 , 0.0 , 0.0 , 0.0 ),
454
462
vmin : float | None = None ,
455
463
vmax : float | None = None ,
@@ -627,7 +635,11 @@ def _normalize(
627
635
return norm
628
636
629
637
630
- def _get_colors_for_categorical_obs (categories : Sequence [str | int ], palette : Palette_t = None ) -> list [str ]:
638
+ def _get_colors_for_categorical_obs (
639
+ categories : Sequence [str | int ],
640
+ palette : ListedColormap | str | None = None ,
641
+ alpha : float = 1.0 ,
642
+ ) -> list [str ]:
631
643
"""
632
644
Return a list of colors for a categorical observation.
633
645
@@ -644,27 +656,40 @@ def _get_colors_for_categorical_obs(categories: Sequence[str | int], palette: Pa
644
656
-------
645
657
None
646
658
"""
647
- length = len (categories )
659
+ len_cat = len (categories )
648
660
649
661
# check if default matplotlib palette has enough colors
650
662
if palette is None :
651
- if len (rcParams ["axes.prop_cycle" ].by_key ()["color" ]) >= length :
663
+ if len (rcParams ["axes.prop_cycle" ].by_key ()["color" ]) >= len_cat :
652
664
cc = rcParams ["axes.prop_cycle" ]()
653
- palette = [next (cc )["color" ] for _ in range (length )]
665
+ palette = [next (cc )["color" ] for _ in range (len_cat )]
654
666
else :
655
- if length <= 20 :
667
+ if len_cat <= 20 :
656
668
palette = default_20
657
- elif length <= 28 :
669
+ elif len_cat <= 28 :
658
670
palette = default_28
659
- elif length <= len (default_102 ): # 103 colors
671
+ elif len_cat <= len (default_102 ): # 103 colors
660
672
palette = default_102
661
673
else :
662
- palette = ["grey" for _ in range (length )]
674
+ palette = ["grey" for _ in range (len_cat )]
663
675
logging .info (
664
676
"input has more than 103 categories. Uniform " "'grey' color will be used for all categories."
665
677
)
666
678
667
- return palette [:length ] # type: ignore[return-value]
679
+ # otherwise, single chanels turn out grey
680
+ color_idx = np .linspace (0 , 1 , len_cat ) if len_cat > 1 else [0.7 ]
681
+
682
+ if isinstance (palette , str ):
683
+ cmap = plt .get_cmap (palette )
684
+ palette = [to_hex (x ) for x in cmap (color_idx , alpha = alpha )]
685
+ elif isinstance (palette , ListedColormap ):
686
+ palette = [to_hex (x ) for x in palette (color_idx , alpha = alpha )]
687
+ elif isinstance (palette , LinearSegmentedColormap ):
688
+ palette = [to_hex (palette (x , alpha = alpha )) for x in [color_idx ]]
689
+ else :
690
+ raise TypeError (f"Palette is { type (palette )} but should be string or `ListedColormap`." )
691
+
692
+ return palette [:len_cat ] # type: ignore[return-value]
668
693
669
694
670
695
def _set_color_source_vec (
@@ -673,8 +698,8 @@ def _set_color_source_vec(
673
698
use_raw : bool | None = None ,
674
699
alt_var : str | None = None ,
675
700
layer : str | None = None ,
676
- groups : _SeqStr | None = None ,
677
- palette : Palette_t = None ,
701
+ groups : Sequence [ str ] | str | None = None ,
702
+ palette : ListedColormap | str | None = None ,
678
703
na_color : str | tuple [float , ...] | None = None ,
679
704
alpha : float = 1.0 ,
680
705
) -> tuple [ArrayLike | pd .Series | None , ArrayLike , bool ]:
@@ -769,7 +794,7 @@ def _get_palette(
769
794
categories : Sequence [Any ],
770
795
adata : AnnData | None = None ,
771
796
cluster_key : None | str = None ,
772
- palette : Palette_t = None ,
797
+ palette : ListedColormap | str | None = None ,
773
798
alpha : float = 1.0 ,
774
799
) -> Mapping [str , str ] | None :
775
800
if adata is not None and palette is None :
@@ -845,7 +870,7 @@ def _decorate_axs(
845
870
adata : AnnData ,
846
871
value_to_plot : str | None ,
847
872
color_source_vector : pd .Series [CategoricalDtype ],
848
- palette : Palette_t = None ,
873
+ palette : ListedColormap | str | None = None ,
849
874
alpha : float = 1.0 ,
850
875
na_color : str | tuple [float , ...] = (0.0 , 0.0 , 0.0 , 0.0 ),
851
876
legend_fontsize : int | float | _FontSize | None = None ,
@@ -1127,3 +1152,17 @@ def _robust_transform(element: Any, cs: str) -> Any:
1127
1152
raise ValueError ("Unable to transform element." ) from e
1128
1153
1129
1154
return element
1155
+
1156
+
1157
+ def _mpl_ax_contains_elements (ax : Axes ) -> bool :
1158
+ """Check if any objects have been plotted on the axes object.
1159
+
1160
+ While extracting the extent, we need to know if the axes object has just been
1161
+ initialised and therefore has extent (0, 1), (0,1) or if it has been plotted on
1162
+ and therefore has a different extent.
1163
+
1164
+ Based on: https://stackoverflow.com/a/71966295
1165
+ """
1166
+ return (
1167
+ len (ax .lines ) > 0 or len (ax .collections ) > 0 or len (ax .images ) > 0 or len (ax .patches ) > 0 or len (ax .tables ) > 0
1168
+ )
0 commit comments