Skip to content

Commit bf44326

Browse files
committed
mvp
1 parent 5e21804 commit bf44326

File tree

2 files changed

+58
-23
lines changed

2 files changed

+58
-23
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_decorate_axs,
3232
_get_colors_for_categorical_obs,
3333
_get_linear_colormap,
34+
_make_patch_from_multipolygon,
3435
_map_color_seg,
3536
_maybe_set_colors,
3637
_normalize,
@@ -110,11 +111,19 @@ def _get_collection_shape(
110111
outline_alpha: Optional[float] = None,
111112
**kwargs: Any,
112113
) -> PatchCollection:
113-
"""Get collection of shapes."""
114-
if shapes["geometry"].iloc[0].geom_type == "Polygon":
115-
patches = [Polygon(p.exterior.coords, closed=True) for p in shapes["geometry"]]
116-
elif shapes["geometry"].iloc[0].geom_type == "Point":
117-
patches = [Circle((circ.x, circ.y), radius=r * s) for circ, r in zip(shapes["geometry"], shapes["radius"])]
114+
polygon_df = shapes[shapes["geometry"].apply(lambda geom: geom.geom_type == "Polygon")]
115+
multipolygon_df = shapes[shapes["geometry"].apply(lambda geom: geom.geom_type == "MultiPolygon")]
116+
circle_df = shapes[shapes["geometry"].apply(lambda geom: geom.geom_type == "Point")]
117+
118+
patches = []
119+
if len(polygon_df) > 0:
120+
patches += [Polygon(p.exterior.coords, closed=True) for p in polygon_df["geometry"]]
121+
if len(circle_df) > 0:
122+
patches += [
123+
Circle((circ.x, circ.y), radius=r * s) for circ, r in zip(circle_df["geometry"], circle_df["radius"])
124+
]
125+
if len(multipolygon_df) > 0:
126+
patches += [_make_patch_from_multipolygon(mp) for mp in multipolygon_df["geometry"]]
118127

119128
cmap = kwargs["cmap"]
120129

@@ -160,6 +169,7 @@ def _get_collection_shape(
160169
outline_alpha=render_params.outline_alpha
161170
# **kwargs,
162171
)
172+
163173
cax = ax.add_collection(_cax)
164174

165175
palette = ListedColormap(set(color_vector)) if render_params.palette is None else render_params.palette

src/spatialdata_plot/pl/utils.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@
99
from types import MappingProxyType
1010
from typing import Any, Literal, Optional, Union
1111

12+
import matplotlib.patches as mpatches
13+
import matplotlib.path as mpath
1214
import matplotlib.pyplot as plt
1315
import multiscale_spatial_image as msi
1416
import numpy as np
1517
import pandas as pd
18+
import shapely
1619
import spatial_image
1720
import spatialdata as sd
1821
import xarray as xr
@@ -305,7 +308,9 @@ def get_point_bb(
305308
sdata.shapes[e_id]["geometry"].apply(lambda geom: geom.geom_type == "Point")
306309
]
307310
tmp_polygons = sdata.shapes[e_id][
308-
sdata.shapes[e_id]["geometry"].apply(lambda geom: geom.geom_type == "Polygon")
311+
sdata.shapes[e_id]["geometry"].apply(
312+
lambda geom: geom.geom_type in ["Polygon", "MultiPolygon"]
313+
)
309314
]
310315

311316
if not tmp_points.empty:
@@ -321,24 +326,17 @@ def get_point_bb(
321326
xmin_br, ymin_br, xmax_br, ymax_br = tmp_points["point_bottomright"].total_bounds
322327
y_dims += [min(ymin_tl, ymin_br), max(ymax_tl, ymax_br)]
323328
x_dims += [min(xmin_tl, xmin_br), max(xmax_tl, xmax_br)]
324-
y_dims += [min(ymin_tl, ymin_br), max(ymax_tl, ymax_br)]
325-
x_dims += [min(xmin_tl, xmin_br), max(xmax_tl, xmax_br)]
326329

327330
if not tmp_polygons.empty:
328331
xmin, ymin, xmax, ymax = tmp_polygons.total_bounds
329332
y_dims += [ymin, ymax]
330333
x_dims += [xmin, xmax]
331-
y_dims += [ymin, ymax]
332-
x_dims += [xmin, xmax]
333334

334335
del tmp_points
335336
del tmp_polygons
336337

337-
extent[cs_name][e_id] = x_dims + y_dims
338338
extent[cs_name][e_id] = x_dims + y_dims
339339

340-
transformations = get_transformation(sdata.shapes[e_id], to_coordinate_system=cs_name)
341-
transformations = _flatten_transformation_sequence(transformations)
342340
transformations = get_transformation(sdata.shapes[e_id], to_coordinate_system=cs_name)
343341
transformations = _flatten_transformation_sequence(transformations)
344342

@@ -358,16 +356,6 @@ def get_point_bb(
358356

359357
elif isinstance(t, sd.transformations.transformations.Affine):
360358
pass
361-
if has_points and cs_contents.query(f"cs == '{cs_name}'")["has_points"][0]:
362-
for points_key in sdata.points:
363-
for e_id in element_ids:
364-
if points_key == e_id:
365-
tmp = sdata.points[points_key]
366-
xmin = tmp["x"].min().compute()
367-
xmax = tmp["x"].max().compute()
368-
ymin = tmp["y"].min().compute()
369-
ymax = tmp["y"].max().compute()
370-
extent[cs_name][e_id] = [xmin, xmax, ymin, ymax]
371359

372360
if has_points and cs_contents.query(f"cs == '{cs_name}'")["has_points"][0]:
373361
for points_key in sdata.points:
@@ -1137,3 +1125,40 @@ def _robust_transform(element: Any, cs: str) -> Any:
11371125
raise ValueError("Unable to transform element.") from e
11381126

11391127
return element
1128+
1129+
1130+
def _split_multipolygon_into_outer_and_inner(mp: shapely.MultiPolygon): # type: ignore
1131+
# https://stackoverflow.com/a/21922058
1132+
if len(mp.geoms) > 1:
1133+
raise NotImplementedError("Currently, lists of Polygons are not supported. Only Polygons with holes.")
1134+
1135+
geom = mp.geoms[0]
1136+
if geom.type == "Polygon":
1137+
exterior_coords = geom.exterior.coords[:]
1138+
interior_coords = []
1139+
for interior in geom.interiors:
1140+
interior_coords += interior.coords[:]
1141+
elif geom.type == "MultiPolygon":
1142+
exterior_coords = []
1143+
interior_coords = []
1144+
for part in geom:
1145+
epc = _split_multipolygon_into_outer_and_inner(part) # Recursive call
1146+
exterior_coords += epc["exterior_coords"]
1147+
interior_coords += epc["interior_coords"]
1148+
else:
1149+
raise ValueError("Unhandled geometry type: " + repr(geom.type))
1150+
1151+
return interior_coords, exterior_coords
1152+
1153+
1154+
def _make_patch_from_multipolygon(mp: shapely.MultiPolygon) -> mpatches.PathPatch:
1155+
# https://matplotlib.org/stable/gallery/shapes_and_collections/donut.html
1156+
1157+
inside, outside = _split_multipolygon_into_outer_and_inner(mp)
1158+
codes = np.ones(len(inside), dtype=mpath.Path.code_type) * mpath.Path.LINETO
1159+
codes[0] = mpath.Path.MOVETO
1160+
vertices = np.concatenate((outside, inside[::-1]))
1161+
all_codes = np.concatenate((codes, codes))
1162+
path = mpath.Path(vertices, all_codes)
1163+
1164+
return mpatches.PathPatch(path)

0 commit comments

Comments
 (0)