Skip to content

Commit 267922a

Browse files
Corrected default 3 channel colormap (#127)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cbd454a commit 267922a

19 files changed

+136
-80
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 82 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import pandas as pd
1212
import scanpy as sc
1313
import spatialdata as sd
14-
import xarray as xr
1514
from anndata import AnnData
1615
from geopandas import GeoDataFrame
1716
from matplotlib import colors
@@ -328,7 +327,6 @@ def _render_images(
328327
fig_params: FigParams,
329328
scalebar_params: ScalebarParams,
330329
legend_params: LegendParams,
331-
# extent: tuple[float, float, float, float] | None = None,
332330
) -> None:
333331
elements = render_params.elements
334332

@@ -346,9 +344,6 @@ def _render_images(
346344
images = [sdata.images[e] for e in elements]
347345

348346
for img in images:
349-
if (len(img.c) > 3 or len(img.c) == 2) and render_params.channel is None:
350-
raise NotImplementedError("Only 1 or 3 channels are supported at the moment.")
351-
352347
if render_params.channel is None:
353348
channels = img.coords["c"].values
354349
else:
@@ -358,11 +353,8 @@ def _render_images(
358353

359354
n_channels = len(channels)
360355

356+
# True if user gave n cmaps for n channels
361357
got_multiple_cmaps = isinstance(render_params.cmap_params, list)
362-
363-
if not isinstance(render_params.cmap_params, list):
364-
render_params.cmap_params = [render_params.cmap_params] * n_channels
365-
366358
if got_multiple_cmaps:
367359
logger.warning(
368360
"You're blending multiple cmaps. "
@@ -372,102 +364,113 @@ def _render_images(
372364
"Consider using 'palette' instead."
373365
)
374366

375-
if render_params.palette is not None:
376-
logger.warning("Parameter 'palette' is ignored when a 'cmap' is provided.")
367+
# not using got_multiple_cmaps here because of ruff :(
368+
if isinstance(render_params.cmap_params, list) and len(render_params.cmap_params) != n_channels:
369+
raise ValueError("If 'cmap' is provided, its length must match the number of channels.")
370+
371+
# 1) Image has only 1 channel
372+
if n_channels == 1 and not isinstance(render_params.cmap_params, list):
373+
layer = img.sel(c=channels).squeeze()
374+
375+
if render_params.quantiles_for_norm != (None, None):
376+
layer = _normalize(
377+
layer, pmin=render_params.quantiles_for_norm[0], pmax=render_params.quantiles_for_norm[1], clip=True
378+
)
379+
380+
if render_params.cmap_params.norm is not None: # type: ignore[attr-defined]
381+
layer = render_params.cmap_params.norm(layer) # type: ignore[attr-defined]
377382

378-
for idx, channel in enumerate(channels):
379-
layer = img.sel(c=channel)
383+
if render_params.palette is None:
384+
cmap = render_params.cmap_params.cmap # type: ignore[attr-defined]
385+
else:
386+
cmap = _get_linear_colormap([render_params.palette], "k")[0]
387+
388+
ax.imshow(
389+
layer, # get rid of the channel dimension
390+
cmap=cmap,
391+
alpha=render_params.alpha,
392+
)
393+
394+
# 2) Image has any number of channels but 1
395+
else:
396+
layers = {}
397+
for i, c in enumerate(channels):
398+
layers[c] = img.sel(c=c).copy(deep=True).squeeze()
380399

381400
if render_params.quantiles_for_norm != (None, None):
382-
layer = _normalize(
383-
layer,
401+
layers[c] = _normalize(
402+
layers[c],
384403
pmin=render_params.quantiles_for_norm[0],
385404
pmax=render_params.quantiles_for_norm[1],
386405
clip=True,
387406
)
388407

389-
if render_params.cmap_params[idx].norm is not None:
390-
layer = render_params.cmap_params[idx].norm(layer)
408+
if not isinstance(render_params.cmap_params, list):
409+
if render_params.cmap_params.norm is not None:
410+
layers[c] = render_params.cmap_params.norm(layers[c])
411+
else:
412+
if render_params.cmap_params[i].norm is not None:
413+
layers[c] = render_params.cmap_params[i].norm(layers[c])
391414

392-
ax.imshow(
393-
layer,
394-
cmap=render_params.cmap_params[idx].cmap,
395-
alpha=(1 / n_channels),
396-
)
397-
break
415+
# 2A) Image has 3 channels, no palette/cmap info -> use RGB
416+
if n_channels == 3 and render_params.palette is None and not got_multiple_cmaps:
417+
ax.imshow(np.stack([layers[c] for c in channels], axis=-1), alpha=render_params.alpha)
398418

399-
if n_channels == 1:
400-
layer = img.sel(c=channels)
419+
# 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
420+
elif render_params.palette is None and not got_multiple_cmaps:
421+
# overwrite if n_channels == 2 for intuitive result
422+
if n_channels == 2:
423+
seed_colors = ["#ff0000ff", "#00ff00ff"]
424+
else:
425+
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
401426

402-
if render_params.quantiles_for_norm != (None, None):
403-
layer = _normalize(
404-
layer, pmin=render_params.quantiles_for_norm[0], pmax=render_params.quantiles_for_norm[1], clip=True
405-
)
427+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
406428

407-
if render_params.cmap_params[0].norm is not None:
408-
layer = render_params.cmap_params[0].norm(layer)
429+
# Apply cmaps to each channel and add up
430+
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
409431

410-
if render_params.palette is None:
411-
ax.imshow(
412-
layer.squeeze(), # get rid of the channel dimension
413-
cmap=render_params.cmap_params[0].cmap,
414-
)
432+
# Remove alpha channel so we can overwrite it from render_params.alpha
433+
colored = colored[:, :, :3]
415434

416-
else:
417435
ax.imshow(
418-
layer.squeeze(), # get rid of the channel dimension
419-
cmap=_get_linear_colormap([render_params.palette], "k")[0],
436+
colored,
437+
alpha=render_params.alpha,
420438
)
421439

422-
break
440+
# 2C) Image has n channels and palette info
441+
elif render_params.palette is not None and not got_multiple_cmaps:
442+
if len(render_params.palette) != n_channels:
443+
raise ValueError("If 'palette' is provided, its length must match the number of channels.")
423444

424-
if render_params.palette is not None and n_channels != len(render_params.palette):
425-
raise ValueError("If 'palette' is provided, its length must match the number of channels.")
445+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in render_params.palette]
426446

427-
if n_channels > 1:
428-
layer = img.sel(c=channels).copy(deep=True)
447+
# Apply cmaps to each channel and add up
448+
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
429449

430-
channel_colors: list[str] | Any
431-
if render_params.palette is None:
432-
channel_colors = _get_colors_for_categorical_obs(
433-
layer.coords["c"].values.tolist(), palette=render_params.cmap_params[0].cmap
450+
# Remove alpha channel so we can overwrite it from render_params.alpha
451+
colored = colored[:, :, :3]
452+
453+
ax.imshow(
454+
colored,
455+
alpha=render_params.alpha,
434456
)
435-
else:
436-
channel_colors = render_params.palette
437457

438-
channel_cmaps = _get_linear_colormap([str(c) for c in channel_colors[:n_channels]], "k")
458+
elif render_params.palette is None and got_multiple_cmaps:
459+
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]
439460

440-
layer_vals = []
441-
if render_params.quantiles_for_norm != (None, None):
442-
for i in range(n_channels):
443-
layer_vals.append(
444-
_normalize(
445-
layer.values[i],
446-
pmin=render_params.quantiles_for_norm[0],
447-
pmax=render_params.quantiles_for_norm[1],
448-
clip=True,
449-
)
450-
)
461+
# Apply cmaps to each channel, add up and normalize to [0, 1]
462+
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0) / n_channels
451463

452-
colored = np.stack([channel_cmaps[i](layer_vals[i]) for i in range(n_channels)], 0).sum(0)
464+
# Remove alpha channel so we can overwrite it from render_params.alpha
465+
colored = colored[:, :, :3]
453466

454-
layer = xr.DataArray(
455-
data=colored,
456-
coords=[
457-
layer.coords["y"],
458-
layer.coords["x"],
459-
["R", "G", "B", "A"],
460-
],
461-
dims=["y", "x", "c"],
462-
)
463-
layer = layer.transpose("y", "x", "c") # for plotting
467+
ax.imshow(
468+
colored,
469+
alpha=render_params.alpha,
470+
)
464471

465-
ax.imshow(
466-
layer.data,
467-
cmap=channel_cmaps[0],
468-
alpha=render_params.alpha,
469-
norm=render_params.cmap_params[0].norm,
470-
)
472+
elif render_params.palette is not None and got_multiple_cmaps:
473+
raise ValueError("If 'palette' is provided, 'cmap' must be None.")
471474

472475

473476
@dataclass
Loading
Loading
Loading
Binary file not shown.
Loading
Loading
Loading
1.95 KB
Loading
Loading

tests/_images/Labels_labels.png

-14 KB
Binary file not shown.
Loading
Loading
Loading
Loading

tests/_images/Points_points.png

-7.03 KB
Binary file not shown.
883 Bytes
Loading

tests/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from shapely.geometry import MultiPolygon, Polygon
1919
from spatial_image import SpatialImage
2020
from spatialdata import SpatialData
21-
from spatialdata.datasets import blobs
21+
from spatialdata.datasets import blobs, raccoon
2222
from spatialdata.models import (
2323
Image2DModel,
2424
Image3DModel,
@@ -56,6 +56,11 @@ def sdata_blobs() -> SpatialData:
5656
return blobs()
5757

5858

59+
@pytest.fixture()
60+
def sdata_raccoon() -> SpatialData:
61+
return raccoon()
62+
63+
5964
@pytest.fixture
6065
def test_sdata_single_image():
6166
"""Creates a simple sdata object."""

tests/pl/test_upstream_plots.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import matplotlib
2+
import matplotlib.pyplot as plt
3+
import scanpy as sc
4+
import spatialdata_plot # noqa: F401
5+
from spatialdata import SpatialData
6+
from spatialdata.transformations import (
7+
MapAxis,
8+
Scale,
9+
set_transformation,
10+
)
11+
12+
from tests.conftest import PlotTester, PlotTesterMeta
13+
14+
sc.pl.set_rcParams_defaults()
15+
sc.set_figure_params(dpi=40, color_map="viridis")
16+
matplotlib.use("agg") # same as GitHub action runner
17+
_ = spatialdata_plot
18+
19+
# WARNING:
20+
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
21+
# 2. tests which produce a plot must be prefixed with `test_plot_`
22+
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
23+
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
24+
# ".png" is appended to <your_filename>, no need to set it
25+
26+
27+
class TestNotebooksTransformations(PlotTester, metaclass=PlotTesterMeta):
28+
def test_plot_can_render_transformations_raccoon_split(self, sdata_raccoon: SpatialData):
29+
_, axs = plt.subplots(ncols=3, figsize=(12, 3))
30+
31+
sdata_raccoon.pl.render_images().pl.show(ax=axs[0])
32+
sdata_raccoon.pl.render_labels().pl.show(ax=axs[1])
33+
sdata_raccoon.pl.render_shapes().pl.show(ax=axs[2])
34+
35+
def test_plot_can_render_transformations_raccoon_overlay(self, sdata_raccoon: SpatialData):
36+
sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show()
37+
38+
def test_plot_can_render_transformations_raccoon_scale(self, sdata_raccoon: SpatialData):
39+
scale = Scale([2.0], axes=("x",))
40+
set_transformation(sdata_raccoon.images["raccoon"], scale, to_coordinate_system="global")
41+
42+
sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show()
43+
44+
def test_plot_can_render_transformations_raccoon_mapaxis(self, sdata_raccoon: SpatialData):
45+
map_axis = MapAxis({"x": "y", "y": "x"})
46+
set_transformation(sdata_raccoon.images["raccoon"], map_axis, to_coordinate_system="global")
47+
48+
sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show()

0 commit comments

Comments
 (0)