Skip to content

Commit 95283b2

Browse files
committed
fix
1 parent e0981f5 commit 95283b2

File tree

4 files changed

+19
-19
lines changed

4 files changed

+19
-19
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
from collections import OrderedDict
4-
from typing import Any, Optional, Union
54
from collections.abc import Sequence
5+
from typing import Any, Optional, Union
66

77
import geopandas as gpd
88
import numpy as np
@@ -41,9 +41,7 @@
4141
_prepare_params_plot,
4242
_set_outline,
4343
)
44-
from spatialdata_plot.pp.utils import (
45-
_verify_plotting_tree,
46-
)
44+
from spatialdata_plot.pp.utils import _verify_plotting_tree
4745

4846

4947
@register_spatial_data_accessor("pl")
@@ -200,7 +198,7 @@ def render_shapes(
200198
def render_points(
201199
self,
202200
palette: Optional[Union[str, list[str], None]] = None,
203-
color_key: Optional[str] = None,
201+
color: Optional[str] = None,
204202
**scatter_kwargs: Optional[str],
205203
) -> sd.SpatialData:
206204
"""Render the points contained in the given sd.SpatialData object
@@ -214,7 +212,7 @@ def render_points(
214212
default colors will be used.
215213
instance_key : str
216214
The name of the column in the table that identifies individual shapes
217-
color_key : str or None, optional (default: None)
215+
color : str or None, optional (default: None)
218216
The name of the column in the table to use for coloring shapes.
219217
220218
Returns
@@ -233,15 +231,15 @@ def render_points(
233231
else:
234232
raise TypeError("The palette argument must be a list of strings or a single string.")
235233

236-
if color_key is not None and not isinstance(color_key, str):
234+
if color is not None and not isinstance(color, str):
237235
raise TypeError("When giving a 'color_key', it must be of type 'str'.")
238236

239237
sdata = self._copy()
240238
sdata = _verify_plotting_tree(sdata)
241239
n_steps = len(sdata.plotting_tree.keys())
242240
sdata.plotting_tree[f"{n_steps+1}_render_points"] = {
243241
"palette": palette,
244-
"color_key": color_key,
242+
"color": color,
245243
}
246244

247245
return sdata
@@ -367,6 +365,7 @@ def render_labels(
367365

368366
def show(
369367
self,
368+
coordinate_system: str | Sequence[str] | None = None,
370369
legend_fontsize: int | float | _FontSize | None = None,
371370
legend_fontweight: int | _FontWeight = "bold",
372371
legend_loc: str | None = "right margin",
@@ -406,6 +405,9 @@ def show(
406405
plotting_tree = self._sdata.plotting_tree
407406
sdata = self._copy()
408407

408+
if isinstance(coordinate_system, str):
409+
coordinate_system = [coordinate_system]
410+
409411
# Evaluate execution tree for plotting
410412
valid_commands = [
411413
"get_elements",
@@ -436,7 +438,9 @@ def show(
436438

437439
# set up canvas
438440
fig_params, scalebar_params = _prepare_params_plot(
439-
num_panels=1, # len(render_cmds),
441+
num_panels=len(sdata.coordinate_systems)
442+
if coordinate_system is None
443+
else len(coordinate_system), # len(render_cmds),
440444
figsize=figsize,
441445
dpi=dpi,
442446
fig=fig,
@@ -528,7 +532,7 @@ def show(
528532
for key in sdata.shapes.keys():
529533
points = []
530534
polygons = []
531-
535+
# TODO: improve getting extent of polygons
532536
for _, row in sdata.shapes[key].iterrows():
533537
if row["geometry"].geom_type == "Point":
534538
points.append(row)

src/spatialdata_plot/pl/render.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@
1212
from geopandas import GeoDataFrame
1313
from matplotlib import colors
1414
from matplotlib.collections import PatchCollection
15-
from matplotlib.colors import (
16-
ColorConverter,
17-
ListedColormap,
18-
Normalize,
19-
)
15+
from matplotlib.colors import ColorConverter, ListedColormap, Normalize
2016
from matplotlib.patches import Circle, Polygon
2117
from pandas.api.types import is_categorical_dtype
2218
from scanpy._settings import settings as sc_settings
@@ -187,17 +183,17 @@ def _render_points(
187183
ax.set_xlim(extent["x"][0], extent["x"][1])
188184
ax.set_ylim(extent["y"][0], extent["y"][1])
189185

190-
if isinstance(params["color_key"], str):
191-
colors = sdata.points[key][params["color_key"]].compute()
186+
if isinstance(params["color"], str):
187+
colors = sdata.points[key][params["color"]].compute()
192188

193189
if is_categorical_dtype(colors):
194190
category_colors = _get_palette(categories=colors.cat.categories)
195191

196-
for i, cat in enumerate(colors.cat.categories):
192+
for cat in colors.cat.categories:
197193
ax.scatter(
198194
x=sdata.points[key]["x"].compute()[colors == cat],
199195
y=sdata.points[key]["y"].compute()[colors == cat],
200-
color=category_colors[i],
196+
color=category_colors[cat],
201197
label=cat,
202198
)
203199

tests/figures/Labels_images.png

19.8 KB
Loading

tests/figures/Labels_labels.png

5.21 KB
Loading

0 commit comments

Comments
 (0)