Skip to content

Complex testing for correct plotting after applying transformation #198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ and this project adheres to [Semantic Versioning][].
- Multiscale image handling: user can specify a scale, else the best scale is selected automatically given the figure size and dpi (#164)
- Large images are automatically rasterized to speed up performance (#164)
- Added better error message for mismatch in cs and ax number (#185)
- Beter test coverage for correct plotting of elements after transformation (#198)
- Can now stack render commands (#190, #192)

### Fixed

- Now dropping index when plotting shapes after spatial query (#177)
- Points are now being correctly rotated (#198)
- User can now pass Colormap objects to the cmap argument in render_images. When only one cmap is given for 3 channels, it is now applied to each channel (#188, #194)

## [0.0.6] - 2023-11-06
Expand Down
13 changes: 5 additions & 8 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def _render_points(
coords.extend(color)

points = points[coords].compute()
# points[color[0]].cat.set_categories(render_params.groups, inplace=True)
if render_params.groups is not None:
points = points[points[color].isin(render_params.groups).values]
points[color[0]] = points[color[0]].cat.set_categories(render_params.groups)
Expand Down Expand Up @@ -260,6 +259,10 @@ def _render_points(
if color_source_vector is None and render_params.transfunc is not None:
color_vector = render_params.transfunc(color_vector)

trans = get_transformation(sdata.points[e], get_all=True)[coordinate_system]
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData

norm = copy(render_params.cmap_params.norm)
_cax = ax.scatter(
adata[:, 0].X.flatten(),
Expand All @@ -270,17 +273,11 @@ def _render_points(
cmap=render_params.cmap_params.cmap,
norm=norm,
alpha=render_params.alpha,
transform=trans
# **kwargs,
)
cax = ax.add_collection(_cax)

trans = get_transformation(sdata.points[e], get_all=True)[coordinate_system]
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
trans = mtransforms.Affine2D(matrix=affine_trans)

for path in _cax.get_paths():
path.vertices = trans.transform(path.vertices)

if not (
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
):
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
98 changes: 98 additions & 0 deletions tests/pl/test_get_extent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import math

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
import spatialdata_plot # noqa: F401
from geopandas import GeoDataFrame
from shapely.geometry import MultiPolygon, Point, Polygon
from spatialdata import SpatialData
from spatialdata.models import PointsModel, ShapesModel
from spatialdata.transformations import Affine, set_transformation

from tests.conftest import PlotTester, PlotTesterMeta

Expand Down Expand Up @@ -42,3 +50,93 @@ def test_plot_extent_of_img_is_correct_after_spatial_query(self, sdata_blobs: Sp
axes=["x", "y"], min_coordinate=[100, 100], max_coordinate=[400, 400], target_coordinate_system="global"
)
cropped_blobs.pl.render_images().pl.show()

def test_plot_correct_plot_after_transformations(self):
# inspired by https://github.com/scverse/spatialdata/blob/ef0a2dc7f9af8d4c84f15eec503177f1d08c3d46/tests/core/test_data_extent.py#L125

circles = [Point(p) for p in [[0.5, 0.1], [0.9, 0.5], [0.5, 0.9], [0.1, 0.5]]]
circles_gdf = GeoDataFrame(geometry=circles)
circles_gdf["radius"] = 0.1
circles_gdf = ShapesModel.parse(circles_gdf)

polygons = [Polygon([(0.5, 0.5), (0.5, 0), (0.6, 0.1), (0.5, 0.5)])]
polygons.append(Polygon([(0.5, 0.5), (1, 0.5), (0.9, 0.6), (0.5, 0.5)]))
polygons.append(Polygon([(0.5, 0.5), (0.5, 1), (0.4, 0.9), (0.5, 0.5)]))
polygons.append(Polygon([(0.5, 0.5), (0, 0.5), (0.1, 0.4), (0.5, 0.5)]))
polygons_gdf = GeoDataFrame(geometry=polygons)
polygons_gdf = ShapesModel.parse(polygons_gdf)

multipolygons = [
MultiPolygon(
[
polygons[0],
Polygon([(0.7, 0.1), (0.9, 0.1), (0.9, 0.3), (0.7, 0.1)]),
]
)
]
multipolygons.append(MultiPolygon([polygons[1], Polygon([(0.9, 0.7), (0.9, 0.9), (0.7, 0.9), (0.9, 0.7)])]))
multipolygons.append(MultiPolygon([polygons[2], Polygon([(0.3, 0.9), (0.1, 0.9), (0.1, 0.7), (0.3, 0.9)])]))
multipolygons.append(MultiPolygon([polygons[3], Polygon([(0.1, 0.3), (0.1, 0.1), (0.3, 0.1), (0.1, 0.3)])]))
multipolygons_gdf = GeoDataFrame(geometry=multipolygons)
multipolygons_gdf = ShapesModel.parse(multipolygons_gdf)

points_df = PointsModel.parse(np.array([[0.5, 0], [1, 0.5], [0.5, 1], [0, 0.5]]))

sdata = SpatialData(
shapes={
"circles": circles_gdf,
"polygons": polygons_gdf,
"multipolygons": multipolygons_gdf,
"circles_pi3": circles_gdf,
"polygons_pi3": polygons_gdf,
"multipolygons_pi3": multipolygons_gdf,
"circles_pi4": circles_gdf,
"polygons_pi4": polygons_gdf,
"multipolygons_pi4": multipolygons_gdf,
},
points={"points": points_df, "points_pi3": points_df, "points_pi4": points_df},
)

for i in [3, 4]:
theta = math.pi / i
rotation = Affine(
[
[math.cos(theta), -math.sin(theta), 0],
[math.sin(theta), math.cos(theta), 0],
[0, 0, 1],
],
input_axes=("x", "y"),
output_axes=("x", "y"),
)
for element_name in [f"circles_pi{i}", f"polygons_pi{i}", f"multipolygons_pi{i}", f"points_pi{i}"]:
set_transformation(element=sdata[element_name], transformation=rotation, to_coordinate_system=f"pi{i}")

_, axs = plt.subplots(ncols=3, nrows=4, figsize=(7, 9))

for cs_idx, cs in enumerate(["global", "pi3", "pi4"]):
if cs == "global":
circles_name = "circles"
polygons_name = "polygons"
multipolygons_name = "multipolygons"
points_name = "points"
elif cs == "pi3":
circles_name = "circles_pi3"
polygons_name = "polygons_pi3"
multipolygons_name = "multipolygons_pi3"
points_name = "points_pi3"
else:
circles_name = "circles_pi4"
polygons_name = "polygons_pi4"
multipolygons_name = "multipolygons_pi4"
points_name = "points_pi4"

sdata.pl.render_shapes(elements=circles_name).pl.show(coordinate_systems=cs, ax=axs[0, cs_idx], title="")
sdata.pl.render_shapes(elements=polygons_name).pl.show(coordinate_systems=cs, ax=axs[1, cs_idx], title="")
sdata.pl.render_shapes(elements=multipolygons_name).pl.show(
coordinate_systems=cs, ax=axs[2, cs_idx], title=""
)
sdata.pl.render_points(elements=points_name, size=10).pl.show(
coordinate_systems=cs, ax=axs[3, cs_idx], title="", pad_extent=0.02
)

plt.tight_layout()