Skip to content

Commit 55170fd

Browse files
author
Sonja Stockhaus
committed
improved cmap handling for images
1 parent 74afb01 commit 55170fd

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,8 @@ def render_images(
362362
sdata = _verify_plotting_tree(sdata)
363363
n_steps = len(sdata.plotting_tree.keys())
364364

365-
if channel is None and cmap is None:
366-
cmap = "brg"
365+
# if channel is None and cmap is None:
366+
# cmap = "brg"
367367

368368
cmap_params: list[CmapParams] | CmapParams
369369
if isinstance(cmap, list):

src/spatialdata_plot/pl/render.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,10 +434,29 @@ def _render_images(
434434
if render_params.cmap_params[i].norm is not None:
435435
layers[c] = render_params.cmap_params[i].norm(layers[c])
436436

437-
# 2A) Image has 3 channels, no palette/cmap info -> use RGB
438-
if n_channels == 3 and render_params.palette is None and not got_multiple_cmaps:
437+
# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
438+
if n_channels == 3 and render_params.palette is None and not isinstance(render_params.cmap_params, list):
439+
if render_params.cmap_params.is_default: # -> use RGB
440+
stacked = np.stack([layers[c] for c in channels], axis=-1)
441+
else: # -> use given cmap for each channel
442+
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
443+
# Apply cmaps to each channel, add up and normalize to [0, 1]
444+
stacked = (
445+
np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0) / n_channels
446+
)
447+
# Remove alpha channel so we can overwrite it from render_params.alpha
448+
stacked = stacked[:, :, :3]
449+
logger.warning(
450+
"One cmap was given for multiple channels and is now used for each channel. "
451+
"You're blending multiple cmaps. "
452+
"If the plot doesn't look like you expect, it might be because your "
453+
"cmaps go from a given color to 'white', and not to 'transparent'. "
454+
"Therefore, the 'white' of higher layers will overlay the lower layers. "
455+
"Consider using 'palette' instead."
456+
)
457+
439458
im = ax.imshow(
440-
np.stack([layers[c] for c in channels], axis=-1),
459+
stacked,
441460
alpha=render_params.alpha,
442461
)
443462
im.set_transform(trans_data)

src/spatialdata_plot/pl/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,13 @@ def _prepare_cmap_norm(
344344
**kwargs: Any,
345345
) -> CmapParams:
346346
is_default = cmap is None
347-
cmap = copy(matplotlib.colormaps[rcParams["image.cmap"] if cmap is None else cmap])
347+
if cmap is None:
348+
cmap = rcParams["image.cmap"]
349+
if isinstance(cmap, str):
350+
cmap = matplotlib.colormaps[cmap]
351+
352+
cmap = copy(cmap)
353+
348354
cmap.set_bad("lightgray" if na_color is None else na_color)
349355

350356
if isinstance(norm, Normalize) or not norm:

0 commit comments

Comments
 (0)