Skip to content

Commit 3a70a99

Browse files
Adding a few e2e tests (#99)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 422328f commit 3a70a99

12 files changed

+98
-37
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,8 +692,15 @@ def show(
692692
cs_contents.query(f"cs == '{cs}'")["has_shapes"][0],
693693
]
694694
):
695-
ax.set_xlim(extent[cs][0], extent[cs][1])
696-
ax.set_ylim(extent[cs][3], extent[cs][2]) # (0, 0) is top-left
695+
# If the axis already has limits, only expand them but not overwrite
696+
x_min, x_max = ax.get_xlim()
697+
y_min, y_max = ax.get_ylim()
698+
x_min = min(x_min, extent[cs][0])
699+
x_max = max(x_max, extent[cs][1])
700+
y_min = min(y_min, extent[cs][2])
701+
y_max = max(y_max, extent[cs][3])
702+
ax.set_xlim(x_min, x_max)
703+
ax.set_ylim(y_max, y_min) # (0, 0) is top-left
697704

698705
if fig_params.fig is not None and save is not None:
699706
save_fig(fig_params.fig, path=save)

src/spatialdata_plot/pl/utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from types import MappingProxyType
1010
from typing import Any, Literal, Optional, Union
1111

12+
import matplotlib
1213
import matplotlib.pyplot as plt
1314
import multiscale_spatial_image as msi
1415
import numpy as np
@@ -20,7 +21,6 @@
2021
from cycler import Cycler, cycler
2122
from matplotlib import colors, patheffects, rcParams
2223
from matplotlib.axes import Axes
23-
from matplotlib.cm import get_cmap
2424
from matplotlib.collections import PatchCollection
2525
from matplotlib.colors import Colormap, LinearSegmentedColormap, ListedColormap, Normalize, TwoSlopeNorm, to_rgba
2626
from matplotlib.figure import Figure
@@ -271,13 +271,19 @@ def _get_extent_after_transformations(element: Any, cs_name: str) -> Sequence[in
271271
for images_key in sdata.images:
272272
for e_id in element_ids:
273273
if images_key == e_id:
274-
extent[cs_name][e_id] = _get_extent_after_transformations(sdata.images[e_id], cs_name)
274+
if not isinstance(sdata.images[e_id], msi.multiscale_spatial_image.MultiscaleSpatialImage):
275+
extent[cs_name][e_id] = _get_extent_after_transformations(sdata.images[e_id], cs_name)
276+
else:
277+
pass
275278

276279
if has_labels and cs_contents.query(f"cs == '{cs_name}'")["has_labels"][0]:
277280
for labels_key in sdata.labels:
278281
for e_id in element_ids:
279282
if labels_key == e_id:
280-
extent[cs_name][e_id] = _get_extent_after_transformations(sdata.labels[e_id], cs_name)
283+
if not isinstance(sdata.labels[e_id], msi.multiscale_spatial_image.MultiscaleSpatialImage):
284+
extent[cs_name][e_id] = _get_extent_after_transformations(sdata.labels[e_id], cs_name)
285+
else:
286+
pass
281287

282288
if has_shapes and cs_contents.query(f"cs == '{cs_name}'")["has_shapes"][0]:
283289
for shapes_key in sdata.shapes:
@@ -303,7 +309,9 @@ def get_point_bb(
303309
sdata.shapes[e_id]["geometry"].apply(lambda geom: geom.geom_type == "Point")
304310
]
305311
tmp_polygons = sdata.shapes[e_id][
306-
sdata.shapes[e_id]["geometry"].apply(lambda geom: geom.geom_type == "Polygon")
312+
sdata.shapes[e_id]["geometry"].apply(
313+
lambda geom: geom.geom_type in ["Polygon", "MultiPolygon"]
314+
)
307315
]
308316

309317
if not tmp_points.empty:
@@ -448,7 +456,7 @@ def _prepare_cmap_norm(
448456
vcenter: float | None = None,
449457
**kwargs: Any,
450458
) -> CmapParams:
451-
cmap = copy(get_cmap(cmap))
459+
cmap = copy(matplotlib.colormaps[rcParams["image.cmap"] if cmap is None else cmap])
452460
cmap.set_bad("lightgray" if na_color is None else na_color)
453461

454462
if isinstance(norm, Normalize):
35.8 KB
Loading

tests/_images/Images_images.png

-37.6 KB
Binary file not shown.

tests/_images/Points_points.png

7.44 KB
Loading
6.16 KB
Loading
6.16 KB
Loading

tests/pl/test_plot.py renamed to tests/pl/test_render_images.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,9 @@
1818
# ".png" is appended to <your_filename>, no need to set it
1919

2020

21-
class TestLabels(PlotTester, metaclass=PlotTesterMeta):
22-
def test_plot_labels(self, sdata_blobs: SpatialData):
23-
# TODO: support multiscale labels
24-
if "blobs_multiscale_labels" in sdata_blobs.labels:
25-
del sdata_blobs.labels["blobs_multiscale_labels"]
26-
sdata_blobs.pl.render_labels(color="channel_2_mean").pl.show()
27-
28-
2921
class TestImages(PlotTester, metaclass=PlotTesterMeta):
30-
def test_plot_images(self, sdata_blobs: SpatialData):
22+
def test_plot_can_render_image(self, sdata_blobs: SpatialData):
3123
sdata_blobs.pl.render_images(elements="blobs_image").pl.show()
24+
25+
# def test_plot_can_render_multiscale_image(self, sdata_blobs: SpatialData):
26+
# sdata_blobs.pl.render_images(elements="blobs_multiscale_image").pl.show()

tests/pl/test_render_labels.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import matplotlib
2+
import scanpy as sc
3+
import spatialdata_plot # noqa: F401
4+
from spatialdata import SpatialData
5+
6+
from tests.conftest import PlotTester, PlotTesterMeta
7+
8+
sc.pl.set_rcParams_defaults()
9+
sc.set_figure_params(dpi=40, color_map="viridis")
10+
matplotlib.use("agg") # same as GitHub action runner
11+
_ = spatialdata_plot
12+
13+
# WARNING:
14+
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
15+
# 2. tests which produce a plot must be prefixed with `test_plot_`
16+
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
17+
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
18+
# ".png" is appended to <your_filename>, no need to set it
19+
20+
21+
class TestLabels(PlotTester, metaclass=PlotTesterMeta):
22+
def test_plot_labels(self, sdata_blobs: SpatialData):
23+
sdata_blobs.pl.render_labels(elements="blobs_labels").pl.show()

tests/pl/test_render_points.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import matplotlib
2+
import scanpy as sc
3+
import spatialdata_plot # noqa: F401
4+
from spatialdata import SpatialData
5+
6+
from tests.conftest import PlotTester, PlotTesterMeta
7+
8+
sc.pl.set_rcParams_defaults()
9+
sc.set_figure_params(dpi=40, color_map="viridis")
10+
matplotlib.use("agg") # same as GitHub action runner
11+
_ = spatialdata_plot
12+
13+
# WARNING:
14+
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
15+
# 2. tests which produce a plot must be prefixed with `test_plot_`
16+
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
17+
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
18+
# ".png" is appended to <your_filename>, no need to set it
19+
20+
21+
class TestPoints(PlotTester, metaclass=PlotTesterMeta):
22+
def test_plot_points(self, sdata_blobs: SpatialData):
23+
sdata_blobs.pl.render_points(elements="blobs_points").pl.show()

tests/pl/test_render_shapes.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import matplotlib
2+
import scanpy as sc
3+
import spatialdata_plot # noqa: F401
4+
from spatialdata import SpatialData
5+
6+
from tests.conftest import PlotTester, PlotTesterMeta
7+
8+
sc.pl.set_rcParams_defaults()
9+
sc.set_figure_params(dpi=40, color_map="viridis")
10+
matplotlib.use("agg") # same as GitHub action runner
11+
_ = spatialdata_plot
12+
13+
# WARNING:
14+
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
15+
# 2. tests which produce a plot must be prefixed with `test_plot_`
16+
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
17+
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
18+
# ".png" is appended to <your_filename>, no need to set it
19+
20+
21+
class TestShapes(PlotTester, metaclass=PlotTesterMeta):
22+
def test_plot_can_render_circles(self, sdata_blobs: SpatialData):
23+
sdata_blobs.pl.render_shapes(element="blobs_circles").pl.show()
24+
25+
def test_plot_can_render_polygons(self, sdata_blobs: SpatialData):
26+
sdata_blobs.pl.render_shapes(element="blobs_polygons").pl.show()

tests/test_pp.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -150,24 +150,3 @@ def test_get_bb_wrong_input_length(sdata, request):
150150

151151
with pytest.raises(ValueError, match="Parameter 'y' must be of length 2."):
152152
sdata.pp.get_bb((0, 5), (0, 5, 2))
153-
154-
155-
# @pytest.mark.parametrize(
156-
# "sdata, keys, nrows ",
157-
# [
158-
# ("full_sdata", "data1", 3),
159-
# ("full_sdata", ["data1", "data3"], 23),
160-
# ],
161-
# )
162-
# def test_table_gets_subset_when_images_are_subset(sdata, keys, nrows, request):
163-
# """Tests wether the images inside sdata can be clipped to a bounding box."""
164-
165-
# sdata = request.getfixturevalue(sdata)
166-
167-
# assert sdata.table.n_obs == 30
168-
169-
# new_sdata = sdata.pp.get_elements(keys)
170-
171-
# print(new_sdata.table)
172-
173-
# assert len(new_sdata.table.obs) == nrows

0 commit comments

Comments
 (0)