Skip to content

Commit b415604

Browse files
authored
Optimized manim.utils.bezier.is_closed() (#3768)
* Optimized manim.utils.bezier.is_closed() * oops, that shouldn't have been there * Slightly optimized is_closed() even more * Added doctest for is_closed()
1 parent 203a536 commit b415604

File tree

1 file changed

+72
-1
lines changed

1 file changed

+72
-1
lines changed

manim/utils/bezier.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,78 @@ def get_quadratic_approximation_of_cubic(
681681

682682

683683
def is_closed(points: Point3D_Array) -> bool:
684-
return np.allclose(points[0], points[-1]) # type: ignore
684+
"""Returns ``True`` if the spline given by ``points`` is closed, by
685+
checking if its first and last points are close to each other, or``False``
686+
otherwise.
687+
688+
.. note::
689+
690+
This function reimplements :meth:`np.allclose`, because repeated
691+
calling of :meth:`np.allclose` for only 2 points is inefficient.
692+
693+
Parameters
694+
----------
695+
points
696+
An array of points defining a spline.
697+
698+
Returns
699+
-------
700+
:class:`bool`
701+
Whether the first and last points of the array are close enough or not
702+
to be considered the same, thus considering the defined spline as
703+
closed.
704+
705+
Examples
706+
--------
707+
.. code-block:: pycon
708+
709+
>>> import numpy as np
710+
>>> from manim import is_closed
711+
>>> is_closed(
712+
... np.array(
713+
... [
714+
... [0, 0, 0],
715+
... [1, 2, 3],
716+
... [3, 2, 1],
717+
... [0, 0, 0],
718+
... ]
719+
... )
720+
... )
721+
True
722+
>>> is_closed(
723+
... np.array(
724+
... [
725+
... [0, 0, 0],
726+
... [1, 2, 3],
727+
... [3, 2, 1],
728+
... [1e-10, 1e-10, 1e-10],
729+
... ]
730+
... )
731+
... )
732+
True
733+
>>> is_closed(
734+
... np.array(
735+
... [
736+
... [0, 0, 0],
737+
... [1, 2, 3],
738+
... [3, 2, 1],
739+
... [1e-2, 1e-2, 1e-2],
740+
... ]
741+
... )
742+
... )
743+
False
744+
"""
745+
start, end = points[0], points[-1]
746+
rtol = 1e-5
747+
atol = 1e-8
748+
tolerance = atol + rtol * start
749+
if abs(end[0] - start[0]) > tolerance[0]:
750+
return False
751+
if abs(end[1] - start[1]) > tolerance[1]:
752+
return False
753+
if abs(end[2] - start[2]) > tolerance[2]:
754+
return False
755+
return True
685756

686757

687758
def proportions_along_bezier_curve_for_point(

0 commit comments

Comments
 (0)