Skip to content

Commit cb2528a

Browse files
Refactor render channels to render images (#55)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5389ae1 commit cb2528a

File tree

4 files changed

+43
-19
lines changed

4 files changed

+43
-19
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def render_points(
290290
def render_images(
291291
self,
292292
element: str | None = None,
293-
channel: str | None = None,
293+
channel: list[str] | list[int] | int | str | None = None,
294294
cmap: Colormap | str | None = None,
295295
norm: Optional[Normalize] = None,
296296
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
@@ -326,6 +326,7 @@ def render_images(
326326
sdata = self._copy()
327327
sdata = _verify_plotting_tree(sdata)
328328
n_steps = len(sdata.plotting_tree.keys())
329+
329330
cmap_params = _prepare_cmap_norm(
330331
cmap=cmap,
331332
norm=norm,

src/spatialdata_plot/pl/render.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pandas as pd
1313
import scanpy as sc
1414
import spatialdata as sd
15+
import xarray as xr
1516
from anndata import AnnData
1617
from geopandas import GeoDataFrame
1718
from matplotlib import colors
@@ -28,6 +29,8 @@
2829
OutlineParams,
2930
ScalebarParams,
3031
_decorate_axs,
32+
_get_colors_for_categorical_obs,
33+
_get_linear_colormap,
3134
_map_color_seg,
3235
_maybe_set_colors,
3336
_normalize,
@@ -292,7 +295,7 @@ class ImageRenderParams:
292295

293296
cmap_params: CmapParams
294297
element: str | None = None
295-
channel: Sequence[str] | None = None
298+
channel: list[str] | list[int] | int | str | None = None
296299
palette: Palette_t = None
297300
alpha: float = 1.0
298301

@@ -319,10 +322,32 @@ def _render_images(
319322
if (len(img.c) > 3 or len(img.c) == 2) and render_params.channel is None:
320323
raise NotImplementedError("Only 1 or 3 channels are supported at the moment.")
321324

322-
img = _normalize(img, clip=True)
323-
324325
if render_params.channel is not None:
325-
img = img.sel(c=[render_params.channel])
326+
channels = [render_params.channel] if isinstance(render_params.channel, (str, int)) else render_params.channel
327+
img = img.sel(c=channels)
328+
num_channels = img.sizes["c"]
329+
330+
if render_params.palette is not None:
331+
if num_channels > len(render_params.palette):
332+
raise ValueError("If palette is provided, it must match the number of channels.")
333+
334+
color = render_params.palette
335+
336+
else:
337+
color = _get_colors_for_categorical_obs(img.coords["c"].values.tolist())
338+
339+
cmaps = _get_linear_colormap([str(c) for c in color[:num_channels]], "k")
340+
img = _normalize(img, clip=True)
341+
colored = np.stack([cmaps[i](img.values[i]) for i in range(num_channels)], 0).sum(0)
342+
img = xr.DataArray(
343+
data=colored,
344+
coords=[
345+
img.coords["y"],
346+
img.coords["x"],
347+
["R", "G", "B", "A"],
348+
],
349+
dims=["y", "x", "c"],
350+
)
326351

327352
img = img.transpose("y", "x", "c") # for plotting
328353

src/spatialdata_plot/pl/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from matplotlib.axes import Axes
2222
from matplotlib.cm import get_cmap
2323
from matplotlib.collections import PatchCollection
24-
from matplotlib.colors import Colormap, ListedColormap, Normalize, TwoSlopeNorm, to_rgba
24+
from matplotlib.colors import Colormap, LinearSegmentedColormap, ListedColormap, Normalize, TwoSlopeNorm, to_rgba
2525
from matplotlib.figure import Figure
2626
from matplotlib.gridspec import GridSpec
2727
from matplotlib_scalebar.scalebar import ScaleBar
@@ -918,3 +918,14 @@ def _multiscale_to_image(sdata: sd.SpatialData) -> sd.SpatialData:
918918
sdata.images[k] = Image2DModel.parse(v["scale0"].ds.to_array().squeeze(axis=0))
919919

920920
return sdata
921+
922+
923+
def _get_linear_colormap(colors: list[str], background: str) -> list[LinearSegmentedColormap]:
924+
return [LinearSegmentedColormap.from_list(c, [background, c], N=256) for c in colors]
925+
926+
927+
def _get_listed_colormap(color_dict: dict[str, str]) -> ListedColormap:
928+
sorted_labels = sorted(color_dict.keys())
929+
colors = [color_dict[k] for k in sorted_labels]
930+
931+
return ListedColormap(["black"] + colors, N=len(colors) + 1)

src/spatialdata_plot/pp/utils.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,10 @@
11
from collections import OrderedDict
22

3-
import matplotlib
43
import spatialdata as sd
5-
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
64
from spatialdata.models import TableModel
75
from spatialdata.transformations import get_transformation
86

97

10-
def _get_linear_colormap(colors: list[str], background: str) -> list[matplotlib.colors.LinearSegmentedColormap]:
11-
return [LinearSegmentedColormap.from_list(c, [background, c], N=256) for c in colors]
12-
13-
14-
def _get_listed_colormap(color_dict: dict[str, str]) -> matplotlib.colors.ListedColormap:
15-
sorted_labels = sorted(color_dict.keys())
16-
colors = [color_dict[k] for k in sorted_labels]
17-
18-
return ListedColormap(["black"] + colors, N=len(colors) + 1)
19-
20-
218
def _get_region_key(sdata: sd.SpatialData) -> str:
229
"""Quick access to the data's region key."""
2310
return str(sdata.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY])

0 commit comments

Comments
 (0)