Skip to content

Commit bde6827

Browse files
authored
Fixed bug because of which extent of previous axs was not respected (#101)
1 parent f1026ca commit bde6827

23 files changed

+157
-19
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any
88

99
import matplotlib.pyplot as plt
10+
import numpy as np
1011
import scanpy as sc
1112
import spatialdata as sd
1213
from anndata import AnnData
@@ -38,6 +39,7 @@
3839
_get_cs_contents,
3940
_get_extent,
4041
_maybe_set_colors,
42+
_mpl_ax_contains_elements,
4143
_multiscale_to_image,
4244
_prepare_cmap_norm,
4345
_prepare_params_plot,
@@ -453,6 +455,7 @@ def show(
453455
fig: Figure | None = None,
454456
title: None | str | Sequence[str] = None,
455457
share_extent: bool = True,
458+
pad_extent: int = 0,
456459
ax: Axes | Sequence[Axes] | None = None,
457460
return_ax: bool = False,
458461
save: None | str | Path = None,
@@ -522,6 +525,14 @@ def show(
522525
# Simplicstic solution: If the images are multiscale, just use the first
523526
sdata = _multiscale_to_image(sdata)
524527

528+
# get original axis extent for later comparison
529+
x_min_orig, x_max_orig = (np.inf, -np.inf)
530+
y_min_orig, y_max_orig = (np.inf, -np.inf)
531+
532+
if isinstance(ax, Axes) and _mpl_ax_contains_elements(ax):
533+
x_min_orig, x_max_orig = ax.get_xlim()
534+
y_max_orig, y_min_orig = ax.get_ylim() # (0, 0) is top-left
535+
525536
# handle coordinate system
526537
coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems
527538
if isinstance(coordinate_systems, str):
@@ -531,12 +542,38 @@ def show(
531542
if cs not in sdata.coordinate_systems:
532543
raise ValueError(f"Unknown coordinate system '{cs}', valid choices are: {sdata.coordinate_systems}")
533544

545+
# Check if user specified only certain elements to be plotted
546+
cs_contents = _get_cs_contents(sdata)
547+
elements_to_be_rendered = []
548+
for cmd, params in render_cmds.items():
549+
if cmd == "render_images" and cs_contents.query(f"cs == '{cs}'")["has_images"][0]: # noqa: SIM114
550+
if params.elements is not None:
551+
elements_to_be_rendered += (
552+
[params.elements] if isinstance(params.elements, str) else params.elements
553+
)
554+
elif cmd == "render_shapes" and cs_contents.query(f"cs == '{cs}'")["has_shapes"][0]: # noqa: SIM114
555+
if params.elements is not None:
556+
elements_to_be_rendered += (
557+
[params.elements] if isinstance(params.elements, str) else params.elements
558+
)
559+
elif cmd == "render_points" and cs_contents.query(f"cs == '{cs}'")["has_points"][0]: # noqa: SIM114
560+
if params.elements is not None:
561+
elements_to_be_rendered += (
562+
[params.elements] if isinstance(params.elements, str) else params.elements
563+
)
564+
elif cmd == "render_labels" and cs_contents.query(f"cs == '{cs}'")["has_labels"][0]: # noqa: SIM102
565+
if params.elements is not None:
566+
elements_to_be_rendered += (
567+
[params.elements] if isinstance(params.elements, str) else params.elements
568+
)
569+
534570
extent = _get_extent(
535571
sdata=sdata,
536572
has_images="render_images" in render_cmds,
537573
has_labels="render_labels" in render_cmds,
538574
has_points="render_points" in render_cmds,
539575
has_shapes="render_shapes" in render_cmds,
576+
elements=elements_to_be_rendered,
540577
coordinate_systems=coordinate_systems,
541578
)
542579

@@ -584,7 +621,6 @@ def show(
584621
)
585622

586623
# go through tree
587-
cs_contents = _get_cs_contents(sdata)
588624
for i, cs in enumerate(coordinate_systems):
589625
sdata = self._copy()
590626
# properly transform all elements to the current coordinate system
@@ -692,12 +728,10 @@ def show(
692728
]
693729
):
694730
# If the axis already has limits, only expand them but not overwrite
695-
x_min, x_max = ax.get_xlim()
696-
y_min, y_max = ax.get_ylim()
697-
x_min = min(x_min, extent[cs][0])
698-
x_max = max(x_max, extent[cs][1])
699-
y_min = min(y_min, extent[cs][2])
700-
y_max = max(y_max, extent[cs][3])
731+
x_min = min(x_min_orig, extent[cs][0]) - pad_extent
732+
x_max = max(x_max_orig, extent[cs][1]) + pad_extent
733+
y_min = min(y_min_orig, extent[cs][2]) - pad_extent
734+
y_max = max(y_max_orig, extent[cs][3]) + pad_extent
701735
ax.set_xlim(x_min, x_max)
702736
ax.set_ylim(y_max, y_min) # (0, 0) is top-left
703737

src/spatialdata_plot/pl/utils.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,12 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame:
175175

176176
def _get_extent(
177177
sdata: sd.SpatialData,
178-
coordinate_systems: None | str | Sequence[str] = None,
178+
coordinate_systems: Sequence[str] | str | None = None,
179179
has_images: bool = True,
180180
has_labels: bool = True,
181181
has_points: bool = True,
182182
has_shapes: bool = True,
183+
elements: Iterable[Any] | None = None,
183184
share_extent: bool = False,
184185
) -> dict[str, tuple[int, int, int, int]]:
185186
"""Return the extent of all elements in their respective coordinate systems.
@@ -188,16 +189,18 @@ def _get_extent(
188189
----------
189190
sdata
190191
The sd.SpatialData object to retrieve the extent from
191-
images
192+
has_images
192193
Flag indicating whether to consider images when calculating the extent
193-
labels
194+
has_labels
194195
Flag indicating whether to consider labels when calculating the extent
195-
points
196+
has_points
196197
Flag indicating whether to consider points when calculating the extent
197-
shapes
198-
Flag indicating whether to consider shaoes when calculating the extent
199-
img_transformations
200-
List of transformations already applied to the images
198+
has_shapes
199+
Flag indicating whether to consider shapes when calculating the extent
200+
elements
201+
Optional list of element names to be considered. When None, all are used.
202+
share_extent
203+
Flag indicating whether to use the same extent for all coordinate systems
201204
202205
Returns
203206
-------
@@ -209,6 +212,12 @@ def _get_extent(
209212
cs_mapping = _get_coordinate_system_mapping(sdata)
210213
cs_contents = _get_cs_contents(sdata)
211214

215+
if elements is None: # to shut up ruff
216+
elements = []
217+
218+
if not isinstance(elements, list):
219+
raise ValueError(f"Invalid type of `elements`: {type(elements)}, expected `list`.")
220+
212221
if coordinate_systems is not None:
213222
if isinstance(coordinate_systems, str):
214223
coordinate_systems = [coordinate_systems]
@@ -217,6 +226,8 @@ def _get_extent(
217226

218227
for cs_name, element_ids in cs_mapping.items():
219228
extent[cs_name] = {}
229+
if len(elements) > 0:
230+
element_ids = [e for e in element_ids if e in elements]
220231

221232
def _get_extent_after_transformations(element: Any, cs_name: str) -> Sequence[int]:
222233
tmp = element.copy()
@@ -1141,3 +1152,17 @@ def _robust_transform(element: Any, cs: str) -> Any:
11411152
raise ValueError("Unable to transform element.") from e
11421153

11431154
return element
1155+
1156+
1157+
def _mpl_ax_contains_elements(ax: Axes) -> bool:
1158+
"""Check if any objects have been plotted on the axes object.
1159+
1160+
While extracting the extent, we need to know if the axes object has just been
1161+
initialised and therefore has extent (0, 1), (0,1) or if it has been plotted on
1162+
and therefore has a different extent.
1163+
1164+
Based on: https://stackoverflow.com/a/71966295
1165+
"""
1166+
return (
1167+
len(ax.lines) > 0 or len(ax.collections) > 0 or len(ax.images) > 0 or len(ax.patches) > 0 or len(ax.tables) > 0
1168+
)
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
14.1 KB
Loading

tests/_images/Labels_labels.png

-14.2 KB
Binary file not shown.
7.44 KB
Loading
-467 Bytes
Loading
-671 Bytes
Loading
21.4 KB
Loading

tests/pl/test_get_extent.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import matplotlib
2+
import scanpy as sc
3+
import spatialdata_plot # noqa: F401
4+
from spatialdata import SpatialData
5+
6+
from tests.conftest import PlotTester, PlotTesterMeta
7+
8+
sc.pl.set_rcParams_defaults()
9+
sc.set_figure_params(dpi=40, color_map="viridis")
10+
matplotlib.use("agg") # same as GitHub action runner
11+
_ = spatialdata_plot
12+
13+
# WARNING:
14+
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
15+
# 2. tests which produce a plot must be prefixed with `test_plot_`
16+
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
17+
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
18+
# ".png" is appended to <your_filename>, no need to set it
19+
20+
21+
class TestExtent(PlotTester, metaclass=PlotTesterMeta):
22+
def test_plot_extent_of_img_full_canvas(self, sdata_blobs: SpatialData):
23+
sdata_blobs.pl.render_images(elements="blobs_image").pl.show()
24+
25+
def test_plot_extent_of_points_partial_canvas(self, sdata_blobs: SpatialData):
26+
sdata_blobs.pl.render_points().pl.show()
27+
28+
def test_plot_extent_of_partial_canvas_on_full_canvas(self, sdata_blobs: SpatialData):
29+
sdata_blobs.pl.render_images(elements="blobs_image").pl.render_points().pl.show()
30+
31+
def test_plot_extent_calculation_respects_element_selection_circles(self, sdata_blobs: SpatialData):
32+
sdata_blobs.pl.render_shapes(elements="blobs_circles").pl.show()
33+
34+
def test_plot_extent_calculation_respects_element_selection_polygons(self, sdata_blobs: SpatialData):
35+
sdata_blobs.pl.render_shapes(elements="blobs_polygons").pl.show()
36+
37+
def test_plot_extent_calculation_respects_element_selection_circles_and_polygons(self, sdata_blobs: SpatialData):
38+
sdata_blobs.pl.render_shapes(elements=["blobs_circles", "blobs_polygons"]).pl.show()
39+
40+
def test_plot_extent_of_img_is_correct_after_spatial_query(self, sdata_blobs: SpatialData):
41+
cropped_blobs = sdata_blobs.pp.get_elements(["blobs_image"]).query.bounding_box(
42+
axes=["x", "y"], min_coordinate=[100, 100], max_coordinate=[400, 400], target_coordinate_system="global"
43+
)
44+
cropped_blobs.pl.render_images().pl.show()
45+
46+
def test_plot_extent_of_polygons_is_correct_after_spatial_query(self, sdata_blobs: SpatialData):
47+
cropped_blobs = sdata_blobs.pp.get_elements(["blobs_polygons"]).query.bounding_box(
48+
axes=["x", "y"], min_coordinate=[100, 100], max_coordinate=[400, 400], target_coordinate_system="global"
49+
)
50+
cropped_blobs.pl.render_shapes().pl.show()
51+
52+
def test_plot_extent_of_polygons_on_img_is_correct_after_spatial_query(self, sdata_blobs: SpatialData):
53+
cropped_blobs = sdata_blobs.pp.get_elements(["blobs_image", "blobs_polygons"]).query.bounding_box(
54+
axes=["x", "y"], min_coordinate=[100, 100], max_coordinate=[400, 400], target_coordinate_system="global"
55+
)
56+
cropped_blobs.pl.render_images().pl.render_shapes().pl.show()

tests/pl/test_render_labels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@
1919

2020

2121
class TestLabels(PlotTester, metaclass=PlotTesterMeta):
22-
def test_plot_labels(self, sdata_blobs: SpatialData):
22+
def test_plot_can_render_labels(self, sdata_blobs: SpatialData):
2323
sdata_blobs.pl.render_labels(elements="blobs_labels").pl.show()

tests/pl/test_render_points.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@
1919

2020

2121
class TestPoints(PlotTester, metaclass=PlotTesterMeta):
22-
def test_plot_points(self, sdata_blobs: SpatialData):
22+
def test_plot_can_render_points(self, sdata_blobs: SpatialData):
2323
sdata_blobs.pl.render_points(elements="blobs_points").pl.show()

tests/pl/test_render_shapes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
class TestShapes(PlotTester, metaclass=PlotTesterMeta):
2222
def test_plot_can_render_circles(self, sdata_blobs: SpatialData):
23-
sdata_blobs.pl.render_shapes(element="blobs_circles").pl.show()
23+
sdata_blobs.pl.render_shapes(elements="blobs_circles").pl.show()
2424

2525
def test_plot_can_render_polygons(self, sdata_blobs: SpatialData):
26-
sdata_blobs.pl.render_shapes(element="blobs_polygons").pl.show()
26+
sdata_blobs.pl.render_shapes(elements="blobs_polygons").pl.show()

tests/pl/test_show.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import matplotlib
2+
import scanpy as sc
3+
import spatialdata_plot # noqa: F401
4+
from spatialdata import SpatialData
5+
6+
from tests.conftest import PlotTester, PlotTesterMeta
7+
8+
sc.pl.set_rcParams_defaults()
9+
sc.set_figure_params(dpi=40, color_map="viridis")
10+
matplotlib.use("agg") # same as GitHub action runner
11+
_ = spatialdata_plot
12+
13+
# WARNING:
14+
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
15+
# 2. tests which produce a plot must be prefixed with `test_plot_`
16+
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
17+
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
18+
# ".png" is appended to <your_filename>, no need to set it
19+
20+
21+
class TestShow(PlotTester, metaclass=PlotTesterMeta):
22+
def test_plot_pad_extent_adds_padding(self, sdata_blobs: SpatialData):
23+
sdata_blobs.pl.render_images(elements="blobs_image").pl.show(pad_extent=100)

0 commit comments

Comments
 (0)