Skip to content

Commit b314b75

Browse files
authored
Fix unused cmap in render_points (#432)
1 parent 3b82c78 commit b314b75

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,9 @@ def _render_points(
523523
palette=palette,
524524
na_color=default_color,
525525
cmap_params=render_params.cmap_params,
526+
alpha=render_params.alpha,
526527
table_name=table_name,
528+
render_type="points",
527529
)
528530

529531
# color_source_vector is None when the values aren't categorical

src/spatialdata_plot/pl/utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,8 +711,10 @@ def _set_color_source_vec(
711711
groups: list[str] | str | None = None,
712712
palette: list[str] | str | None = None,
713713
cmap_params: CmapParams | None = None,
714+
alpha: float = 1.0,
714715
table_name: str | None = None,
715716
table_layer: str | None = None,
717+
render_type: Literal["points"] | None = None,
716718
) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]:
717719
if value_to_plot is None and element is not None:
718720
color = np.full(len(element), na_color)
@@ -757,9 +759,12 @@ def _set_color_source_vec(
757759
adata=sdata.table,
758760
cluster_key=value_to_plot,
759761
color_source_vector=color_source_vector,
762+
cmap_params=cmap_params,
763+
alpha=alpha,
760764
groups=groups,
761765
palette=palette,
762766
na_color=na_color,
767+
render_type=render_type,
763768
)
764769

765770
color_source_vector = color_source_vector.set_categories(color_mapping.keys())
@@ -912,15 +917,28 @@ def _get_categorical_color_mapping(
912917
na_color: ColorLike,
913918
cluster_key: str | None = None,
914919
color_source_vector: ArrayLike | pd.Series[CategoricalDtype] | None = None,
920+
cmap_params: CmapParams | None = None,
921+
alpha: float = 1,
915922
groups: list[str] | str | None = None,
916923
palette: list[str] | str | None = None,
924+
render_type: Literal["points"] | None = None,
917925
) -> Mapping[str, str]:
918926
if not isinstance(color_source_vector, Categorical):
919927
raise TypeError(f"Expected `categories` to be a `Categorical`, but got {type(color_source_vector).__name__}")
920928

921929
if isinstance(groups, str):
922930
groups = [groups]
923931

932+
if not palette and render_type == "points" and cmap_params is not None and not cmap_params.cmap_is_default:
933+
palette = cmap_params.cmap
934+
935+
color_idx = color_idx = np.linspace(0, 1, len(color_source_vector.categories))
936+
if isinstance(palette, ListedColormap):
937+
palette = [to_hex(x) for x in palette(color_idx, alpha=alpha)]
938+
elif isinstance(palette, LinearSegmentedColormap):
939+
palette = [to_hex(palette(x, alpha=alpha)) for x in color_idx] # type: ignore[attr-defined]
940+
return dict(zip(color_source_vector.categories, palette, strict=True))
941+
924942
if isinstance(palette, str):
925943
palette = [palette]
926944

@@ -2011,7 +2029,7 @@ def _is_coercable_to_float(series: pd.Series) -> bool:
20112029

20122030

20132031
def _ax_show_and_transform(
2014-
array: MaskedArray[tuple[int, ...], Any],
2032+
array: MaskedArray[tuple[int, ...], Any] | npt.NDArray[Any],
20152033
trans_data: CompositeGenericTransform,
20162034
ax: Axes,
20172035
alpha: float | None = None,

tests/pl/test_render_labels.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def test_plot_can_color_labels_by_continuous_variable(self, sdata_blobs: Spatial
8484
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum").pl.show()
8585

8686
def test_plot_can_color_labels_by_categorical_variable(self, sdata_blobs: SpatialData):
87-
max_col = sdata_blobs.table.to_df().idxmax(axis=1)
88-
max_col = pd.Categorical(max_col, categories=sdata_blobs.table.to_df().columns, ordered=True)
89-
sdata_blobs.table.obs["which_max"] = max_col
87+
max_col = sdata_blobs["table"].to_df().idxmax(axis=1)
88+
max_col = pd.Categorical(max_col, categories=sdata_blobs["table"].to_df().columns, ordered=True)
89+
sdata_blobs["table"].obs["which_max"] = max_col
9090

9191
sdata_blobs.pl.render_labels("blobs_labels", color="which_max").pl.show()
9292

0 commit comments

Comments
 (0)