Skip to content

Donut-MultiPolygons are now correctly rendered again #334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,11 @@ jobs:
DISPLAY: :42
run: |
pytest -v --cov --color=yes --cov-report=xml
# - name: Generate GH action "groundtruth" figures as artifacts, uncomment if needed
# if: always()
# uses: actions/upload-artifact@v3
# with:
# name: groundtruth-figures
# path: /home/runner/work/spatialdata-plot/spatialdata-plot/tests/_images/*
- name: Archive figures generated during testing
if: always()
uses: actions/upload-artifact@v3
with:
name: plotting-results
name: visual_test_results_${{ matrix.os }}-python${{ matrix.python }}
path: /home/runner/work/spatialdata-plot/spatialdata-plot/tests/figures/*
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
Expand Down
262 changes: 76 additions & 186 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@

import matplotlib
import matplotlib.patches as mpatches
import matplotlib.patches as mplp
import matplotlib.path as mpath
import matplotlib.pyplot as plt
import multiscale_spatial_image as msi
import numpy as np
import pandas as pd
import shapely
Expand Down Expand Up @@ -48,7 +46,6 @@
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation
from scanpy.plotting.palettes import default_20, default_28, default_102
from shapely.geometry import LineString, Polygon
from skimage.color import label2rgb
from skimage.morphology import erosion, square
from skimage.segmentation import find_boundaries
Expand Down Expand Up @@ -233,6 +230,30 @@ def _sanitise_na_color(na_color: ColorLike | None) -> tuple[str, bool]:
raise ValueError(f"Invalid na_color value: {na_color}")


def _get_centroid_of_pathpatch(pathpatch: mpatches.PathPatch) -> tuple[float, float]:
# Extract the vertices from the PathPatch
path = pathpatch.get_path()
vertices = path.vertices
x = vertices[:, 0]
y = vertices[:, 1]

area = 0.5 * np.sum(x[:-1] * y[1:] - x[1:] * y[:-1])

# Calculate the centroid coordinates
centroid_x = np.sum((x[:-1] + x[1:]) * (x[:-1] * y[1:] - x[1:] * y[:-1])) / (6 * area)
centroid_y = np.sum((y[:-1] + y[1:]) * (x[:-1] * y[1:] - x[1:] * y[:-1])) / (6 * area)

return centroid_x, centroid_y


def _scale_pathpatch_around_centroid(pathpatch: mpatches.PathPatch, scale_factor: float) -> None:

centroid = _get_centroid_of_pathpatch(pathpatch)
vertices = pathpatch.get_path().vertices
scaled_vertices = np.array([centroid + (vertex - centroid) * scale_factor for vertex in vertices])
pathpatch.get_path().vertices = scaled_vertices


def _get_collection_shape(
shapes: list[GeoDataFrame],
c: Any,
Expand Down Expand Up @@ -302,63 +323,64 @@ def _get_collection_shape(
outline_c = outline_c * fill_c.shape[0]

shapes_df = pd.DataFrame(shapes, copy=True)

# remove empty points/polygons
shapes_df = shapes_df[shapes_df["geometry"].apply(lambda geom: not geom.is_empty)]

# reset index of shapes_df for case of spatial query
shapes_df = shapes_df.reset_index(drop=True)

rows = []

def assign_fill_and_outline_to_row(
shapes: list[GeoDataFrame], fill_c: list[Any], outline_c: list[Any], row: pd.Series, idx: int
def _assign_fill_and_outline_to_row(
fill_c: list[Any], outline_c: list[Any], row: dict[str, Any], idx: int, is_multiple_shapes: bool
) -> None:
try:
if len(shapes) > 1 and len(fill_c) == 1:
row["fill_c"] = fill_c
row["outline_c"] = outline_c
if is_multiple_shapes and len(fill_c) == 1:
row["fill_c"] = fill_c[0]
row["outline_c"] = outline_c[0]
else:
row["fill_c"] = fill_c[idx]
row["outline_c"] = outline_c[idx]
except IndexError as e:
raise IndexError("Could not assign fill and outline colors due to a mismatch in row-numbers.") from e

# Match colors to the geometry, potentially expanding the row in case of
# multipolygons
for idx, row in shapes_df.iterrows():
geom = row["geometry"]
if geom.geom_type == "Polygon":
row = row.to_dict()
coords = np.array(geom.exterior.coords)
centroid = np.mean(coords, axis=0)
scaled_coords = [(centroid + (np.array(coord) - centroid) * s).tolist() for coord in geom.exterior.coords]
row["geometry"] = mplp.Polygon(scaled_coords, closed=True)
assign_fill_and_outline_to_row(shapes, fill_c, outline_c, row, idx)
rows.append(row)

elif geom.geom_type == "MultiPolygon":
# mp = _make_patch_from_multipolygon(geom)
for polygon in geom.geoms:
mp_copy = row.to_dict()
coords = np.array(polygon.exterior.coords)
centroid = np.mean(coords, axis=0)
scaled_coords = [(centroid + (coord - centroid) * s).tolist() for coord in coords]
mp_copy["geometry"] = mplp.Polygon(scaled_coords, closed=True)
assign_fill_and_outline_to_row(shapes, fill_c, outline_c, mp_copy, idx)
rows.append(mp_copy)

elif geom.geom_type == "Point":
row = row.to_dict()
scaled_radius = row["radius"] * s
row["geometry"] = mplp.Circle(
(geom.x, geom.y), radius=scaled_radius
) # Circle is always scaled from its center
assign_fill_and_outline_to_row(shapes, fill_c, outline_c, row, idx)
rows.append(row)

patches = pd.DataFrame(rows)

raise IndexError("Could not assign fill and outline colors due to a mismatch in row numbers.") from e

def _process_polygon(row: pd.Series, s: float) -> dict[str, Any]:
coords = np.array(row["geometry"].exterior.coords)
centroid = np.mean(coords, axis=0)
scaled_coords = (centroid + (coords - centroid) * s).tolist()
return {**row.to_dict(), "geometry": mpatches.Polygon(scaled_coords, closed=True)}

def _process_multipolygon(row: pd.Series, s: float) -> list[dict[str, Any]]:
mp = _make_patch_from_multipolygon(row["geometry"])
row_dict = row.to_dict()
for m in mp:
_scale_pathpatch_around_centroid(m, s)

return [{**row_dict, "geometry": m} for m in mp]

def _process_point(row: pd.Series, s: float) -> dict[str, Any]:
return {
**row.to_dict(),
"geometry": mpatches.Circle((row["geometry"].x, row["geometry"].y), radius=row["radius"] * s),
}

def _create_patches(shapes_df: GeoDataFrame, fill_c: list[Any], outline_c: list[Any], s: float) -> pd.DataFrame:
rows = []
is_multiple_shapes = len(shapes_df) > 1

for idx, row in shapes_df.iterrows():
geom_type = row["geometry"].geom_type
processed_rows = []

if geom_type == "Polygon":
processed_rows.append(_process_polygon(row, s))
elif geom_type == "MultiPolygon":
processed_rows.extend(_process_multipolygon(row, s))
elif geom_type == "Point":
processed_rows.append(_process_point(row, s))

for processed_row in processed_rows:
_assign_fill_and_outline_to_row(fill_c, outline_c, processed_row, idx, is_multiple_shapes)
rows.append(processed_row)

return pd.DataFrame(rows)

patches = _create_patches(shapes_df, fill_c, outline_c, s)
return PatchCollection(
patches["geometry"].values.tolist(),
snap=False,
Expand Down Expand Up @@ -738,7 +760,7 @@ def _map_color_seg(
cell_id = np.array(cell_id)
if color_vector is not None and isinstance(color_vector.dtype, pd.CategoricalDtype):
# users wants to plot a categorical column
if isinstance(na_color, tuple) and len(na_color) == 4 and np.any(color_source_vector.isna()):
if np.any(color_source_vector.isna()):
cell_id[color_source_vector.isna()] = 0
val_im: ArrayLike = map_array(seg, cell_id, color_vector.codes + 1)
cols = colors.to_rgba_array(color_vector.categories)
Expand Down Expand Up @@ -823,9 +845,9 @@ def _modify_categorical_color_mapping(
modified_mapping = {key: mapping[key] for key in mapping if key in groups or key == "NaN"}
elif len(palette) == len(groups) and isinstance(groups, list) and isinstance(palette, list):
modified_mapping = dict(zip(groups, palette))

else:
raise ValueError(f"Expected palette to be of length `{len(groups)}`, found `{len(palette)}`.")

return modified_mapping


Expand All @@ -841,7 +863,7 @@ def _get_default_categorial_color_mapping(
palette = default_102
else:
palette = ["grey" for _ in range(len_cat)]
logger.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.")

return {cat: to_hex(to_rgba(col)[:3]) for cat, col in zip(color_source_vector.categories, palette[:len_cat])}

Expand Down Expand Up @@ -872,54 +894,6 @@ def _get_categorical_color_mapping(
return _modify_categorical_color_mapping(base_mapping, groups, palette)


def _get_palette(
categories: Sequence[Any],
adata: AnnData | None = None,
cluster_key: None | str = None,
palette: ListedColormap | str | list[str] | None = None,
alpha: float = 1.0,
) -> Mapping[str, str] | None:
palette = None if isinstance(palette, list) and palette[0] is None else palette
if adata is not None and palette is None:
try:
palette = adata.uns[f"{cluster_key}_colors"] # type: ignore[arg-type]
if len(palette) != len(categories):
raise ValueError(
f"Expected palette to be of length `{len(categories)}`, found `{len(palette)}`. "
+ f"Removing the colors in `adata.uns` with `adata.uns.pop('{cluster_key}_colors')` may help."
)
return {cat: to_hex(to_rgba(col)[:3]) for cat, col in zip(categories, palette)}
except KeyError as e:
logger.warning(e)
return None

len_cat = len(categories)

if palette is None:
if len_cat <= 20:
palette = default_20
elif len_cat <= 28:
palette = default_28
elif len_cat <= len(default_102): # 103 colors
palette = default_102
else:
palette = ["grey" for _ in range(len_cat)]
logger.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
return {cat: to_hex(to_rgba(col)[:3]) for cat, col in zip(categories, palette[:len_cat])}

if isinstance(palette, str):
cmap = ListedColormap([palette])
elif isinstance(palette, list):
cmap = ListedColormap(palette)
elif isinstance(palette, ListedColormap):
cmap = palette
else:
raise TypeError(f"Palette is {type(palette)} but should be string or list.")
palette = [to_hex(np.round(x, 5)) for x in cmap(np.linspace(0, 1, len_cat), alpha=alpha)]

return dict(zip(categories, palette))


def _maybe_set_colors(
source: AnnData, target: AnnData, key: str, palette: str | ListedColormap | Cycler | Sequence[Any] | None = None
) -> None:
Expand Down Expand Up @@ -1087,34 +1061,6 @@ def save_fig(fig: Figure, path: str | Path, make_dir: bool = True, ext: str = "p
fig.savefig(path, **kwargs)


def _get_cs_element_map(
element: str | Sequence[str] | None,
element_map: Mapping[str, Any],
) -> Mapping[str, str]:
"""Get the mapping between the coordinate system and the class."""
# from spatialdata.models import Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, ShapesModel
element = list(element_map.keys())[0] if element is None else element
element = [element] if isinstance(element, str) else element
d = {}
for e in element:
cs = list(element_map[e].attrs["transform"].keys())[0]
d[cs] = e
# model = get_model(element_map["blobs_labels"])
# if model in [Image2DModel, Image3DModel, Labels2DModel, Labels3DModel]
return d


def _multiscale_to_image(sdata: sd.SpatialData) -> sd.SpatialData:
if sdata.images is None:
raise ValueError("No images found in the SpatialData object.")

for k, v in sdata.images.items():
if isinstance(v, msi.multiscale_spatial_image.DataTree):
sdata.images[k] = Image2DModel.parse(v["scale0"].ds.to_array().squeeze(axis=0))

return sdata


def _get_linear_colormap(colors: list[str], background: str) -> list[LinearSegmentedColormap]:
return [LinearSegmentedColormap.from_list(c, [background, c], N=256) for c in colors]

Expand All @@ -1126,62 +1072,6 @@ def _get_listed_colormap(color_dict: dict[str, str]) -> ListedColormap:
return ListedColormap(["black"] + colors, N=len(colors) + 1)


def _translate_image(
image: DataArray,
translation: sd.transformations.transformations.Translation,
) -> DataArray:
shifts: dict[str, int] = {axis: int(translation.translation[idx]) for idx, axis in enumerate(translation.axes)}
img = image.values.copy()
# for yx images (important for rasterized MultiscaleImages as labels)
expanded_dims = False
if len(img.shape) == 2:
img = np.expand_dims(img, axis=0)
expanded_dims = True

shifted_channels = []

# split channels, shift axes individually, them recombine
if len(img.shape) == 3:
for c in range(img.shape[0]):
channel = img[c, :, :]

# iterates over [x, y]
for axis, shift in shifts.items():
pad_x, pad_y = (0, 0), (0, 0)
if axis == "x" and shift > 0:
pad_x = (abs(shift), 0)
elif axis == "x" and shift < 0:
pad_x = (0, abs(shift))

if axis == "y" and shift > 0:
pad_y = (abs(shift), 0)
elif axis == "y" and shift < 0:
pad_y = (0, abs(shift))

channel = np.pad(channel, (pad_y, pad_x), mode="constant")

shifted_channels.append(channel)

if expanded_dims:
return Labels2DModel.parse(
np.array(shifted_channels[0]),
dims=["y", "x"],
transformations=image.attrs["transform"],
)
return Image2DModel.parse(
np.array(shifted_channels),
dims=["c", "y", "x"],
transformations=image.attrs["transform"],
)


def _convert_polygon_to_linestrings(polygon: Polygon) -> list[LineString]:
b = polygon.boundary.coords
linestrings = [LineString(b[k : k + 2]) for k in range(len(b) - 1)]

return [list(ls.coords) for ls in linestrings]


def _split_multipolygon_into_outer_and_inner(mp: shapely.MultiPolygon): # type: ignore
# https://stackoverflow.com/a/21922058

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed tests/_images/Labels_label_categorical_color.png
Binary file not shown.
Binary file modified tests/_images/Shapes_can_render_multipolygons.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading