Skip to content

Commit 16aa9e9

Browse files
Sonja-StockhausSonja Stockhauspre-commit-ci[bot]timtreis
authored
Can now specify which 'layer' of the AnnData table to use (#402)
Co-authored-by: Sonja Stockhaus <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Treis <[email protected]>
1 parent b91de80 commit 16aa9e9

14 files changed

+232
-49
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ dynamic= [
2121
license = {file = "LICENSE"}
2222
readme = "README.md"
2323
dependencies = [
24-
"spatialdata>=0.2.6",
24+
"spatialdata>=0.3.0",
2525
"matplotlib",
2626
"scikit-learn",
2727
"scanpy",

src/spatialdata_plot/pl/basic.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def render_shapes(
169169
scale: float | int = 1.0,
170170
method: str | None = None,
171171
table_name: str | None = None,
172+
table_layer: str | None = None,
172173
**kwargs: Any,
173174
) -> sd.SpatialData:
174175
"""
@@ -228,6 +229,9 @@ def render_shapes(
228229
Name of the table containing the color(s) columns. If one name is given than the table is used for each
229230
spatial element to be plotted if the table annotates it. If you want to use different tables for particular
230231
elements, as specified under element.
232+
table_layer: str | None
233+
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
234+
:attr:`sdata.table.X` is used for coloring.
231235
232236
**kwargs : Any
233237
Additional arguments for customization. This can include:
@@ -271,6 +275,7 @@ def render_shapes(
271275
norm=norm,
272276
scale=scale,
273277
table_name=table_name,
278+
table_layer=table_layer,
274279
method=method,
275280
ds_reduction=kwargs.get("datashader_reduction"),
276281
)
@@ -298,6 +303,7 @@ def render_shapes(
298303
fill_alpha=param_values["fill_alpha"],
299304
transfunc=kwargs.get("transfunc"),
300305
table_name=param_values["table_name"],
306+
table_layer=param_values["table_layer"],
301307
zorder=n_steps,
302308
method=param_values["method"],
303309
ds_reduction=param_values["ds_reduction"],
@@ -320,6 +326,7 @@ def render_points(
320326
size: float | int = 1.0,
321327
method: str | None = None,
322328
table_name: str | None = None,
329+
table_layer: str | None = None,
323330
**kwargs: Any,
324331
) -> sd.SpatialData:
325332
"""
@@ -370,6 +377,9 @@ def render_points(
370377
Name of the table containing the color(s) columns. If one name is given than the table is used for each
371378
spatial element to be plotted if the table annotates it. If you want to use different tables for particular
372379
elements, as specified under element.
380+
table_layer: str | None
381+
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
382+
:attr:`sdata.table.X` is used for coloring.
373383
374384
**kwargs : Any
375385
Additional arguments for customization. This can include:
@@ -403,6 +413,7 @@ def render_points(
403413
norm=norm,
404414
size=size,
405415
table_name=table_name,
416+
table_layer=table_layer,
406417
ds_reduction=kwargs.get("datashader_reduction"),
407418
)
408419

@@ -433,6 +444,7 @@ def render_points(
433444
transfunc=kwargs.get("transfunc"),
434445
size=param_values["size"],
435446
table_name=param_values["table_name"],
447+
table_layer=param_values["table_layer"],
436448
zorder=n_steps,
437449
method=method,
438450
ds_reduction=param_values["ds_reduction"],
@@ -573,6 +585,7 @@ def render_labels(
573585
fill_alpha: float | int = 0.4,
574586
scale: str | None = None,
575587
table_name: str | None = None,
588+
table_layer: str | None = None,
576589
**kwargs: Any,
577590
) -> sd.SpatialData:
578591
"""
@@ -590,10 +603,10 @@ def render_labels(
590603
The name of the labels element to render. If `None`, all label
591604
elements in the `SpatialData` object will be used and all parameters will be broadcasted if possible.
592605
color : list[str] | str | None
593-
Can either be string representing a color-like or key in :attr:`sdata.table.obs`. The latter can be used to
594-
color by categorical or continuous variables. If the color column is found in multiple locations, please
595-
provide the table_name to be used for the element if you would like a specific table to be used. By default
596-
one table will automatically be choosen.
606+
Can either be string representing a color-like or key in :attr:`sdata.table.obs` or in the index of
607+
:attr:`sdata.table.var`. The latter can be used to color by categorical or continuous variables. If the
608+
color column is found in multiple locations, please provide the table_name to be used for the element if you
609+
would like a specific table to be used. By default one table will automatically be choosen.
597610
groups : list[str] | str | None
598611
When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of
599612
them. Other values are set to NA. The list can contain multiple discrete labels to be visualized.
@@ -626,6 +639,9 @@ def render_labels(
626639
with the highest resolution is selected. This can lead to long computing times for large images!
627640
table_name: str | None
628641
Name of the table containing the color columns.
642+
table_layer: str | None
643+
Layer of the AnnData table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None,
644+
:attr:`sdata.table.X` of the default table is used for coloring.
629645
kwargs
630646
Additional arguments to be passed to cmap and norm.
631647
@@ -654,6 +670,7 @@ def render_labels(
654670
palette=palette,
655671
scale=scale,
656672
table_name=table_name,
673+
table_layer=table_layer,
657674
)
658675

659676
sdata = self._copy()
@@ -678,6 +695,7 @@ def render_labels(
678695
transfunc=kwargs.get("transfunc"),
679696
scale=param_values["scale"],
680697
table_name=param_values["table_name"],
698+
table_layer=param_values["table_layer"],
681699
zorder=n_steps,
682700
)
683701
n_steps += 1
@@ -811,7 +829,6 @@ def show(
811829
ax_x_min, ax_x_max = ax.get_xlim()
812830
ax_y_max, ax_y_min = ax.get_ylim() # (0, 0) is top-left
813831

814-
# handle coordinate system
815832
coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems
816833
if isinstance(coordinate_systems, str):
817834
coordinate_systems = [coordinate_systems]

src/spatialdata_plot/pl/render.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from matplotlib.cm import ScalarMappable
1818
from matplotlib.colors import ListedColormap, Normalize
1919
from scanpy._settings import settings as sc_settings
20-
from spatialdata import get_extent, join_spatialelement_table
20+
from spatialdata import get_extent, get_values, join_spatialelement_table
2121
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
2222
from spatialdata.transformations import get_transformation, set_transformation
2323
from spatialdata.transformations.transformations import Identity
@@ -70,6 +70,7 @@ def _render_shapes(
7070
element = render_params.element
7171
col_for_color = render_params.col_for_color
7272
groups = render_params.groups
73+
table_layer = render_params.table_layer
7374

7475
sdata_filt = sdata.filter_by_coordinate_system(
7576
coordinate_system=coordinate_system,
@@ -115,6 +116,7 @@ def _render_shapes(
115116
na_color=render_params.color or render_params.cmap_params.na_color,
116117
cmap_params=render_params.cmap_params,
117118
table_name=table_name,
119+
table_layer=table_layer,
118120
)
119121

120122
values_are_categorical = color_source_vector is not None
@@ -397,6 +399,7 @@ def _render_points(
397399
element = render_params.element
398400
col_for_color = render_params.col_for_color
399401
table_name = render_params.table_name
402+
table_layer = render_params.table_layer
400403
color = render_params.color
401404
groups = render_params.groups
402405
palette = render_params.palette
@@ -409,10 +412,22 @@ def _render_points(
409412
points = sdata.points[element]
410413
coords = ["x", "y"]
411414

412-
if col_for_color is None or (table_name is not None and col_for_color in sdata_filt[table_name].obs.columns):
415+
if table_name is not None and col_for_color not in points.columns:
416+
warnings.warn(
417+
f"Annotating points with {col_for_color} which is stored in the table `{table_name}`. "
418+
f"To improve performance, it is advisable to store point annotations directly in the .parquet file.",
419+
UserWarning,
420+
stacklevel=2,
421+
)
422+
423+
if col_for_color is None or (
424+
table_name is not None
425+
and (col_for_color in sdata_filt[table_name].obs.columns or col_for_color in sdata_filt[table_name].var_names)
426+
):
413427
points = points[coords].compute()
414428
if (
415429
col_for_color
430+
and col_for_color in sdata_filt[table_name].obs.columns
416431
and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O"
417432
and not _is_coercable_to_float(color_col)
418433
):
@@ -428,7 +443,20 @@ def _render_points(
428443
points = points[coords].compute()
429444

430445
if groups is not None and col_for_color is not None:
431-
points = points[points[col_for_color].isin(groups)]
446+
if col_for_color in points.columns:
447+
points_color_values = points[col_for_color]
448+
else:
449+
points_color_values = get_values(
450+
value_key=col_for_color,
451+
sdata=sdata_filt,
452+
element_name=element,
453+
table_name=table_name,
454+
table_layer=table_layer,
455+
)
456+
points_color_values = points.merge(points_color_values, how="left", left_index=True, right_index=True)[
457+
col_for_color
458+
]
459+
points = points[points_color_values.isin(groups)]
432460
if len(points) <= 0:
433461
raise ValueError(f"None of the groups {groups} could be found in the column '{col_for_color}'.")
434462

@@ -438,9 +466,18 @@ def _render_points(
438466
X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype
439467
)
440468
else:
469+
adata_obs = sdata_filt[table_name].obs
470+
# if the points are colored by values in X (or a different layer), add the values to obs
471+
if col_for_color in sdata_filt[table_name].var_names:
472+
if table_layer is None:
473+
adata_obs[col_for_color] = sdata_filt[table_name][:, col_for_color].X.flatten().copy()
474+
else:
475+
adata_obs[col_for_color] = sdata_filt[table_name][:, col_for_color].layers[table_layer].flatten().copy()
476+
if groups is not None:
477+
adata_obs = adata_obs[adata_obs[col_for_color].isin(groups)]
441478
adata = AnnData(
442479
X=points[["x", "y"]].values,
443-
obs=sdata_filt[table_name].obs,
480+
obs=adata_obs,
444481
dtype=points[["x", "y"]].values.dtype,
445482
uns=sdata_filt[table_name].uns,
446483
)
@@ -847,6 +884,7 @@ def _render_labels(
847884
) -> None:
848885
element = render_params.element
849886
table_name = render_params.table_name
887+
table_layer = render_params.table_layer
850888
palette = render_params.palette
851889
color = render_params.color
852890
groups = render_params.groups
@@ -882,7 +920,7 @@ def _render_labels(
882920
extent=extent,
883921
)
884922

885-
# the avove adds a useless c dimension of 1 (y, x) -> (1, y, x)
923+
# the above adds a useless c dimension of 1 (y, x) -> (1, y, x)
886924
label = label.squeeze()
887925

888926
if table_name is None:
@@ -907,6 +945,7 @@ def _render_labels(
907945
na_color=render_params.cmap_params.na_color,
908946
cmap_params=render_params.cmap_params,
909947
table_name=table_name,
948+
table_layer=table_layer,
910949
)
911950

912951
def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) -> matplotlib.image.AxesImage:

src/spatialdata_plot/pl/render_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class ShapesRenderParams:
8989
method: str | None = None
9090
zorder: int = 0
9191
table_name: str | None = None
92+
table_layer: str | None = None
9293
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
9394

9495

@@ -108,6 +109,7 @@ class PointsRenderParams:
108109
method: str | None = None
109110
zorder: int = 0
110111
table_name: str | None = None
112+
table_layer: str | None = None
111113
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
112114

113115

@@ -141,4 +143,5 @@ class LabelsRenderParams:
141143
transfunc: Callable[[float], float] | None = None
142144
scale: str | None = None
143145
table_name: str | None = None
146+
table_layer: str | None = None
144147
zorder: int = 0

0 commit comments

Comments
 (0)