Skip to content

fix index bug after spatial query #163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning][].
- Images no longer normalised by default (#150)
- Filtering of shapes and points using the `groups` argument is now possible, coloring by palette and cmap arguments works for shapes and points (#153)
- Colorbar no longer autoscales to [0, 1] (#155)
- Plotting shapes after a spatial query is now possible (#163)

## [0.0.4] - 2023-08-11

Expand Down
3 changes: 3 additions & 0 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def _render_shapes(
if not (
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
):
# necessary in case different shapes elements are annotated with one table
if color_source_vector is not None:
color_source_vector = color_source_vector.remove_unused_categories()
_ = _decorate_axs(
ax=ax,
cax=cax,
Expand Down
3 changes: 3 additions & 0 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ def _get_collection_shape(
# remove empty points/polygons
shapes_df = shapes_df[shapes_df["geometry"].apply(lambda geom: not geom.is_empty)]

# reset index of shapes_df for case of spatial query
shapes_df = shapes_df.reset_index()

rows = []

def assign_fill_and_outline_to_row(
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
126 changes: 126 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import anndata
import geopandas as gpd
import matplotlib
import numpy as np
import pandas as pd
import scanpy as sc
import spatialdata_plot # noqa: F401
Expand Down Expand Up @@ -128,3 +129,128 @@ def test_plot_colorbar_respects_input_limits(self, sdata_blobs: SpatialData):
def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData):
sdata_blobs.shapes["blobs_polygons"]["cluster"] = [1, 2, 3, 5, 20]
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups=["c1"], norm=True).pl.show()

def test_plot_can_plot_shapes_after_spatial_query(self, sdata_blobs: SpatialData):
# subset to only shapes, should be unnecessary after rasterizeation of multiscale images is included
blob = SpatialData.from_elements_dict(
{
"blobs_circles": sdata_blobs.shapes["blobs_circles"],
"blobs_multipolygons": sdata_blobs.shapes["blobs_multipolygons"],
"blobs_polygons": sdata_blobs.shapes["blobs_polygons"],
}
)
cropped_blob = blob.query.bounding_box(
axes=["x", "y"], min_coordinate=[100, 100], max_coordinate=[300, 300], target_coordinate_system="global"
)
cropped_blob.pl.render_shapes().pl.show()

def test_plot_can_plot_with_annotation_despite_random_shuffling(self, sdata_blobs: SpatialData):
new_table = sdata_blobs.table.copy()
sdata_blobs.table.obs["region"] = "blobs_circles"
new_table = sdata_blobs.table[:5]
new_table.uns["spatialdata_attrs"]["region"] = "blobs_circles"
new_table.obs["instance_id"] = np.array(range(5))

new_table.obs["annotation"] = ["a", "b", "c", "d", "e"]
new_table.obs["annotation"] = new_table.obs["annotation"].astype("category")

del sdata_blobs.table
sdata_blobs.table = new_table

# random permutation of table and shapes
sdata_blobs.table.obs = sdata_blobs.table.obs.sample(frac=1, random_state=83)
temp = sdata_blobs["blobs_circles"].sample(frac=1, random_state=47)
del sdata_blobs.shapes["blobs_circles"]
sdata_blobs["blobs_circles"] = temp

sdata_blobs.pl.render_shapes("blobs_circles", color="annotation").pl.show()

def test_plot_can_plot_queried_with_annotation_despite_random_shuffling(self, sdata_blobs: SpatialData):
new_table = sdata_blobs.table.copy()
sdata_blobs.table.obs["region"] = "blobs_circles"
new_table = sdata_blobs.table[:5]
new_table.uns["spatialdata_attrs"]["region"] = "blobs_circles"
new_table.obs["instance_id"] = np.array(range(5))

new_table.obs["annotation"] = ["a", "b", "c", "d", "e"]
new_table.obs["annotation"] = new_table.obs["annotation"].astype("category")

del sdata_blobs.table
sdata_blobs.table = new_table

# random permutation of table and shapes
sdata_blobs.table.obs = sdata_blobs.table.obs.sample(frac=1, random_state=83)
temp = sdata_blobs["blobs_circles"].sample(frac=1, random_state=47)
del sdata_blobs.shapes["blobs_circles"]
sdata_blobs["blobs_circles"] = temp

# subsetting the data
sdata_cropped = sdata_blobs.query.bounding_box(
axes=("x", "y"),
min_coordinate=[100, 150],
max_coordinate=[400, 250],
target_coordinate_system="global",
filter_table=True,
)

# workaround for bug that should be gone in later versions
del sdata_cropped.images["blobs_multiscale_image"]
del sdata_cropped.labels["blobs_labels"]
del sdata_cropped.labels["blobs_multiscale_labels"]

sdata_cropped.pl.render_shapes("blobs_circles", color="annotation").pl.show()

def test_plot_can_color_two_shapes_elements_by_annotation(self, sdata_blobs: SpatialData):
new_table = sdata_blobs.table.copy()
sdata_blobs.table.obs["region"] = "blobs_circles"
new_table = sdata_blobs.table[:10]
new_table.uns["spatialdata_attrs"]["region"] = ["blobs_circles", "blobs_polygons"]
new_table.obs["instance_id"] = np.concatenate((np.array(range(5)), np.array(range(5))))

new_table.obs.loc[5 * [False] + 5 * [True], "region"] = "blobs_polygons"
new_table.obs["annotation"] = ["a", "b", "c", "d", "e", "v", "w", "x", "y", "z"]
new_table.obs["annotation"] = new_table.obs["annotation"].astype("category")

del sdata_blobs.table
sdata_blobs.table = new_table

sdata_blobs.pl.render_shapes(["blobs_circles", "blobs_polygons"], color="annotation").pl.show()

def test_plot_can_color_two_queried_shapes_elements_by_annotation(self, sdata_blobs: SpatialData):
new_table = sdata_blobs.table.copy()
sdata_blobs.table.obs["region"] = "blobs_circles"
new_table = sdata_blobs.table[:10]
new_table.uns["spatialdata_attrs"]["region"] = ["blobs_circles", "blobs_polygons"]
new_table.obs["instance_id"] = np.concatenate((np.array(range(5)), np.array(range(5))))

new_table.obs.loc[5 * [False] + 5 * [True], "region"] = "blobs_polygons"
new_table.obs["annotation"] = ["a", "b", "c", "d", "e", "v", "w", "x", "y", "z"]
new_table.obs["annotation"] = new_table.obs["annotation"].astype("category")

del sdata_blobs.table
sdata_blobs.table = new_table

# random permutation of table and shapes
sdata_blobs.table.obs = sdata_blobs.table.obs.sample(frac=1, random_state=83)
temp = sdata_blobs["blobs_circles"].sample(frac=1, random_state=47)
del sdata_blobs.shapes["blobs_circles"]
sdata_blobs["blobs_circles"] = temp
temp = sdata_blobs["blobs_polygons"].sample(frac=1, random_state=71)
del sdata_blobs.shapes["blobs_polygons"]
sdata_blobs["blobs_polygons"] = temp

# subsetting the data
sdata_cropped = sdata_blobs.query.bounding_box(
axes=("x", "y"),
min_coordinate=[100, 150],
max_coordinate=[350, 300],
target_coordinate_system="global",
filter_table=True,
)

# workaround for bug that should be gone in later versions
del sdata_cropped.images["blobs_multiscale_image"]
del sdata_cropped.labels["blobs_labels"]
del sdata_cropped.labels["blobs_multiscale_labels"]

sdata_cropped.pl.render_shapes(["blobs_circles", "blobs_polygons"], color="annotation").pl.show()