@@ -134,12 +134,21 @@ def _render_shapes(
134
134
# Apply the transformation to the PatchCollection's paths
135
135
trans = get_transformation (sdata_filt .shapes [e ], get_all = True )[coordinate_system ]
136
136
affine_trans = trans .to_affine_matrix (input_axes = ("x" , "y" ), output_axes = ("x" , "y" ))
137
- trans = mtransforms .Affine2D (matrix = affine_trans ) + ax . transData
137
+ trans = mtransforms .Affine2D (matrix = affine_trans )
138
138
139
139
shapes = gpd .GeoDataFrame (shapes , geometry = "geometry" )
140
140
141
- if len (shapes ) < 1 :
142
- logger .info ("Using matplotlib" )
141
+ # Determine which method to use for rendering. Default is matplotlib for under 100 shapes and datashader for more
142
+ # User can also specify the method to use
143
+ method = render_params .method
144
+
145
+ if method is None :
146
+ method = "datashader" if len (shapes ) > 100 else "matplotlib"
147
+ elif method not in ["matplotlib" , "datashader" ]:
148
+ raise ValueError ("Method must be either 'matplotlib' or 'datashader'." )
149
+
150
+ if method == "matplotlib" :
151
+ logger .info (f"Using { method } " )
143
152
_cax = _get_collection_shape (
144
153
shapes = shapes ,
145
154
s = render_params .scale ,
@@ -159,8 +168,12 @@ def _render_shapes(
159
168
for path in _cax .get_paths ():
160
169
path .vertices = trans .transform (path .vertices )
161
170
cax = ax .add_collection (_cax )
162
- else :
163
- logger .info ("Using datashader" )
171
+ elif method == "datashader" :
172
+ logger .info (f"Using { method } " )
173
+
174
+ # Where to put this
175
+ trans = mtransforms .Affine2D (matrix = affine_trans ) + ax .transData
176
+
164
177
extent = get_extent (sdata .shapes [e ])
165
178
x_ext = extent ["x" ][1 ]
166
179
y_ext = extent ["y" ][1 ]
@@ -179,10 +192,28 @@ def _render_shapes(
179
192
180
193
# Handle circles encoded as points with radius
181
194
if is_point .any (): # TODO
182
- shapes .loc [is_point , "geometry" ] = _geometry [is_point ].buffer (shapes [is_point ]["radius" ])
195
+ scale = shapes [is_point ]["radius" ] * render_params .scale
196
+ shapes .loc [is_point , "geometry" ] = _geometry [is_point ].buffer (scale )
183
197
184
198
agg = cvs .polygons (shapes , geometry = "geometry" , agg = ds .count ())
185
- ds_result = ds .tf .shade (agg )
199
+
200
+ # Render shapes with datashader
201
+ if render_params .col_for_color is not None and (
202
+ render_params .groups is None or len (render_params .groups ) > 1
203
+ ):
204
+ agg = cvs .polygons (shapes , geometry = "geometry" , agg = ds .by (render_params .col_for_color , ds .count ()))
205
+ else :
206
+ agg = cvs .polygons (shapes , geometry = "geometry" , agg = ds .count ())
207
+
208
+ color_key = (
209
+ [x [:- 2 ] for x in color_vector .categories .values ]
210
+ if (type (color_vector ) == pd .core .arrays .categorical .Categorical )
211
+ and (len (color_vector .categories .values ) > 1 )
212
+ else None
213
+ )
214
+ ds_result = ds .tf .shade (
215
+ agg , cmap = color_vector [0 ][:- 2 ], alpha = render_params .fill_alpha * 255 , color_key = color_key , min_alpha = 200
216
+ )
186
217
187
218
# Render image
188
219
rgba_image = np .transpose (ds_result .to_numpy ().base , (0 , 1 , 2 ))
@@ -312,9 +343,13 @@ def _render_points(
312
343
313
344
norm = copy (render_params .cmap_params .norm )
314
345
315
- # optionally render points using datashader
316
- # TODO: maybe move this, add heuristic
317
- if len (points ) > 50 :
346
+ method = render_params .method
347
+ if method is None :
348
+ method = "datashader" if len (points .shape [0 ]) > 10000 else "matplotlib"
349
+ elif method not in ["matplotlib" , "datashader" ]:
350
+ raise ValueError ("Method must be either 'matplotlib' or 'datashader'." )
351
+
352
+ if method == "datashader" :
318
353
extent = get_extent (sdata_filt .points [e ], coordinate_system = coordinate_system )
319
354
x_ext = extent ["x" ][1 ]
320
355
y_ext = extent ["y" ][1 ]
@@ -368,7 +403,7 @@ def _render_points(
368
403
rbga_image = np .transpose (ds_result .to_numpy ().base , (0 , 1 , 2 ))
369
404
ax .imshow (rbga_image , zorder = render_params .zorder )
370
405
cax = None
371
- else :
406
+ elif method == "matplotlib" :
372
407
# original way of plotting points
373
408
_cax = ax .scatter (
374
409
adata [:, 0 ].X .flatten (),
0 commit comments