Skip to content

mvp for render_points #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/spatialdata_plot/pl/_categorical_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,41 @@ def _add_categorical_legend(
fontsize=legend_fontsize,
path_effects=legend_fontoutline,
)


def _get_colors_for_categorical_obs(categories: Sequence[Union[str, int]]) -> list[str]:
"""
Return a list of colors for a categorical observation.

Parameters
----------
adata
AnnData object
value_to_plot
Name of a valid categorical observation
categories
categories of the categorical observation.

Returns
-------
None
"""
length = len(categories)

# check if default matplotlib palette has enough colors
if len(rcParams["axes.prop_cycle"].by_key()["color"]) >= length:
cc = rcParams["axes.prop_cycle"]()
palette = [next(cc)["color"] for _ in range(length)]

else:
if length <= 20:
palette = default_20
elif length <= 28:
palette = default_28
elif length <= len(default_102): # 103 colors
palette = default_102
else:
palette = ["grey" for _ in range(length)]
logging.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")

return palette[:length]
69 changes: 68 additions & 1 deletion src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@

from ..accessor import register_spatial_data_accessor
from ..pp.utils import _get_instance_key, _get_region_key, _verify_plotting_tree_exists
from .render import _render_channels, _render_images, _render_labels, _render_shapes
from .render import (
_render_channels,
_render_images,
_render_labels,
_render_points,
_render_shapes,
)
from .utils import (
_get_color_key_dtype,
_get_color_key_values,
Expand Down Expand Up @@ -197,6 +203,55 @@ def render_shapes(

return sdata

def render_points(
self,
palette: Optional[Union[str, list[str], None]] = None,
color_key: Optional[str] = None,
**scatter_kwargs: Optional[str],
) -> sd.SpatialData:
"""Render the points contained in the given sd.SpatialData object

Parameters
----------
self : sd.SpatialData
The sd.SpatialData object.
palette : list[str], optional (default: None)
A list of colors to use for rendering the images. If `None`, the
default colors will be used.
instance_key : str
The name of the column in the table that identifies individual shapes
color_key : str or None, optional (default: None)
The name of the column in the table to use for coloring shapes.

Returns
-------
sd.SpatialData
The input sd.SpatialData with a command added to the plotting tree

"""
if palette is not None:
if isinstance(palette, str):
palette = [palette]

if isinstance(palette, list):
if not all(isinstance(p, str) for p in palette):
raise TypeError("The palette argument must be a list of strings or a single string.")
else:
raise TypeError("The palette argument must be a list of strings or a single string.")

if color_key is not None and not isinstance(color_key, str):
raise TypeError("When giving a 'color_key', it must be of type 'str'.")

sdata = self._copy()
sdata = _verify_plotting_tree_exists(sdata)
n_steps = len(sdata.plotting_tree.keys())
sdata.plotting_tree[f"{n_steps+1}_render_points"] = {
"palette": palette,
"color_key": color_key,
}

return sdata

def render_images(
self,
palette: Optional[Union[str, list[str]]] = None,
Expand Down Expand Up @@ -511,6 +566,7 @@ def show(
"render_images",
"render_shapes",
"render_labels",
"render_points",
]

if len(plotting_tree.keys()) > 0:
Expand Down Expand Up @@ -646,6 +702,17 @@ def show(
key = list(sdata.shapes.keys())[idx]
_render_shapes(sdata=sdata, params=params, key=key, ax=ax, extent=extent)

elif cmd == "render_points":
for idx, ax in enumerate(axs):
key = list(sdata.points.keys())[idx]
if params["color_key"] is not None:
if params["color_key"] not in sdata.points[key].columns:
raise ValueError(
f"The column '{params['color_key']}' is not present in the 'metadata' of the points."
)

_render_points(sdata=sdata, params=params, key=key, ax=ax, extent=extent)

elif cmd == "render_labels":
if (
sdata.table is not None
Expand Down
43 changes: 41 additions & 2 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import spatialdata as sd
import xarray as xr
from matplotlib.colors import ListedColormap, to_rgb
from pandas.api.types import is_categorical_dtype
from skimage.segmentation import find_boundaries
from sklearn.decomposition import PCA

from ..pl._categorical_utils import _get_colors_for_categorical_obs
from ..pl.utils import _normalize
from ..pp.utils import _get_linear_colormap, _get_region_key

Expand Down Expand Up @@ -58,13 +60,11 @@ def _render_shapes(
) -> None:
if sdata.table is not None and isinstance(params["instance_key"], str) and isinstance(params["color_key"], str):
colors = [to_rgb(c) for c in sdata.table.uns[f"{params['color_key']}_colors"]]

elif isinstance(params["palette"], str):
colors = [params["palette"]]
elif isinstance(params["palette"], Iterable):
colors = [to_rgb(c) for c in list(params["palette"])]
else:
# assert isinstance(params["palette"], Iterable)
colors = [params["palette"]]

ax.set_xlim(extent["x"][0], extent["x"][1])
Expand All @@ -82,6 +82,45 @@ def _render_shapes(
ax.set_title(key)


def _render_points(
sdata: sd.SpatialData,
params: dict[str, Union[str, int, float, Iterable[str]]],
key: str,
ax: matplotlib.axes.SubplotBase,
extent: dict[str, list[int]],
) -> None:
ax.set_xlim(extent["x"][0], extent["x"][1])
ax.set_ylim(extent["y"][0], extent["y"][1])

if isinstance(params["color_key"], str):
colors = sdata.points[key][params["color_key"]].compute()

if is_categorical_dtype(colors):
category_colors = _get_colors_for_categorical_obs(colors.cat.categories)

for i, cat in enumerate(colors.cat.categories):
ax.scatter(
x=sdata.points[key]["x"].compute()[colors == cat],
y=sdata.points[key]["y"].compute()[colors == cat],
color=category_colors[i],
label=cat,
)

else:
ax.scatter(
x=sdata.points[key]["x"].compute(),
y=sdata.points[key]["y"].compute(),
c=colors,
)
else:
ax.scatter(
x=sdata.points[key]["x"].compute(),
y=sdata.points[key]["y"].compute(),
)

ax.set_title(key)


def _render_images(
sdata: sd.SpatialData,
params: dict[str, Union[str, int, float]],
Expand Down