Skip to content

Commit e160a30

Browse files
committed
chore(//py/torch_tensorrt/dynamo/conversion): mypy conforming
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent c8719d8 commit e160a30

File tree

3 files changed

+51
-40
lines changed

3 files changed

+51
-40
lines changed

py/torch_tensorrt/dynamo/conversion/trt_interpreter.py renamed to py/torch_tensorrt/dynamo/conversion/TRTInterpreter.py

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import warnings
33
from datetime import datetime
44
from packaging import version
5-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
5+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
66

77
import numpy
88

99
# @manual=//deeplearning/trt/python:py_tensorrt
1010
import tensorrt as trt
1111
import torch
1212
import torch.fx
13+
from torch.fx.node import Target
1314
from torch._ops import OpOverload
1415
from torch.fx.node import _get_qualified_name
1516
from torch.fx.passes.shape_prop import TensorMetadata
@@ -42,12 +43,13 @@ def __init__(
4243
self,
4344
module: torch.fx.GraphModule,
4445
input_specs: List[Input],
45-
logger_level=None,
46-
output_dtypes=None,
46+
logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING,
47+
output_dtypes: Optional[List[torch.dtype]] = None,
4748
):
4849
super().__init__(module)
4950

50-
self.logger = trt.Logger(logger_level or trt.Logger.WARNING)
51+
# TODO: @narendasan replace with Torch-TensorRT Logger
52+
self.logger = trt.Logger(logger_level)
5153
self.builder = trt.Builder(self.logger)
5254

5355
flag = 0
@@ -60,12 +62,13 @@ def __init__(
6062

6163
missing_ops = self.validate_conversion()
6264
if missing_ops:
65+
# TODO: @narendasan make sure to set logging.captureWarnings(True)
6366
warnings.warn(
6467
"Interpretation will fail due to missing operations \n"
6568
+ "\n".join(f"{i}" for i in missing_ops)
6669
)
6770

68-
self.optimization_profiles: Optional[List] = None
71+
self.optimization_profiles: List[trt.IOptimizationProfile] = []
6972
self.input_specs = input_specs
7073
self.input_specs_iter = 0
7174
self._cur_node_name: Optional[str] = None
@@ -78,37 +81,37 @@ def __init__(
7881
# Data types for TRT Module output Tensors
7982
self.output_dtypes = output_dtypes
8083

81-
def validate_conversion(self):
82-
missing_converter = set()
84+
def validate_conversion(self) -> Set[str]:
85+
missing_converters = set()
8386

8487
for node in self.module.graph.nodes:
8588
if node.op == "call_function" and not CONVERTERS.get(node.target):
86-
missing_converter.add(f"{node.op} {_get_qualified_name(node.target)}")
89+
missing_converters.add(f"{node.op} {_get_qualified_name(node.target)}")
8790
elif node.op == "call_method" and not CONVERTERS.get(node.target):
88-
missing_converter.add(f"{node.op} torch.Tensor.{node.target}")
91+
missing_converters.add(f"{node.op} torch.Tensor.{node.target}")
8992
elif node.op == "call_module":
9093
submod = self.fetch_attr(node.target)
9194
submod_type = getattr(submod, "_base_class_origin", type(submod))
9295
if not CONVERTERS.get(submod_type):
93-
missing_converter.add(f"{node.op} {torch.typename(submod_type)}")
96+
missing_converters.add(f"{node.op} {torch.typename(submod_type)}") # type: ignore[no-untyped-call]
9497

95-
return missing_converter
98+
return missing_converters
9699

97100
def run(
98101
self,
99-
workspace_size=0,
100-
precision=torch.float32,
101-
sparse_weights=False,
102-
disable_tf32=False,
103-
force_fp32_output=False,
104-
strict_type_constraints=False,
105-
algorithm_selector=None,
106-
timing_cache=None,
107-
profiling_verbosity=None,
108-
tactic_sources=None,
109-
max_aux_streams=None,
110-
version_compatible=False,
111-
optimization_level=None,
102+
workspace_size: int = 0,
103+
precision: torch.dtype = torch.float32, # TODO: @peri044 Needs to be expanded to set
104+
sparse_weights: bool = False,
105+
disable_tf32: bool = False,
106+
force_fp32_output: bool = False,
107+
strict_type_constraints: bool = False,
108+
algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
109+
timing_cache: Optional[trt.ITimingCache] =None,
110+
profiling_verbosity: Optional[trt.ProfilingVerbosity] = None,
111+
tactic_sources: Optional[int] = None,
112+
max_aux_streams: Optional[int] = None,
113+
version_compatible: bool = False,
114+
optimization_level: Optional[int] = None,
112115
) -> TRTInterpreterResult:
113116
"""
114117
Build TensorRT engine with some configs.
@@ -196,7 +199,7 @@ def run(
196199
if strict_type_constraints:
197200
builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
198201

199-
if self.optimization_profiles:
202+
if len(self.optimization_profiles) > 0:
200203
for optimization_profile in self.optimization_profiles:
201204
builder_config.add_optimization_profile(optimization_profile)
202205

@@ -224,55 +227,63 @@ def run(
224227
engine, self._input_names, self._output_names, serialized_cache
225228
)
226229

227-
def run_node(self, n):
230+
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
228231
self._cur_node_name = str(n)
229232
# add "_itensor_to_tensor_meta"
230233
kwargs = dict(n.kwargs)
231234
kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta
232235
n.kwargs = kwargs
233236

234237
# run the node
235-
trt_node = super().run_node(n)
238+
trt_node: torch.fx.Node = super().run_node(n)
236239

237240
# remove "_itensor_to_tensor_meta"
238241
kwargs = dict(n.kwargs)
239242
del kwargs["_itensor_to_tensor_meta"]
240243
n.kwargs = kwargs
241244

242245
if isinstance(trt_node, trt.tensorrt.ITensor):
243-
self._itensor_to_tensor_meta[trt_node] = n.meta.get("tensor_meta")
246+
self._itensor_to_tensor_meta[trt_node] = n.meta.get("tensor_meta") #type: ignore[assignment]
244247

245248
return trt_node
246249

247-
def placeholder(self, target, args, kwargs):
250+
def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
248251
self._input_names.append(target)
249252
current_input = self.input_specs[self.input_specs_iter]
250253
self.input_specs_iter += 1
251254
# Set optimization profile for dynamic input shape
252-
shape = current_input.shape
255+
shape = None
253256
if current_input.shape_mode == Input._ShapeMode.DYNAMIC:
257+
assert isinstance(current_input.shape, dict)
254258
shape = []
255259
min_shape = current_input.shape["min_shape"]
256260
opt_shape = current_input.shape["opt_shape"]
257261
max_shape = current_input.shape["max_shape"]
258-
self.optimization_profiles[0].set_shape(
262+
# TODO: Does not support disjoint optimization profiles?
263+
self.optimization_profiles.append(trt.IOptimizationProfile().set_shape(
259264
target, [min_shape, opt_shape, max_shape]
260-
)
265+
))
261266
assert len(min_shape) == len(opt_shape) == len(max_shape)
262267
for i in range(len(min_shape)):
263268
if min_shape[i] == opt_shape[i] == max_shape[i]:
264269
shape.append(min_shape[i])
265270
else:
266271
# -1 to represent the dynamic dimension
267272
shape.append(-1)
273+
elif current_input.shape_mode == Input._ShapeMode.STATIC:
274+
assert isinstance(current_input.shape, tuple)
275+
shape = list(current_input.shape)
276+
else:
277+
raise RuntimeError(f"Unable to access shape spec for input: {target} (got: {current_input})")
278+
268279

269280
return self.network.add_input(
270281
name=target,
271282
shape=tuple(shape),
272283
dtype=unified_dtype_converter(current_input.torch_dtype, Frameworks.TRT),
273284
)
274285

275-
def call_module(self, target, args, kwargs):
286+
def call_module(self, target: str, args: Any, kwargs: Any) -> Any: #Probably should be Tuple[trt.ITensor]? Case for Any?
276287
assert isinstance(target, str)
277288
submod = self.fetch_attr(target)
278289
submod_type = getattr(submod, "_base_class_origin", type(submod))
@@ -286,17 +297,17 @@ def call_module(self, target, args, kwargs):
286297
assert self._cur_node_name is not None
287298
return converter(self.network, submod, args, kwargs, self._cur_node_name)
288299

289-
def call_function(self, target, args, kwargs):
300+
def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
290301
converter = CONVERTERS.get(target)
291302
if not converter:
292303
raise RuntimeError(
293-
f"Conversion of function {torch.typename(target)} not currently supported!"
304+
f"Conversion of function {torch.typename(target)} not currently supported!" # type: ignore[no-untyped-call]
294305
)
295306

296307
assert self._cur_node_name is not None
297308
return converter(self.network, target, args, kwargs, self._cur_node_name)
298309

299-
def call_method(self, target, args, kwargs):
310+
def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
300311
assert isinstance(target, str)
301312
converter = CONVERTERS.get(target)
302313

@@ -308,7 +319,7 @@ def call_method(self, target, args, kwargs):
308319
assert self._cur_node_name is not None
309320
return converter(self.network, target, args, kwargs, self._cur_node_name)
310321

311-
def output(self, target, args, kwargs):
322+
def output(self, target: str, args: Any, kwargs: Any) -> None:
312323
assert len(args) == 1
313324
if isinstance(args[0], tuple):
314325
outputs = args[0]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .trt_interpreter import *
1+
from .TRTInterpreter import *
22
from .conversion import *

py/torch_tensorrt/dynamo/conversion/conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Sequence, Union
22
import torch
33
import io
4-
from torch_tensorrt.dynamo.runtime import _PythonTorchTRTModule
4+
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
55
from torch_tensorrt.dynamo import CompilationSettings
66
from torch_tensorrt import Input
77
from torch_tensorrt.dynamo.conversion import TRTInterpreter
@@ -15,7 +15,7 @@ def convert_module(
1515
inputs: Sequence[torch.Tensor],
1616
settings: CompilationSettings = CompilationSettings(),
1717
name: str = "",
18-
):
18+
) -> Union[PythonTorchTensorRTModule, TorchTensorRTModule]:
1919
"""Convert an FX module to a TRT module
2020
Args:
2121
module: FX GraphModule to convert

0 commit comments

Comments
 (0)