Skip to content

Commit 6b83753

Browse files
committed
refactored render_images
1 parent ce26519 commit 6b83753

19 files changed

+140
-78
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 86 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import pandas as pd
1212
import scanpy as sc
1313
import spatialdata as sd
14-
import xarray as xr
1514
from anndata import AnnData
1615
from geopandas import GeoDataFrame
1716
from matplotlib import colors
@@ -29,6 +28,7 @@
2928
OutlineParams,
3029
ScalebarParams,
3130
_decorate_axs,
31+
_get_colors_for_categorical_obs,
3232
_get_linear_colormap,
3333
_map_color_seg,
3434
_maybe_set_colors,
@@ -327,7 +327,6 @@ def _render_images(
327327
fig_params: FigParams,
328328
scalebar_params: ScalebarParams,
329329
legend_params: LegendParams,
330-
# extent: tuple[float, float, float, float] | None = None,
331330
) -> None:
332331
elements = render_params.elements
333332

@@ -345,8 +344,8 @@ def _render_images(
345344
images = [sdata.images[e] for e in elements]
346345

347346
for img in images:
348-
if (len(img.c) > 3 or len(img.c) == 2) and render_params.channel is None:
349-
raise NotImplementedError("Only 1 or 3 channels are supported at the moment.")
347+
# if (len(img.c) > 3 or len(img.c) == 2) and render_params.channel is None:
348+
# raise NotImplementedError("Only 1 or 3 channels are supported at the moment.")
350349

351350
if render_params.channel is None:
352351
channels = img.coords["c"].values
@@ -357,11 +356,8 @@ def _render_images(
357356

358357
n_channels = len(channels)
359358

359+
# True if user gave n cmaps for n channels
360360
got_multiple_cmaps = isinstance(render_params.cmap_params, list)
361-
362-
if not isinstance(render_params.cmap_params, list):
363-
render_params.cmap_params = [render_params.cmap_params] * n_channels
364-
365361
if got_multiple_cmaps:
366362
logger.warning(
367363
"You're blending multiple cmaps. "
@@ -371,100 +367,113 @@ def _render_images(
371367
"Consider using 'palette' instead."
372368
)
373369

374-
if render_params.palette is not None:
375-
logger.warning("Parameter 'palette' is ignored when a 'cmap' is provided.")
370+
# not using got_multiple_cmaps here because of ruff :(
371+
if isinstance(render_params.cmap_params, list) and len(render_params.cmap_params) != n_channels:
372+
raise ValueError("If 'cmap' is provided, its length must match the number of channels.")
373+
374+
# 1) Image has only 1 channel
375+
if n_channels == 1 and not isinstance(render_params.cmap_params, list):
376+
layer = img.sel(c=channels).squeeze()
377+
378+
if render_params.quantiles_for_norm != (None, None):
379+
layer = _normalize(
380+
layer, pmin=render_params.quantiles_for_norm[0], pmax=render_params.quantiles_for_norm[1], clip=True
381+
)
382+
383+
if render_params.cmap_params.norm is not None: # type: ignore[attr-defined]
384+
layer = render_params.cmap_params.norm(layer) # type: ignore[attr-defined]
385+
386+
if render_params.palette is None:
387+
cmap = render_params.cmap_params.cmap # type: ignore[attr-defined]
388+
else:
389+
cmap = _get_linear_colormap([render_params.palette], "k")[0]
390+
391+
ax.imshow(
392+
layer, # get rid of the channel dimension
393+
cmap=cmap,
394+
alpha=render_params.alpha,
395+
)
376396

377-
for idx, channel in enumerate(channels):
378-
layer = img.sel(c=channel)
397+
# 2) Image has any number of channels but 1
398+
else:
399+
layers = {}
400+
for i, c in enumerate(channels):
401+
layers[c] = img.sel(c=c).copy(deep=True).squeeze()
379402

380403
if render_params.quantiles_for_norm != (None, None):
381-
layer = _normalize(
382-
layer,
404+
layers[c] = _normalize(
405+
layers[c],
383406
pmin=render_params.quantiles_for_norm[0],
384407
pmax=render_params.quantiles_for_norm[1],
385408
clip=True,
386409
)
387410

388-
if render_params.cmap_params[idx].norm is not None:
389-
layer = render_params.cmap_params[idx].norm(layer)
411+
if not isinstance(render_params.cmap_params, list):
412+
if render_params.cmap_params.norm is not None:
413+
layers[c] = render_params.cmap_params.norm(layers[c])
414+
else:
415+
if render_params.cmap_params[i].norm is not None:
416+
layers[c] = render_params.cmap_params[i].norm(layers[c])
390417

391-
ax.imshow(
392-
layer,
393-
cmap=render_params.cmap_params[idx].cmap,
394-
alpha=(1 / n_channels),
395-
)
396-
break
418+
# 2A) Image has 3 channels, no palette/cmap info -> use RGB
419+
if n_channels == 3 and render_params.palette is None and not got_multiple_cmaps:
420+
ax.imshow(np.stack([layers[c] for c in channels], axis=-1), alpha=render_params.alpha)
397421

398-
if n_channels == 1:
399-
layer = img.sel(c=channels)
422+
# 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
423+
elif render_params.palette is None and not got_multiple_cmaps:
424+
# overwrite if n_channels == 2 for intuitive result
425+
if n_channels == 2:
426+
seed_colors = ["#ff0000ff", "#00ff00ff"]
427+
else:
428+
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
400429

401-
if render_params.quantiles_for_norm != (None, None):
402-
layer = _normalize(
403-
layer, pmin=render_params.quantiles_for_norm[0], pmax=render_params.quantiles_for_norm[1], clip=True
404-
)
430+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
405431

406-
if render_params.cmap_params[0].norm is not None:
407-
layer = render_params.cmap_params[0].norm(layer)
432+
# Apply cmaps to each channel and add up
433+
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
408434

409-
if render_params.palette is None:
410-
ax.imshow(
411-
layer.squeeze(), # get rid of the channel dimension
412-
cmap=render_params.cmap_params[0].cmap,
413-
)
435+
# Remove alpha channel so we can overwrite it from render_params.alpha
436+
colored = colored[:, :, :3]
414437

415-
else:
416438
ax.imshow(
417-
layer.squeeze(), # get rid of the channel dimension
418-
cmap=_get_linear_colormap([render_params.palette], "k")[0],
439+
colored,
440+
alpha=render_params.alpha,
419441
)
420442

421-
break
443+
# 2C) Image has n channels and palette info
444+
elif render_params.palette is not None and not got_multiple_cmaps:
445+
if len(render_params.palette) != n_channels:
446+
raise ValueError("If 'palette' is provided, its length must match the number of channels.")
422447

423-
if render_params.palette is not None and n_channels != len(render_params.palette):
424-
raise ValueError("If 'palette' is provided, its length must match the number of channels.")
448+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in render_params.palette]
425449

426-
if n_channels > 1: # to capture n_channels = 3 and custom number cases
427-
layer = img.sel(c=channels).copy(deep=True)
450+
# Apply cmaps to each channel and add up
451+
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
428452

429-
channel_colors: list[str] | Any
430-
if render_params.palette is None:
431-
channel_colors = ["#ff0000ff", "#00ff00ff", "#0000ffff"]
432-
else:
433-
channel_colors = render_params.palette
453+
# Remove alpha channel so we can overwrite it from render_params.alpha
454+
colored = colored[:, :, :3]
455+
456+
ax.imshow(
457+
colored,
458+
alpha=render_params.alpha,
459+
)
434460

435-
channel_cmaps = _get_linear_colormap([str(c) for c in channel_colors[:n_channels]], "k")
461+
elif render_params.palette is None and got_multiple_cmaps:
462+
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]
436463

437-
layer_vals = []
438-
if render_params.quantiles_for_norm != (None, None):
439-
for i in range(n_channels):
440-
layer_vals.append(
441-
_normalize(
442-
layer.values[i],
443-
pmin=render_params.quantiles_for_norm[0],
444-
pmax=render_params.quantiles_for_norm[1],
445-
clip=True,
446-
)
447-
)
464+
# Apply cmaps to each channel, add up and normalize to [0, 1]
465+
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0) / n_channels
448466

449-
colored = np.stack([channel_cmaps[i](layer_vals[i]) for i in range(n_channels)], 0).sum(0)
467+
# Remove alpha channel so we can overwrite it from render_params.alpha
468+
colored = colored[:, :, :3]
450469

451-
layer = xr.DataArray(
452-
data=colored,
453-
coords=[
454-
layer.coords["y"],
455-
layer.coords["x"],
456-
["R", "G", "B", "A"],
457-
],
458-
dims=["y", "x", "c"],
459-
)
460-
layer = layer.transpose("y", "x", "c") # for plotting
470+
ax.imshow(
471+
colored,
472+
alpha=render_params.alpha,
473+
)
461474

462-
ax.imshow(
463-
layer.data,
464-
cmap=channel_cmaps[0],
465-
alpha=render_params.alpha,
466-
norm=render_params.cmap_params[0].norm,
467-
)
475+
elif render_params.palette is not None and got_multiple_cmaps:
476+
raise ValueError("If 'palette' is provided, 'cmap' must be None.")
468477

469478

470479
@dataclass
Loading
Loading
Loading
Binary file not shown.
Loading
Loading
Loading
1.95 KB
Loading
Loading

tests/_images/Labels_labels.png

-14 KB
Binary file not shown.
Loading
Loading
Loading
Loading

tests/_images/Points_points.png

-7.03 KB
Binary file not shown.
883 Bytes
Loading

tests/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from shapely.geometry import MultiPolygon, Polygon
1919
from spatial_image import SpatialImage
2020
from spatialdata import SpatialData
21-
from spatialdata.datasets import blobs
21+
from spatialdata.datasets import blobs, raccoon
2222
from spatialdata.models import (
2323
Image2DModel,
2424
Image3DModel,
@@ -56,6 +56,11 @@ def sdata_blobs() -> SpatialData:
5656
return blobs()
5757

5858

59+
@pytest.fixture()
60+
def sdata_raccoon() -> SpatialData:
61+
return raccoon()
62+
63+
5964
@pytest.fixture
6065
def test_sdata_single_image():
6166
"""Creates a simple sdata object."""

tests/pl/test_upstream_plots.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import matplotlib
2+
import matplotlib.pyplot as plt
3+
import scanpy as sc
4+
import spatialdata_plot # noqa: F401
5+
from spatialdata import SpatialData
6+
from spatialdata.transformations import (
7+
MapAxis,
8+
Scale,
9+
set_transformation,
10+
)
11+
12+
from tests.conftest import PlotTester, PlotTesterMeta
13+
14+
sc.pl.set_rcParams_defaults()
15+
sc.set_figure_params(dpi=40, color_map="viridis")
16+
matplotlib.use("agg") # same as GitHub action runner
17+
_ = spatialdata_plot
18+
19+
# WARNING:
20+
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
21+
# 2. tests which produce a plot must be prefixed with `test_plot_`
22+
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
23+
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
24+
# ".png" is appended to <your_filename>, no need to set it
25+
26+
27+
class TestNotebooksTransformations(PlotTester, metaclass=PlotTesterMeta):
28+
def test_plot_can_render_transformations_raccoon_split(self, sdata_raccoon: SpatialData):
29+
_, axs = plt.subplots(ncols=3, figsize=(12, 3))
30+
31+
sdata_raccoon.pl.render_images().pl.show(ax=axs[0])
32+
sdata_raccoon.pl.render_labels().pl.show(ax=axs[1])
33+
sdata_raccoon.pl.render_shapes().pl.show(ax=axs[2])
34+
35+
def test_plot_can_render_transformations_raccoon_overlay(self, sdata_raccoon: SpatialData):
36+
sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show()
37+
38+
def test_plot_can_render_transformations_raccoon_scale(self, sdata_raccoon: SpatialData):
39+
scale = Scale([2.0], axes=("x",))
40+
set_transformation(sdata_raccoon.images["raccoon"], scale, to_coordinate_system="global")
41+
42+
sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show()
43+
44+
def test_plot_can_render_transformations_raccoon_mapaxis(self, sdata_raccoon: SpatialData):
45+
map_axis = MapAxis({"x": "y", "y": "x"})
46+
set_transformation(sdata_raccoon.images["raccoon"], map_axis, to_coordinate_system="global")
47+
48+
sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show()

0 commit comments

Comments
 (0)