4
4
from copy import copy
5
5
from dataclasses import dataclass
6
6
from functools import partial
7
- from typing import Any , Callable , Optional , Union
7
+ from typing import Any , Callable , Union
8
8
9
9
import matplotlib
10
10
import numpy as np
21
21
from pandas .api .types import is_categorical_dtype
22
22
from scanpy ._settings import settings as sc_settings
23
23
24
+ from spatialdata_plot ._logging import logger
24
25
from spatialdata_plot .pl .utils import (
25
26
CmapParams ,
26
27
FigParams ,
37
38
)
38
39
from spatialdata_plot .pp .utils import _get_instance_key , _get_region_key
39
40
40
- Palette_t = Optional [Union [str , ListedColormap ]]
41
41
_Normalize = Union [Normalize , Sequence [Normalize ]]
42
42
to_hex = partial (colors .to_hex , keep_alpha = True )
43
43
@@ -54,7 +54,7 @@ class ShapesRenderParams:
54
54
contour_px : int | None = None
55
55
alt_var : str | None = None
56
56
layer : str | None = None
57
- palette : Palette_t = None
57
+ palette : ListedColormap | str | None = None
58
58
outline_alpha : float = 1.0
59
59
fill_alpha : float = 0.3
60
60
size : float = 1.0
@@ -208,7 +208,7 @@ class PointsRenderParams:
208
208
elements : str | Sequence [str ] | None = None
209
209
color : str | None = None
210
210
groups : str | Sequence [str ] | None = None
211
- palette : Palette_t = None
211
+ palette : ListedColormap | str | None = None
212
212
alpha : float = 1.0
213
213
size : float = 1.0
214
214
transfunc : Callable [[float ], float ] | None = None
@@ -312,11 +312,12 @@ def _render_points(
312
312
class ImageRenderParams :
313
313
"""Labels render parameters.."""
314
314
315
- cmap_params : CmapParams
315
+ cmap_params : list [ CmapParams ] | CmapParams
316
316
elements : str | Sequence [str ] | None = None
317
317
channel : list [str ] | list [int ] | int | str | None = None
318
- palette : Palette_t = None
318
+ palette : ListedColormap | str | None = None
319
319
alpha : float = 1.0
320
+ quantiles_for_norm : tuple [float | None , float | None ] = (3.0 , 99.8 ) # defaults from CSBDeep
320
321
321
322
322
323
def _render_images (
@@ -347,47 +348,126 @@ def _render_images(
347
348
for img in images :
348
349
if (len (img .c ) > 3 or len (img .c ) == 2 ) and render_params .channel is None :
349
350
raise NotImplementedError ("Only 1 or 3 channels are supported at the moment." )
350
- if render_params .channel is None and len (img .c ) == 1 :
351
- render_params .channel = 0
352
- if render_params .channel is not None :
351
+
352
+ if render_params .channel is None :
353
+ channels = img .coords ["c" ].values
354
+ else :
353
355
channels = (
354
356
[render_params .channel ] if isinstance (render_params .channel , (str , int )) else render_params .channel
355
357
)
356
- img = img .sel (c = channels )
357
- num_channels = img .sizes ["c" ]
358
+
359
+ n_channels = len (channels )
360
+
361
+ got_multiple_cmaps = isinstance (render_params .cmap_params , list )
362
+
363
+ if not isinstance (render_params .cmap_params , list ):
364
+ render_params .cmap_params = [render_params .cmap_params ] * n_channels
365
+
366
+ if got_multiple_cmaps :
367
+ logger .warning (
368
+ "You're blending multiple cmaps. "
369
+ "If the plot doesn't look like you expect, it might be because your "
370
+ "cmaps go from a given color to 'white', and not to 'transparent'. "
371
+ "Therefore, the 'white' of higher layers will overlay the lower layers. "
372
+ "Consider using 'palette' instead."
373
+ )
358
374
359
375
if render_params .palette is not None :
360
- if num_channels > len (render_params .palette ):
361
- raise ValueError ("If palette is provided, it must match the number of channels." )
376
+ logger .warning ("Parameter 'palette' is ignored when a 'cmap' is provided." )
377
+
378
+ for idx , channel in enumerate (channels ):
379
+ layer = img .sel (c = channel )
380
+
381
+ if render_params .quantiles_for_norm != (None , None ):
382
+ layer = _normalize (
383
+ layer ,
384
+ pmin = render_params .quantiles_for_norm [0 ],
385
+ pmax = render_params .quantiles_for_norm [1 ],
386
+ clip = True ,
387
+ )
388
+
389
+ if render_params .cmap_params [idx ].norm is not None :
390
+ layer = render_params .cmap_params [idx ].norm (layer )
391
+
392
+ ax .imshow (
393
+ layer ,
394
+ cmap = render_params .cmap_params [idx ].cmap ,
395
+ alpha = (1 / n_channels ),
396
+ )
397
+ break
398
+
399
+ if n_channels == 1 :
400
+ layer = img .sel (c = channels )
401
+
402
+ if render_params .quantiles_for_norm != (None , None ):
403
+ layer = _normalize (
404
+ layer , pmin = render_params .quantiles_for_norm [0 ], pmax = render_params .quantiles_for_norm [1 ], clip = True
405
+ )
406
+
407
+ if render_params .cmap_params [0 ].norm is not None :
408
+ layer = render_params .cmap_params [0 ].norm (layer )
362
409
363
- color = render_params .palette
410
+ if render_params .palette is None :
411
+ ax .imshow (
412
+ layer .squeeze (), # get rid of the channel dimension
413
+ cmap = render_params .cmap_params [0 ].cmap ,
414
+ )
364
415
365
416
else :
366
- color = _get_colors_for_categorical_obs (
367
- img .coords ["c" ].values .tolist (), palette = render_params .cmap_params .cmap
417
+ ax .imshow (
418
+ layer .squeeze (), # get rid of the channel dimension
419
+ cmap = _get_linear_colormap ([render_params .palette ], "k" )[0 ],
368
420
)
369
421
370
- cmaps = _get_linear_colormap ([str (c ) for c in color [:num_channels ]], "k" )
371
- img = _normalize (img , clip = True )
372
- colored = np .stack ([cmaps [i ](img .values [i ]) for i in range (num_channels )], 0 ).sum (0 )
373
- img = xr .DataArray (
422
+ break
423
+
424
+ if render_params .palette is not None and n_channels != len (render_params .palette ):
425
+ raise ValueError ("If 'palette' is provided, its length must match the number of channels." )
426
+
427
+ if n_channels > 1 :
428
+ layer = img .sel (c = channels ).copy (deep = True )
429
+
430
+ channel_colors : list [str ] | Any
431
+ if render_params .palette is None :
432
+ channel_colors = _get_colors_for_categorical_obs (
433
+ layer .coords ["c" ].values .tolist (), palette = render_params .cmap_params [0 ].cmap
434
+ )
435
+ else :
436
+ channel_colors = render_params .palette
437
+
438
+ channel_cmaps = _get_linear_colormap ([str (c ) for c in channel_colors [:n_channels ]], "k" )
439
+
440
+ layer_vals = []
441
+ if render_params .quantiles_for_norm != (None , None ):
442
+ for i in range (n_channels ):
443
+ layer_vals .append (
444
+ _normalize (
445
+ layer .values [i ],
446
+ pmin = render_params .quantiles_for_norm [0 ],
447
+ pmax = render_params .quantiles_for_norm [1 ],
448
+ clip = True ,
449
+ )
450
+ )
451
+
452
+ colored = np .stack ([channel_cmaps [i ](layer_vals [i ]) for i in range (n_channels )], 0 ).sum (0 )
453
+
454
+ layer = xr .DataArray (
374
455
data = colored ,
375
456
coords = [
376
- img .coords ["y" ],
377
- img .coords ["x" ],
457
+ layer .coords ["y" ],
458
+ layer .coords ["x" ],
378
459
["R" , "G" , "B" , "A" ],
379
460
],
380
461
dims = ["y" , "x" , "c" ],
381
462
)
463
+ layer = layer .transpose ("y" , "x" , "c" ) # for plotting
382
464
383
- img = img .transpose ("y" , "x" , "c" ) # for plotting
384
-
385
- ax .imshow (
386
- img .data ,
387
- cmap = render_params .cmap_params .cmap ,
388
- alpha = render_params .alpha ,
389
- # extent=extent,
390
- )
465
+ ax .imshow (
466
+ layer .data ,
467
+ cmap = channel_cmaps [0 ],
468
+ alpha = render_params .alpha ,
469
+ norm = render_params .cmap_params [0 ].norm ,
470
+ )
391
471
392
472
393
473
@dataclass
@@ -402,7 +482,7 @@ class LabelsRenderParams:
402
482
outline : bool = False
403
483
alt_var : str | None = None
404
484
layer : str | None = None
405
- palette : Palette_t = None
485
+ palette : ListedColormap | str | None = None
406
486
outline_alpha : float = 1.0
407
487
fill_alpha : float = 0.4
408
488
transfunc : Callable [[float ], float ] | None = None
0 commit comments