Skip to content

Can now specify which 'layer' of the AnnData table to use #402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dynamic= [
license = {file = "LICENSE"}
readme = "README.md"
dependencies = [
"spatialdata>=0.2.6",
"spatialdata>=0.3.0",
"matplotlib",
"scikit-learn",
"scanpy",
Expand Down
27 changes: 22 additions & 5 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def render_shapes(
scale: float | int = 1.0,
method: str | None = None,
table_name: str | None = None,
table_layer: str | None = None,
**kwargs: Any,
) -> sd.SpatialData:
"""
Expand Down Expand Up @@ -228,6 +229,9 @@ def render_shapes(
Name of the table containing the color(s) columns. If one name is given than the table is used for each
spatial element to be plotted if the table annotates it. If you want to use different tables for particular
elements, as specified under element.
table_layer: str | None
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
:attr:`sdata.table.X` is used for coloring.

**kwargs : Any
Additional arguments for customization. This can include:
Expand Down Expand Up @@ -271,6 +275,7 @@ def render_shapes(
norm=norm,
scale=scale,
table_name=table_name,
table_layer=table_layer,
method=method,
ds_reduction=kwargs.get("datashader_reduction"),
)
Expand Down Expand Up @@ -298,6 +303,7 @@ def render_shapes(
fill_alpha=param_values["fill_alpha"],
transfunc=kwargs.get("transfunc"),
table_name=param_values["table_name"],
table_layer=param_values["table_layer"],
zorder=n_steps,
method=param_values["method"],
ds_reduction=param_values["ds_reduction"],
Expand All @@ -320,6 +326,7 @@ def render_points(
size: float | int = 1.0,
method: str | None = None,
table_name: str | None = None,
table_layer: str | None = None,
**kwargs: Any,
) -> sd.SpatialData:
"""
Expand Down Expand Up @@ -370,6 +377,9 @@ def render_points(
Name of the table containing the color(s) columns. If one name is given than the table is used for each
spatial element to be plotted if the table annotates it. If you want to use different tables for particular
elements, as specified under element.
table_layer: str | None
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
:attr:`sdata.table.X` is used for coloring.

**kwargs : Any
Additional arguments for customization. This can include:
Expand Down Expand Up @@ -403,6 +413,7 @@ def render_points(
norm=norm,
size=size,
table_name=table_name,
table_layer=table_layer,
ds_reduction=kwargs.get("datashader_reduction"),
)

Expand Down Expand Up @@ -433,6 +444,7 @@ def render_points(
transfunc=kwargs.get("transfunc"),
size=param_values["size"],
table_name=param_values["table_name"],
table_layer=param_values["table_layer"],
zorder=n_steps,
method=method,
ds_reduction=param_values["ds_reduction"],
Expand Down Expand Up @@ -573,6 +585,7 @@ def render_labels(
fill_alpha: float | int = 0.4,
scale: str | None = None,
table_name: str | None = None,
table_layer: str | None = None,
**kwargs: Any,
) -> sd.SpatialData:
"""
Expand All @@ -590,10 +603,10 @@ def render_labels(
The name of the labels element to render. If `None`, all label
elements in the `SpatialData` object will be used and all parameters will be broadcasted if possible.
color : list[str] | str | None
Can either be string representing a color-like or key in :attr:`sdata.table.obs`. The latter can be used to
color by categorical or continuous variables. If the color column is found in multiple locations, please
provide the table_name to be used for the element if you would like a specific table to be used. By default
one table will automatically be choosen.
Can either be string representing a color-like or key in :attr:`sdata.table.obs` or in the index of
:attr:`sdata.table.var`. The latter can be used to color by categorical or continuous variables. If the
color column is found in multiple locations, please provide the table_name to be used for the element if you
would like a specific table to be used. By default one table will automatically be choosen.
groups : list[str] | str | None
When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of
them. Other values are set to NA. The list can contain multiple discrete labels to be visualized.
Expand Down Expand Up @@ -626,6 +639,9 @@ def render_labels(
with the highest resolution is selected. This can lead to long computing times for large images!
table_name: str | None
Name of the table containing the color columns.
table_layer: str | None
Layer of the AnnData table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None,
:attr:`sdata.table.X` of the default table is used for coloring.
kwargs
Additional arguments to be passed to cmap and norm.

Expand Down Expand Up @@ -654,6 +670,7 @@ def render_labels(
palette=palette,
scale=scale,
table_name=table_name,
table_layer=table_layer,
)

sdata = self._copy()
Expand All @@ -678,6 +695,7 @@ def render_labels(
transfunc=kwargs.get("transfunc"),
scale=param_values["scale"],
table_name=param_values["table_name"],
table_layer=param_values["table_layer"],
zorder=n_steps,
)
n_steps += 1
Expand Down Expand Up @@ -811,7 +829,6 @@ def show(
ax_x_min, ax_x_max = ax.get_xlim()
ax_y_max, ax_y_min = ax.get_ylim() # (0, 0) is top-left

# handle coordinate system
coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems
if isinstance(coordinate_systems, str):
coordinate_systems = [coordinate_systems]
Expand Down
49 changes: 44 additions & 5 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from matplotlib.cm import ScalarMappable
from matplotlib.colors import ListedColormap, Normalize
from scanpy._settings import settings as sc_settings
from spatialdata import get_extent, join_spatialelement_table
from spatialdata import get_extent, get_values, join_spatialelement_table
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
from spatialdata.transformations import get_transformation, set_transformation
from spatialdata.transformations.transformations import Identity
Expand Down Expand Up @@ -70,6 +70,7 @@ def _render_shapes(
element = render_params.element
col_for_color = render_params.col_for_color
groups = render_params.groups
table_layer = render_params.table_layer

sdata_filt = sdata.filter_by_coordinate_system(
coordinate_system=coordinate_system,
Expand Down Expand Up @@ -115,6 +116,7 @@ def _render_shapes(
na_color=render_params.color or render_params.cmap_params.na_color,
cmap_params=render_params.cmap_params,
table_name=table_name,
table_layer=table_layer,
)

values_are_categorical = color_source_vector is not None
Expand Down Expand Up @@ -397,6 +399,7 @@ def _render_points(
element = render_params.element
col_for_color = render_params.col_for_color
table_name = render_params.table_name
table_layer = render_params.table_layer
color = render_params.color
groups = render_params.groups
palette = render_params.palette
Expand All @@ -409,10 +412,22 @@ def _render_points(
points = sdata.points[element]
coords = ["x", "y"]

if col_for_color is None or (table_name is not None and col_for_color in sdata_filt[table_name].obs.columns):
if table_name is not None and col_for_color not in points.columns:
warnings.warn(
f"Annotating points with {col_for_color} which is stored in the table `{table_name}`. "
f"To improve performance, it is advisable to store point annotations directly in the .parquet file.",
UserWarning,
stacklevel=2,
)

if col_for_color is None or (
table_name is not None
and (col_for_color in sdata_filt[table_name].obs.columns or col_for_color in sdata_filt[table_name].var_names)
):
points = points[coords].compute()
if (
col_for_color
and col_for_color in sdata_filt[table_name].obs.columns
and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O"
and not _is_coercable_to_float(color_col)
):
Expand All @@ -428,7 +443,20 @@ def _render_points(
points = points[coords].compute()

if groups is not None and col_for_color is not None:
points = points[points[col_for_color].isin(groups)]
if col_for_color in points.columns:
points_color_values = points[col_for_color]
else:
points_color_values = get_values(
value_key=col_for_color,
sdata=sdata_filt,
element_name=element,
table_name=table_name,
table_layer=table_layer,
)
points_color_values = points.merge(points_color_values, how="left", left_index=True, right_index=True)[
col_for_color
]
points = points[points_color_values.isin(groups)]
if len(points) <= 0:
raise ValueError(f"None of the groups {groups} could be found in the column '{col_for_color}'.")

Expand All @@ -438,9 +466,18 @@ def _render_points(
X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype
)
else:
adata_obs = sdata_filt[table_name].obs
# if the points are colored by values in X (or a different layer), add the values to obs
if col_for_color in sdata_filt[table_name].var_names:
if table_layer is None:
adata_obs[col_for_color] = sdata_filt[table_name][:, col_for_color].X.flatten().copy()
else:
adata_obs[col_for_color] = sdata_filt[table_name][:, col_for_color].layers[table_layer].flatten().copy()
if groups is not None:
adata_obs = adata_obs[adata_obs[col_for_color].isin(groups)]
adata = AnnData(
X=points[["x", "y"]].values,
obs=sdata_filt[table_name].obs,
obs=adata_obs,
dtype=points[["x", "y"]].values.dtype,
uns=sdata_filt[table_name].uns,
)
Expand Down Expand Up @@ -847,6 +884,7 @@ def _render_labels(
) -> None:
element = render_params.element
table_name = render_params.table_name
table_layer = render_params.table_layer
palette = render_params.palette
color = render_params.color
groups = render_params.groups
Expand Down Expand Up @@ -882,7 +920,7 @@ def _render_labels(
extent=extent,
)

# the avove adds a useless c dimension of 1 (y, x) -> (1, y, x)
# the above adds a useless c dimension of 1 (y, x) -> (1, y, x)
label = label.squeeze()

if table_name is None:
Expand All @@ -907,6 +945,7 @@ def _render_labels(
na_color=render_params.cmap_params.na_color,
cmap_params=render_params.cmap_params,
table_name=table_name,
table_layer=table_layer,
)

def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) -> matplotlib.image.AxesImage:
Expand Down
3 changes: 3 additions & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class ShapesRenderParams:
method: str | None = None
zorder: int = 0
table_name: str | None = None
table_layer: str | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None


Expand All @@ -108,6 +109,7 @@ class PointsRenderParams:
method: str | None = None
zorder: int = 0
table_name: str | None = None
table_layer: str | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None


Expand Down Expand Up @@ -141,4 +143,5 @@ class LabelsRenderParams:
transfunc: Callable[[float], float] | None = None
scale: str | None = None
table_name: str | None = None
table_layer: str | None = None
zorder: int = 0
Loading
Loading