Skip to content

Commit 8d13c48

Browse files
authored
Image render param refactor (#267)
* add temporary deprecation decorator * initial refactor * mostly fixed tests * fix remaining 2 tests * some more refactor * some more refactor * mypy * add images properly to elements to render
1 parent fc427d0 commit 8d13c48

File tree

5 files changed

+485
-254
lines changed

5 files changed

+485
-254
lines changed

src/spatialdata_plot/_utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from __future__ import annotations
2+
3+
import functools
4+
import warnings
5+
from typing import Any, Callable, TypeVar
6+
7+
RT = TypeVar("RT")
8+
9+
10+
def deprecation_alias(**aliases: str) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
11+
"""
12+
Decorate a function to warn user of use of arguments set for deprecation.
13+
14+
Parameters
15+
----------
16+
aliases
17+
Deprecation argument aliases to be mapped to the new arguments.
18+
19+
Returns
20+
-------
21+
A decorator that can be used to mark an argument for deprecation and substituting it with the new argument.
22+
23+
Raises
24+
------
25+
TypeError
26+
If the provided aliases are not of string type.
27+
28+
Example
29+
-------
30+
Assuming we have an argument 'table' set for deprecation and we want to warn the user and substitute with 'tables':
31+
32+
```python
33+
@deprecation_alias(table="tables")
34+
def my_function(tables: AnnData | dict[str, AnnData]):
35+
pass
36+
```
37+
"""
38+
39+
def deprecation_decorator(f: Callable[..., RT]) -> Callable[..., RT]:
40+
@functools.wraps(f)
41+
def wrapper(*args: Any, **kwargs: Any) -> RT:
42+
class_name = f.__qualname__
43+
rename_kwargs(f.__name__, kwargs, aliases, class_name)
44+
return f(*args, **kwargs)
45+
46+
return wrapper
47+
48+
return deprecation_decorator
49+
50+
51+
def rename_kwargs(func_name: str, kwargs: dict[str, Any], aliases: dict[str, str], class_name: None | str) -> None:
52+
"""Rename function arguments set for deprecation and gives warning in case of usage of these arguments."""
53+
for alias, new in aliases.items():
54+
if alias in kwargs:
55+
class_name = class_name + "." if class_name else ""
56+
if new in kwargs:
57+
raise TypeError(
58+
f"{class_name}{func_name} received both {alias} and {new} as arguments!"
59+
f" {alias} is being deprecated in spatialdata-plot version 0.3, only use {new} instead."
60+
)
61+
warnings.warn(
62+
message=(
63+
f"`{alias}` is being deprecated as an argument to `{class_name}{func_name}` in spatialdata-plot "
64+
f"version 0.3, switch to `{new}` instead."
65+
),
66+
category=DeprecationWarning,
67+
stacklevel=3,
68+
)
69+
kwargs[new] = kwargs.pop(alias)

src/spatialdata_plot/pl/basic.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from spatialdata._core.data_extent import get_extent
2424

2525
from spatialdata_plot._accessor import register_spatial_data_accessor
26+
from spatialdata_plot._utils import deprecation_alias
2627
from spatialdata_plot.pl.render import (
2728
_render_images,
2829
_render_labels,
@@ -50,6 +51,7 @@
5051
_prepare_params_plot,
5152
_set_outline,
5253
_update_params,
54+
_validate_image_render_params,
5355
_validate_render_params,
5456
_validate_show_parameters,
5557
save_fig,
@@ -390,59 +392,63 @@ def render_points(
390392

391393
return sdata
392394

395+
@deprecation_alias(elements="element", quantiles_for_norm="percentiles_for_norm", version="version 0.3.0")
393396
def render_images(
394397
self,
395-
elements: list[str] | str | None = None,
398+
element: str | None = None,
396399
channel: list[str] | list[int] | str | int | None = None,
397-
cmap: list[Colormap] | Colormap | str | None = None,
400+
cmap: list[Colormap | str] | Colormap | str | None = None,
398401
norm: Normalize | None = None,
399402
na_color: ColorLike | None = (0.0, 0.0, 0.0, 0.0),
400-
palette: list[list[str | None]] | list[str | None] | str | None = None,
403+
palette: list[str] | str | None = None,
401404
alpha: float | int = 1.0,
402-
quantiles_for_norm: tuple[float | None, float | None] | None = None,
403-
scale: list[str] | str | None = None,
405+
percentiles_for_norm: tuple[float, float] | None = None,
406+
scale: str | None = None,
404407
**kwargs: Any,
405408
) -> sd.SpatialData:
406409
"""
407410
Render image elements in SpatialData.
408411
412+
In case of no elements specified, "broadcasting" of parameters is applied. This means that for any particular
413+
SpatialElement, we validate whether a given parameter is valid. If not valid for a particular SpatialElement the
414+
specific parameter for that particular SpatialElement will be ignored. If you want to set specific parameters
415+
for specific elements please chain the render functions: `pl.render_images(...).pl.render_images(...).pl.show()`
416+
.
417+
409418
Parameters
410419
----------
411-
elements : list[str] | str | None, optional
412-
The name(s) of the image element(s) to render. If `None`, all image
413-
elements in the `SpatialData` object will be used. If a string is provided,
414-
it is converted into a single-element list.
415-
channel : list[str] | list[int] | str | int | None, optional
420+
element : str | None
421+
The name of the image element to render. If `None`, all image
422+
elements in the `SpatialData` object will be used.
423+
channels : list[str] | list[int] | str | int | None, optional
416424
To select specific channels to plot. Can be a single channel name/int or a
417425
list of channel names/ints. If `None`, all channels will be used.
418-
cmap : list[Colormap] | Colormap | str | None, optional
426+
cmap : list[Colormap | str] | Colormap | str | None, optional
419427
Colormap or list of colormaps for continuous annotations, see :class:`matplotlib.colors.Colormap`.
420428
Each colormap applies to a corresponding channel.
421429
norm : Normalize | None, optional
422430
Colormap normalization for continuous annotations, see :class:`matplotlib.colors.Normalize`.
423431
Applies to all channels if set.
424432
na_color : ColorLike | None, default (0.0, 0.0, 0.0, 0.0)
425433
Color to be used for NA values. Accepts color-like values (string, hex, RGB(A)).
426-
palette : list[list[str | None]] | list[str | None] | str | None
434+
palette : list[str] | None
427435
Palette to color images. In the case of a list of
428436
lists means that there is one list per element to be plotted in the list and this list contains the string
429437
indicating the palette to be used. If not provided as list of lists, broadcasting behaviour is
430438
attempted (use the same values for all elements).
431439
alpha : float | int, default 1.0
432440
Alpha value for the images. Must be a numeric between 0 and 1.
433-
quantiles_for_norm : tuple[float | None, float | None] | None, optional
441+
percentiles_for_norm : tuple[float, float] | None
434442
Optional pair of floats (pmin < pmax, 0-100) which will be used for quantile normalization.
435-
scale : list[str] | str | None, optional
443+
scale : str | None
436444
Influences the resolution of the rendering. Possibilities include:
437445
1) `None` (default): The image is rasterized to fit the canvas size. For
438446
multiscale images, the best scale is selected before rasterization.
439-
2) A scale name: Renders the specified scale as-is (with adjustments for dpi
440-
in `show()`).
447+
2) A scale name: Renders the specified scale ( of a multiscale image) as-is
448+
(with adjustments for dpi in `show()`).
441449
3) "full": Renders the full image without rasterization. In the case of
442450
multiscale images, the highest resolution scale is selected. Note that
443451
this may result in long computing times for large images.
444-
4) A list matching the list of elements. Can contain `None`, scale names, or
445-
"full". Each scale applies to the corresponding element.
446452
kwargs
447453
Additional arguments to be passed to cmap, norm, and other rendering functions.
448454
@@ -451,19 +457,19 @@ def render_images(
451457
sd.SpatialData
452458
The SpatialData object with the rendered images.
453459
"""
454-
params_dict = _validate_render_params(
455-
"images",
460+
params_dict = _validate_image_render_params(
456461
self._sdata,
457-
elements=elements,
462+
element=element,
458463
channel=channel,
459464
alpha=alpha,
460465
palette=palette,
461466
na_color=na_color,
462467
cmap=cmap,
463468
norm=norm,
464469
scale=scale,
465-
quantiles_for_norm=quantiles_for_norm,
470+
percentiles_for_norm=percentiles_for_norm,
466471
)
472+
467473
sdata = self._copy()
468474
sdata = _verify_plotting_tree(sdata)
469475
n_steps = len(sdata.plotting_tree.keys())
@@ -488,15 +494,17 @@ def render_images(
488494
**kwargs,
489495
)
490496

491-
sdata.plotting_tree[f"{n_steps+1}_render_images"] = ImageRenderParams(
492-
elements=params_dict["elements"],
493-
channel=channel,
494-
cmap_params=cmap_params,
495-
palette=params_dict["palette"],
496-
alpha=alpha,
497-
quantiles_for_norm=params_dict["quantiles_for_norm"],
498-
scale=params_dict["scale"],
499-
)
497+
for element, param_values in params_dict.items():
498+
sdata.plotting_tree[f"{n_steps+1}_render_images"] = ImageRenderParams(
499+
element=element,
500+
channel=param_values["channel"],
501+
cmap_params=cmap_params,
502+
palette=param_values["palette"],
503+
alpha=param_values["alpha"],
504+
percentiles_for_norm=param_values["percentiles_for_norm"],
505+
scale=param_values["scale"],
506+
)
507+
n_steps += 1
500508

501509
return sdata
502510

0 commit comments

Comments
 (0)