Skip to content

Fix clims when plotting shapes element annotations with matplotlib rendering #368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,13 @@ def render_shapes(
sd.SpatialData
The modified SpatialData object with the rendered shapes.
"""
# TODO add Normalize object in tutorial notebook and point to that notebook here
if "vmin" in kwargs or "vmax" in kwargs:
warnings.warn(
"`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.",
DeprecationWarning,
stacklevel=2,
)
params_dict = _validate_shape_render_params(
self._sdata,
element=element,
Expand Down Expand Up @@ -269,7 +276,6 @@ def render_shapes(
cmap=cmap,
norm=norm,
na_color=params_dict[element]["na_color"], # type: ignore[arg-type]
**kwargs,
)
sdata.plotting_tree[f"{n_steps+1}_render_shapes"] = ShapesRenderParams(
element=element,
Expand Down Expand Up @@ -363,6 +369,13 @@ def render_points(
sd.SpatialData
The modified SpatialData object with the rendered shapes.
"""
# TODO add Normalize object in tutorial notebook and point to that notebook here
if "vmin" in kwargs or "vmax" in kwargs:
warnings.warn(
"`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.",
DeprecationWarning,
stacklevel=2,
)
params_dict = _validate_points_render_params(
self._sdata,
element=element,
Expand Down Expand Up @@ -392,7 +405,6 @@ def render_points(
cmap=cmap,
norm=norm,
na_color=param_values["na_color"], # type: ignore[arg-type]
**kwargs,
)
sdata.plotting_tree[f"{n_steps+1}_render_points"] = PointsRenderParams(
element=element,
Expand Down Expand Up @@ -473,6 +485,13 @@ def render_images(
sd.SpatialData
The SpatialData object with the rendered images.
"""
# TODO add Normalize object in tutorial notebook and point to that notebook here
if "vmin" in kwargs or "vmax" in kwargs:
warnings.warn(
"`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.",
DeprecationWarning,
stacklevel=2,
)
params_dict = _validate_image_render_params(
self._sdata,
element=element,
Expand All @@ -498,7 +517,6 @@ def render_images(
cmap=c,
norm=norm,
na_color=param_values["na_color"],
**kwargs,
)
for c in cmap
]
Expand Down Expand Up @@ -598,6 +616,13 @@ def render_labels(
-------
None
"""
# TODO add Normalize object in tutorial notebook and point to that notebook here
if "vmin" in kwargs or "vmax" in kwargs:
warnings.warn(
"`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.",
DeprecationWarning,
stacklevel=2,
)
params_dict = _validate_label_render_params(
self._sdata,
element=element,
Expand All @@ -623,7 +648,6 @@ def render_labels(
cmap=cmap,
norm=norm,
na_color=param_values["na_color"], # type: ignore[arg-type]
**kwargs,
)
sdata.plotting_tree[f"{n_steps+1}_render_labels"] = LabelsRenderParams(
element=element,
Expand Down
18 changes: 4 additions & 14 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
LinearSegmentedColormap,
ListedColormap,
Normalize,
TwoSlopeNorm,
to_rgba,
)
from matplotlib.figure import Figure
Expand Down Expand Up @@ -339,7 +338,7 @@ def _get_collection_shape(
c = cmap(c)
else:
try:
norm = colors.Normalize(vmin=min(c), vmax=max(c))
norm = colors.Normalize(vmin=min(c), vmax=max(c)) if norm is None else norm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this function Normalize() is initialized without clip, while in _prepare_cmap_norm() the default is to set clip=True. I would choose one of the two as our default choice. The user will be able to specify clip, vcenter, etc by passing a norm object directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm let me double check that if we don't pass norm as user, whether ultimately the norm is always created anyway, then we can get rid of normalize instance initiated here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok there is code left over of when vmin and vmax were removed. Not certain whether to address this in a different PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd address the choice of the value of clip in this PR please, because it's easy to forget about this in a new PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default is set to False now

except ValueError as e:
raise ValueError(
"Could not convert values in the `color` column to float, if `color` column represents"
Expand All @@ -353,7 +352,7 @@ def _get_collection_shape(
c = cmap(c)
else:
try:
norm = colors.Normalize(vmin=min(c), vmax=max(c))
norm = colors.Normalize(vmin=min(c), vmax=max(c)) if norm is None else norm
except ValueError as e:
raise ValueError(
"Could not convert values in the `color` column to float, if `color` column represents"
Expand Down Expand Up @@ -491,11 +490,8 @@ def _prepare_cmap_norm(
cmap: Colormap | str | None = None,
norm: Normalize | None = None,
na_color: ColorLike | None = None,
vmin: float | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@timtreis I don't remember the outcome of the discussion with the user that reported this. Is this the way to go (=letting users only use norm and not vmin, vmax) or shall we remove vcenter only and keep vmin, vmax and use them to initialize the default Normalize object?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the discussion it was stated that vmin and vmax are removed. This function is only internally called

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we agree on clip 'True' by default if user does not provide normaloze object?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If vmin and vmax are not exposed to the user (and hence they are None), then clip will have no effect because when exposed vmin, vmax are None, the data limits are used. So I would keep the default clip to be False (which is matplotlib's default).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing, if vmin, vmax are removed from pl.render_shapes(), we should throw an informative exception or deprecation warning, explaining to the user that norm should be used instead. Could you add that please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They have not been removed in this PR though and the public functions thus already did not contain these parameters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did add a deprecation warning in case of the arguments being passed as kwargs

vmax: float | None = None,
vcenter: float | None = None,
**kwargs: Any,
) -> CmapParams:
# TODO: check refactoring norm out here as it gets overwritten later
cmap_is_default = cmap is None
if cmap is None:
cmap = rcParams["image.cmap"]
Expand All @@ -505,13 +501,7 @@ def _prepare_cmap_norm(
cmap = copy(cmap)

if norm is None:
norm = Normalize(vmin=vmin, vmax=vmax, clip=True)
elif isinstance(norm, Normalize) or not norm:
pass # TODO
elif vcenter is None:
norm = Normalize(vmin=vmin, vmax=vmax, clip=True)
else:
norm = TwoSlopeNorm(vmin=vmin, vmax=vmax, vcenter=vcenter)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that vcenter is removed, it should be removed also from the function signature. Also, kwargs is in the signature but not used, so I would remove it.

norm = Normalize(vmin=None, vmax=None, clip=False)

na_color, na_color_modified_by_user = _sanitise_na_color(na_color)
cmap.set_bad(na_color)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Shapes_can_set_clims_clip.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Shapes_colorbar_can_be_normalised.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
33 changes: 23 additions & 10 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _make_multi():

def test_plot_can_color_from_geodataframe(self, sdata_blobs: SpatialData):
blob = deepcopy(sdata_blobs)
blob["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
blob["table"].obs["region"] = "blobs_polygons"
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
blob.shapes["blobs_polygons"]["value"] = [1, 10, 1, 20, 1]
blob.pl.render_shapes(
Expand All @@ -111,7 +111,7 @@ def test_plot_can_scale_shapes(self, sdata_blobs: SpatialData):
def test_plot_can_filter_with_groups(self, sdata_blobs: SpatialData):
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")

sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].obs["region"] = "blobs_polygons"
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["cluster"] = "c1"
sdata_blobs.shapes["blobs_polygons"].iloc[3:5, 1] = "c2"
Expand All @@ -125,7 +125,7 @@ def test_plot_can_filter_with_groups(self, sdata_blobs: SpatialData):
)

def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].obs["region"] = "blobs_polygons"
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["cluster"] = "c1"
sdata_blobs.shapes["blobs_polygons"].iloc[3:5, 1] = "c2"
Expand All @@ -138,13 +138,13 @@ def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData):
).pl.show()

def test_plot_colorbar_respects_input_limits(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].obs["region"] = "blobs_polygons"
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["cluster"] = [1, 2, 3, 5, 20]
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups=["c1"]).pl.show()
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster").pl.show()

def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].obs["region"] = "blobs_polygons"
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["cluster"] = [1, 2, 3, 5, 20]
norm = Normalize(vmin=0, vmax=5, clip=True)
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_plot_can_plot_with_annotation_despite_random_shuffling(self, sdata_blob

def test_plot_can_plot_queried_with_annotation_despite_random_shuffling(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = "blobs_circles"
new_table = sdata_blobs["table"][:5]
new_table = sdata_blobs["table"][:5].copy()
new_table.uns["spatialdata_attrs"]["region"] = "blobs_circles"
new_table.obs["instance_id"] = np.array(range(5))

Expand Down Expand Up @@ -214,7 +214,7 @@ def test_plot_can_plot_queried_with_annotation_despite_random_shuffling(self, sd

def test_plot_can_color_two_shapes_elements_by_annotation(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = "blobs_circles"
new_table = sdata_blobs["table"][:10]
new_table = sdata_blobs["table"][:10].copy()
new_table.uns["spatialdata_attrs"]["region"] = ["blobs_circles", "blobs_polygons"]
new_table.obs["instance_id"] = np.concatenate((np.array(range(5)), np.array(range(5))))

Expand All @@ -230,7 +230,7 @@ def test_plot_can_color_two_shapes_elements_by_annotation(self, sdata_blobs: Spa

def test_plot_can_color_two_queried_shapes_elements_by_annotation(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = "blobs_circles"
new_table = sdata_blobs["table"][:10]
new_table = sdata_blobs["table"][:10].copy()
new_table.uns["spatialdata_attrs"]["region"] = ["blobs_circles", "blobs_polygons"]
new_table.obs["instance_id"] = np.concatenate((np.array(range(5)), np.array(range(5))))

Expand Down Expand Up @@ -312,7 +312,20 @@ def test_plot_datashader_can_color_by_category(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", color="category", method="datashader").pl.show()

def test_plot_datashader_can_color_by_value(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
sdata_blobs["table"].obs["region"] = "blobs_polygons"
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["value"] = [1, 10, 1, 20, 1]
sdata_blobs.pl.render_shapes(element="blobs_polygons", color="value", method="datashader").pl.show()

def test_plot_can_set_clims_clip(self, sdata_blobs: SpatialData):
table_shapes = sdata_blobs["table"][:5].copy()
table_shapes.obs.instance_id = list(range(5))
table_shapes.obs["region"] = "blobs_circles"
table_shapes.obs["dummy_gene_expression"] = [i * 10 for i in range(5)]
table_shapes.uns["spatialdata_attrs"]["region"] = "blobs_circles"
sdata_blobs["new_table"] = table_shapes

norm = Normalize(vmin=20, vmax=40, clip=True)
sdata_blobs.pl.render_shapes(
"blobs_circles", color="dummy_gene_expression", norm=norm, table_name="new_table"
).pl.show()
Loading