Skip to content

Commit 18c2f92

Browse files
giovptimtreispre-commit-ci[bot]
authored
Basic structure for comparing generated plots in unit tests
Co-authored-by: Tim Treis <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6be6ca3 commit 18c2f92

File tree

10 files changed

+107
-12
lines changed

10 files changed

+107
-12
lines changed

.github/workflows/test.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ jobs:
5555
DISPLAY: :42
5656
run: |
5757
pytest -v --cov --color=yes --cov-report=xml
58+
# - name: Generate GH action "groundtruth" figures as artifacts, uncomment if needed
59+
# if: always()
60+
# uses: actions/upload-artifact@v3
61+
# with:
62+
# name: groundtruth-figures
63+
# path: /home/runner/work/spatialdata-plot/spatialdata-plot/tests/_images/*
64+
- name: Archive figures generated during testing
65+
if: always()
66+
uses: actions/upload-artifact@v3
67+
with:
68+
name: plotting-results
69+
path: /home/runner/work/spatialdata-plot/spatialdata-plot/tests/figures/*
5870
- name: Upload coverage to Codecov
5971
uses: codecov/[email protected]
6072
with:

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,9 @@ __pycache__/
2929

3030
format.sh
3131

32+
33+
# test
34+
tests/figures/
35+
3236
# jupyter checkpoints
3337
.ipynb_checkpoints

tests/__init__.py

Whitespace-only changes.

tests/_images/Images_images.png

36.5 KB
Loading

tests/_images/Labels_labels.png

14.7 KB
Loading

tests/conftest.py

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
1-
from typing import Optional, Union
1+
from abc import ABC, ABCMeta
2+
from functools import wraps
3+
from pathlib import Path
4+
from typing import Callable, Optional, Union
25

6+
import matplotlib.pyplot as plt
37
import numpy as np
48
import pandas as pd
59
import pyarrow as pa
610
import pytest
711
import spatialdata as sd
812
from anndata import AnnData
913
from geopandas import GeoDataFrame
14+
from matplotlib.testing.compare import compare_images
1015
from multiscale_spatial_image import MultiscaleSpatialImage
1116
from numpy.random import default_rng
1217
from shapely.geometry import MultiPolygon, Polygon
1318
from spatial_image import SpatialImage
1419
from spatialdata import SpatialData
20+
from spatialdata.datasets import blobs
1521
from spatialdata.models import (
1622
Image2DModel,
1723
Image3DModel,
@@ -25,6 +31,13 @@
2531

2632
import spatialdata_plot # noqa: F401
2733

34+
HERE: Path = Path(__file__).parent
35+
36+
EXPECTED = HERE / "_images"
37+
ACTUAL = HERE / "figures"
38+
TOL = 60
39+
DPI = 40
40+
2841
RNG = default_rng()
2942

3043

@@ -39,6 +52,11 @@ def full_sdata() -> SpatialData:
3952
)
4053

4154

55+
@pytest.fixture()
56+
def sdata_blobs() -> SpatialData:
57+
return blobs()
58+
59+
4260
@pytest.fixture
4361
def test_sdata_single_image():
4462
"""Creates a simple sdata object."""
@@ -152,17 +170,6 @@ def table_multiple_annotations() -> SpatialData:
152170
return SpatialData(table=_get_table(region=["sample1", "sample2"]))
153171

154172

155-
# @pytest.fixture()
156-
# def empty_points() -> SpatialData:
157-
# geo_df = GeoDataFrame(
158-
# geometry=[],
159-
# )
160-
# from spatialdata.transformations import Identity
161-
# set_transform(geo_df, Identity())
162-
#
163-
# return SpatialData(points={"empty": geo_df})
164-
165-
166173
@pytest.fixture()
167174
def empty_table() -> SpatialData:
168175
adata = AnnData(shape=(0, 0))
@@ -337,3 +344,48 @@ def _get_table(
337344
return TableModel.parse(adata=adata, region=region, region_key=region_key, instance_key=instance_key)
338345
else:
339346
return TableModel.parse(adata=adata, region=region, region_key=region_key, instance_key=instance_key)
347+
348+
349+
class PlotTesterMeta(ABCMeta):
350+
def __new__(cls, clsname, superclasses, attributedict):
351+
for key, value in attributedict.items():
352+
if callable(value):
353+
attributedict[key] = _decorate(value, clsname, name=key)
354+
return super().__new__(cls, clsname, superclasses, attributedict)
355+
356+
357+
class PlotTester(ABC): # noqa: B024
358+
@classmethod
359+
def compare(cls, basename: str, tolerance: Optional[float] = None):
360+
ACTUAL.mkdir(parents=True, exist_ok=True)
361+
out_path = ACTUAL / f"{basename}.png"
362+
363+
plt.savefig(out_path, dpi=DPI)
364+
plt.close()
365+
366+
if tolerance is None:
367+
# see https://github.com/scverse/squidpy/pull/302
368+
tolerance = 2 * TOL if "Napari" in str(basename) else TOL
369+
370+
res = compare_images(str(EXPECTED / f"{basename}.png"), str(out_path), tolerance)
371+
372+
assert res is None, res
373+
374+
375+
def _decorate(fn: Callable, clsname: str, name: Optional[str] = None) -> Callable:
376+
@wraps(fn)
377+
def save_and_compare(self, *args, **kwargs):
378+
fn(self, *args, **kwargs)
379+
self.compare(fig_name)
380+
381+
if not callable(fn):
382+
raise TypeError(f"Expected a `callable` for class `{clsname}`, found `{type(fn).__name__}`.")
383+
384+
name = fn.__name__ if name is None else name
385+
386+
if not name.startswith("test_plot_") or not clsname.startswith("Test"):
387+
return fn
388+
389+
fig_name = f"{clsname[4:]}_{name[10:]}"
390+
391+
return save_and_compare

tests/figures/Labels_images.png

-19.8 KB
Binary file not shown.

tests/figures/Labels_labels.png

-5.21 KB
Binary file not shown.

tests/pl/__init__.py

Whitespace-only changes.

tests/pl/test_plot.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import matplotlib
2+
import scanpy as sc
3+
from spatialdata import SpatialData
4+
5+
import spatialdata_plot # noqa: F401
6+
from tests.conftest import PlotTester, PlotTesterMeta
7+
8+
sc.pl.set_rcParams_defaults()
9+
sc.set_figure_params(dpi=40, color_map="viridis")
10+
matplotlib.use("agg") # same as GitHub action runner
11+
12+
# WARNING:
13+
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
14+
# 2. tests which produce a plot must be prefixed with `test_plot_`
15+
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
16+
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
17+
# ".png" is appended to <your_filename>, no need to set it
18+
19+
20+
class TestLabels(PlotTester, metaclass=PlotTesterMeta):
21+
def test_plot_labels(self, sdata_blobs: SpatialData):
22+
sdata_blobs.pl.render_labels(color="channel_2_mean").pl.show()
23+
24+
25+
class TestImages(PlotTester, metaclass=PlotTesterMeta):
26+
def test_plot_images(self, sdata_blobs: SpatialData):
27+
sdata_blobs.pl.render_images().pl.show()

0 commit comments

Comments
 (0)