Skip to content

Commit bcecb0e

Browse files
committed
basic shapes datashader rendering
1 parent 3bfab39 commit bcecb0e

File tree

1 file changed

+63
-29
lines changed

1 file changed

+63
-29
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -122,35 +122,6 @@ def _render_shapes(
122122
shapes = shapes.reset_index()
123123
color_source_vector = color_source_vector[mask]
124124
color_vector = color_vector[mask]
125-
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
126-
127-
_cax = _get_collection_shape(
128-
shapes=shapes,
129-
s=render_params.scale,
130-
c=color_vector,
131-
render_params=render_params,
132-
rasterized=sc_settings._vector_friendly,
133-
cmap=render_params.cmap_params.cmap,
134-
norm=norm,
135-
fill_alpha=render_params.fill_alpha,
136-
outline_alpha=render_params.outline_alpha,
137-
zorder=render_params.zorder,
138-
# **kwargs,
139-
)
140-
141-
# Sets the limits of the colorbar to the values instead of [0, 1]
142-
if not norm and not values_are_categorical:
143-
_cax.set_clim(min(color_vector), max(color_vector))
144-
145-
cax = ax.add_collection(_cax)
146-
147-
# Apply the transformation to the PatchCollection's paths
148-
trans = get_transformation(sdata_filt.shapes[e], get_all=True)[coordinate_system]
149-
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
150-
trans = mtransforms.Affine2D(matrix=affine_trans)
151-
152-
for path in _cax.get_paths():
153-
path.vertices = trans.transform(path.vertices)
154125

155126
# Using dict.fromkeys here since set returns in arbitrary order
156127
# remove the color of NaN values, else it might be assigned to a category
@@ -160,6 +131,69 @@ def _render_shapes(
160131
else:
161132
palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()]))
162133

134+
# Apply the transformation to the PatchCollection's paths
135+
trans = get_transformation(sdata_filt.shapes[e], get_all=True)[coordinate_system]
136+
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
137+
trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData
138+
139+
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
140+
141+
if len(shapes) < 1:
142+
logger.info("Using matplotlib")
143+
_cax = _get_collection_shape(
144+
shapes=shapes,
145+
s=render_params.scale,
146+
c=color_vector,
147+
render_params=render_params,
148+
rasterized=sc_settings._vector_friendly,
149+
cmap=render_params.cmap_params.cmap,
150+
norm=norm,
151+
fill_alpha=render_params.fill_alpha,
152+
outline_alpha=render_params.outline_alpha,
153+
zorder=render_params.zorder,
154+
# **kwargs,
155+
)
156+
cax = ax.add_collection(_cax)
157+
158+
# Transform the paths in PatchCollection
159+
for path in _cax.get_paths():
160+
path.vertices = trans.transform(path.vertices)
161+
cax = ax.add_collection(_cax)
162+
else:
163+
logger.info("Using datashader")
164+
extent = get_extent(sdata.shapes[e])
165+
x_ext = extent["x"][1]
166+
y_ext = extent["y"][1]
167+
# previous_xlim = fig_params.ax.get_xlim()
168+
# previous_ylim = fig_params.ax.get_ylim()
169+
x_range = [0, x_ext]
170+
y_range = [0, y_ext]
171+
# round because we need integers
172+
plot_width = int(np.round(x_range[1] - x_range[0]))
173+
plot_height = int(np.round(y_range[1] - y_range[0]))
174+
175+
cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_range, y_range=y_range)
176+
177+
_geometry = shapes["geometry"]
178+
is_point = _geometry.type == "Point"
179+
180+
# Handle circles encoded as points with radius
181+
if is_point.any(): # TODO
182+
shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(shapes[is_point]["radius"])
183+
184+
agg = cvs.polygons(shapes, geometry="geometry", agg=ds.count())
185+
ds_result = ds.tf.shade(agg)
186+
187+
# Render image
188+
rgba_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
189+
_cax = ax.imshow(rgba_image, cmap=palette, zorder=render_params.zorder)
190+
_cax.set_transform(trans)
191+
cax = ax.add_image(_cax)
192+
193+
# Sets the limits of the colorbar to the values instead of [0, 1]
194+
if not norm and not values_are_categorical:
195+
_cax.set_clim(min(color_vector), max(color_vector))
196+
163197
if not (
164198
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
165199
):

0 commit comments

Comments
 (0)