Skip to content

Commit 9b09a3f

Browse files
committed
merge
2 parents fd55b7e + 538d081 commit 9b09a3f

File tree

54 files changed

+642
-212
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+642
-212
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,6 @@ tests/figures/
3838

3939
# sphinx files
4040
*.swp
41+
42+
# other
43+
_version.py

CHANGELOG.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,22 @@ and this project adheres to [Semantic Versioning][].
88
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
99
[semantic versioning]: https://semver.org/spec/v2.0.0.html
1010

11-
## [Unreleased]
11+
## [0.0.4] - 2023-08-11
12+
13+
### Fixed
14+
15+
- Multi-scale images/labels are now correctly substituted and the action is logged (#131).
16+
- Empty geometries among the shapes can be handeled (#133).
17+
- `outline_width` parameter in render_shapes is now a float that actually determines the line width (#139).
18+
19+
## [0.0.2] - 2023-06-25
20+
21+
### Fixed
22+
23+
- Multiple bugfixes of which I didn't keep track of.
24+
25+
## [0.0.1] - 2023-04-04
1226

1327
### Added
1428

15-
- Basic tool, preprocessing and plotting functions
29+
- Initial release of `spatialdata-plot` with support for `images`, `labels`, `points` and `shapes`.

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
[![Tests][badge-tests]][link-tests]
66
[![Documentation][badge-docs]][link-docs]
77
[![Codecov][badge-codecov]][link-codecov]
8+
[![Documentation][badge-pypi]][link-pypi]
89

910
[badge-tests]: https://img.shields.io/github/actions/workflow/status/scverse/spatialdata-plot/test_and_deploy.yaml?branch=main
1011
[link-tests]: https://github.com/scverse/spatialdata-plot/actions/workflows/test.yml
1112
[badge-docs]: https://img.shields.io/readthedocs/spatialdata-plot
1213
[badge-codecov]: https://codecov.io/gh/scverse/spatialdata-plot/branch/main/graph/badge.svg?token=C45F3ATSVI
1314
[link-codecov]: https://app.codecov.io/gh/scverse/spatialdata-plot
15+
[badge-pypi]: https://badge.fury.io/py/spatialdata_plot.svg
16+
[link-pypi]: https://pypi.org/project/spatialdata-plot/
1417

1518
The `spatialdata-plot` package extends `spatialdata` with a declarative plotting API that enables to quickly visualize `spatialdata` objects and their respective elements (i.e. `images`, `labels`, `points` and `shapes`).
1619

@@ -20,13 +23,10 @@ SpatialData’s plotting capabilities allow to quickly visualise all contained m
2023

2124
## Getting started
2225

23-
For more information on the `spatialdata` format, please refer to the [documentation](https://spatialdata.scverse.org/en/latest/). In particular, the
26+
For more information on the `spatialdata-plot` library, please refer to the [documentation](https://spatialdata.scverse.org/projects/plot/en/latest/index.html). In particular, the
2427

2528
- [API documentation][link-api].
26-
- [Design doc][link-design-doc].
27-
- [Example notebooks][link-notebooks].
28-
29-
For usage examples, please refer to the ["Visualizations"](https://spatialdata.scverse.org/en/latest/tutorials/notebooks/notebooks.html#visualizations) section of `spatialdata`.
29+
- [Example notebooks][link-notebooks] (section "Visiualizations")
3030

3131
## Installation
3232

pyproject.toml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
[build-system]
2+
build-backend = "hatchling.build"
3+
requires = ["hatchling", "hatch-vcs"]
4+
15
[project]
26
name = "spatialdata-plot"
37
description = "Static plotting for spatial data."
@@ -161,3 +165,25 @@ target-version = "py39"
161165
"src/spatialdata_plot/pl/utils.py"= ["PGH003"]
162166
[tool.ruff.pydocstyle]
163167
convention = "numpy"
168+
169+
[tool.bumpver]
170+
current_version = "0.0.2"
171+
version_pattern = "MAJOR.MINOR.PATCH"
172+
commit_message = "bump version {old_version} -> {new_version}"
173+
tag_message = "{new_version}"
174+
tag_scope = "default"
175+
pre_commit_hook = ""
176+
post_commit_hook = ""
177+
commit = true
178+
tag = true
179+
push = false
180+
181+
[tool.bumpver.file_patterns]
182+
"pyproject.toml" = [
183+
'current_version = "{version}"',
184+
]
185+
"README.md" = [
186+
"{version}",
187+
"{pep440_version}",
188+
]
189+

src/spatialdata_plot/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from importlib.metadata import version
2+
13
from . import pl, pp
24

35
__all__ = ["pl", "pp"]
46

5-
# __version__ = version("spatialdata-plot")
7+
__version__ = version("spatialdata-plot")

src/spatialdata_plot/_logging.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# from https://github.com/scverse/spatialdata/blob/main/src/spatialdata/_logging.py
2+
3+
import logging
4+
5+
6+
def _setup_logger() -> "logging.Logger":
7+
from rich.console import Console
8+
from rich.logging import RichHandler
9+
10+
logger = logging.getLogger(__name__)
11+
logger.setLevel(logging.INFO)
12+
console = Console(force_terminal=True)
13+
if console.is_jupyter is True:
14+
console.is_jupyter = False
15+
ch = RichHandler(show_path=False, console=console, show_time=False)
16+
logger.addHandler(ch)
17+
18+
# this prevents double outputs
19+
logger.propagate = False
20+
return logger
21+
22+
23+
logger = _setup_logger()

src/spatialdata_plot/pl/basic.py

Lines changed: 82 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
from typing import Any
88

99
import matplotlib.pyplot as plt
10+
import numpy as np
1011
import scanpy as sc
1112
import spatialdata as sd
1213
from anndata import AnnData
1314
from dask.dataframe.core import DataFrame as DaskDataFrame
1415
from geopandas import GeoDataFrame
1516
from matplotlib.axes import Axes
16-
from matplotlib.colors import Colormap, Normalize
17+
from matplotlib.colors import Colormap, ListedColormap, Normalize
1718
from matplotlib.figure import Figure
1819
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
1920
from pandas.api.types import is_categorical_dtype
@@ -32,14 +33,14 @@
3233
_render_shapes,
3334
)
3435
from spatialdata_plot.pl.utils import (
36+
CmapParams,
3537
LegendParams,
36-
Palette_t,
3738
_FontSize,
3839
_FontWeight,
3940
_get_cs_contents,
4041
_get_extent,
4142
_maybe_set_colors,
42-
_multiscale_to_image,
43+
_mpl_ax_contains_elements,
4344
_prepare_cmap_norm,
4445
_prepare_params_plot,
4546
_robust_transform,
@@ -144,11 +145,11 @@ def render_shapes(
144145
groups: str | Sequence[str] | None = None,
145146
size: float = 1.0,
146147
outline: bool = False,
147-
outline_width: tuple[float, float] = (0.3, 0.05),
148-
outline_color: tuple[str, str] = ("#000000ff", "#ffffffff"), # black, white
148+
outline_width: float = 1.5,
149+
outline_color: str | list[float] = "#000000ff",
149150
alt_var: str | None = None,
150151
layer: str | None = None,
151-
palette: Palette_t = None,
152+
palette: ListedColormap | str | None = None,
152153
cmap: Colormap | str | None = None,
153154
norm: None | Normalize = None,
154155
na_color: str | tuple[float, ...] | None = "lightgrey",
@@ -194,6 +195,11 @@ def render_shapes(
194195
kwargs
195196
Additional arguments to be passed to cmap and norm.
196197
198+
Notes
199+
-----
200+
Empty geometries will be removed at the time of plotting.
201+
An ``outline_width`` of 0.0 leads to no border being plotted.
202+
197203
Returns
198204
-------
199205
None
@@ -230,7 +236,7 @@ def render_points(
230236
color: str | None = None,
231237
groups: str | Sequence[str] | None = None,
232238
size: float = 1.0,
233-
palette: Palette_t = None,
239+
palette: ListedColormap | str | None = None,
234240
cmap: Colormap | str | None = None,
235241
norm: None | Normalize = None,
236242
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
@@ -295,11 +301,12 @@ def render_images(
295301
self,
296302
elements: str | list[str] | None = None,
297303
channel: list[str] | list[int] | int | str | None = None,
298-
cmap: Colormap | str | None = None,
304+
cmap: list[Colormap] | list[str] | Colormap | str | None = None,
299305
norm: None | Normalize = None,
300306
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
301-
palette: Palette_t = None,
307+
palette: ListedColormap | str | None = None,
302308
alpha: float = 1.0,
309+
quantiles_for_norm: tuple[float | None, float | None] = (3.0, 99.8), # defaults from CSBDeep
303310
**kwargs: Any,
304311
) -> sd.SpatialData:
305312
"""
@@ -320,6 +327,8 @@ def render_images(
320327
Color to be used for NAs values, if present.
321328
alpha
322329
Alpha value for the shapes.
330+
quantiles_for_norm
331+
Tuple of (pmin, pmax) which will be used for quantile normalization.
323332
kwargs
324333
Additional arguments to be passed to cmap and norm.
325334
@@ -331,18 +340,36 @@ def render_images(
331340
sdata = _verify_plotting_tree(sdata)
332341
n_steps = len(sdata.plotting_tree.keys())
333342

334-
cmap_params = _prepare_cmap_norm(
335-
cmap=cmap,
336-
norm=norm,
337-
na_color=na_color, # type: ignore[arg-type]
338-
**kwargs,
339-
)
343+
if channel is None and cmap is None:
344+
cmap = "brg"
345+
346+
cmap_params: list[CmapParams] | CmapParams
347+
if isinstance(cmap, list):
348+
cmap_params = [
349+
_prepare_cmap_norm(
350+
cmap=c,
351+
norm=norm,
352+
na_color=na_color, # type: ignore[arg-type]
353+
**kwargs,
354+
)
355+
for c in cmap
356+
]
357+
358+
else:
359+
cmap_params = _prepare_cmap_norm(
360+
cmap=cmap,
361+
norm=norm,
362+
na_color=na_color, # type: ignore[arg-type]
363+
**kwargs,
364+
)
365+
340366
sdata.plotting_tree[f"{n_steps+1}_render_images"] = ImageRenderParams(
341367
elements=elements,
342368
channel=channel,
343369
cmap_params=cmap_params,
344370
palette=palette,
345371
alpha=alpha,
372+
quantiles_for_norm=quantiles_for_norm,
346373
)
347374

348375
return sdata
@@ -356,7 +383,7 @@ def render_labels(
356383
outline: bool = False,
357384
alt_var: str | None = None,
358385
layer: str | None = None,
359-
palette: Palette_t = None,
386+
palette: ListedColormap | str | None = None,
360387
cmap: Colormap | str | None = None,
361388
norm: None | Normalize = None,
362389
na_color: str | tuple[float, ...] | None = (0.0, 0.0, 0.0, 0.0),
@@ -454,6 +481,7 @@ def show(
454481
fig: Figure | None = None,
455482
title: None | str | Sequence[str] = None,
456483
share_extent: bool = True,
484+
pad_extent: int = 0,
457485
ax: Axes | Sequence[Axes] | None = None,
458486
return_ax: bool = False,
459487
save: None | str | Path = None,
@@ -511,7 +539,7 @@ def show(
511539
render_cmds[cmd] = params
512540

513541
if len(render_cmds.keys()) == 0:
514-
raise TypeError("Please specify what to plot using the 'render_*' functions before calling 'imshow().")
542+
raise TypeError("Please specify what to plot using the 'render_*' functions before calling 'imshow()'.")
515543

516544
if title is not None:
517545
if isinstance(title, str):
@@ -520,8 +548,13 @@ def show(
520548
if not all(isinstance(t, str) for t in title):
521549
raise TypeError("All titles must be strings.")
522550

523-
# Simplicstic solution: If the images are multiscale, just use the first
524-
sdata = _multiscale_to_image(sdata)
551+
# get original axis extent for later comparison
552+
x_min_orig, x_max_orig = (np.inf, -np.inf)
553+
y_min_orig, y_max_orig = (np.inf, -np.inf)
554+
555+
if isinstance(ax, Axes) and _mpl_ax_contains_elements(ax):
556+
x_min_orig, x_max_orig = ax.get_xlim()
557+
y_max_orig, y_min_orig = ax.get_ylim() # (0, 0) is top-left
525558

526559
# handle coordinate system
527560
coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems
@@ -532,12 +565,38 @@ def show(
532565
if cs not in sdata.coordinate_systems:
533566
raise ValueError(f"Unknown coordinate system '{cs}', valid choices are: {sdata.coordinate_systems}")
534567

568+
# Check if user specified only certain elements to be plotted
569+
cs_contents = _get_cs_contents(sdata)
570+
elements_to_be_rendered = []
571+
for cmd, params in render_cmds.items():
572+
if cmd == "render_images" and cs_contents.query(f"cs == '{cs}'")["has_images"][0]: # noqa: SIM114
573+
if params.elements is not None:
574+
elements_to_be_rendered += (
575+
[params.elements] if isinstance(params.elements, str) else params.elements
576+
)
577+
elif cmd == "render_shapes" and cs_contents.query(f"cs == '{cs}'")["has_shapes"][0]: # noqa: SIM114
578+
if params.elements is not None:
579+
elements_to_be_rendered += (
580+
[params.elements] if isinstance(params.elements, str) else params.elements
581+
)
582+
elif cmd == "render_points" and cs_contents.query(f"cs == '{cs}'")["has_points"][0]: # noqa: SIM114
583+
if params.elements is not None:
584+
elements_to_be_rendered += (
585+
[params.elements] if isinstance(params.elements, str) else params.elements
586+
)
587+
elif cmd == "render_labels" and cs_contents.query(f"cs == '{cs}'")["has_labels"][0]: # noqa: SIM102
588+
if params.elements is not None:
589+
elements_to_be_rendered += (
590+
[params.elements] if isinstance(params.elements, str) else params.elements
591+
)
592+
535593
extent = _get_extent(
536594
sdata=sdata,
537595
has_images="render_images" in render_cmds,
538596
has_labels="render_labels" in render_cmds,
539597
has_points="render_points" in render_cmds,
540598
has_shapes="render_shapes" in render_cmds,
599+
elements=elements_to_be_rendered,
541600
coordinate_systems=coordinate_systems,
542601
)
543602

@@ -550,19 +609,6 @@ def show(
550609
logg.info(f"Dropping coordinate system '{cs}' since it doesn't have relevant elements.")
551610
coordinate_systems = valid_cs
552611

553-
# print(coordinate_systems)
554-
# cs_mapping = _get_coordinate_system_mapping(sdata)
555-
# print(cs_mapping)
556-
557-
# check that coordinate system and elements to be rendered match
558-
# for cmd, params in render_cmds.items():
559-
# if params.elements is not None and len([params.elements]) != len(coordinate_systems):
560-
# print(params.elements)
561-
# raise ValueError(
562-
# f"Number of coordinate systems ({len(coordinate_systems)}) does not match number of elements "
563-
# f"({len(params.elements)}) in command {cmd}."
564-
# )
565-
566612
# set up canvas
567613
fig_params, scalebar_params = _prepare_params_plot(
568614
num_panels=len(coordinate_systems),
@@ -585,7 +631,6 @@ def show(
585631
)
586632

587633
# go through tree
588-
cs_contents = _get_cs_contents(sdata)
589634
for i, cs in enumerate(coordinate_systems):
590635
sdata = self._copy()
591636
# properly transform all elements to the current coordinate system
@@ -693,12 +738,10 @@ def show(
693738
]
694739
):
695740
# If the axis already has limits, only expand them but not overwrite
696-
x_min, x_max = ax.get_xlim()
697-
y_min, y_max = ax.get_ylim()
698-
x_min = min(x_min, extent[cs][0])
699-
x_max = max(x_max, extent[cs][1])
700-
y_min = min(y_min, extent[cs][2])
701-
y_max = max(y_max, extent[cs][3])
741+
x_min = min(x_min_orig, extent[cs][0]) - pad_extent
742+
x_max = max(x_max_orig, extent[cs][1]) + pad_extent
743+
y_min = min(y_min_orig, extent[cs][2]) - pad_extent
744+
y_max = max(y_max_orig, extent[cs][3]) + pad_extent
702745
ax.set_xlim(x_min, x_max)
703746
ax.set_ylim(y_max, y_min) # (0, 0) is top-left
704747

0 commit comments

Comments
 (0)