Skip to content

Commit 0ac9f28

Browse files
authored
Fix categorical plotting (#229)
* Add error when category is not categorical dtype * partial tests * fix categorical * add image artefacts
1 parent 9ca03ad commit 0ac9f28

15 files changed

+238
-26
lines changed

pyproject.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ exclude = [
118118
"dist",
119119
"setup.py",
120120
]
121-
ignore = [
121+
lint.ignore = [
122122
# Do not assign a lambda expression, use a def -> lambda expression assignments are convenient
123123
"E731",
124124
# allow I, O, l as variable names -> I is the identity matrix, i, j, k, l is reasonable indexing notation
@@ -137,7 +137,7 @@ ignore = [
137137
"D105",
138138
]
139139
line-length = 120
140-
select = [
140+
lint.select = [
141141
"D", # flake8-docstrings
142142
"I", # isort
143143
"E", # pycodestyle
@@ -156,16 +156,16 @@ select = [
156156
"RET", # flake8-raise
157157
"PGH", # pygrep-hooks
158158
]
159-
unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"]
159+
lint.unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"]
160160
target-version = "py39"
161-
[tool.ruff.per-file-ignores]
161+
[tool.ruff.lint.per-file-ignores]
162162
"tests/*" = ["D", "PT", "B024"]
163163
"*/__init__.py" = ["F401", "D104", "D107", "E402"]
164164
"docs/*" = ["D","B","E","A"]
165165
# "src/spatialdata/transformations/transformations.py" = ["D101","D102", "D106", "B024", "T201", "RET504"]
166166
"tests/conftest.py"= ["E402", "RET504"]
167167
"src/spatialdata_plot/pl/utils.py"= ["PGH003"]
168-
[tool.ruff.pydocstyle]
168+
[tool.ruff.lint.pydocstyle]
169169
convention = "numpy"
170170

171171
[tool.bumpver]

src/spatialdata_plot/pl/render.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from collections import abc
45
from copy import copy
56
from typing import Union, cast
@@ -37,6 +38,7 @@
3738
_get_collection_shape,
3839
_get_colors_for_categorical_obs,
3940
_get_linear_colormap,
41+
_is_coercable_to_float,
4042
_map_color_seg,
4143
_maybe_set_colors,
4244
_multiscale_to_spatial_image,
@@ -70,6 +72,7 @@ def _render_shapes(
7072
elements = list(sdata_filt.shapes.keys())
7173

7274
for index, e in enumerate(elements):
75+
col_for_color = render_params.col_for_color[index]
7376
shapes = sdata.shapes[e]
7477

7578
table_name = element_table_mapping.get(e)
@@ -79,13 +82,28 @@ def _render_shapes(
7982
_, region_key, _ = get_table_keys(sdata[table_name])
8083
table = sdata[table_name][sdata[table_name].obs[region_key].isin([e])]
8184

85+
if (
86+
col_for_color is not None
87+
and table_name is not None
88+
and col_for_color in sdata_filt[table_name].obs.columns
89+
and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O"
90+
and not _is_coercable_to_float(color_col)
91+
):
92+
warnings.warn(
93+
f"Converting copy of '{col_for_color}' column to categorical dtype for categorical plotting. "
94+
f"Consider converting before plotting.",
95+
UserWarning,
96+
stacklevel=2,
97+
)
98+
sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category")
99+
82100
# get color vector (categorical or continuous)
83101
color_source_vector, color_vector, _ = _set_color_source_vec(
84102
sdata=sdata_filt,
85103
element=sdata_filt.shapes[e],
86104
element_index=index,
87105
element_name=e,
88-
value_to_plot=render_params.col_for_color[index],
106+
value_to_plot=col_for_color,
89107
groups=render_params.groups[index] if render_params.groups[index][0] is not None else None,
90108
palette=(
91109
render_params.palette[index] if render_params.palette is not None else None
@@ -170,7 +188,7 @@ def _render_shapes(
170188
cax=cax,
171189
fig_params=fig_params,
172190
adata=table,
173-
value_to_plot=render_params.col_for_color[index],
191+
value_to_plot=col_for_color,
174192
color_source_vector=color_source_vector,
175193
palette=palette,
176194
alpha=render_params.fill_alpha,
@@ -212,22 +230,48 @@ def _render_points(
212230
table_name = element_table_mapping.get(e)
213231

214232
coords = ["x", "y"]
215-
if col_for_color is not None:
216-
if col_for_color not in points.columns:
217-
# no error in case there are multiple elements, but onyl some have color key
218-
msg = f"Color key '{col_for_color}' for element '{e}' not been found, using default colors."
219-
logger.warning(msg)
220-
else:
221-
coords += [col_for_color]
233+
# if col_for_color is not None:
234+
if (
235+
col_for_color is not None
236+
and col_for_color not in points.columns
237+
and col_for_color not in sdata_filt[table_name].obs.columns
238+
):
239+
# no error in case there are multiple elements, but onyl some have color key
240+
msg = f"Color key '{col_for_color}' for element '{e}' not been found, using default colors."
241+
logger.warning(msg)
242+
elif col_for_color is None or (table_name is not None and col_for_color in sdata_filt[table_name].obs.columns):
243+
points = points[coords].compute()
244+
if (
245+
col_for_color
246+
and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O"
247+
and not _is_coercable_to_float(color_col)
248+
):
249+
warnings.warn(
250+
f"Converting copy of '{col_for_color}' column to categorical dtype for categorical "
251+
f"plotting. Consider converting before plotting.",
252+
UserWarning,
253+
stacklevel=2,
254+
)
255+
sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category")
256+
else:
257+
coords += [col_for_color]
258+
points = points[coords].compute()
222259

223-
points = points[coords].compute()
224260
if render_params.groups[index][0] is not None and col_for_color is not None:
225261
points = points[points[col_for_color].isin(render_params.groups[index])]
226262

227263
# we construct an anndata to hack the plotting functions
228-
adata = AnnData(
229-
X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype
230-
)
264+
if table_name is None:
265+
adata = AnnData(
266+
X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype
267+
)
268+
else:
269+
adata = AnnData(
270+
X=points[["x", "y"]].values, obs=sdata_filt[table_name].obs, dtype=points[["x", "y"]].values.dtype
271+
)
272+
sdata_filt[table_name] = adata
273+
274+
# we can do this because of dealing with a copy
231275

232276
# Convert back to dask dataframe to modify sdata
233277
points = dask.dataframe.from_pandas(points, npartitions=1)
@@ -559,6 +603,7 @@ def _render_labels(
559603
label = sdata_filt.labels[e]
560604
extent = get_extent(label, coordinate_system=coordinate_system)
561605
scale = render_params.scale[i] if isinstance(render_params.scale, list) else render_params.scale
606+
color = render_params.color[i]
562607

563608
# get best scale out of multiscale label
564609
if isinstance(label, MultiscaleSpatialImage):
@@ -603,7 +648,7 @@ def _render_labels(
603648
element=label,
604649
element_index=i,
605650
element_name=e,
606-
value_to_plot=cast(list[str], render_params.color)[i],
651+
value_to_plot=color,
607652
groups=render_params.groups[i],
608653
palette=render_params.palette[i],
609654
na_color=render_params.cmap_params.na_color,
@@ -684,7 +729,7 @@ def _render_labels(
684729
cax=cax,
685730
fig_params=fig_params,
686731
adata=table,
687-
value_to_plot=cast(list[str], render_params.color)[i],
732+
value_to_plot=color,
688733
color_source_vector=color_source_vector,
689734
palette=render_params.palette[i],
690735
alpha=render_params.fill_alpha,

src/spatialdata_plot/pl/utils.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@
5454
from spatial_image import SpatialImage
5555
from spatialdata import SpatialData
5656
from spatialdata._core.operations.rasterize import rasterize
57-
from spatialdata._core.query.relational_query import _get_element_annotators, _locate_value, get_values
57+
from spatialdata._core.query.relational_query import _get_element_annotators, _locate_value, _ValueOrigin, get_values
5858
from spatialdata._types import ArrayLike
59-
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement, TableModel
59+
from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, SpatialElement, TableModel, get_model
6060
from spatialdata.transformations.operations import get_transformation
6161

6262
from spatialdata_plot._logging import logger
@@ -211,7 +211,13 @@ def _get_collection_shape(
211211
if norm is None:
212212
c = cmap(c)
213213
else:
214-
norm = colors.Normalize(vmin=min(c), vmax=max(c))
214+
try:
215+
norm = colors.Normalize(vmin=min(c), vmax=max(c))
216+
except ValueError as e:
217+
raise ValueError(
218+
"Could not convert values in the `color` column to float, if `color` column represents"
219+
" categories, set the column to categorical dtype."
220+
) from e
215221
c = cmap(norm(c))
216222

217223
fill_c = ColorConverter().to_rgba_array(c)
@@ -589,6 +595,29 @@ def _get_colors_for_categorical_obs(
589595
return palette[:len_cat] # type: ignore[return-value]
590596

591597

598+
def _locate_points_value_in_table(value_key: str, sdata: SpatialData, element_name: str, table_name: str):
599+
table = sdata[table_name]
600+
601+
if value_key in table.obs.columns:
602+
value = table.obs[value_key]
603+
is_categorical = isinstance(value.dtype, CategoricalDtype)
604+
return _ValueOrigin(origin="obs", is_categorical=is_categorical, value_key=value_key)
605+
606+
is_categorical = False
607+
return _ValueOrigin(origin="var", is_categorical=is_categorical, value_key=value_key)
608+
609+
610+
# TODO consider move to relational query in spatialdata
611+
def get_values_point_table(sdata: SpatialData, origin: _ValueOrigin, table_name: str):
612+
"""Get a particular column stored in _ValueOrigin from the table in the spatialdata object."""
613+
table = sdata[table_name]
614+
if origin.origin == "obs":
615+
return table.obs[origin.value_key]
616+
if origin.origin == "var":
617+
return table[:, table.var_names.isin([origin.value_key])].X.copy()
618+
raise ValueError(f"Color column `{origin.value_key}` not found in table {table_name}")
619+
620+
592621
def _set_color_source_vec(
593622
sdata: sd.SpatialData,
594623
element: SpatialElement | None,
@@ -605,16 +634,28 @@ def _set_color_source_vec(
605634
color = np.full(len(element), to_hex(na_color)) # type: ignore[arg-type]
606635
return color, color, False
607636

637+
model = get_model(sdata[element_name])
638+
608639
# Figure out where to get the color from
609640
origins = _locate_value(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)
641+
if model == PointsModel and table_name is not None:
642+
origin = _locate_points_value_in_table(
643+
value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name
644+
)
645+
if origin is not None:
646+
origins.append(origin)
647+
610648
if len(origins) > 1:
611649
raise ValueError(
612650
f"Color key '{value_to_plot}' for element '{element_name}' been found in multiple locations: {origins}."
613651
)
614652

615653
if len(origins) == 1:
616-
vals = get_values(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)
617-
color_source_vector = vals[value_to_plot]
654+
if model == PointsModel and table_name is not None:
655+
color_source_vector = get_values_point_table(sdata=sdata, origin=origin, table_name=table_name)
656+
else:
657+
vals = get_values(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)
658+
color_source_vector = vals[value_to_plot]
618659

619660
# numerical case, return early
620661
if color_source_vector is not None and not isinstance(color_source_vector.dtype, pd.CategoricalDtype):
@@ -1857,3 +1898,8 @@ def _update_params(sdata, params, wanted_elements_on_cs, element_type: Literal["
18571898
# params.palette = [[None] for _ in wanted_elements_on_cs]
18581899
image_flag = element_type == "images"
18591900
return _match_length_elements_groups_palette(params, wanted_elements_on_cs, image=image_flag)
1901+
1902+
1903+
def _is_coercable_to_float(series):
1904+
numeric_series = pd.to_numeric(series, errors="coerce")
1905+
return not numeric_series.isnull().any()
11.9 KB
Loading
0 Bytes
Loading
17.4 KB
Loading
Loading
-33 Bytes
Loading
Loading
Loading
Loading
Loading

tests/pl/test_render_labels.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import dask.array as da
22
import matplotlib
3+
import numpy as np
4+
import pandas as pd
35
import scanpy as sc
46
import spatialdata_plot # noqa: F401
7+
from anndata import AnnData
58
from spatial_image import to_spatial_image
69
from spatialdata import SpatialData
10+
from spatialdata._core.query.relational_query import _get_unique_label_values_as_index
11+
from spatialdata.models import TableModel
712

813
from tests.conftest import PlotTester, PlotTesterMeta
914

@@ -12,6 +17,7 @@
1217
matplotlib.use("agg") # same as GitHub action runner
1318
_ = spatialdata_plot
1419

20+
RNG = np.random.default_rng(seed=42)
1521
# WARNING:
1622
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
1723
# 2. tests which produce a plot must be prefixed with `test_plot_`
@@ -90,4 +96,56 @@ def test_can_plot_with_one_element_color_table(self, sdata_blobs: SpatialData):
9096
table.uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels"
9197
table = table[:, ~table.var_names.isin(["channel_0_sum"])]
9298
sdata_blobs["multi_table"] = table
93-
sdata_blobs.pl.render_labels(color=["channel_0_sum"], table_name=["table"]).pl.show()
99+
sdata_blobs.pl.render_labels(
100+
color=["channel_0_sum", "channel_1_sum"], table_name=["table", "multi_table"]
101+
).pl.show()
102+
103+
def test_plot_label_categorical_color(self, sdata_blobs: SpatialData):
104+
n_obs = max(_get_unique_label_values_as_index(sdata_blobs["blobs_labels"]))
105+
adata = AnnData(
106+
RNG.normal(size=(n_obs, 10)), obs=pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["a", "b", "c"])
107+
)
108+
adata.obs["instance_id"] = np.arange(adata.n_obs)
109+
adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs)
110+
adata.obs["instance_id"] = list(range(adata.n_obs))
111+
adata.obs["region"] = "blobs_labels"
112+
table = TableModel.parse(adata=adata, region_key="region", instance_key="instance_id", region="blobs_labels")
113+
sdata_blobs["other_table"] = table
114+
115+
# with pytest.raises(ValueError, match="could not convert string"):
116+
# sdata_blobs.pl.render_labels('blobs_labels', color='category').pl.show()
117+
sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category")
118+
sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show()
119+
120+
def test_plot_multiscale_label_categorical_color(self, sdata_blobs: SpatialData):
121+
n_obs = max(_get_unique_label_values_as_index(sdata_blobs["blobs_multiscale_labels"]))
122+
adata = AnnData(
123+
RNG.normal(size=(n_obs, 10)), obs=pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["a", "b", "c"])
124+
)
125+
adata.obs["instance_id"] = np.arange(adata.n_obs)
126+
adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs)
127+
adata.obs["instance_id"] = list(range(adata.n_obs))
128+
adata.obs["region"] = "blobs_multiscale_labels"
129+
table = TableModel.parse(
130+
adata=adata, region_key="region", instance_key="instance_id", region="blobs_multiscale_labels"
131+
)
132+
sdata_blobs["other_table"] = table
133+
134+
sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category")
135+
sdata_blobs.pl.render_labels("blobs_multiscale_labels", color="category").pl.show()
136+
137+
# def test_plot_multiscale_label_coercable_categorical_color(self, sdata_blobs: SpatialData):
138+
# n_obs = max(_get_unique_label_values_as_index(sdata_blobs["blobs_multiscale_labels"]))
139+
# adata = AnnData(
140+
# RNG.normal(size=(n_obs, 10)), obs=pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["a", "b", "c"])
141+
# )
142+
# adata.obs["instance_id"] = np.arange(adata.n_obs)
143+
# adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs)
144+
# adata.obs["instance_id"] = list(range(adata.n_obs))
145+
# adata.obs["region"] = "blobs_multiscale_labels"
146+
# table = TableModel.parse(
147+
# adata=adata, region_key="region", instance_key="instance_id", region="blobs_multiscale_labels"
148+
# )
149+
# sdata_blobs["other_table"] = table
150+
#
151+
# sdata_blobs.pl.render_labels("blobs_multiscale_labels", color="category").pl.show()

0 commit comments

Comments
 (0)