Skip to content

Commit 0d21a7e

Browse files
SirJamesClarkMaxwellpre-commit-ci[bot]chopan050
authored
Fixed Arrow3D.put_start_and_end_on() to use the actual end of the arrow (#3706)
* my test is not passing, i need to add a little bit of docs. except that everything is fine. Issue is solved! * fixed the issue #3655 * removed comments * fix: 3706 original issue, without adding unnecessary dot added: i added self.height parameter in Cone class my tests passes * Changes that way how end point of Arrow3D is calculated. * I've improved the methods get_start and get_end for the Cone class, and get_end for the Arrow3D class to ensure they return accurate geometrical points after transformations. Additionally, I've included unit tests to verify the correctness of these methods for the Cone class. * Finished! Replaced VMobject by VectorizedPoint as Ben suggested while ago * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Francisco Manríquez Novoa <[email protected]>
1 parent 93a20cd commit 0d21a7e

File tree

2 files changed

+57
-12
lines changed

2 files changed

+57
-12
lines changed

manim/mobject/three_d/three_dimensions.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from manim.mobject.mobject import *
3333
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
3434
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
35-
from manim.mobject.types.vectorized_mobject import VGroup, VMobject
35+
from manim.mobject.types.vectorized_mobject import VectorizedPoint, VGroup, VMobject
3636
from manim.utils.color import (
3737
ManimColor,
3838
ParsableManimColor,
@@ -616,17 +616,18 @@ def __init__(
616616
**kwargs,
617617
)
618618
# used for rotations
619+
self.new_height = height
619620
self._current_theta = 0
620621
self._current_phi = 0
621-
622+
self.base_circle = Circle(
623+
radius=base_radius,
624+
color=self.fill_color,
625+
fill_opacity=self.fill_opacity,
626+
stroke_width=0,
627+
)
628+
self.base_circle.shift(height * IN)
629+
self._set_start_and_end_attributes(direction)
622630
if show_base:
623-
self.base_circle = Circle(
624-
radius=base_radius,
625-
color=self.fill_color,
626-
fill_opacity=self.fill_opacity,
627-
stroke_width=0,
628-
)
629-
self.base_circle.shift(height * IN)
630631
self.add(self.base_circle)
631632

632633
self._rotate_to_direction()
@@ -656,6 +657,12 @@ def func(self, u: float, v: float) -> np.ndarray:
656657
],
657658
)
658659

660+
def get_start(self) -> np.ndarray:
661+
return self.start_point.get_center()
662+
663+
def get_end(self) -> np.ndarray:
664+
return self.end_point.get_center()
665+
659666
def _rotate_to_direction(self) -> None:
660667
x, y, z = self.direction
661668

@@ -710,6 +717,15 @@ def get_direction(self) -> np.ndarray:
710717
"""
711718
return self.direction
712719

720+
def _set_start_and_end_attributes(self, direction):
721+
normalized_direction = direction * np.linalg.norm(direction)
722+
723+
start = self.base_circle.get_center()
724+
end = start + normalized_direction * self.new_height
725+
self.start_point = VectorizedPoint(start)
726+
self.end_point = VectorizedPoint(end)
727+
self.add(self.start_point, self.end_point)
728+
713729

714730
class Cylinder(Surface):
715731
"""A cylinder, defined by its height, radius and direction,
@@ -1150,14 +1166,20 @@ def __init__(
11501166
self.end - height * self.direction,
11511167
**kwargs,
11521168
)
1153-
11541169
self.cone = Cone(
1155-
direction=self.direction, base_radius=base_radius, height=height, **kwargs
1170+
direction=self.direction,
1171+
base_radius=base_radius,
1172+
height=height,
1173+
**kwargs,
11561174
)
11571175
self.cone.shift(end)
1158-
self.add(self.cone)
1176+
self.end_point = VectorizedPoint(end)
1177+
self.add(self.end_point, self.cone)
11591178
self.set_color(color)
11601179

1180+
def get_end(self) -> np.ndarray:
1181+
return self.end_point.get_center()
1182+
11611183

11621184
class Torus(Surface):
11631185
"""A torus.

tests/test_graphical_units/test_threed.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ def test_Cone(scene):
3333
scene.add(Cone(resolution=16))
3434

3535

36+
def test_Cone_get_start_and_get_end():
37+
cone = Cone().shift(RIGHT).rotate(PI / 4, about_point=ORIGIN, about_edge=OUT)
38+
start = [0.70710678, 0.70710678, -1.0]
39+
end = [0.70710678, 0.70710678, 0.0]
40+
assert np.allclose(
41+
cone.get_start(), start, atol=0.01
42+
), "start points of Cone do not match"
43+
assert np.allclose(
44+
cone.get_end(), end, atol=0.01
45+
), "end points of Cone do not match"
46+
47+
3648
@frames_comparison(base_scene=ThreeDScene)
3749
def test_Cylinder(scene):
3850
scene.add(Cylinder())
@@ -149,3 +161,14 @@ def param_surface(u, v):
149161
axes=axes, colorscale=[(RED, -0.4), (YELLOW, 0), (GREEN, 0.4)], axis=1
150162
)
151163
scene.add(axes, surface_plane)
164+
165+
166+
def test_get_start_and_end_Arrow3d():
167+
start, end = ORIGIN, np.array([2, 1, 0])
168+
arrow = Arrow3D(start, end)
169+
assert np.allclose(
170+
arrow.get_start(), start, atol=0.01
171+
), "start points of Arrow3D do not match"
172+
assert np.allclose(
173+
arrow.get_end(), end, atol=0.01
174+
), "end points of Arrow3D do not match"

0 commit comments

Comments
 (0)