Skip to content

Corrected default 3 channel colormap #127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 82 additions & 79 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import pandas as pd
import scanpy as sc
import spatialdata as sd
import xarray as xr
from anndata import AnnData
from geopandas import GeoDataFrame
from matplotlib import colors
Expand Down Expand Up @@ -328,7 +327,6 @@ def _render_images(
fig_params: FigParams,
scalebar_params: ScalebarParams,
legend_params: LegendParams,
# extent: tuple[float, float, float, float] | None = None,
) -> None:
elements = render_params.elements

Expand All @@ -346,9 +344,6 @@ def _render_images(
images = [sdata.images[e] for e in elements]

for img in images:
if (len(img.c) > 3 or len(img.c) == 2) and render_params.channel is None:
raise NotImplementedError("Only 1 or 3 channels are supported at the moment.")

if render_params.channel is None:
channels = img.coords["c"].values
else:
Expand All @@ -358,11 +353,8 @@ def _render_images(

n_channels = len(channels)

# True if user gave n cmaps for n channels
got_multiple_cmaps = isinstance(render_params.cmap_params, list)

if not isinstance(render_params.cmap_params, list):
render_params.cmap_params = [render_params.cmap_params] * n_channels

if got_multiple_cmaps:
logger.warning(
"You're blending multiple cmaps. "
Expand All @@ -372,102 +364,113 @@ def _render_images(
"Consider using 'palette' instead."
)

if render_params.palette is not None:
logger.warning("Parameter 'palette' is ignored when a 'cmap' is provided.")
# not using got_multiple_cmaps here because of ruff :(
if isinstance(render_params.cmap_params, list) and len(render_params.cmap_params) != n_channels:
raise ValueError("If 'cmap' is provided, its length must match the number of channels.")

# 1) Image has only 1 channel
if n_channels == 1 and not isinstance(render_params.cmap_params, list):
layer = img.sel(c=channels).squeeze()

if render_params.quantiles_for_norm != (None, None):
layer = _normalize(
layer, pmin=render_params.quantiles_for_norm[0], pmax=render_params.quantiles_for_norm[1], clip=True
)

if render_params.cmap_params.norm is not None: # type: ignore[attr-defined]
layer = render_params.cmap_params.norm(layer) # type: ignore[attr-defined]

for idx, channel in enumerate(channels):
layer = img.sel(c=channel)
if render_params.palette is None:
cmap = render_params.cmap_params.cmap # type: ignore[attr-defined]
else:
cmap = _get_linear_colormap([render_params.palette], "k")[0]

ax.imshow(
layer, # get rid of the channel dimension
cmap=cmap,
alpha=render_params.alpha,
)

# 2) Image has any number of channels but 1
else:
layers = {}
for i, c in enumerate(channels):
layers[c] = img.sel(c=c).copy(deep=True).squeeze()

if render_params.quantiles_for_norm != (None, None):
layer = _normalize(
layer,
layers[c] = _normalize(
layers[c],
pmin=render_params.quantiles_for_norm[0],
pmax=render_params.quantiles_for_norm[1],
clip=True,
)

if render_params.cmap_params[idx].norm is not None:
layer = render_params.cmap_params[idx].norm(layer)
if not isinstance(render_params.cmap_params, list):
if render_params.cmap_params.norm is not None:
layers[c] = render_params.cmap_params.norm(layers[c])
else:
if render_params.cmap_params[i].norm is not None:
layers[c] = render_params.cmap_params[i].norm(layers[c])

ax.imshow(
layer,
cmap=render_params.cmap_params[idx].cmap,
alpha=(1 / n_channels),
)
break
# 2A) Image has 3 channels, no palette/cmap info -> use RGB
if n_channels == 3 and render_params.palette is None and not got_multiple_cmaps:
ax.imshow(np.stack([layers[c] for c in channels], axis=-1), alpha=render_params.alpha)

if n_channels == 1:
layer = img.sel(c=channels)
# 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
elif render_params.palette is None and not got_multiple_cmaps:
# overwrite if n_channels == 2 for intuitive result
if n_channels == 2:
seed_colors = ["#ff0000ff", "#00ff00ff"]
else:
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))

if render_params.quantiles_for_norm != (None, None):
layer = _normalize(
layer, pmin=render_params.quantiles_for_norm[0], pmax=render_params.quantiles_for_norm[1], clip=True
)
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]

if render_params.cmap_params[0].norm is not None:
layer = render_params.cmap_params[0].norm(layer)
# Apply cmaps to each channel and add up
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)

if render_params.palette is None:
ax.imshow(
layer.squeeze(), # get rid of the channel dimension
cmap=render_params.cmap_params[0].cmap,
)
# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]

else:
ax.imshow(
layer.squeeze(), # get rid of the channel dimension
cmap=_get_linear_colormap([render_params.palette], "k")[0],
colored,
alpha=render_params.alpha,
)

break
# 2C) Image has n channels and palette info
elif render_params.palette is not None and not got_multiple_cmaps:
if len(render_params.palette) != n_channels:
raise ValueError("If 'palette' is provided, its length must match the number of channels.")

if render_params.palette is not None and n_channels != len(render_params.palette):
raise ValueError("If 'palette' is provided, its length must match the number of channels.")
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in render_params.palette]

if n_channels > 1:
layer = img.sel(c=channels).copy(deep=True)
# Apply cmaps to each channel and add up
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)

channel_colors: list[str] | Any
if render_params.palette is None:
channel_colors = _get_colors_for_categorical_obs(
layer.coords["c"].values.tolist(), palette=render_params.cmap_params[0].cmap
# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]

ax.imshow(
colored,
alpha=render_params.alpha,
)
else:
channel_colors = render_params.palette

channel_cmaps = _get_linear_colormap([str(c) for c in channel_colors[:n_channels]], "k")
elif render_params.palette is None and got_multiple_cmaps:
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]

layer_vals = []
if render_params.quantiles_for_norm != (None, None):
for i in range(n_channels):
layer_vals.append(
_normalize(
layer.values[i],
pmin=render_params.quantiles_for_norm[0],
pmax=render_params.quantiles_for_norm[1],
clip=True,
)
)
# Apply cmaps to each channel, add up and normalize to [0, 1]
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0) / n_channels

colored = np.stack([channel_cmaps[i](layer_vals[i]) for i in range(n_channels)], 0).sum(0)
# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]

layer = xr.DataArray(
data=colored,
coords=[
layer.coords["y"],
layer.coords["x"],
["R", "G", "B", "A"],
],
dims=["y", "x", "c"],
)
layer = layer.transpose("y", "x", "c") # for plotting
ax.imshow(
colored,
alpha=render_params.alpha,
)

ax.imshow(
layer.data,
cmap=channel_cmaps[0],
alpha=render_params.alpha,
norm=render_params.cmap_params[0].norm,
)
elif render_params.palette is not None and got_multiple_cmaps:
raise ValueError("If 'palette' is provided, 'cmap' must be None.")


@dataclass
Expand Down
Binary file modified tests/_images/Extent_extent_of_img_full_canvas.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Extent_extent_of_partial_canvas_on_full_canvas.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_pass_cmap_to_each_channel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_pass_cmap_to_render_images.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_render_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_render_two_channels_from_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed tests/_images/Labels_labels.png
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed tests/_images/Points_points.png
Binary file not shown.
Binary file modified tests/_images/Show_pad_extent_adds_padding.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from shapely.geometry import MultiPolygon, Polygon
from spatial_image import SpatialImage
from spatialdata import SpatialData
from spatialdata.datasets import blobs
from spatialdata.datasets import blobs, raccoon
from spatialdata.models import (
Image2DModel,
Image3DModel,
Expand Down Expand Up @@ -56,6 +56,11 @@ def sdata_blobs() -> SpatialData:
return blobs()


@pytest.fixture()
def sdata_raccoon() -> SpatialData:
return raccoon()


@pytest.fixture
def test_sdata_single_image():
"""Creates a simple sdata object."""
Expand Down
48 changes: 48 additions & 0 deletions tests/pl/test_upstream_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import matplotlib
import matplotlib.pyplot as plt
import scanpy as sc
import spatialdata_plot # noqa: F401
from spatialdata import SpatialData
from spatialdata.transformations import (
MapAxis,
Scale,
set_transformation,
)

from tests.conftest import PlotTester, PlotTesterMeta

sc.pl.set_rcParams_defaults()
sc.set_figure_params(dpi=40, color_map="viridis")
matplotlib.use("agg") # same as GitHub action runner
_ = spatialdata_plot

# WARNING:
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
# 2. tests which produce a plot must be prefixed with `test_plot_`
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
# ".png" is appended to <your_filename>, no need to set it


class TestNotebooksTransformations(PlotTester, metaclass=PlotTesterMeta):
def test_plot_can_render_transformations_raccoon_split(self, sdata_raccoon: SpatialData):
_, axs = plt.subplots(ncols=3, figsize=(12, 3))

sdata_raccoon.pl.render_images().pl.show(ax=axs[0])
sdata_raccoon.pl.render_labels().pl.show(ax=axs[1])
sdata_raccoon.pl.render_shapes().pl.show(ax=axs[2])

def test_plot_can_render_transformations_raccoon_overlay(self, sdata_raccoon: SpatialData):
sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show()

def test_plot_can_render_transformations_raccoon_scale(self, sdata_raccoon: SpatialData):
scale = Scale([2.0], axes=("x",))
set_transformation(sdata_raccoon.images["raccoon"], scale, to_coordinate_system="global")

sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show()

def test_plot_can_render_transformations_raccoon_mapaxis(self, sdata_raccoon: SpatialData):
map_axis = MapAxis({"x": "y", "y": "x"})
set_transformation(sdata_raccoon.images["raccoon"], map_axis, to_coordinate_system="global")

sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show()