@@ -122,35 +122,6 @@ def _render_shapes(
122
122
shapes = shapes .reset_index ()
123
123
color_source_vector = color_source_vector [mask ]
124
124
color_vector = color_vector [mask ]
125
- shapes = gpd .GeoDataFrame (shapes , geometry = "geometry" )
126
-
127
- _cax = _get_collection_shape (
128
- shapes = shapes ,
129
- s = render_params .scale ,
130
- c = color_vector ,
131
- render_params = render_params ,
132
- rasterized = sc_settings ._vector_friendly ,
133
- cmap = render_params .cmap_params .cmap ,
134
- norm = norm ,
135
- fill_alpha = render_params .fill_alpha ,
136
- outline_alpha = render_params .outline_alpha ,
137
- zorder = render_params .zorder ,
138
- # **kwargs,
139
- )
140
-
141
- # Sets the limits of the colorbar to the values instead of [0, 1]
142
- if not norm and not values_are_categorical :
143
- _cax .set_clim (min (color_vector ), max (color_vector ))
144
-
145
- cax = ax .add_collection (_cax )
146
-
147
- # Apply the transformation to the PatchCollection's paths
148
- trans = get_transformation (sdata_filt .shapes [e ], get_all = True )[coordinate_system ]
149
- affine_trans = trans .to_affine_matrix (input_axes = ("x" , "y" ), output_axes = ("x" , "y" ))
150
- trans = mtransforms .Affine2D (matrix = affine_trans )
151
-
152
- for path in _cax .get_paths ():
153
- path .vertices = trans .transform (path .vertices )
154
125
155
126
# Using dict.fromkeys here since set returns in arbitrary order
156
127
# remove the color of NaN values, else it might be assigned to a category
@@ -160,6 +131,98 @@ def _render_shapes(
160
131
else :
161
132
palette = ListedColormap (dict .fromkeys (color_vector [~ pd .Categorical (color_source_vector ).isnull ()]))
162
133
134
+ # Apply the transformation to the PatchCollection's paths
135
+ trans = get_transformation (sdata_filt .shapes [e ], get_all = True )[coordinate_system ]
136
+ affine_trans = trans .to_affine_matrix (input_axes = ("x" , "y" ), output_axes = ("x" , "y" ))
137
+ trans = mtransforms .Affine2D (matrix = affine_trans )
138
+
139
+ shapes = gpd .GeoDataFrame (shapes , geometry = "geometry" )
140
+
141
+ # Determine which method to use for rendering
142
+ method = render_params .method
143
+ if method is None :
144
+ method = "datashader" if len (shapes ) > 100 else "matplotlib"
145
+ elif method not in ["matplotlib" , "datashader" ]:
146
+ raise ValueError ("Method must be either 'matplotlib' or 'datashader'." )
147
+
148
+ if method == "matplotlib" :
149
+ logger .info (f"Using { method } " )
150
+ _cax = _get_collection_shape (
151
+ shapes = shapes ,
152
+ s = render_params .scale ,
153
+ c = color_vector ,
154
+ render_params = render_params ,
155
+ rasterized = sc_settings ._vector_friendly ,
156
+ cmap = render_params .cmap_params .cmap ,
157
+ norm = norm ,
158
+ fill_alpha = render_params .fill_alpha ,
159
+ outline_alpha = render_params .outline_alpha ,
160
+ zorder = render_params .zorder ,
161
+ # **kwargs,
162
+ )
163
+ cax = ax .add_collection (_cax )
164
+
165
+ # Transform the paths in PatchCollection
166
+ for path in _cax .get_paths ():
167
+ path .vertices = trans .transform (path .vertices )
168
+ cax = ax .add_collection (_cax )
169
+ elif method == "datashader" :
170
+ logger .info (f"Using { method } " )
171
+
172
+ # Where to put this
173
+ trans = mtransforms .Affine2D (matrix = affine_trans ) + ax .transData
174
+
175
+ extent = get_extent (sdata .shapes [e ])
176
+ x_ext = extent ["x" ][1 ]
177
+ y_ext = extent ["y" ][1 ]
178
+ # previous_xlim = fig_params.ax.get_xlim()
179
+ # previous_ylim = fig_params.ax.get_ylim()
180
+ x_range = [0 , x_ext ]
181
+ y_range = [0 , y_ext ]
182
+ # round because we need integers
183
+ plot_width = int (np .round (x_range [1 ] - x_range [0 ]))
184
+ plot_height = int (np .round (y_range [1 ] - y_range [0 ]))
185
+
186
+ cvs = ds .Canvas (plot_width = plot_width , plot_height = plot_height , x_range = x_range , y_range = y_range )
187
+
188
+ _geometry = shapes ["geometry" ]
189
+ is_point = _geometry .type == "Point"
190
+
191
+ # Handle circles encoded as points with radius
192
+ if is_point .any (): # TODO
193
+ scale = shapes [is_point ]["radius" ] * render_params .scale
194
+ shapes .loc [is_point , "geometry" ] = _geometry [is_point ].buffer (scale )
195
+
196
+ agg = cvs .polygons (shapes , geometry = "geometry" , agg = ds .count ())
197
+
198
+ # Render shapes with datashader
199
+ if render_params .col_for_color is not None and (
200
+ render_params .groups is None or len (render_params .groups ) > 1
201
+ ):
202
+ agg = cvs .polygons (shapes , geometry = "geometry" , agg = ds .by (render_params .col_for_color , ds .count ()))
203
+ else :
204
+ agg = cvs .polygons (shapes , geometry = "geometry" , agg = ds .count ())
205
+
206
+ color_key = (
207
+ [x [:- 2 ] for x in color_vector .categories .values ]
208
+ if (type (color_vector ) == pd .core .arrays .categorical .Categorical )
209
+ and (len (color_vector .categories .values ) > 1 )
210
+ else None
211
+ )
212
+ ds_result = ds .tf .shade (
213
+ agg , cmap = color_vector [0 ][:- 2 ], alpha = render_params .fill_alpha * 255 , color_key = color_key , min_alpha = 200
214
+ )
215
+
216
+ # Render image
217
+ rgba_image = np .transpose (ds_result .to_numpy ().base , (0 , 1 , 2 ))
218
+ _cax = ax .imshow (rgba_image , cmap = palette , zorder = render_params .zorder )
219
+ _cax .set_transform (trans )
220
+ cax = ax .add_image (_cax )
221
+
222
+ # Sets the limits of the colorbar to the values instead of [0, 1]
223
+ if not norm and not values_are_categorical :
224
+ _cax .set_clim (min (color_vector ), max (color_vector ))
225
+
163
226
if not (
164
227
len (set (color_vector )) == 1 and list (set (color_vector ))[0 ] == to_hex (render_params .cmap_params .na_color )
165
228
):
@@ -278,9 +341,13 @@ def _render_points(
278
341
279
342
norm = copy (render_params .cmap_params .norm )
280
343
281
- # optionally render points using datashader
282
- # TODO: maybe move this, add heuristic
283
- if len (points ) > 50 :
344
+ method = render_params .method
345
+ if method is None :
346
+ method = "datashader" if len (points .shape [0 ]) > 10000 else "matplotlib"
347
+ elif method not in ["matplotlib" , "datashader" ]:
348
+ raise ValueError ("Method must be either 'matplotlib' or 'datashader'." )
349
+
350
+ if method == "datashader" :
284
351
extent = get_extent (sdata_filt .points [e ], coordinate_system = coordinate_system )
285
352
x_ext = extent ["x" ][1 ]
286
353
y_ext = extent ["y" ][1 ]
@@ -334,7 +401,7 @@ def _render_points(
334
401
rbga_image = np .transpose (ds_result .to_numpy ().base , (0 , 1 , 2 ))
335
402
ax .imshow (rbga_image , zorder = render_params .zorder )
336
403
cax = None
337
- else :
404
+ elif method == "matplotlib" :
338
405
# original way of plotting points
339
406
_cax = ax .scatter (
340
407
adata [:, 0 ].X .flatten (),
0 commit comments