Skip to content

Commit 01fc627

Browse files
support for datatree from xarray (#380)
2 parents 0667ef8 + f069e78 commit 01fc627

File tree

9 files changed

+60
-63
lines changed

9 files changed

+60
-63
lines changed

.github/workflows/test.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
strategy:
1919
fail-fast: false
2020
matrix:
21-
python: ["3.9", "3.10"]
21+
python: ["3.10", "3.12"]
2222
os: [ubuntu-latest]
2323

2424
env:
@@ -49,7 +49,7 @@ jobs:
4949
pip install pytest-cov
5050
- name: Install dependencies
5151
run: |
52-
pip install --pre -e ".[dev,test]"
52+
pip install --pre -e ".[dev,test,pre]"
5353
- name: Test
5454
env:
5555
MPLBACKEND: agg

.mypy.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[mypy]
2-
python_version = 3.9
2+
python_version = 3.10
33
plugins = numpy.typing.mypy_plugin
44

55
ignore_errors = False
@@ -25,4 +25,4 @@ no_warn_no_return = True
2525

2626
show_error_codes = True
2727
show_column_numbers = True
28-
error_summary = True
28+
error_summary = True

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ maintainers = [
1414
urls.Documentation = "https://spatialdata.scverse.org/projects/plot/en/latest/index.html"
1515
urls.Source = "https://github.com/scverse/spatialdata-plot.git"
1616
urls.Home-page = "https://github.com/scverse/spatialdata-plot.git"
17-
requires-python = ">=3.9"
17+
requires-python = ">=3.10"
1818
dynamic= [
1919
"version" # allow version to be set by git tags
2020
]
@@ -75,7 +75,7 @@ filterwarnings = [
7575

7676
[tool.black]
7777
line-length = 120
78-
target-version = ['py39']
78+
target-version = ['py310']
7979
include = '\.pyi?$'
8080
exclude = '''
8181
(
@@ -158,7 +158,7 @@ lint.select = [
158158
"PGH", # pygrep-hooks
159159
]
160160
lint.unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"]
161-
target-version = "py39"
161+
target-version = "py310"
162162
[tool.ruff.lint.per-file-ignores]
163163
"tests/*" = ["D", "PT", "B024"]
164164
"*/__init__.py" = ["F401", "D104", "D107", "E402"]

src/spatialdata_plot/pl/basic.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import OrderedDict
66
from copy import deepcopy
77
from pathlib import Path
8-
from typing import Any, Union
8+
from typing import Any
99

1010
import matplotlib.pyplot as plt
1111
import numpy as np
@@ -14,14 +14,13 @@
1414
import spatialdata as sd
1515
from anndata import AnnData
1616
from dask.dataframe import DataFrame as DaskDataFrame
17-
from datatree import DataTree
1817
from geopandas import GeoDataFrame
1918
from matplotlib.axes import Axes
2019
from matplotlib.colors import Colormap, Normalize
2120
from matplotlib.figure import Figure
2221
from spatialdata import get_extent
2322
from spatialdata._utils import _deprecation_alias
24-
from xarray import DataArray
23+
from xarray import DataArray, DataTree
2524

2625
from spatialdata_plot._accessor import register_spatial_data_accessor
2726
from spatialdata_plot.pl.render import (
@@ -62,7 +61,7 @@
6261
# replace with
6362
# from spatialdata._types import ColorLike
6463
# once https://github.com/scverse/spatialdata/pull/689/ is in a release
65-
ColorLike = Union[tuple[float, ...], str]
64+
ColorLike = tuple[float, ...] | str
6665

6766

6867
@register_spatial_data_accessor("pl")
@@ -950,7 +949,7 @@ def show(
950949
if wanted_labels_on_this_cs:
951950
if (table := params_copy.table_name) is not None:
952951
colors = sc.get.obs_df(sdata[table], params_copy.color)
953-
if isinstance(colors.dtype, pd.CategoricalDtype):
952+
if isinstance(colors[params_copy.color].dtype, pd.CategoricalDtype):
954953
_maybe_set_colors(
955954
source=sdata[table],
956955
target=sdata[table],

src/spatialdata_plot/pl/render.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import warnings
44
from collections import abc
55
from copy import copy
6-
from typing import Union
76

87
import dask
98
import datashader as ds
@@ -15,7 +14,6 @@
1514
import scanpy as sc
1615
import spatialdata as sd
1716
from anndata import AnnData
18-
from datatree import DataTree
1917
from matplotlib.cm import ScalarMappable
2018
from matplotlib.colors import ListedColormap, Normalize
2119
from scanpy._settings import settings as sc_settings
@@ -24,6 +22,7 @@
2422
from spatialdata.transformations import (
2523
set_transformation,
2624
)
25+
from xarray import DataTree
2726

2827
from spatialdata_plot._logging import logger
2928
from spatialdata_plot.pl.render_params import (
@@ -56,7 +55,7 @@
5655
to_hex,
5756
)
5857

59-
_Normalize = Union[Normalize, abc.Sequence[Normalize]]
58+
_Normalize = Normalize | abc.Sequence[Normalize]
6059

6160

6261
def _render_shapes(
@@ -442,7 +441,7 @@ def _render_points(
442441
if col_for_color is not None:
443442
cols = sc.get.obs_df(adata, col_for_color)
444443
# maybe set color based on type
445-
if isinstance(cols.dtype, pd.CategoricalDtype):
444+
if isinstance(cols[col_for_color].dtype, pd.CategoricalDtype):
446445
_maybe_set_colors(
447446
source=adata,
448447
target=adata,

src/spatialdata_plot/pl/render_params.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Callable, Sequence
44
from dataclasses import dataclass
5-
from typing import Literal, Union
5+
from typing import Literal
66

77
from matplotlib.axes import Axes
88
from matplotlib.colors import Colormap, ListedColormap, Normalize
@@ -14,7 +14,7 @@
1414
# replace with
1515
# from spatialdata._types import ColorLike
1616
# once https://github.com/scverse/spatialdata/pull/689/ is in a release
17-
ColorLike = Union[tuple[float, ...], str]
17+
ColorLike = tuple[float, ...] | str
1818

1919

2020
@dataclass

0 commit comments

Comments
 (0)