Skip to content

Commit 17d4a7f

Browse files
Add type annotations to manim.utils.* (#3999)
Co-authored-by: Francisco Manríquez Novoa <[email protected]> Co-authored-by: Francisco Manríquez <[email protected]>
1 parent dbad8a8 commit 17d4a7f

40 files changed

+708
-358
lines changed

manim/mobject/text/numbers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
__all__ = ["DecimalNumber", "Integer", "Variable"]
66

77
from collections.abc import Sequence
8+
from typing import Any
89

910
import numpy as np
1011

@@ -327,7 +328,9 @@ def construct(self):
327328
self.add(Integer(number=6.28).set_x(-1.5).set_y(-2).set_color(YELLOW).scale(1.4))
328329
"""
329330

330-
def __init__(self, number=0, num_decimal_places=0, **kwargs):
331+
def __init__(
332+
self, number: float = 0, num_decimal_places: int = 0, **kwargs: Any
333+
) -> None:
331334
super().__init__(number=number, num_decimal_places=num_decimal_places, **kwargs)
332335

333336
def get_value(self):

manim/renderer/cairo_renderer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from manim.animation.animation import Animation
2020
from manim.scene.scene import Scene
2121

22+
from ..typing import PixelArray
23+
2224
__all__ = ["CairoRenderer"]
2325

2426

@@ -158,7 +160,7 @@ def render(self, scene, time, moving_mobjects):
158160
self.update_frame(scene, moving_mobjects)
159161
self.add_frame(self.get_frame())
160162

161-
def get_frame(self):
163+
def get_frame(self) -> PixelArray:
162164
"""
163165
Gets the current frame as NumPy array.
164166

manim/renderer/opengl_renderer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def get_texture_id(self, path):
388388

389389
return self.path_to_texture_id[repr(path)]
390390

391-
def update_skipping_status(self):
391+
def update_skipping_status(self) -> None:
392392
"""
393393
This method is used internally to check if the current
394394
animation needs to be skipped or not. It also checks if

manim/scene/scene.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def add(self, *mobjects: Mobject):
488488
self.moving_mobjects += mobjects
489489
return self
490490

491-
def add_mobjects_from_animations(self, animations):
491+
def add_mobjects_from_animations(self, animations: list[Animation]) -> None:
492492
curr_mobjects = self.get_mobject_family_members()
493493
for animation in animations:
494494
if animation.is_introducer():

manim/scene/scene_file_writer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pydub import AudioSegment
2121

2222
from manim import __version__
23-
from manim.typing import PixelArray
23+
from manim.typing import PixelArray, StrPath
2424

2525
from .. import config, logger
2626
from .._config.logger_utils import set_file_logger
@@ -38,6 +38,7 @@
3838
from .section import DefaultSectionType, Section
3939

4040
if TYPE_CHECKING:
41+
from manim.renderer.cairo_renderer import CairoRenderer
4142
from manim.renderer.opengl_renderer import OpenGLRenderer
4243

4344

@@ -104,7 +105,12 @@ class SceneFileWriter:
104105

105106
force_output_as_scene_name = False
106107

107-
def __init__(self, renderer, scene_name, **kwargs):
108+
def __init__(
109+
self,
110+
renderer: CairoRenderer | OpenGLRenderer,
111+
scene_name: StrPath,
112+
**kwargs: Any,
113+
) -> None:
108114
self.renderer = renderer
109115
self.init_output_directories(scene_name)
110116
self.init_audio()
@@ -118,7 +124,7 @@ def __init__(self, renderer, scene_name, **kwargs):
118124
name="autocreated", type_=DefaultSectionType.NORMAL, skip_animations=False
119125
)
120126

121-
def init_output_directories(self, scene_name):
127+
def init_output_directories(self, scene_name: StrPath) -> None:
122128
"""Initialise output directories.
123129
124130
Notes
@@ -378,7 +384,9 @@ def add_sound(
378384
self.add_audio_segment(new_segment, time, **kwargs)
379385

380386
# Writers
381-
def begin_animation(self, allow_write: bool = False, file_path=None):
387+
def begin_animation(
388+
self, allow_write: bool = False, file_path: StrPath | None = None
389+
) -> None:
382390
"""
383391
Used internally by manim to stream the animation to FFMPEG for
384392
displaying or writing to a file.
@@ -391,7 +399,7 @@ def begin_animation(self, allow_write: bool = False, file_path=None):
391399
if write_to_movie() and allow_write:
392400
self.open_partial_movie_stream(file_path=file_path)
393401

394-
def end_animation(self, allow_write: bool = False):
402+
def end_animation(self, allow_write: bool = False) -> None:
395403
"""
396404
Internally used by Manim to stop streaming to
397405
FFMPEG gracefully.

manim/utils/bezier.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,12 @@
2929
from manim.utils.simple_functions import choose
3030

3131
if TYPE_CHECKING:
32-
import numpy.typing as npt
33-
3432
from manim.typing import (
3533
BezierPoints,
3634
BezierPoints_Array,
3735
BezierPointsLike,
3836
BezierPointsLike_Array,
3937
ColVector,
40-
ManimFloat,
4138
MatrixMN,
4239
Point3D,
4340
Point3D_Array,
@@ -64,7 +61,9 @@ def bezier(
6461
) -> Callable[[float | ColVector], Point3D_Array]: ...
6562

6663

67-
def bezier(points):
64+
def bezier(
65+
points: Point3D_Array | Sequence[Point3D_Array],
66+
) -> Callable[[float | ColVector], Point3D_Array]:
6867
"""Classic implementation of a Bézier curve.
6968
7069
Parameters
@@ -118,21 +117,21 @@ def bezier(points):
118117

119118
if degree == 0:
120119

121-
def zero_bezier(t):
120+
def zero_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
122121
return np.ones_like(t) * P[0]
123122

124123
return zero_bezier
125124

126125
if degree == 1:
127126

128-
def linear_bezier(t):
127+
def linear_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
129128
return P[0] + t * (P[1] - P[0])
130129

131130
return linear_bezier
132131

133132
if degree == 2:
134133

135-
def quadratic_bezier(t):
134+
def quadratic_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
136135
t2 = t * t
137136
mt = 1 - t
138137
mt2 = mt * mt
@@ -142,7 +141,7 @@ def quadratic_bezier(t):
142141

143142
if degree == 3:
144143

145-
def cubic_bezier(t):
144+
def cubic_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
146145
t2 = t * t
147146
t3 = t2 * t
148147
mt = 1 - t
@@ -152,11 +151,12 @@ def cubic_bezier(t):
152151

153152
return cubic_bezier
154153

155-
def nth_grade_bezier(t):
154+
def nth_grade_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
156155
is_scalar = not isinstance(t, np.ndarray)
157156
if is_scalar:
158157
B = np.empty((1, *P.shape))
159158
else:
159+
assert isinstance(t, np.ndarray)
160160
t = t.reshape(-1, *[1 for dim in P.shape])
161161
B = np.empty((t.shape[0], *P.shape))
162162
B[:] = P
@@ -169,7 +169,8 @@ def nth_grade_bezier(t):
169169
# In the end, there shall be the evaluation at t of a single Bezier curve of
170170
# grade d, stored in the first slot of B
171171
if is_scalar:
172-
return B[0, 0]
172+
val: Point3D = B[0, 0]
173+
return val
173174
return B[:, 0]
174175

175176
return nth_grade_bezier
@@ -1026,7 +1027,11 @@ def interpolate(start: Point3D, end: Point3D, alpha: float) -> Point3D: ...
10261027
def interpolate(start: Point3D, end: Point3D, alpha: ColVector) -> Point3D_Array: ...
10271028

10281029

1029-
def interpolate(start, end, alpha):
1030+
def interpolate(
1031+
start: float | Point3D,
1032+
end: float | Point3D,
1033+
alpha: float | ColVector,
1034+
) -> float | ColVector | Point3D | Point3D_Array:
10301035
"""Linearly interpolates between two values ``start`` and ``end``.
10311036
10321037
Parameters
@@ -1139,7 +1144,9 @@ def inverse_interpolate(start: Point3D, end: Point3D, value: Point3D) -> Point3D
11391144

11401145

11411146
def inverse_interpolate(
1142-
start: float | Point3D, end: float | Point3D, value: float | Point3D
1147+
start: float | Point3D,
1148+
end: float | Point3D,
1149+
value: float | Point3D,
11431150
) -> float | Point3D:
11441151
"""Perform inverse interpolation to determine the alpha
11451152
values that would produce the specified ``value``
@@ -1234,7 +1241,7 @@ def match_interpolate(
12341241
return interpolate(
12351242
new_start,
12361243
new_end,
1237-
old_alpha, # type: ignore[arg-type]
1244+
old_alpha,
12381245
)
12391246

12401247

@@ -1270,7 +1277,8 @@ def get_smooth_cubic_bezier_handle_points(
12701277
# they can only be an interpolation of these two anchors with alphas
12711278
# 1/3 and 2/3, which will draw a straight line between the anchors.
12721279
if n_anchors == 2:
1273-
return interpolate(anchors[0], anchors[1], np.array([[1 / 3], [2 / 3]]))
1280+
val = interpolate(anchors[0], anchors[1], np.array([[1 / 3], [2 / 3]]))
1281+
return (val[0], val[1])
12741282

12751283
# Handle different cases depending on whether the points form a closed
12761284
# curve or not
@@ -1745,7 +1753,12 @@ def get_quadratic_approximation_of_cubic(
17451753
) -> QuadraticBezierPath: ...
17461754

17471755

1748-
def get_quadratic_approximation_of_cubic(a0, h0, h1, a1):
1756+
def get_quadratic_approximation_of_cubic(
1757+
a0: Point3D | Point3D_Array,
1758+
h0: Point3D | Point3D_Array,
1759+
h1: Point3D | Point3D_Array,
1760+
a1: Point3D | Point3D_Array,
1761+
) -> QuadraticSpline | QuadraticBezierPath:
17491762
r"""If ``a0``, ``h0``, ``h1`` and ``a1`` are the control points of a cubic
17501763
Bézier curve, approximate the curve with two quadratic Bézier curves and
17511764
return an array of 6 points, where the first 3 points represent the first
@@ -1849,33 +1862,33 @@ def get_quadratic_approximation_of_cubic(a0, h0, h1, a1):
18491862
If ``a0``, ``h0``, ``h1`` and ``a1`` have different dimensions, or
18501863
if their number of dimensions is not 1 or 2.
18511864
"""
1852-
a0 = np.asarray(a0)
1853-
h0 = np.asarray(h0)
1854-
h1 = np.asarray(h1)
1855-
a1 = np.asarray(a1)
1856-
1857-
if all(arr.ndim == 1 for arr in (a0, h0, h1, a1)):
1858-
num_curves, dim = 1, a0.shape[0]
1859-
elif all(arr.ndim == 2 for arr in (a0, h0, h1, a1)):
1860-
num_curves, dim = a0.shape
1865+
a0c = np.asarray(a0)
1866+
h0c = np.asarray(h0)
1867+
h1c = np.asarray(h1)
1868+
a1c = np.asarray(a1)
1869+
1870+
if all(arr.ndim == 1 for arr in (a0c, h0c, h1c, a1c)):
1871+
num_curves, dim = 1, a0c.shape[0]
1872+
elif all(arr.ndim == 2 for arr in (a0c, h0c, h1c, a1c)):
1873+
num_curves, dim = a0c.shape
18611874
else:
18621875
raise ValueError("All arguments must be Point3D or Point3D_Array.")
18631876

1864-
m0 = 0.25 * (3 * h0 + a0)
1865-
m1 = 0.25 * (3 * h1 + a1)
1877+
m0 = 0.25 * (3 * h0c + a0c)
1878+
m1 = 0.25 * (3 * h1c + a1c)
18661879
k = 0.5 * (m0 + m1)
18671880

18681881
result = np.empty((6 * num_curves, dim))
1869-
result[0::6] = a0
1882+
result[0::6] = a0c
18701883
result[1::6] = m0
18711884
result[2::6] = k
18721885
result[3::6] = k
18731886
result[4::6] = m1
1874-
result[5::6] = a1
1887+
result[5::6] = a1c
18751888
return result
18761889

18771890

1878-
def is_closed(points: Point3DLike_Array) -> bool:
1891+
def is_closed(points: Point3D_Array) -> bool:
18791892
"""Returns ``True`` if the spline given by ``points`` is closed, by
18801893
checking if its first and last points are close to each other, or``False``
18811894
otherwise.
@@ -1952,7 +1965,7 @@ def proportions_along_bezier_curve_for_point(
19521965
point: Point3DLike,
19531966
control_points: BezierPointsLike,
19541967
round_to: float = 1e-6,
1955-
) -> npt.NDArray[ManimFloat]:
1968+
) -> MatrixMN:
19561969
"""Obtains the proportion along the bezier curve corresponding to a given point
19571970
given the bezier curve's control points.
19581971

manim/utils/caching.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
from __future__ import annotations
22

3-
from typing import Callable
3+
from typing import TYPE_CHECKING, Callable
44

55
from .. import config, logger
66
from ..utils.hashing import get_hash_from_play_call
77

88
__all__ = ["handle_caching_play"]
99

10+
if TYPE_CHECKING:
11+
from typing import Any
1012

11-
def handle_caching_play(func: Callable[..., None]):
13+
from manim.renderer.opengl_renderer import OpenGLRenderer
14+
from manim.scene.scene import Scene
15+
16+
17+
def handle_caching_play(func: Callable[..., None]) -> Callable[..., None]:
1218
"""Decorator that returns a wrapped version of func that will compute
1319
the hash of the play invocation.
1420
@@ -28,7 +34,7 @@ def handle_caching_play(func: Callable[..., None]):
2834
# the play logic of the latter has to be refactored in the same way the cairo renderer has been, and thus this
2935
# method has to be deleted.
3036

31-
def wrapper(self, scene, *args, **kwargs):
37+
def wrapper(self: OpenGLRenderer, scene: Scene, *args: Any, **kwargs: Any) -> None:
3238
self.skip_animations = self._original_skipping_status
3339
self.update_skipping_status()
3440
animations = scene.compile_animations(*args, **kwargs)
@@ -43,8 +49,9 @@ def wrapper(self, scene, *args, **kwargs):
4349
return
4450
if not config["disable_caching"]:
4551
mobjects_on_scene = scene.mobjects
52+
# TODO: the first argument seems wrong. Shouldn't it be scene instead?
4653
hash_play = get_hash_from_play_call(
47-
self,
54+
self, # type: ignore[arg-type]
4855
self.camera,
4956
animations,
5057
mobjects_on_scene,

0 commit comments

Comments
 (0)