Skip to content

Commit b84c1ab

Browse files
authored
ruff fixes (#102)
1 parent 18a15d8 commit b84c1ab

File tree

4 files changed

+27
-27
lines changed

4 files changed

+27
-27
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import OrderedDict
55
from collections.abc import Sequence
66
from pathlib import Path
7-
from typing import Any, Optional, Union
7+
from typing import Any
88

99
import matplotlib.pyplot as plt
1010
import scanpy as sc
@@ -82,11 +82,11 @@ def __init__(self, sdata: sd.SpatialData) -> None:
8282

8383
def _copy(
8484
self,
85-
images: Union[None, dict[str, Union[SpatialImage, MultiscaleSpatialImage]]] = None,
86-
labels: Union[None, dict[str, Union[SpatialImage, MultiscaleSpatialImage]]] = None,
87-
points: Union[None, dict[str, DaskDataFrame]] = None,
88-
shapes: Union[None, dict[str, GeoDataFrame]] = None,
89-
table: Union[None, AnnData] = None,
85+
images: None | dict[str, SpatialImage | MultiscaleSpatialImage] = None,
86+
labels: None | dict[str, SpatialImage | MultiscaleSpatialImage] = None,
87+
points: None | dict[str, DaskDataFrame] = None,
88+
shapes: None | dict[str, GeoDataFrame] = None,
89+
table: None | AnnData = None,
9090
) -> sd.SpatialData:
9191
"""Copy the current `SpatialData` object, optionally modifying some of its attributes.
9292
@@ -150,7 +150,7 @@ def render_shapes(
150150
layer: str | None = None,
151151
palette: Palette_t = None,
152152
cmap: Colormap | str | None = None,
153-
norm: Optional[Normalize] = None,
153+
norm: None | Normalize = None,
154154
na_color: str | tuple[float, ...] | None = "lightgrey",
155155
outline_alpha: float = 1.0,
156156
fill_alpha: float = 1.0,
@@ -232,7 +232,7 @@ def render_points(
232232
size: float = 1.0,
233233
palette: Palette_t = None,
234234
cmap: Colormap | str | None = None,
235-
norm: Optional[Normalize] = None,
235+
norm: None | Normalize = None,
236236
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
237237
alpha: float = 1.0,
238238
**kwargs: Any,
@@ -296,7 +296,7 @@ def render_images(
296296
elements: str | list[str] | None = None,
297297
channel: list[str] | list[int] | int | str | None = None,
298298
cmap: Colormap | str | None = None,
299-
norm: Optional[Normalize] = None,
299+
norm: None | Normalize = None,
300300
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
301301
palette: Palette_t = None,
302302
alpha: float = 1.0,
@@ -358,7 +358,7 @@ def render_labels(
358358
layer: str | None = None,
359359
palette: Palette_t = None,
360360
cmap: Colormap | str | None = None,
361-
norm: Optional[Normalize] = None,
361+
norm: None | Normalize = None,
362362
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
363363
outline_alpha: float = 1.0,
364364
fill_alpha: float = 0.3,
@@ -452,11 +452,11 @@ def show(
452452
figsize: tuple[float, float] | None = None,
453453
dpi: int | None = None,
454454
fig: Figure | None = None,
455-
title: Optional[Union[str, Sequence[str]]] = None,
455+
title: None | str | Sequence[str] = None,
456456
share_extent: bool = True,
457457
ax: Axes | Sequence[Axes] | None = None,
458458
return_ax: bool = False,
459-
save: Optional[Union[str, Path]] = None,
459+
save: None | str | Path = None,
460460
) -> sd.SpatialData:
461461
"""
462462
Plot the images in the SpatialData object.
@@ -517,7 +517,7 @@ def show(
517517
if isinstance(title, str):
518518
title = [title]
519519

520-
if not all([isinstance(t, str) for t in title]):
520+
if not all(isinstance(t, str) for t in title):
521521
raise TypeError("All titles must be strings.")
522522

523523
# Simplicstic solution: If the images are multiscale, just use the first

src/spatialdata_plot/pl/render.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ def _get_collection_shape(
114114
c: Any,
115115
s: float,
116116
norm: Any,
117-
fill_alpha: Optional[float] = None,
118-
outline_alpha: Optional[float] = None,
117+
fill_alpha: None | float = None,
118+
outline_alpha: None | float = None,
119119
**kwargs: Any,
120120
) -> PatchCollection:
121121
patches = []

src/spatialdata_plot/pl/utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _prepare_params_plot(
109109
fig, grid = _panel_grid(
110110
num_panels=num_panels, hspace=hspace, wspace=wspace, ncols=ncols, dpi=dpi, figsize=figsize
111111
)
112-
axs: Union[Sequence[Axes], None] = [plt.subplot(grid[c]) for c in range(num_panels)]
112+
axs: None | Sequence[Axes] = [plt.subplot(grid[c]) for c in range(num_panels)]
113113
elif num_panels > 1 and ax is not None:
114114
if len(ax) != num_panels:
115115
raise ValueError(f"Len of `ax`: {len(ax)} is not equal to number of panels: {num_panels}.")
@@ -148,10 +148,10 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame:
148148

149149
for cs_name, element_ids in cs_mapping.items():
150150
# determine if coordinate system has the respective elements
151-
cs_has_images = bool(any([(e in sdata.images) for e in element_ids]))
152-
cs_has_labels = bool(any([(e in sdata.labels) for e in element_ids]))
153-
cs_has_points = bool(any([(e in sdata.points) for e in element_ids]))
154-
cs_has_shapes = bool(any([(e in sdata.shapes) for e in element_ids]))
151+
cs_has_images = bool(any((e in sdata.images) for e in element_ids))
152+
cs_has_labels = bool(any((e in sdata.labels) for e in element_ids))
153+
cs_has_points = bool(any((e in sdata.points) for e in element_ids))
154+
cs_has_shapes = bool(any((e in sdata.shapes) for e in element_ids))
155155

156156
cs_contents = pd.concat(
157157
[
@@ -178,7 +178,7 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame:
178178

179179
def _get_extent(
180180
sdata: sd.SpatialData,
181-
coordinate_systems: Optional[Union[str, Sequence[str]]] = None,
181+
coordinate_systems: None | str | Sequence[str] = None,
182182
has_images: bool = True,
183183
has_labels: bool = True,
184184
has_points: bool = True,
@@ -477,7 +477,7 @@ class OutlineParams:
477477
gap_size: float
478478
gap_color: str
479479
bg_size: float
480-
bg_color: Union[str, tuple[float, ...]]
480+
bg_color: str | tuple[float, ...]
481481

482482

483483
def _set_outline(
@@ -501,7 +501,7 @@ def _set_outline(
501501
return OutlineParams(outline, gap_size, gap_color, bg_size, bg_color)
502502

503503

504-
def _get_subplots(num_images: int, ncols: int = 4, width: int = 4, height: int = 3) -> Union[plt.Figure, plt.Axes]:
504+
def _get_subplots(num_images: int, ncols: int = 4, width: int = 4, height: int = 3) -> plt.Figure | plt.Axes:
505505
"""Set up the axs objects.
506506
507507
Parameters
@@ -627,7 +627,7 @@ def _normalize(
627627
return norm
628628

629629

630-
def _get_colors_for_categorical_obs(categories: Sequence[Union[str, int]], palette: Palette_t = None) -> list[str]:
630+
def _get_colors_for_categorical_obs(categories: Sequence[str | int], palette: Palette_t = None) -> list[str]:
631631
"""
632632
Return a list of colors for a categorical observation.
633633
@@ -768,7 +768,7 @@ def _map_color_seg(
768768
def _get_palette(
769769
categories: Sequence[Any],
770770
adata: AnnData | None = None,
771-
cluster_key: Optional[str] | None = None,
771+
cluster_key: None | str = None,
772772
palette: Palette_t = None,
773773
alpha: float = 1.0,
774774
) -> Mapping[str, str] | None:
@@ -1079,7 +1079,7 @@ def _flatten_transformation_sequence(
10791079
transformations = list(transformation_sequence.transformations)
10801080
found_bottom_of_tree = False
10811081
while not found_bottom_of_tree:
1082-
if all([not isinstance(t, sd.transformations.transformations.Sequence) for t in transformations]):
1082+
if all(not isinstance(t, sd.transformations.transformations.Sequence) for t in transformations):
10831083
found_bottom_of_tree = True
10841084
else:
10851085
for idx, t in enumerate(transformations):

src/spatialdata_plot/pp/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def get_elements(self, elements: Union[str, list[str]]) -> sd.SpatialData:
116116
if not isinstance(elements, (str, list)):
117117
raise TypeError("Parameter 'elements' must be a string or a list of strings.")
118118

119-
if not all([isinstance(e, str) for e in elements]):
119+
if not all(isinstance(e, str) for e in elements):
120120
raise TypeError("When parameter 'elements' is a list, all elements must be strings.")
121121

122122
if isinstance(elements, str):

0 commit comments

Comments
 (0)