11
11
import pandas as pd
12
12
import scanpy as sc
13
13
import spatialdata as sd
14
- import xarray as xr
15
14
from anndata import AnnData
16
15
from geopandas import GeoDataFrame
17
16
from matplotlib import colors
@@ -328,7 +327,6 @@ def _render_images(
328
327
fig_params : FigParams ,
329
328
scalebar_params : ScalebarParams ,
330
329
legend_params : LegendParams ,
331
- # extent: tuple[float, float, float, float] | None = None,
332
330
) -> None :
333
331
elements = render_params .elements
334
332
@@ -346,9 +344,6 @@ def _render_images(
346
344
images = [sdata .images [e ] for e in elements ]
347
345
348
346
for img in images :
349
- if (len (img .c ) > 3 or len (img .c ) == 2 ) and render_params .channel is None :
350
- raise NotImplementedError ("Only 1 or 3 channels are supported at the moment." )
351
-
352
347
if render_params .channel is None :
353
348
channels = img .coords ["c" ].values
354
349
else :
@@ -358,11 +353,8 @@ def _render_images(
358
353
359
354
n_channels = len (channels )
360
355
356
+ # True if user gave n cmaps for n channels
361
357
got_multiple_cmaps = isinstance (render_params .cmap_params , list )
362
-
363
- if not isinstance (render_params .cmap_params , list ):
364
- render_params .cmap_params = [render_params .cmap_params ] * n_channels
365
-
366
358
if got_multiple_cmaps :
367
359
logger .warning (
368
360
"You're blending multiple cmaps. "
@@ -372,102 +364,113 @@ def _render_images(
372
364
"Consider using 'palette' instead."
373
365
)
374
366
375
- if render_params .palette is not None :
376
- logger .warning ("Parameter 'palette' is ignored when a 'cmap' is provided." )
367
+ # not using got_multiple_cmaps here because of ruff :(
368
+ if isinstance (render_params .cmap_params , list ) and len (render_params .cmap_params ) != n_channels :
369
+ raise ValueError ("If 'cmap' is provided, its length must match the number of channels." )
370
+
371
+ # 1) Image has only 1 channel
372
+ if n_channels == 1 and not isinstance (render_params .cmap_params , list ):
373
+ layer = img .sel (c = channels ).squeeze ()
374
+
375
+ if render_params .quantiles_for_norm != (None , None ):
376
+ layer = _normalize (
377
+ layer , pmin = render_params .quantiles_for_norm [0 ], pmax = render_params .quantiles_for_norm [1 ], clip = True
378
+ )
379
+
380
+ if render_params .cmap_params .norm is not None : # type: ignore[attr-defined]
381
+ layer = render_params .cmap_params .norm (layer ) # type: ignore[attr-defined]
377
382
378
- for idx , channel in enumerate (channels ):
379
- layer = img .sel (c = channel )
383
+ if render_params .palette is None :
384
+ cmap = render_params .cmap_params .cmap # type: ignore[attr-defined]
385
+ else :
386
+ cmap = _get_linear_colormap ([render_params .palette ], "k" )[0 ]
387
+
388
+ ax .imshow (
389
+ layer , # get rid of the channel dimension
390
+ cmap = cmap ,
391
+ alpha = render_params .alpha ,
392
+ )
393
+
394
+ # 2) Image has any number of channels but 1
395
+ else :
396
+ layers = {}
397
+ for i , c in enumerate (channels ):
398
+ layers [c ] = img .sel (c = c ).copy (deep = True ).squeeze ()
380
399
381
400
if render_params .quantiles_for_norm != (None , None ):
382
- layer = _normalize (
383
- layer ,
401
+ layers [ c ] = _normalize (
402
+ layers [ c ] ,
384
403
pmin = render_params .quantiles_for_norm [0 ],
385
404
pmax = render_params .quantiles_for_norm [1 ],
386
405
clip = True ,
387
406
)
388
407
389
- if render_params .cmap_params [idx ].norm is not None :
390
- layer = render_params .cmap_params [idx ].norm (layer )
408
+ if not isinstance (render_params .cmap_params , list ):
409
+ if render_params .cmap_params .norm is not None :
410
+ layers [c ] = render_params .cmap_params .norm (layers [c ])
411
+ else :
412
+ if render_params .cmap_params [i ].norm is not None :
413
+ layers [c ] = render_params .cmap_params [i ].norm (layers [c ])
391
414
392
- ax .imshow (
393
- layer ,
394
- cmap = render_params .cmap_params [idx ].cmap ,
395
- alpha = (1 / n_channels ),
396
- )
397
- break
415
+ # 2A) Image has 3 channels, no palette/cmap info -> use RGB
416
+ if n_channels == 3 and render_params .palette is None and not got_multiple_cmaps :
417
+ ax .imshow (np .stack ([layers [c ] for c in channels ], axis = - 1 ), alpha = render_params .alpha )
398
418
399
- if n_channels == 1 :
400
- layer = img .sel (c = channels )
419
+ # 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
420
+ elif render_params .palette is None and not got_multiple_cmaps :
421
+ # overwrite if n_channels == 2 for intuitive result
422
+ if n_channels == 2 :
423
+ seed_colors = ["#ff0000ff" , "#00ff00ff" ]
424
+ else :
425
+ seed_colors = _get_colors_for_categorical_obs (list (range (n_channels )))
401
426
402
- if render_params .quantiles_for_norm != (None , None ):
403
- layer = _normalize (
404
- layer , pmin = render_params .quantiles_for_norm [0 ], pmax = render_params .quantiles_for_norm [1 ], clip = True
405
- )
427
+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
406
428
407
- if render_params . cmap_params [ 0 ]. norm is not None :
408
- layer = render_params . cmap_params [ 0 ]. norm ( layer )
429
+ # Apply cmaps to each channel and add up
430
+ colored = np . stack ([ channel_cmaps [ i ]( layers [ c ]) for i , c in enumerate ( channels )], 0 ). sum ( 0 )
409
431
410
- if render_params .palette is None :
411
- ax .imshow (
412
- layer .squeeze (), # get rid of the channel dimension
413
- cmap = render_params .cmap_params [0 ].cmap ,
414
- )
432
+ # Remove alpha channel so we can overwrite it from render_params.alpha
433
+ colored = colored [:, :, :3 ]
415
434
416
- else :
417
435
ax .imshow (
418
- layer . squeeze (), # get rid of the channel dimension
419
- cmap = _get_linear_colormap ([ render_params .palette ], "k" )[ 0 ] ,
436
+ colored ,
437
+ alpha = render_params .alpha ,
420
438
)
421
439
422
- break
440
+ # 2C) Image has n channels and palette info
441
+ elif render_params .palette is not None and not got_multiple_cmaps :
442
+ if len (render_params .palette ) != n_channels :
443
+ raise ValueError ("If 'palette' is provided, its length must match the number of channels." )
423
444
424
- if render_params .palette is not None and n_channels != len (render_params .palette ):
425
- raise ValueError ("If 'palette' is provided, its length must match the number of channels." )
445
+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in render_params .palette ]
426
446
427
- if n_channels > 1 :
428
- layer = img . sel ( c = channels ). copy ( deep = True )
447
+ # Apply cmaps to each channel and add up
448
+ colored = np . stack ([ channel_cmaps [ i ]( layers [ c ]) for i , c in enumerate ( channels )], 0 ). sum ( 0 )
429
449
430
- channel_colors : list [str ] | Any
431
- if render_params .palette is None :
432
- channel_colors = _get_colors_for_categorical_obs (
433
- layer .coords ["c" ].values .tolist (), palette = render_params .cmap_params [0 ].cmap
450
+ # Remove alpha channel so we can overwrite it from render_params.alpha
451
+ colored = colored [:, :, :3 ]
452
+
453
+ ax .imshow (
454
+ colored ,
455
+ alpha = render_params .alpha ,
434
456
)
435
- else :
436
- channel_colors = render_params .palette
437
457
438
- channel_cmaps = _get_linear_colormap ([str (c ) for c in channel_colors [:n_channels ]], "k" )
458
+ elif render_params .palette is None and got_multiple_cmaps :
459
+ channel_cmaps = [cp .cmap for cp in render_params .cmap_params ] # type: ignore[union-attr]
439
460
440
- layer_vals = []
441
- if render_params .quantiles_for_norm != (None , None ):
442
- for i in range (n_channels ):
443
- layer_vals .append (
444
- _normalize (
445
- layer .values [i ],
446
- pmin = render_params .quantiles_for_norm [0 ],
447
- pmax = render_params .quantiles_for_norm [1 ],
448
- clip = True ,
449
- )
450
- )
461
+ # Apply cmaps to each channel, add up and normalize to [0, 1]
462
+ colored = np .stack ([channel_cmaps [i ](layers [c ]) for i , c in enumerate (channels )], 0 ).sum (0 ) / n_channels
451
463
452
- colored = np .stack ([channel_cmaps [i ](layer_vals [i ]) for i in range (n_channels )], 0 ).sum (0 )
464
+ # Remove alpha channel so we can overwrite it from render_params.alpha
465
+ colored = colored [:, :, :3 ]
453
466
454
- layer = xr .DataArray (
455
- data = colored ,
456
- coords = [
457
- layer .coords ["y" ],
458
- layer .coords ["x" ],
459
- ["R" , "G" , "B" , "A" ],
460
- ],
461
- dims = ["y" , "x" , "c" ],
462
- )
463
- layer = layer .transpose ("y" , "x" , "c" ) # for plotting
467
+ ax .imshow (
468
+ colored ,
469
+ alpha = render_params .alpha ,
470
+ )
464
471
465
- ax .imshow (
466
- layer .data ,
467
- cmap = channel_cmaps [0 ],
468
- alpha = render_params .alpha ,
469
- norm = render_params .cmap_params [0 ].norm ,
470
- )
472
+ elif render_params .palette is not None and got_multiple_cmaps :
473
+ raise ValueError ("If 'palette' is provided, 'cmap' must be None." )
471
474
472
475
473
476
@dataclass
0 commit comments