18
18
from matplotlib .colors import ListedColormap , Normalize
19
19
from scanpy ._settings import settings as sc_settings
20
20
from spatialdata import get_extent
21
- from spatialdata .models import PointsModel , get_table_keys
22
- from spatialdata .transformations import (
23
- set_transformation ,
24
- )
21
+ from spatialdata .models import PointsModel , ShapesModel , get_table_keys
22
+ from spatialdata .transformations import get_transformation , set_transformation
23
+ from spatialdata .transformations .transformations import Identity
25
24
from xarray import DataTree
26
25
27
26
from spatialdata_plot ._logging import logger
44
43
_get_colors_for_categorical_obs ,
45
44
_get_extent_and_range_for_datashader_canvas ,
46
45
_get_linear_colormap ,
46
+ _get_transformation_matrix_for_datashader ,
47
47
_is_coercable_to_float ,
48
48
_map_color_seg ,
49
49
_maybe_set_colors ,
@@ -148,7 +148,7 @@ def _render_shapes(
148
148
colorbar = False if col_for_color is None else legend_params .colorbar
149
149
150
150
# Apply the transformation to the PatchCollection's paths
151
- trans , _ = _prepare_transformation (sdata_filt .shapes [element ], coordinate_system )
151
+ trans , trans_data = _prepare_transformation (sdata_filt .shapes [element ], coordinate_system )
152
152
153
153
shapes = gpd .GeoDataFrame (shapes , geometry = "geometry" )
154
154
@@ -168,14 +168,6 @@ def _render_shapes(
168
168
)
169
169
170
170
if method == "datashader" :
171
- trans += ax .transData
172
-
173
- plot_width , plot_height , x_ext , y_ext , factor = _get_extent_and_range_for_datashader_canvas (
174
- sdata_filt .shapes [element ], coordinate_system , ax , fig_params
175
- )
176
-
177
- cvs = ds .Canvas (plot_width = plot_width , plot_height = plot_height , x_range = x_ext , y_range = y_ext )
178
-
179
171
_geometry = shapes ["geometry" ]
180
172
is_point = _geometry .type == "Point"
181
173
@@ -184,36 +176,48 @@ def _render_shapes(
184
176
scale = shapes [is_point ]["radius" ] * render_params .scale
185
177
sdata_filt .shapes [element ].loc [is_point , "geometry" ] = _geometry [is_point ].buffer (scale .to_numpy ())
186
178
179
+ # apply transformations to the individual points
180
+ element_trans = get_transformation (sdata_filt .shapes [element ])
181
+ tm = _get_transformation_matrix_for_datashader (element_trans )
182
+ transformed_element = sdata_filt .shapes [element ].transform (
183
+ lambda x : (np .hstack ([x , np .ones ((x .shape [0 ], 1 ))]) @ tm )[:, :2 ]
184
+ )
185
+ transformed_element = ShapesModel .parse (
186
+ gpd .GeoDataFrame (data = sdata_filt .shapes [element ].drop ("geometry" , axis = 1 ), geometry = transformed_element )
187
+ )
188
+
189
+ plot_width , plot_height , x_ext , y_ext , factor = _get_extent_and_range_for_datashader_canvas (
190
+ transformed_element , coordinate_system , ax , fig_params
191
+ )
192
+
193
+ cvs = ds .Canvas (plot_width = plot_width , plot_height = plot_height , x_range = x_ext , y_range = y_ext )
194
+
187
195
# in case we are coloring by a column in table
188
- if col_for_color is not None and col_for_color not in sdata_filt .shapes [element ].columns :
189
- sdata_filt .shapes [element ][col_for_color ] = (
190
- color_vector if color_source_vector is None else color_source_vector
191
- )
196
+ if col_for_color is not None and col_for_color not in transformed_element .columns :
197
+ transformed_element [col_for_color ] = color_vector if color_source_vector is None else color_source_vector
192
198
# Render shapes with datashader
193
199
color_by_categorical = col_for_color is not None and color_source_vector is not None
194
200
aggregate_with_reduction = None
195
201
if col_for_color is not None and (render_params .groups is None or len (render_params .groups ) > 1 ):
196
202
if color_by_categorical :
197
- agg = cvs .polygons (
198
- sdata_filt .shapes [element ], geometry = "geometry" , agg = ds .by (col_for_color , ds .count ())
199
- )
203
+ agg = cvs .polygons (transformed_element , geometry = "geometry" , agg = ds .by (col_for_color , ds .count ()))
200
204
else :
201
205
reduction_name = render_params .ds_reduction if render_params .ds_reduction is not None else "mean"
202
206
logger .info (
203
207
f'Using the datashader reduction "{ reduction_name } ". "max" will give an output very close '
204
208
"to the matplotlib result."
205
209
)
206
210
agg = _datashader_aggregate_with_function (
207
- render_params .ds_reduction , cvs , sdata_filt . shapes [ element ] , col_for_color , "shapes"
211
+ render_params .ds_reduction , cvs , transformed_element , col_for_color , "shapes"
208
212
)
209
213
# save min and max values for drawing the colorbar
210
214
aggregate_with_reduction = (agg .min (), agg .max ())
211
215
else :
212
- agg = cvs .polygons (sdata_filt . shapes [ element ] , geometry = "geometry" , agg = ds .count ())
216
+ agg = cvs .polygons (transformed_element , geometry = "geometry" , agg = ds .count ())
213
217
# render outlines if needed
214
218
if (render_outlines := render_params .outline_alpha ) > 0 :
215
219
agg_outlines = cvs .line (
216
- sdata_filt . shapes [ element ] ,
220
+ transformed_element ,
217
221
geometry = "geometry" ,
218
222
line_width = render_params .outline_params .linewidth ,
219
223
)
@@ -287,13 +291,23 @@ def _render_shapes(
287
291
288
292
rgba_image , trans_data = _create_image_from_datashader_result (ds_result , factor , ax )
289
293
_cax = _ax_show_and_transform (
290
- rgba_image , trans_data , ax , zorder = render_params .zorder , alpha = render_params .fill_alpha
294
+ rgba_image ,
295
+ trans_data ,
296
+ ax ,
297
+ zorder = render_params .zorder ,
298
+ alpha = render_params .fill_alpha ,
299
+ extent = x_ext + y_ext ,
291
300
)
292
301
# render outline image if needed
293
302
if render_outlines :
294
303
rgba_image , trans_data = _create_image_from_datashader_result (ds_outlines , factor , ax )
295
304
_ax_show_and_transform (
296
- rgba_image , trans_data , ax , zorder = render_params .zorder , alpha = render_params .outline_alpha
305
+ rgba_image ,
306
+ trans_data ,
307
+ ax ,
308
+ zorder = render_params .zorder ,
309
+ alpha = render_params .outline_alpha ,
310
+ extent = x_ext + y_ext ,
297
311
)
298
312
299
313
cax = None
@@ -330,7 +344,7 @@ def _render_shapes(
330
344
331
345
if not values_are_categorical :
332
346
# If the user passed a Normalize object with vmin/vmax we'll use those,
333
- # # if not we'll use the min/max of the color_vector
347
+ # if not we'll use the min/max of the color_vector
334
348
_cax .set_clim (
335
349
vmin = render_params .cmap_params .norm .vmin or min (color_vector ),
336
350
vmax = render_params .cmap_params .norm .vmax or max (color_vector ),
@@ -468,7 +482,7 @@ def _render_points(
468
482
if color_source_vector is None and render_params .transfunc is not None :
469
483
color_vector = render_params .transfunc (color_vector )
470
484
471
- _ , trans_data = _prepare_transformation (sdata .points [element ], coordinate_system , ax )
485
+ trans , trans_data = _prepare_transformation (sdata .points [element ], coordinate_system , ax )
472
486
473
487
norm = copy (render_params .cmap_params .norm )
474
488
@@ -491,8 +505,15 @@ def _render_points(
491
505
# use dpi/100 as a factor for cases where dpi!=100
492
506
px = int (np .round (np .sqrt (render_params .size ) * (fig_params .fig .dpi / 100 )))
493
507
508
+ # apply transformations
509
+ transformed_element = PointsModel .parse (
510
+ trans .transform (sdata_filt .points [element ][["x" , "y" ]]),
511
+ annotation = sdata_filt .points [element ][sdata_filt .points [element ].columns .drop (["x" , "y" ])],
512
+ transformations = {coordinate_system : Identity ()},
513
+ )
514
+
494
515
plot_width , plot_height , x_ext , y_ext , factor = _get_extent_and_range_for_datashader_canvas (
495
- sdata_filt . points [ element ] , coordinate_system , ax , fig_params
516
+ transformed_element , coordinate_system , ax , fig_params
496
517
)
497
518
498
519
# use datashader for the visualization of points
@@ -502,20 +523,20 @@ def _render_points(
502
523
aggregate_with_reduction = None
503
524
if col_for_color is not None and (render_params .groups is None or len (render_params .groups ) > 1 ):
504
525
if color_by_categorical :
505
- agg = cvs .points (sdata_filt . points [ element ] , "x" , "y" , agg = ds .by (col_for_color , ds .count ()))
526
+ agg = cvs .points (transformed_element , "x" , "y" , agg = ds .by (col_for_color , ds .count ()))
506
527
else :
507
528
reduction_name = render_params .ds_reduction if render_params .ds_reduction is not None else "sum"
508
529
logger .info (
509
530
f'Using the datashader reduction "{ reduction_name } ". "max" will give an output very close '
510
531
"to the matplotlib result."
511
532
)
512
533
agg = _datashader_aggregate_with_function (
513
- render_params .ds_reduction , cvs , sdata_filt . points [ element ] , col_for_color , "points"
534
+ render_params .ds_reduction , cvs , transformed_element , col_for_color , "points"
514
535
)
515
536
# save min and max values for drawing the colorbar
516
537
aggregate_with_reduction = (agg .min (), agg .max ())
517
538
else :
518
- agg = cvs .points (sdata_filt . points [ element ] , "x" , "y" , agg = ds .count ())
539
+ agg = cvs .points (transformed_element , "x" , "y" , agg = ds .count ())
519
540
520
541
if norm .vmin is not None or norm .vmax is not None :
521
542
norm .vmin = np .min (agg ) if norm .vmin is None else norm .vmin
@@ -573,7 +594,14 @@ def _render_points(
573
594
)
574
595
575
596
rgba_image , trans_data = _create_image_from_datashader_result (ds_result , factor , ax )
576
- _ax_show_and_transform (rgba_image , trans_data , ax , zorder = render_params .zorder , alpha = render_params .alpha )
597
+ _ax_show_and_transform (
598
+ rgba_image ,
599
+ trans_data ,
600
+ ax ,
601
+ zorder = render_params .zorder ,
602
+ alpha = render_params .alpha ,
603
+ extent = x_ext + y_ext ,
604
+ )
577
605
578
606
cax = None
579
607
if aggregate_with_reduction is not None :
0 commit comments