Skip to content

Commit 884e0ea

Browse files
committed
chore(//py/torch_tensorrt/dynamo/conversion): mypy conforming
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 92c4355 commit 884e0ea

File tree

3 files changed

+47
-36
lines changed

3 files changed

+47
-36
lines changed

py/torch_tensorrt/dynamo/conversion/trt_interpreter.py renamed to py/torch_tensorrt/dynamo/conversion/TRTInterpreter.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
from datetime import datetime
44
from packaging import version
5-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
5+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
66

77
import numpy
88

@@ -41,12 +41,13 @@ def __init__(
4141
self,
4242
module: torch.fx.GraphModule,
4343
input_specs: List[Input],
44-
logger_level=None,
45-
output_dtypes=None,
44+
logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING,
45+
output_dtypes: Optional[List[torch.dtype]] = None,
4646
):
4747
super().__init__(module)
4848

49-
self.logger = trt.Logger(logger_level or trt.Logger.WARNING)
49+
# TODO: @narendasan replace with Torch-TensorRT Logger
50+
self.logger = trt.Logger(logger_level)
5051
self.builder = trt.Builder(self.logger)
5152

5253
flag = 0
@@ -59,12 +60,13 @@ def __init__(
5960

6061
missing_ops = self.validate_conversion()
6162
if missing_ops:
63+
# TODO: @narendasan make sure to set logging.captureWarnings(True)
6264
warnings.warn(
6365
"Interpretation will fail due to missing operations \n"
6466
+ "\n".join(f"{i}" for i in missing_ops)
6567
)
6668

67-
self.optimization_profiles = (
69+
self.optimization_profiles: Optional[List[trt.IOptimizationProfile]] = (
6870
[self.builder.create_optimization_profile()]
6971
if any(
7072
input_spec.shape_mode == Input._ShapeMode.DYNAMIC
@@ -86,8 +88,8 @@ def __init__(
8688
# Data types for TRT Module output Tensors
8789
self.output_dtypes = output_dtypes
8890

89-
def validate_conversion(self):
90-
missing_converter = set()
91+
def validate_conversion(self) -> Set[str]:
92+
missing_converters = set()
9193

9294
for node in self.module.graph.nodes:
9395
if node.op == "call_function" and not CONVERTERS.get(node):
@@ -98,25 +100,25 @@ def validate_conversion(self):
98100
submod = self.fetch_attr(node.target)
99101
submod_type = getattr(submod, "_base_class_origin", type(submod))
100102
if not CONVERTERS.get(node):
101-
missing_converter.add(f"{node.op} {torch.typename(submod_type)}")
103+
missing_converter.add(f"{node.op} {torch.typename(submod_type)}") # type: ignore[no-untyped-call]
102104

103-
return missing_converter
105+
return missing_converters
104106

105107
def run(
106108
self,
107-
workspace_size=0,
108-
precision=torch.float32,
109-
sparse_weights=False,
110-
disable_tf32=False,
111-
force_fp32_output=False,
112-
strict_type_constraints=False,
113-
algorithm_selector=None,
114-
timing_cache=None,
115-
profiling_verbosity=None,
116-
tactic_sources=None,
117-
max_aux_streams=None,
118-
version_compatible=False,
119-
optimization_level=None,
109+
workspace_size: int = 0,
110+
precision: torch.dtype = torch.float32, # TODO: @peri044 Needs to be expanded to set
111+
sparse_weights: bool = False,
112+
disable_tf32: bool = False,
113+
force_fp32_output: bool = False,
114+
strict_type_constraints: bool = False,
115+
algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
116+
timing_cache: Optional[trt.ITimingCache] =None,
117+
profiling_verbosity: Optional[trt.ProfilingVerbosity] = None,
118+
tactic_sources: Optional[int] = None,
119+
max_aux_streams: Optional[int] = None,
120+
version_compatible: bool = False,
121+
optimization_level: Optional[int] = None,
120122
) -> TRTInterpreterResult:
121123
"""
122124
Build TensorRT engine with some configs.
@@ -204,7 +206,7 @@ def run(
204206
if strict_type_constraints:
205207
builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
206208

207-
if self.optimization_profiles:
209+
if len(self.optimization_profiles) > 0:
208210
for optimization_profile in self.optimization_profiles:
209211
builder_config.add_optimization_profile(optimization_profile)
210212

@@ -232,7 +234,7 @@ def run(
232234
engine, self._input_names, self._output_names, serialized_cache
233235
)
234236

235-
def run_node(self, n):
237+
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
236238
self._cur_node_name = str(n)
237239
self._cur_node = n
238240
# add "_itensor_to_tensor_meta"
@@ -241,29 +243,31 @@ def run_node(self, n):
241243
n.kwargs = kwargs
242244

243245
# run the node
244-
trt_node = super().run_node(n)
246+
trt_node: torch.fx.Node = super().run_node(n)
245247

246248
# remove "_itensor_to_tensor_meta"
247249
kwargs = dict(n.kwargs)
248250
del kwargs["_itensor_to_tensor_meta"]
249251
n.kwargs = kwargs
250252

251253
if isinstance(trt_node, trt.tensorrt.ITensor):
252-
self._itensor_to_tensor_meta[trt_node] = n.meta.get("tensor_meta")
254+
self._itensor_to_tensor_meta[trt_node] = n.meta.get("tensor_meta") #type: ignore[assignment]
253255

254256
return trt_node
255257

256-
def placeholder(self, target, args, kwargs):
258+
def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
257259
self._input_names.append(target)
258260
current_input = self.input_specs[self.input_specs_iter]
259261
self.input_specs_iter += 1
260262
# Set optimization profile for dynamic input shape
261-
shape = current_input.shape
263+
shape = None
262264
if current_input.shape_mode == Input._ShapeMode.DYNAMIC:
265+
assert isinstance(current_input.shape, dict)
263266
shape = []
264267
min_shape = current_input.shape["min_shape"]
265268
opt_shape = current_input.shape["opt_shape"]
266269
max_shape = current_input.shape["max_shape"]
270+
# TODO: Does not support disjoint optimization profiles?
267271
self.optimization_profiles[0].set_shape(
268272
target, min_shape, opt_shape, max_shape
269273
)
@@ -274,14 +278,20 @@ def placeholder(self, target, args, kwargs):
274278
else:
275279
# -1 to represent the dynamic dimension
276280
shape.append(-1)
281+
elif current_input.shape_mode == Input._ShapeMode.STATIC:
282+
assert isinstance(current_input.shape, tuple)
283+
shape = list(current_input.shape)
284+
else:
285+
raise RuntimeError(f"Unable to access shape spec for input: {target} (got: {current_input})")
286+
277287

278288
return self.network.add_input(
279289
name=target,
280290
shape=tuple(shape),
281291
dtype=unified_dtype_converter(current_input.torch_dtype, Frameworks.TRT),
282292
)
283293

284-
def call_module(self, target, args, kwargs):
294+
def call_module(self, target: str, args: Any, kwargs: Any) -> Any: #Probably should be Tuple[trt.ITensor]? Case for Any?
285295
assert isinstance(target, str)
286296
submod = self.fetch_attr(target)
287297
submod_type = getattr(submod, "_base_class_origin", type(submod))
@@ -295,17 +305,18 @@ def call_module(self, target, args, kwargs):
295305
assert self._cur_node_name is not None
296306
return converter(self.network, submod, args, kwargs, self._cur_node_name)
297307

298-
def call_function(self, target, args, kwargs):
308+
def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
309+
#TODO: Why is this stateful? We should be able to take in the inputs
299310
converter = CONVERTERS.get(self._cur_node)
300311
if not converter:
301312
raise RuntimeError(
302-
f"Conversion of function {torch.typename(target)} not currently supported!"
313+
f"Conversion of function {torch.typename(target)} not currently supported!" # type: ignore[no-untyped-call]
303314
)
304315

305316
assert self._cur_node_name is not None
306317
return converter(self.network, target, args, kwargs, self._cur_node_name)
307318

308-
def call_method(self, target, args, kwargs):
319+
def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
309320
assert isinstance(target, str)
310321
converter = CONVERTERS.get(self._cur_node)
311322

@@ -317,7 +328,7 @@ def call_method(self, target, args, kwargs):
317328
assert self._cur_node_name is not None
318329
return converter(self.network, target, args, kwargs, self._cur_node_name)
319330

320-
def output(self, target, args, kwargs):
331+
def output(self, target: str, args: Any, kwargs: Any) -> None:
321332
assert len(args) == 1
322333
if isinstance(args[0], tuple):
323334
outputs = args[0]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .aten_ops_converters import *
2-
from .trt_interpreter import *
2+
from .TRTInterpreter import *
33
from .conversion import *
44
from .truncate_long_and_double import repair_long_or_double_inputs

py/torch_tensorrt/dynamo/conversion/conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Sequence, Union
22
import torch
33
import io
4-
from torch_tensorrt.dynamo.runtime import _PythonTorchTRTModule
4+
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
55
from torch_tensorrt.dynamo import CompilationSettings
66
from torch_tensorrt import Input
77
from torch_tensorrt.dynamo.conversion import TRTInterpreter
@@ -15,7 +15,7 @@ def convert_module(
1515
inputs: Sequence[torch.Tensor],
1616
settings: CompilationSettings = CompilationSettings(),
1717
name: str = "",
18-
):
18+
) -> Union[PythonTorchTensorRTModule, TorchTensorRTModule]:
1919
"""Convert an FX module to a TRT module
2020
Args:
2121
module: FX GraphModule to convert

0 commit comments

Comments
 (0)