|
21 | 21 |
|
22 | 22 | from ..accessor import register_spatial_data_accessor
|
23 | 23 | from ..pp.utils import _get_instance_key, _get_region_key, _verify_plotting_tree_exists
|
24 |
| -from .render import _render_images, _render_labels, _render_shapes |
| 24 | +from .render import _render_channels, _render_images, _render_labels, _render_shapes |
25 | 25 | from .utils import (
|
26 | 26 | _get_color_key_dtype,
|
27 | 27 | _get_color_key_values,
|
@@ -231,6 +231,96 @@ def render_images(
|
231 | 231 |
|
232 | 232 | return sdata
|
233 | 233 |
|
| 234 | + def render_channels( |
| 235 | + self, |
| 236 | + channels: Union[list[str], list[int]], |
| 237 | + colors: list[str], |
| 238 | + normalize: bool = True, |
| 239 | + clip: bool = True, |
| 240 | + background: str = "black", |
| 241 | + pmin: float = 3.0, |
| 242 | + pmax: float = 99.8, |
| 243 | + ) -> sd.SpatialData: |
| 244 | + """Renders selected channels. |
| 245 | +
|
| 246 | + Parameters: |
| 247 | + ----------- |
| 248 | + self: object |
| 249 | + The SpatialData object |
| 250 | + channels: Union[List[str], List[int]] |
| 251 | + The channels to plot |
| 252 | + colors: List[str] |
| 253 | + The colors for the channels. Must be at least as long as len(channels). |
| 254 | + normalize: bool |
| 255 | + Perform quantile normalisation (using pmin, pmax) |
| 256 | + clip: bool |
| 257 | + Clips the merged image to the range (0, 1). |
| 258 | + background: str |
| 259 | + Background color (defaults to black). |
| 260 | + pmin: float |
| 261 | + Lower percentile for quantile normalisation (defaults to 3.-). |
| 262 | + pmax: float |
| 263 | + Upper percentile for quantile normalisation (defaults to 99.8). |
| 264 | +
|
| 265 | + Raises |
| 266 | + ------ |
| 267 | + TypeError |
| 268 | + If any of the parameters have an invalid type. |
| 269 | + ValueError |
| 270 | + If any of the parameters have an invalid value. |
| 271 | +
|
| 272 | + Returns |
| 273 | + ------- |
| 274 | + sd.SpatialData |
| 275 | + A new `SpatialData` object that is a copy of the original |
| 276 | + `SpatialData` object, with an updated plotting tree. |
| 277 | + """ |
| 278 | + if not isinstance(channels, list): |
| 279 | + raise TypeError("Parameter 'channels' must be a list.") |
| 280 | + |
| 281 | + if not isinstance(colors, list): |
| 282 | + raise TypeError("Parameter 'colors' must be a list.") |
| 283 | + |
| 284 | + if len(channels) > len(colors): |
| 285 | + raise ValueError("Number of colors must have at least the same length as the number of selected channels.") |
| 286 | + |
| 287 | + if not isinstance(clip, bool): |
| 288 | + raise TypeError("Parameter 'clip' must be a bool.") |
| 289 | + |
| 290 | + if not isinstance(normalize, bool): |
| 291 | + raise TypeError("Parameter 'normalize' must be a bool.") |
| 292 | + |
| 293 | + if not isinstance(background, str): |
| 294 | + raise TypeError("Parameter 'background' must be a str.") |
| 295 | + |
| 296 | + if not isinstance(pmin, float): |
| 297 | + raise TypeError("Parameter 'pmin' must be a str.") |
| 298 | + |
| 299 | + if not isinstance(pmax, float): |
| 300 | + raise TypeError("Parameter 'pmax' must be a str.") |
| 301 | + |
| 302 | + if (pmin < 0.0) or (pmin > 100.0) or (pmax < 0.0) or (pmax > 100.0): |
| 303 | + raise ValueError("Percentiles must be in the range 0 < pmin/pmax < 100.") |
| 304 | + |
| 305 | + if pmin > pmax: |
| 306 | + raise ValueError("Percentile parameters must satisfy pmin < pmax.") |
| 307 | + |
| 308 | + sdata = self._copy() |
| 309 | + sdata = _verify_plotting_tree_exists(sdata) |
| 310 | + n_steps = len(sdata.plotting_tree.keys()) |
| 311 | + |
| 312 | + sdata.plotting_tree[f"{n_steps+1}_render_channels"] = { |
| 313 | + "channels": channels, |
| 314 | + "colors": colors, |
| 315 | + "clip": clip, |
| 316 | + "normalize": normalize, |
| 317 | + "background": background, |
| 318 | + "pmin": pmin, |
| 319 | + "pmax": pmax, |
| 320 | + } |
| 321 | + |
| 322 | + return sdata |
| 323 | + |
234 | 324 | def render_labels(
|
235 | 325 | self,
|
236 | 326 | instance_key: Optional[Union[str, None]] = None,
|
@@ -458,12 +548,12 @@ def show(
|
458 | 548 | num_images = len(sdata.coordinate_systems)
|
459 | 549 | fig, axs = _get_subplots(num_images, ncols, width, height)
|
460 | 550 | elif isinstance(ax, matplotlib.pyplot.Axes):
|
461 |
| - axs = [ax] |
| 551 | + axs = np.array([ax]) |
462 | 552 | elif isinstance(ax, list):
|
463 | 553 | axs = ax
|
464 | 554 |
|
465 | 555 | # Set background color
|
466 |
| - for _, ax in enumerate(axs): |
| 556 | + for _, ax in enumerate(axs.flatten()): |
467 | 557 | ax.set_facecolor(bg_color)
|
468 | 558 | # key = list(sdata.labels.keys())[idx]
|
469 | 559 | # ax.imshow(sdata.labels[key].values, cmap=ListedColormap([bg_color]))
|
@@ -514,12 +604,16 @@ def show(
|
514 | 604 |
|
515 | 605 | # go through tree
|
516 | 606 | for cmd, params in render_cmds.items():
|
517 |
| - if cmd == "render_images": |
518 |
| - for idx, ax in enumerate(axs): |
519 |
| - key = list(sdata.images.keys())[idx] |
| 607 | + keys = list(sdata.images.keys()) |
520 | 608 |
|
| 609 | + if cmd == "render_images": |
| 610 | + for key, ax in zip(keys, axs.flatten()): |
521 | 611 | _render_images(sdata=sdata, params=params, key=key, ax=ax, extent=extent)
|
522 | 612 |
|
| 613 | + elif cmd == "render_channels": |
| 614 | + for key, ax in zip(keys, axs.flatten()): |
| 615 | + _render_channels(sdata=sdata, key=key, ax=ax, **params) |
| 616 | + |
523 | 617 | elif cmd == "render_shapes":
|
524 | 618 | if (
|
525 | 619 | sdata.table is not None
|
|
0 commit comments