Skip to content

Commit dc37bf1

Browse files
Switch from list comprehensions to vectorized approach for axes3d.plot_wireframe
1 parent f2717a5 commit dc37bf1

File tree

1 file changed

+12
-19
lines changed

1 file changed

+12
-19
lines changed

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2419,41 +2419,34 @@ def plot_wireframe(self, X, Y, Z, *, axlim_clip=False, **kwargs):
24192419
tX, tY, tZ = np.transpose(X), np.transpose(Y), np.transpose(Z)
24202420

24212421
if rstride:
2422-
rii = list(range(0, rows, rstride))
2422+
rii = np.arange(0, rows, rstride)
24232423
# Add the last index only if needed
24242424
if rows > 0 and rii[-1] != (rows - 1):
2425-
rii += [rows-1]
2425+
rii = np.append(rii, rows-1)
24262426
else:
2427-
rii = []
2427+
rii = np.array([], dtype=int)
2428+
24282429
if cstride:
2429-
cii = list(range(0, cols, cstride))
2430+
cii = np.arange(0, cols, cstride)
24302431
# Add the last index only if needed
24312432
if cols > 0 and cii[-1] != (cols - 1):
2432-
cii += [cols-1]
2433+
cii = np.append(cii, cols-1)
24332434
else:
2434-
cii = []
2435+
cii = np.array([], dtype=int)
24352436

24362437
if rstride == 0 and cstride == 0:
24372438
raise ValueError("Either rstride or cstride must be non zero")
24382439

24392440
# If the inputs were empty, then just
24402441
# reset everything.
24412442
if Z.size == 0:
2442-
rii = []
2443-
cii = []
2444-
2445-
xlines = [X[i] for i in rii]
2446-
ylines = [Y[i] for i in rii]
2447-
zlines = [Z[i] for i in rii]
2443+
rii = np.array([], dtype=int)
2444+
cii = np.array([], dtype=int)
24482445

2449-
txlines = [tX[i] for i in cii]
2450-
tylines = [tY[i] for i in cii]
2451-
tzlines = [tZ[i] for i in cii]
2446+
row_lines = np.stack([X[rii], Y[rii], Z[rii]], axis=-1)
2447+
col_lines = np.stack([tX[cii], tY[cii], tZ[cii]], axis=-1)
24522448

2453-
lines = ([list(zip(xl, yl, zl))
2454-
for xl, yl, zl in zip(xlines, ylines, zlines)]
2455-
+ [list(zip(xl, yl, zl))
2456-
for xl, yl, zl in zip(txlines, tylines, tzlines)])
2449+
lines = np.concatenate([row_lines, col_lines])
24572450

24582451
linec = art3d.Line3DCollection(lines, axlim_clip=axlim_clip, **kwargs)
24592452
self.add_collection(linec)

0 commit comments

Comments
 (0)