Skip to content

Commit a1788c8

Browse files
Updated render_shapes doc (#199)
* Updated render_shapes doc * Updated render_shapes * halfway points * Updated type hints and handling * Fixed color behaviour * Changed source of logger; fixed type in test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typos in documentation; minor correction to type checks * Fixed test * Added tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent adb6bb8 commit a1788c8

File tree

8 files changed

+633
-214
lines changed

8 files changed

+633
-214
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 560 additions & 157 deletions
Large diffs are not rendered by default.

src/spatialdata_plot/pl/render.py

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _render_shapes(
8282

8383
for e in elements:
8484
shapes = sdata.shapes[e]
85-
n_shapes = sum([len(s) for s in shapes])
85+
n_shapes = sum(len(s) for s in shapes)
8686

8787
if sdata.table is None:
8888
table = AnnData(None, obs=pd.DataFrame(index=pd.Index(np.arange(n_shapes), dtype=str)))
@@ -94,11 +94,11 @@ def _render_shapes(
9494
sdata=sdata_filt,
9595
element=sdata_filt.shapes[e],
9696
element_name=e,
97-
value_to_plot=render_params.color,
97+
value_to_plot=render_params.col_for_color,
9898
layer=render_params.layer,
9999
groups=render_params.groups,
100100
palette=render_params.palette,
101-
na_color=render_params.cmap_params.na_color,
101+
na_color=render_params.color or render_params.cmap_params.na_color,
102102
alpha=render_params.fill_alpha,
103103
cmap_params=render_params.cmap_params,
104104
)
@@ -162,14 +162,18 @@ def _render_shapes(
162162
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
163163
):
164164
# necessary in case different shapes elements are annotated with one table
165-
if color_source_vector is not None:
165+
if color_source_vector is not None and render_params.col_for_color is not None:
166166
color_source_vector = color_source_vector.remove_unused_categories()
167+
168+
# False if user specified color-like with 'color' parameter
169+
colorbar = False if render_params.col_for_color is None else legend_params.colorbar
170+
167171
_ = _decorate_axs(
168172
ax=ax,
169173
cax=cax,
170174
fig_params=fig_params,
171175
adata=table,
172-
value_to_plot=render_params.color,
176+
value_to_plot=render_params.col_for_color,
173177
color_source_vector=color_source_vector,
174178
palette=palette,
175179
alpha=render_params.fill_alpha,
@@ -179,7 +183,7 @@ def _render_shapes(
179183
legend_loc=legend_params.legend_loc,
180184
legend_fontoutline=legend_params.legend_fontoutline,
181185
na_in_legend=legend_params.na_in_legend,
182-
colorbar=legend_params.colorbar,
186+
colorbar=colorbar,
183187
scalebar_dx=scalebar_params.scalebar_dx,
184188
scalebar_units=scalebar_params.scalebar_units,
185189
)
@@ -194,12 +198,6 @@ def _render_points(
194198
scalebar_params: ScalebarParams,
195199
legend_params: LegendParams,
196200
) -> None:
197-
if render_params.groups is not None:
198-
if isinstance(render_params.groups, str):
199-
render_params.groups = [render_params.groups]
200-
if not all(isinstance(g, str) for g in render_params.groups):
201-
raise TypeError("All groups must be strings.")
202-
203201
elements = render_params.elements
204202

205203
sdata_filt = sdata.filter_by_coordinate_system(
@@ -214,43 +212,56 @@ def _render_points(
214212

215213
for e in elements:
216214
points = sdata.points[e]
215+
col_for_color = render_params.col_for_color
216+
217217
coords = ["x", "y"]
218-
if render_params.color is not None:
219-
color = [render_params.color] if isinstance(render_params.color, str) else render_params.color
220-
coords.extend(color)
218+
if col_for_color is not None:
219+
if col_for_color not in points.columns:
220+
# no error in case there are multiple elements, but onyl some have color key
221+
msg = f"Color key '{col_for_color}' for element '{e}' not been found, using default colors."
222+
logger.warning(msg)
223+
else:
224+
coords += [col_for_color]
221225

222226
points = points[coords].compute()
223-
if render_params.groups is not None:
224-
points = points[points[color].isin(render_params.groups).values]
225-
points[color[0]] = points[color[0]].cat.set_categories(render_params.groups)
226-
points = dask.dataframe.from_pandas(points, npartitions=1)
227-
sdata_filt.points[e] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"})
228-
229-
point_df = points[coords].compute()
227+
if render_params.groups is not None and col_for_color is not None:
228+
points = points[points[col_for_color].isin(render_params.groups)]
230229

231230
# we construct an anndata to hack the plotting functions
232231
adata = AnnData(
233-
X=point_df[["x", "y"]].values, obs=point_df[coords].reset_index(), dtype=point_df[["x", "y"]].values.dtype
232+
X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype
234233
)
235-
if render_params.color is not None:
236-
cols = sc.get.obs_df(adata, render_params.color)
234+
235+
# Convert back to dask dataframe to modify sdata
236+
points = dask.dataframe.from_pandas(points, npartitions=1)
237+
sdata_filt.points[e] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"})
238+
239+
if render_params.col_for_color is not None:
240+
cols = sc.get.obs_df(adata, render_params.col_for_color)
237241
# maybe set color based on type
238242
if is_categorical_dtype(cols):
239243
_maybe_set_colors(
240244
source=adata,
241245
target=adata,
242-
key=render_params.color,
246+
key=render_params.col_for_color,
243247
palette=render_params.palette,
244248
)
245249

250+
# when user specified a single color, we overwrite na with it
251+
default_color = (
252+
render_params.color
253+
if render_params.col_for_color is None and render_params.color is not None
254+
else render_params.cmap_params.na_color
255+
)
256+
246257
color_source_vector, color_vector, _ = _set_color_source_vec(
247258
sdata=sdata_filt,
248259
element=points,
249260
element_name=e,
250-
value_to_plot=render_params.color,
261+
value_to_plot=render_params.col_for_color,
251262
groups=render_params.groups,
252263
palette=render_params.palette,
253-
na_color=render_params.cmap_params.na_color,
264+
na_color=default_color,
254265
alpha=render_params.alpha,
255266
cmap_params=render_params.cmap_params,
256267
)
@@ -278,9 +289,7 @@ def _render_points(
278289
)
279290
cax = ax.add_collection(_cax)
280291

281-
if not (
282-
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
283-
):
292+
if len(set(color_vector)) != 1 or list(set(color_vector))[0] != to_hex(render_params.cmap_params.na_color):
284293
if color_source_vector is None:
285294
palette = ListedColormap(dict.fromkeys(color_vector))
286295
else:
@@ -291,7 +300,7 @@ def _render_points(
291300
cax=cax,
292301
fig_params=fig_params,
293302
adata=adata,
294-
value_to_plot=render_params.color,
303+
value_to_plot=render_params.col_for_color,
295304
color_source_vector=color_source_vector,
296305
palette=palette,
297306
alpha=render_params.alpha,
@@ -629,8 +638,8 @@ def _render_labels(
629638
_cax = ax.imshow(
630639
labels_infill,
631640
rasterized=True,
632-
cmap=render_params.cmap_params.cmap if not categorical else None,
633-
norm=render_params.cmap_params.norm if not categorical else None,
641+
cmap=None if categorical else render_params.cmap_params.cmap,
642+
norm=None if categorical else render_params.cmap_params.norm,
634643
alpha=render_params.fill_alpha,
635644
origin="lower",
636645
)
@@ -652,14 +661,11 @@ def _render_labels(
652661
_cax = ax.imshow(
653662
labels_contour,
654663
rasterized=True,
655-
cmap=render_params.cmap_params.cmap if not categorical else None,
656-
norm=render_params.cmap_params.norm if not categorical else None,
664+
cmap=None if categorical else render_params.cmap_params.cmap,
665+
norm=None if categorical else render_params.cmap_params.norm,
657666
alpha=render_params.outline_alpha,
658667
origin="lower",
659668
)
660-
_cax.set_transform(trans_data)
661-
cax = ax.add_image(_cax)
662-
663669
else:
664670
# Default: no alpha, contour = infill
665671
label = _map_color_seg(
@@ -676,13 +682,13 @@ def _render_labels(
676682
_cax = ax.imshow(
677683
label,
678684
rasterized=True,
679-
cmap=render_params.cmap_params.cmap if not categorical else None,
680-
norm=render_params.cmap_params.norm if not categorical else None,
685+
cmap=None if categorical else render_params.cmap_params.cmap,
686+
norm=None if categorical else render_params.cmap_params.norm,
681687
alpha=render_params.fill_alpha,
682688
origin="lower",
683689
)
684-
_cax.set_transform(trans_data)
685-
cax = ax.add_image(_cax)
690+
_cax.set_transform(trans_data)
691+
cax = ax.add_image(_cax)
686692

687693
_ = _decorate_axs(
688694
ax=ax,

src/spatialdata_plot/pl/render_params.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class ShapesRenderParams:
7272
outline_params: OutlineParams
7373
elements: str | Sequence[str] | None = None
7474
color: str | None = None
75+
col_for_color: str | None = None
7576
groups: str | Sequence[str] | None = None
7677
contour_px: int | None = None
7778
layer: str | None = None
@@ -89,6 +90,7 @@ class PointsRenderParams:
8990
cmap_params: CmapParams
9091
elements: str | Sequence[str] | None = None
9192
color: str | None = None
93+
col_for_color: str | None = None
9294
groups: str | Sequence[str] | None = None
9395
palette: ListedColormap | str | None = None
9496
alpha: float = 1.0

src/spatialdata_plot/pl/utils.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@
5252
from spatial_image import SpatialImage
5353
from spatialdata._core.operations.rasterize import rasterize
5454
from spatialdata._core.query.relational_query import _locate_value, get_values
55-
from spatialdata._logging import logger as logging
5655
from spatialdata._types import ArrayLike
5756
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement
5857

58+
from spatialdata_plot._logging import logger
5959
from spatialdata_plot.pl.render_params import (
6060
CmapParams,
6161
FigParams,
@@ -379,7 +379,7 @@ def _set_outline(
379379
if outline_width == 0.0:
380380
outline = False
381381
if outline_width < 0.0:
382-
logging.warning(f"Negative line widths are not allowed, changing {outline_width} to {(-1)*outline_width}")
382+
logger.warning(f"Negative line widths are not allowed, changing {outline_width} to {(-1)*outline_width}")
383383
outline_width *= -1
384384

385385
# the default black and white colors can be changed using the contour_config parameter
@@ -561,7 +561,7 @@ def _get_colors_for_categorical_obs(
561561
palette = default_102
562562
else:
563563
palette = ["grey" for _ in range(len_cat)]
564-
logging.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
564+
logger.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
565565
else:
566566
# raise error when user didn't provide the right number of colors in palette
567567
if isinstance(palette, list) and len(palette) != len(categories):
@@ -623,7 +623,7 @@ def _set_color_source_vec(
623623
# numerical case, return early
624624
if not is_categorical_dtype(color_source_vector):
625625
if palette is not None:
626-
logging.warning(
626+
logger.warning(
627627
"Ignoring categorical palette which is given for a continuous variable. "
628628
"Consider using `cmap` to pass a ColorMap."
629629
)
@@ -651,7 +651,7 @@ def _set_color_source_vec(
651651

652652
return color_source_vector, color_vector, True
653653

654-
logging.warning(f"Color key '{value_to_plot}' for element '{element_name}' not been found, using default colors.")
654+
logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not been found, using default colors.")
655655
color = np.full(sdata.table.n_obs, to_hex(na_color))
656656
return color, color, False
657657

@@ -723,7 +723,7 @@ def _get_palette(
723723
)
724724
return {cat: to_hex(to_rgba(col)[:3]) for cat, col in zip(categories, palette)}
725725
except KeyError as e:
726-
logging.warning(e)
726+
logger.warning(e)
727727
return None
728728

729729
len_cat = len(categories)
@@ -737,7 +737,7 @@ def _get_palette(
737737
palette = default_102
738738
else:
739739
palette = ["grey" for _ in range(len_cat)]
740-
logging.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
740+
logger.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
741741
return {cat: to_hex(to_rgba(col)[:3]) for cat, col in zip(categories, palette[:len_cat])}
742742

743743
if isinstance(palette, str):
@@ -904,9 +904,9 @@ def save_fig(fig: Figure, path: str | Path, make_dir: bool = True, ext: str = "p
904904
try:
905905
path.parent.mkdir(parents=True, exist_ok=True)
906906
except OSError as e:
907-
logging.debug(f"Unable to create directory `{path.parent}`. Reason: `{e}`")
907+
logger.debug(f"Unable to create directory `{path.parent}`. Reason: `{e}`")
908908

909-
logging.debug(f"Saving figure to `{path!r}`")
909+
logger.debug(f"Saving figure to `{path!r}`")
910910

911911
kwargs.setdefault("bbox_inches", "tight")
912912
kwargs.setdefault("transparent", True)
@@ -1070,13 +1070,13 @@ def _mpl_ax_contains_elements(ax: Axes) -> bool:
10701070

10711071
def _get_valid_cs(
10721072
sdata: sd.SpatialData,
1073-
coordinate_systems: Sequence[str],
1073+
coordinate_systems: list[str],
10741074
render_images: bool,
10751075
render_labels: bool,
10761076
render_points: bool,
10771077
render_shapes: bool,
10781078
elements: list[str],
1079-
) -> Sequence[str]:
1079+
) -> list[str]:
10801080
"""Get names of the valid coordinate systems.
10811081
10821082
Valid cs are cs that contain elements to be rendered:
@@ -1090,8 +1090,10 @@ def _get_valid_cs(
10901090
cs_mapping = _get_coordinate_system_mapping(sdata)
10911091
valid_cs = []
10921092
for cs in coordinate_systems:
1093-
if (len(elements) > 0 and any(e in elements for e in cs_mapping[cs])) or (
1094-
len(elements) == 0
1093+
if (
1094+
elements
1095+
and any(e in elements for e in cs_mapping[cs])
1096+
or not elements
10951097
and (
10961098
(len(sdata.images.keys()) > 0 and render_images)
10971099
or (len(sdata.labels.keys()) > 0 and render_labels)
@@ -1101,7 +1103,7 @@ def _get_valid_cs(
11011103
): # not nice, but ruff wants it (SIM114)
11021104
valid_cs.append(cs)
11031105
else:
1104-
logging.info(f"Dropping coordinate system '{cs}' since it doesn't have relevant elements.")
1106+
logger.info(f"Dropping coordinate system '{cs}' since it doesn't have relevant elements.")
11051107
return valid_cs
11061108

11071109

Loading
Loading

tests/pl/test_render_points.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,6 @@ def test_plot_can_stack_render_points(self, sdata_blobs: SpatialData):
3737
.pl.render_points(elements="blobs_points", na_color="blue", size=10)
3838
.pl.show()
3939
)
40+
41+
def test_plot_color_recognises_actual_color_as_color(self, sdata_blobs: SpatialData):
42+
sdata_blobs.pl.render_points(elements="blobs_points", color="red").pl.show()

tests/pl/test_render_shapes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,6 @@ def test_plot_can_stack_render_shapes(self, sdata_blobs: SpatialData):
261261
.pl.render_shapes(elements="blobs_polygons", na_color="blue", fill_alpha=0.5)
262262
.pl.show()
263263
)
264+
265+
def test_plot_color_recognises_actual_color_as_color(self, sdata_blobs: SpatialData):
266+
(sdata_blobs.pl.render_shapes(elements="blobs_circles", color="red").pl.show())

0 commit comments

Comments
 (0)