1
1
from __future__ import annotations
2
2
3
+ import warnings
3
4
from collections import abc
4
5
from copy import copy
5
6
from typing import Union , cast
37
38
_get_collection_shape ,
38
39
_get_colors_for_categorical_obs ,
39
40
_get_linear_colormap ,
41
+ _is_coercable_to_float ,
40
42
_map_color_seg ,
41
43
_maybe_set_colors ,
42
44
_multiscale_to_spatial_image ,
@@ -70,6 +72,7 @@ def _render_shapes(
70
72
elements = list (sdata_filt .shapes .keys ())
71
73
72
74
for index , e in enumerate (elements ):
75
+ col_for_color = render_params .col_for_color [index ]
73
76
shapes = sdata .shapes [e ]
74
77
75
78
table_name = element_table_mapping .get (e )
@@ -79,13 +82,28 @@ def _render_shapes(
79
82
_ , region_key , _ = get_table_keys (sdata [table_name ])
80
83
table = sdata [table_name ][sdata [table_name ].obs [region_key ].isin ([e ])]
81
84
85
+ if (
86
+ col_for_color is not None
87
+ and table_name is not None
88
+ and col_for_color in sdata_filt [table_name ].obs .columns
89
+ and (color_col := sdata_filt [table_name ].obs [col_for_color ]).dtype == "O"
90
+ and not _is_coercable_to_float (color_col )
91
+ ):
92
+ warnings .warn (
93
+ f"Converting copy of '{ col_for_color } ' column to categorical dtype for categorical plotting. "
94
+ f"Consider converting before plotting." ,
95
+ UserWarning ,
96
+ stacklevel = 2 ,
97
+ )
98
+ sdata_filt [table_name ].obs [col_for_color ] = sdata_filt [table_name ].obs [col_for_color ].astype ("category" )
99
+
82
100
# get color vector (categorical or continuous)
83
101
color_source_vector , color_vector , _ = _set_color_source_vec (
84
102
sdata = sdata_filt ,
85
103
element = sdata_filt .shapes [e ],
86
104
element_index = index ,
87
105
element_name = e ,
88
- value_to_plot = render_params . col_for_color [ index ] ,
106
+ value_to_plot = col_for_color ,
89
107
groups = render_params .groups [index ] if render_params .groups [index ][0 ] is not None else None ,
90
108
palette = (
91
109
render_params .palette [index ] if render_params .palette is not None else None
@@ -170,7 +188,7 @@ def _render_shapes(
170
188
cax = cax ,
171
189
fig_params = fig_params ,
172
190
adata = table ,
173
- value_to_plot = render_params . col_for_color [ index ] ,
191
+ value_to_plot = col_for_color ,
174
192
color_source_vector = color_source_vector ,
175
193
palette = palette ,
176
194
alpha = render_params .fill_alpha ,
@@ -212,22 +230,48 @@ def _render_points(
212
230
table_name = element_table_mapping .get (e )
213
231
214
232
coords = ["x" , "y" ]
215
- if col_for_color is not None :
216
- if col_for_color not in points .columns :
217
- # no error in case there are multiple elements, but onyl some have color key
218
- msg = f"Color key '{ col_for_color } ' for element '{ e } ' not been found, using default colors."
219
- logger .warning (msg )
220
- else :
221
- coords += [col_for_color ]
233
+ # if col_for_color is not None:
234
+ if (
235
+ col_for_color is not None
236
+ and col_for_color not in points .columns
237
+ and col_for_color not in sdata_filt [table_name ].obs .columns
238
+ ):
239
+ # no error in case there are multiple elements, but onyl some have color key
240
+ msg = f"Color key '{ col_for_color } ' for element '{ e } ' not been found, using default colors."
241
+ logger .warning (msg )
242
+ elif col_for_color is None or (table_name is not None and col_for_color in sdata_filt [table_name ].obs .columns ):
243
+ points = points [coords ].compute ()
244
+ if (
245
+ col_for_color
246
+ and (color_col := sdata_filt [table_name ].obs [col_for_color ]).dtype == "O"
247
+ and not _is_coercable_to_float (color_col )
248
+ ):
249
+ warnings .warn (
250
+ f"Converting copy of '{ col_for_color } ' column to categorical dtype for categorical "
251
+ f"plotting. Consider converting before plotting." ,
252
+ UserWarning ,
253
+ stacklevel = 2 ,
254
+ )
255
+ sdata_filt [table_name ].obs [col_for_color ] = sdata_filt [table_name ].obs [col_for_color ].astype ("category" )
256
+ else :
257
+ coords += [col_for_color ]
258
+ points = points [coords ].compute ()
222
259
223
- points = points [coords ].compute ()
224
260
if render_params .groups [index ][0 ] is not None and col_for_color is not None :
225
261
points = points [points [col_for_color ].isin (render_params .groups [index ])]
226
262
227
263
# we construct an anndata to hack the plotting functions
228
- adata = AnnData (
229
- X = points [["x" , "y" ]].values , obs = points [coords ].reset_index (), dtype = points [["x" , "y" ]].values .dtype
230
- )
264
+ if table_name is None :
265
+ adata = AnnData (
266
+ X = points [["x" , "y" ]].values , obs = points [coords ].reset_index (), dtype = points [["x" , "y" ]].values .dtype
267
+ )
268
+ else :
269
+ adata = AnnData (
270
+ X = points [["x" , "y" ]].values , obs = sdata_filt [table_name ].obs , dtype = points [["x" , "y" ]].values .dtype
271
+ )
272
+ sdata_filt [table_name ] = adata
273
+
274
+ # we can do this because of dealing with a copy
231
275
232
276
# Convert back to dask dataframe to modify sdata
233
277
points = dask .dataframe .from_pandas (points , npartitions = 1 )
@@ -559,6 +603,7 @@ def _render_labels(
559
603
label = sdata_filt .labels [e ]
560
604
extent = get_extent (label , coordinate_system = coordinate_system )
561
605
scale = render_params .scale [i ] if isinstance (render_params .scale , list ) else render_params .scale
606
+ color = render_params .color [i ]
562
607
563
608
# get best scale out of multiscale label
564
609
if isinstance (label , MultiscaleSpatialImage ):
@@ -603,7 +648,7 @@ def _render_labels(
603
648
element = label ,
604
649
element_index = i ,
605
650
element_name = e ,
606
- value_to_plot = cast ( list [ str ], render_params . color )[ i ] ,
651
+ value_to_plot = color ,
607
652
groups = render_params .groups [i ],
608
653
palette = render_params .palette [i ],
609
654
na_color = render_params .cmap_params .na_color ,
@@ -684,7 +729,7 @@ def _render_labels(
684
729
cax = cax ,
685
730
fig_params = fig_params ,
686
731
adata = table ,
687
- value_to_plot = cast ( list [ str ], render_params . color )[ i ] ,
732
+ value_to_plot = color ,
688
733
color_source_vector = color_source_vector ,
689
734
palette = render_params .palette [i ],
690
735
alpha = render_params .fill_alpha ,
0 commit comments