1
- from typing import Optional , Union
1
+ from abc import ABC , ABCMeta
2
+ from functools import wraps
3
+ from pathlib import Path
4
+ from typing import Callable , Optional , Union
2
5
6
+ import matplotlib .pyplot as plt
3
7
import numpy as np
4
8
import pandas as pd
5
9
import pyarrow as pa
6
10
import pytest
7
11
import spatialdata as sd
8
12
from anndata import AnnData
9
13
from geopandas import GeoDataFrame
14
+ from matplotlib .testing .compare import compare_images
10
15
from multiscale_spatial_image import MultiscaleSpatialImage
11
16
from numpy .random import default_rng
12
17
from shapely .geometry import MultiPolygon , Polygon
13
18
from spatial_image import SpatialImage
14
19
from spatialdata import SpatialData
20
+ from spatialdata .datasets import blobs
15
21
from spatialdata .models import (
16
22
Image2DModel ,
17
23
Image3DModel ,
25
31
26
32
import spatialdata_plot # noqa: F401
27
33
34
+ HERE : Path = Path (__file__ ).parent
35
+
36
+ EXPECTED = HERE / "_images"
37
+ ACTUAL = HERE / "figures"
38
+ TOL = 60
39
+ DPI = 40
40
+
28
41
RNG = default_rng ()
29
42
30
43
@@ -39,6 +52,11 @@ def full_sdata() -> SpatialData:
39
52
)
40
53
41
54
55
+ @pytest .fixture ()
56
+ def sdata_blobs () -> SpatialData :
57
+ return blobs ()
58
+
59
+
42
60
@pytest .fixture
43
61
def test_sdata_single_image ():
44
62
"""Creates a simple sdata object."""
@@ -152,17 +170,6 @@ def table_multiple_annotations() -> SpatialData:
152
170
return SpatialData (table = _get_table (region = ["sample1" , "sample2" ]))
153
171
154
172
155
- # @pytest.fixture()
156
- # def empty_points() -> SpatialData:
157
- # geo_df = GeoDataFrame(
158
- # geometry=[],
159
- # )
160
- # from spatialdata.transformations import Identity
161
- # set_transform(geo_df, Identity())
162
- #
163
- # return SpatialData(points={"empty": geo_df})
164
-
165
-
166
173
@pytest .fixture ()
167
174
def empty_table () -> SpatialData :
168
175
adata = AnnData (shape = (0 , 0 ))
@@ -337,3 +344,48 @@ def _get_table(
337
344
return TableModel .parse (adata = adata , region = region , region_key = region_key , instance_key = instance_key )
338
345
else :
339
346
return TableModel .parse (adata = adata , region = region , region_key = region_key , instance_key = instance_key )
347
+
348
+
349
+ class PlotTesterMeta (ABCMeta ):
350
+ def __new__ (cls , clsname , superclasses , attributedict ):
351
+ for key , value in attributedict .items ():
352
+ if callable (value ):
353
+ attributedict [key ] = _decorate (value , clsname , name = key )
354
+ return super ().__new__ (cls , clsname , superclasses , attributedict )
355
+
356
+
357
+ class PlotTester (ABC ): # noqa: B024
358
+ @classmethod
359
+ def compare (cls , basename : str , tolerance : Optional [float ] = None ):
360
+ ACTUAL .mkdir (parents = True , exist_ok = True )
361
+ out_path = ACTUAL / f"{ basename } .png"
362
+
363
+ plt .savefig (out_path , dpi = DPI )
364
+ plt .close ()
365
+
366
+ if tolerance is None :
367
+ # see https://github.com/scverse/squidpy/pull/302
368
+ tolerance = 2 * TOL if "Napari" in str (basename ) else TOL
369
+
370
+ res = compare_images (str (EXPECTED / f"{ basename } .png" ), str (out_path ), tolerance )
371
+
372
+ assert res is None , res
373
+
374
+
375
+ def _decorate (fn : Callable , clsname : str , name : Optional [str ] = None ) -> Callable :
376
+ @wraps (fn )
377
+ def save_and_compare (self , * args , ** kwargs ):
378
+ fn (self , * args , ** kwargs )
379
+ self .compare (fig_name )
380
+
381
+ if not callable (fn ):
382
+ raise TypeError (f"Expected a `callable` for class `{ clsname } `, found `{ type (fn ).__name__ } `." )
383
+
384
+ name = fn .__name__ if name is None else name
385
+
386
+ if not name .startswith ("test_plot_" ) or not clsname .startswith ("Test" ):
387
+ return fn
388
+
389
+ fig_name = f"{ clsname [4 :]} _{ name [10 :]} "
390
+
391
+ return save_and_compare
0 commit comments