Skip to content

Commit bd19b41

Browse files
committed
chore(//py/torch_tensorrt/dynamo): Tracer mypy compliance
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 0e9def4 commit bd19b41

File tree

3 files changed

+43
-6
lines changed

3 files changed

+43
-6
lines changed

py/torch_tensorrt/dynamo/aten_tracer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
import torch
88
import torch._dynamo as torchdynamo
9+
from torch import _guards
10+
from torch.fx.passes.infra.pass_base import PassResult
911

10-
from torch_tensorrt.fx.utils import req_torch_version
12+
from torch_tensorrt.dynamo.utils import req_torch_version
1113
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
1214
compose_bmm,
1315
compose_chunk,
@@ -97,7 +99,7 @@ def dynamo_trace(
9799
aten_graph: bool,
98100
tracing_mode: str = "real",
99101
dynamo_config: Optional[DynamoConfig] = None,
100-
) -> Tuple[torch.fx.GraphModule, Set]:
102+
) -> Tuple[torch.fx.GraphModule, Set[_guards.Guard]]:
101103
"""
102104
TODO: Once we fully migrate to torchdynamo frontend, we will remove
103105
this config option alltogether. For now, it helps with quick
@@ -126,7 +128,7 @@ def dynamo_trace(
126128

127129

128130
@req_torch_version("2.dev")
129-
def trace(model, inputs, **kwargs):
131+
def trace(model: Union[torch.nn.Module, torch.fx.GraphModule], inputs: Tuple[Any, ...], **kwargs: Any) -> torch.fx.GraphModule:
130132
"""
131133
Optimized trace with necessary passes which re-compose some ops or replace some ops
132134
These passes should be general and functional purpose
@@ -147,11 +149,11 @@ def trace(model, inputs, **kwargs):
147149
fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic")
148150
print(fx_module.graph)
149151
for passes in passes_list:
150-
pr: PassResult = passes(fx_module)
152+
pr: PassResult = passes(fx_module) #type: ignore[assignment] #The type hints in fx are wrong
151153
fx_module = pr.graph_module
152154

153155
fx_module(*inputs)
154156

155157
fx_module = run_const_fold(fx_module)
156158
print(fx_module.graph)
157-
return fx_module
159+
return fx_module #type: ignore[no-any-return]

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def _compile_module(
139139
partitioned_module, submodule, sample_inputs
140140
)
141141

142+
assert submodule_inputs is not None
143+
142144
# Create TRT Module from submodule
143145
trt_mod = convert_module(
144146
submodule,

py/torch_tensorrt/dynamo/utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from torch_tensorrt.dynamo import CompilationSettings
55
from typing import Any, Union, Sequence, Dict
66
from torch_tensorrt import Input, Device
7-
from typing import Optional
7+
from typing import Optional, Callable, Any
8+
from packaging import version
89

910
logger = logging.getLogger(__name__)
1011

@@ -160,3 +161,35 @@ def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings:
160161
logger.debug(f"Compiling with Settings:\n{settings}")
161162

162163
return settings
164+
165+
def req_torch_version(min_torch_version: str = "2.dev") -> Callable[..., Any]:
166+
"""
167+
Create a decorator which verifies the Torch version installed
168+
against a specified version range
169+
170+
Args:
171+
min_torch_version (str): The minimum required Torch version
172+
for the decorated function to work properly
173+
174+
Returns:
175+
A decorator which raises a descriptive error message if
176+
an unsupported Torch version is used
177+
"""
178+
179+
def nested_decorator(f: Callable[..., Any]) -> Callable[..., Any]:
180+
def function_wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
181+
# Parse minimum and current Torch versions
182+
min_version = version.parse(min_torch_version)
183+
current_version = version.parse(torch.__version__)
184+
185+
if current_version < min_version:
186+
raise AssertionError(
187+
f"Expected Torch version {min_torch_version} or greater, "
188+
+ f"when calling {f}. Detected version {torch.__version__}"
189+
)
190+
else:
191+
return f(*args, **kwargs)
192+
193+
return function_wrapper
194+
195+
return nested_decorator

0 commit comments

Comments
 (0)