Skip to content

Commit c3e5bee

Browse files
fix translation bug from issue #53
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cb2528a commit c3e5bee

File tree

2 files changed

+105
-37
lines changed

2 files changed

+105
-37
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from __future__ import annotations
22

3+
import sys
34
from collections import OrderedDict
45
from collections.abc import Sequence
56
from pathlib import Path
67
from typing import Any, Optional, Union
78

8-
import numpy as np
9+
import matplotlib.pyplot as plt
910
import scanpy as sc
1011
import spatialdata as sd
1112
from anndata import AnnData
@@ -19,7 +20,6 @@
1920
from spatial_image import SpatialImage
2021
from spatialdata import transform
2122
from spatialdata._logging import logger as logg
22-
from spatialdata.models import Image2DModel
2323
from spatialdata.transformations import get_transformation
2424

2525
from spatialdata_plot._accessor import register_spatial_data_accessor
@@ -45,6 +45,7 @@
4545
_prepare_cmap_norm,
4646
_prepare_params_plot,
4747
_set_outline,
48+
_translate_image,
4849
save_fig,
4950
)
5051
from spatialdata_plot.pp.utils import _verify_plotting_tree
@@ -509,47 +510,28 @@ def show(
509510
# Simplicstic solution: If the images are multiscale, just use the first
510511
sdata = _multiscale_to_image(sdata)
511512

513+
img_transformations: dict[str, dict[str, sd.transformations.transformations.BaseTransformation]] = {}
512514
# transform all elements
513515
for cmd, _ in render_cmds.items():
514516
if cmd == "render_images" or cmd == "render_channels":
515517
for key in sdata.images:
516-
img_transformation = get_transformation(sdata.images[key], get_all=True)
517-
img_transformation = list(img_transformation.values())[0]
518+
img_transformations[key] = {}
519+
all_transformations = get_transformation(sdata.images[key], get_all=True)
518520

519-
if isinstance(img_transformation, sd.transformations.transformations.Translation):
520-
shifts: dict[str, int] = {}
521-
for idx, axis in enumerate(img_transformation.axes):
522-
shifts[axis] = int(img_transformation.translation[idx])
521+
for cs, transformation in all_transformations.items():
522+
img_transformations[key][cs] = transformation
523523

524-
img = sdata.images[key].values.copy()
525-
shifted_channels = []
524+
if isinstance(transformation, sd.transformations.transformations.Translation):
525+
sdata.images[key] = _translate_image(image=sdata.images[key], translation=transformation)
526526

527-
# split channels, shift axes individually, them recombine
528-
if len(sdata.images[key].shape) == 3:
529-
for c in range(sdata.images[key].shape[0]):
530-
channel = img[c, :, :]
527+
elif isinstance(transformation, sd.transformations.transformations.Sequence):
528+
# we have more than one transformation, let's find the translation(s)
529+
for t in list(transformation.transformations):
530+
if isinstance(t, sd.transformations.transformations.Translation):
531+
sdata.images[key] = _translate_image(image=sdata.images[key], translation=t)
531532

532-
# iterates over [x, y]
533-
for axis, shift in shifts.items():
534-
pad_x, pad_y = (0, 0), (0, 0)
535-
if axis == "x" and shift > 0:
536-
pad_x = (abs(shift), 0)
537-
elif axis == "x" and shift < 0:
538-
pad_x = (0, abs(shift))
539-
540-
if axis == "y" and shift > 0:
541-
pad_y = (abs(shift), 0)
542-
elif axis == "y" and shift < 0:
543-
pad_y = (0, abs(shift))
544-
545-
channel = np.pad(channel, (pad_y, pad_x), mode="constant")
546-
547-
shifted_channels.append(channel)
548-
549-
sdata.images[key] = Image2DModel.parse(np.array(shifted_channels), dims=["c", "y", "x"])
550-
551-
else:
552-
sdata.images[key] = transform(sdata.images[key], img_transformation)
533+
else:
534+
sdata.images[key] = transform(sdata.images[key], t)
553535

554536
elif cmd == "render_shapes":
555537
for key in sdata.shapes:
@@ -569,6 +551,7 @@ def show(
569551
labels="render_labels" in render_cmds,
570552
points="render_points" in render_cmds,
571553
shapes="render_shapes" in render_cmds,
554+
img_transformations=img_transformations if len(img_transformations) > 0 else None,
572555
)
573556

574557
# handle coordinate system
@@ -699,4 +682,9 @@ def show(
699682
if fig_params.fig is not None and save is not None:
700683
save_fig(fig_params.fig, path=save)
701684

685+
# Manually show plot if we're not in interactive mode
686+
# https://stackoverflow.com/a/64523765
687+
if hasattr(sys, "ps1"):
688+
plt.show()
689+
702690
return (fig_params.ax if fig_params.axs is None else fig_params.axs) if return_ax else None # shuts up ruff

src/spatialdata_plot/pl/utils.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import multiscale_spatial_image as msi
1414
import numpy as np
1515
import pandas as pd
16+
import spatial_image
1617
import spatialdata as sd
1718
import xarray as xr
1819
from anndata import AnnData
@@ -180,6 +181,7 @@ def _get_extent(
180181
labels: bool = True,
181182
points: bool = True,
182183
shapes: bool = True,
184+
img_transformations: Optional[dict[str, dict[str, sd.transformations.transformations.BaseTransformation]]] = None,
183185
) -> dict[str, tuple[int, int, int, int]]:
184186
"""Return the extent of the elements contained in the SpatialData object.
185187
@@ -195,6 +197,8 @@ def _get_extent(
195197
Flag indicating whether to consider points when calculating the extent
196198
shapes
197199
Flag indicating whether to consider shaoes when calculating the extent
200+
img_transformations
201+
List of transformations already applied to the images
198202
199203
Returns
200204
-------
@@ -218,8 +222,43 @@ def _get_extent(
218222
for element_id in element_ids:
219223
if images_key == element_id:
220224
tmp = sdata.images[element_id]
221-
y_dims += [(0, tmp.shape[1])] # img is cyx, so we skip 0
222-
x_dims += [(0, tmp.shape[2])]
225+
226+
# calculate original image extent
227+
if img_transformations is not None:
228+
shifts: dict[str, float] = {}
229+
shifts["c"] = tmp.shape[0]
230+
shifts["y"] = tmp.shape[1]
231+
shifts["x"] = tmp.shape[2]
232+
233+
if isinstance(
234+
img_transformations[images_key][cs_name], sd.transformations.transformations.Sequence
235+
):
236+
transformations = list(img_transformations[images_key][cs_name].transformations)
237+
238+
else:
239+
transformations = [img_transformations[images_key][cs_name]]
240+
241+
# First reverse all scaling
242+
for transformation in transformations:
243+
if isinstance(transformation, sd.transformations.transformations.Scale):
244+
for idx, ax in enumerate(transformation.axes):
245+
shifts["c"] /= transformation.scale[idx] if ax == "c" else 1
246+
shifts["x"] /= transformation.scale[idx] if ax == "x" else 1
247+
shifts["y"] /= transformation.scale[idx] if ax == "y" else 1
248+
249+
# Then the shift
250+
for transformation in transformations:
251+
if isinstance(transformation, sd.transformations.transformations.Translation):
252+
for idx, ax in enumerate(transformation.axes):
253+
shifts["c"] -= transformation.translation[idx] if ax == "c" else 0
254+
shifts["x"] -= transformation.translation[idx] if ax == "x" else 0
255+
shifts["y"] -= transformation.translation[idx] if ax == "y" else 0
256+
257+
for ax in ["c", "x", "y"]:
258+
shifts[ax] = int(shifts[ax])
259+
260+
y_dims += [(tmp.shape[1] - shifts["y"], tmp.shape[1])] # img is cyx, so we skip 0
261+
x_dims += [(tmp.shape[2] - shifts["x"], tmp.shape[2])]
223262
del tmp
224263

225264
if labels and cs_contents.query(f"cs == '{cs_name}'")["has_labels"][0]:
@@ -929,3 +968,44 @@ def _get_listed_colormap(color_dict: dict[str, str]) -> ListedColormap:
929968
colors = [color_dict[k] for k in sorted_labels]
930969

931970
return ListedColormap(["black"] + colors, N=len(colors) + 1)
971+
972+
973+
def _translate_image(
974+
image: spatial_image.SpatialImage,
975+
translation: sd.transformations.transformations.Translation,
976+
) -> spatial_image.SpatialImage:
977+
shifts: dict[str, int] = {}
978+
979+
for idx, axis in enumerate(translation.axes):
980+
shifts[axis] = int(translation.translation[idx])
981+
982+
img = image.values.copy()
983+
shifted_channels = []
984+
985+
# split channels, shift axes individually, them recombine
986+
if len(image.shape) == 3:
987+
for c in range(image.shape[0]):
988+
channel = img[c, :, :]
989+
990+
# iterates over [x, y]
991+
for axis, shift in shifts.items():
992+
pad_x, pad_y = (0, 0), (0, 0)
993+
if axis == "x" and shift > 0:
994+
pad_x = (abs(shift), 0)
995+
elif axis == "x" and shift < 0:
996+
pad_x = (0, abs(shift))
997+
998+
if axis == "y" and shift > 0:
999+
pad_y = (abs(shift), 0)
1000+
elif axis == "y" and shift < 0:
1001+
pad_y = (0, abs(shift))
1002+
1003+
channel = np.pad(channel, (pad_y, pad_x), mode="constant")
1004+
1005+
shifted_channels.append(channel)
1006+
1007+
return Image2DModel.parse(
1008+
np.array(shifted_channels),
1009+
dims=["c", "y", "x"],
1010+
transformations=image.attrs["transform"],
1011+
)

0 commit comments

Comments
 (0)