37
37
_ax_show_and_transform ,
38
38
_create_image_from_datashader_result ,
39
39
_datashader_aggregate_with_function ,
40
+ _datashader_map_aggregate_to_color ,
40
41
_datshader_get_how_kw_for_spread ,
41
42
_decorate_axs ,
42
43
_get_collection_shape ,
@@ -229,18 +230,20 @@ def _render_shapes(
229
230
line_width = render_params .outline_params .linewidth ,
230
231
)
231
232
233
+ ds_span = None
232
234
if norm .vmin is not None or norm .vmax is not None :
233
235
norm .vmin = np .min (agg ) if norm .vmin is None else norm .vmin
234
236
norm .vmax = np .max (agg ) if norm .vmax is None else norm .vmax
235
- norm . clip = True # NOTE: mpl currently behaves like clip is always True
237
+ ds_span = [ norm . vmin , norm . vmax ]
236
238
if norm .vmin == norm .vmax :
237
- # data is mapped to 0
238
- agg = agg - agg
239
- else :
240
- agg = (agg - norm .vmin ) / (norm .vmax - norm .vmin )
239
+ # edge case, value vmin is rendered as the middle of the cmap
240
+ ds_span = [0 , 1 ]
241
241
if norm .clip :
242
- agg = np .maximum (agg , 0 )
243
- agg = np .minimum (agg , 1 )
242
+ agg = (agg - agg ) + 0.5
243
+ else :
244
+ agg = agg .where ((agg >= norm .vmin ) | (np .isnan (agg )), other = - 1 )
245
+ agg = agg .where ((agg <= norm .vmin ) | (np .isnan (agg )), other = 2 )
246
+ agg = agg .where ((agg != norm .vmin ) | (np .isnan (agg )), other = 0.5 )
244
247
245
248
color_key = (
246
249
[x [:- 2 ] for x in color_vector .categories .values ]
@@ -256,13 +259,12 @@ def _render_shapes(
256
259
if isinstance (ds_cmap , str ) and ds_cmap [0 ] == "#" :
257
260
ds_cmap = ds_cmap [:- 2 ]
258
261
259
- ds_result = ds . tf . shade (
262
+ ds_result = _datashader_map_aggregate_to_color (
260
263
agg ,
261
264
cmap = ds_cmap ,
262
265
color_key = color_key ,
263
266
min_alpha = np .min ([254 , render_params .fill_alpha * 255 ]),
264
- how = "linear" ,
265
- )
267
+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
266
268
elif aggregate_with_reduction is not None : # to shut up mypy
267
269
ds_cmap = render_params .cmap_params .cmap
268
270
# in case all elements have the same value X: we render them using cmap(0.0),
@@ -272,12 +274,13 @@ def _render_shapes(
272
274
ds_cmap = matplotlib .colors .to_hex (render_params .cmap_params .cmap (0.0 ), keep_alpha = False )
273
275
aggregate_with_reduction = (aggregate_with_reduction [0 ], aggregate_with_reduction [0 ] + 1 )
274
276
275
- ds_result = ds . tf . shade (
277
+ ds_result = _datashader_map_aggregate_to_color (
276
278
agg ,
277
279
cmap = ds_cmap ,
278
- how = "linear" ,
279
280
min_alpha = np .min ([254 , render_params .fill_alpha * 255 ]),
280
- )
281
+ span = ds_span ,
282
+ clip = norm .clip ,
283
+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
281
284
282
285
# shade outlines if needed
283
286
outline_color = render_params .outline_params .outline_color
@@ -294,7 +297,7 @@ def _render_shapes(
294
297
cmap = outline_color ,
295
298
min_alpha = np .min ([254 , render_params .outline_alpha * 255 ]),
296
299
how = "linear" ,
297
- )
300
+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
298
301
299
302
rgba_image , trans_data = _create_image_from_datashader_result (ds_result , factor , ax )
300
303
_cax = _ax_show_and_transform (
@@ -322,8 +325,10 @@ def _render_shapes(
322
325
vmin = aggregate_with_reduction [0 ].values if norm .vmin is None else norm .vmin
323
326
vmax = aggregate_with_reduction [1 ].values if norm .vmin is None else norm .vmax
324
327
if (norm .vmin is not None or norm .vmax is not None ) and norm .vmin == norm .vmax :
325
- vmin = norm .vmin
326
- vmax = norm .vmin + 1
328
+ # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
329
+ # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
330
+ vmin = norm .vmin - 0.5
331
+ vmax = norm .vmin + 0.5
327
332
cax = ScalarMappable (
328
333
norm = matplotlib .colors .Normalize (vmin = vmin , vmax = vmax ),
329
334
cmap = render_params .cmap_params .cmap ,
@@ -586,18 +591,21 @@ def _render_points(
586
591
else :
587
592
agg = cvs .points (transformed_element , "x" , "y" , agg = ds .count ())
588
593
594
+ ds_span = None
589
595
if norm .vmin is not None or norm .vmax is not None :
590
596
norm .vmin = np .min (agg ) if norm .vmin is None else norm .vmin
591
597
norm .vmax = np .max (agg ) if norm .vmax is None else norm .vmax
592
- norm . clip = True # NOTE: mpl currently behaves like clip is always True
598
+ ds_span = [ norm . vmin , norm . vmax ]
593
599
if norm .vmin == norm .vmax :
594
- # data is mapped to 0
595
- agg = agg - agg
596
- else :
597
- agg = (agg - norm .vmin ) / (norm .vmax - norm .vmin )
600
+ ds_span = [0 , 1 ]
598
601
if norm .clip :
599
- agg = np .maximum (agg , 0 )
600
- agg = np .minimum (agg , 1 )
602
+ # all data is mapped to 0.5
603
+ agg = (agg - agg ) + 0.5
604
+ else :
605
+ # values equal to norm.vmin are mapped to 0.5, the rest to -1 or 2
606
+ agg = agg .where ((agg >= norm .vmin ) | (np .isnan (agg )), other = - 1 )
607
+ agg = agg .where ((agg <= norm .vmin ) | (np .isnan (agg )), other = 2 )
608
+ agg = agg .where ((agg != norm .vmin ) | (np .isnan (agg )), other = 0.5 )
601
609
602
610
color_key = (
603
611
list (color_vector .categories .values )
@@ -615,13 +623,12 @@ def _render_points(
615
623
color_vector = np .asarray ([x [:- 2 ] for x in color_vector ])
616
624
617
625
if color_by_categorical or col_for_color is None :
618
- ds_result = ds . tf . shade (
626
+ ds_result = _datashader_map_aggregate_to_color (
619
627
ds .tf .spread (agg , px = px ),
620
628
cmap = color_vector [0 ],
621
629
color_key = color_key ,
622
630
min_alpha = np .min ([254 , render_params .alpha * 255 ]),
623
- how = "linear" ,
624
- )
631
+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
625
632
else :
626
633
spread_how = _datshader_get_how_kw_for_spread (render_params .ds_reduction )
627
634
agg = ds .tf .spread (agg , px = px , how = spread_how )
@@ -631,15 +638,17 @@ def _render_points(
631
638
# in case all elements have the same value X: we render them using cmap(0.0),
632
639
# using an artificial "span" of [X, X + 1] for the color bar
633
640
# else: all elements would get alpha=0 and the color bar would have a weird range
634
- if aggregate_with_reduction [0 ] == aggregate_with_reduction [1 ]:
641
+ if aggregate_with_reduction [0 ] == aggregate_with_reduction [1 ] and ( ds_span is None or ds_span != [ 0 , 1 ]) :
635
642
ds_cmap = matplotlib .colors .to_hex (render_params .cmap_params .cmap (0.0 ), keep_alpha = False )
636
643
aggregate_with_reduction = (aggregate_with_reduction [0 ], aggregate_with_reduction [0 ] + 1 )
637
644
638
- ds_result = ds . tf . shade (
645
+ ds_result = _datashader_map_aggregate_to_color (
639
646
agg ,
640
647
cmap = ds_cmap ,
641
- how = "linear" ,
642
- )
648
+ span = ds_span ,
649
+ clip = norm .clip ,
650
+ min_alpha = np .min ([254 , render_params .alpha * 255 ]),
651
+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
643
652
644
653
rgba_image , trans_data = _create_image_from_datashader_result (ds_result , factor , ax )
645
654
_ax_show_and_transform (
@@ -656,8 +665,10 @@ def _render_points(
656
665
vmin = aggregate_with_reduction [0 ].values if norm .vmin is None else norm .vmin
657
666
vmax = aggregate_with_reduction [1 ].values if norm .vmax is None else norm .vmax
658
667
if (norm .vmin is not None or norm .vmax is not None ) and norm .vmin == norm .vmax :
659
- vmin = norm .vmin
660
- vmax = norm .vmin + 1
668
+ # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
669
+ # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
670
+ vmin = norm .vmin - 0.5
671
+ vmax = norm .vmin + 0.5
661
672
cax = ScalarMappable (
662
673
norm = matplotlib .colors .Normalize (vmin = vmin , vmax = vmax ),
663
674
cmap = render_params .cmap_params .cmap ,
@@ -723,7 +734,6 @@ def _render_images(
723
734
legend_params : LegendParams ,
724
735
rasterize : bool ,
725
736
) -> None :
726
-
727
737
sdata_filt = sdata .filter_by_coordinate_system (
728
738
coordinate_system = coordinate_system ,
729
739
filter_tables = False ,
@@ -781,9 +791,6 @@ def _render_images(
781
791
if n_channels == 1 and not isinstance (render_params .cmap_params , list ):
782
792
layer = img .sel (c = channels [0 ]).squeeze () if isinstance (channels [0 ], str ) else img .isel (c = channels [0 ]).squeeze ()
783
793
784
- if render_params .cmap_params .norm : # type: ignore[attr-defined]
785
- layer = render_params .cmap_params .norm (layer ) # type: ignore[attr-defined]
786
-
787
794
cmap = (
788
795
_get_linear_colormap (palette , "k" )[0 ]
789
796
if isinstance (palette , list ) and all (isinstance (p , str ) for p in palette )
@@ -794,7 +801,10 @@ def _render_images(
794
801
cmap ._init ()
795
802
cmap ._lut [:, - 1 ] = render_params .alpha
796
803
797
- _ax_show_and_transform (layer , trans_data , ax , cmap = cmap , zorder = render_params .zorder )
804
+ # norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip.
805
+ _ax_show_and_transform (
806
+ layer , trans_data , ax , cmap = cmap , zorder = render_params .zorder , norm = render_params .cmap_params .norm
807
+ )
798
808
799
809
if legend_params .colorbar :
800
810
sm = plt .cm .ScalarMappable (cmap = cmap , norm = render_params .cmap_params .norm )
0 commit comments