Skip to content

Fix norm behavior, add tests #419

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/extensions/typed_returns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def _process_return(lines: Iterable[str]) -> Generator[str, None, None]:
for line in lines:
if m := re.fullmatch(r"(?P<param>\w+)\s+:\s+(?P<type>[\w.]+)", line):
yield f'-{m["param"]} (:class:`~{m["type"]}`)'
yield f"-{m['param']} (:class:`~{m['type']}`)"
else:
yield line

Expand Down
9 changes: 4 additions & 5 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def render_shapes(
norm=norm,
na_color=params_dict[element]["na_color"], # type: ignore[arg-type]
)
sdata.plotting_tree[f"{n_steps+1}_render_shapes"] = ShapesRenderParams(
sdata.plotting_tree[f"{n_steps + 1}_render_shapes"] = ShapesRenderParams(
element=element,
color=param_values["color"],
col_for_color=param_values["col_for_color"],
Expand Down Expand Up @@ -433,7 +433,7 @@ def render_points(
norm=norm,
na_color=param_values["na_color"], # type: ignore[arg-type]
)
sdata.plotting_tree[f"{n_steps+1}_render_points"] = PointsRenderParams(
sdata.plotting_tree[f"{n_steps + 1}_render_points"] = PointsRenderParams(
element=element,
color=param_values["color"],
col_for_color=param_values["col_for_color"],
Expand Down Expand Up @@ -538,7 +538,6 @@ def render_images(
n_steps = len(sdata.plotting_tree.keys())

for element, param_values in params_dict.items():

cmap_params: list[CmapParams] | CmapParams
if isinstance(cmap, list):
cmap_params = [
Expand All @@ -557,7 +556,7 @@ def render_images(
na_color=param_values["na_color"],
**kwargs,
)
sdata.plotting_tree[f"{n_steps+1}_render_images"] = ImageRenderParams(
sdata.plotting_tree[f"{n_steps + 1}_render_images"] = ImageRenderParams(
element=element,
channel=param_values["channel"],
cmap_params=cmap_params,
Expand Down Expand Up @@ -683,7 +682,7 @@ def render_labels(
norm=norm,
na_color=param_values["na_color"], # type: ignore[arg-type]
)
sdata.plotting_tree[f"{n_steps+1}_render_labels"] = LabelsRenderParams(
sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams(
element=element,
color=param_values["color"],
groups=param_values["groups"],
Expand Down
84 changes: 47 additions & 37 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
_ax_show_and_transform,
_create_image_from_datashader_result,
_datashader_aggregate_with_function,
_datashader_map_aggregate_to_color,
_datshader_get_how_kw_for_spread,
_decorate_axs,
_get_collection_shape,
Expand Down Expand Up @@ -229,18 +230,20 @@ def _render_shapes(
line_width=render_params.outline_params.linewidth,
)

ds_span = None
if norm.vmin is not None or norm.vmax is not None:
norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin
norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax
norm.clip = True # NOTE: mpl currently behaves like clip is always True
ds_span = [norm.vmin, norm.vmax]
if norm.vmin == norm.vmax:
# data is mapped to 0
agg = agg - agg
else:
agg = (agg - norm.vmin) / (norm.vmax - norm.vmin)
# edge case, value vmin is rendered as the middle of the cmap
ds_span = [0, 1]
if norm.clip:
agg = np.maximum(agg, 0)
agg = np.minimum(agg, 1)
agg = (agg - agg) + 0.5
else:
agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1)
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2)
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)

color_key = (
[x[:-2] for x in color_vector.categories.values]
Expand All @@ -256,13 +259,12 @@ def _render_shapes(
if isinstance(ds_cmap, str) and ds_cmap[0] == "#":
ds_cmap = ds_cmap[:-2]

ds_result = ds.tf.shade(
ds_result = _datashader_map_aggregate_to_color(
agg,
cmap=ds_cmap,
color_key=color_key,
min_alpha=np.min([254, render_params.fill_alpha * 255]),
how="linear",
)
) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
elif aggregate_with_reduction is not None: # to shut up mypy
ds_cmap = render_params.cmap_params.cmap
# in case all elements have the same value X: we render them using cmap(0.0),
Expand All @@ -272,12 +274,13 @@ def _render_shapes(
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)

ds_result = ds.tf.shade(
ds_result = _datashader_map_aggregate_to_color(
agg,
cmap=ds_cmap,
how="linear",
min_alpha=np.min([254, render_params.fill_alpha * 255]),
)
span=ds_span,
clip=norm.clip,
) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes

# shade outlines if needed
outline_color = render_params.outline_params.outline_color
Expand All @@ -294,7 +297,7 @@ def _render_shapes(
cmap=outline_color,
min_alpha=np.min([254, render_params.outline_alpha * 255]),
how="linear",
)
) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes

rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
_cax = _ax_show_and_transform(
Expand Down Expand Up @@ -322,8 +325,10 @@ def _render_shapes(
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
vmin = norm.vmin
vmax = norm.vmin + 1
# value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
# under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
vmin = norm.vmin - 0.5
vmax = norm.vmin + 0.5
cax = ScalarMappable(
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax),
cmap=render_params.cmap_params.cmap,
Expand Down Expand Up @@ -586,18 +591,21 @@ def _render_points(
else:
agg = cvs.points(transformed_element, "x", "y", agg=ds.count())

ds_span = None
if norm.vmin is not None or norm.vmax is not None:
norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin
norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax
norm.clip = True # NOTE: mpl currently behaves like clip is always True
ds_span = [norm.vmin, norm.vmax]
if norm.vmin == norm.vmax:
# data is mapped to 0
agg = agg - agg
else:
agg = (agg - norm.vmin) / (norm.vmax - norm.vmin)
ds_span = [0, 1]
if norm.clip:
agg = np.maximum(agg, 0)
agg = np.minimum(agg, 1)
# all data is mapped to 0.5
agg = (agg - agg) + 0.5
else:
# values equal to norm.vmin are mapped to 0.5, the rest to -1 or 2
agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1)
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2)
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)

color_key = (
list(color_vector.categories.values)
Expand All @@ -615,13 +623,12 @@ def _render_points(
color_vector = np.asarray([x[:-2] for x in color_vector])

if color_by_categorical or col_for_color is None:
ds_result = ds.tf.shade(
ds_result = _datashader_map_aggregate_to_color(
ds.tf.spread(agg, px=px),
cmap=color_vector[0],
color_key=color_key,
min_alpha=np.min([254, render_params.alpha * 255]),
how="linear",
)
) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
else:
spread_how = _datshader_get_how_kw_for_spread(render_params.ds_reduction)
agg = ds.tf.spread(agg, px=px, how=spread_how)
Expand All @@ -631,15 +638,17 @@ def _render_points(
# in case all elements have the same value X: we render them using cmap(0.0),
# using an artificial "span" of [X, X + 1] for the color bar
# else: all elements would get alpha=0 and the color bar would have a weird range
if aggregate_with_reduction[0] == aggregate_with_reduction[1]:
if aggregate_with_reduction[0] == aggregate_with_reduction[1] and (ds_span is None or ds_span != [0, 1]):
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)

ds_result = ds.tf.shade(
ds_result = _datashader_map_aggregate_to_color(
agg,
cmap=ds_cmap,
how="linear",
)
span=ds_span,
clip=norm.clip,
min_alpha=np.min([254, render_params.alpha * 255]),
) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes

rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
_ax_show_and_transform(
Expand All @@ -656,8 +665,10 @@ def _render_points(
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
vmin = norm.vmin
vmax = norm.vmin + 1
# value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
# under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
vmin = norm.vmin - 0.5
vmax = norm.vmin + 0.5
cax = ScalarMappable(
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax),
cmap=render_params.cmap_params.cmap,
Expand Down Expand Up @@ -723,7 +734,6 @@ def _render_images(
legend_params: LegendParams,
rasterize: bool,
) -> None:

sdata_filt = sdata.filter_by_coordinate_system(
coordinate_system=coordinate_system,
filter_tables=False,
Expand Down Expand Up @@ -781,9 +791,6 @@ def _render_images(
if n_channels == 1 and not isinstance(render_params.cmap_params, list):
layer = img.sel(c=channels[0]).squeeze() if isinstance(channels[0], str) else img.isel(c=channels[0]).squeeze()

if render_params.cmap_params.norm: # type: ignore[attr-defined]
layer = render_params.cmap_params.norm(layer) # type: ignore[attr-defined]

cmap = (
_get_linear_colormap(palette, "k")[0]
if isinstance(palette, list) and all(isinstance(p, str) for p in palette)
Expand All @@ -794,7 +801,10 @@ def _render_images(
cmap._init()
cmap._lut[:, -1] = render_params.alpha

_ax_show_and_transform(layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder)
# norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip.
_ax_show_and_transform(
layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder, norm=render_params.cmap_params.norm
)

if legend_params.colorbar:
sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm)
Expand Down
Loading
Loading