Skip to content

Commit f19a308

Browse files
authored
Merge branch 'main' into bugfix/issue108-outline_color-doesnt-work
2 parents 4a15039 + 139b774 commit f19a308

21 files changed

+197
-47
lines changed

src/spatialdata_plot/_logging.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# from https://github.com/scverse/spatialdata/blob/main/src/spatialdata/_logging.py
2+
3+
import logging
4+
5+
6+
def _setup_logger() -> "logging.Logger":
7+
from rich.console import Console
8+
from rich.logging import RichHandler
9+
10+
logger = logging.getLogger(__name__)
11+
logger.setLevel(logging.INFO)
12+
console = Console(force_terminal=True)
13+
if console.is_jupyter is True:
14+
console.is_jupyter = False
15+
ch = RichHandler(show_path=False, console=console, show_time=False)
16+
logger.addHandler(ch)
17+
18+
# this prevents double outputs
19+
logger.propagate = False
20+
return logger
21+
22+
23+
logger = _setup_logger()

src/spatialdata_plot/pl/basic.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_render_shapes,
3434
)
3535
from spatialdata_plot.pl.utils import (
36+
CmapParams,
3637
LegendParams,
3738
_FontSize,
3839
_FontWeight,
@@ -296,11 +297,12 @@ def render_images(
296297
self,
297298
elements: str | list[str] | None = None,
298299
channel: list[str] | list[int] | int | str | None = None,
299-
cmap: Colormap | str | None = None,
300+
cmap: list[Colormap] | list[str] | Colormap | str | None = None,
300301
norm: None | Normalize = None,
301302
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
302303
palette: ListedColormap | str | None = None,
303304
alpha: float = 1.0,
305+
quantiles_for_norm: tuple[float | None, float | None] = (3.0, 99.8), # defaults from CSBDeep
304306
**kwargs: Any,
305307
) -> sd.SpatialData:
306308
"""
@@ -321,6 +323,8 @@ def render_images(
321323
Color to be used for NAs values, if present.
322324
alpha
323325
Alpha value for the shapes.
326+
quantiles_for_norm
327+
Tuple of (pmin, pmax) which will be used for quantile normalization.
324328
kwargs
325329
Additional arguments to be passed to cmap and norm.
326330
@@ -332,18 +336,36 @@ def render_images(
332336
sdata = _verify_plotting_tree(sdata)
333337
n_steps = len(sdata.plotting_tree.keys())
334338

335-
cmap_params = _prepare_cmap_norm(
336-
cmap=cmap,
337-
norm=norm,
338-
na_color=na_color, # type: ignore[arg-type]
339-
**kwargs,
340-
)
339+
if channel is None and cmap is None:
340+
cmap = "brg"
341+
342+
cmap_params: list[CmapParams] | CmapParams
343+
if isinstance(cmap, list):
344+
cmap_params = [
345+
_prepare_cmap_norm(
346+
cmap=c,
347+
norm=norm,
348+
na_color=na_color, # type: ignore[arg-type]
349+
**kwargs,
350+
)
351+
for c in cmap
352+
]
353+
354+
else:
355+
cmap_params = _prepare_cmap_norm(
356+
cmap=cmap,
357+
norm=norm,
358+
na_color=na_color, # type: ignore[arg-type]
359+
**kwargs,
360+
)
361+
341362
sdata.plotting_tree[f"{n_steps+1}_render_images"] = ImageRenderParams(
342363
elements=elements,
343364
channel=channel,
344365
cmap_params=cmap_params,
345366
palette=palette,
346367
alpha=alpha,
368+
quantiles_for_norm=quantiles_for_norm,
347369
)
348370

349371
return sdata

src/spatialdata_plot/pl/render.py

Lines changed: 111 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from copy import copy
55
from dataclasses import dataclass
66
from functools import partial
7-
from typing import Any, Callable, Optional, Union
7+
from typing import Any, Callable, Union
88

99
import matplotlib
1010
import numpy as np
@@ -21,6 +21,7 @@
2121
from pandas.api.types import is_categorical_dtype
2222
from scanpy._settings import settings as sc_settings
2323

24+
from spatialdata_plot._logging import logger
2425
from spatialdata_plot.pl.utils import (
2526
CmapParams,
2627
FigParams,
@@ -37,7 +38,6 @@
3738
)
3839
from spatialdata_plot.pp.utils import _get_instance_key, _get_region_key
3940

40-
Palette_t = Optional[Union[str, ListedColormap]]
4141
_Normalize = Union[Normalize, Sequence[Normalize]]
4242
to_hex = partial(colors.to_hex, keep_alpha=True)
4343

@@ -54,7 +54,7 @@ class ShapesRenderParams:
5454
contour_px: int | None = None
5555
alt_var: str | None = None
5656
layer: str | None = None
57-
palette: Palette_t = None
57+
palette: ListedColormap | str | None = None
5858
outline_alpha: float = 1.0
5959
fill_alpha: float = 0.3
6060
size: float = 1.0
@@ -208,7 +208,7 @@ class PointsRenderParams:
208208
elements: str | Sequence[str] | None = None
209209
color: str | None = None
210210
groups: str | Sequence[str] | None = None
211-
palette: Palette_t = None
211+
palette: ListedColormap | str | None = None
212212
alpha: float = 1.0
213213
size: float = 1.0
214214
transfunc: Callable[[float], float] | None = None
@@ -312,11 +312,12 @@ def _render_points(
312312
class ImageRenderParams:
313313
"""Labels render parameters.."""
314314

315-
cmap_params: CmapParams
315+
cmap_params: list[CmapParams] | CmapParams
316316
elements: str | Sequence[str] | None = None
317317
channel: list[str] | list[int] | int | str | None = None
318-
palette: Palette_t = None
318+
palette: ListedColormap | str | None = None
319319
alpha: float = 1.0
320+
quantiles_for_norm: tuple[float | None, float | None] = (3.0, 99.8) # defaults from CSBDeep
320321

321322

322323
def _render_images(
@@ -347,47 +348,126 @@ def _render_images(
347348
for img in images:
348349
if (len(img.c) > 3 or len(img.c) == 2) and render_params.channel is None:
349350
raise NotImplementedError("Only 1 or 3 channels are supported at the moment.")
350-
if render_params.channel is None and len(img.c) == 1:
351-
render_params.channel = 0
352-
if render_params.channel is not None:
351+
352+
if render_params.channel is None:
353+
channels = img.coords["c"].values
354+
else:
353355
channels = (
354356
[render_params.channel] if isinstance(render_params.channel, (str, int)) else render_params.channel
355357
)
356-
img = img.sel(c=channels)
357-
num_channels = img.sizes["c"]
358+
359+
n_channels = len(channels)
360+
361+
got_multiple_cmaps = isinstance(render_params.cmap_params, list)
362+
363+
if not isinstance(render_params.cmap_params, list):
364+
render_params.cmap_params = [render_params.cmap_params] * n_channels
365+
366+
if got_multiple_cmaps:
367+
logger.warning(
368+
"You're blending multiple cmaps. "
369+
"If the plot doesn't look like you expect, it might be because your "
370+
"cmaps go from a given color to 'white', and not to 'transparent'. "
371+
"Therefore, the 'white' of higher layers will overlay the lower layers. "
372+
"Consider using 'palette' instead."
373+
)
358374

359375
if render_params.palette is not None:
360-
if num_channels > len(render_params.palette):
361-
raise ValueError("If palette is provided, it must match the number of channels.")
376+
logger.warning("Parameter 'palette' is ignored when a 'cmap' is provided.")
377+
378+
for idx, channel in enumerate(channels):
379+
layer = img.sel(c=channel)
380+
381+
if render_params.quantiles_for_norm != (None, None):
382+
layer = _normalize(
383+
layer,
384+
pmin=render_params.quantiles_for_norm[0],
385+
pmax=render_params.quantiles_for_norm[1],
386+
clip=True,
387+
)
388+
389+
if render_params.cmap_params[idx].norm is not None:
390+
layer = render_params.cmap_params[idx].norm(layer)
391+
392+
ax.imshow(
393+
layer,
394+
cmap=render_params.cmap_params[idx].cmap,
395+
alpha=(1 / n_channels),
396+
)
397+
break
398+
399+
if n_channels == 1:
400+
layer = img.sel(c=channels)
401+
402+
if render_params.quantiles_for_norm != (None, None):
403+
layer = _normalize(
404+
layer, pmin=render_params.quantiles_for_norm[0], pmax=render_params.quantiles_for_norm[1], clip=True
405+
)
406+
407+
if render_params.cmap_params[0].norm is not None:
408+
layer = render_params.cmap_params[0].norm(layer)
362409

363-
color = render_params.palette
410+
if render_params.palette is None:
411+
ax.imshow(
412+
layer.squeeze(), # get rid of the channel dimension
413+
cmap=render_params.cmap_params[0].cmap,
414+
)
364415

365416
else:
366-
color = _get_colors_for_categorical_obs(
367-
img.coords["c"].values.tolist(), palette=render_params.cmap_params.cmap
417+
ax.imshow(
418+
layer.squeeze(), # get rid of the channel dimension
419+
cmap=_get_linear_colormap([render_params.palette], "k")[0],
368420
)
369421

370-
cmaps = _get_linear_colormap([str(c) for c in color[:num_channels]], "k")
371-
img = _normalize(img, clip=True)
372-
colored = np.stack([cmaps[i](img.values[i]) for i in range(num_channels)], 0).sum(0)
373-
img = xr.DataArray(
422+
break
423+
424+
if render_params.palette is not None and n_channels != len(render_params.palette):
425+
raise ValueError("If 'palette' is provided, its length must match the number of channels.")
426+
427+
if n_channels > 1:
428+
layer = img.sel(c=channels).copy(deep=True)
429+
430+
channel_colors: list[str] | Any
431+
if render_params.palette is None:
432+
channel_colors = _get_colors_for_categorical_obs(
433+
layer.coords["c"].values.tolist(), palette=render_params.cmap_params[0].cmap
434+
)
435+
else:
436+
channel_colors = render_params.palette
437+
438+
channel_cmaps = _get_linear_colormap([str(c) for c in channel_colors[:n_channels]], "k")
439+
440+
layer_vals = []
441+
if render_params.quantiles_for_norm != (None, None):
442+
for i in range(n_channels):
443+
layer_vals.append(
444+
_normalize(
445+
layer.values[i],
446+
pmin=render_params.quantiles_for_norm[0],
447+
pmax=render_params.quantiles_for_norm[1],
448+
clip=True,
449+
)
450+
)
451+
452+
colored = np.stack([channel_cmaps[i](layer_vals[i]) for i in range(n_channels)], 0).sum(0)
453+
454+
layer = xr.DataArray(
374455
data=colored,
375456
coords=[
376-
img.coords["y"],
377-
img.coords["x"],
457+
layer.coords["y"],
458+
layer.coords["x"],
378459
["R", "G", "B", "A"],
379460
],
380461
dims=["y", "x", "c"],
381462
)
463+
layer = layer.transpose("y", "x", "c") # for plotting
382464

383-
img = img.transpose("y", "x", "c") # for plotting
384-
385-
ax.imshow(
386-
img.data,
387-
cmap=render_params.cmap_params.cmap,
388-
alpha=render_params.alpha,
389-
# extent=extent,
390-
)
465+
ax.imshow(
466+
layer.data,
467+
cmap=channel_cmaps[0],
468+
alpha=render_params.alpha,
469+
norm=render_params.cmap_params[0].norm,
470+
)
391471

392472

393473
@dataclass
@@ -402,7 +482,7 @@ class LabelsRenderParams:
402482
outline: bool = False
403483
alt_var: str | None = None
404484
layer: str | None = None
405-
palette: Palette_t = None
485+
palette: ListedColormap | str | None = None
406486
outline_alpha: float = 1.0
407487
fill_alpha: float = 0.4
408488
transfunc: Callable[[float], float] | None = None

src/spatialdata_plot/pl/utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -603,8 +603,8 @@ def _get_hex_colors_for_continous_values(values: pd.Series, cmap_name: str = "vi
603603

604604
def _normalize(
605605
img: xr.DataArray,
606-
pmin: float = 3.0,
607-
pmax: float = 99.8,
606+
pmin: float | None = 3.0,
607+
pmax: float | None = 99.8,
608608
eps: float = 1e-20,
609609
clip: bool = False,
610610
name: str = "normed",
@@ -618,9 +618,9 @@ def _normalize(
618618
dataarray
619619
A xarray DataArray with an image field.
620620
pmin
621-
Lower quantile (min value) used to perform qunatile normalization.
621+
Lower quantile (min value) used to perform quantile normalization.
622622
pmax
623-
Upper quantile (max value) used to perform qunatile normalization.
623+
Upper quantile (max value) used to perform quantile normalization.
624624
eps
625625
Epsilon float added to prevent 0 division.
626626
clip
@@ -631,9 +631,12 @@ def _normalize(
631631
xr.DataArray
632632
A min-max normalized image.
633633
"""
634-
perc = np.percentile(img, [pmin, pmax], axis=(1, 2)).T
634+
pmin = pmin or 0.0
635+
pmax = pmax or 100.0
635636

636-
norm = (img - np.expand_dims(perc[:, 0], (1, 2))) / (np.expand_dims(perc[:, 1] - perc[:, 0], (1, 2)) + eps)
637+
perc = np.percentile(img, [pmin, pmax])
638+
639+
norm = (img - perc[0]) / (perc[1] - perc[0] + eps)
637640

638641
if clip:
639642
norm = np.clip(norm, 0, 1)
@@ -691,7 +694,7 @@ def _get_colors_for_categorical_obs(
691694
elif isinstance(palette, ListedColormap):
692695
palette = [to_hex(x) for x in palette(color_idx, alpha=alpha)]
693696
elif isinstance(palette, LinearSegmentedColormap):
694-
palette = [to_hex(palette(x, alpha=alpha)) for x in [color_idx]]
697+
palette = [to_hex(palette(x, alpha=alpha)) for x in color_idx] # type: ignore[attr-defined]
695698
else:
696699
raise TypeError(f"Palette is {type(palette)} but should be string or `ListedColormap`.")
697700

-174 Bytes
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
-174 Bytes
Loading
Loading

tests/_images/Labels_labels.png

14 KB
Loading

tests/_images/Points_points.png

7.03 KB
Loading
-121 Bytes
Loading

tests/pl/test_render_images.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,27 @@ class TestImages(PlotTester, metaclass=PlotTesterMeta):
2222
def test_plot_can_render_image(self, sdata_blobs: SpatialData):
2323
sdata_blobs.pl.render_images(elements="blobs_image").pl.show()
2424

25-
# def test_plot_can_render_multiscale_image(self, sdata_blobs: SpatialData):
26-
# sdata_blobs.pl.render_images(elements="blobs_multiscale_image").pl.show()
25+
def test_plot_can_pass_cmap_to_render_images(self, sdata_blobs: SpatialData):
26+
sdata_blobs.pl.render_images(elements="blobs_image", cmap="seismic").pl.show()
27+
28+
def test_plot_can_render_a_single_channel_from_image(self, sdata_blobs: SpatialData):
29+
sdata_blobs.pl.render_images(elements="blobs_image", channel=0).pl.show()
30+
31+
def test_plot_can_render_two_channels_from_image(self, sdata_blobs: SpatialData):
32+
sdata_blobs.pl.render_images(elements="blobs_image", channel=[0, 1]).pl.show()
33+
34+
def test_plot_can_pass_color_to_single_channel(self, sdata_blobs: SpatialData):
35+
sdata_blobs.pl.render_images(elements="blobs_image", channel=1, palette="red").pl.show()
36+
37+
def test_plot_can_pass_cmap_to_single_channel(self, sdata_blobs: SpatialData):
38+
sdata_blobs.pl.render_images(elements="blobs_image", channel=1, cmap="Reds").pl.show()
39+
40+
def test_plot_can_pass_color_to_each_channel(self, sdata_blobs: SpatialData):
41+
sdata_blobs.pl.render_images(
42+
elements="blobs_image", channel=[0, 1, 2], palette=["red", "green", "blue"]
43+
).pl.show()
44+
45+
def test_plot_can_pass_cmap_to_each_channel(self, sdata_blobs: SpatialData):
46+
sdata_blobs.pl.render_images(
47+
elements="blobs_image", channel=[0, 1, 2], cmap=["Reds", "Greens", "Blues"]
48+
).pl.show()

0 commit comments

Comments
 (0)