4
4
from copy import copy
5
5
from typing import Union
6
6
7
+ import dask
7
8
import geopandas as gpd
8
9
import matplotlib
9
10
import numpy as np
18
19
from spatialdata .models import (
19
20
Image2DModel ,
20
21
Labels2DModel ,
22
+ PointsModel ,
21
23
)
22
24
23
25
from spatialdata_plot ._logging import logger
@@ -57,6 +59,12 @@ def _render_shapes(
57
59
) -> None :
58
60
elements = render_params .elements
59
61
62
+ if render_params .groups is not None :
63
+ if isinstance (render_params .groups , str ):
64
+ render_params .groups = [render_params .groups ]
65
+ if not all (isinstance (g , str ) for g in render_params .groups ):
66
+ raise TypeError ("All groups must be strings." )
67
+
60
68
sdata_filt = sdata .filter_by_coordinate_system (
61
69
coordinate_system = coordinate_system ,
62
70
filter_table = sdata .table is not None ,
@@ -68,7 +76,6 @@ def _render_shapes(
68
76
elements = list (sdata_filt .shapes .keys ())
69
77
70
78
for e in elements :
71
- # shapes = [sdata.shapes[e] for e in elements]
72
79
shapes = sdata .shapes [e ]
73
80
n_shapes = sum ([len (s ) for s in shapes ])
74
81
@@ -88,6 +95,7 @@ def _render_shapes(
88
95
palette = render_params .palette ,
89
96
na_color = render_params .cmap_params .na_color ,
90
97
alpha = render_params .fill_alpha ,
98
+ cmap_params = render_params .cmap_params ,
91
99
)
92
100
93
101
values_are_categorical = color_source_vector is not None
@@ -101,7 +109,15 @@ def _render_shapes(
101
109
if len (color_vector ) == 0 :
102
110
color_vector = [render_params .cmap_params .na_color ]
103
111
112
+ # filter by `groups`
113
+ if render_params .groups is not None and color_source_vector is not None :
114
+ mask = color_source_vector .isin (render_params .groups )
115
+ shapes = shapes [mask ]
116
+ shapes = shapes .reset_index ()
117
+ color_source_vector = color_source_vector [mask ]
118
+ color_vector = color_vector [mask ]
104
119
shapes = gpd .GeoDataFrame (shapes , geometry = "geometry" )
120
+
105
121
_cax = _get_collection_shape (
106
122
shapes = shapes ,
107
123
s = render_params .scale ,
@@ -122,9 +138,12 @@ def _render_shapes(
122
138
cax = ax .add_collection (_cax )
123
139
124
140
# Using dict.fromkeys here since set returns in arbitrary order
125
- palette = (
126
- ListedColormap (dict .fromkeys (color_vector )) if render_params .palette is None else render_params .palette
127
- )
141
+ # remove the color of NaN values, else it might be assigned to a category
142
+ # order of color in the palette should agree to order of occurence
143
+ if color_source_vector is None :
144
+ palette = ListedColormap (dict .fromkeys (color_vector ))
145
+ else :
146
+ palette = ListedColormap (dict .fromkeys (color_vector [~ pd .Categorical (color_source_vector ).isnull ()]))
128
147
129
148
if not (
130
149
len (set (color_vector )) == 1 and list (set (color_vector ))[0 ] == to_hex (render_params .cmap_params .na_color )
@@ -159,6 +178,12 @@ def _render_points(
159
178
scalebar_params : ScalebarParams ,
160
179
legend_params : LegendParams ,
161
180
) -> None :
181
+ if render_params .groups is not None :
182
+ if isinstance (render_params .groups , str ):
183
+ render_params .groups = [render_params .groups ]
184
+ if not all (isinstance (g , str ) for g in render_params .groups ):
185
+ raise TypeError ("All groups must be strings." )
186
+
162
187
elements = render_params .elements
163
188
164
189
sdata_filt = sdata .filter_by_coordinate_system (
@@ -178,6 +203,14 @@ def _render_points(
178
203
color = [render_params .color ] if isinstance (render_params .color , str ) else render_params .color
179
204
coords .extend (color )
180
205
206
+ points = points [coords ].compute ()
207
+ # points[color[0]].cat.set_categories(render_params.groups, inplace=True)
208
+ if render_params .groups is not None :
209
+ points = points [points [color ].isin (render_params .groups ).values ]
210
+ points [color [0 ]] = points [color [0 ]].cat .set_categories (render_params .groups )
211
+ points = dask .dataframe .from_pandas (points , npartitions = 1 )
212
+ sdata_filt .points [e ] = PointsModel .parse (points , coordinates = {"x" : "x" , "y" : "y" })
213
+
181
214
point_df = points [coords ].compute ()
182
215
183
216
# we construct an anndata to hack the plotting functions
@@ -204,6 +237,7 @@ def _render_points(
204
237
palette = render_params .palette ,
205
238
na_color = render_params .cmap_params .na_color ,
206
239
alpha = render_params .alpha ,
240
+ cmap_params = render_params .cmap_params ,
207
241
)
208
242
209
243
# color_source_vector is None when the values aren't categorical
@@ -226,14 +260,19 @@ def _render_points(
226
260
if not (
227
261
len (set (color_vector )) == 1 and list (set (color_vector ))[0 ] == to_hex (render_params .cmap_params .na_color )
228
262
):
263
+ if color_source_vector is None :
264
+ palette = ListedColormap (dict .fromkeys (color_vector ))
265
+ else :
266
+ palette = ListedColormap (dict .fromkeys (color_vector [~ pd .Categorical (color_source_vector ).isnull ()]))
267
+
229
268
_ = _decorate_axs (
230
269
ax = ax ,
231
270
cax = cax ,
232
271
fig_params = fig_params ,
233
272
adata = adata ,
234
273
value_to_plot = render_params .color ,
235
274
color_source_vector = color_source_vector ,
236
- palette = render_params . palette ,
275
+ palette = palette ,
237
276
alpha = render_params .alpha ,
238
277
na_color = render_params .cmap_params .na_color ,
239
278
legend_fontsize = legend_params .legend_fontsize ,
@@ -415,6 +454,12 @@ def _render_labels(
415
454
) -> None :
416
455
elements = render_params .elements
417
456
457
+ if render_params .groups is not None :
458
+ if isinstance (render_params .groups , str ):
459
+ render_params .groups = [render_params .groups ]
460
+ if not all (isinstance (g , str ) for g in render_params .groups ):
461
+ raise TypeError ("All groups must be strings." )
462
+
418
463
sdata_filt = sdata .filter_by_coordinate_system (
419
464
coordinate_system = coordinate_system ,
420
465
filter_table = sdata .table is not None ,
@@ -441,7 +486,7 @@ def _render_labels(
441
486
442
487
table = sdata .table [sdata .table .obs [region_key ].isin ([label_key ])]
443
488
444
- # get isntance id based on subsetted table
489
+ # get instance id based on subsetted table
445
490
instance_id = table .obs [instance_key ].values
446
491
447
492
# get color vector (categorical or continuous)
@@ -455,6 +500,7 @@ def _render_labels(
455
500
palette = render_params .palette ,
456
501
na_color = render_params .cmap_params .na_color ,
457
502
alpha = render_params .fill_alpha ,
503
+ cmap_params = render_params .cmap_params ,
458
504
)
459
505
460
506
if (render_params .fill_alpha != render_params .outline_alpha ) and render_params .contour_px is not None :
0 commit comments