Skip to content

Commit 5cdc924

Browse files
committed
minor reformats
1 parent cf581b3 commit 5cdc924

File tree

3 files changed

+46
-50
lines changed

3 files changed

+46
-50
lines changed

CHANGELOG.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,23 @@ and this project adheres to [Semantic Versioning][].
1010

1111
## [0.2.3] - tbd
1212

13+
### Added
14+
15+
- Datashader support for points and shapes (#244)
16+
1317
### Changed
1418

15-
- All parameters are now provided for a single element. If element in pl.render is None then this value will be broadcasted
19+
- All parameters are now provided for a single element (#272)
1620

1721
### Fixed
1822

1923
- Fix color assignment for NaN values (#257)
20-
- Fix channel str support #221
2124

2225
## [0.2.2] - 2024-05-02
2326

2427
### Fixed
2528

26-
- Fixed `fill_alpha` ignoring `alpha` channel from custom cmap
29+
- Fixed `fill_alpha` ignoring `alpha` channel from custom cmap (#236)
2730
- Fix channel str support (#221)
2831

2932
## [0.2.1] - 2024-03-26

src/spatialdata_plot/pl/basic.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
from matplotlib.colors import Colormap, Normalize
2121
from matplotlib.figure import Figure
2222
from spatialdata import get_extent
23-
24-
# from spatialdata._core.data_extent import get_extent
2523
from spatialdata._utils import _deprecation_alias
2624
from xarray import DataArray
2725

@@ -261,11 +259,11 @@ def render_shapes(
261259
table_name=table_name,
262260
)
263261

264-
if method is not None:
265-
if not isinstance(method, str):
266-
raise TypeError("Parameter 'method' must be a string.")
267-
if method not in ["matplotlib", "datashader"]:
268-
raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.")
262+
if not isinstance(method, str):
263+
raise TypeError("Parameter 'method' must be a string.")
264+
265+
if method not in ["matplotlib", "datashader"]:
266+
raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.")
269267

270268
sdata = self._copy()
271269
sdata = _verify_plotting_tree(sdata)

src/spatialdata_plot/pl/render.py

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import spatialdata as sd
1717
from anndata import AnnData
1818
from datatree import DataTree
19-
20-
# from datatree.datatree import DataTree
2119
from matplotlib.cm import ScalarMappable
2220
from matplotlib.colors import ListedColormap, Normalize
2321
from scanpy._settings import settings as sc_settings
@@ -163,27 +161,7 @@ def _render_shapes(
163161
raise ValueError("Method must be either 'matplotlib' or 'datashader'.")
164162
logger.info(f"Using {method}")
165163

166-
if method == "matplotlib":
167-
_cax = _get_collection_shape(
168-
shapes=shapes,
169-
s=render_params.scale,
170-
c=color_vector,
171-
render_params=render_params,
172-
rasterized=sc_settings._vector_friendly,
173-
cmap=render_params.cmap_params.cmap,
174-
norm=norm,
175-
fill_alpha=render_params.fill_alpha,
176-
outline_alpha=render_params.outline_alpha,
177-
zorder=render_params.zorder,
178-
# **kwargs,
179-
)
180-
cax = ax.add_collection(_cax)
181-
182-
# Transform the paths in PatchCollection
183-
for path in _cax.get_paths():
184-
path.vertices = trans.transform(path.vertices)
185-
cax = ax.add_collection(_cax)
186-
elif method == "datashader":
164+
if method == "datashader":
187165
# TODO: Where to put this
188166
trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData
189167

@@ -209,11 +187,9 @@ def _render_shapes(
209187
# in case we are coloring by a column in table
210188
if col_for_color is not None and col_for_color not in sdata_filt.shapes[element].columns:
211189
# numerical
212-
if color_source_vector is None:
213-
sdata_filt.shapes[element][col_for_color] = color_vector
214-
else: # categorical
215-
sdata_filt.shapes[element][col_for_color] = color_source_vector
216-
190+
sdata_filt.shapes[element][col_for_color] = (
191+
color_vector if color_source_vector is None else color_source_vector
192+
)
217193
# Render shapes with datashader
218194
color_by_categorical = col_for_color is not None and color_source_vector is not None
219195
aggregate_with_sum = None
@@ -232,24 +208,24 @@ def _render_shapes(
232208

233209
color_key = (
234210
[x[:-2] for x in color_vector.categories.values]
235-
if (type(color_vector) == pd.core.arrays.categorical.Categorical)
211+
if (type(color_vector) is pd.core.arrays.categorical.Categorical)
236212
and (len(color_vector.categories.values) > 1)
237213
else None
238214
)
239215

240-
if color_by_categorical or col_for_color is None:
241-
ds_result = ds.tf.shade(
216+
ds_result = (
217+
ds.tf.shade(
242218
agg,
243219
cmap=color_vector[0][:-2],
244220
color_key=color_key,
245221
min_alpha=np.min([150, render_params.fill_alpha * 255]),
246-
) # TODO: choose other value than 150 for min_alpha (here and below)?
247-
else:
248-
ds_result = ds.tf.shade(
222+
)
223+
if color_by_categorical or col_for_color is None
224+
else ds.tf.shade(
249225
agg,
250226
cmap=render_params.cmap_params.cmap,
251227
)
252-
228+
)
253229
# Render image
254230
rgba_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
255231
_cax = ax.imshow(rgba_image, cmap=palette, zorder=render_params.zorder)
@@ -261,6 +237,27 @@ def _render_shapes(
261237
cmap=render_params.cmap_params.cmap,
262238
)
263239

240+
elif method == "matplotlib":
241+
_cax = _get_collection_shape(
242+
shapes=shapes,
243+
s=render_params.scale,
244+
c=color_vector,
245+
render_params=render_params,
246+
rasterized=sc_settings._vector_friendly,
247+
cmap=render_params.cmap_params.cmap,
248+
norm=norm,
249+
fill_alpha=render_params.fill_alpha,
250+
outline_alpha=render_params.outline_alpha,
251+
zorder=render_params.zorder,
252+
# **kwargs,
253+
)
254+
cax = ax.add_collection(_cax)
255+
256+
# Transform the paths in PatchCollection
257+
for path in _cax.get_paths():
258+
path.vertices = trans.transform(path.vertices)
259+
cax = ax.add_collection(_cax)
260+
264261
# Sets the limits of the colorbar to the values instead of [0, 1]
265262
if not norm and not values_are_categorical:
266263
_cax.set_clim(min(color_vector), max(color_vector))
@@ -356,7 +353,7 @@ def _render_points(
356353
)
357354
sdata_filt[table_name] = adata
358355

359-
# we can do this because of dealing with a copy
356+
# we can modify the sdata because of dealing with a copy
360357

361358
# Convert back to dask dataframe to modify sdata
362359
transformation_in_cs = sdata_filt.points[element].attrs["transform"][coordinate_system]
@@ -456,7 +453,7 @@ def _render_points(
456453

457454
color_key = (
458455
[x[:-2] for x in color_vector.categories.values]
459-
if (type(color_vector) == pd.core.arrays.categorical.Categorical)
456+
if (type(color_vector) is pd.core.arrays.categorical.Categorical)
460457
and (len(color_vector.categories.values) > 1)
461458
else None
462459
)
@@ -473,9 +470,8 @@ def _render_points(
473470
ds.tf.spread(agg, px=px),
474471
rescale_discrete_levels=True,
475472
cmap=render_params.cmap_params.cmap,
476-
# color_key=color_key,
477473
)
478-
# render image
474+
479475
rbga_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
480476
cax = ax.imshow(rbga_image, zorder=render_params.zorder, alpha=render_params.alpha)
481477
if aggregate_with_sum is not None:
@@ -498,7 +494,6 @@ def _render_points(
498494
alpha=render_params.alpha,
499495
transform=trans,
500496
zorder=render_params.zorder,
501-
# **kwargs,
502497
)
503498
cax = ax.add_collection(_cax)
504499
if update_parameters:

0 commit comments

Comments
 (0)