16
16
import spatialdata as sd
17
17
from anndata import AnnData
18
18
from datatree import DataTree
19
-
20
- # from datatree.datatree import DataTree
21
19
from matplotlib .cm import ScalarMappable
22
20
from matplotlib .colors import ListedColormap , Normalize
23
21
from scanpy ._settings import settings as sc_settings
@@ -163,27 +161,7 @@ def _render_shapes(
163
161
raise ValueError ("Method must be either 'matplotlib' or 'datashader'." )
164
162
logger .info (f"Using { method } " )
165
163
166
- if method == "matplotlib" :
167
- _cax = _get_collection_shape (
168
- shapes = shapes ,
169
- s = render_params .scale ,
170
- c = color_vector ,
171
- render_params = render_params ,
172
- rasterized = sc_settings ._vector_friendly ,
173
- cmap = render_params .cmap_params .cmap ,
174
- norm = norm ,
175
- fill_alpha = render_params .fill_alpha ,
176
- outline_alpha = render_params .outline_alpha ,
177
- zorder = render_params .zorder ,
178
- # **kwargs,
179
- )
180
- cax = ax .add_collection (_cax )
181
-
182
- # Transform the paths in PatchCollection
183
- for path in _cax .get_paths ():
184
- path .vertices = trans .transform (path .vertices )
185
- cax = ax .add_collection (_cax )
186
- elif method == "datashader" :
164
+ if method == "datashader" :
187
165
# TODO: Where to put this
188
166
trans = mtransforms .Affine2D (matrix = affine_trans ) + ax .transData
189
167
@@ -209,11 +187,9 @@ def _render_shapes(
209
187
# in case we are coloring by a column in table
210
188
if col_for_color is not None and col_for_color not in sdata_filt .shapes [element ].columns :
211
189
# numerical
212
- if color_source_vector is None :
213
- sdata_filt .shapes [element ][col_for_color ] = color_vector
214
- else : # categorical
215
- sdata_filt .shapes [element ][col_for_color ] = color_source_vector
216
-
190
+ sdata_filt .shapes [element ][col_for_color ] = (
191
+ color_vector if color_source_vector is None else color_source_vector
192
+ )
217
193
# Render shapes with datashader
218
194
color_by_categorical = col_for_color is not None and color_source_vector is not None
219
195
aggregate_with_sum = None
@@ -232,24 +208,24 @@ def _render_shapes(
232
208
233
209
color_key = (
234
210
[x [:- 2 ] for x in color_vector .categories .values ]
235
- if (type (color_vector ) == pd .core .arrays .categorical .Categorical )
211
+ if (type (color_vector ) is pd .core .arrays .categorical .Categorical )
236
212
and (len (color_vector .categories .values ) > 1 )
237
213
else None
238
214
)
239
215
240
- if color_by_categorical or col_for_color is None :
241
- ds_result = ds .tf .shade (
216
+ ds_result = (
217
+ ds .tf .shade (
242
218
agg ,
243
219
cmap = color_vector [0 ][:- 2 ],
244
220
color_key = color_key ,
245
221
min_alpha = np .min ([150 , render_params .fill_alpha * 255 ]),
246
- ) # TODO: choose other value than 150 for min_alpha (here and below)?
247
- else :
248
- ds_result = ds .tf .shade (
222
+ )
223
+ if color_by_categorical or col_for_color is None
224
+ else ds .tf .shade (
249
225
agg ,
250
226
cmap = render_params .cmap_params .cmap ,
251
227
)
252
-
228
+ )
253
229
# Render image
254
230
rgba_image = np .transpose (ds_result .to_numpy ().base , (0 , 1 , 2 ))
255
231
_cax = ax .imshow (rgba_image , cmap = palette , zorder = render_params .zorder )
@@ -261,6 +237,27 @@ def _render_shapes(
261
237
cmap = render_params .cmap_params .cmap ,
262
238
)
263
239
240
+ elif method == "matplotlib" :
241
+ _cax = _get_collection_shape (
242
+ shapes = shapes ,
243
+ s = render_params .scale ,
244
+ c = color_vector ,
245
+ render_params = render_params ,
246
+ rasterized = sc_settings ._vector_friendly ,
247
+ cmap = render_params .cmap_params .cmap ,
248
+ norm = norm ,
249
+ fill_alpha = render_params .fill_alpha ,
250
+ outline_alpha = render_params .outline_alpha ,
251
+ zorder = render_params .zorder ,
252
+ # **kwargs,
253
+ )
254
+ cax = ax .add_collection (_cax )
255
+
256
+ # Transform the paths in PatchCollection
257
+ for path in _cax .get_paths ():
258
+ path .vertices = trans .transform (path .vertices )
259
+ cax = ax .add_collection (_cax )
260
+
264
261
# Sets the limits of the colorbar to the values instead of [0, 1]
265
262
if not norm and not values_are_categorical :
266
263
_cax .set_clim (min (color_vector ), max (color_vector ))
@@ -356,7 +353,7 @@ def _render_points(
356
353
)
357
354
sdata_filt [table_name ] = adata
358
355
359
- # we can do this because of dealing with a copy
356
+ # we can modify the sdata because of dealing with a copy
360
357
361
358
# Convert back to dask dataframe to modify sdata
362
359
transformation_in_cs = sdata_filt .points [element ].attrs ["transform" ][coordinate_system ]
@@ -456,7 +453,7 @@ def _render_points(
456
453
457
454
color_key = (
458
455
[x [:- 2 ] for x in color_vector .categories .values ]
459
- if (type (color_vector ) == pd .core .arrays .categorical .Categorical )
456
+ if (type (color_vector ) is pd .core .arrays .categorical .Categorical )
460
457
and (len (color_vector .categories .values ) > 1 )
461
458
else None
462
459
)
@@ -473,9 +470,8 @@ def _render_points(
473
470
ds .tf .spread (agg , px = px ),
474
471
rescale_discrete_levels = True ,
475
472
cmap = render_params .cmap_params .cmap ,
476
- # color_key=color_key,
477
473
)
478
- # render image
474
+
479
475
rbga_image = np .transpose (ds_result .to_numpy ().base , (0 , 1 , 2 ))
480
476
cax = ax .imshow (rbga_image , zorder = render_params .zorder , alpha = render_params .alpha )
481
477
if aggregate_with_sum is not None :
@@ -498,7 +494,6 @@ def _render_points(
498
494
alpha = render_params .alpha ,
499
495
transform = trans ,
500
496
zorder = render_params .zorder ,
501
- # **kwargs,
502
497
)
503
498
cax = ax .add_collection (_cax )
504
499
if update_parameters :
0 commit comments