@@ -82,7 +82,7 @@ def _render_shapes(
82
82
83
83
for e in elements :
84
84
shapes = sdata .shapes [e ]
85
- n_shapes = sum ([ len (s ) for s in shapes ] )
85
+ n_shapes = sum (len (s ) for s in shapes )
86
86
87
87
if sdata .table is None :
88
88
table = AnnData (None , obs = pd .DataFrame (index = pd .Index (np .arange (n_shapes ), dtype = str )))
@@ -94,11 +94,11 @@ def _render_shapes(
94
94
sdata = sdata_filt ,
95
95
element = sdata_filt .shapes [e ],
96
96
element_name = e ,
97
- value_to_plot = render_params .color ,
97
+ value_to_plot = render_params .col_for_color ,
98
98
layer = render_params .layer ,
99
99
groups = render_params .groups ,
100
100
palette = render_params .palette ,
101
- na_color = render_params .cmap_params .na_color ,
101
+ na_color = render_params .color or render_params . cmap_params .na_color ,
102
102
alpha = render_params .fill_alpha ,
103
103
cmap_params = render_params .cmap_params ,
104
104
)
@@ -162,14 +162,18 @@ def _render_shapes(
162
162
len (set (color_vector )) == 1 and list (set (color_vector ))[0 ] == to_hex (render_params .cmap_params .na_color )
163
163
):
164
164
# necessary in case different shapes elements are annotated with one table
165
- if color_source_vector is not None :
165
+ if color_source_vector is not None and render_params . col_for_color is not None :
166
166
color_source_vector = color_source_vector .remove_unused_categories ()
167
+
168
+ # False if user specified color-like with 'color' parameter
169
+ colorbar = False if render_params .col_for_color is None else legend_params .colorbar
170
+
167
171
_ = _decorate_axs (
168
172
ax = ax ,
169
173
cax = cax ,
170
174
fig_params = fig_params ,
171
175
adata = table ,
172
- value_to_plot = render_params .color ,
176
+ value_to_plot = render_params .col_for_color ,
173
177
color_source_vector = color_source_vector ,
174
178
palette = palette ,
175
179
alpha = render_params .fill_alpha ,
@@ -179,7 +183,7 @@ def _render_shapes(
179
183
legend_loc = legend_params .legend_loc ,
180
184
legend_fontoutline = legend_params .legend_fontoutline ,
181
185
na_in_legend = legend_params .na_in_legend ,
182
- colorbar = legend_params . colorbar ,
186
+ colorbar = colorbar ,
183
187
scalebar_dx = scalebar_params .scalebar_dx ,
184
188
scalebar_units = scalebar_params .scalebar_units ,
185
189
)
@@ -194,12 +198,6 @@ def _render_points(
194
198
scalebar_params : ScalebarParams ,
195
199
legend_params : LegendParams ,
196
200
) -> None :
197
- if render_params .groups is not None :
198
- if isinstance (render_params .groups , str ):
199
- render_params .groups = [render_params .groups ]
200
- if not all (isinstance (g , str ) for g in render_params .groups ):
201
- raise TypeError ("All groups must be strings." )
202
-
203
201
elements = render_params .elements
204
202
205
203
sdata_filt = sdata .filter_by_coordinate_system (
@@ -214,43 +212,56 @@ def _render_points(
214
212
215
213
for e in elements :
216
214
points = sdata .points [e ]
215
+ col_for_color = render_params .col_for_color
216
+
217
217
coords = ["x" , "y" ]
218
- if render_params .color is not None :
219
- color = [render_params .color ] if isinstance (render_params .color , str ) else render_params .color
220
- coords .extend (color )
218
+ if col_for_color is not None :
219
+ if col_for_color not in points .columns :
220
+ # no error in case there are multiple elements, but onyl some have color key
221
+ msg = f"Color key '{ col_for_color } ' for element '{ e } ' not been found, using default colors."
222
+ logger .warning (msg )
223
+ else :
224
+ coords += [col_for_color ]
221
225
222
226
points = points [coords ].compute ()
223
- if render_params .groups is not None :
224
- points = points [points [color ].isin (render_params .groups ).values ]
225
- points [color [0 ]] = points [color [0 ]].cat .set_categories (render_params .groups )
226
- points = dask .dataframe .from_pandas (points , npartitions = 1 )
227
- sdata_filt .points [e ] = PointsModel .parse (points , coordinates = {"x" : "x" , "y" : "y" })
228
-
229
- point_df = points [coords ].compute ()
227
+ if render_params .groups is not None and col_for_color is not None :
228
+ points = points [points [col_for_color ].isin (render_params .groups )]
230
229
231
230
# we construct an anndata to hack the plotting functions
232
231
adata = AnnData (
233
- X = point_df [["x" , "y" ]].values , obs = point_df [coords ].reset_index (), dtype = point_df [["x" , "y" ]].values .dtype
232
+ X = points [["x" , "y" ]].values , obs = points [coords ].reset_index (), dtype = points [["x" , "y" ]].values .dtype
234
233
)
235
- if render_params .color is not None :
236
- cols = sc .get .obs_df (adata , render_params .color )
234
+
235
+ # Convert back to dask dataframe to modify sdata
236
+ points = dask .dataframe .from_pandas (points , npartitions = 1 )
237
+ sdata_filt .points [e ] = PointsModel .parse (points , coordinates = {"x" : "x" , "y" : "y" })
238
+
239
+ if render_params .col_for_color is not None :
240
+ cols = sc .get .obs_df (adata , render_params .col_for_color )
237
241
# maybe set color based on type
238
242
if is_categorical_dtype (cols ):
239
243
_maybe_set_colors (
240
244
source = adata ,
241
245
target = adata ,
242
- key = render_params .color ,
246
+ key = render_params .col_for_color ,
243
247
palette = render_params .palette ,
244
248
)
245
249
250
+ # when user specified a single color, we overwrite na with it
251
+ default_color = (
252
+ render_params .color
253
+ if render_params .col_for_color is None and render_params .color is not None
254
+ else render_params .cmap_params .na_color
255
+ )
256
+
246
257
color_source_vector , color_vector , _ = _set_color_source_vec (
247
258
sdata = sdata_filt ,
248
259
element = points ,
249
260
element_name = e ,
250
- value_to_plot = render_params .color ,
261
+ value_to_plot = render_params .col_for_color ,
251
262
groups = render_params .groups ,
252
263
palette = render_params .palette ,
253
- na_color = render_params . cmap_params . na_color ,
264
+ na_color = default_color ,
254
265
alpha = render_params .alpha ,
255
266
cmap_params = render_params .cmap_params ,
256
267
)
@@ -278,9 +289,7 @@ def _render_points(
278
289
)
279
290
cax = ax .add_collection (_cax )
280
291
281
- if not (
282
- len (set (color_vector )) == 1 and list (set (color_vector ))[0 ] == to_hex (render_params .cmap_params .na_color )
283
- ):
292
+ if len (set (color_vector )) != 1 or list (set (color_vector ))[0 ] != to_hex (render_params .cmap_params .na_color ):
284
293
if color_source_vector is None :
285
294
palette = ListedColormap (dict .fromkeys (color_vector ))
286
295
else :
@@ -291,7 +300,7 @@ def _render_points(
291
300
cax = cax ,
292
301
fig_params = fig_params ,
293
302
adata = adata ,
294
- value_to_plot = render_params .color ,
303
+ value_to_plot = render_params .col_for_color ,
295
304
color_source_vector = color_source_vector ,
296
305
palette = palette ,
297
306
alpha = render_params .alpha ,
@@ -629,8 +638,8 @@ def _render_labels(
629
638
_cax = ax .imshow (
630
639
labels_infill ,
631
640
rasterized = True ,
632
- cmap = render_params . cmap_params . cmap if not categorical else None ,
633
- norm = render_params . cmap_params . norm if not categorical else None ,
641
+ cmap = None if categorical else render_params . cmap_params . cmap ,
642
+ norm = None if categorical else render_params . cmap_params . norm ,
634
643
alpha = render_params .fill_alpha ,
635
644
origin = "lower" ,
636
645
)
@@ -652,14 +661,11 @@ def _render_labels(
652
661
_cax = ax .imshow (
653
662
labels_contour ,
654
663
rasterized = True ,
655
- cmap = render_params . cmap_params . cmap if not categorical else None ,
656
- norm = render_params . cmap_params . norm if not categorical else None ,
664
+ cmap = None if categorical else render_params . cmap_params . cmap ,
665
+ norm = None if categorical else render_params . cmap_params . norm ,
657
666
alpha = render_params .outline_alpha ,
658
667
origin = "lower" ,
659
668
)
660
- _cax .set_transform (trans_data )
661
- cax = ax .add_image (_cax )
662
-
663
669
else :
664
670
# Default: no alpha, contour = infill
665
671
label = _map_color_seg (
@@ -676,13 +682,13 @@ def _render_labels(
676
682
_cax = ax .imshow (
677
683
label ,
678
684
rasterized = True ,
679
- cmap = render_params . cmap_params . cmap if not categorical else None ,
680
- norm = render_params . cmap_params . norm if not categorical else None ,
685
+ cmap = None if categorical else render_params . cmap_params . cmap ,
686
+ norm = None if categorical else render_params . cmap_params . norm ,
681
687
alpha = render_params .fill_alpha ,
682
688
origin = "lower" ,
683
689
)
684
- _cax .set_transform (trans_data )
685
- cax = ax .add_image (_cax )
690
+ _cax .set_transform (trans_data )
691
+ cax = ax .add_image (_cax )
686
692
687
693
_ = _decorate_axs (
688
694
ax = ax ,
0 commit comments