Skip to content

Commit c1cc849

Browse files
authored
Refactored transformations implementation (#162)
1 parent 4fc2b04 commit c1cc849

14 files changed

+272
-460
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ and this project adheres to [Semantic Versioning][].
1010

1111
## [0.1.0] - tbd
1212

13+
### Added
14+
15+
- Pushed `get_extent` functionality upstream to `spatialdata` (#162)
16+
17+
### Fixed
18+
19+
-
20+
1321
## [0.0.5] - 2023-10-02
1422

1523
### Added

src/spatialdata_plot/pl/basic.py

Lines changed: 93 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
2020
from pandas.api.types import is_categorical_dtype
2121
from spatial_image import SpatialImage
22-
from spatialdata._logging import logger as logg
22+
from spatialdata._core.data_extent import get_extent
23+
from spatialdata.transformations.operations import get_transformation
2324

2425
from spatialdata_plot._accessor import register_spatial_data_accessor
2526
from spatialdata_plot.pl.render import (
@@ -40,12 +41,10 @@
4041
)
4142
from spatialdata_plot.pl.utils import (
4243
_get_cs_contents,
43-
_get_extent,
4444
_maybe_set_colors,
4545
_mpl_ax_contains_elements,
4646
_prepare_cmap_norm,
4747
_prepare_params_plot,
48-
_robust_transform,
4948
_set_outline,
5049
save_fig,
5150
)
@@ -216,6 +215,8 @@ def render_shapes(
216215
na_color=na_color, # type: ignore[arg-type]
217216
**kwargs,
218217
)
218+
if isinstance(elements, str):
219+
elements = [elements]
219220
outline_params = _set_outline(outline, outline_width, outline_color)
220221
sdata.plotting_tree[f"{n_steps+1}_render_shapes"] = ShapesRenderParams(
221222
elements=elements,
@@ -285,12 +286,15 @@ def render_points(
285286
sdata = self._copy()
286287
sdata = _verify_plotting_tree(sdata)
287288
n_steps = len(sdata.plotting_tree.keys())
289+
288290
cmap_params = _prepare_cmap_norm(
289291
cmap=cmap,
290292
norm=norm,
291293
na_color=na_color, # type: ignore[arg-type]
292294
**kwargs,
293295
)
296+
if isinstance(elements, str):
297+
elements = [elements]
294298
sdata.plotting_tree[f"{n_steps+1}_render_points"] = PointsRenderParams(
295299
elements=elements,
296300
color=color,
@@ -370,6 +374,8 @@ def render_images(
370374
**kwargs,
371375
)
372376

377+
if isinstance(elements, str):
378+
elements = [elements]
373379
sdata.plotting_tree[f"{n_steps+1}_render_images"] = ImageRenderParams(
374380
elements=elements,
375381
channel=channel,
@@ -450,6 +456,8 @@ def render_labels(
450456
na_color=na_color, # type: ignore[arg-type]
451457
**kwargs,
452458
)
459+
if isinstance(elements, str):
460+
elements = [elements]
453461
sdata.plotting_tree[f"{n_steps+1}_render_labels"] = LabelsRenderParams(
454462
elements=elements,
455463
color=color,
@@ -552,12 +560,12 @@ def show(
552560
raise TypeError("All titles must be strings.")
553561

554562
# get original axis extent for later comparison
555-
x_min_orig, x_max_orig = (np.inf, -np.inf)
556-
y_min_orig, y_max_orig = (np.inf, -np.inf)
563+
ax_x_min, ax_x_max = (np.inf, -np.inf)
564+
ax_y_min, ax_y_max = (np.inf, -np.inf)
557565

558566
if isinstance(ax, Axes) and _mpl_ax_contains_elements(ax):
559-
x_min_orig, x_max_orig = ax.get_xlim()
560-
y_max_orig, y_min_orig = ax.get_ylim() # (0, 0) is top-left
567+
ax_x_min, ax_x_max = ax.get_xlim()
568+
ax_y_max, ax_y_min = ax.get_ylim() # (0, 0) is top-left
561569

562570
# handle coordinate system
563571
coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems
@@ -568,50 +576,6 @@ def show(
568576
if cs not in sdata.coordinate_systems:
569577
raise ValueError(f"Unknown coordinate system '{cs}', valid choices are: {sdata.coordinate_systems}")
570578

571-
# Check if user specified only certain elements to be plotted
572-
cs_contents = _get_cs_contents(sdata)
573-
elements_to_be_rendered = []
574-
for cmd, params in render_cmds.items():
575-
if cmd == "render_images" and cs_contents.query(f"cs == '{cs}'")["has_images"][0]: # noqa: SIM114
576-
if params.elements is not None:
577-
elements_to_be_rendered += (
578-
[params.elements] if isinstance(params.elements, str) else params.elements
579-
)
580-
elif cmd == "render_shapes" and cs_contents.query(f"cs == '{cs}'")["has_shapes"][0]: # noqa: SIM114
581-
if params.elements is not None:
582-
elements_to_be_rendered += (
583-
[params.elements] if isinstance(params.elements, str) else params.elements
584-
)
585-
elif cmd == "render_points" and cs_contents.query(f"cs == '{cs}'")["has_points"][0]: # noqa: SIM114
586-
if params.elements is not None:
587-
elements_to_be_rendered += (
588-
[params.elements] if isinstance(params.elements, str) else params.elements
589-
)
590-
elif cmd == "render_labels" and cs_contents.query(f"cs == '{cs}'")["has_labels"][0]: # noqa: SIM102
591-
if params.elements is not None:
592-
elements_to_be_rendered += (
593-
[params.elements] if isinstance(params.elements, str) else params.elements
594-
)
595-
596-
extent = _get_extent(
597-
sdata=sdata,
598-
has_images="render_images" in render_cmds,
599-
has_labels="render_labels" in render_cmds,
600-
has_points="render_points" in render_cmds,
601-
has_shapes="render_shapes" in render_cmds,
602-
elements=elements_to_be_rendered,
603-
coordinate_systems=coordinate_systems,
604-
)
605-
606-
# Use extent to filter out coordinate system without the relevant elements
607-
valid_cs = []
608-
for cs in coordinate_systems:
609-
if cs in extent:
610-
valid_cs.append(cs)
611-
else:
612-
logg.info(f"Dropping coordinate system '{cs}' since it doesn't have relevant elements.")
613-
coordinate_systems = valid_cs
614-
615579
# set up canvas
616580
fig_params, scalebar_params = _prepare_params_plot(
617581
num_panels=len(coordinate_systems),
@@ -633,32 +597,25 @@ def show(
633597
colorbar=colorbar,
634598
)
635599

600+
cs_contents = _get_cs_contents(sdata)
601+
636602
# go through tree
603+
637604
for i, cs in enumerate(coordinate_systems):
638605
sdata = self._copy()
639-
# properly transform all elements to the current coordinate system
640-
members = cs_contents.query(f"cs == '{cs}'")
641-
642-
if members["has_images"].values[0]:
643-
for key in sdata.images:
644-
sdata.images[key] = _robust_transform(sdata.images[key], cs)
645-
646-
if members["has_labels"].values[0]:
647-
for key in sdata.labels:
648-
sdata.labels[key] = _robust_transform(sdata.labels[key], cs)
649-
650-
if members["has_points"].values[0]:
651-
for key in sdata.points:
652-
sdata.points[key] = _robust_transform(sdata.points[key], cs)
653-
654-
if members["has_shapes"].values[0]:
655-
for key in sdata.shapes:
656-
sdata.shapes[key] = _robust_transform(sdata.shapes[key], cs)
657-
606+
_, has_images, has_labels, has_points, has_shapes = (
607+
cs_contents.query(f"cs == '{cs}'").iloc[0, :].values.tolist()
608+
)
658609
ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i]
659610

611+
wants_images = False
612+
wants_labels = False
613+
wants_points = False
614+
wants_shapes = False
615+
wanted_elements = []
616+
660617
for cmd, params in render_cmds.items():
661-
if cmd == "render_images" and cs_contents.query(f"cs == '{cs}'")["has_images"][0]:
618+
if cmd == "render_images" and has_images:
662619
_render_images(
663620
sdata=sdata,
664621
render_params=params,
@@ -667,9 +624,18 @@ def show(
667624
fig_params=fig_params,
668625
scalebar_params=scalebar_params,
669626
legend_params=legend_params,
670-
# extent=extent[cs],
671627
)
672-
elif cmd == "render_shapes" and cs_contents.query(f"cs == '{cs}'")["has_shapes"][0]:
628+
wants_images = True
629+
wanted_images = params.elements if params.elements is not None else list(sdata.images.keys())
630+
wanted_elements.extend(
631+
[
632+
image
633+
for image in wanted_images
634+
if cs in set(get_transformation(sdata.images[image], get_all=True).keys())
635+
]
636+
)
637+
638+
elif cmd == "render_shapes" and has_shapes:
673639
_render_shapes(
674640
sdata=sdata,
675641
render_params=params,
@@ -679,8 +645,17 @@ def show(
679645
scalebar_params=scalebar_params,
680646
legend_params=legend_params,
681647
)
648+
wants_shapes = True
649+
wanted_shapes = params.elements if params.elements is not None else list(sdata.shapes.keys())
650+
wanted_elements.extend(
651+
[
652+
shape
653+
for shape in wanted_shapes
654+
if cs in set(get_transformation(sdata.shapes[shape], get_all=True).keys())
655+
]
656+
)
682657

683-
elif cmd == "render_points" and cs_contents.query(f"cs == '{cs}'")["has_points"][0]:
658+
elif cmd == "render_points" and has_points:
684659
_render_points(
685660
sdata=sdata,
686661
render_params=params,
@@ -690,8 +665,17 @@ def show(
690665
scalebar_params=scalebar_params,
691666
legend_params=legend_params,
692667
)
668+
wants_points = True
669+
wanted_points = params.elements if params.elements is not None else list(sdata.points.keys())
670+
wanted_elements.extend(
671+
[
672+
point
673+
for point in wanted_points
674+
if cs in set(get_transformation(sdata.points[point], get_all=True).keys())
675+
]
676+
)
693677

694-
elif cmd == "render_labels" and cs_contents.query(f"cs == '{cs}'")["has_labels"][0]:
678+
elif cmd == "render_labels" and has_labels:
695679
if sdata.table is not None and isinstance(params.color, str):
696680
colors = sc.get.obs_df(sdata.table, params.color)
697681
if is_categorical_dtype(colors):
@@ -710,33 +694,46 @@ def show(
710694
scalebar_params=scalebar_params,
711695
legend_params=legend_params,
712696
)
697+
wants_labels = True
698+
wanted_labels = params.elements if params.elements is not None else list(sdata.labels.keys())
699+
wanted_elements.extend(
700+
[
701+
label
702+
for label in wanted_labels
703+
if cs in set(get_transformation(sdata.labels[label], get_all=True).keys())
704+
]
705+
)
713706

714-
if title is not None:
715-
if len(title) == 1:
716-
t = title[0]
717-
else:
718-
try:
719-
t = title[i]
720-
except IndexError as e:
721-
raise IndexError("The number of titles must match the number of coordinate systems.") from e
722-
else:
707+
if title is None:
723708
t = cs
709+
elif len(title) == 1:
710+
t = title[0]
711+
else:
712+
try:
713+
t = title[i]
714+
except IndexError as e:
715+
raise IndexError("The number of titles must match the number of coordinate systems.") from e
724716
ax.set_title(t)
725717
ax.set_aspect("equal")
726718

727-
if any(
728-
[
729-
cs_contents.query(f"cs == '{cs}'")["has_images"][0],
730-
cs_contents.query(f"cs == '{cs}'")["has_labels"][0],
731-
cs_contents.query(f"cs == '{cs}'")["has_points"][0],
732-
cs_contents.query(f"cs == '{cs}'")["has_shapes"][0],
733-
]
734-
):
719+
extent = get_extent(
720+
sdata,
721+
coordinate_system=cs,
722+
has_images=has_images and wants_images,
723+
has_labels=has_labels and wants_labels,
724+
has_points=has_points and wants_points,
725+
has_shapes=has_shapes and wants_shapes,
726+
elements=wanted_elements,
727+
)
728+
cs_x_min, cs_x_max = extent["x"]
729+
cs_y_min, cs_y_max = extent["y"]
730+
731+
if any([has_images, has_labels, has_points, has_shapes]):
735732
# If the axis already has limits, only expand them but not overwrite
736-
x_min = min(x_min_orig, extent[cs][0]) - pad_extent
737-
x_max = max(x_max_orig, extent[cs][1]) + pad_extent
738-
y_min = min(y_min_orig, extent[cs][2]) - pad_extent
739-
y_max = max(y_max_orig, extent[cs][3]) + pad_extent
733+
x_min = min(ax_x_min, cs_x_min) - pad_extent
734+
x_max = max(ax_x_max, cs_x_max) + pad_extent
735+
y_min = min(ax_y_min, cs_y_min) - pad_extent
736+
y_max = max(ax_y_max, cs_y_max) + pad_extent
740737
ax.set_xlim(x_min, x_max)
741738
ax.set_ylim(y_max, y_min) # (0, 0) is top-left
742739

@@ -747,5 +744,4 @@ def show(
747744
# https://stackoverflow.com/a/64523765
748745
if not hasattr(sys, "ps1"):
749746
plt.show()
750-
751747
return (fig_params.ax if fig_params.axs is None else fig_params.axs) if return_ax else None # shuts up ruff

0 commit comments

Comments
 (0)