11
11
import pandas as pd
12
12
import scanpy as sc
13
13
import spatialdata as sd
14
- import xarray as xr
15
14
from anndata import AnnData
16
15
from geopandas import GeoDataFrame
17
16
from matplotlib import colors
29
28
OutlineParams ,
30
29
ScalebarParams ,
31
30
_decorate_axs ,
31
+ _get_colors_for_categorical_obs ,
32
32
_get_linear_colormap ,
33
33
_map_color_seg ,
34
34
_maybe_set_colors ,
@@ -327,7 +327,6 @@ def _render_images(
327
327
fig_params : FigParams ,
328
328
scalebar_params : ScalebarParams ,
329
329
legend_params : LegendParams ,
330
- # extent: tuple[float, float, float, float] | None = None,
331
330
) -> None :
332
331
elements = render_params .elements
333
332
@@ -345,8 +344,8 @@ def _render_images(
345
344
images = [sdata .images [e ] for e in elements ]
346
345
347
346
for img in images :
348
- if (len (img .c ) > 3 or len (img .c ) == 2 ) and render_params .channel is None :
349
- raise NotImplementedError ("Only 1 or 3 channels are supported at the moment." )
347
+ # if (len(img.c) > 3 or len(img.c) == 2) and render_params.channel is None:
348
+ # raise NotImplementedError("Only 1 or 3 channels are supported at the moment.")
350
349
351
350
if render_params .channel is None :
352
351
channels = img .coords ["c" ].values
@@ -357,11 +356,8 @@ def _render_images(
357
356
358
357
n_channels = len (channels )
359
358
359
+ # True if user gave n cmaps for n channels
360
360
got_multiple_cmaps = isinstance (render_params .cmap_params , list )
361
-
362
- if not isinstance (render_params .cmap_params , list ):
363
- render_params .cmap_params = [render_params .cmap_params ] * n_channels
364
-
365
361
if got_multiple_cmaps :
366
362
logger .warning (
367
363
"You're blending multiple cmaps. "
@@ -371,100 +367,113 @@ def _render_images(
371
367
"Consider using 'palette' instead."
372
368
)
373
369
374
- if render_params .palette is not None :
375
- logger .warning ("Parameter 'palette' is ignored when a 'cmap' is provided." )
370
+ # not using got_multiple_cmaps here because of ruff :(
371
+ if isinstance (render_params .cmap_params , list ) and len (render_params .cmap_params ) != n_channels :
372
+ raise ValueError ("If 'cmap' is provided, its length must match the number of channels." )
373
+
374
+ # 1) Image has only 1 channel
375
+ if n_channels == 1 and not isinstance (render_params .cmap_params , list ):
376
+ layer = img .sel (c = channels ).squeeze ()
377
+
378
+ if render_params .quantiles_for_norm != (None , None ):
379
+ layer = _normalize (
380
+ layer , pmin = render_params .quantiles_for_norm [0 ], pmax = render_params .quantiles_for_norm [1 ], clip = True
381
+ )
382
+
383
+ if render_params .cmap_params .norm is not None : # type: ignore[attr-defined]
384
+ layer = render_params .cmap_params .norm (layer ) # type: ignore[attr-defined]
385
+
386
+ if render_params .palette is None :
387
+ cmap = render_params .cmap_params .cmap # type: ignore[attr-defined]
388
+ else :
389
+ cmap = _get_linear_colormap ([render_params .palette ], "k" )[0 ]
390
+
391
+ ax .imshow (
392
+ layer , # get rid of the channel dimension
393
+ cmap = cmap ,
394
+ alpha = render_params .alpha ,
395
+ )
376
396
377
- for idx , channel in enumerate (channels ):
378
- layer = img .sel (c = channel )
397
+ # 2) Image has any number of channels but 1
398
+ else :
399
+ layers = {}
400
+ for i , c in enumerate (channels ):
401
+ layers [c ] = img .sel (c = c ).copy (deep = True ).squeeze ()
379
402
380
403
if render_params .quantiles_for_norm != (None , None ):
381
- layer = _normalize (
382
- layer ,
404
+ layers [ c ] = _normalize (
405
+ layers [ c ] ,
383
406
pmin = render_params .quantiles_for_norm [0 ],
384
407
pmax = render_params .quantiles_for_norm [1 ],
385
408
clip = True ,
386
409
)
387
410
388
- if render_params .cmap_params [idx ].norm is not None :
389
- layer = render_params .cmap_params [idx ].norm (layer )
411
+ if not isinstance (render_params .cmap_params , list ):
412
+ if render_params .cmap_params .norm is not None :
413
+ layers [c ] = render_params .cmap_params .norm (layers [c ])
414
+ else :
415
+ if render_params .cmap_params [i ].norm is not None :
416
+ layers [c ] = render_params .cmap_params [i ].norm (layers [c ])
390
417
391
- ax .imshow (
392
- layer ,
393
- cmap = render_params .cmap_params [idx ].cmap ,
394
- alpha = (1 / n_channels ),
395
- )
396
- break
418
+ # 2A) Image has 3 channels, no palette/cmap info -> use RGB
419
+ if n_channels == 3 and render_params .palette is None and not got_multiple_cmaps :
420
+ ax .imshow (np .stack ([layers [c ] for c in channels ], axis = - 1 ), alpha = render_params .alpha )
397
421
398
- if n_channels == 1 :
399
- layer = img .sel (c = channels )
422
+ # 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
423
+ elif render_params .palette is None and not got_multiple_cmaps :
424
+ # overwrite if n_channels == 2 for intuitive result
425
+ if n_channels == 2 :
426
+ seed_colors = ["#ff0000ff" , "#00ff00ff" ]
427
+ else :
428
+ seed_colors = _get_colors_for_categorical_obs (list (range (n_channels )))
400
429
401
- if render_params .quantiles_for_norm != (None , None ):
402
- layer = _normalize (
403
- layer , pmin = render_params .quantiles_for_norm [0 ], pmax = render_params .quantiles_for_norm [1 ], clip = True
404
- )
430
+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
405
431
406
- if render_params . cmap_params [ 0 ]. norm is not None :
407
- layer = render_params . cmap_params [ 0 ]. norm ( layer )
432
+ # Apply cmaps to each channel and add up
433
+ colored = np . stack ([ channel_cmaps [ i ]( layers [ c ]) for i , c in enumerate ( channels )], 0 ). sum ( 0 )
408
434
409
- if render_params .palette is None :
410
- ax .imshow (
411
- layer .squeeze (), # get rid of the channel dimension
412
- cmap = render_params .cmap_params [0 ].cmap ,
413
- )
435
+ # Remove alpha channel so we can overwrite it from render_params.alpha
436
+ colored = colored [:, :, :3 ]
414
437
415
- else :
416
438
ax .imshow (
417
- layer . squeeze (), # get rid of the channel dimension
418
- cmap = _get_linear_colormap ([ render_params .palette ], "k" )[ 0 ] ,
439
+ colored ,
440
+ alpha = render_params .alpha ,
419
441
)
420
442
421
- break
443
+ # 2C) Image has n channels and palette info
444
+ elif render_params .palette is not None and not got_multiple_cmaps :
445
+ if len (render_params .palette ) != n_channels :
446
+ raise ValueError ("If 'palette' is provided, its length must match the number of channels." )
422
447
423
- if render_params .palette is not None and n_channels != len (render_params .palette ):
424
- raise ValueError ("If 'palette' is provided, its length must match the number of channels." )
448
+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in render_params .palette ]
425
449
426
- if n_channels > 1 : # to capture n_channels = 3 and custom number cases
427
- layer = img . sel ( c = channels ). copy ( deep = True )
450
+ # Apply cmaps to each channel and add up
451
+ colored = np . stack ([ channel_cmaps [ i ]( layers [ c ]) for i , c in enumerate ( channels )], 0 ). sum ( 0 )
428
452
429
- channel_colors : list [str ] | Any
430
- if render_params .palette is None :
431
- channel_colors = ["#ff0000ff" , "#00ff00ff" , "#0000ffff" ]
432
- else :
433
- channel_colors = render_params .palette
453
+ # Remove alpha channel so we can overwrite it from render_params.alpha
454
+ colored = colored [:, :, :3 ]
455
+
456
+ ax .imshow (
457
+ colored ,
458
+ alpha = render_params .alpha ,
459
+ )
434
460
435
- channel_cmaps = _get_linear_colormap ([str (c ) for c in channel_colors [:n_channels ]], "k" )
461
+ elif render_params .palette is None and got_multiple_cmaps :
462
+ channel_cmaps = [cp .cmap for cp in render_params .cmap_params ] # type: ignore[union-attr]
436
463
437
- layer_vals = []
438
- if render_params .quantiles_for_norm != (None , None ):
439
- for i in range (n_channels ):
440
- layer_vals .append (
441
- _normalize (
442
- layer .values [i ],
443
- pmin = render_params .quantiles_for_norm [0 ],
444
- pmax = render_params .quantiles_for_norm [1 ],
445
- clip = True ,
446
- )
447
- )
464
+ # Apply cmaps to each channel, add up and normalize to [0, 1]
465
+ colored = np .stack ([channel_cmaps [i ](layers [c ]) for i , c in enumerate (channels )], 0 ).sum (0 ) / n_channels
448
466
449
- colored = np .stack ([channel_cmaps [i ](layer_vals [i ]) for i in range (n_channels )], 0 ).sum (0 )
467
+ # Remove alpha channel so we can overwrite it from render_params.alpha
468
+ colored = colored [:, :, :3 ]
450
469
451
- layer = xr .DataArray (
452
- data = colored ,
453
- coords = [
454
- layer .coords ["y" ],
455
- layer .coords ["x" ],
456
- ["R" , "G" , "B" , "A" ],
457
- ],
458
- dims = ["y" , "x" , "c" ],
459
- )
460
- layer = layer .transpose ("y" , "x" , "c" ) # for plotting
470
+ ax .imshow (
471
+ colored ,
472
+ alpha = render_params .alpha ,
473
+ )
461
474
462
- ax .imshow (
463
- layer .data ,
464
- cmap = channel_cmaps [0 ],
465
- alpha = render_params .alpha ,
466
- norm = render_params .cmap_params [0 ].norm ,
467
- )
475
+ elif render_params .palette is not None and got_multiple_cmaps :
476
+ raise ValueError ("If 'palette' is provided, 'cmap' must be None." )
468
477
469
478
470
479
@dataclass
0 commit comments