8
8
from functools import partial
9
9
from pathlib import Path
10
10
from types import MappingProxyType
11
- from typing import Any , Literal , Union
11
+ from typing import Any , Literal
12
12
13
13
import dask
14
14
import datashader as ds
81
81
# replace with
82
82
# from spatialdata._types import ColorLike
83
83
# once https://github.com/scverse/spatialdata/pull/689/ is in a release
84
- ColorLike = Union [ tuple [float , ...], str ]
84
+ ColorLike = tuple [float , ...] | str
85
85
86
86
87
87
def _verify_plotting_tree (sdata : SpatialData ) -> SpatialData :
@@ -526,7 +526,7 @@ def _set_outline(
526
526
outline_color : str | list [float ] = "#0000000ff" , # black, white
527
527
** kwargs : Any ,
528
528
) -> OutlineParams :
529
- if not isinstance (outline_width , ( int , float ) ):
529
+ if not isinstance (outline_width , int | float ):
530
530
raise TypeError (f"Invalid type of `outline_width`: { type (outline_width )} , expected `int` or `float`." )
531
531
if outline_width == 0.0 :
532
532
outline = False
@@ -868,9 +868,9 @@ def _generate_base_categorial_color_mapping(
868
868
na_color = to_hex (to_rgba (na_color )[:3 ])
869
869
870
870
if na_color and len (categories ) > len (colors ):
871
- return dict (zip (categories , colors + [na_color ]))
871
+ return dict (zip (categories , colors + [na_color ], strict = True ))
872
872
873
- return dict (zip (categories , colors ))
873
+ return dict (zip (categories , colors , strict = True ))
874
874
875
875
return _get_default_categorial_color_mapping (color_source_vector )
876
876
@@ -887,7 +887,7 @@ def _modify_categorical_color_mapping(
887
887
# subset base mapping to only those specified in groups
888
888
modified_mapping = {key : mapping [key ] for key in mapping if key in groups or key == "NaN" }
889
889
elif len (palette ) == len (groups ) and isinstance (groups , list ) and isinstance (palette , list ):
890
- modified_mapping = dict (zip (groups , palette ))
890
+ modified_mapping = dict (zip (groups , palette , strict = True ))
891
891
else :
892
892
raise ValueError (f"Expected palette to be of length `{ len (groups )} `, found `{ len (palette )} `." )
893
893
@@ -908,7 +908,10 @@ def _get_default_categorial_color_mapping(
908
908
palette = ["grey" for _ in range (len_cat )]
909
909
logger .info ("input has more than 103 categories. Uniform 'grey' color will be used for all categories." )
910
910
911
- return {cat : to_hex (to_rgba (col )[:3 ]) for cat , col in zip (color_source_vector .categories , palette [:len_cat ])}
911
+ return {
912
+ cat : to_hex (to_rgba (col )[:3 ])
913
+ for cat , col in zip (color_source_vector .categories , palette [:len_cat ], strict = True )
914
+ }
912
915
913
916
914
917
def _get_categorical_color_mapping (
@@ -1342,7 +1345,7 @@ def _multiscale_to_spatial_image(
1342
1345
optimal_index_x -= 1
1343
1346
1344
1347
# pick the scale with higher resolution (worst case: downscaled afterwards)
1345
- optimal_scale = scales [min (optimal_index_x , optimal_index_y )]
1348
+ optimal_scale = scales [min (int ( optimal_index_x ), int ( optimal_index_y ) )]
1346
1349
1347
1350
# NOTE: problematic if there are cases with > 1 data variable
1348
1351
data_var_keys = list (multiscale_image [optimal_scale ].data_vars )
@@ -1412,12 +1415,12 @@ def _validate_show_parameters(
1412
1415
return_ax : bool ,
1413
1416
save : str | Path | None ,
1414
1417
) -> None :
1415
- if coordinate_systems is not None and not isinstance (coordinate_systems , ( list , str ) ):
1418
+ if coordinate_systems is not None and not isinstance (coordinate_systems , list | str ):
1416
1419
raise TypeError ("Parameter 'coordinate_systems' must be a string or a list of strings." )
1417
1420
1418
1421
font_weights = ["light" , "normal" , "medium" , "semibold" , "bold" , "heavy" , "black" ]
1419
1422
if legend_fontweight is not None and (
1420
- not isinstance (legend_fontweight , ( int , str ) )
1423
+ not isinstance (legend_fontweight , int | str )
1421
1424
or (isinstance (legend_fontweight , str ) and legend_fontweight not in font_weights )
1422
1425
):
1423
1426
readable_font_weights = ", " .join (font_weights [:- 1 ]) + ", or " + font_weights [- 1 ]
@@ -1429,7 +1432,7 @@ def _validate_show_parameters(
1429
1432
font_sizes = ["xx-small" , "x-small" , "small" , "medium" , "large" , "x-large" , "xx-large" ]
1430
1433
1431
1434
if legend_fontsize is not None and (
1432
- not isinstance (legend_fontsize , ( int , float , str ) )
1435
+ not isinstance (legend_fontsize , int | float | str )
1433
1436
or (isinstance (legend_fontsize , str ) and legend_fontsize not in font_sizes )
1434
1437
):
1435
1438
readable_font_sizes = ", " .join (font_sizes [:- 1 ]) + ", or " + font_sizes [- 1 ]
@@ -1471,22 +1474,22 @@ def _validate_show_parameters(
1471
1474
if fig is not None and not isinstance (fig , Figure ):
1472
1475
raise TypeError ("Parameter 'fig' must be a matplotlib.figure.Figure." )
1473
1476
1474
- if title is not None and not isinstance (title , ( list , str ) ):
1477
+ if title is not None and not isinstance (title , list | str ):
1475
1478
raise TypeError ("Parameter 'title' must be a string or a list of strings." )
1476
1479
1477
1480
if not isinstance (share_extent , bool ):
1478
1481
raise TypeError ("Parameter 'share_extent' must be a boolean." )
1479
1482
1480
- if not isinstance (pad_extent , ( int , float ) ):
1483
+ if not isinstance (pad_extent , int | float ):
1481
1484
raise TypeError ("Parameter 'pad_extent' must be numeric." )
1482
1485
1483
- if ax is not None and not isinstance (ax , ( Axes , list ) ):
1486
+ if ax is not None and not isinstance (ax , Axes | list ):
1484
1487
raise TypeError ("Parameter 'ax' must be a matplotlib.axes.Axes or a list of Axes." )
1485
1488
1486
1489
if not isinstance (return_ax , bool ):
1487
1490
raise TypeError ("Parameter 'return_ax' must be a boolean." )
1488
1491
1489
- if save is not None and not isinstance (save , ( str , Path ) ):
1492
+ if save is not None and not isinstance (save , str | Path ):
1490
1493
raise TypeError ("Parameter 'save' must be a string or a pathlib.Path." )
1491
1494
1492
1495
@@ -1505,10 +1508,10 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
1505
1508
elif element_type == "shapes" :
1506
1509
param_dict ["element" ] = [element ] if element is not None else list (param_dict ["sdata" ].shapes .keys ())
1507
1510
1508
- if (channel := param_dict .get ("channel" )) is not None and not isinstance (channel , ( list , str , int ) ):
1511
+ if (channel := param_dict .get ("channel" )) is not None and not isinstance (channel , list | str | int ):
1509
1512
raise TypeError ("Parameter 'channel' must be a string, an integer, or a list of strings or integers." )
1510
1513
if isinstance (channel , list ):
1511
- if not all (isinstance (c , ( str , int ) ) for c in channel ):
1514
+ if not all (isinstance (c , str | int ) for c in channel ):
1512
1515
raise TypeError ("Each item in 'channel' list must be a string or an integer." )
1513
1516
if not all (isinstance (c , type (channel [0 ])) for c in channel ):
1514
1517
raise TypeError ("Each item in 'channel' list must be of the same type, either string or integer." )
@@ -1533,27 +1536,27 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
1533
1536
param_dict ["col_for_color" ] = None
1534
1537
1535
1538
if outline_width := param_dict .get ("outline_width" ):
1536
- if not isinstance (outline_width , ( float , int ) ):
1539
+ if not isinstance (outline_width , float | int ):
1537
1540
raise TypeError ("Parameter 'outline_width' must be numeric." )
1538
1541
if outline_width < 0 :
1539
1542
raise ValueError ("Parameter 'outline_width' cannot be negative." )
1540
1543
1541
1544
if (outline_alpha := param_dict .get ("outline_alpha" )) and (
1542
- not isinstance (outline_alpha , ( float , int ) ) or not 0 <= outline_alpha <= 1
1545
+ not isinstance (outline_alpha , float | int ) or not 0 <= outline_alpha <= 1
1543
1546
):
1544
1547
raise TypeError ("Parameter 'outline_alpha' must be numeric and between 0 and 1." )
1545
1548
1546
1549
if contour_px is not None and contour_px <= 0 :
1547
1550
raise ValueError ("Parameter 'contour_px' must be a positive number." )
1548
1551
1549
1552
if (alpha := param_dict .get ("alpha" )) is not None :
1550
- if not isinstance (alpha , ( float , int ) ):
1553
+ if not isinstance (alpha , float | int ):
1551
1554
raise TypeError ("Parameter 'alpha' must be numeric." )
1552
1555
if not 0 <= alpha <= 1 :
1553
1556
raise ValueError ("Parameter 'alpha' must be between 0 and 1." )
1554
1557
1555
1558
if (fill_alpha := param_dict .get ("fill_alpha" )) is not None :
1556
- if not isinstance (fill_alpha , ( float , int ) ):
1559
+ if not isinstance (fill_alpha , float | int ):
1557
1560
raise TypeError ("Parameter 'fill_alpha' must be numeric." )
1558
1561
if fill_alpha < 0 :
1559
1562
raise ValueError ("Parameter 'fill_alpha' cannot be negative." )
@@ -1563,7 +1566,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
1563
1566
param_dict ["cmap" ] = cmap
1564
1567
1565
1568
if (groups := param_dict .get ("groups" )) is not None :
1566
- if not isinstance (groups , ( list , str ) ):
1569
+ if not isinstance (groups , list | str ):
1567
1570
raise TypeError ("Parameter 'groups' must be a string or a list of strings." )
1568
1571
if isinstance (groups , str ):
1569
1572
param_dict ["groups" ] = [groups ]
@@ -1575,7 +1578,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
1575
1578
if isinstance ((palette := param_dict ["palette" ]), list ):
1576
1579
if not all (isinstance (p , str ) for p in palette ):
1577
1580
raise ValueError ("If specified, parameter 'palette' must contain only strings." )
1578
- elif isinstance (palette , ( str , type (None ) )) and "palette" in param_dict :
1581
+ elif isinstance (palette , str | type (None )) and "palette" in param_dict :
1579
1582
param_dict ["palette" ] = [palette ] if palette is not None else None
1580
1583
1581
1584
if element_type in ["shapes" , "points" , "labels" ] and (palette := param_dict .get ("palette" )) is not None :
@@ -1589,9 +1592,9 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
1589
1592
)
1590
1593
1591
1594
if isinstance (cmap , list ):
1592
- if not all (isinstance (c , ( Colormap , str ) ) for c in cmap ):
1595
+ if not all (isinstance (c , Colormap | str ) for c in cmap ):
1593
1596
raise TypeError ("Each item in 'cmap' list must be a string or a Colormap." )
1594
- elif isinstance (cmap , ( Colormap , str , type (None ) )):
1597
+ elif isinstance (cmap , Colormap | str | type (None )):
1595
1598
if "cmap" in param_dict :
1596
1599
param_dict ["cmap" ] = [cmap ] if cmap is not None else None
1597
1600
else :
@@ -1605,20 +1608,20 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
1605
1608
if (norm := param_dict .get ("norm" )) is not None :
1606
1609
if element_type in ["images" , "labels" ] and not isinstance (norm , Normalize ):
1607
1610
raise TypeError ("Parameter 'norm' must be of type Normalize." )
1608
- if element_type in ["shapes" , "points" ] and not isinstance (norm , ( bool , Normalize ) ):
1611
+ if element_type in ["shapes" , "points" ] and not isinstance (norm , bool | Normalize ):
1609
1612
raise TypeError ("Parameter 'norm' must be a boolean or a mpl.Normalize." )
1610
1613
1611
1614
if (scale := param_dict .get ("scale" )) is not None :
1612
1615
if element_type in ["images" , "labels" ] and not isinstance (scale , str ):
1613
1616
raise TypeError ("Parameter 'scale' must be a string if specified." )
1614
1617
if element_type == "shapes" :
1615
- if not isinstance (scale , ( float , int ) ):
1618
+ if not isinstance (scale , float | int ):
1616
1619
raise TypeError ("Parameter 'scale' must be numeric." )
1617
1620
if scale < 0 :
1618
1621
raise ValueError ("Parameter 'scale' must be a positive number." )
1619
1622
1620
1623
if size := param_dict .get ("size" ):
1621
- if not isinstance (size , ( float , int ) ):
1624
+ if not isinstance (size , float | int ):
1622
1625
raise TypeError ("Parameter 'size' must be numeric." )
1623
1626
if size < 0 :
1624
1627
raise ValueError ("Parameter 'size' must be a positive number." )
@@ -1968,7 +1971,7 @@ def _is_coercable_to_float(series: pd.Series) -> bool:
1968
1971
1969
1972
1970
1973
def _ax_show_and_transform (
1971
- array : MaskedArray [np . float64 , Any ],
1974
+ array : MaskedArray [tuple [ int , ...] , Any ],
1972
1975
trans_data : CompositeGenericTransform ,
1973
1976
ax : Axes ,
1974
1977
alpha : float | None = None ,
@@ -2052,7 +2055,7 @@ def _get_extent_and_range_for_datashader_canvas(
2052
2055
2053
2056
def _create_image_from_datashader_result (
2054
2057
ds_result : ds .transfer_functions .Image , factor : float , ax : Axes
2055
- ) -> tuple [MaskedArray [np . float64 , Any ], matplotlib .transforms .CompositeGenericTransform ]:
2058
+ ) -> tuple [MaskedArray [tuple [ int , ...] , Any ], matplotlib .transforms .CompositeGenericTransform ]:
2056
2059
# create SpatialImage from datashader output to get it back to original size
2057
2060
rgba_image_data = ds_result .to_numpy ().base
2058
2061
rgba_image_data = np .transpose (rgba_image_data , (2 , 0 , 1 ))
0 commit comments