Skip to content

Commit 6949c66

Browse files
authored
Optimized get_unit_normal() and replaced np.cross() with custom cross() in manim.utils.space_ops (#3494)
* Added cross and optimized get_unit_normal in manim.utils.space_ops * Added missing border case to new get_unit_normal where one vector is nonzero * Updated test_threed.py::test_Sphere test data
1 parent 7cead84 commit 6949c66

File tree

2 files changed

+50
-19
lines changed

2 files changed

+50
-19
lines changed

manim/utils/space_ops.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from manim.typing import Point3D_Array, Vector
5+
from manim.typing import Point3D_Array, Vector, Vector3
66

77
__all__ = [
88
"quaternion_mult",
@@ -53,6 +53,16 @@ def norm_squared(v: float) -> float:
5353
return np.dot(v, v)
5454

5555

56+
def cross(v1: Vector3, v2: Vector3) -> Vector3:
57+
return np.array(
58+
[
59+
v1[1] * v2[2] - v1[2] * v2[1],
60+
v1[2] * v2[0] - v1[0] * v2[2],
61+
v1[0] * v2[1] - v1[1] * v2[0],
62+
]
63+
)
64+
65+
5666
# Quaternions
5767
# TODO, implement quaternion type
5868

@@ -273,12 +283,12 @@ def z_to_vector(vector: np.ndarray) -> np.ndarray:
273283
(normalized) vector provided as an argument
274284
"""
275285
axis_z = normalize(vector)
276-
axis_y = normalize(np.cross(axis_z, RIGHT))
277-
axis_x = np.cross(axis_y, axis_z)
286+
axis_y = normalize(cross(axis_z, RIGHT))
287+
axis_x = cross(axis_y, axis_z)
278288
if np.linalg.norm(axis_y) == 0:
279289
# the vector passed just so happened to be in the x direction.
280-
axis_x = normalize(np.cross(UP, axis_z))
281-
axis_y = -np.cross(axis_x, axis_z)
290+
axis_x = normalize(cross(UP, axis_z))
291+
axis_y = -cross(axis_x, axis_z)
282292

283293
return np.array([axis_x, axis_y, axis_z]).T
284294

@@ -359,7 +369,7 @@ def normalize_along_axis(array: np.ndarray, axis: np.ndarray) -> np.ndarray:
359369
return array
360370

361371

362-
def get_unit_normal(v1: np.ndarray, v2: np.ndarray, tol: float = 1e-6) -> np.ndarray:
372+
def get_unit_normal(v1: Vector3, v2: Vector3, tol: float = 1e-6) -> Vector3:
363373
"""Gets the unit normal of the vectors.
364374
365375
Parameters
@@ -376,16 +386,37 @@ def get_unit_normal(v1: np.ndarray, v2: np.ndarray, tol: float = 1e-6) -> np.nda
376386
np.ndarray
377387
The normal of the two vectors.
378388
"""
379-
v1, v2 = (normalize(i) for i in (v1, v2))
380-
cp = np.cross(v1, v2)
381-
cp_norm = np.linalg.norm(cp)
382-
if cp_norm < tol:
383-
# Vectors align, so find a normal to them in the plane shared with the z-axis
384-
cp = np.cross(np.cross(v1, OUT), v1)
385-
cp_norm = np.linalg.norm(cp)
386-
if cp_norm < tol:
389+
# Instead of normalizing v1 and v2, just divide by the greatest
390+
# of all their absolute components, which is just enough
391+
div1, div2 = max(np.abs(v1)), max(np.abs(v2))
392+
if div1 == 0.0:
393+
if div2 == 0.0:
387394
return DOWN
388-
return normalize(cp)
395+
u = v2 / div2
396+
elif div2 == 0.0:
397+
u = v1 / div1
398+
else:
399+
# Normal scenario: v1 and v2 are both non-null
400+
u1, u2 = v1 / div1, v2 / div2
401+
cp = cross(u1, u2)
402+
cp_norm = np.sqrt(norm_squared(cp))
403+
if cp_norm > tol:
404+
return cp / cp_norm
405+
# Otherwise, v1 and v2 were aligned
406+
u = u1
407+
408+
# If you are here, you have an "unique", non-zero, unit-ish vector u
409+
# If it's also too aligned to the Z axis, just return DOWN
410+
if abs(u[0]) < tol and abs(u[1]) < tol:
411+
return DOWN
412+
# Otherwise rotate u in the plane it shares with the Z axis,
413+
# 90° TOWARDS the Z axis. This is done via (u x [0, 0, 1]) x u,
414+
# which gives [-xz, -yz, x²+y²] (slightly scaled as well)
415+
cp = np.array([-u[0] * u[2], -u[1] * u[2], u[0] * u[0] + u[1] * u[1]])
416+
cp_norm = np.sqrt(norm_squared(cp))
417+
# Because the norm(u) == 0 case was filtered in the beginning,
418+
# there is no need to check if the norm of cp is 0
419+
return cp / cp_norm
389420

390421

391422
###
@@ -529,8 +560,8 @@ def line_intersection(
529560
np.pad(np.array(i)[:, :2], ((0, 0), (0, 1)), constant_values=1)
530561
for i in (line1, line2)
531562
)
532-
line1, line2 = (np.cross(*i) for i in padded)
533-
x, y, z = np.cross(line1, line2)
563+
line1, line2 = (cross(*i) for i in padded)
564+
x, y, z = cross(line1, line2)
534565

535566
if z == 0:
536567
raise ValueError(
@@ -558,7 +589,7 @@ def find_intersection(
558589
result = []
559590

560591
for p0, v0, p1, v1 in zip(*[p0s, v0s, p1s, v1s]):
561-
normal = np.cross(v1, np.cross(v0, v1))
592+
normal = cross(v1, cross(v0, v1))
562593
denom = max(np.dot(v0, normal), threshold)
563594
result += [p0 + np.dot(p1 - p0, normal) / denom * v0]
564595
return result
@@ -814,6 +845,6 @@ def perpendicular_bisector(
814845
"""
815846
p1 = line[0]
816847
p2 = line[1]
817-
direction = np.cross(p1 - p2, norm_vector)
848+
direction = cross(p1 - p2, norm_vector)
818849
m = midpoint(p1, p2)
819850
return [m + direction, m - direction]
Binary file not shown.

0 commit comments

Comments
 (0)