Skip to content

Coloring labels by a continuous variable fixed #344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [0.2.6] - tbd

### Added

-

### Changed

- Lowered RMSE-threshold for plot-based tests from 45 to 15 (#344)
- When subsetting to `groups`, `NA` isn't automatically added to legend (#344)

### Fixed

- Filtering with `groups` now preserves original cmap (#344)
- Non-selected `groups` are now not shown in `na_color` (#344)

## [0.2.5] - 2024-08-23

### Added
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
legend_fontweight=legend_params.legend_fontweight,
legend_loc=legend_params.legend_loc,
legend_fontoutline=legend_params.legend_fontoutline,
na_in_legend=legend_params.na_in_legend,
na_in_legend=legend_params.na_in_legend if groups is None else len(groups) == len(set(color_vector)),
colorbar=legend_params.colorbar,
scalebar_dx=scalebar_params.scalebar_dx,
scalebar_units=scalebar_params.scalebar_units,
Expand Down
68 changes: 27 additions & 41 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,9 +567,6 @@ def _get_subplots(num_images: int, ncols: int = 4, width: int = 4, height: int =
Union[plt.Figure, plt.Axes]
Matplotlib figure and axes object.
"""
# if num_images <= 1:
# raise ValueError("Number of images must be greater than 1.")

if num_images < ncols:
nrows = 1
ncols = num_images
Expand Down Expand Up @@ -733,8 +730,6 @@ def _set_color_source_vec(
color = np.full(len(element), na_color)
return color, color, False

# model = get_model(sdata[element_name])

# Figure out where to get the color from
origins = _locate_value(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)

Expand Down Expand Up @@ -778,16 +773,13 @@ def _set_color_source_vec(
palette=palette,
na_color=na_color,
)

color_source_vector = color_source_vector.set_categories(color_mapping.keys())
if color_mapping is None:
raise ValueError("Unable to create color palette.")

# do not rename categories, as colors need not be unique
color_vector = color_source_vector.map(color_mapping)
if color_vector.isna().any():
if (na_cat_color := to_hex(na_color)) not in color_vector.categories:
color_vector = color_vector.add_categories([na_cat_color])
color_vector = color_vector.fillna(to_hex(na_color))

return color_source_vector, color_vector, True

Expand All @@ -808,44 +800,43 @@ def _map_color_seg(
seg_boundaries: bool = False,
) -> ArrayLike:
cell_id = np.array(cell_id)
if color_vector is not None and isinstance(color_vector.dtype, pd.CategoricalDtype):
# users wants to plot a categorical column

if pd.api.types.is_categorical_dtype(color_vector.dtype):
# Case A: users wants to plot a categorical column
if np.any(color_source_vector.isna()):
cell_id[color_source_vector.isna()] = 0
val_im: ArrayLike = map_array(seg, cell_id, color_vector.codes + 1)
val_im: ArrayLike = map_array(seg.copy(), cell_id, color_vector.codes + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm is the copy really required here? I thought that map array returns an array with same shape as the original array, but it would not have the same pointer

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I remember. Don't ask me why, but if I don't copy(), one of the tests fails with:

ValueError: buffer source array is read-only

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same issue with the current main: whenever I tried to color labels by a column in table.obs, I'd get ValueError: buffer source array is read-only, and putting in those .copy() calls fixed it. This happened regardless of whether the SpatialData object was buffered to disk or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this is probably a bug in skimage actually. For now we could keep it as it is, but please add a comment with the issue so we remove this when we can.

cols = colors.to_rgba_array(color_vector.categories)

elif pd.api.types.is_numeric_dtype(color_vector.dtype):
# user wants to plot a continous column
# Case B: user wants to plot a continous column
if isinstance(color_vector, pd.Series):
color_vector = color_vector.to_numpy()
val_im = map_array(seg, cell_id, color_vector)
cols = cmap_params.cmap(cmap_params.norm(color_vector))

val_im = map_array(seg.copy(), cell_id, cell_id)
else:
val_im = map_array(seg.copy(), cell_id, cell_id) # replace with same seg id to remove missing segs

if val_im.shape[0] == 1:
val_im = np.squeeze(val_im, axis=0)
if "#" in str(color_vector[0]):
# we have hex colors
assert all(_is_color_like(c) for c in color_vector), "Not all values are color-like."
cols = colors.to_rgba_array(color_vector)
# Case C: User didn't specify any colors
if color_source_vector is not None and (
set(color_vector) == set(color_source_vector)
and len(set(color_vector)) == 1
and set(color_vector) == {na_color}
and not na_color_modified_by_user
):
val_im = map_array(seg.copy(), cell_id, cell_id)
RNG = default_rng(42)
cols = RNG.random((len(color_vector), 3))
else:
cols = cmap_params.cmap(cmap_params.norm(color_vector))
# Case D: User didn't specify a column to color by, but modified the na_color
val_im = map_array(seg.copy(), cell_id, cell_id)
if "#" in str(color_vector[0]):
# we have hex colors
assert all(_is_color_like(c) for c in color_vector), "Not all values are color-like."
cols = colors.to_rgba_array(color_vector)
else:
cols = cmap_params.cmap(cmap_params.norm(color_vector))

if seg_erosionpx is not None:
val_im[val_im == erosion(val_im, square(seg_erosionpx))] = 0

if color_source_vector is not None and (
set(color_vector) == set(color_source_vector)
and len(set(color_vector)) == 1
and set(color_vector) == {na_color}
and not na_color_modified_by_user
):
RNG = default_rng(42)
cols = RNG.random((len(cols), 3))

seg_im: ArrayLike = label2rgb(
label=val_im,
colors=cols,
Expand Down Expand Up @@ -948,7 +939,7 @@ def _get_categorical_color_mapping(
else:
base_mapping = _generate_base_categorial_color_mapping(adata, cluster_key, color_source_vector, na_color)

return _modify_categorical_color_mapping(base_mapping, groups, palette)
return _modify_categorical_color_mapping(mapping=base_mapping, groups=groups, palette=palette)


def _maybe_set_colors(
Expand Down Expand Up @@ -1587,19 +1578,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st

palette = param_dict["palette"]

if (groups := param_dict.get("groups")) is not None and palette is None:
warnings.warn(
"Groups is specified but palette is not. Setting palette to default 'lightgray'", UserWarning, stacklevel=2
)
param_dict["palette"] = ["lightgray" for _ in range(len(groups))]

if isinstance((palette := param_dict["palette"]), list):
if not all(isinstance(p, str) for p in palette):
raise ValueError("If specified, parameter 'palette' must contain only strings.")
elif isinstance(palette, (str, type(None))) and "palette" in param_dict:
param_dict["palette"] = [palette] if palette is not None else None

if element_type in ["shapes", "points", "labels"] and (palette := param_dict.get("palette")) is not None:
groups = param_dict.get("groups")
if groups is None:
raise ValueError("When specifying 'palette', 'groups' must also be specified.")
if len(groups) != len(palette):
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Labels_can_color_labels_by_continuous_variable.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Labels_can_control_label_infill.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Labels_can_control_label_outline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Labels_can_render_labels.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Labels_can_render_multiscale_labels.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Labels_can_stack_render_labels.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Labels_can_stop_rasterization_with_scale_full.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed tests/_images/Points_can_filter_with_groups.png
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Points_can_filter_with_groups_default_palette.png
Binary file modified tests/_images/Shapes_can_filter_with_groups.png
Binary file modified tests/_images/Utils_can_set_zero_in_cmap_to_transparent.png
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

EXPECTED = HERE / "_images"
ACTUAL = HERE / "figures"
TOL = 45
TOL = 15
DPI = 80

RNG = np.random.default_rng(seed=42)
Expand Down
42 changes: 28 additions & 14 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import pandas as pd
import pytest
import scanpy as sc
from anndata import AnnData
from spatial_image import to_spatial_image
from spatialdata import SpatialData, deepcopy, get_element_instances
from spatialdata import SpatialData, deepcopy
from spatialdata.models import TableModel

import spatialdata_plot # noqa: F401
Expand Down Expand Up @@ -82,26 +81,31 @@ def test_plot_can_stack_render_labels(self, sdata_blobs: SpatialData):
def test_plot_can_color_labels_by_continuous_variable(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum").pl.show()

def test_plot_can_color_labels_by_categorical_variable(self, sdata_blobs: SpatialData):
max_col = sdata_blobs.table.to_df().idxmax(axis=1)
max_col = pd.Categorical(max_col, categories=sdata_blobs.table.to_df().columns, ordered=True)
sdata_blobs.table.obs["which_max"] = max_col

sdata_blobs.pl.render_labels("blobs_labels", color="which_max").pl.show()

@pytest.mark.parametrize(
"label",
[
"blobs_labels",
"blobs_multiscale_labels",
],
)
def test_plot_can_color_labels_by_categorical_variable(self, sdata_blobs: SpatialData, label: str):
def test_plot_can_color_labels_by_categorical_variable_in_other_table(self, sdata_blobs: SpatialData, label: str):

def _make_tablemodel_with_categorical_labels(sdata_blobs, label):

n_obs = len(get_element_instances(sdata_blobs[label]))
vals = np.arange(n_obs) + 1
adata = AnnData(vals.reshape(-1, 1), obs=pd.DataFrame({"instance_id": vals}))
adata.obs["category"] = pd.Categorical(
list(["a", "b", "c"] * ((n_obs // 3) + 1))[:n_obs],
categories=["a", "b", "c"],
ordered=True,
)
adata = sdata_blobs.tables["table"].copy()
max_col = adata.to_df().idxmax(axis=1)
max_col = max_col.str.replace("channel_", "ch").str.replace("_sum", "")
max_col = pd.Categorical(max_col, categories=set(max_col), ordered=True)
adata.obs["which_max"] = max_col
adata.obs["region"] = label
del adata.uns["spatialdata_attrs"]
table = TableModel.parse(
adata=adata,
region_key="region",
Expand All @@ -110,7 +114,17 @@ def _make_tablemodel_with_categorical_labels(sdata_blobs, label):
)
sdata_blobs.tables["other_table"] = table

sdata_blobs.pl.render_labels(label, color="category", table="other_table", scale="scale0").pl.show()
_, axs = plt.subplots(nrows=1, ncols=3, layout="tight")

sdata_blobs.pl.render_labels(label, color="channel_1_sum", table="other_table", scale="scale0").pl.show(
ax=axs[0], title="ch_1_sum", colorbar=False
)
sdata_blobs.pl.render_labels(label, color="channel_2_sum", table="other_table", scale="scale0").pl.show(
ax=axs[1], title="ch_2_sum", colorbar=False
)
sdata_blobs.pl.render_labels(label, color="which_max", table="other_table", scale="scale0").pl.show(
ax=axs[2], legend_fontsize=6
)

# we're modifying the data here, so we need an independent copy
sdata_blobs_local = deepcopy(sdata_blobs)
Expand Down Expand Up @@ -176,7 +190,7 @@ def test_plot_subset_categorical_label_maintains_order(self, sdata_blobs: Spatia

_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")

sdata_blobs.pl.render_labels("blobs_labels", color="which_max").pl.show(ax=axs[0])
sdata_blobs.pl.render_labels("blobs_labels", color="which_max").pl.show(ax=axs[0], legend_fontsize=6)
sdata_blobs.pl.render_labels(
"blobs_labels",
color="which_max",
Expand All @@ -190,7 +204,7 @@ def test_plot_subset_categorical_label_maintains_order_when_palette_overwrite(se

_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")

sdata_blobs.pl.render_labels("blobs_labels", color="which_max").pl.show(ax=axs[0])
sdata_blobs.pl.render_labels("blobs_labels", color="which_max").pl.show(ax=axs[0], legend_fontsize=6)
sdata_blobs.pl.render_labels(
"blobs_labels", color="which_max", groups=["channel_0_sum"], palette="red"
).pl.show(ax=axs[1])
19 changes: 15 additions & 4 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
Expand Down Expand Up @@ -27,15 +28,25 @@ class TestPoints(PlotTester, metaclass=PlotTesterMeta):
def test_plot_can_render_points(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_points(element="blobs_points").pl.show()

def test_plot_can_filter_with_groups(self, sdata_blobs: SpatialData):
def test_plot_can_filter_with_groups_default_palette(self, sdata_blobs: SpatialData):
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")

sdata_blobs["table"].obs["region"] = ["blobs_points"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points"
sdata_blobs.pl.render_points(color="genes", groups="gene_b", palette="red").pl.show()

def test_plot_can_filter_with_groups_default_palette(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_points(color="genes", size=10).pl.show(ax=axs[0], legend_fontsize=6)
sdata_blobs.pl.render_points(color="genes", groups="gene_b", size=10).pl.show(ax=axs[1], legend_fontsize=6)

def test_plot_can_filter_with_groups_custom_palette(self, sdata_blobs: SpatialData):
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")

sdata_blobs["table"].obs["region"] = ["blobs_points"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points"
sdata_blobs.pl.render_points(color="genes", groups="gene_b").pl.show()

sdata_blobs.pl.render_points(color="genes", size=10).pl.show(ax=axs[0], legend_fontsize=6)
sdata_blobs.pl.render_points(color="genes", groups="gene_b", size=10, palette="red").pl.show(
ax=axs[1], legend_fontsize=6
)

def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = ["blobs_points"] * sdata_blobs["table"].n_obs
Expand Down
8 changes: 7 additions & 1 deletion tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import anndata
import geopandas as gpd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
Expand Down Expand Up @@ -107,6 +108,8 @@ def test_plot_can_scale_shapes(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", scale=0.5).pl.show()

def test_plot_can_filter_with_groups(self, sdata_blobs: SpatialData):
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")

sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["cluster"] = "c1"
Expand All @@ -115,7 +118,10 @@ def test_plot_can_filter_with_groups(self, sdata_blobs: SpatialData):
"category"
)

sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups="c1").pl.show()
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster").pl.show(ax=axs[0], legend_fontsize=6)
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups="c1").pl.show(
ax=axs[1], legend_fontsize=6
)

def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
Expand Down
22 changes: 10 additions & 12 deletions tests/pl/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,23 @@ def test_plot_can_set_zero_in_cmap_to_transparent(self, sdata_blobs: SpatialData
from spatialdata_plot.pl.utils import set_zero_in_cmap_to_transparent

# set up figure and modify the data to add 0s
fig, axs = plt.subplots(ncols=2, figsize=(6, 3))
table = sdata_blobs.table.copy()
x = table.X.todense()
x[:10, 0] = 0
table.X = x
sdata_blobs.tables["modified_table"] = table
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
sdata_blobs.tables["table"].obs["my_var"] = list(range(len(sdata_blobs.tables["table"].obs)))
sdata_blobs.tables["table"].obs["my_var"] += 2 # shift the values to not have 0s

# create a new cmap with 0 as transparent
new_cmap = set_zero_in_cmap_to_transparent(cmap="plasma")
new_cmap = set_zero_in_cmap_to_transparent(cmap="viridis")

# baseline img
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", cmap="viridis", table="table").pl.show(
sdata_blobs.pl.render_labels("blobs_labels", color="my_var", cmap="viridis", table="table").pl.show(
ax=axs[0], colorbar=False
)

sdata_blobs.tables["table"].obs.iloc[8:12, 2] = 0

# image with 0s as transparent, so some labels are "missing"
sdata_blobs.pl.render_labels(
"blobs_labels", color="channel_0_sum", cmap=new_cmap, table="modified_table"
).pl.show(ax=axs[1], colorbar=False)
sdata_blobs.pl.render_labels("blobs_labels", color="my_var", cmap=new_cmap, table="table").pl.show(
ax=axs[1], colorbar=False
)


@pytest.mark.parametrize(
Expand Down
Loading