Skip to content

Commit d9d58ed

Browse files
authored
mvp for render_points (#39)
* mvp for render_points * render_points can now do continous/categorical column-coloring
1 parent 5f9917c commit d9d58ed

File tree

3 files changed

+147
-3
lines changed

3 files changed

+147
-3
lines changed

src/spatialdata_plot/pl/_categorical_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,3 +450,41 @@ def _add_categorical_legend(
450450
fontsize=legend_fontsize,
451451
path_effects=legend_fontoutline,
452452
)
453+
454+
455+
def _get_colors_for_categorical_obs(categories: Sequence[Union[str, int]]) -> list[str]:
456+
"""
457+
Return a list of colors for a categorical observation.
458+
459+
Parameters
460+
----------
461+
adata
462+
AnnData object
463+
value_to_plot
464+
Name of a valid categorical observation
465+
categories
466+
categories of the categorical observation.
467+
468+
Returns
469+
-------
470+
None
471+
"""
472+
length = len(categories)
473+
474+
# check if default matplotlib palette has enough colors
475+
if len(rcParams["axes.prop_cycle"].by_key()["color"]) >= length:
476+
cc = rcParams["axes.prop_cycle"]()
477+
palette = [next(cc)["color"] for _ in range(length)]
478+
479+
else:
480+
if length <= 20:
481+
palette = default_20
482+
elif length <= 28:
483+
palette = default_28
484+
elif length <= len(default_102): # 103 colors
485+
palette = default_102
486+
else:
487+
palette = ["grey" for _ in range(length)]
488+
logging.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
489+
490+
return palette[:length]

src/spatialdata_plot/pl/basic.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121

2222
from ..accessor import register_spatial_data_accessor
2323
from ..pp.utils import _get_instance_key, _get_region_key, _verify_plotting_tree_exists
24-
from .render import _render_channels, _render_images, _render_labels, _render_shapes
24+
from .render import (
25+
_render_channels,
26+
_render_images,
27+
_render_labels,
28+
_render_points,
29+
_render_shapes,
30+
)
2531
from .utils import (
2632
_get_color_key_dtype,
2733
_get_color_key_values,
@@ -197,6 +203,55 @@ def render_shapes(
197203

198204
return sdata
199205

206+
def render_points(
207+
self,
208+
palette: Optional[Union[str, list[str], None]] = None,
209+
color_key: Optional[str] = None,
210+
**scatter_kwargs: Optional[str],
211+
) -> sd.SpatialData:
212+
"""Render the points contained in the given sd.SpatialData object
213+
214+
Parameters
215+
----------
216+
self : sd.SpatialData
217+
The sd.SpatialData object.
218+
palette : list[str], optional (default: None)
219+
A list of colors to use for rendering the images. If `None`, the
220+
default colors will be used.
221+
instance_key : str
222+
The name of the column in the table that identifies individual shapes
223+
color_key : str or None, optional (default: None)
224+
The name of the column in the table to use for coloring shapes.
225+
226+
Returns
227+
-------
228+
sd.SpatialData
229+
The input sd.SpatialData with a command added to the plotting tree
230+
231+
"""
232+
if palette is not None:
233+
if isinstance(palette, str):
234+
palette = [palette]
235+
236+
if isinstance(palette, list):
237+
if not all(isinstance(p, str) for p in palette):
238+
raise TypeError("The palette argument must be a list of strings or a single string.")
239+
else:
240+
raise TypeError("The palette argument must be a list of strings or a single string.")
241+
242+
if color_key is not None and not isinstance(color_key, str):
243+
raise TypeError("When giving a 'color_key', it must be of type 'str'.")
244+
245+
sdata = self._copy()
246+
sdata = _verify_plotting_tree_exists(sdata)
247+
n_steps = len(sdata.plotting_tree.keys())
248+
sdata.plotting_tree[f"{n_steps+1}_render_points"] = {
249+
"palette": palette,
250+
"color_key": color_key,
251+
}
252+
253+
return sdata
254+
200255
def render_images(
201256
self,
202257
palette: Optional[Union[str, list[str]]] = None,
@@ -511,6 +566,7 @@ def show(
511566
"render_images",
512567
"render_shapes",
513568
"render_labels",
569+
"render_points",
514570
]
515571

516572
if len(plotting_tree.keys()) > 0:
@@ -646,6 +702,17 @@ def show(
646702
key = list(sdata.shapes.keys())[idx]
647703
_render_shapes(sdata=sdata, params=params, key=key, ax=ax, extent=extent)
648704

705+
elif cmd == "render_points":
706+
for idx, ax in enumerate(axs):
707+
key = list(sdata.points.keys())[idx]
708+
if params["color_key"] is not None:
709+
if params["color_key"] not in sdata.points[key].columns:
710+
raise ValueError(
711+
f"The column '{params['color_key']}' is not present in the 'metadata' of the points."
712+
)
713+
714+
_render_points(sdata=sdata, params=params, key=key, ax=ax, extent=extent)
715+
649716
elif cmd == "render_labels":
650717
if (
651718
sdata.table is not None

src/spatialdata_plot/pl/render.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
import spatialdata as sd
99
import xarray as xr
1010
from matplotlib.colors import ListedColormap, to_rgb
11+
from pandas.api.types import is_categorical_dtype
1112
from skimage.segmentation import find_boundaries
1213
from sklearn.decomposition import PCA
1314

15+
from ..pl._categorical_utils import _get_colors_for_categorical_obs
1416
from ..pl.utils import _normalize
1517
from ..pp.utils import _get_linear_colormap, _get_region_key
1618

@@ -58,13 +60,11 @@ def _render_shapes(
5860
) -> None:
5961
if sdata.table is not None and isinstance(params["instance_key"], str) and isinstance(params["color_key"], str):
6062
colors = [to_rgb(c) for c in sdata.table.uns[f"{params['color_key']}_colors"]]
61-
6263
elif isinstance(params["palette"], str):
6364
colors = [params["palette"]]
6465
elif isinstance(params["palette"], Iterable):
6566
colors = [to_rgb(c) for c in list(params["palette"])]
6667
else:
67-
# assert isinstance(params["palette"], Iterable)
6868
colors = [params["palette"]]
6969

7070
ax.set_xlim(extent["x"][0], extent["x"][1])
@@ -82,6 +82,45 @@ def _render_shapes(
8282
ax.set_title(key)
8383

8484

85+
def _render_points(
86+
sdata: sd.SpatialData,
87+
params: dict[str, Union[str, int, float, Iterable[str]]],
88+
key: str,
89+
ax: matplotlib.axes.SubplotBase,
90+
extent: dict[str, list[int]],
91+
) -> None:
92+
ax.set_xlim(extent["x"][0], extent["x"][1])
93+
ax.set_ylim(extent["y"][0], extent["y"][1])
94+
95+
if isinstance(params["color_key"], str):
96+
colors = sdata.points[key][params["color_key"]].compute()
97+
98+
if is_categorical_dtype(colors):
99+
category_colors = _get_colors_for_categorical_obs(colors.cat.categories)
100+
101+
for i, cat in enumerate(colors.cat.categories):
102+
ax.scatter(
103+
x=sdata.points[key]["x"].compute()[colors == cat],
104+
y=sdata.points[key]["y"].compute()[colors == cat],
105+
color=category_colors[i],
106+
label=cat,
107+
)
108+
109+
else:
110+
ax.scatter(
111+
x=sdata.points[key]["x"].compute(),
112+
y=sdata.points[key]["y"].compute(),
113+
c=colors,
114+
)
115+
else:
116+
ax.scatter(
117+
x=sdata.points[key]["x"].compute(),
118+
y=sdata.points[key]["y"].compute(),
119+
)
120+
121+
ax.set_title(key)
122+
123+
85124
def _render_images(
86125
sdata: sd.SpatialData,
87126
params: dict[str, Union[str, int, float]],

0 commit comments

Comments
 (0)