Skip to content

Commit 0851de5

Browse files
Revert "[ONNX] Remove beartype usage (pytorch#130484)"
This reverts commit 1794c35. Reverted pytorch#130484 on behalf of https://github.com/clee2000 due to test_sympy_utils failure is real https://github.com/pytorch/pytorch/actions/runs/9961499559/job/27523758780 https://hud.pytorch.org/pytorch/pytorch/commit/1794c35912025aa44b0d70f67ff664b4f7bd1014. Dr CI is matching with commits in current commit? ([comment](pytorch#130484 (comment)))
1 parent 09b1b11 commit 0851de5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1265
-107
lines changed

.ci/docker/common/install_onnx.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ retry () {
1010

1111
# A bunch of custom pip dependencies for ONNX
1212
pip_install \
13+
beartype==0.15.0 \
1314
filelock==3.9.0 \
1415
flatbuffers==2.0 \
1516
mock==5.0.1 \

test/onnx/dynamo/test_exporter_api.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@
33
import os
44

55
import onnx
6+
from beartype import roar
67

78
import torch
89
from torch.onnx import dynamo_export, ExportOptions, ONNXProgram
9-
from torch.onnx._internal import exporter
10+
from torch.onnx._internal import exporter, io_adapter
1011
from torch.onnx._internal.exporter import (
1112
LargeProtobufONNXProgramSerializer,
1213
ONNXProgramSerializer,
1314
ProtobufONNXProgramSerializer,
1415
ResolvedExportOptions,
1516
)
17+
from torch.onnx._internal.fx import diagnostics
1618

1719
from torch.testing._internal import common_utils
1820

@@ -47,6 +49,15 @@ def forward(self, x):
4749

4850

4951
class TestExportOptionsAPI(common_utils.TestCase):
52+
def test_raise_on_invalid_argument_type(self):
53+
expected_exception_type = roar.BeartypeException
54+
with self.assertRaises(expected_exception_type):
55+
ExportOptions(dynamic_shapes=2) # type: ignore[arg-type]
56+
with self.assertRaises(expected_exception_type):
57+
ExportOptions(diagnostic_options="DEBUG") # type: ignore[arg-type]
58+
with self.assertRaises(expected_exception_type):
59+
ResolvedExportOptions(options=12) # type: ignore[arg-type]
60+
5061
def test_dynamic_shapes_default(self):
5162
options = ResolvedExportOptions(ExportOptions())
5263
self.assertFalse(options.dynamic_shapes)
@@ -109,6 +120,7 @@ def test_save_to_file_using_specified_serializer_without_inheritance(self):
109120

110121
# NOTE: Inheritance from `ONNXProgramSerializer` is not required.
111122
# Because `ONNXProgramSerializer` is a Protocol class.
123+
# `beartype` will not complain.
112124
class CustomSerializer:
113125
def serialize(
114126
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
@@ -184,8 +196,27 @@ def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_i
184196
),
185197
)
186198

199+
def test_raise_on_invalid_save_argument_type(self):
200+
with self.assertRaises(roar.BeartypeException):
201+
ONNXProgram(torch.nn.Linear(2, 3)) # type: ignore[arg-type]
202+
onnx_program = ONNXProgram(
203+
onnx.ModelProto(),
204+
io_adapter.InputAdapter(),
205+
io_adapter.OutputAdapter(),
206+
diagnostics.DiagnosticContext("test", "1.0"),
207+
fake_context=None,
208+
)
209+
with self.assertRaises(roar.BeartypeException):
210+
onnx_program.save(None) # type: ignore[arg-type]
211+
onnx_program.model_proto
212+
187213

188214
class TestProtobufONNXProgramSerializerAPI(common_utils.TestCase):
215+
def test_raise_on_invalid_argument_type(self):
216+
with self.assertRaises(roar.BeartypeException):
217+
serializer = ProtobufONNXProgramSerializer()
218+
serializer.serialize(None, None) # type: ignore[arg-type]
219+
189220
def test_serialize_raises_when_model_greater_than_2gb(self):
190221
onnx_program = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1))
191222
serializer = ProtobufONNXProgramSerializer()

test/onnx/internal/test_beartype.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Owner(s): ["module: onnx"]
2+
"""Unit tests for the internal beartype wrapper module."""
3+
4+
import unittest
5+
6+
from torch.onnx._internal import _beartype
7+
from torch.testing._internal import common_utils
8+
9+
10+
def beartype_installed():
11+
try:
12+
import beartype # noqa: F401
13+
except ImportError:
14+
return False
15+
return True
16+
17+
18+
def skip_if_beartype_not_installed(test_case):
19+
return unittest.skipIf(not beartype_installed(), "beartype is not installed")(
20+
test_case
21+
)
22+
23+
24+
def func_with_type_hint(x: int) -> int:
25+
return x
26+
27+
28+
def func_with_incorrect_type_hint(x: int) -> str:
29+
return x # type: ignore[return-value]
30+
31+
32+
@common_utils.instantiate_parametrized_tests
33+
class TestBeartype(common_utils.TestCase):
34+
def test_create_beartype_decorator_returns_no_op_decorator_when_disabled(self):
35+
decorator = _beartype._create_beartype_decorator(
36+
_beartype.RuntimeTypeCheckState.DISABLED,
37+
)
38+
decorated = decorator(func_with_incorrect_type_hint)
39+
decorated("string_input") # type: ignore[arg-type]
40+
41+
@skip_if_beartype_not_installed
42+
def test_create_beartype_decorator_warns_when_warnings(self):
43+
decorator = _beartype._create_beartype_decorator(
44+
_beartype.RuntimeTypeCheckState.WARNINGS,
45+
)
46+
decorated = decorator(func_with_incorrect_type_hint)
47+
with self.assertWarns(_beartype.CallHintViolationWarning):
48+
decorated("string_input") # type: ignore[arg-type]
49+
50+
@common_utils.parametrize("arg", [1, "string_input"])
51+
@skip_if_beartype_not_installed
52+
def test_create_beartype_decorator_errors_when_errors(self, arg):
53+
import beartype
54+
55+
decorator = _beartype._create_beartype_decorator(
56+
_beartype.RuntimeTypeCheckState.ERRORS,
57+
)
58+
decorated = decorator(func_with_incorrect_type_hint)
59+
with self.assertRaises(beartype.roar.BeartypeCallHintViolation):
60+
decorated(arg)
61+
62+
@skip_if_beartype_not_installed
63+
def test_create_beartype_decorator_warning_calls_function_once(self):
64+
call_count = 0
65+
66+
def func_with_incorrect_type_hint_and_side_effect(x: int) -> str:
67+
nonlocal call_count
68+
call_count += 1
69+
return x # type: ignore[return-value]
70+
71+
decorator = _beartype._create_beartype_decorator(
72+
_beartype.RuntimeTypeCheckState.WARNINGS,
73+
)
74+
decorated = decorator(func_with_incorrect_type_hint_and_side_effect)
75+
decorated("string_input") # type: ignore[arg-type]
76+
self.assertEqual(call_count, 1)
77+
decorated(1)
78+
# The return value violates the type hint, but the function is called
79+
# only once.
80+
self.assertEqual(call_count, 2)
81+
82+
83+
if __name__ == "__main__":
84+
common_utils.run_tests()

test/onnx/onnx_test_common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import torch
3535
from torch import export as torch_export
3636
from torch.onnx import _constants, verification
37+
from torch.onnx._internal import _beartype
3738
from torch.onnx._internal.fx import diagnostics
3839
from torch.testing._internal import common_utils
3940
from torch.testing._internal.opinfo import core as opinfo_core
@@ -205,6 +206,7 @@ def _run_test(m, remained_onnx_input_idx, flatten=True, ignore_none=True):
205206
if not is_model_script and not self.is_script:
206207
_run_test(model, tracing_remained_onnx_input_idx)
207208

209+
@_beartype.beartype
208210
def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
209211
self,
210212
model: _ModelType,
@@ -358,6 +360,7 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
358360
)
359361

360362

363+
@_beartype.beartype
361364
def run_ort(
362365
onnx_model: Union[str, torch.onnx.ONNXProgram],
363366
pytorch_inputs: Sequence[_InputArgsType],
@@ -403,6 +406,7 @@ def run_ort(
403406
return session.run(None, ort_input)
404407

405408

409+
@_beartype.beartype
406410
def _try_clone_model(model: _ModelType) -> _ModelType:
407411
"""Used for preserving original model in case forward mutates model states."""
408412
try:
@@ -414,12 +418,14 @@ def _try_clone_model(model: _ModelType) -> _ModelType:
414418
return model
415419

416420

421+
@_beartype.beartype
417422
def _try_clone_inputs(input_args, input_kwargs):
418423
ref_input_args = copy.deepcopy(input_args)
419424
ref_input_kwargs = copy.deepcopy(input_kwargs)
420425
return ref_input_args, ref_input_kwargs
421426

422427

428+
@_beartype.beartype
423429
def _compare_pytorch_onnx_with_ort(
424430
onnx_program: torch.onnx.ONNXProgram,
425431
model: _ModelType,

test/onnx/test_fx_to_onnx_with_onnxruntime.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch import nn
2323

2424
from torch._subclasses import fake_tensor
25-
from torch.onnx._internal import exporter
25+
from torch.onnx._internal import _beartype, exporter
2626
from torch.onnx._internal.fx import (
2727
diagnostics,
2828
fx_symbolic_graph_extractor,
@@ -721,6 +721,7 @@ def forward(self, x):
721721
CustomModule(), (torch.randn(1, 2, 3),)
722722
)
723723

724+
@_beartype.beartype
724725
def _test_fx_symbolic_tracer_large_scale_exporter(
725726
self,
726727
model_name: str,
@@ -954,6 +955,7 @@ def setUp(self):
954955
super().setUp()
955956
self.ort_version = onnxruntime.__version__
956957

958+
@_beartype.beartype
957959
def _test_fake_tensor_mode_exporter(
958960
self,
959961
model_name: str,

torch/onnx/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.
3535
`_jit_pass_onnx_remove_inplace_ops_for_onnx`, and
3636
transparently dispatched to their non inplace versions in
3737
"run_symbolic_function". See Note [Export inplace](#export-inplace)
38+
- Required: Annotate new symbolic functions with type annotations and decorate
39+
with `@_beartype.beartype` to enable runtime type checking.
40+
`@_beartype.beartype` should typically be the closest to the function to
41+
ensure proper type checking.
3842

3943
### A note on Tensor types
4044

torch/onnx/_internal/_beartype.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# mypy: allow-untyped-defs
2+
"""An internal wrapper for the beartype library.
3+
4+
The module returns a no-op decorator when the beartype library is not installed.
5+
"""
6+
import enum
7+
import functools
8+
import os
9+
import traceback
10+
import typing
11+
import warnings
12+
from types import ModuleType
13+
14+
try:
15+
import beartype as _beartype_lib # type: ignore[import]
16+
from beartype import roar as _roar # type: ignore[import]
17+
18+
# Beartype warns when we import from typing because the types are deprecated
19+
# in Python 3.9. But there will be a long time until we can move to using
20+
# the native container types for type annotations (when 3.9 is the lowest
21+
# supported version). So we silence the warning.
22+
warnings.filterwarnings(
23+
"ignore",
24+
category=_roar.BeartypeDecorHintPep585DeprecationWarning,
25+
)
26+
27+
if _beartype_lib.__version__ == "0.16.0":
28+
# beartype 0.16.0 has a bug that causes it to crash when used with
29+
# PyTorch. See https://github.com/beartype/beartype/issues/282
30+
warnings.warn("beartype 0.16.0 is not supported. Please upgrade to 0.16.1+.")
31+
_beartype_lib = None # type: ignore[assignment]
32+
except ImportError:
33+
_beartype_lib = None # type: ignore[assignment]
34+
except Exception as e:
35+
# Warn errors that are not import errors (unexpected).
36+
warnings.warn(f"{e}")
37+
_beartype_lib = None # type: ignore[assignment]
38+
39+
40+
@enum.unique
41+
class RuntimeTypeCheckState(enum.Enum):
42+
"""Runtime type check state."""
43+
44+
# Runtime type checking is disabled.
45+
DISABLED = enum.auto()
46+
# Runtime type checking is enabled but warnings are shown only.
47+
WARNINGS = enum.auto()
48+
# Runtime type checking is enabled.
49+
ERRORS = enum.auto()
50+
51+
52+
class CallHintViolationWarning(UserWarning):
53+
"""Warning raised when a type hint is violated during a function call."""
54+
55+
pass
56+
57+
58+
def _no_op_decorator(func):
59+
return func
60+
61+
62+
def _create_beartype_decorator(
63+
runtime_check_state: RuntimeTypeCheckState,
64+
):
65+
# beartype needs to be imported outside of the function and aliased because
66+
# this module overwrites the name "beartype".
67+
68+
if runtime_check_state == RuntimeTypeCheckState.DISABLED:
69+
return _no_op_decorator
70+
if _beartype_lib is None:
71+
# If the beartype library is not installed, return a no-op decorator
72+
return _no_op_decorator
73+
74+
assert isinstance(_beartype_lib, ModuleType)
75+
76+
if runtime_check_state == RuntimeTypeCheckState.ERRORS:
77+
# Enable runtime type checking which errors on any type hint violation.
78+
return _beartype_lib.beartype
79+
80+
# Warnings only
81+
def beartype(func):
82+
"""Warn on type hint violation."""
83+
84+
if "return" in func.__annotations__:
85+
# Remove the return type from the func function's
86+
# annotations so that the beartype decorator does not complain
87+
# about the return type.
88+
return_type = func.__annotations__["return"]
89+
del func.__annotations__["return"]
90+
beartyped = _beartype_lib.beartype(func)
91+
# Restore the return type to the func function's annotations
92+
func.__annotations__["return"] = return_type
93+
else:
94+
beartyped = _beartype_lib.beartype(func)
95+
96+
@functools.wraps(func)
97+
def _coerce_beartype_exceptions_to_warnings(*args, **kwargs):
98+
try:
99+
return beartyped(*args, **kwargs)
100+
except _roar.BeartypeCallHintParamViolation:
101+
# Fall back to the original function if the beartype hint is violated.
102+
warnings.warn(
103+
traceback.format_exc(),
104+
category=CallHintViolationWarning,
105+
stacklevel=2,
106+
)
107+
108+
return func(*args, **kwargs) # noqa: B012
109+
110+
return _coerce_beartype_exceptions_to_warnings
111+
112+
return beartype
113+
114+
115+
if typing.TYPE_CHECKING:
116+
# This is a hack to make mypy play nicely with the beartype decorator.
117+
def beartype(func):
118+
return func
119+
120+
else:
121+
_TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK = os.getenv(
122+
"TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK"
123+
)
124+
if _TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK == "ERRORS":
125+
_runtime_type_check_state = RuntimeTypeCheckState.ERRORS
126+
elif _TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK == "DISABLED":
127+
_runtime_type_check_state = RuntimeTypeCheckState.DISABLED
128+
else:
129+
_runtime_type_check_state = RuntimeTypeCheckState.WARNINGS
130+
beartype = _create_beartype_decorator(_runtime_type_check_state)
131+
# Make sure that the beartype decorator is enabled whichever path we took.
132+
assert beartype is not None

0 commit comments

Comments
 (0)