Skip to content

Added utils function for 0-transparent cmaps #302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning][].

### Added

- Added utils function for 0-transparent cmaps (#302)

### Changed

-
Expand Down
23 changes: 23 additions & 0 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Literal, Union

import matplotlib
import matplotlib.cm as cm
import matplotlib.patches as mpatches
import matplotlib.patches as mplp
import matplotlib.path as mpath
Expand Down Expand Up @@ -1961,3 +1962,25 @@ def _ax_show_and_transform(
zorder=zorder,
)
im.set_transform(trans_data)


def set_zero_in_cmap_to_transparent(cmap: Colormap | str, steps: int | None = None) -> ListedColormap:
"""
Modify colormap so that 0s are transparent.

Parameters
----------
cmap (Colormap | str): A matplotlib Colormap instance or a colormap name string.
steps (int): The number of steps in the colormap.

Returns
-------
ListedColormap: A new colormap instance with modified alpha values.
"""
if isinstance(cmap, str):
cmap = cm.get_cmap(cmap)

colors = cmap(np.arange(steps or cmap.N))
colors[0, :] = [1.0, 1.0, 1.0, 0.0]

return ListedColormap(colors)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
70 changes: 57 additions & 13 deletions tests/pl/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,58 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pytest
from spatialdata.datasets import blobs


@pytest.mark.parametrize(
"outline_color",
[
(0.0, 1.0, 0.0, 1.0),
"#00ff00",
],
)
def test_set_outline_accepts_str_or_float_or_list_thereof(outline_color):
sdata = blobs()
sdata.pl.render_shapes(element="blobs_polygons", outline=True, outline_color=outline_color).pl.show()
import scanpy as sc
import spatialdata_plot
from spatialdata import SpatialData

from tests.conftest import DPI, PlotTester, PlotTesterMeta

RNG = np.random.default_rng(seed=42)
sc.pl.set_rcParams_defaults()
sc.set_figure_params(dpi=DPI, color_map="viridis")
matplotlib.use("agg") # same as GitHub action runner
_ = spatialdata_plot

# WARNING:
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
# 2. tests which produce a plot must be prefixed with `test_plot_`
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
# ".png" is appended to <your_filename>, no need to set it


class TestUtils(PlotTester, metaclass=PlotTesterMeta):
@pytest.mark.parametrize(
"outline_color",
[
(0.0, 1.0, 0.0, 1.0),
"#00ff00",
],
)
def test_plot_set_outline_accepts_str_or_float_or_list_thereof(self, sdata_blobs: SpatialData, outline_color):
sdata_blobs.pl.render_shapes(element="blobs_polygons", outline=True, outline_color=outline_color).pl.show()

def test_plot_can_set_zero_in_cmap_to_transparent(self, sdata_blobs: SpatialData):
from spatialdata_plot.pl.utils import set_zero_in_cmap_to_transparent

# set up figure and modify the data to add 0s
fig, axs = plt.subplots(ncols=2, figsize=(6, 3))
table = sdata_blobs.table.copy()
x = table.X.todense()
x[:10, 0] = 0
table.X = x
sdata_blobs.tables["modified_table"] = table

# create a new cmap with 0 as transparent
new_cmap = set_zero_in_cmap_to_transparent(cmap="plasma")

# baseline img
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", cmap="viridis", table="table").pl.show(
ax=axs[0], colorbar=False
)

# image with 0s as transparent, so some labels are "missing"
sdata_blobs.pl.render_labels(
"blobs_labels", color="channel_0_sum", cmap=new_cmap, table="modified_table"
).pl.show(ax=axs[1], colorbar=False)
Loading