Skip to content

Commit c66e6cd

Browse files
Merge pull request #243 from ckmah/feat/issue209_datashader_shapes
basic shapes datashader rendering
2 parents 3bfab39 + 8e809f5 commit c66e6cd

File tree

3 files changed

+124
-33
lines changed

3 files changed

+124
-33
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def render_shapes(
160160
cmap: Colormap | str | None = None,
161161
norm: bool | Normalize = False,
162162
scale: float | int = 1.0,
163+
method: str | None = None,
163164
**kwargs: Any,
164165
) -> sd.SpatialData:
165166
"""
@@ -204,6 +205,9 @@ def render_shapes(
204205
Colormap normalization for continuous annotations.
205206
scale : float | int, default 1.0
206207
Value to scale circles, if present.
208+
method : str | None, optional
209+
Whether to use 'matplotlib' and 'datashader'. When None, the method is
210+
chosen based on the size of the data.
207211
**kwargs : Any
208212
Additional arguments to be passed to cmap and norm.
209213
@@ -317,6 +321,12 @@ def render_shapes(
317321
if scale < 0:
318322
raise ValueError("Parameter 'scale' must be a positive number.")
319323

324+
if method is not None:
325+
if not isinstance(method, str):
326+
raise TypeError("Parameter 'method' must be a string.")
327+
if method not in ["matplotlib", "datashader"]:
328+
raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.")
329+
320330
sdata = self._copy()
321331
sdata = _verify_plotting_tree(sdata)
322332
n_steps = len(sdata.plotting_tree.keys())
@@ -343,6 +353,7 @@ def render_shapes(
343353
fill_alpha=fill_alpha,
344354
transfunc=kwargs.get("transfunc", None),
345355
zorder=n_steps,
356+
method=method,
346357
)
347358

348359
return sdata
@@ -358,6 +369,7 @@ def render_points(
358369
cmap: Colormap | str | None = None,
359370
norm: None | Normalize = None,
360371
size: float | int = 1.0,
372+
method: str | None = None,
361373
**kwargs: Any,
362374
) -> sd.SpatialData:
363375
"""
@@ -392,6 +404,9 @@ def render_points(
392404
Colormap normalization for continuous annotations.
393405
size : float | int, default 1.0
394406
Size of the points
407+
method : str | None, optional
408+
Whether to use 'matplotlib' and 'datashader'. When None, the method is
409+
chosen based on the size of the data.
395410
kwargs
396411
Additional arguments to be passed to cmap and norm.
397412
@@ -479,6 +494,12 @@ def render_points(
479494
if size < 0:
480495
raise ValueError("Parameter 'size' must be a positive number.")
481496

497+
if method is not None:
498+
if not isinstance(method, str):
499+
raise TypeError("Parameter 'method' must be a string.")
500+
if method not in ["matplotlib", "datashader"]:
501+
raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.")
502+
482503
sdata = self._copy()
483504
sdata = _verify_plotting_tree(sdata)
484505
n_steps = len(sdata.plotting_tree.keys())
@@ -501,6 +522,7 @@ def render_points(
501522
transfunc=kwargs.get("transfunc", None),
502523
size=size,
503524
zorder=n_steps,
525+
method=method,
504526
)
505527

506528
return sdata

src/spatialdata_plot/pl/render.py

Lines changed: 100 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -122,35 +122,6 @@ def _render_shapes(
122122
shapes = shapes.reset_index()
123123
color_source_vector = color_source_vector[mask]
124124
color_vector = color_vector[mask]
125-
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
126-
127-
_cax = _get_collection_shape(
128-
shapes=shapes,
129-
s=render_params.scale,
130-
c=color_vector,
131-
render_params=render_params,
132-
rasterized=sc_settings._vector_friendly,
133-
cmap=render_params.cmap_params.cmap,
134-
norm=norm,
135-
fill_alpha=render_params.fill_alpha,
136-
outline_alpha=render_params.outline_alpha,
137-
zorder=render_params.zorder,
138-
# **kwargs,
139-
)
140-
141-
# Sets the limits of the colorbar to the values instead of [0, 1]
142-
if not norm and not values_are_categorical:
143-
_cax.set_clim(min(color_vector), max(color_vector))
144-
145-
cax = ax.add_collection(_cax)
146-
147-
# Apply the transformation to the PatchCollection's paths
148-
trans = get_transformation(sdata_filt.shapes[e], get_all=True)[coordinate_system]
149-
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
150-
trans = mtransforms.Affine2D(matrix=affine_trans)
151-
152-
for path in _cax.get_paths():
153-
path.vertices = trans.transform(path.vertices)
154125

155126
# Using dict.fromkeys here since set returns in arbitrary order
156127
# remove the color of NaN values, else it might be assigned to a category
@@ -160,6 +131,98 @@ def _render_shapes(
160131
else:
161132
palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()]))
162133

134+
# Apply the transformation to the PatchCollection's paths
135+
trans = get_transformation(sdata_filt.shapes[e], get_all=True)[coordinate_system]
136+
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
137+
trans = mtransforms.Affine2D(matrix=affine_trans)
138+
139+
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
140+
141+
# Determine which method to use for rendering
142+
method = render_params.method
143+
if method is None:
144+
method = "datashader" if len(shapes) > 100 else "matplotlib"
145+
elif method not in ["matplotlib", "datashader"]:
146+
raise ValueError("Method must be either 'matplotlib' or 'datashader'.")
147+
148+
if method == "matplotlib":
149+
logger.info(f"Using {method}")
150+
_cax = _get_collection_shape(
151+
shapes=shapes,
152+
s=render_params.scale,
153+
c=color_vector,
154+
render_params=render_params,
155+
rasterized=sc_settings._vector_friendly,
156+
cmap=render_params.cmap_params.cmap,
157+
norm=norm,
158+
fill_alpha=render_params.fill_alpha,
159+
outline_alpha=render_params.outline_alpha,
160+
zorder=render_params.zorder,
161+
# **kwargs,
162+
)
163+
cax = ax.add_collection(_cax)
164+
165+
# Transform the paths in PatchCollection
166+
for path in _cax.get_paths():
167+
path.vertices = trans.transform(path.vertices)
168+
cax = ax.add_collection(_cax)
169+
elif method == "datashader":
170+
logger.info(f"Using {method}")
171+
172+
# Where to put this
173+
trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData
174+
175+
extent = get_extent(sdata.shapes[e])
176+
x_ext = extent["x"][1]
177+
y_ext = extent["y"][1]
178+
# previous_xlim = fig_params.ax.get_xlim()
179+
# previous_ylim = fig_params.ax.get_ylim()
180+
x_range = [0, x_ext]
181+
y_range = [0, y_ext]
182+
# round because we need integers
183+
plot_width = int(np.round(x_range[1] - x_range[0]))
184+
plot_height = int(np.round(y_range[1] - y_range[0]))
185+
186+
cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_range, y_range=y_range)
187+
188+
_geometry = shapes["geometry"]
189+
is_point = _geometry.type == "Point"
190+
191+
# Handle circles encoded as points with radius
192+
if is_point.any(): # TODO
193+
scale = shapes[is_point]["radius"] * render_params.scale
194+
shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(scale)
195+
196+
agg = cvs.polygons(shapes, geometry="geometry", agg=ds.count())
197+
198+
# Render shapes with datashader
199+
if render_params.col_for_color is not None and (
200+
render_params.groups is None or len(render_params.groups) > 1
201+
):
202+
agg = cvs.polygons(shapes, geometry="geometry", agg=ds.by(render_params.col_for_color, ds.count()))
203+
else:
204+
agg = cvs.polygons(shapes, geometry="geometry", agg=ds.count())
205+
206+
color_key = (
207+
[x[:-2] for x in color_vector.categories.values]
208+
if (type(color_vector) == pd.core.arrays.categorical.Categorical)
209+
and (len(color_vector.categories.values) > 1)
210+
else None
211+
)
212+
ds_result = ds.tf.shade(
213+
agg, cmap=color_vector[0][:-2], alpha=render_params.fill_alpha * 255, color_key=color_key, min_alpha=200
214+
)
215+
216+
# Render image
217+
rgba_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
218+
_cax = ax.imshow(rgba_image, cmap=palette, zorder=render_params.zorder)
219+
_cax.set_transform(trans)
220+
cax = ax.add_image(_cax)
221+
222+
# Sets the limits of the colorbar to the values instead of [0, 1]
223+
if not norm and not values_are_categorical:
224+
_cax.set_clim(min(color_vector), max(color_vector))
225+
163226
if not (
164227
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
165228
):
@@ -278,9 +341,13 @@ def _render_points(
278341

279342
norm = copy(render_params.cmap_params.norm)
280343

281-
# optionally render points using datashader
282-
# TODO: maybe move this, add heuristic
283-
if len(points) > 50:
344+
method = render_params.method
345+
if method is None:
346+
method = "datashader" if len(points.shape[0]) > 10000 else "matplotlib"
347+
elif method not in ["matplotlib", "datashader"]:
348+
raise ValueError("Method must be either 'matplotlib' or 'datashader'.")
349+
350+
if method == "datashader":
284351
extent = get_extent(sdata_filt.points[e], coordinate_system=coordinate_system)
285352
x_ext = extent["x"][1]
286353
y_ext = extent["y"][1]
@@ -334,7 +401,7 @@ def _render_points(
334401
rbga_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
335402
ax.imshow(rbga_image, zorder=render_params.zorder)
336403
cax = None
337-
else:
404+
elif method == "matplotlib":
338405
# original way of plotting points
339406
_cax = ax.scatter(
340407
adata[:, 0].X.flatten(),

src/spatialdata_plot/pl/render_params.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class ShapesRenderParams:
8181
fill_alpha: float = 0.3
8282
scale: float = 1.0
8383
transfunc: Callable[[float], float] | None = None
84+
method: str | None = None
8485
zorder: int | None = None
8586

8687

@@ -97,6 +98,7 @@ class PointsRenderParams:
9798
alpha: float = 1.0
9899
size: float = 1.0
99100
transfunc: Callable[[float], float] | None = None
101+
method: str | None = None
100102
zorder: int | None = None
101103

102104

0 commit comments

Comments
 (0)