Skip to content

Commit 1794c35

Browse files
justinchubypytorchmergebot
authored andcommitted
[ONNX] Remove beartype usage (pytorch#130484)
beartype has served us well in identifying type errors and ensuring we call internal functions with the correct arguments (thanks!). However, the value of having beartype is diminished because of the following: 1. When beartype improves support for better Dict[] type checking, it discovered typing mistakes in some functions that were previously uncaught. This caused the exporter to fail with newer versions beartype when it used to succeed. Since we cannot fix PyTorch and release a new version just because of this, it creates confusion for users that have beartype in their environment from using torch.onnx 2. beartype adds an additional call line in the traceback, which makes the already thick dynamo stack even larger, affecting readability when users diagnose errors with the traceback. 3. Since the typing annotations need to be evaluated, we cannot use new syntaxes like `|` because we need to maintain compatibility with Python 3.8. We don't want to wait for PyTorch take py310 as the lowest supported Python before using the new typing syntaxes. Pull Request resolved: pytorch#130484 Approved by: https://github.com/titaiwangms
1 parent 67e22d6 commit 1794c35

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+107
-1265
lines changed

.ci/docker/common/install_onnx.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ retry () {
1010

1111
# A bunch of custom pip dependencies for ONNX
1212
pip_install \
13-
beartype==0.15.0 \
1413
filelock==3.9.0 \
1514
flatbuffers==2.0 \
1615
mock==5.0.1 \

test/onnx/dynamo/test_exporter_api.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,16 @@
33
import os
44

55
import onnx
6-
from beartype import roar
76

87
import torch
98
from torch.onnx import dynamo_export, ExportOptions, ONNXProgram
10-
from torch.onnx._internal import exporter, io_adapter
9+
from torch.onnx._internal import exporter
1110
from torch.onnx._internal.exporter import (
1211
LargeProtobufONNXProgramSerializer,
1312
ONNXProgramSerializer,
1413
ProtobufONNXProgramSerializer,
1514
ResolvedExportOptions,
1615
)
17-
from torch.onnx._internal.fx import diagnostics
1816

1917
from torch.testing._internal import common_utils
2018

@@ -49,15 +47,6 @@ def forward(self, x):
4947

5048

5149
class TestExportOptionsAPI(common_utils.TestCase):
52-
def test_raise_on_invalid_argument_type(self):
53-
expected_exception_type = roar.BeartypeException
54-
with self.assertRaises(expected_exception_type):
55-
ExportOptions(dynamic_shapes=2) # type: ignore[arg-type]
56-
with self.assertRaises(expected_exception_type):
57-
ExportOptions(diagnostic_options="DEBUG") # type: ignore[arg-type]
58-
with self.assertRaises(expected_exception_type):
59-
ResolvedExportOptions(options=12) # type: ignore[arg-type]
60-
6150
def test_dynamic_shapes_default(self):
6251
options = ResolvedExportOptions(ExportOptions())
6352
self.assertFalse(options.dynamic_shapes)
@@ -120,7 +109,6 @@ def test_save_to_file_using_specified_serializer_without_inheritance(self):
120109

121110
# NOTE: Inheritance from `ONNXProgramSerializer` is not required.
122111
# Because `ONNXProgramSerializer` is a Protocol class.
123-
# `beartype` will not complain.
124112
class CustomSerializer:
125113
def serialize(
126114
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
@@ -196,27 +184,8 @@ def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_i
196184
),
197185
)
198186

199-
def test_raise_on_invalid_save_argument_type(self):
200-
with self.assertRaises(roar.BeartypeException):
201-
ONNXProgram(torch.nn.Linear(2, 3)) # type: ignore[arg-type]
202-
onnx_program = ONNXProgram(
203-
onnx.ModelProto(),
204-
io_adapter.InputAdapter(),
205-
io_adapter.OutputAdapter(),
206-
diagnostics.DiagnosticContext("test", "1.0"),
207-
fake_context=None,
208-
)
209-
with self.assertRaises(roar.BeartypeException):
210-
onnx_program.save(None) # type: ignore[arg-type]
211-
onnx_program.model_proto
212-
213187

214188
class TestProtobufONNXProgramSerializerAPI(common_utils.TestCase):
215-
def test_raise_on_invalid_argument_type(self):
216-
with self.assertRaises(roar.BeartypeException):
217-
serializer = ProtobufONNXProgramSerializer()
218-
serializer.serialize(None, None) # type: ignore[arg-type]
219-
220189
def test_serialize_raises_when_model_greater_than_2gb(self):
221190
onnx_program = torch.onnx.dynamo_export(_LargeModel(), torch.randn(1))
222191
serializer = ProtobufONNXProgramSerializer()

test/onnx/internal/test_beartype.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

test/onnx/onnx_test_common.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import torch
3535
from torch import export as torch_export
3636
from torch.onnx import _constants, verification
37-
from torch.onnx._internal import _beartype
3837
from torch.onnx._internal.fx import diagnostics
3938
from torch.testing._internal import common_utils
4039
from torch.testing._internal.opinfo import core as opinfo_core
@@ -206,7 +205,6 @@ def _run_test(m, remained_onnx_input_idx, flatten=True, ignore_none=True):
206205
if not is_model_script and not self.is_script:
207206
_run_test(model, tracing_remained_onnx_input_idx)
208207

209-
@_beartype.beartype
210208
def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
211209
self,
212210
model: _ModelType,
@@ -360,7 +358,6 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
360358
)
361359

362360

363-
@_beartype.beartype
364361
def run_ort(
365362
onnx_model: Union[str, torch.onnx.ONNXProgram],
366363
pytorch_inputs: Sequence[_InputArgsType],
@@ -406,7 +403,6 @@ def run_ort(
406403
return session.run(None, ort_input)
407404

408405

409-
@_beartype.beartype
410406
def _try_clone_model(model: _ModelType) -> _ModelType:
411407
"""Used for preserving original model in case forward mutates model states."""
412408
try:
@@ -418,14 +414,12 @@ def _try_clone_model(model: _ModelType) -> _ModelType:
418414
return model
419415

420416

421-
@_beartype.beartype
422417
def _try_clone_inputs(input_args, input_kwargs):
423418
ref_input_args = copy.deepcopy(input_args)
424419
ref_input_kwargs = copy.deepcopy(input_kwargs)
425420
return ref_input_args, ref_input_kwargs
426421

427422

428-
@_beartype.beartype
429423
def _compare_pytorch_onnx_with_ort(
430424
onnx_program: torch.onnx.ONNXProgram,
431425
model: _ModelType,

test/onnx/test_fx_to_onnx_with_onnxruntime.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch import nn
2323

2424
from torch._subclasses import fake_tensor
25-
from torch.onnx._internal import _beartype, exporter
25+
from torch.onnx._internal import exporter
2626
from torch.onnx._internal.fx import (
2727
diagnostics,
2828
fx_symbolic_graph_extractor,
@@ -721,7 +721,6 @@ def forward(self, x):
721721
CustomModule(), (torch.randn(1, 2, 3),)
722722
)
723723

724-
@_beartype.beartype
725724
def _test_fx_symbolic_tracer_large_scale_exporter(
726725
self,
727726
model_name: str,
@@ -955,7 +954,6 @@ def setUp(self):
955954
super().setUp()
956955
self.ort_version = onnxruntime.__version__
957956

958-
@_beartype.beartype
959957
def _test_fake_tensor_mode_exporter(
960958
self,
961959
model_name: str,

torch/onnx/README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.
3535
`_jit_pass_onnx_remove_inplace_ops_for_onnx`, and
3636
transparently dispatched to their non inplace versions in
3737
"run_symbolic_function". See Note [Export inplace](#export-inplace)
38-
- Required: Annotate new symbolic functions with type annotations and decorate
39-
with `@_beartype.beartype` to enable runtime type checking.
40-
`@_beartype.beartype` should typically be the closest to the function to
41-
ensure proper type checking.
4238

4339
### A note on Tensor types
4440

torch/onnx/_internal/_beartype.py

Lines changed: 0 additions & 132 deletions
This file was deleted.

0 commit comments

Comments
 (0)