Skip to content

Adding a few e2e tests #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,15 @@ def show(
cs_contents.query(f"cs == '{cs}'")["has_shapes"][0],
]
):
ax.set_xlim(extent[cs][0], extent[cs][1])
ax.set_ylim(extent[cs][3], extent[cs][2]) # (0, 0) is top-left
# If the axis already has limits, only expand them but not overwrite
x_min, x_max = ax.get_xlim()
y_min, y_max = ax.get_ylim()
x_min = min(x_min, extent[cs][0])
x_max = max(x_max, extent[cs][1])
y_min = min(y_min, extent[cs][2])
y_max = max(y_max, extent[cs][3])
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_max, y_min) # (0, 0) is top-left

if fig_params.fig is not None and save is not None:
save_fig(fig_params.fig, path=save)
Expand Down
18 changes: 13 additions & 5 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from types import MappingProxyType
from typing import Any, Literal, Optional, Union

import matplotlib
import matplotlib.pyplot as plt
import multiscale_spatial_image as msi
import numpy as np
Expand All @@ -20,7 +21,6 @@
from cycler import Cycler, cycler
from matplotlib import colors, patheffects, rcParams
from matplotlib.axes import Axes
from matplotlib.cm import get_cmap
from matplotlib.collections import PatchCollection
from matplotlib.colors import Colormap, LinearSegmentedColormap, ListedColormap, Normalize, TwoSlopeNorm, to_rgba
from matplotlib.figure import Figure
Expand Down Expand Up @@ -271,13 +271,19 @@ def _get_extent_after_transformations(element: Any, cs_name: str) -> Sequence[in
for images_key in sdata.images:
for e_id in element_ids:
if images_key == e_id:
extent[cs_name][e_id] = _get_extent_after_transformations(sdata.images[e_id], cs_name)
if not isinstance(sdata.images[e_id], msi.multiscale_spatial_image.MultiscaleSpatialImage):
extent[cs_name][e_id] = _get_extent_after_transformations(sdata.images[e_id], cs_name)
else:
pass

if has_labels and cs_contents.query(f"cs == '{cs_name}'")["has_labels"][0]:
for labels_key in sdata.labels:
for e_id in element_ids:
if labels_key == e_id:
extent[cs_name][e_id] = _get_extent_after_transformations(sdata.labels[e_id], cs_name)
if not isinstance(sdata.labels[e_id], msi.multiscale_spatial_image.MultiscaleSpatialImage):
extent[cs_name][e_id] = _get_extent_after_transformations(sdata.labels[e_id], cs_name)
else:
pass

if has_shapes and cs_contents.query(f"cs == '{cs_name}'")["has_shapes"][0]:
for shapes_key in sdata.shapes:
Expand All @@ -303,7 +309,9 @@ def get_point_bb(
sdata.shapes[e_id]["geometry"].apply(lambda geom: geom.geom_type == "Point")
]
tmp_polygons = sdata.shapes[e_id][
sdata.shapes[e_id]["geometry"].apply(lambda geom: geom.geom_type == "Polygon")
sdata.shapes[e_id]["geometry"].apply(
lambda geom: geom.geom_type in ["Polygon", "MultiPolygon"]
)
]

if not tmp_points.empty:
Expand Down Expand Up @@ -448,7 +456,7 @@ def _prepare_cmap_norm(
vcenter: float | None = None,
**kwargs: Any,
) -> CmapParams:
cmap = copy(get_cmap(cmap))
cmap = copy(matplotlib.colormaps[rcParams["image.cmap"] if cmap is None else cmap])
cmap.set_bad("lightgray" if na_color is None else na_color)

if isinstance(norm, Normalize):
Expand Down
Binary file added tests/_images/Images_can_render_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed tests/_images/Images_images.png
Binary file not shown.
Binary file added tests/_images/Points_points.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Shapes_can_render_circles.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Shapes_can_render_polygons.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 4 additions & 9 deletions tests/pl/test_plot.py → tests/pl/test_render_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@
# ".png" is appended to <your_filename>, no need to set it


class TestLabels(PlotTester, metaclass=PlotTesterMeta):
def test_plot_labels(self, sdata_blobs: SpatialData):
# TODO: support multiscale labels
if "blobs_multiscale_labels" in sdata_blobs.labels:
del sdata_blobs.labels["blobs_multiscale_labels"]
sdata_blobs.pl.render_labels(color="channel_2_mean").pl.show()


class TestImages(PlotTester, metaclass=PlotTesterMeta):
def test_plot_images(self, sdata_blobs: SpatialData):
def test_plot_can_render_image(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(elements="blobs_image").pl.show()

# def test_plot_can_render_multiscale_image(self, sdata_blobs: SpatialData):
# sdata_blobs.pl.render_images(elements="blobs_multiscale_image").pl.show()
23 changes: 23 additions & 0 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import matplotlib
import scanpy as sc
import spatialdata_plot # noqa: F401
from spatialdata import SpatialData

from tests.conftest import PlotTester, PlotTesterMeta

sc.pl.set_rcParams_defaults()
sc.set_figure_params(dpi=40, color_map="viridis")
matplotlib.use("agg") # same as GitHub action runner
_ = spatialdata_plot

# WARNING:
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
# 2. tests which produce a plot must be prefixed with `test_plot_`
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
# ".png" is appended to <your_filename>, no need to set it


class TestLabels(PlotTester, metaclass=PlotTesterMeta):
def test_plot_labels(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_labels(elements="blobs_labels").pl.show()
23 changes: 23 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import matplotlib
import scanpy as sc
import spatialdata_plot # noqa: F401
from spatialdata import SpatialData

from tests.conftest import PlotTester, PlotTesterMeta

sc.pl.set_rcParams_defaults()
sc.set_figure_params(dpi=40, color_map="viridis")
matplotlib.use("agg") # same as GitHub action runner
_ = spatialdata_plot

# WARNING:
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
# 2. tests which produce a plot must be prefixed with `test_plot_`
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
# ".png" is appended to <your_filename>, no need to set it


class TestPoints(PlotTester, metaclass=PlotTesterMeta):
def test_plot_points(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_points(elements="blobs_points").pl.show()
26 changes: 26 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import matplotlib
import scanpy as sc
import spatialdata_plot # noqa: F401
from spatialdata import SpatialData

from tests.conftest import PlotTester, PlotTesterMeta

sc.pl.set_rcParams_defaults()
sc.set_figure_params(dpi=40, color_map="viridis")
matplotlib.use("agg") # same as GitHub action runner
_ = spatialdata_plot

# WARNING:
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
# 2. tests which produce a plot must be prefixed with `test_plot_`
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
# ".png" is appended to <your_filename>, no need to set it


class TestShapes(PlotTester, metaclass=PlotTesterMeta):
def test_plot_can_render_circles(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles").pl.show()

def test_plot_can_render_polygons(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons").pl.show()
21 changes: 0 additions & 21 deletions tests/test_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,24 +150,3 @@ def test_get_bb_wrong_input_length(sdata, request):

with pytest.raises(ValueError, match="Parameter 'y' must be of length 2."):
sdata.pp.get_bb((0, 5), (0, 5, 2))


# @pytest.mark.parametrize(
# "sdata, keys, nrows ",
# [
# ("full_sdata", "data1", 3),
# ("full_sdata", ["data1", "data3"], 23),
# ],
# )
# def test_table_gets_subset_when_images_are_subset(sdata, keys, nrows, request):
# """Tests wether the images inside sdata can be clipped to a bounding box."""

# sdata = request.getfixturevalue(sdata)

# assert sdata.table.n_obs == 30

# new_sdata = sdata.pp.get_elements(keys)

# print(new_sdata.table)

# assert len(new_sdata.table.obs) == nrows