13
13
import multiscale_spatial_image as msi
14
14
import numpy as np
15
15
import pandas as pd
16
+ import spatial_image
16
17
import spatialdata as sd
17
18
import xarray as xr
18
19
from anndata import AnnData
@@ -180,6 +181,7 @@ def _get_extent(
180
181
labels : bool = True ,
181
182
points : bool = True ,
182
183
shapes : bool = True ,
184
+ img_transformations : Optional [dict [str , dict [str , sd .transformations .transformations .BaseTransformation ]]] = None ,
183
185
) -> dict [str , tuple [int , int , int , int ]]:
184
186
"""Return the extent of the elements contained in the SpatialData object.
185
187
@@ -195,6 +197,8 @@ def _get_extent(
195
197
Flag indicating whether to consider points when calculating the extent
196
198
shapes
197
199
Flag indicating whether to consider shaoes when calculating the extent
200
+ img_transformations
201
+ List of transformations already applied to the images
198
202
199
203
Returns
200
204
-------
@@ -218,8 +222,43 @@ def _get_extent(
218
222
for element_id in element_ids :
219
223
if images_key == element_id :
220
224
tmp = sdata .images [element_id ]
221
- y_dims += [(0 , tmp .shape [1 ])] # img is cyx, so we skip 0
222
- x_dims += [(0 , tmp .shape [2 ])]
225
+
226
+ # calculate original image extent
227
+ if img_transformations is not None :
228
+ shifts : dict [str , float ] = {}
229
+ shifts ["c" ] = tmp .shape [0 ]
230
+ shifts ["y" ] = tmp .shape [1 ]
231
+ shifts ["x" ] = tmp .shape [2 ]
232
+
233
+ if isinstance (
234
+ img_transformations [images_key ][cs_name ], sd .transformations .transformations .Sequence
235
+ ):
236
+ transformations = list (img_transformations [images_key ][cs_name ].transformations )
237
+
238
+ else :
239
+ transformations = [img_transformations [images_key ][cs_name ]]
240
+
241
+ # First reverse all scaling
242
+ for transformation in transformations :
243
+ if isinstance (transformation , sd .transformations .transformations .Scale ):
244
+ for idx , ax in enumerate (transformation .axes ):
245
+ shifts ["c" ] /= transformation .scale [idx ] if ax == "c" else 1
246
+ shifts ["x" ] /= transformation .scale [idx ] if ax == "x" else 1
247
+ shifts ["y" ] /= transformation .scale [idx ] if ax == "y" else 1
248
+
249
+ # Then the shift
250
+ for transformation in transformations :
251
+ if isinstance (transformation , sd .transformations .transformations .Translation ):
252
+ for idx , ax in enumerate (transformation .axes ):
253
+ shifts ["c" ] -= transformation .translation [idx ] if ax == "c" else 0
254
+ shifts ["x" ] -= transformation .translation [idx ] if ax == "x" else 0
255
+ shifts ["y" ] -= transformation .translation [idx ] if ax == "y" else 0
256
+
257
+ for ax in ["c" , "x" , "y" ]:
258
+ shifts [ax ] = int (shifts [ax ])
259
+
260
+ y_dims += [(tmp .shape [1 ] - shifts ["y" ], tmp .shape [1 ])] # img is cyx, so we skip 0
261
+ x_dims += [(tmp .shape [2 ] - shifts ["x" ], tmp .shape [2 ])]
223
262
del tmp
224
263
225
264
if labels and cs_contents .query (f"cs == '{ cs_name } '" )["has_labels" ][0 ]:
@@ -929,3 +968,44 @@ def _get_listed_colormap(color_dict: dict[str, str]) -> ListedColormap:
929
968
colors = [color_dict [k ] for k in sorted_labels ]
930
969
931
970
return ListedColormap (["black" ] + colors , N = len (colors ) + 1 )
971
+
972
+
973
+ def _translate_image (
974
+ image : spatial_image .SpatialImage ,
975
+ translation : sd .transformations .transformations .Translation ,
976
+ ) -> spatial_image .SpatialImage :
977
+ shifts : dict [str , int ] = {}
978
+
979
+ for idx , axis in enumerate (translation .axes ):
980
+ shifts [axis ] = int (translation .translation [idx ])
981
+
982
+ img = image .values .copy ()
983
+ shifted_channels = []
984
+
985
+ # split channels, shift axes individually, them recombine
986
+ if len (image .shape ) == 3 :
987
+ for c in range (image .shape [0 ]):
988
+ channel = img [c , :, :]
989
+
990
+ # iterates over [x, y]
991
+ for axis , shift in shifts .items ():
992
+ pad_x , pad_y = (0 , 0 ), (0 , 0 )
993
+ if axis == "x" and shift > 0 :
994
+ pad_x = (abs (shift ), 0 )
995
+ elif axis == "x" and shift < 0 :
996
+ pad_x = (0 , abs (shift ))
997
+
998
+ if axis == "y" and shift > 0 :
999
+ pad_y = (abs (shift ), 0 )
1000
+ elif axis == "y" and shift < 0 :
1001
+ pad_y = (0 , abs (shift ))
1002
+
1003
+ channel = np .pad (channel , (pad_y , pad_x ), mode = "constant" )
1004
+
1005
+ shifted_channels .append (channel )
1006
+
1007
+ return Image2DModel .parse (
1008
+ np .array (shifted_channels ),
1009
+ dims = ["c" , "y" , "x" ],
1010
+ transformations = image .attrs ["transform" ],
1011
+ )
0 commit comments