Skip to content

Commit 8094fb1

Browse files
authored
Merge pull request matplotlib#16675 from apaszke/vectorize_surface_plot_perimeters
ENH: Vectorize patch extraction in Axes3D.plot_surface
2 parents a046d8c + 907235d commit 8094fb1

File tree

3 files changed

+171
-20
lines changed

3 files changed

+171
-20
lines changed

lib/matplotlib/cbook/__init__.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,6 +1991,107 @@ def _array_perimeter(arr):
19911991
))
19921992

19931993

1994+
def _unfold(arr, axis, size, step):
1995+
"""
1996+
Append an extra dimension containing sliding windows along *axis*.
1997+
1998+
All windows are of size *size* and begin with every *step* elements.
1999+
2000+
Parameters
2001+
----------
2002+
arr : ndarray, shape (N_1, ..., N_k)
2003+
The input array
2004+
axis : int
2005+
Axis along which the windows are extracted
2006+
size : int
2007+
Size of the windows
2008+
step : int
2009+
Stride between first elements of subsequent windows.
2010+
2011+
Returns
2012+
-------
2013+
windows : ndarray, shape (N_1, ..., 1 + (N_axis-size)/step, ..., N_k, size)
2014+
2015+
Examples
2016+
--------
2017+
>>> i, j = np.ogrid[:3,:7]
2018+
>>> a = i*10 + j
2019+
>>> a
2020+
array([[ 0, 1, 2, 3, 4, 5, 6],
2021+
[10, 11, 12, 13, 14, 15, 16],
2022+
[20, 21, 22, 23, 24, 25, 26]])
2023+
>>> _unfold(a, axis=1, size=3, step=2)
2024+
array([[[ 0, 1, 2],
2025+
[ 2, 3, 4],
2026+
[ 4, 5, 6]],
2027+
2028+
[[10, 11, 12],
2029+
[12, 13, 14],
2030+
[14, 15, 16]],
2031+
2032+
[[20, 21, 22],
2033+
[22, 23, 24],
2034+
[24, 25, 26]]])
2035+
"""
2036+
new_shape = [*arr.shape, size]
2037+
new_strides = [*arr.strides, arr.strides[axis]]
2038+
new_shape[axis] = (new_shape[axis] - size) // step + 1
2039+
new_strides[axis] = new_strides[axis] * step
2040+
return np.lib.stride_tricks.as_strided(arr,
2041+
shape=new_shape,
2042+
strides=new_strides,
2043+
writeable=False)
2044+
2045+
2046+
def _array_patch_perimeters(x, rstride, cstride):
2047+
"""
2048+
Extract perimeters of patches from *arr*.
2049+
2050+
Extracted patches are of size (*rstride* + 1) x (*cstride* + 1) and
2051+
share perimeters with their neighbors. The ordering of the vertices matches
2052+
that returned by ``_array_perimeter``.
2053+
2054+
Parameters
2055+
----------
2056+
x : ndarray, shape (N, M)
2057+
Input array
2058+
rstride : int
2059+
Vertical (row) stride between corresponding elements of each patch
2060+
cstride : int
2061+
Horizontal (column) stride between corresponding elements of each patch
2062+
2063+
Returns
2064+
-------
2065+
patches : ndarray, shape (N/rstride * M/cstride, 2 * (rstride + cstride))
2066+
"""
2067+
assert rstride > 0 and cstride > 0
2068+
assert (x.shape[0] - 1) % rstride == 0
2069+
assert (x.shape[1] - 1) % cstride == 0
2070+
# We build up each perimeter from four half-open intervals. Here is an
2071+
# illustrated explanation for rstride == cstride == 3
2072+
#
2073+
# T T T R
2074+
# L R
2075+
# L R
2076+
# L B B B
2077+
#
2078+
# where T means that this element will be in the top array, R for right,
2079+
# B for bottom and L for left. Each of the arrays below has a shape of:
2080+
#
2081+
# (number of perimeters that can be extracted vertically,
2082+
# number of perimeters that can be extracted horizontally,
2083+
# cstride for top and bottom and rstride for left and right)
2084+
#
2085+
# Note that _unfold doesn't incur any memory copies, so the only costly
2086+
# operation here is the np.concatenate.
2087+
top = _unfold(x[:-1:rstride, :-1], 1, cstride, cstride)
2088+
bottom = _unfold(x[rstride::rstride, 1:], 1, cstride, cstride)[..., ::-1]
2089+
right = _unfold(x[:-1, cstride::cstride], 0, rstride, rstride)
2090+
left = _unfold(x[1:, :-1:cstride], 0, rstride, rstride)[..., ::-1]
2091+
return (np.concatenate((top, right, bottom, left), axis=2)
2092+
.reshape(-1, 2 * (rstride + cstride)))
2093+
2094+
19942095
@contextlib.contextmanager
19952096
def _setattr_cm(obj, **kwargs):
19962097
"""Temporarily set some attributes; restore original state at context exit.

lib/matplotlib/tests/test_cbook.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,3 +613,31 @@ def test_warn_external(recwarn):
613613
cbook._warn_external("oops")
614614
assert len(recwarn) == 1
615615
assert recwarn[0].filename == __file__
616+
617+
618+
def test_array_patch_perimeters():
619+
# This compares the old implementation as a reference for the
620+
# vectorized one.
621+
def check(x, rstride, cstride):
622+
rows, cols = x.shape
623+
row_inds = [*range(0, rows-1, rstride), rows-1]
624+
col_inds = [*range(0, cols-1, cstride), cols-1]
625+
polys = []
626+
for rs, rs_next in zip(row_inds[:-1], row_inds[1:]):
627+
for cs, cs_next in zip(col_inds[:-1], col_inds[1:]):
628+
# +1 ensures we share edges between polygons
629+
ps = cbook._array_perimeter(x[rs:rs_next+1, cs:cs_next+1]).T
630+
polys.append(ps)
631+
polys = np.asarray(polys)
632+
assert np.array_equal(polys,
633+
cbook._array_patch_perimeters(
634+
x, rstride=rstride, cstride=cstride))
635+
636+
def divisors(n):
637+
return [i for i in range(1, n + 1) if n % i == 0]
638+
639+
for rows, cols in [(5, 5), (7, 14), (13, 9)]:
640+
x = np.arange(rows * cols).reshape(rows, cols)
641+
for rstride, cstride in itertools.product(divisors(rows - 1),
642+
divisors(cols - 1)):
643+
check(x, rstride=rstride, cstride=cstride)

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,17 @@ def plot_surface(self, X, Y, Z, *args, norm=None, vmin=None,
14371437
the input data is larger, it will be downsampled (by slicing) to
14381438
these numbers of points.
14391439
1440+
.. note::
1441+
1442+
To maximize rendering speed consider setting *rstride* and *cstride*
1443+
to divisors of the number of rows minus 1 and columns minus 1
1444+
respectively. For example, given 51 rows rstride can be any of the
1445+
divisors of 50.
1446+
1447+
Similarly, a setting of *rstride* and *cstride* equal to 1 (or
1448+
*rcount* and *ccount* equal the number of rows and columns) can use
1449+
the optimized path.
1450+
14401451
Parameters
14411452
----------
14421453
X, Y, Z : 2d arrays
@@ -1540,25 +1551,33 @@ def plot_surface(self, X, Y, Z, *args, norm=None, vmin=None,
15401551
"semantic or raise an error in matplotlib 3.3. "
15411552
"Please use shade=False instead.")
15421553

1543-
# evenly spaced, and including both endpoints
1544-
row_inds = list(range(0, rows-1, rstride)) + [rows-1]
1545-
col_inds = list(range(0, cols-1, cstride)) + [cols-1]
1546-
15471554
colset = [] # the sampled facecolor
1548-
polys = []
1549-
for rs, rs_next in zip(row_inds[:-1], row_inds[1:]):
1550-
for cs, cs_next in zip(col_inds[:-1], col_inds[1:]):
1551-
ps = [
1552-
# +1 ensures we share edges between polygons
1553-
cbook._array_perimeter(a[rs:rs_next+1, cs:cs_next+1])
1554-
for a in (X, Y, Z)
1555-
]
1556-
# ps = np.stack(ps, axis=-1)
1557-
ps = np.array(ps).T
1558-
polys.append(ps)
1559-
1560-
if fcolors is not None:
1561-
colset.append(fcolors[rs][cs])
1555+
if (rows - 1) % rstride == 0 and \
1556+
(cols - 1) % cstride == 0 and \
1557+
fcolors is None:
1558+
polys = np.stack(
1559+
[cbook._array_patch_perimeters(a, rstride, cstride)
1560+
for a in (X, Y, Z)],
1561+
axis=-1)
1562+
else:
1563+
# evenly spaced, and including both endpoints
1564+
row_inds = list(range(0, rows-1, rstride)) + [rows-1]
1565+
col_inds = list(range(0, cols-1, cstride)) + [cols-1]
1566+
1567+
polys = []
1568+
for rs, rs_next in zip(row_inds[:-1], row_inds[1:]):
1569+
for cs, cs_next in zip(col_inds[:-1], col_inds[1:]):
1570+
ps = [
1571+
# +1 ensures we share edges between polygons
1572+
cbook._array_perimeter(a[rs:rs_next+1, cs:cs_next+1])
1573+
for a in (X, Y, Z)
1574+
]
1575+
# ps = np.stack(ps, axis=-1)
1576+
ps = np.array(ps).T
1577+
polys.append(ps)
1578+
1579+
if fcolors is not None:
1580+
colset.append(fcolors[rs][cs])
15621581

15631582
# note that the striding causes some polygons to have more coordinates
15641583
# than others
@@ -1571,8 +1590,11 @@ def plot_surface(self, X, Y, Z, *args, norm=None, vmin=None,
15711590
polyc.set_facecolors(colset)
15721591
polyc.set_edgecolors(colset)
15731592
elif cmap:
1574-
# doesn't vectorize because polys is jagged
1575-
avg_z = np.array([ps[:, 2].mean() for ps in polys])
1593+
# can't always vectorize, because polys might be jagged
1594+
if isinstance(polys, np.ndarray):
1595+
avg_z = polys[..., 2].mean(axis=-1)
1596+
else:
1597+
avg_z = np.array([ps[:, 2].mean() for ps in polys])
15761598
polyc.set_array(avg_z)
15771599
if vmin is not None or vmax is not None:
15781600
polyc.set_clim(vmin, vmax)

0 commit comments

Comments
 (0)