Skip to content

basic shapes datashader rendering #243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def render_shapes(
cmap: Colormap | str | None = None,
norm: bool | Normalize = False,
scale: float | int = 1.0,
method: str | None = None,
**kwargs: Any,
) -> sd.SpatialData:
"""
Expand Down Expand Up @@ -204,6 +205,9 @@ def render_shapes(
Colormap normalization for continuous annotations.
scale : float | int, default 1.0
Value to scale circles, if present.
method : str | None, optional
Whether to use 'matplotlib' and 'datashader'. When None, the method is
chosen based on the size of the data.
**kwargs : Any
Additional arguments to be passed to cmap and norm.

Expand Down Expand Up @@ -317,6 +321,12 @@ def render_shapes(
if scale < 0:
raise ValueError("Parameter 'scale' must be a positive number.")

if method is not None:
if not isinstance(method, str):
raise TypeError("Parameter 'method' must be a string.")
if method not in ["matplotlib", "datashader"]:
raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.")

sdata = self._copy()
sdata = _verify_plotting_tree(sdata)
n_steps = len(sdata.plotting_tree.keys())
Expand All @@ -343,6 +353,7 @@ def render_shapes(
fill_alpha=fill_alpha,
transfunc=kwargs.get("transfunc", None),
zorder=n_steps,
method=method,
)

return sdata
Expand All @@ -358,6 +369,7 @@ def render_points(
cmap: Colormap | str | None = None,
norm: None | Normalize = None,
size: float | int = 1.0,
method: str | None = None,
**kwargs: Any,
) -> sd.SpatialData:
"""
Expand Down Expand Up @@ -392,6 +404,9 @@ def render_points(
Colormap normalization for continuous annotations.
size : float | int, default 1.0
Size of the points
method : str | None, optional
Whether to use 'matplotlib' and 'datashader'. When None, the method is
chosen based on the size of the data.
kwargs
Additional arguments to be passed to cmap and norm.

Expand Down Expand Up @@ -479,6 +494,12 @@ def render_points(
if size < 0:
raise ValueError("Parameter 'size' must be a positive number.")

if method is not None:
if not isinstance(method, str):
raise TypeError("Parameter 'method' must be a string.")
if method not in ["matplotlib", "datashader"]:
raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.")

sdata = self._copy()
sdata = _verify_plotting_tree(sdata)
n_steps = len(sdata.plotting_tree.keys())
Expand All @@ -501,6 +522,7 @@ def render_points(
transfunc=kwargs.get("transfunc", None),
size=size,
zorder=n_steps,
method=method,
)

return sdata
Expand Down
133 changes: 100 additions & 33 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,35 +122,6 @@ def _render_shapes(
shapes = shapes.reset_index()
color_source_vector = color_source_vector[mask]
color_vector = color_vector[mask]
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")

_cax = _get_collection_shape(
shapes=shapes,
s=render_params.scale,
c=color_vector,
render_params=render_params,
rasterized=sc_settings._vector_friendly,
cmap=render_params.cmap_params.cmap,
norm=norm,
fill_alpha=render_params.fill_alpha,
outline_alpha=render_params.outline_alpha,
zorder=render_params.zorder,
# **kwargs,
)

# Sets the limits of the colorbar to the values instead of [0, 1]
if not norm and not values_are_categorical:
_cax.set_clim(min(color_vector), max(color_vector))

cax = ax.add_collection(_cax)

# Apply the transformation to the PatchCollection's paths
trans = get_transformation(sdata_filt.shapes[e], get_all=True)[coordinate_system]
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
trans = mtransforms.Affine2D(matrix=affine_trans)

for path in _cax.get_paths():
path.vertices = trans.transform(path.vertices)

# Using dict.fromkeys here since set returns in arbitrary order
# remove the color of NaN values, else it might be assigned to a category
Expand All @@ -160,6 +131,98 @@ def _render_shapes(
else:
palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()]))

# Apply the transformation to the PatchCollection's paths
trans = get_transformation(sdata_filt.shapes[e], get_all=True)[coordinate_system]
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
trans = mtransforms.Affine2D(matrix=affine_trans)

shapes = gpd.GeoDataFrame(shapes, geometry="geometry")

# Determine which method to use for rendering
method = render_params.method
if method is None:
method = "datashader" if len(shapes) > 100 else "matplotlib"
elif method not in ["matplotlib", "datashader"]:
raise ValueError("Method must be either 'matplotlib' or 'datashader'.")

if method == "matplotlib":
logger.info(f"Using {method}")
_cax = _get_collection_shape(
shapes=shapes,
s=render_params.scale,
c=color_vector,
render_params=render_params,
rasterized=sc_settings._vector_friendly,
cmap=render_params.cmap_params.cmap,
norm=norm,
fill_alpha=render_params.fill_alpha,
outline_alpha=render_params.outline_alpha,
zorder=render_params.zorder,
# **kwargs,
)
cax = ax.add_collection(_cax)

# Transform the paths in PatchCollection
for path in _cax.get_paths():
path.vertices = trans.transform(path.vertices)
cax = ax.add_collection(_cax)
elif method == "datashader":
logger.info(f"Using {method}")

# Where to put this
trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData

extent = get_extent(sdata.shapes[e])
x_ext = extent["x"][1]
y_ext = extent["y"][1]
# previous_xlim = fig_params.ax.get_xlim()
# previous_ylim = fig_params.ax.get_ylim()
x_range = [0, x_ext]
y_range = [0, y_ext]
# round because we need integers
plot_width = int(np.round(x_range[1] - x_range[0]))
plot_height = int(np.round(y_range[1] - y_range[0]))

cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_range, y_range=y_range)

_geometry = shapes["geometry"]
is_point = _geometry.type == "Point"

# Handle circles encoded as points with radius
if is_point.any(): # TODO
scale = shapes[is_point]["radius"] * render_params.scale
shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(scale)

agg = cvs.polygons(shapes, geometry="geometry", agg=ds.count())

# Render shapes with datashader
if render_params.col_for_color is not None and (
render_params.groups is None or len(render_params.groups) > 1
):
agg = cvs.polygons(shapes, geometry="geometry", agg=ds.by(render_params.col_for_color, ds.count()))
else:
agg = cvs.polygons(shapes, geometry="geometry", agg=ds.count())

color_key = (
[x[:-2] for x in color_vector.categories.values]
if (type(color_vector) == pd.core.arrays.categorical.Categorical)
and (len(color_vector.categories.values) > 1)
else None
)
ds_result = ds.tf.shade(
agg, cmap=color_vector[0][:-2], alpha=render_params.fill_alpha * 255, color_key=color_key, min_alpha=200
)

# Render image
rgba_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
_cax = ax.imshow(rgba_image, cmap=palette, zorder=render_params.zorder)
_cax.set_transform(trans)
cax = ax.add_image(_cax)

# Sets the limits of the colorbar to the values instead of [0, 1]
if not norm and not values_are_categorical:
_cax.set_clim(min(color_vector), max(color_vector))

if not (
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
):
Expand Down Expand Up @@ -278,9 +341,13 @@ def _render_points(

norm = copy(render_params.cmap_params.norm)

# optionally render points using datashader
# TODO: maybe move this, add heuristic
if len(points) > 50:
method = render_params.method
if method is None:
method = "datashader" if len(points.shape[0]) > 10000 else "matplotlib"
elif method not in ["matplotlib", "datashader"]:
raise ValueError("Method must be either 'matplotlib' or 'datashader'.")

if method == "datashader":
extent = get_extent(sdata_filt.points[e], coordinate_system=coordinate_system)
x_ext = extent["x"][1]
y_ext = extent["y"][1]
Expand Down Expand Up @@ -334,7 +401,7 @@ def _render_points(
rbga_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
ax.imshow(rbga_image, zorder=render_params.zorder)
cax = None
else:
elif method == "matplotlib":
# original way of plotting points
_cax = ax.scatter(
adata[:, 0].X.flatten(),
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class ShapesRenderParams:
fill_alpha: float = 0.3
scale: float = 1.0
transfunc: Callable[[float], float] | None = None
method: str | None = None
zorder: int | None = None


Expand All @@ -97,6 +98,7 @@ class PointsRenderParams:
alpha: float = 1.0
size: float = 1.0
transfunc: Callable[[float], float] | None = None
method: str | None = None
zorder: int | None = None


Expand Down