Skip to content

Commit cc02f62

Browse files
committed
handle render shape colors; add parameter for render method
1 parent bcecb0e commit cc02f62

File tree

3 files changed

+70
-11
lines changed

3 files changed

+70
-11
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def render_shapes(
160160
cmap: Colormap | str | None = None,
161161
norm: bool | Normalize = False,
162162
scale: float | int = 1.0,
163+
method: str | None = None,
163164
**kwargs: Any,
164165
) -> sd.SpatialData:
165166
"""
@@ -204,6 +205,9 @@ def render_shapes(
204205
Colormap normalization for continuous annotations.
205206
scale : float | int, default 1.0
206207
Value to scale circles, if present.
208+
method : str | None, optional
209+
Whether to use 'matplotlib' and 'datashader'. When None, the method is
210+
chosen based on the size of the data.
207211
**kwargs : Any
208212
Additional arguments to be passed to cmap and norm.
209213
@@ -317,6 +321,12 @@ def render_shapes(
317321
if scale < 0:
318322
raise ValueError("Parameter 'scale' must be a positive number.")
319323

324+
if method is not None:
325+
if not isinstance(method, str):
326+
raise TypeError("Parameter 'method' must be a string.")
327+
if method not in ["matplotlib", "datashader"]:
328+
raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.")
329+
320330
sdata = self._copy()
321331
sdata = _verify_plotting_tree(sdata)
322332
n_steps = len(sdata.plotting_tree.keys())
@@ -343,6 +353,7 @@ def render_shapes(
343353
fill_alpha=fill_alpha,
344354
transfunc=kwargs.get("transfunc", None),
345355
zorder=n_steps,
356+
method=method,
346357
)
347358

348359
return sdata
@@ -358,6 +369,7 @@ def render_points(
358369
cmap: Colormap | str | None = None,
359370
norm: None | Normalize = None,
360371
size: float | int = 1.0,
372+
method: str | None = None,
361373
**kwargs: Any,
362374
) -> sd.SpatialData:
363375
"""
@@ -392,6 +404,9 @@ def render_points(
392404
Colormap normalization for continuous annotations.
393405
size : float | int, default 1.0
394406
Size of the points
407+
method : str | None, optional
408+
Whether to use 'matplotlib' and 'datashader'. When None, the method is
409+
chosen based on the size of the data.
395410
kwargs
396411
Additional arguments to be passed to cmap and norm.
397412
@@ -479,6 +494,12 @@ def render_points(
479494
if size < 0:
480495
raise ValueError("Parameter 'size' must be a positive number.")
481496

497+
if method is not None:
498+
if not isinstance(method, str):
499+
raise TypeError("Parameter 'method' must be a string.")
500+
if method not in ["matplotlib", "datashader"]:
501+
raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.")
502+
482503
sdata = self._copy()
483504
sdata = _verify_plotting_tree(sdata)
484505
n_steps = len(sdata.plotting_tree.keys())
@@ -501,6 +522,7 @@ def render_points(
501522
transfunc=kwargs.get("transfunc", None),
502523
size=size,
503524
zorder=n_steps,
525+
method=method,
504526
)
505527

506528
return sdata

src/spatialdata_plot/pl/render.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,21 @@ def _render_shapes(
134134
# Apply the transformation to the PatchCollection's paths
135135
trans = get_transformation(sdata_filt.shapes[e], get_all=True)[coordinate_system]
136136
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
137-
trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData
137+
trans = mtransforms.Affine2D(matrix=affine_trans)
138138

139139
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
140140

141-
if len(shapes) < 1:
142-
logger.info("Using matplotlib")
141+
# Determine which method to use for rendering. Default is matplotlib for under 100 shapes and datashader for more
142+
# User can also specify the method to use
143+
method = render_params.method
144+
145+
if method is None:
146+
method = "datashader" if len(shapes) > 100 else "matplotlib"
147+
elif method not in ["matplotlib", "datashader"]:
148+
raise ValueError("Method must be either 'matplotlib' or 'datashader'.")
149+
150+
if method == "matplotlib":
151+
logger.info(f"Using {method}")
143152
_cax = _get_collection_shape(
144153
shapes=shapes,
145154
s=render_params.scale,
@@ -159,8 +168,12 @@ def _render_shapes(
159168
for path in _cax.get_paths():
160169
path.vertices = trans.transform(path.vertices)
161170
cax = ax.add_collection(_cax)
162-
else:
163-
logger.info("Using datashader")
171+
elif method == "datashader":
172+
logger.info(f"Using {method}")
173+
174+
# Where to put this
175+
trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData
176+
164177
extent = get_extent(sdata.shapes[e])
165178
x_ext = extent["x"][1]
166179
y_ext = extent["y"][1]
@@ -179,10 +192,28 @@ def _render_shapes(
179192

180193
# Handle circles encoded as points with radius
181194
if is_point.any(): # TODO
182-
shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(shapes[is_point]["radius"])
195+
scale = shapes[is_point]["radius"] * render_params.scale
196+
shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(scale)
183197

184198
agg = cvs.polygons(shapes, geometry="geometry", agg=ds.count())
185-
ds_result = ds.tf.shade(agg)
199+
200+
# Render shapes with datashader
201+
if render_params.col_for_color is not None and (
202+
render_params.groups is None or len(render_params.groups) > 1
203+
):
204+
agg = cvs.polygons(shapes, geometry="geometry", agg=ds.by(render_params.col_for_color, ds.count()))
205+
else:
206+
agg = cvs.polygons(shapes, geometry="geometry", agg=ds.count())
207+
208+
color_key = (
209+
[x[:-2] for x in color_vector.categories.values]
210+
if (type(color_vector) == pd.core.arrays.categorical.Categorical)
211+
and (len(color_vector.categories.values) > 1)
212+
else None
213+
)
214+
ds_result = ds.tf.shade(
215+
agg, cmap=color_vector[0][:-2], alpha=render_params.fill_alpha * 255, color_key=color_key, min_alpha=200
216+
)
186217

187218
# Render image
188219
rgba_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
@@ -312,9 +343,13 @@ def _render_points(
312343

313344
norm = copy(render_params.cmap_params.norm)
314345

315-
# optionally render points using datashader
316-
# TODO: maybe move this, add heuristic
317-
if len(points) > 50:
346+
method = render_params.method
347+
if method is None:
348+
method = "datashader" if len(points.shape[0]) > 10000 else "matplotlib"
349+
elif method not in ["matplotlib", "datashader"]:
350+
raise ValueError("Method must be either 'matplotlib' or 'datashader'.")
351+
352+
if method == "datashader":
318353
extent = get_extent(sdata_filt.points[e], coordinate_system=coordinate_system)
319354
x_ext = extent["x"][1]
320355
y_ext = extent["y"][1]
@@ -368,7 +403,7 @@ def _render_points(
368403
rbga_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
369404
ax.imshow(rbga_image, zorder=render_params.zorder)
370405
cax = None
371-
else:
406+
elif method == "matplotlib":
372407
# original way of plotting points
373408
_cax = ax.scatter(
374409
adata[:, 0].X.flatten(),

src/spatialdata_plot/pl/render_params.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class ShapesRenderParams:
8181
fill_alpha: float = 0.3
8282
scale: float = 1.0
8383
transfunc: Callable[[float], float] | None = None
84+
method: str | None = None
8485
zorder: int | None = None
8586

8687

@@ -97,6 +98,7 @@ class PointsRenderParams:
9798
alpha: float = 1.0
9899
size: float = 1.0
99100
transfunc: Callable[[float], float] | None = None
101+
method: str | None = None
100102
zorder: int | None = None
101103

102104

0 commit comments

Comments
 (0)