Skip to content

Commit cad33a1

Browse files
Sonja-StockhausSonja Stockhauspre-commit-ci[bot]timtreis
authored
Fix norm behavior, add tests (#419)
Co-authored-by: Sonja Stockhaus <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Treis <[email protected]> Co-authored-by: Tim Treis <[email protected]>
1 parent cf42151 commit cad33a1

24 files changed

+247
-60
lines changed

docs/extensions/typed_returns.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
def _process_return(lines: Iterable[str]) -> Generator[str, None, None]:
1313
for line in lines:
1414
if m := re.fullmatch(r"(?P<param>\w+)\s+:\s+(?P<type>[\w.]+)", line):
15-
yield f'-{m["param"]} (:class:`~{m["type"]}`)'
15+
yield f"-{m['param']} (:class:`~{m['type']}`)"
1616
else:
1717
yield line
1818

src/spatialdata_plot/pl/basic.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def render_shapes(
290290
norm=norm,
291291
na_color=params_dict[element]["na_color"], # type: ignore[arg-type]
292292
)
293-
sdata.plotting_tree[f"{n_steps+1}_render_shapes"] = ShapesRenderParams(
293+
sdata.plotting_tree[f"{n_steps + 1}_render_shapes"] = ShapesRenderParams(
294294
element=element,
295295
color=param_values["color"],
296296
col_for_color=param_values["col_for_color"],
@@ -433,7 +433,7 @@ def render_points(
433433
norm=norm,
434434
na_color=param_values["na_color"], # type: ignore[arg-type]
435435
)
436-
sdata.plotting_tree[f"{n_steps+1}_render_points"] = PointsRenderParams(
436+
sdata.plotting_tree[f"{n_steps + 1}_render_points"] = PointsRenderParams(
437437
element=element,
438438
color=param_values["color"],
439439
col_for_color=param_values["col_for_color"],
@@ -538,7 +538,6 @@ def render_images(
538538
n_steps = len(sdata.plotting_tree.keys())
539539

540540
for element, param_values in params_dict.items():
541-
542541
cmap_params: list[CmapParams] | CmapParams
543542
if isinstance(cmap, list):
544543
cmap_params = [
@@ -557,7 +556,7 @@ def render_images(
557556
na_color=param_values["na_color"],
558557
**kwargs,
559558
)
560-
sdata.plotting_tree[f"{n_steps+1}_render_images"] = ImageRenderParams(
559+
sdata.plotting_tree[f"{n_steps + 1}_render_images"] = ImageRenderParams(
561560
element=element,
562561
channel=param_values["channel"],
563562
cmap_params=cmap_params,
@@ -683,7 +682,7 @@ def render_labels(
683682
norm=norm,
684683
na_color=param_values["na_color"], # type: ignore[arg-type]
685684
)
686-
sdata.plotting_tree[f"{n_steps+1}_render_labels"] = LabelsRenderParams(
685+
sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams(
687686
element=element,
688687
color=param_values["color"],
689688
groups=param_values["groups"],

src/spatialdata_plot/pl/render.py

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
_ax_show_and_transform,
3838
_create_image_from_datashader_result,
3939
_datashader_aggregate_with_function,
40+
_datashader_map_aggregate_to_color,
4041
_datshader_get_how_kw_for_spread,
4142
_decorate_axs,
4243
_get_collection_shape,
@@ -229,18 +230,20 @@ def _render_shapes(
229230
line_width=render_params.outline_params.linewidth,
230231
)
231232

233+
ds_span = None
232234
if norm.vmin is not None or norm.vmax is not None:
233235
norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin
234236
norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax
235-
norm.clip = True # NOTE: mpl currently behaves like clip is always True
237+
ds_span = [norm.vmin, norm.vmax]
236238
if norm.vmin == norm.vmax:
237-
# data is mapped to 0
238-
agg = agg - agg
239-
else:
240-
agg = (agg - norm.vmin) / (norm.vmax - norm.vmin)
239+
# edge case, value vmin is rendered as the middle of the cmap
240+
ds_span = [0, 1]
241241
if norm.clip:
242-
agg = np.maximum(agg, 0)
243-
agg = np.minimum(agg, 1)
242+
agg = (agg - agg) + 0.5
243+
else:
244+
agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1)
245+
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2)
246+
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)
244247

245248
color_key = (
246249
[x[:-2] for x in color_vector.categories.values]
@@ -256,13 +259,12 @@ def _render_shapes(
256259
if isinstance(ds_cmap, str) and ds_cmap[0] == "#":
257260
ds_cmap = ds_cmap[:-2]
258261

259-
ds_result = ds.tf.shade(
262+
ds_result = _datashader_map_aggregate_to_color(
260263
agg,
261264
cmap=ds_cmap,
262265
color_key=color_key,
263266
min_alpha=np.min([254, render_params.fill_alpha * 255]),
264-
how="linear",
265-
)
267+
) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
266268
elif aggregate_with_reduction is not None: # to shut up mypy
267269
ds_cmap = render_params.cmap_params.cmap
268270
# in case all elements have the same value X: we render them using cmap(0.0),
@@ -272,12 +274,13 @@ def _render_shapes(
272274
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
273275
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)
274276

275-
ds_result = ds.tf.shade(
277+
ds_result = _datashader_map_aggregate_to_color(
276278
agg,
277279
cmap=ds_cmap,
278-
how="linear",
279280
min_alpha=np.min([254, render_params.fill_alpha * 255]),
280-
)
281+
span=ds_span,
282+
clip=norm.clip,
283+
) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
281284

282285
# shade outlines if needed
283286
outline_color = render_params.outline_params.outline_color
@@ -294,7 +297,7 @@ def _render_shapes(
294297
cmap=outline_color,
295298
min_alpha=np.min([254, render_params.outline_alpha * 255]),
296299
how="linear",
297-
)
300+
) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
298301

299302
rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
300303
_cax = _ax_show_and_transform(
@@ -322,8 +325,10 @@ def _render_shapes(
322325
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
323326
vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax
324327
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
325-
vmin = norm.vmin
326-
vmax = norm.vmin + 1
328+
# value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
329+
# under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
330+
vmin = norm.vmin - 0.5
331+
vmax = norm.vmin + 0.5
327332
cax = ScalarMappable(
328333
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax),
329334
cmap=render_params.cmap_params.cmap,
@@ -586,18 +591,21 @@ def _render_points(
586591
else:
587592
agg = cvs.points(transformed_element, "x", "y", agg=ds.count())
588593

594+
ds_span = None
589595
if norm.vmin is not None or norm.vmax is not None:
590596
norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin
591597
norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax
592-
norm.clip = True # NOTE: mpl currently behaves like clip is always True
598+
ds_span = [norm.vmin, norm.vmax]
593599
if norm.vmin == norm.vmax:
594-
# data is mapped to 0
595-
agg = agg - agg
596-
else:
597-
agg = (agg - norm.vmin) / (norm.vmax - norm.vmin)
600+
ds_span = [0, 1]
598601
if norm.clip:
599-
agg = np.maximum(agg, 0)
600-
agg = np.minimum(agg, 1)
602+
# all data is mapped to 0.5
603+
agg = (agg - agg) + 0.5
604+
else:
605+
# values equal to norm.vmin are mapped to 0.5, the rest to -1 or 2
606+
agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1)
607+
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2)
608+
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)
601609

602610
color_key = (
603611
list(color_vector.categories.values)
@@ -615,13 +623,12 @@ def _render_points(
615623
color_vector = np.asarray([x[:-2] for x in color_vector])
616624

617625
if color_by_categorical or col_for_color is None:
618-
ds_result = ds.tf.shade(
626+
ds_result = _datashader_map_aggregate_to_color(
619627
ds.tf.spread(agg, px=px),
620628
cmap=color_vector[0],
621629
color_key=color_key,
622630
min_alpha=np.min([254, render_params.alpha * 255]),
623-
how="linear",
624-
)
631+
) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
625632
else:
626633
spread_how = _datshader_get_how_kw_for_spread(render_params.ds_reduction)
627634
agg = ds.tf.spread(agg, px=px, how=spread_how)
@@ -631,15 +638,17 @@ def _render_points(
631638
# in case all elements have the same value X: we render them using cmap(0.0),
632639
# using an artificial "span" of [X, X + 1] for the color bar
633640
# else: all elements would get alpha=0 and the color bar would have a weird range
634-
if aggregate_with_reduction[0] == aggregate_with_reduction[1]:
641+
if aggregate_with_reduction[0] == aggregate_with_reduction[1] and (ds_span is None or ds_span != [0, 1]):
635642
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
636643
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)
637644

638-
ds_result = ds.tf.shade(
645+
ds_result = _datashader_map_aggregate_to_color(
639646
agg,
640647
cmap=ds_cmap,
641-
how="linear",
642-
)
648+
span=ds_span,
649+
clip=norm.clip,
650+
min_alpha=np.min([254, render_params.alpha * 255]),
651+
) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
643652

644653
rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
645654
_ax_show_and_transform(
@@ -656,8 +665,10 @@ def _render_points(
656665
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
657666
vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax
658667
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
659-
vmin = norm.vmin
660-
vmax = norm.vmin + 1
668+
# value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
669+
# under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
670+
vmin = norm.vmin - 0.5
671+
vmax = norm.vmin + 0.5
661672
cax = ScalarMappable(
662673
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax),
663674
cmap=render_params.cmap_params.cmap,
@@ -723,7 +734,6 @@ def _render_images(
723734
legend_params: LegendParams,
724735
rasterize: bool,
725736
) -> None:
726-
727737
sdata_filt = sdata.filter_by_coordinate_system(
728738
coordinate_system=coordinate_system,
729739
filter_tables=False,
@@ -781,9 +791,6 @@ def _render_images(
781791
if n_channels == 1 and not isinstance(render_params.cmap_params, list):
782792
layer = img.sel(c=channels[0]).squeeze() if isinstance(channels[0], str) else img.isel(c=channels[0]).squeeze()
783793

784-
if render_params.cmap_params.norm: # type: ignore[attr-defined]
785-
layer = render_params.cmap_params.norm(layer) # type: ignore[attr-defined]
786-
787794
cmap = (
788795
_get_linear_colormap(palette, "k")[0]
789796
if isinstance(palette, list) and all(isinstance(p, str) for p in palette)
@@ -794,7 +801,10 @@ def _render_images(
794801
cmap._init()
795802
cmap._lut[:, -1] = render_params.alpha
796803

797-
_ax_show_and_transform(layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder)
804+
# norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip.
805+
_ax_show_and_transform(
806+
layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder, norm=render_params.cmap_params.norm
807+
)
798808

799809
if legend_params.colorbar:
800810
sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm)

0 commit comments

Comments
 (0)