Skip to content

Commit 662b83f

Browse files
authored
render_shapes now plots polygons (#41)
1 parent d9d58ed commit 662b83f

File tree

3 files changed

+69
-37
lines changed

3 files changed

+69
-37
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import OrderedDict
22
from typing import Callable, Optional, Union
33

4+
import geopandas as gpd
45
import matplotlib
56
import matplotlib.pyplot as plt
67
import numpy as np
@@ -637,15 +638,33 @@ def show(
637638
# get biggest image after transformations to set ax size
638639
x_dims = []
639640
y_dims = []
641+
640642
for cmd, _ in render_cmds.items():
641643
if cmd == "render_images":
642644
y_dims += [(0, x.shape[1]) for x in sdata.images.values()]
643645
x_dims += [(0, x.shape[2]) for x in sdata.images.values()]
644646

645647
elif cmd == "render_shapes":
646-
for k in sdata.shapes.keys():
647-
x_dims += [(min(sdata.shapes[k].geometry.x), max(sdata.shapes[k].geometry.x))]
648-
y_dims += [(min(sdata.shapes[k].geometry.y), max(sdata.shapes[k].geometry.y))]
648+
for key in sdata.shapes.keys():
649+
points = []
650+
polygons = []
651+
652+
for _, row in sdata.shapes[key].iterrows():
653+
if row["geometry"].type == "Point":
654+
points.append(row)
655+
else:
656+
polygons.append(row)
657+
658+
if len(points) > 0:
659+
points_df = gpd.GeoDataFrame(data=points)
660+
x_dims += [(min(points_df.geometry.x), max(points_df.geometry.x))]
661+
y_dims += [(min(points_df.geometry.y), max(points_df.geometry.y))]
662+
663+
if len(polygons) > 0:
664+
for p in polygons:
665+
minx, miny, maxx, maxy = p.geometry.bounds
666+
x_dims += [(minx, maxx)]
667+
y_dims += [(miny, maxy)]
649668

650669
elif cmd == "render_labels":
651670
y_dims += [(0, x.shape[0]) for x in sdata.labels.values()]
@@ -767,7 +786,6 @@ def show(
767786
cell_ids_per_label = {}
768787
for key in list(sdata.labels.keys()):
769788
cell_ids_per_label[key] = sdata.labels[key].values.max()
770-
print(cell_ids_per_label)
771789
region_key = "tmp_label_id"
772790
instance_key = "tmp_cell_id"
773791
params["instance_key"] = instance_key
@@ -786,7 +804,6 @@ def show(
786804
distinct_cells = max(list(cell_ids_per_label.values()))
787805

788806
if sdata.table is not None:
789-
# print("Plotting a lot of cells with random colors, might take a while...")
790807
sdata.table.uns[f"{instance_key}_colors"] = _get_random_hex_colors(distinct_cells)
791808

792809
elif sdata.table is None:

src/spatialdata_plot/pl/render.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Iterable
2-
from typing import Callable, Union
2+
from typing import Callable, Optional, Union
33

44
import matplotlib
55
import matplotlib.patches as mpatches
@@ -53,31 +53,53 @@ def _render_channels(
5353

5454
def _render_shapes(
5555
sdata: sd.SpatialData,
56-
params: dict[str, Union[str, int, float, Iterable[str]]],
56+
params: dict[str, Optional[Union[str, int, float, Iterable[str]]]],
5757
key: str,
5858
ax: matplotlib.axes.SubplotBase,
5959
extent: dict[str, list[int]],
6060
) -> None:
61+
colors: Optional[Union[str, int, float, Iterable[str]]] = None # to shut up mypy
6162
if sdata.table is not None and isinstance(params["instance_key"], str) and isinstance(params["color_key"], str):
6263
colors = [to_rgb(c) for c in sdata.table.uns[f"{params['color_key']}_colors"]]
6364
elif isinstance(params["palette"], str):
6465
colors = [params["palette"]]
6566
elif isinstance(params["palette"], Iterable):
6667
colors = [to_rgb(c) for c in list(params["palette"])]
6768
else:
68-
colors = [params["palette"]]
69+
colors = params["palette"]
6970

7071
ax.set_xlim(extent["x"][0], extent["x"][1])
7172
ax.set_ylim(extent["y"][0], extent["y"][1])
7273

73-
shape = sdata.shapes[key]
74+
points = []
75+
polygons = []
7476

75-
ax.scatter(
76-
x=shape.geometry.x,
77-
y=shape.geometry.y,
78-
s=shape.radius,
79-
color=colors,
80-
)
77+
for _, row in sdata.shapes[key].iterrows():
78+
if row["geometry"].geom_type == "Point":
79+
points.append((row[0], row[1])) # (point, radius)
80+
elif row["geometry"].geom_type == "Polygon":
81+
polygons.append(row[0]) # just polygon
82+
else:
83+
raise NotImplementedError(f"Geometry type {row['geometry'].type} not supported.")
84+
85+
if len(polygons) > 0:
86+
for polygon in polygons:
87+
ax.add_patch(
88+
mpatches.Polygon(
89+
polygon.exterior.coords,
90+
color=colors,
91+
)
92+
)
93+
94+
if len(points) > 0:
95+
for point, radius in points:
96+
ax.add_patch(
97+
mpatches.Circle(
98+
(point.x, point.y),
99+
radius=radius,
100+
color=colors,
101+
)
102+
)
81103

82104
ax.set_title(key)
83105

@@ -146,11 +168,6 @@ def _render_images(
146168
elif n_channels == 2:
147169
colors = ListedColormap(["#d30cb8", "#6df1d8"])
148170
elif n_channels == 3:
149-
# bg = [(1, 1, 1, 1)]
150-
# cmap_red = ListedColormap([(1, 0, 0, i) for i in reversed(range(0, 256, 1))] + bg)
151-
# cmap_green = ListedColormap([(0, 1, 0, i) for i in reversed(range(0, 256, 1))] + bg)
152-
# cmap_blue = ListedColormap([(0, 0, 1, i) for i in reversed(range(0, 256, 1))] + bg)
153-
# colors = [cmap_red, cmap_green, cmap_blue]
154171
colors = ListedColormap(["red", "blue", "green"])
155172
else:
156173
# we do PCA to reduce to 3 channels
@@ -171,10 +188,6 @@ def _render_images(
171188
)
172189

173190
ax.set_title(key)
174-
# ax.set_xlabel("spatial1")
175-
# ax.set_ylabel("spatial2")
176-
# ax.set_xticks([])
177-
# ax.set_yticks([])
178191

179192

180193
def _render_labels(
@@ -249,7 +262,3 @@ def _render_labels(
249262
ax.legend(handles=patches, bbox_to_anchor=(0.9, 0.9), loc="upper left", frameon=False)
250263

251264
ax.set_title(key)
252-
# ax.set_xlabel("spatial1")
253-
# ax.set_ylabel("spatial2")
254-
# ax.set_xticks([])
255-
# ax.set_yticks([])

test_sandbox_data.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
# - You can use `jupyter nbconvert --to python test_sandbox_data.ipynb` to convert it back
77
# - Otherwise run it from the CLI and verify that the plots are okay
88

9-
# In[4]:
9+
# In[1]:
1010

1111

1212
import matplotlib.pyplot as plt
1313
import spatialdata as sd
14+
from spatialdata.datasets import blobs
1415

1516
import spatialdata_plot
1617

@@ -19,26 +20,31 @@
1920
DATA_DIR = "/Users/tim.treis/Documents/GitHub/spatialdata-sandbox/"
2021

2122

22-
# ## Load data
23-
# Adjust paths as neccecary
23+
# In[2]:
24+
2425

25-
# In[15]:
26+
(blobs().pl.render_images().pl.render_labels().pl.render_shapes().pl.render_points(color_key="genes").pl.show())
27+
28+
29+
# In[4]:
2630

2731

2832
# Mibi
2933

30-
sdata = sd.read_zarr(DATA_DIR + "mibitof/data.zarr")
31-
sdata.pl.render_images().pl.render_labels().pl.show()
34+
mibitof = sd.read_zarr(DATA_DIR + "mibitof/data.zarr")
35+
36+
(mibitof.pl.render_images().pl.render_labels().pl.show())
3237

3338
plt.savefig("mibi.png")
3439

3540

36-
# In[16]:
41+
# In[5]:
3742

3843

3944
# Visium
4045

41-
sdata = sd.read_zarr(DATA_DIR + "visium/data.zarr")
42-
sdata.pl.render_images().pl.render_shapes().pl.show(width=12, height=12)
46+
visium = sd.read_zarr(DATA_DIR + "visium/data.zarr")
47+
48+
(visium.pl.render_images().pl.render_shapes().pl.show())
4349

4450
plt.savefig("visium.png")

0 commit comments

Comments
 (0)