-
Notifications
You must be signed in to change notification settings - Fork 17
coloring shapes by categorical variable #153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
d32604c
coloring shapes by categorical variable
24c1392
update for case of array instead of categorical
d171e59
unittests added
567df8e
filter points and shapes using groups
0735840
Merge branch 'main' into bugfix/145_palette_in_render_shapes
Sonja-Stockhaus 9c65706
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 03779da
PointsModel coordinate argument
19dd3e6
unittests for points and filtering by groups
a1cf6ae
changing palette and cmap usage
14e166d
some docstrings, changelog
b3a7805
docstrings
40288fe
remove comments and brackets
f67dee5
Merge branch 'main' into bugfix/145_palette_in_render_shapes
timtreis File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
from copy import copy | ||
from typing import Union | ||
|
||
import dask | ||
import geopandas as gpd | ||
import matplotlib | ||
import numpy as np | ||
|
@@ -18,6 +19,7 @@ | |
from spatialdata.models import ( | ||
Image2DModel, | ||
Labels2DModel, | ||
PointsModel, | ||
) | ||
|
||
from spatialdata_plot._logging import logger | ||
|
@@ -57,6 +59,12 @@ def _render_shapes( | |
) -> None: | ||
elements = render_params.elements | ||
|
||
if render_params.groups is not None: | ||
if isinstance(render_params.groups, str): | ||
render_params.groups = [render_params.groups] | ||
if not all(isinstance(g, str) for g in render_params.groups): | ||
raise TypeError("All groups must be strings.") | ||
|
||
sdata_filt = sdata.filter_by_coordinate_system( | ||
coordinate_system=coordinate_system, | ||
filter_table=sdata.table is not None, | ||
|
@@ -68,7 +76,6 @@ def _render_shapes( | |
elements = list(sdata_filt.shapes.keys()) | ||
|
||
for e in elements: | ||
# shapes = [sdata.shapes[e] for e in elements] | ||
shapes = sdata.shapes[e] | ||
n_shapes = sum([len(s) for s in shapes]) | ||
|
||
|
@@ -88,6 +95,7 @@ def _render_shapes( | |
palette=render_params.palette, | ||
na_color=render_params.cmap_params.na_color, | ||
alpha=render_params.fill_alpha, | ||
cmap_params=render_params.cmap_params, | ||
) | ||
|
||
values_are_categorical = color_source_vector is not None | ||
|
@@ -101,7 +109,15 @@ def _render_shapes( | |
if len(color_vector) == 0: | ||
color_vector = [render_params.cmap_params.na_color] | ||
|
||
# filter by `groups` | ||
if render_params.groups is not None and color_source_vector is not None: | ||
mask = color_source_vector.isin(render_params.groups) | ||
shapes = shapes[mask] | ||
shapes = shapes.reset_index() | ||
color_source_vector = color_source_vector[mask] | ||
color_vector = color_vector[mask] | ||
shapes = gpd.GeoDataFrame(shapes, geometry="geometry") | ||
|
||
_cax = _get_collection_shape( | ||
shapes=shapes, | ||
s=render_params.scale, | ||
|
@@ -122,9 +138,12 @@ def _render_shapes( | |
cax = ax.add_collection(_cax) | ||
|
||
# Using dict.fromkeys here since set returns in arbitrary order | ||
palette = ( | ||
ListedColormap(dict.fromkeys(color_vector)) if render_params.palette is None else render_params.palette | ||
) | ||
# remove the color of NaN values, else it might be assigned to a category | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to have this comment directly next to the |
||
# order of color in the palette should agree to order of occurence | ||
if color_source_vector is None: | ||
palette = ListedColormap(dict.fromkeys(color_vector)) | ||
else: | ||
palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()])) | ||
|
||
if not ( | ||
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color) | ||
|
@@ -159,6 +178,12 @@ def _render_points( | |
scalebar_params: ScalebarParams, | ||
legend_params: LegendParams, | ||
) -> None: | ||
if render_params.groups is not None: | ||
if isinstance(render_params.groups, str): | ||
render_params.groups = [render_params.groups] | ||
if not all(isinstance(g, str) for g in render_params.groups): | ||
raise TypeError("All groups must be strings.") | ||
|
||
elements = render_params.elements | ||
|
||
sdata_filt = sdata.filter_by_coordinate_system( | ||
|
@@ -178,6 +203,14 @@ def _render_points( | |
color = [render_params.color] if isinstance(render_params.color, str) else render_params.color | ||
coords.extend(color) | ||
|
||
points = points[coords].compute() | ||
# points[color[0]].cat.set_categories(render_params.groups, inplace=True) | ||
if render_params.groups is not None: | ||
points = points[points[color].isin(render_params.groups).values] | ||
points[color[0]] = points[color[0]].cat.set_categories(render_params.groups) | ||
points = dask.dataframe.from_pandas(points, npartitions=1) | ||
sdata_filt.points[e] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"}) | ||
|
||
point_df = points[coords].compute() | ||
|
||
# we construct an anndata to hack the plotting functions | ||
|
@@ -204,6 +237,7 @@ def _render_points( | |
palette=render_params.palette, | ||
na_color=render_params.cmap_params.na_color, | ||
alpha=render_params.alpha, | ||
cmap_params=render_params.cmap_params, | ||
) | ||
|
||
# color_source_vector is None when the values aren't categorical | ||
|
@@ -226,14 +260,19 @@ def _render_points( | |
if not ( | ||
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color) | ||
): | ||
if color_source_vector is None: | ||
palette = ListedColormap(dict.fromkeys(color_vector)) | ||
else: | ||
palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()])) | ||
|
||
_ = _decorate_axs( | ||
ax=ax, | ||
cax=cax, | ||
fig_params=fig_params, | ||
adata=adata, | ||
value_to_plot=render_params.color, | ||
color_source_vector=color_source_vector, | ||
palette=render_params.palette, | ||
palette=palette, | ||
alpha=render_params.alpha, | ||
na_color=render_params.cmap_params.na_color, | ||
legend_fontsize=legend_params.legend_fontsize, | ||
|
@@ -415,6 +454,12 @@ def _render_labels( | |
) -> None: | ||
elements = render_params.elements | ||
|
||
if render_params.groups is not None: | ||
if isinstance(render_params.groups, str): | ||
render_params.groups = [render_params.groups] | ||
if not all(isinstance(g, str) for g in render_params.groups): | ||
raise TypeError("All groups must be strings.") | ||
|
||
sdata_filt = sdata.filter_by_coordinate_system( | ||
coordinate_system=coordinate_system, | ||
filter_table=sdata.table is not None, | ||
|
@@ -441,7 +486,7 @@ def _render_labels( | |
|
||
table = sdata.table[sdata.table.obs[region_key].isin([label_key])] | ||
|
||
# get isntance id based on subsetted table | ||
# get instance id based on subsetted table | ||
instance_id = table.obs[instance_key].values | ||
|
||
# get color vector (categorical or continuous) | ||
|
@@ -455,6 +500,7 @@ def _render_labels( | |
palette=render_params.palette, | ||
na_color=render_params.cmap_params.na_color, | ||
alpha=render_params.fill_alpha, | ||
cmap_params=render_params.cmap_params, | ||
) | ||
|
||
if (render_params.fill_alpha != render_params.outline_alpha) and render_params.contour_px is not None: | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good descriptions 👌