Skip to content

Commit bc0947d

Browse files
sagar87timtreispre-commit-ci[bot]
authored
first version of render_channels
* pl.render_channels blueprint * _render_channels blueprint * implemented basic functionality of _render_channels * pl.render_channels checks types * pl.render_channels appends the params dict correctly to the plotting tree * _render_channels clips the final image * cast passed ax object in pl.show to an numpy array such that it is an iterable * added docstring and additional parameter validation to pl.render_channels * fixed input validation in pl.render_channels * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Tim Treis <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4ab2215 commit bc0947d

File tree

3 files changed

+137
-9
lines changed

3 files changed

+137
-9
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ..accessor import register_spatial_data_accessor
2323
from ..pp.utils import _get_instance_key, _get_region_key, _verify_plotting_tree_exists
24-
from .render import _render_images, _render_labels, _render_shapes
24+
from .render import _render_channels, _render_images, _render_labels, _render_shapes
2525
from .utils import (
2626
_get_color_key_dtype,
2727
_get_color_key_values,
@@ -231,6 +231,96 @@ def render_images(
231231

232232
return sdata
233233

234+
def render_channels(
235+
self,
236+
channels: Union[list[str], list[int]],
237+
colors: list[str],
238+
normalize: bool = True,
239+
clip: bool = True,
240+
background: str = "black",
241+
pmin: float = 3.0,
242+
pmax: float = 99.8,
243+
) -> sd.SpatialData:
244+
"""Renders selected channels.
245+
246+
Parameters:
247+
-----------
248+
self: object
249+
The SpatialData object
250+
channels: Union[List[str], List[int]]
251+
The channels to plot
252+
colors: List[str]
253+
The colors for the channels. Must be at least as long as len(channels).
254+
normalize: bool
255+
Perform quantile normalisation (using pmin, pmax)
256+
clip: bool
257+
Clips the merged image to the range (0, 1).
258+
background: str
259+
Background color (defaults to black).
260+
pmin: float
261+
Lower percentile for quantile normalisation (defaults to 3.-).
262+
pmax: float
263+
Upper percentile for quantile normalisation (defaults to 99.8).
264+
265+
Raises
266+
------
267+
TypeError
268+
If any of the parameters have an invalid type.
269+
ValueError
270+
If any of the parameters have an invalid value.
271+
272+
Returns
273+
-------
274+
sd.SpatialData
275+
A new `SpatialData` object that is a copy of the original
276+
`SpatialData` object, with an updated plotting tree.
277+
"""
278+
if not isinstance(channels, list):
279+
raise TypeError("Parameter 'channels' must be a list.")
280+
281+
if not isinstance(colors, list):
282+
raise TypeError("Parameter 'colors' must be a list.")
283+
284+
if len(channels) > len(colors):
285+
raise ValueError("Number of colors must have at least the same length as the number of selected channels.")
286+
287+
if not isinstance(clip, bool):
288+
raise TypeError("Parameter 'clip' must be a bool.")
289+
290+
if not isinstance(normalize, bool):
291+
raise TypeError("Parameter 'normalize' must be a bool.")
292+
293+
if not isinstance(background, str):
294+
raise TypeError("Parameter 'background' must be a str.")
295+
296+
if not isinstance(pmin, float):
297+
raise TypeError("Parameter 'pmin' must be a str.")
298+
299+
if not isinstance(pmax, float):
300+
raise TypeError("Parameter 'pmax' must be a str.")
301+
302+
if (pmin < 0.0) or (pmin > 100.0) or (pmax < 0.0) or (pmax > 100.0):
303+
raise ValueError("Percentiles must be in the range 0 < pmin/pmax < 100.")
304+
305+
if pmin > pmax:
306+
raise ValueError("Percentile parameters must satisfy pmin < pmax.")
307+
308+
sdata = self._copy()
309+
sdata = _verify_plotting_tree_exists(sdata)
310+
n_steps = len(sdata.plotting_tree.keys())
311+
312+
sdata.plotting_tree[f"{n_steps+1}_render_channels"] = {
313+
"channels": channels,
314+
"colors": colors,
315+
"clip": clip,
316+
"normalize": normalize,
317+
"background": background,
318+
"pmin": pmin,
319+
"pmax": pmax,
320+
}
321+
322+
return sdata
323+
234324
def render_labels(
235325
self,
236326
instance_key: Optional[Union[str, None]] = None,
@@ -458,12 +548,12 @@ def show(
458548
num_images = len(sdata.coordinate_systems)
459549
fig, axs = _get_subplots(num_images, ncols, width, height)
460550
elif isinstance(ax, matplotlib.pyplot.Axes):
461-
axs = [ax]
551+
axs = np.array([ax])
462552
elif isinstance(ax, list):
463553
axs = ax
464554

465555
# Set background color
466-
for _, ax in enumerate(axs):
556+
for _, ax in enumerate(axs.flatten()):
467557
ax.set_facecolor(bg_color)
468558
# key = list(sdata.labels.keys())[idx]
469559
# ax.imshow(sdata.labels[key].values, cmap=ListedColormap([bg_color]))
@@ -514,12 +604,16 @@ def show(
514604

515605
# go through tree
516606
for cmd, params in render_cmds.items():
517-
if cmd == "render_images":
518-
for idx, ax in enumerate(axs):
519-
key = list(sdata.images.keys())[idx]
607+
keys = list(sdata.images.keys())
520608

609+
if cmd == "render_images":
610+
for key, ax in zip(keys, axs.flatten()):
521611
_render_images(sdata=sdata, params=params, key=key, ax=ax, extent=extent)
522612

613+
elif cmd == "render_channels":
614+
for key, ax in zip(keys, axs.flatten()):
615+
_render_channels(sdata=sdata, key=key, ax=ax, **params)
616+
523617
elif cmd == "render_shapes":
524618
if (
525619
sdata.table is not None

src/spatialdata_plot/pl/render.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,41 @@
1212
from sklearn.decomposition import PCA
1313

1414
from ..pl.utils import _normalize
15-
from ..pp.utils import _get_region_key
15+
from ..pp.utils import _get_linear_colormap, _get_region_key
16+
17+
18+
def _render_channels(
19+
sdata: sd.SpatialData,
20+
channels: list[Union[str, int]],
21+
colors: list[str],
22+
clip: bool,
23+
normalize: bool,
24+
background: str,
25+
pmin: float,
26+
pmax: float,
27+
key: str,
28+
ax: matplotlib.axes.SubplotBase,
29+
) -> None:
30+
selection = sdata.images[key].sel({"c": channels})
31+
n_channels, y_dim, x_dim = selection.shape # (c, y, x)
32+
img = selection.values.copy()
33+
img = img.astype("float")
34+
35+
if normalize:
36+
img = _normalize(img, pmin, pmax, clip)
37+
38+
cmaps = _get_linear_colormap(colors[:n_channels], background)
39+
colored = np.stack([cmaps[i](img[i]) for i in range(n_channels)], 0).sum(0)
40+
41+
if clip:
42+
colored = np.clip(colored, 0, 1)
43+
44+
ax.imshow(colored)
45+
ax.set_title(key)
46+
ax.set_xlabel("spatial1")
47+
ax.set_ylabel("spatial2")
48+
ax.set_xticks([])
49+
ax.set_yticks([])
1650

1751

1852
def _render_shapes(

src/spatialdata_plot/pl/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ def _get_subplots(num_images: int, ncols: int = 4, width: int = 4, height: int =
4444
fig, axes = plt.subplots(nrows, ncols, figsize=(width * ncols, height * nrows))
4545

4646
if not isinstance(axes, Iterable):
47-
axes = [axes]
47+
axes = np.array([axes])
4848

4949
# get rid of the empty axes
50-
# _ = [ax.axis("off") for ax in axes.flatten()[num_images:]]
50+
_ = [ax.axis("off") for ax in axes.flatten()[num_images:]]
5151
return fig, axes
5252

5353

0 commit comments

Comments
 (0)