Skip to content

Commit 79c9e36

Browse files
committed
fixed
1 parent c89d4e2 commit 79c9e36

File tree

3 files changed

+88
-42
lines changed

3 files changed

+88
-42
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from __future__ import annotations
22

3-
import contextlib
43
from collections.abc import Sequence
54
from copy import copy
65
from dataclasses import dataclass
76
from functools import partial
8-
from itertools import chain
97
from typing import Any, Callable
108

119
import geopandas as gpd
@@ -123,34 +121,22 @@ def _get_collection_shape(
123121
outline_alpha: None | float = None,
124122
**kwargs: Any,
125123
) -> PatchCollection:
126-
print(shapes)
127-
patches = []
128-
# remove empty points/polygons
129-
shapes = shapes[shapes["geometry"].apply(lambda geom: not geom.is_empty)]
130-
131-
polygon_df = shapes[
132-
shapes["geometry"].apply(lambda geom: geom.geom_type == "Polygon") # type: ignore[call-overload]
133-
]
134-
multipolygon_df = shapes[
135-
shapes["geometry"].apply(lambda geom: geom.geom_type == "MultiPolygon") # type: ignore[call-overload]
136-
]
137-
circle_df = shapes[
138-
shapes["geometry"].apply(lambda geom: geom.geom_type == "Point") # type: ignore[call-overload]
139-
]
140-
141-
if len(polygon_df) > 0:
142-
patches += [Polygon(p.exterior.coords, closed=True) for p in polygon_df["geometry"]]
143-
if len(circle_df) > 0:
144-
patches += [
145-
Circle((circ.x, circ.y), radius=r * s) for circ, r in zip(circle_df["geometry"], circle_df["radius"])
146-
]
147-
if len(multipolygon_df) > 0:
148-
patches += [_make_patch_from_multipolygon(mp) for mp in multipolygon_df["geometry"]]
149-
150-
# flatten list since multipolygons cause a nested list
151-
with contextlib.suppress(Exception):
152-
patches = list(chain.from_iterable([x] if not isinstance(x, list) else x for x in patches))
153-
124+
"""
125+
Get a PatchCollection for rendering given geometries with specified colors and outlines.
126+
127+
Args:
128+
- shapes (list[GeoDataFrame]): List of geometrical shapes.
129+
- c: Color parameter.
130+
- s (float): Size of the shape.
131+
- norm: Normalization for the color map.
132+
- fill_alpha (float, optional): Opacity for the fill color.
133+
- outline_alpha (float, optional): Opacity for the outline.
134+
- **kwargs: Additional keyword arguments.
135+
136+
Returns
137+
-------
138+
- PatchCollection: Collection of patches for rendering.
139+
"""
154140
cmap = kwargs["cmap"]
155141

156142
try:
@@ -169,16 +155,60 @@ def _get_collection_shape(
169155
if render_params.outline_params.outline:
170156
outline_c = ColorConverter().to_rgba_array(render_params.outline_params.outline_color)
171157
outline_c[..., -1] = render_params.outline_alpha
158+
outline_c = outline_c.tolist()
172159
else:
173-
outline_c = None
160+
outline_c = [None]
161+
outline_c = outline_c * fill_c.shape[0]
162+
163+
shapes_df = pd.DataFrame(shapes, copy=True)
164+
165+
# remove empty points/polygons
166+
shapes_df = shapes_df[shapes_df["geometry"].apply(lambda geom: not geom.is_empty)]
167+
168+
rows = []
169+
170+
def assign_fill_and_outline_to_row(
171+
shapes: list[GeoDataFrame], fill_c: list[Any], outline_c: list[Any], row: pd.Series, idx: int
172+
) -> None:
173+
if len(shapes) > 1 and len(fill_c) == 1:
174+
row["fill_c"] = fill_c
175+
row["outline_c"] = outline_c
176+
else:
177+
row["fill_c"] = fill_c[idx]
178+
row["outline_c"] = outline_c[idx]
179+
180+
# Match colors to the geometry, potentially expanding the row in case of
181+
# multipolygons
182+
for idx, row in shapes_df.iterrows():
183+
geom = row["geometry"]
184+
if geom.geom_type == "Polygon":
185+
row = row.to_dict()
186+
row["geometry"] = Polygon(geom.exterior.coords, closed=True)
187+
assign_fill_and_outline_to_row(shapes, fill_c, outline_c, row, idx)
188+
rows.append(row)
189+
190+
elif geom.geom_type == "MultiPolygon":
191+
mp = _make_patch_from_multipolygon(geom)
192+
for _, m in enumerate(mp):
193+
mp_copy = row.to_dict()
194+
mp_copy["geometry"] = m
195+
assign_fill_and_outline_to_row(shapes, fill_c, outline_c, mp_copy, idx)
196+
rows.append(mp_copy)
197+
198+
elif geom.geom_type == "Point":
199+
row = row.to_dict()
200+
row["geometry"] = Circle((geom.x, geom.y), radius=row["radius"])
201+
assign_fill_and_outline_to_row(shapes, fill_c, outline_c, row, idx)
202+
rows.append(row)
203+
204+
patches = pd.DataFrame(rows)
174205

175206
return PatchCollection(
176-
patches,
207+
patches["geometry"].values.tolist(),
177208
snap=False,
178-
# zorder=4,
179209
lw=render_params.outline_params.linewidth,
180-
facecolor=fill_c,
181-
edgecolor=outline_c,
210+
facecolor=patches["fill_c"],
211+
edgecolor=None if all(outline is None for outline in outline_c) else outline_c,
182212
**kwargs,
183213
)
184214

src/spatialdata_plot/pl/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def _get_extent_after_transformations(element: Any, cs_name: str) -> Sequence[in
302302
if shapes_key == e_id:
303303

304304
def get_point_bb(
305-
point: Point, radius: int, method: Literal["topleft", "bottomright"], buffer: int = 1
305+
point: Point, radius: int, method: Literal["topleft", "bottomright"], buffer: int = 0
306306
) -> Point:
307307
x, y = point.coords[0]
308308
if method == "topleft":
@@ -349,7 +349,12 @@ def get_point_bb(
349349
del tmp_points
350350
del tmp_polygons
351351

352-
extent[cs_name][e_id] = x_dims + y_dims
352+
xmin = np.min(x_dims)
353+
xmax = np.max(x_dims)
354+
ymin = np.min(y_dims)
355+
ymax = np.max(y_dims)
356+
357+
extent[cs_name][e_id] = [xmin, xmax, ymin, ymax]
353358

354359
transformations = get_transformation(sdata.shapes[e_id], to_coordinate_system=cs_name)
355360
transformations = _flatten_transformation_sequence(transformations)

tests/pl/test_render_shapes.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import anndata
12
import geopandas as gpd
23
import matplotlib
4+
import pandas as pd
35
import scanpy as sc
46
import spatialdata_plot # noqa: F401
5-
from shapely.geometry import MultiPolygon, Polygon
7+
from shapely.geometry import MultiPolygon, Point, Polygon
68
from spatialdata import SpatialData
7-
from spatialdata.models import ShapesModel
9+
from spatialdata.models import ShapesModel, TableModel
810

911
from tests.conftest import PlotTester, PlotTesterMeta
1012

@@ -72,9 +74,18 @@ def _make_multi():
7274
]
7375
)
7476
poly = Polygon([(4.0, 0.0), (4.0, 1.0), (5.0, 1.0), (5.0, 0.0)])
75-
polygon_series = gpd.GeoSeries([hole, overlap, poly])
77+
circ = Point(6.0, 0.5)
78+
polygon_series = gpd.GeoSeries([hole, overlap, poly, circ])
7679
cell_polygon_table = gpd.GeoDataFrame(geometry=polygon_series)
77-
return ShapesModel.parse(cell_polygon_table)
80+
sd_polygons = ShapesModel.parse(cell_polygon_table)
81+
sd_polygons.loc[:, "radius"] = [None, None, None, 0.3]
82+
83+
return sd_polygons
7884

7985
sdata = SpatialData(shapes={"p": _make_multi()})
80-
sdata.pl.render_shapes(outline=True, fill_alpha=0.3).pl.show()
86+
adata = anndata.AnnData(pd.DataFrame({"p": ["hole", "overlap", "square", "circle"]}))
87+
adata.obs.loc[:, "region"] = "p"
88+
adata.obs.loc[:, "val"] = [1, 2, 3, 4]
89+
table = TableModel.parse(adata, region="p", region_key="region", instance_key="val")
90+
sdata.table = table
91+
sdata.pl.render_shapes(col="val", outline=True, fill_alpha=0.3).pl.show()

0 commit comments

Comments
 (0)