Skip to content

Commit 187f75b

Browse files
Merge pull request #93 from scverse/feature/202306_support_multipolygons
2 parents a77e2f6 + 08d2137 commit 187f75b

12 files changed

+159
-19
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning][].
1212

1313
### Added
1414

15+
- Multipolygons are now handled correctly (#93)
16+
1517
### Fixed
1618

1719
- Legend order is now deterministic (#143)

src/spatialdata_plot/pl/render.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from functools import partial
77
from typing import Any, Callable, Union
88

9+
import geopandas as gpd
910
import matplotlib
1011
import numpy as np
1112
import pandas as pd
@@ -35,6 +36,7 @@
3536
_decorate_axs,
3637
_get_colors_for_categorical_obs,
3738
_get_linear_colormap,
39+
_make_patch_from_multipolygon,
3840
_map_color_seg,
3941
_maybe_set_colors,
4042
_normalize,
@@ -119,18 +121,22 @@ def _get_collection_shape(
119121
outline_alpha: None | float = None,
120122
**kwargs: Any,
121123
) -> PatchCollection:
122-
patches = []
123-
for shape in shapes:
124-
# remove empty points/polygons
125-
shape = shape[shape["geometry"].apply(lambda geom: not geom.is_empty)]
126-
# We assume that all elements in one collection are of the same type
127-
if shape["geometry"].iloc[0].geom_type == "Polygon":
128-
patches += [Polygon(p.exterior.coords, closed=True) for p in shape["geometry"]]
129-
elif shape["geometry"].iloc[0].geom_type == "Point":
130-
patches += [
131-
Circle((circ.x, circ.y), radius=r * s) for circ, r in zip(shape["geometry"], shape["radius"])
132-
]
133-
124+
"""
125+
Get a PatchCollection for rendering given geometries with specified colors and outlines.
126+
127+
Args:
128+
- shapes (list[GeoDataFrame]): List of geometrical shapes.
129+
- c: Color parameter.
130+
- s (float): Size of the shape.
131+
- norm: Normalization for the color map.
132+
- fill_alpha (float, optional): Opacity for the fill color.
133+
- outline_alpha (float, optional): Opacity for the outline.
134+
- **kwargs: Additional keyword arguments.
135+
136+
Returns
137+
-------
138+
- PatchCollection: Collection of patches for rendering.
139+
"""
134140
cmap = kwargs["cmap"]
135141

136142
try:
@@ -149,16 +155,60 @@ def _get_collection_shape(
149155
if render_params.outline_params.outline:
150156
outline_c = ColorConverter().to_rgba_array(render_params.outline_params.outline_color)
151157
outline_c[..., -1] = render_params.outline_alpha
158+
outline_c = outline_c.tolist()
152159
else:
153-
outline_c = None
160+
outline_c = [None]
161+
outline_c = outline_c * fill_c.shape[0]
162+
163+
shapes_df = pd.DataFrame(shapes, copy=True)
164+
165+
# remove empty points/polygons
166+
shapes_df = shapes_df[shapes_df["geometry"].apply(lambda geom: not geom.is_empty)]
167+
168+
rows = []
169+
170+
def assign_fill_and_outline_to_row(
171+
shapes: list[GeoDataFrame], fill_c: list[Any], outline_c: list[Any], row: pd.Series, idx: int
172+
) -> None:
173+
if len(shapes) > 1 and len(fill_c) == 1:
174+
row["fill_c"] = fill_c
175+
row["outline_c"] = outline_c
176+
else:
177+
row["fill_c"] = fill_c[idx]
178+
row["outline_c"] = outline_c[idx]
179+
180+
# Match colors to the geometry, potentially expanding the row in case of
181+
# multipolygons
182+
for idx, row in shapes_df.iterrows():
183+
geom = row["geometry"]
184+
if geom.geom_type == "Polygon":
185+
row = row.to_dict()
186+
row["geometry"] = Polygon(geom.exterior.coords, closed=True)
187+
assign_fill_and_outline_to_row(shapes, fill_c, outline_c, row, idx)
188+
rows.append(row)
189+
190+
elif geom.geom_type == "MultiPolygon":
191+
mp = _make_patch_from_multipolygon(geom)
192+
for _, m in enumerate(mp):
193+
mp_copy = row.to_dict()
194+
mp_copy["geometry"] = m
195+
assign_fill_and_outline_to_row(shapes, fill_c, outline_c, mp_copy, idx)
196+
rows.append(mp_copy)
197+
198+
elif geom.geom_type == "Point":
199+
row = row.to_dict()
200+
row["geometry"] = Circle((geom.x, geom.y), radius=row["radius"])
201+
assign_fill_and_outline_to_row(shapes, fill_c, outline_c, row, idx)
202+
rows.append(row)
203+
204+
patches = pd.DataFrame(rows)
154205

155206
return PatchCollection(
156-
patches,
207+
patches["geometry"].values.tolist(),
157208
snap=False,
158-
# zorder=4,
159209
lw=render_params.outline_params.linewidth,
160-
facecolor=fill_c,
161-
edgecolor=outline_c,
210+
facecolor=patches["fill_c"],
211+
edgecolor=None if all(outline is None for outline in outline_c) else outline_c,
162212
**kwargs,
163213
)
164214

@@ -167,6 +217,8 @@ def _get_collection_shape(
167217
if len(color_vector) == 0:
168218
color_vector = [render_params.cmap_params.na_color]
169219

220+
shapes = pd.concat(shapes, ignore_index=True)
221+
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
170222
_cax = _get_collection_shape(
171223
shapes=shapes,
172224
s=render_params.size,
@@ -178,6 +230,7 @@ def _get_collection_shape(
178230
outline_alpha=render_params.outline_alpha
179231
# **kwargs,
180232
)
233+
181234
cax = ax.add_collection(_cax)
182235

183236
# Using dict.fromkeys here since set returns in arbitrary order

src/spatialdata_plot/pl/utils.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
from typing import Any, Literal
1111

1212
import matplotlib
13+
import matplotlib.patches as mpatches
14+
import matplotlib.path as mpath
1315
import matplotlib.pyplot as plt
1416
import multiscale_spatial_image as msi
1517
import numpy as np
1618
import pandas as pd
19+
import shapely
1720
import spatial_image
1821
import spatialdata as sd
1922
import xarray as xr
@@ -299,7 +302,7 @@ def _get_extent_after_transformations(element: Any, cs_name: str) -> Sequence[in
299302
if shapes_key == e_id:
300303

301304
def get_point_bb(
302-
point: Point, radius: int, method: Literal["topleft", "bottomright"], buffer: int = 1
305+
point: Point, radius: int, method: Literal["topleft", "bottomright"], buffer: int = 0
303306
) -> Point:
304307
x, y = point.coords[0]
305308
if method == "topleft":
@@ -346,7 +349,12 @@ def get_point_bb(
346349
del tmp_points
347350
del tmp_polygons
348351

349-
extent[cs_name][e_id] = x_dims + y_dims
352+
xmin = np.min(x_dims)
353+
xmax = np.max(x_dims)
354+
ymin = np.min(y_dims)
355+
ymax = np.max(y_dims)
356+
357+
extent[cs_name][e_id] = [xmin, xmax, ymin, ymax]
350358

351359
transformations = get_transformation(sdata.shapes[e_id], to_coordinate_system=cs_name)
352360
transformations = _flatten_transformation_sequence(transformations)
@@ -1166,6 +1174,51 @@ def _robust_transform(element: Any, cs: str) -> Any:
11661174
return element
11671175

11681176

1177+
def _split_multipolygon_into_outer_and_inner(mp: shapely.MultiPolygon): # type: ignore
1178+
# https://stackoverflow.com/a/21922058
1179+
1180+
for geom in mp.geoms:
1181+
if geom.geom_type == "Polygon":
1182+
exterior_coords = geom.exterior.coords[:]
1183+
interior_coords = []
1184+
for interior in geom.interiors:
1185+
interior_coords += interior.coords[:]
1186+
elif geom.geom_type == "MultiPolygon":
1187+
exterior_coords = []
1188+
interior_coords = []
1189+
for part in geom:
1190+
epc = _split_multipolygon_into_outer_and_inner(part) # Recursive call
1191+
exterior_coords += epc["exterior_coords"]
1192+
interior_coords += epc["interior_coords"]
1193+
else:
1194+
raise ValueError("Unhandled geometry type: " + repr(geom.type))
1195+
1196+
return interior_coords, exterior_coords
1197+
1198+
1199+
def _make_patch_from_multipolygon(mp: shapely.MultiPolygon) -> mpatches.PathPatch:
1200+
# https://matplotlib.org/stable/gallery/shapes_and_collections/donut.html
1201+
1202+
patches = []
1203+
for geom in mp.geoms:
1204+
if len(geom.interiors) == 0:
1205+
# polygon has no holes
1206+
patches += [mpatches.Polygon(geom.exterior.coords, closed=True)]
1207+
else:
1208+
inside, outside = _split_multipolygon_into_outer_and_inner(mp)
1209+
if len(inside) > 0:
1210+
codes = np.ones(len(inside), dtype=mpath.Path.code_type) * mpath.Path.LINETO
1211+
codes[0] = mpath.Path.MOVETO
1212+
all_codes = np.concatenate((codes, codes))
1213+
vertices = np.concatenate((outside, inside[::-1]))
1214+
else:
1215+
all_codes = []
1216+
vertices = np.concatenate(outside)
1217+
patches += [mpatches.PathPatch(mpath.Path(vertices, all_codes))]
1218+
1219+
return patches
1220+
1221+
11691222
def _mpl_ax_contains_elements(ax: Axes) -> bool:
11701223
"""Check if any objects have been plotted on the axes object.
11711224
Loading
Loading
Loading
369 Bytes
Loading
Loading
Loading
Loading
Loading

tests/pl/test_render_shapes.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
import anndata
12
import geopandas as gpd
23
import matplotlib
4+
import pandas as pd
35
import scanpy as sc
46
import spatialdata_plot # noqa: F401
7+
from shapely.geometry import MultiPolygon, Point, Polygon
58
from spatialdata import SpatialData
9+
from spatialdata.models import ShapesModel, TableModel
610

711
from tests.conftest import PlotTester, PlotTesterMeta
812

@@ -57,3 +61,31 @@ def test_plot_can_render_circles_with_default_outline_width(self, sdata_blobs: S
5761

5862
def test_plot_can_render_circles_with_specified_outline_width(self, sdata_blobs: SpatialData):
5963
sdata_blobs.pl.render_shapes(elements="blobs_circles", outline=True, outline_width=3.0).pl.show()
64+
65+
def test_plot_can_render_multipolygons(self):
66+
def _make_multi():
67+
hole = MultiPolygon(
68+
[(((0.0, 0.0), (0.0, 1.0), (1.0, 1.0), (1.0, 0.0)), [((0.2, 0.2), (0.2, 0.8), (0.8, 0.8), (0.8, 0.2))])]
69+
)
70+
overlap = MultiPolygon(
71+
[
72+
Polygon([(2.0, 0.0), (2.0, 0.8), (2.8, 0.8), (2.8, 0.0)]),
73+
Polygon([(2.2, 0.2), (2.2, 1.0), (3.0, 1.0), (3.0, 0.2)]),
74+
]
75+
)
76+
poly = Polygon([(4.0, 0.0), (4.0, 1.0), (5.0, 1.0), (5.0, 0.0)])
77+
circ = Point(6.0, 0.5)
78+
polygon_series = gpd.GeoSeries([hole, overlap, poly, circ])
79+
cell_polygon_table = gpd.GeoDataFrame(geometry=polygon_series)
80+
sd_polygons = ShapesModel.parse(cell_polygon_table)
81+
sd_polygons.loc[:, "radius"] = [None, None, None, 0.3]
82+
83+
return sd_polygons
84+
85+
sdata = SpatialData(shapes={"p": _make_multi()})
86+
adata = anndata.AnnData(pd.DataFrame({"p": ["hole", "overlap", "square", "circle"]}))
87+
adata.obs.loc[:, "region"] = "p"
88+
adata.obs.loc[:, "val"] = [1, 2, 3, 4]
89+
table = TableModel.parse(adata, region="p", region_key="region", instance_key="val")
90+
sdata.table = table
91+
sdata.pl.render_shapes(color="val", outline=True, fill_alpha=0.3).pl.show()

0 commit comments

Comments
 (0)