Skip to content

Commit 4a244c5

Browse files
committed
another update
1 parent 9431eed commit 4a244c5

File tree

3 files changed

+50
-5
lines changed

3 files changed

+50
-5
lines changed

src/spatialdata_plot/pl/_categorical_utils.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,8 @@ def _get_colors_for_categorical_obs(categories: Sequence[Union[str, int]]) -> li
498498

499499

500500
Palette_t = Optional[Union[str, ListedColormap]]
501+
from matplotlib.colors import ListedColormap, to_hex, to_rgba
502+
import matplotlib.pyplot as plt
501503

502504

503505
def _get_palette(
@@ -509,15 +511,15 @@ def _get_palette(
509511
) -> Mapping[str, str] | None:
510512
if palette is None:
511513
try:
512-
palette = adata.uns[Key.uns.colors(cluster_key)] # type: ignore[arg-type]
514+
palette = adata.uns[f"{cluster_key}_colors"] # type: ignore[arg-type]
513515
if len(palette) != len(categories):
514516
raise ValueError(
515517
f"Expected palette to be of length `{len(categories)}`, found `{len(palette)}`. "
516518
+ f"Removing the colors in `adata.uns` with `adata.uns.pop('{cluster_key}_colors')` may help."
517519
)
518520
return {cat: to_hex(to_rgba(col)[:3] + (alpha,), keep_alpha=True) for cat, col in zip(categories, palette)}
519521
except KeyError as e:
520-
logg.error(f"Unable to fetch palette, reason: {e}. Using `None`.")
522+
print(e)
521523
return None
522524

523525
len_cat = len(categories)
@@ -530,3 +532,27 @@ def _get_palette(
530532
raise TypeError(f"Palette is {type(palette)} but should be string or `ListedColormap`.")
531533

532534
return dict(zip(categories, palette))
535+
536+
537+
def _maybe_set_colors(
538+
source: AnnData, target: AnnData, key: str, palette: str | ListedColormap | Cycler | Sequence[Any] | None = None
539+
) -> None:
540+
color_key = f"{key}_colors"
541+
from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation
542+
543+
# this is insane, basically the version copied here was from napari
544+
# in napari is modified because we have to do some tricks to plot the categorical values
545+
# hence it requires the argument vec. But here we don't, so am re importing the original one
546+
# from scanpy here.
547+
# this is a testament to how broken the categorical color handling is in the scanpy ecosystem and
548+
# to the fact that, because I've never fixed it, an embarassing amount of intellectual debt has
549+
# been accumulated.
550+
551+
try:
552+
if palette is not None:
553+
raise KeyError("Unable to copy the palette when there was other explicitly specified.")
554+
target.uns[color_key] = source.uns[color_key]
555+
except KeyError:
556+
if isinstance(palette, ListedColormap): # `scanpy` requires it
557+
palette = cycler(color=palette.colors)
558+
add_colors_for_categorical_sample_annotation(target, key=key, force_update_colors=True, palette=palette)

src/spatialdata_plot/pl/basic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
_get_subplots,
4646
)
4747

48+
from spatialdata_plot.pl._categorical_utils import _maybe_set_colors
49+
4850

4951
@register_spatial_data_accessor("pl")
5052
class PlotAccessor:
@@ -793,6 +795,9 @@ def show(
793795
vec=_get_color_key_values(sdata, params["color_key"]),
794796
palette=params["palette"],
795797
)
798+
_maybe_set_colors(
799+
source=sdata.table, target=sdata.table, key=params["color_key"], palette=None
800+
)
796801
else:
797802
# If any of the previous conditions are not met, generate random
798803
# colors for each cell id

src/spatialdata_plot/pl/render.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from skimage.segmentation import find_boundaries
3131
from skimage.util import map_array
3232
from functools import partial
33-
from ..pl._categorical_utils import _get_colors_for_categorical_obs, _get_palette
33+
from ..pl._categorical_utils import _get_colors_for_categorical_obs, _get_palette, _maybe_set_colors
3434
from ..pl.utils import _normalize
3535
from ..pp.utils import _get_linear_colormap, _get_region_key
3636

@@ -221,8 +221,10 @@ def _render_images(
221221

222222
ax.set_title(key)
223223

224+
224225
import matplotlib.pyplot as plt
225226

227+
226228
def _render_labels(
227229
sdata: sd.SpatialData,
228230
params: dict[str, Union[str, int, float]],
@@ -234,7 +236,6 @@ def _render_labels(
234236

235237
# subset table to only the entires specified by 'key'
236238
table = sdata.table[sdata.table.obs[region_key] == key]
237-
238239
segmentation = sdata.labels[key].values
239240

240241
norm = Normalize(vmin=None, vmax=None)
@@ -271,7 +272,20 @@ def _render_labels(
271272

272273
# ax.legend(handles=patches, bbox_to_anchor=(0.9, 0.9), loc="upper left", frameon=False)
273274
# ax.colorbar(pad=0.01, fraction=0.08, aspect=30)
274-
plt.colorbar(cax, ax=ax, pad=0.01, fraction=0.08, aspect=30)
275+
if is_categorical_dtype(color_source_vector):
276+
clusters = color_source_vector.categories
277+
palette = _get_palette(table, cluster_key=params["color_key"], categories=clusters)
278+
for label in clusters:
279+
ax.scatter([], [], c=palette[label], label=label)
280+
ax.legend(
281+
frameon=False,
282+
loc="center left",
283+
bbox_to_anchor=(1, 0.5),
284+
ncol=(1 if len(clusters) <= 14 else 2 if len(clusters) <= 30 else 3),
285+
fontsize=None,
286+
)
287+
else:
288+
plt.colorbar(cax, ax=ax, pad=0.01, fraction=0.08, aspect=30)
275289
ax.set_title(key)
276290

277291

0 commit comments

Comments
 (0)