Skip to content

Commit e6ed752

Browse files
Sonja-StockhausSonja Stockhauspre-commit-ci[bot]
authored
Transformations are applied before rendering with datashader (#378)
Co-authored-by: Sonja Stockhaus <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent bdfd7b5 commit e6ed752

11 files changed

+238
-43
lines changed

CHANGELOG.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,17 @@ and this project adheres to [Semantic Versioning][].
88
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
99
[semantic versioning]: https://semver.org/spec/v2.0.0.html
1010

11+
## [0.2.9] - tbd
12+
13+
### Fixed
14+
15+
- Transformations of Points and Shapes are now applied before rendering with datashader (#378)
16+
1117
## [0.2.8] - 2024-11-26
1218

1319
### Changed
14-
- Support for `xarray.DataTree` (which moved from `datatree.DataTree`) #380
20+
21+
- Support for `xarray.DataTree` (which moved from `datatree.DataTree`) (#380)
1522

1623
## [0.2.7] - 2024-10-24
1724

@@ -45,10 +52,6 @@ and this project adheres to [Semantic Versioning][].
4552

4653
## [0.2.5] - 2024-08-23
4754

48-
### Added
49-
50-
-
51-
5255
### Changed
5356

5457
- Replaced `outline` parameter in `render_labels` with alpha-based logic (#323)

src/spatialdata_plot/pl/basic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def render_shapes(
162162
palette: list[str] | str | None = None,
163163
na_color: ColorLike | None = "default",
164164
outline_width: float | int = 1.5,
165-
outline_color: str | list[float] = "#000000ff",
165+
outline_color: str | list[float] = "#000000",
166166
outline_alpha: float | int = 0.0,
167167
cmap: Colormap | str | None = None,
168168
norm: Normalize | None = None,
@@ -208,9 +208,11 @@ def render_shapes(
208208
won't be shown.
209209
outline_width : float | int, default 1.5
210210
Width of the border.
211-
outline_color : str | list[float], default "#000000ff"
212-
Color of the border. Can either be a named color ("red"), a hex representation ("#000000ff") or a list of
213-
floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0).
211+
outline_color : str | list[float], default "#000000"
212+
Color of the border. Can either be a named color ("red"), a hex representation ("#000000") or a list of
213+
floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). If the hex representation includes alpha, e.g.
214+
"#000000ff", the last two positions are ignored, since the alpha of the outlines is solely controlled by
215+
`outline_alpha`.
214216
outline_alpha : float | int, default 0.0
215217
Alpha value for the outline of shapes. Invisible by default.
216218
cmap : Colormap | str | None, optional

src/spatialdata_plot/pl/render.py

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818
from matplotlib.colors import ListedColormap, Normalize
1919
from scanpy._settings import settings as sc_settings
2020
from spatialdata import get_extent
21-
from spatialdata.models import PointsModel, get_table_keys
22-
from spatialdata.transformations import (
23-
set_transformation,
24-
)
21+
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
22+
from spatialdata.transformations import get_transformation, set_transformation
23+
from spatialdata.transformations.transformations import Identity
2524
from xarray import DataTree
2625

2726
from spatialdata_plot._logging import logger
@@ -44,6 +43,7 @@
4443
_get_colors_for_categorical_obs,
4544
_get_extent_and_range_for_datashader_canvas,
4645
_get_linear_colormap,
46+
_get_transformation_matrix_for_datashader,
4747
_is_coercable_to_float,
4848
_map_color_seg,
4949
_maybe_set_colors,
@@ -148,7 +148,7 @@ def _render_shapes(
148148
colorbar = False if col_for_color is None else legend_params.colorbar
149149

150150
# Apply the transformation to the PatchCollection's paths
151-
trans, _ = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)
151+
trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)
152152

153153
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
154154

@@ -168,14 +168,6 @@ def _render_shapes(
168168
)
169169

170170
if method == "datashader":
171-
trans += ax.transData
172-
173-
plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas(
174-
sdata_filt.shapes[element], coordinate_system, ax, fig_params
175-
)
176-
177-
cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext)
178-
179171
_geometry = shapes["geometry"]
180172
is_point = _geometry.type == "Point"
181173

@@ -184,36 +176,48 @@ def _render_shapes(
184176
scale = shapes[is_point]["radius"] * render_params.scale
185177
sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())
186178

179+
# apply transformations to the individual points
180+
element_trans = get_transformation(sdata_filt.shapes[element])
181+
tm = _get_transformation_matrix_for_datashader(element_trans)
182+
transformed_element = sdata_filt.shapes[element].transform(
183+
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2]
184+
)
185+
transformed_element = ShapesModel.parse(
186+
gpd.GeoDataFrame(data=sdata_filt.shapes[element].drop("geometry", axis=1), geometry=transformed_element)
187+
)
188+
189+
plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas(
190+
transformed_element, coordinate_system, ax, fig_params
191+
)
192+
193+
cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext)
194+
187195
# in case we are coloring by a column in table
188-
if col_for_color is not None and col_for_color not in sdata_filt.shapes[element].columns:
189-
sdata_filt.shapes[element][col_for_color] = (
190-
color_vector if color_source_vector is None else color_source_vector
191-
)
196+
if col_for_color is not None and col_for_color not in transformed_element.columns:
197+
transformed_element[col_for_color] = color_vector if color_source_vector is None else color_source_vector
192198
# Render shapes with datashader
193199
color_by_categorical = col_for_color is not None and color_source_vector is not None
194200
aggregate_with_reduction = None
195201
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
196202
if color_by_categorical:
197-
agg = cvs.polygons(
198-
sdata_filt.shapes[element], geometry="geometry", agg=ds.by(col_for_color, ds.count())
199-
)
203+
agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.by(col_for_color, ds.count()))
200204
else:
201205
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "mean"
202206
logger.info(
203207
f'Using the datashader reduction "{reduction_name}". "max" will give an output very close '
204208
"to the matplotlib result."
205209
)
206210
agg = _datashader_aggregate_with_function(
207-
render_params.ds_reduction, cvs, sdata_filt.shapes[element], col_for_color, "shapes"
211+
render_params.ds_reduction, cvs, transformed_element, col_for_color, "shapes"
208212
)
209213
# save min and max values for drawing the colorbar
210214
aggregate_with_reduction = (agg.min(), agg.max())
211215
else:
212-
agg = cvs.polygons(sdata_filt.shapes[element], geometry="geometry", agg=ds.count())
216+
agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.count())
213217
# render outlines if needed
214218
if (render_outlines := render_params.outline_alpha) > 0:
215219
agg_outlines = cvs.line(
216-
sdata_filt.shapes[element],
220+
transformed_element,
217221
geometry="geometry",
218222
line_width=render_params.outline_params.linewidth,
219223
)
@@ -287,13 +291,23 @@ def _render_shapes(
287291

288292
rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
289293
_cax = _ax_show_and_transform(
290-
rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.fill_alpha
294+
rgba_image,
295+
trans_data,
296+
ax,
297+
zorder=render_params.zorder,
298+
alpha=render_params.fill_alpha,
299+
extent=x_ext + y_ext,
291300
)
292301
# render outline image if needed
293302
if render_outlines:
294303
rgba_image, trans_data = _create_image_from_datashader_result(ds_outlines, factor, ax)
295304
_ax_show_and_transform(
296-
rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.outline_alpha
305+
rgba_image,
306+
trans_data,
307+
ax,
308+
zorder=render_params.zorder,
309+
alpha=render_params.outline_alpha,
310+
extent=x_ext + y_ext,
297311
)
298312

299313
cax = None
@@ -330,7 +344,7 @@ def _render_shapes(
330344

331345
if not values_are_categorical:
332346
# If the user passed a Normalize object with vmin/vmax we'll use those,
333-
# # if not we'll use the min/max of the color_vector
347+
# if not we'll use the min/max of the color_vector
334348
_cax.set_clim(
335349
vmin=render_params.cmap_params.norm.vmin or min(color_vector),
336350
vmax=render_params.cmap_params.norm.vmax or max(color_vector),
@@ -468,7 +482,7 @@ def _render_points(
468482
if color_source_vector is None and render_params.transfunc is not None:
469483
color_vector = render_params.transfunc(color_vector)
470484

471-
_, trans_data = _prepare_transformation(sdata.points[element], coordinate_system, ax)
485+
trans, trans_data = _prepare_transformation(sdata.points[element], coordinate_system, ax)
472486

473487
norm = copy(render_params.cmap_params.norm)
474488

@@ -491,8 +505,15 @@ def _render_points(
491505
# use dpi/100 as a factor for cases where dpi!=100
492506
px = int(np.round(np.sqrt(render_params.size) * (fig_params.fig.dpi / 100)))
493507

508+
# apply transformations
509+
transformed_element = PointsModel.parse(
510+
trans.transform(sdata_filt.points[element][["x", "y"]]),
511+
annotation=sdata_filt.points[element][sdata_filt.points[element].columns.drop(["x", "y"])],
512+
transformations={coordinate_system: Identity()},
513+
)
514+
494515
plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas(
495-
sdata_filt.points[element], coordinate_system, ax, fig_params
516+
transformed_element, coordinate_system, ax, fig_params
496517
)
497518

498519
# use datashader for the visualization of points
@@ -502,20 +523,20 @@ def _render_points(
502523
aggregate_with_reduction = None
503524
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
504525
if color_by_categorical:
505-
agg = cvs.points(sdata_filt.points[element], "x", "y", agg=ds.by(col_for_color, ds.count()))
526+
agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count()))
506527
else:
507528
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum"
508529
logger.info(
509530
f'Using the datashader reduction "{reduction_name}". "max" will give an output very close '
510531
"to the matplotlib result."
511532
)
512533
agg = _datashader_aggregate_with_function(
513-
render_params.ds_reduction, cvs, sdata_filt.points[element], col_for_color, "points"
534+
render_params.ds_reduction, cvs, transformed_element, col_for_color, "points"
514535
)
515536
# save min and max values for drawing the colorbar
516537
aggregate_with_reduction = (agg.min(), agg.max())
517538
else:
518-
agg = cvs.points(sdata_filt.points[element], "x", "y", agg=ds.count())
539+
agg = cvs.points(transformed_element, "x", "y", agg=ds.count())
519540

520541
if norm.vmin is not None or norm.vmax is not None:
521542
norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin
@@ -573,7 +594,14 @@ def _render_points(
573594
)
574595

575596
rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
576-
_ax_show_and_transform(rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.alpha)
597+
_ax_show_and_transform(
598+
rgba_image,
599+
trans_data,
600+
ax,
601+
zorder=render_params.zorder,
602+
alpha=render_params.alpha,
603+
extent=x_ext + y_ext,
604+
)
577605

578606
cax = None
579607
if aggregate_with_reduction is not None:

src/spatialdata_plot/pl/utils.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import matplotlib.transforms as mtransforms
2020
import numpy as np
2121
import numpy.ma as ma
22+
import numpy.typing as npt
2223
import pandas as pd
2324
import shapely
2425
import spatialdata as sd
@@ -58,8 +59,11 @@
5859
from spatialdata._core.query.relational_query import _locate_value, _ValueOrigin
5960
from spatialdata._types import ArrayLike
6061
from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, SpatialElement, get_model
62+
63+
# from spatialdata.transformations.transformations import Scale
64+
from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Translation
65+
from spatialdata.transformations import Sequence as SDSequence
6166
from spatialdata.transformations.operations import get_transformation
62-
from spatialdata.transformations.transformations import Scale
6367
from xarray import DataArray, DataTree
6468

6569
from spatialdata_plot._logging import logger
@@ -1977,19 +1981,37 @@ def _ax_show_and_transform(
19771981
alpha: float | None = None,
19781982
cmap: ListedColormap | LinearSegmentedColormap | None = None,
19791983
zorder: int = 0,
1984+
extent: list[float] | None = None,
19801985
) -> matplotlib.image.AxesImage:
1986+
# default extent in mpl:
1987+
image_extent = [-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5]
1988+
if extent is not None:
1989+
# make sure extent is [x_min, x_max, y_min, y_max]
1990+
if extent[3] < extent[2]:
1991+
extent[2], extent[3] = extent[3], extent[2]
1992+
if extent[0] < 0:
1993+
x_factor = array.shape[1] / (extent[1] - extent[0])
1994+
image_extent[0] = image_extent[0] + (extent[0] * x_factor)
1995+
image_extent[1] = image_extent[1] + (extent[0] * x_factor)
1996+
if extent[2] < 0:
1997+
y_factor = array.shape[0] / (extent[3] - extent[2])
1998+
image_extent[2] = image_extent[2] + (extent[2] * y_factor)
1999+
image_extent[3] = image_extent[3] + (extent[2] * y_factor)
2000+
19812001
if not cmap and alpha is not None:
19822002
im = ax.imshow(
19832003
array,
19842004
alpha=alpha,
19852005
zorder=zorder,
2006+
extent=tuple(image_extent),
19862007
)
19872008
im.set_transform(trans_data)
19882009
else:
19892010
im = ax.imshow(
19902011
array,
19912012
cmap=cmap,
19922013
zorder=zorder,
2014+
extent=tuple(image_extent),
19932015
)
19942016
im.set_transform(trans_data)
19952017
return im
@@ -2055,7 +2077,7 @@ def _get_extent_and_range_for_datashader_canvas(
20552077

20562078
def _create_image_from_datashader_result(
20572079
ds_result: ds.transfer_functions.Image, factor: float, ax: Axes
2058-
) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.CompositeGenericTransform]:
2080+
) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.Transform]:
20592081
# create SpatialImage from datashader output to get it back to original size
20602082
rgba_image_data = ds_result.to_numpy().base
20612083
rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1))
@@ -2187,3 +2209,34 @@ def _prepare_transformation(
21872209
trans_data = trans + ax.transData if ax is not None else None
21882210

21892211
return trans, trans_data
2212+
2213+
2214+
def _get_datashader_trans_matrix_of_single_element(
2215+
trans: Identity | Scale | Affine | MapAxis | Translation,
2216+
) -> npt.NDArray[Any]:
2217+
flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]])
2218+
tm: npt.NDArray[Any] = trans.to_affine_matrix(("x", "y"), ("x", "y"))
2219+
2220+
if isinstance(trans, Identity):
2221+
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
2222+
if isinstance(trans, (Scale | Affine)):
2223+
# idea: "flip the y-axis", apply transformation, flip back
2224+
flip_and_transform: npt.NDArray[Any] = flip_matrix @ tm @ flip_matrix
2225+
return flip_and_transform
2226+
if isinstance(trans, MapAxis):
2227+
# no flipping needed
2228+
return tm
2229+
# for a Translation, we need the transposed transformation matrix
2230+
return tm.T
2231+
2232+
2233+
def _get_transformation_matrix_for_datashader(
2234+
trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence,
2235+
) -> npt.NDArray[Any]:
2236+
"""Get the affine matrix needed to transform shapes for rendering with datashader."""
2237+
if isinstance(trans, SDSequence):
2238+
tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
2239+
for x in trans.transformations:
2240+
tm = tm @ _get_datashader_trans_matrix_of_single_element(x)
2241+
return tm
2242+
return _get_datashader_trans_matrix_of_single_element(trans)
Loading
Loading
Loading
Loading
Loading

0 commit comments

Comments
 (0)