Skip to content

Commit 1b03650

Browse files
committed
feat: Add decorator to improve legacy support
- Add general-purpose utility for intercepting invalid Torch version imports and displaying an error message to the user, halting compilation - Utility automatically parses string inputs to semantic versions using the `packaging` utility package - Decorator can accept variable input versions and checks whether the current version in use in the environment is at least as large as the minimum specified version - Display clear error message upon calling the aten tracer when using a legacy Torch version, for which the Dynamo import is invalid
1 parent fce0a01 commit 1b03650

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch._dynamo as torchdynamo
1010

1111
from torch.fx.passes.infra.pass_base import PassResult
12-
12+
from torch_tensorrt.fx.utils import req_torch_version
1313
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
1414
compose_bmm,
1515
compose_chunk,
@@ -91,6 +91,7 @@ def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None,
9191
sys.setrecursionlimit(default)
9292

9393

94+
@req_torch_version("2.0")
9495
def dynamo_trace(
9596
f: Callable[..., Value],
9697
# pyre-ignore
@@ -104,11 +105,6 @@ def dynamo_trace(
104105
this config option alltogether. For now, it helps with quick
105106
experiments with playing around with TorchDynamo
106107
"""
107-
if torch.__version__.startswith("1"):
108-
raise ValueError(
109-
f"The aten tracer requires Torch version >= 2.0. Detected version {torch.__version__}"
110-
)
111-
112108
if dynamo_config is None:
113109
dynamo_config = DynamoConfig()
114110
with using_config(dynamo_config), setting_python_recursive_limit(2000):
@@ -131,11 +127,13 @@ def dynamo_trace(
131127
) from exc
132128

133129

130+
@req_torch_version("2.0")
134131
def trace(f, args, *rest):
135132
graph_module, guards = dynamo_trace(f, args, True, "symbolic")
136133
return graph_module, guards
137134

138135

136+
@req_torch_version("2.0")
139137
def opt_trace(f, args, *rest):
140138
"""
141139
Optimized trace with necessary passes which re-compose some ops or replace some ops

py/torch_tensorrt/fx/utils.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from enum import Enum
2-
from typing import List
2+
from typing import List, Callable
3+
from packaging import version
34

45
# @manual=//deeplearning/trt/python:py_tensorrt
56
import tensorrt as trt
@@ -104,3 +105,36 @@ def f(*inp):
104105
mod = run_const_fold(mod)
105106
mod = replace_op_with_indices(mod)
106107
return mod
108+
109+
110+
def req_torch_version(min_torch_version: str = "2.0"):
111+
"""
112+
Create a decorator which verifies the Torch version installed
113+
against a specified version range
114+
115+
Args:
116+
min_torch_version (str): The minimum required Torch version
117+
for the decorated function to work properly
118+
119+
Returns:
120+
A decorator which raises a descriptive error message if
121+
an unsupported Torch version is used
122+
"""
123+
124+
def nested_decorator(f: Callable):
125+
def function_wrapper(*args, **kwargs):
126+
# Parse minimum and current Torch versions
127+
min_version = version.parse(min_torch_version)
128+
current_version = version.parse(torch.__version__)
129+
130+
if current_version < min_version:
131+
raise AssertionError(
132+
f"Expected Torch version {min_torch_version} or greater, "
133+
+ f"when calling {f}. Detected version {torch.__version__}"
134+
)
135+
else:
136+
return f(*args, **kwargs)
137+
138+
return function_wrapper
139+
140+
return nested_decorator

0 commit comments

Comments
 (0)