Skip to content

Commit 7558de3

Browse files
committed
fix pre-commit
1 parent 5ce5ed7 commit 7558de3

File tree

6 files changed

+47
-47
lines changed

6 files changed

+47
-47
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import OrderedDict
66
from copy import deepcopy
77
from pathlib import Path
8-
from typing import Any, Union
8+
from typing import Any
99

1010
import matplotlib.pyplot as plt
1111
import numpy as np
@@ -61,7 +61,7 @@
6161
# replace with
6262
# from spatialdata._types import ColorLike
6363
# once https://github.com/scverse/spatialdata/pull/689/ is in a release
64-
ColorLike = Union[tuple[float, ...], str]
64+
ColorLike = tuple[float, ...] | str
6565

6666

6767
@register_spatial_data_accessor("pl")

src/spatialdata_plot/pl/render.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import warnings
44
from collections import abc
55
from copy import copy
6-
from typing import Union
76

87
import dask
98
import datashader as ds
@@ -56,7 +55,7 @@
5655
to_hex,
5756
)
5857

59-
_Normalize = Union[Normalize, abc.Sequence[Normalize]]
58+
_Normalize = Normalize | abc.Sequence[Normalize]
6059

6160

6261
def _render_shapes(

src/spatialdata_plot/pl/render_params.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Callable, Sequence
44
from dataclasses import dataclass
5-
from typing import Literal, Union
5+
from typing import Literal
66

77
from matplotlib.axes import Axes
88
from matplotlib.colors import Colormap, ListedColormap, Normalize
@@ -14,7 +14,7 @@
1414
# replace with
1515
# from spatialdata._types import ColorLike
1616
# once https://github.com/scverse/spatialdata/pull/689/ is in a release
17-
ColorLike = Union[tuple[float, ...], str]
17+
ColorLike = tuple[float, ...] | str
1818

1919

2020
@dataclass

src/spatialdata_plot/pl/utils.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from functools import partial
99
from pathlib import Path
1010
from types import MappingProxyType
11-
from typing import Any, Literal, Union
11+
from typing import Any, Literal
1212

1313
import dask
1414
import datashader as ds
@@ -81,7 +81,7 @@
8181
# replace with
8282
# from spatialdata._types import ColorLike
8383
# once https://github.com/scverse/spatialdata/pull/689/ is in a release
84-
ColorLike = Union[tuple[float, ...], str]
84+
ColorLike = tuple[float, ...] | str
8585

8686

8787
def _verify_plotting_tree(sdata: SpatialData) -> SpatialData:
@@ -526,7 +526,7 @@ def _set_outline(
526526
outline_color: str | list[float] = "#0000000ff", # black, white
527527
**kwargs: Any,
528528
) -> OutlineParams:
529-
if not isinstance(outline_width, (int, float)):
529+
if not isinstance(outline_width, int | float):
530530
raise TypeError(f"Invalid type of `outline_width`: {type(outline_width)}, expected `int` or `float`.")
531531
if outline_width == 0.0:
532532
outline = False
@@ -868,9 +868,9 @@ def _generate_base_categorial_color_mapping(
868868
na_color = to_hex(to_rgba(na_color)[:3])
869869

870870
if na_color and len(categories) > len(colors):
871-
return dict(zip(categories, colors + [na_color]))
871+
return dict(zip(categories, colors + [na_color], strict=True))
872872

873-
return dict(zip(categories, colors))
873+
return dict(zip(categories, colors, strict=True))
874874

875875
return _get_default_categorial_color_mapping(color_source_vector)
876876

@@ -887,7 +887,7 @@ def _modify_categorical_color_mapping(
887887
# subset base mapping to only those specified in groups
888888
modified_mapping = {key: mapping[key] for key in mapping if key in groups or key == "NaN"}
889889
elif len(palette) == len(groups) and isinstance(groups, list) and isinstance(palette, list):
890-
modified_mapping = dict(zip(groups, palette))
890+
modified_mapping = dict(zip(groups, palette, strict=True))
891891
else:
892892
raise ValueError(f"Expected palette to be of length `{len(groups)}`, found `{len(palette)}`.")
893893

@@ -908,7 +908,10 @@ def _get_default_categorial_color_mapping(
908908
palette = ["grey" for _ in range(len_cat)]
909909
logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.")
910910

911-
return {cat: to_hex(to_rgba(col)[:3]) for cat, col in zip(color_source_vector.categories, palette[:len_cat])}
911+
return {
912+
cat: to_hex(to_rgba(col)[:3])
913+
for cat, col in zip(color_source_vector.categories, palette[:len_cat], strict=True)
914+
}
912915

913916

914917
def _get_categorical_color_mapping(
@@ -1342,7 +1345,7 @@ def _multiscale_to_spatial_image(
13421345
optimal_index_x -= 1
13431346

13441347
# pick the scale with higher resolution (worst case: downscaled afterwards)
1345-
optimal_scale = scales[min(optimal_index_x, optimal_index_y)]
1348+
optimal_scale = scales[min(int(optimal_index_x), int(optimal_index_y))]
13461349

13471350
# NOTE: problematic if there are cases with > 1 data variable
13481351
data_var_keys = list(multiscale_image[optimal_scale].data_vars)
@@ -1412,12 +1415,12 @@ def _validate_show_parameters(
14121415
return_ax: bool,
14131416
save: str | Path | None,
14141417
) -> None:
1415-
if coordinate_systems is not None and not isinstance(coordinate_systems, (list, str)):
1418+
if coordinate_systems is not None and not isinstance(coordinate_systems, list | str):
14161419
raise TypeError("Parameter 'coordinate_systems' must be a string or a list of strings.")
14171420

14181421
font_weights = ["light", "normal", "medium", "semibold", "bold", "heavy", "black"]
14191422
if legend_fontweight is not None and (
1420-
not isinstance(legend_fontweight, (int, str))
1423+
not isinstance(legend_fontweight, int | str)
14211424
or (isinstance(legend_fontweight, str) and legend_fontweight not in font_weights)
14221425
):
14231426
readable_font_weights = ", ".join(font_weights[:-1]) + ", or " + font_weights[-1]
@@ -1429,7 +1432,7 @@ def _validate_show_parameters(
14291432
font_sizes = ["xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large"]
14301433

14311434
if legend_fontsize is not None and (
1432-
not isinstance(legend_fontsize, (int, float, str))
1435+
not isinstance(legend_fontsize, int | float | str)
14331436
or (isinstance(legend_fontsize, str) and legend_fontsize not in font_sizes)
14341437
):
14351438
readable_font_sizes = ", ".join(font_sizes[:-1]) + ", or " + font_sizes[-1]
@@ -1471,22 +1474,22 @@ def _validate_show_parameters(
14711474
if fig is not None and not isinstance(fig, Figure):
14721475
raise TypeError("Parameter 'fig' must be a matplotlib.figure.Figure.")
14731476

1474-
if title is not None and not isinstance(title, (list, str)):
1477+
if title is not None and not isinstance(title, list | str):
14751478
raise TypeError("Parameter 'title' must be a string or a list of strings.")
14761479

14771480
if not isinstance(share_extent, bool):
14781481
raise TypeError("Parameter 'share_extent' must be a boolean.")
14791482

1480-
if not isinstance(pad_extent, (int, float)):
1483+
if not isinstance(pad_extent, int | float):
14811484
raise TypeError("Parameter 'pad_extent' must be numeric.")
14821485

1483-
if ax is not None and not isinstance(ax, (Axes, list)):
1486+
if ax is not None and not isinstance(ax, Axes | list):
14841487
raise TypeError("Parameter 'ax' must be a matplotlib.axes.Axes or a list of Axes.")
14851488

14861489
if not isinstance(return_ax, bool):
14871490
raise TypeError("Parameter 'return_ax' must be a boolean.")
14881491

1489-
if save is not None and not isinstance(save, (str, Path)):
1492+
if save is not None and not isinstance(save, str | Path):
14901493
raise TypeError("Parameter 'save' must be a string or a pathlib.Path.")
14911494

14921495

@@ -1505,10 +1508,10 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
15051508
elif element_type == "shapes":
15061509
param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].shapes.keys())
15071510

1508-
if (channel := param_dict.get("channel")) is not None and not isinstance(channel, (list, str, int)):
1511+
if (channel := param_dict.get("channel")) is not None and not isinstance(channel, list | str | int):
15091512
raise TypeError("Parameter 'channel' must be a string, an integer, or a list of strings or integers.")
15101513
if isinstance(channel, list):
1511-
if not all(isinstance(c, (str, int)) for c in channel):
1514+
if not all(isinstance(c, str | int) for c in channel):
15121515
raise TypeError("Each item in 'channel' list must be a string or an integer.")
15131516
if not all(isinstance(c, type(channel[0])) for c in channel):
15141517
raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.")
@@ -1533,27 +1536,27 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
15331536
param_dict["col_for_color"] = None
15341537

15351538
if outline_width := param_dict.get("outline_width"):
1536-
if not isinstance(outline_width, (float, int)):
1539+
if not isinstance(outline_width, float | int):
15371540
raise TypeError("Parameter 'outline_width' must be numeric.")
15381541
if outline_width < 0:
15391542
raise ValueError("Parameter 'outline_width' cannot be negative.")
15401543

15411544
if (outline_alpha := param_dict.get("outline_alpha")) and (
1542-
not isinstance(outline_alpha, (float, int)) or not 0 <= outline_alpha <= 1
1545+
not isinstance(outline_alpha, float | int) or not 0 <= outline_alpha <= 1
15431546
):
15441547
raise TypeError("Parameter 'outline_alpha' must be numeric and between 0 and 1.")
15451548

15461549
if contour_px is not None and contour_px <= 0:
15471550
raise ValueError("Parameter 'contour_px' must be a positive number.")
15481551

15491552
if (alpha := param_dict.get("alpha")) is not None:
1550-
if not isinstance(alpha, (float, int)):
1553+
if not isinstance(alpha, float | int):
15511554
raise TypeError("Parameter 'alpha' must be numeric.")
15521555
if not 0 <= alpha <= 1:
15531556
raise ValueError("Parameter 'alpha' must be between 0 and 1.")
15541557

15551558
if (fill_alpha := param_dict.get("fill_alpha")) is not None:
1556-
if not isinstance(fill_alpha, (float, int)):
1559+
if not isinstance(fill_alpha, float | int):
15571560
raise TypeError("Parameter 'fill_alpha' must be numeric.")
15581561
if fill_alpha < 0:
15591562
raise ValueError("Parameter 'fill_alpha' cannot be negative.")
@@ -1563,7 +1566,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
15631566
param_dict["cmap"] = cmap
15641567

15651568
if (groups := param_dict.get("groups")) is not None:
1566-
if not isinstance(groups, (list, str)):
1569+
if not isinstance(groups, list | str):
15671570
raise TypeError("Parameter 'groups' must be a string or a list of strings.")
15681571
if isinstance(groups, str):
15691572
param_dict["groups"] = [groups]
@@ -1575,7 +1578,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
15751578
if isinstance((palette := param_dict["palette"]), list):
15761579
if not all(isinstance(p, str) for p in palette):
15771580
raise ValueError("If specified, parameter 'palette' must contain only strings.")
1578-
elif isinstance(palette, (str, type(None))) and "palette" in param_dict:
1581+
elif isinstance(palette, str | type(None)) and "palette" in param_dict:
15791582
param_dict["palette"] = [palette] if palette is not None else None
15801583

15811584
if element_type in ["shapes", "points", "labels"] and (palette := param_dict.get("palette")) is not None:
@@ -1589,9 +1592,9 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
15891592
)
15901593

15911594
if isinstance(cmap, list):
1592-
if not all(isinstance(c, (Colormap, str)) for c in cmap):
1595+
if not all(isinstance(c, Colormap | str) for c in cmap):
15931596
raise TypeError("Each item in 'cmap' list must be a string or a Colormap.")
1594-
elif isinstance(cmap, (Colormap, str, type(None))):
1597+
elif isinstance(cmap, Colormap | str | type(None)):
15951598
if "cmap" in param_dict:
15961599
param_dict["cmap"] = [cmap] if cmap is not None else None
15971600
else:
@@ -1605,20 +1608,20 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
16051608
if (norm := param_dict.get("norm")) is not None:
16061609
if element_type in ["images", "labels"] and not isinstance(norm, Normalize):
16071610
raise TypeError("Parameter 'norm' must be of type Normalize.")
1608-
if element_type in ["shapes", "points"] and not isinstance(norm, (bool, Normalize)):
1611+
if element_type in ["shapes", "points"] and not isinstance(norm, bool | Normalize):
16091612
raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.")
16101613

16111614
if (scale := param_dict.get("scale")) is not None:
16121615
if element_type in ["images", "labels"] and not isinstance(scale, str):
16131616
raise TypeError("Parameter 'scale' must be a string if specified.")
16141617
if element_type == "shapes":
1615-
if not isinstance(scale, (float, int)):
1618+
if not isinstance(scale, float | int):
16161619
raise TypeError("Parameter 'scale' must be numeric.")
16171620
if scale < 0:
16181621
raise ValueError("Parameter 'scale' must be a positive number.")
16191622

16201623
if size := param_dict.get("size"):
1621-
if not isinstance(size, (float, int)):
1624+
if not isinstance(size, float | int):
16221625
raise TypeError("Parameter 'size' must be numeric.")
16231626
if size < 0:
16241627
raise ValueError("Parameter 'size' must be a positive number.")
@@ -1968,7 +1971,7 @@ def _is_coercable_to_float(series: pd.Series) -> bool:
19681971

19691972

19701973
def _ax_show_and_transform(
1971-
array: MaskedArray[np.float64, Any],
1974+
array: MaskedArray[tuple[int, ...], Any],
19721975
trans_data: CompositeGenericTransform,
19731976
ax: Axes,
19741977
alpha: float | None = None,
@@ -2052,7 +2055,7 @@ def _get_extent_and_range_for_datashader_canvas(
20522055

20532056
def _create_image_from_datashader_result(
20542057
ds_result: ds.transfer_functions.Image, factor: float, ax: Axes
2055-
) -> tuple[MaskedArray[np.float64, Any], matplotlib.transforms.CompositeGenericTransform]:
2058+
) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.CompositeGenericTransform]:
20562059
# create SpatialImage from datashader output to get it back to original size
20572060
rgba_image_data = ds_result.to_numpy().base
20582061
rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1))

tests/conftest.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, ABCMeta
2+
from collections.abc import Callable
23
from functools import wraps
34
from pathlib import Path
4-
from typing import Callable, Optional, Union
55

66
import matplotlib.pyplot as plt
77
import numpy as np
@@ -216,7 +216,7 @@ def sdata(request) -> SpatialData:
216216
return s
217217

218218

219-
def _get_images() -> dict[str, Union[DataArray, DataTree]]:
219+
def _get_images() -> dict[str, DataArray | DataTree]:
220220
out = {}
221221
dims_2d = ("c", "y", "x")
222222
dims_3d = ("z", "y", "x", "c")
@@ -243,7 +243,7 @@ def _get_images() -> dict[str, Union[DataArray, DataTree]]:
243243
return out
244244

245245

246-
def _get_labels() -> dict[str, Union[DataArray, DataTree]]:
246+
def _get_labels() -> dict[str, DataArray | DataTree]:
247247
out = {}
248248
dims_2d = ("y", "x")
249249
dims_3d = ("z", "y", "x")
@@ -344,9 +344,9 @@ def _get_points() -> dict[str, pa.Table]:
344344

345345

346346
def _get_table(
347-
region: Optional[AnnData] = None,
348-
region_key: Optional[str] = None,
349-
instance_key: Optional[str] = None,
347+
region: AnnData | None = None,
348+
region_key: str | None = None,
349+
instance_key: str | None = None,
350350
) -> AnnData:
351351
region_key = region_key or "annotated_region"
352352
instance_key = instance_key or "instance_id"
@@ -374,7 +374,7 @@ def __new__(cls, clsname, superclasses, attributedict):
374374

375375
class PlotTester(ABC): # noqa: B024
376376
@classmethod
377-
def compare(cls, basename: str, tolerance: Optional[float] = None):
377+
def compare(cls, basename: str, tolerance: float | None = None):
378378
ACTUAL.mkdir(parents=True, exist_ok=True)
379379
out_path = ACTUAL / f"{basename}.png"
380380

@@ -397,7 +397,7 @@ def compare(cls, basename: str, tolerance: Optional[float] = None):
397397
assert res is None, res
398398

399399

400-
def _decorate(fn: Callable, clsname: str, name: Optional[str] = None) -> Callable:
400+
def _decorate(fn: Callable, clsname: str, name: str | None = None) -> Callable:
401401
@wraps(fn)
402402
def save_and_compare(self, *args, **kwargs):
403403
fn(self, *args, **kwargs)

tests/pl/test_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Union
2-
31
import matplotlib
42
import matplotlib.pyplot as plt
53
import numpy as np
@@ -28,7 +26,7 @@
2826
# replace with
2927
# from spatialdata._types import ColorLike
3028
# once https://github.com/scverse/spatialdata/pull/689/ is in a release
31-
ColorLike = Union[tuple[float, ...], str]
29+
ColorLike = tuple[float, ...] | str
3230

3331

3432
class TestUtils(PlotTester, metaclass=PlotTesterMeta):

0 commit comments

Comments
 (0)