Skip to content

Commit 7716d43

Browse files
render_shapes now respects the cmap parameter (#436)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b01cf09 commit 7716d43

File tree

5 files changed

+317
-59
lines changed

5 files changed

+317
-59
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 87 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_get_extent_and_range_for_datashader_canvas,
4646
_get_linear_colormap,
4747
_get_transformation_matrix_for_datashader,
48+
_hex_no_alpha,
4849
_is_coercable_to_float,
4950
_map_color_seg,
5051
_maybe_set_colors,
@@ -191,7 +192,10 @@ def _render_shapes(
191192
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2]
192193
)
193194
transformed_element = ShapesModel.parse(
194-
gpd.GeoDataFrame(data=sdata_filt.shapes[element].drop("geometry", axis=1), geometry=transformed_element)
195+
gpd.GeoDataFrame(
196+
data=sdata_filt.shapes[element].drop("geometry", axis=1),
197+
geometry=transformed_element,
198+
)
195199
)
196200

197201
plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas(
@@ -208,15 +212,23 @@ def _render_shapes(
208212
aggregate_with_reduction = None
209213
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
210214
if color_by_categorical:
211-
agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.by(col_for_color, ds.count()))
215+
agg = cvs.polygons(
216+
transformed_element,
217+
geometry="geometry",
218+
agg=ds.by(col_for_color, ds.count()),
219+
)
212220
else:
213221
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "mean"
214222
logger.info(
215223
f'Using the datashader reduction "{reduction_name}". "max" will give an output very close '
216224
"to the matplotlib result."
217225
)
218226
agg = _datashader_aggregate_with_function(
219-
render_params.ds_reduction, cvs, transformed_element, col_for_color, "shapes"
227+
render_params.ds_reduction,
228+
cvs,
229+
transformed_element,
230+
col_for_color,
231+
"shapes",
220232
)
221233
# save min and max values for drawing the colorbar
222234
aggregate_with_reduction = (agg.min(), agg.max())
@@ -246,7 +258,7 @@ def _render_shapes(
246258
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)
247259

248260
color_key = (
249-
[x[:-2] for x in color_vector.categories.values]
261+
[_hex_no_alpha(x) for x in color_vector.categories.values]
250262
if (type(color_vector) is pd.core.arrays.categorical.Categorical)
251263
and (len(color_vector.categories.values) > 1)
252264
else None
@@ -257,7 +269,7 @@ def _render_shapes(
257269
if color_vector is not None:
258270
ds_cmap = color_vector[0]
259271
if isinstance(ds_cmap, str) and ds_cmap[0] == "#":
260-
ds_cmap = ds_cmap[:-2]
272+
ds_cmap = _hex_no_alpha(ds_cmap)
261273

262274
ds_result = _datashader_map_aggregate_to_color(
263275
agg,
@@ -272,7 +284,10 @@ def _render_shapes(
272284
# else: all elements would get alpha=0 and the color bar would have a weird range
273285
if aggregate_with_reduction[0] == aggregate_with_reduction[1]:
274286
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
275-
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)
287+
aggregate_with_reduction = (
288+
aggregate_with_reduction[0],
289+
aggregate_with_reduction[0] + 1,
290+
)
276291

277292
ds_result = _datashader_map_aggregate_to_color(
278293
agg,
@@ -468,7 +483,9 @@ def _render_points(
468483
# we construct an anndata to hack the plotting functions
469484
if table_name is None:
470485
adata = AnnData(
471-
X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype
486+
X=points[["x", "y"]].values,
487+
obs=points[coords].reset_index(),
488+
dtype=points[["x", "y"]].values.dtype,
472489
)
473490
else:
474491
adata_obs = sdata_filt[table_name].obs
@@ -496,7 +513,9 @@ def _render_points(
496513
sdata_filt.points[element] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"})
497514
# restore transformation in coordinate system of interest
498515
set_transformation(
499-
element=sdata_filt.points[element], transformation=transformation_in_cs, to_coordinate_system=coordinate_system
516+
element=sdata_filt.points[element],
517+
transformation=transformation_in_cs,
518+
to_coordinate_system=coordinate_system,
500519
)
501520

502521
if col_for_color is not None:
@@ -586,7 +605,11 @@ def _render_points(
586605
"to the matplotlib result."
587606
)
588607
agg = _datashader_aggregate_with_function(
589-
render_params.ds_reduction, cvs, transformed_element, col_for_color, "points"
608+
render_params.ds_reduction,
609+
cvs,
610+
transformed_element,
611+
col_for_color,
612+
"points",
590613
)
591614
# save min and max values for drawing the colorbar
592615
aggregate_with_reduction = (agg.min(), agg.max())
@@ -642,7 +665,10 @@ def _render_points(
642665
# else: all elements would get alpha=0 and the color bar would have a weird range
643666
if aggregate_with_reduction[0] == aggregate_with_reduction[1] and (ds_span is None or ds_span != [0, 1]):
644667
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
645-
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)
668+
aggregate_with_reduction = (
669+
aggregate_with_reduction[0],
670+
aggregate_with_reduction[0] + 1,
671+
)
646672

647673
ds_result = _datashader_map_aggregate_to_color(
648674
agg,
@@ -805,7 +831,12 @@ def _render_images(
805831

806832
# norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip.
807833
_ax_show_and_transform(
808-
layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder, norm=render_params.cmap_params.norm
834+
layer,
835+
trans_data,
836+
ax,
837+
cmap=cmap,
838+
zorder=render_params.zorder,
839+
norm=render_params.cmap_params.norm,
809840
)
810841

811842
if legend_params.colorbar:
@@ -832,7 +863,11 @@ def _render_images(
832863
else: # -> use given cmap for each channel
833864
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
834865
stacked = (
835-
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels
866+
np.stack(
867+
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
868+
0,
869+
).sum(0)
870+
/ n_channels
836871
)
837872
stacked = stacked[:, :, :3]
838873
logger.warning(
@@ -844,7 +879,13 @@ def _render_images(
844879
"Consider using 'palette' instead."
845880
)
846881

847-
_ax_show_and_transform(stacked, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
882+
_ax_show_and_transform(
883+
stacked,
884+
trans_data,
885+
ax,
886+
render_params.alpha,
887+
zorder=render_params.zorder,
888+
)
848889

849890
# 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
850891
elif palette is None and not got_multiple_cmaps:
@@ -858,7 +899,13 @@ def _render_images(
858899
colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)
859900
colored = colored[:, :, :3]
860901

861-
_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
902+
_ax_show_and_transform(
903+
colored,
904+
trans_data,
905+
ax,
906+
render_params.alpha,
907+
zorder=render_params.zorder,
908+
)
862909

863910
# 2C) Image has n channels and palette info
864911
elif palette is not None and not got_multiple_cmaps:
@@ -869,16 +916,32 @@ def _render_images(
869916
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
870917
colored = colored[:, :, :3]
871918

872-
_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
919+
_ax_show_and_transform(
920+
colored,
921+
trans_data,
922+
ax,
923+
render_params.alpha,
924+
zorder=render_params.zorder,
925+
)
873926

874927
elif palette is None and got_multiple_cmaps:
875928
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]
876929
colored = (
877-
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels
930+
np.stack(
931+
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
932+
0,
933+
).sum(0)
934+
/ n_channels
878935
)
879936
colored = colored[:, :, :3]
880937

881-
_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
938+
_ax_show_and_transform(
939+
colored,
940+
trans_data,
941+
ax,
942+
render_params.alpha,
943+
zorder=render_params.zorder,
944+
)
882945

883946
elif palette is not None and got_multiple_cmaps:
884947
raise ValueError("If 'palette' is provided, 'cmap' must be None.")
@@ -999,7 +1062,9 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
9991062
# outline-only case
10001063
elif render_params.fill_alpha == 0.0 and render_params.outline_alpha > 0.0:
10011064
cax = _draw_labels(
1002-
seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha
1065+
seg_erosionpx=render_params.contour_px,
1066+
seg_boundaries=True,
1067+
alpha=render_params.outline_alpha,
10031068
)
10041069
alpha_to_decorate_ax = render_params.outline_alpha
10051070

@@ -1010,7 +1075,9 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
10101075

10111076
# ... then overlay the contour
10121077
cax_contour = _draw_labels(
1013-
seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha
1078+
seg_erosionpx=render_params.contour_px,
1079+
seg_boundaries=True,
1080+
alpha=render_params.outline_alpha,
10141081
)
10151082

10161083
# pass the less-transparent _cax for the legend
@@ -1035,7 +1102,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
10351102
legend_fontweight=legend_params.legend_fontweight,
10361103
legend_loc=legend_params.legend_loc,
10371104
legend_fontoutline=legend_params.legend_fontoutline,
1038-
na_in_legend=legend_params.na_in_legend if groups is None else len(groups) == len(set(color_vector)),
1105+
na_in_legend=(legend_params.na_in_legend if groups is None else len(groups) == len(set(color_vector))),
10391106
colorbar=legend_params.colorbar,
10401107
scalebar_dx=scalebar_params.scalebar_dx,
10411108
scalebar_units=scalebar_params.scalebar_units,

0 commit comments

Comments
 (0)