Skip to content

Fix categorical plotting #229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ exclude = [
"dist",
"setup.py",
]
ignore = [
lint.ignore = [
# Do not assign a lambda expression, use a def -> lambda expression assignments are convenient
"E731",
# allow I, O, l as variable names -> I is the identity matrix, i, j, k, l is reasonable indexing notation
Expand All @@ -137,7 +137,7 @@ ignore = [
"D105",
]
line-length = 120
select = [
lint.select = [
"D", # flake8-docstrings
"I", # isort
"E", # pycodestyle
Expand All @@ -156,16 +156,16 @@ select = [
"RET", # flake8-raise
"PGH", # pygrep-hooks
]
unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"]
lint.unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"]
target-version = "py39"
[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D", "PT", "B024"]
"*/__init__.py" = ["F401", "D104", "D107", "E402"]
"docs/*" = ["D","B","E","A"]
# "src/spatialdata/transformations/transformations.py" = ["D101","D102", "D106", "B024", "T201", "RET504"]
"tests/conftest.py"= ["E402", "RET504"]
"src/spatialdata_plot/pl/utils.py"= ["PGH003"]
[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "numpy"

[tool.bumpver]
Expand Down
75 changes: 60 additions & 15 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from collections import abc
from copy import copy
from typing import Union, cast
Expand Down Expand Up @@ -37,6 +38,7 @@
_get_collection_shape,
_get_colors_for_categorical_obs,
_get_linear_colormap,
_is_coercable_to_float,
_map_color_seg,
_maybe_set_colors,
_multiscale_to_spatial_image,
Expand Down Expand Up @@ -70,6 +72,7 @@ def _render_shapes(
elements = list(sdata_filt.shapes.keys())

for index, e in enumerate(elements):
col_for_color = render_params.col_for_color[index]
shapes = sdata.shapes[e]

table_name = element_table_mapping.get(e)
Expand All @@ -79,13 +82,28 @@ def _render_shapes(
_, region_key, _ = get_table_keys(sdata[table_name])
table = sdata[table_name][sdata[table_name].obs[region_key].isin([e])]

if (
col_for_color is not None
and table_name is not None
and col_for_color in sdata_filt[table_name].obs.columns
and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O"
and not _is_coercable_to_float(color_col)
):
warnings.warn(
f"Converting copy of '{col_for_color}' column to categorical dtype for categorical plotting. "
f"Consider converting before plotting.",
UserWarning,
stacklevel=2,
)
sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category")

# get color vector (categorical or continuous)
color_source_vector, color_vector, _ = _set_color_source_vec(
sdata=sdata_filt,
element=sdata_filt.shapes[e],
element_index=index,
element_name=e,
value_to_plot=render_params.col_for_color[index],
value_to_plot=col_for_color,
groups=render_params.groups[index] if render_params.groups[index][0] is not None else None,
palette=(
render_params.palette[index] if render_params.palette is not None else None
Expand Down Expand Up @@ -170,7 +188,7 @@ def _render_shapes(
cax=cax,
fig_params=fig_params,
adata=table,
value_to_plot=render_params.col_for_color[index],
value_to_plot=col_for_color,
color_source_vector=color_source_vector,
palette=palette,
alpha=render_params.fill_alpha,
Expand Down Expand Up @@ -212,22 +230,48 @@ def _render_points(
table_name = element_table_mapping.get(e)

coords = ["x", "y"]
if col_for_color is not None:
if col_for_color not in points.columns:
# no error in case there are multiple elements, but onyl some have color key
msg = f"Color key '{col_for_color}' for element '{e}' not been found, using default colors."
logger.warning(msg)
else:
coords += [col_for_color]
# if col_for_color is not None:
if (
col_for_color is not None
and col_for_color not in points.columns
and col_for_color not in sdata_filt[table_name].obs.columns
):
# no error in case there are multiple elements, but onyl some have color key
msg = f"Color key '{col_for_color}' for element '{e}' not been found, using default colors."
logger.warning(msg)
elif col_for_color is None or (table_name is not None and col_for_color in sdata_filt[table_name].obs.columns):
points = points[coords].compute()
if (
col_for_color
and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O"
and not _is_coercable_to_float(color_col)
):
warnings.warn(
f"Converting copy of '{col_for_color}' column to categorical dtype for categorical "
f"plotting. Consider converting before plotting.",
UserWarning,
stacklevel=2,
)
sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category")
else:
coords += [col_for_color]
points = points[coords].compute()

points = points[coords].compute()
if render_params.groups[index][0] is not None and col_for_color is not None:
points = points[points[col_for_color].isin(render_params.groups[index])]

# we construct an anndata to hack the plotting functions
adata = AnnData(
X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype
)
if table_name is None:
adata = AnnData(
X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype
)
else:
adata = AnnData(
X=points[["x", "y"]].values, obs=sdata_filt[table_name].obs, dtype=points[["x", "y"]].values.dtype
)
sdata_filt[table_name] = adata

# we can do this because of dealing with a copy

# Convert back to dask dataframe to modify sdata
points = dask.dataframe.from_pandas(points, npartitions=1)
Expand Down Expand Up @@ -559,6 +603,7 @@ def _render_labels(
label = sdata_filt.labels[e]
extent = get_extent(label, coordinate_system=coordinate_system)
scale = render_params.scale[i] if isinstance(render_params.scale, list) else render_params.scale
color = render_params.color[i]

# get best scale out of multiscale label
if isinstance(label, MultiscaleSpatialImage):
Expand Down Expand Up @@ -603,7 +648,7 @@ def _render_labels(
element=label,
element_index=i,
element_name=e,
value_to_plot=cast(list[str], render_params.color)[i],
value_to_plot=color,
groups=render_params.groups[i],
palette=render_params.palette[i],
na_color=render_params.cmap_params.na_color,
Expand Down Expand Up @@ -684,7 +729,7 @@ def _render_labels(
cax=cax,
fig_params=fig_params,
adata=table,
value_to_plot=cast(list[str], render_params.color)[i],
value_to_plot=color,
color_source_vector=color_source_vector,
palette=render_params.palette[i],
alpha=render_params.fill_alpha,
Expand Down
56 changes: 51 additions & 5 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@
from spatial_image import SpatialImage
from spatialdata import SpatialData
from spatialdata._core.operations.rasterize import rasterize
from spatialdata._core.query.relational_query import _get_element_annotators, _locate_value, get_values
from spatialdata._core.query.relational_query import _get_element_annotators, _locate_value, _ValueOrigin, get_values
from spatialdata._types import ArrayLike
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement, TableModel
from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, SpatialElement, TableModel, get_model
from spatialdata.transformations.operations import get_transformation

from spatialdata_plot._logging import logger
Expand Down Expand Up @@ -211,7 +211,13 @@ def _get_collection_shape(
if norm is None:
c = cmap(c)
else:
norm = colors.Normalize(vmin=min(c), vmax=max(c))
try:
norm = colors.Normalize(vmin=min(c), vmax=max(c))
except ValueError as e:
raise ValueError(
"Could not convert values in the `color` column to float, if `color` column represents"
" categories, set the column to categorical dtype."
) from e
c = cmap(norm(c))

fill_c = ColorConverter().to_rgba_array(c)
Expand Down Expand Up @@ -589,6 +595,29 @@ def _get_colors_for_categorical_obs(
return palette[:len_cat] # type: ignore[return-value]


def _locate_points_value_in_table(value_key: str, sdata: SpatialData, element_name: str, table_name: str):
table = sdata[table_name]

if value_key in table.obs.columns:
value = table.obs[value_key]
is_categorical = isinstance(value.dtype, CategoricalDtype)
return _ValueOrigin(origin="obs", is_categorical=is_categorical, value_key=value_key)

is_categorical = False
return _ValueOrigin(origin="var", is_categorical=is_categorical, value_key=value_key)


# TODO consider move to relational query in spatialdata
def get_values_point_table(sdata: SpatialData, origin: _ValueOrigin, table_name: str):
"""Get a particular column stored in _ValueOrigin from the table in the spatialdata object."""
table = sdata[table_name]
if origin.origin == "obs":
return table.obs[origin.value_key]
if origin.origin == "var":
return table[:, table.var_names.isin([origin.value_key])].X.copy()
raise ValueError(f"Color column `{origin.value_key}` not found in table {table_name}")


def _set_color_source_vec(
sdata: sd.SpatialData,
element: SpatialElement | None,
Expand All @@ -605,16 +634,28 @@ def _set_color_source_vec(
color = np.full(len(element), to_hex(na_color)) # type: ignore[arg-type]
return color, color, False

model = get_model(sdata[element_name])

# Figure out where to get the color from
origins = _locate_value(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)
if model == PointsModel and table_name is not None:
origin = _locate_points_value_in_table(
value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name
)
if origin is not None:
origins.append(origin)

if len(origins) > 1:
raise ValueError(
f"Color key '{value_to_plot}' for element '{element_name}' been found in multiple locations: {origins}."
)

if len(origins) == 1:
vals = get_values(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)
color_source_vector = vals[value_to_plot]
if model == PointsModel and table_name is not None:
color_source_vector = get_values_point_table(sdata=sdata, origin=origin, table_name=table_name)
else:
vals = get_values(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)
color_source_vector = vals[value_to_plot]

# numerical case, return early
if color_source_vector is not None and not isinstance(color_source_vector.dtype, pd.CategoricalDtype):
Expand Down Expand Up @@ -1857,3 +1898,8 @@ def _update_params(sdata, params, wanted_elements_on_cs, element_type: Literal["
# params.palette = [[None] for _ in wanted_elements_on_cs]
image_flag = element_type == "images"
return _match_length_elements_groups_palette(params, wanted_elements_on_cs, image=image_flag)


def _is_coercable_to_float(series):
numeric_series = pd.to_numeric(series, errors="coerce")
return not numeric_series.isnull().any()
Binary file added tests/_images/Labels_can_color_labels.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Labels_can_stack_render_labels.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Labels_label_categorical_color.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Points_can_stack_render_points.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Points_points_categorical_color.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Shapes_shapes_categorical_color.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 59 additions & 1 deletion tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import dask.array as da
import matplotlib
import numpy as np
import pandas as pd
import scanpy as sc
import spatialdata_plot # noqa: F401
from anndata import AnnData
from spatial_image import to_spatial_image
from spatialdata import SpatialData
from spatialdata._core.query.relational_query import _get_unique_label_values_as_index
from spatialdata.models import TableModel

from tests.conftest import PlotTester, PlotTesterMeta

Expand All @@ -12,6 +17,7 @@
matplotlib.use("agg") # same as GitHub action runner
_ = spatialdata_plot

RNG = np.random.default_rng(seed=42)
# WARNING:
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
# 2. tests which produce a plot must be prefixed with `test_plot_`
Expand Down Expand Up @@ -90,4 +96,56 @@ def test_can_plot_with_one_element_color_table(self, sdata_blobs: SpatialData):
table.uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels"
table = table[:, ~table.var_names.isin(["channel_0_sum"])]
sdata_blobs["multi_table"] = table
sdata_blobs.pl.render_labels(color=["channel_0_sum"], table_name=["table"]).pl.show()
sdata_blobs.pl.render_labels(
color=["channel_0_sum", "channel_1_sum"], table_name=["table", "multi_table"]
).pl.show()

def test_plot_label_categorical_color(self, sdata_blobs: SpatialData):
n_obs = max(_get_unique_label_values_as_index(sdata_blobs["blobs_labels"]))
adata = AnnData(
RNG.normal(size=(n_obs, 10)), obs=pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["a", "b", "c"])
)
adata.obs["instance_id"] = np.arange(adata.n_obs)
adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs)
adata.obs["instance_id"] = list(range(adata.n_obs))
adata.obs["region"] = "blobs_labels"
table = TableModel.parse(adata=adata, region_key="region", instance_key="instance_id", region="blobs_labels")
sdata_blobs["other_table"] = table

# with pytest.raises(ValueError, match="could not convert string"):
# sdata_blobs.pl.render_labels('blobs_labels', color='category').pl.show()
sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category")
sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show()

def test_plot_multiscale_label_categorical_color(self, sdata_blobs: SpatialData):
n_obs = max(_get_unique_label_values_as_index(sdata_blobs["blobs_multiscale_labels"]))
adata = AnnData(
RNG.normal(size=(n_obs, 10)), obs=pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["a", "b", "c"])
)
adata.obs["instance_id"] = np.arange(adata.n_obs)
adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs)
adata.obs["instance_id"] = list(range(adata.n_obs))
adata.obs["region"] = "blobs_multiscale_labels"
table = TableModel.parse(
adata=adata, region_key="region", instance_key="instance_id", region="blobs_multiscale_labels"
)
sdata_blobs["other_table"] = table

sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category")
sdata_blobs.pl.render_labels("blobs_multiscale_labels", color="category").pl.show()

# def test_plot_multiscale_label_coercable_categorical_color(self, sdata_blobs: SpatialData):
# n_obs = max(_get_unique_label_values_as_index(sdata_blobs["blobs_multiscale_labels"]))
# adata = AnnData(
# RNG.normal(size=(n_obs, 10)), obs=pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["a", "b", "c"])
# )
# adata.obs["instance_id"] = np.arange(adata.n_obs)
# adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs)
# adata.obs["instance_id"] = list(range(adata.n_obs))
# adata.obs["region"] = "blobs_multiscale_labels"
# table = TableModel.parse(
# adata=adata, region_key="region", instance_key="instance_id", region="blobs_multiscale_labels"
# )
# sdata_blobs["other_table"] = table
#
# sdata_blobs.pl.render_labels("blobs_multiscale_labels", color="category").pl.show()
Loading