Skip to content

Commit ac0126f

Browse files
committed
chore(//py/torch_tensorrt/ts): Making compile mypy compliant
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 1eeb319 commit ac0126f

File tree

2 files changed

+48
-45
lines changed

2 files changed

+48
-45
lines changed

py/torch_tensorrt/_Device.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Self
12
import torch
23

34
# from torch_tensorrt import _enums
@@ -137,7 +138,7 @@ def _from_torch_device(cls, torch_dev: torch.device):
137138
return cls(gpu_id=gpu_id)
138139

139140
@classmethod
140-
def _current_device(cls):
141+
def _current_device(cls) -> Self:
141142
try:
142143
dev = _C._get_current_device()
143144
except RuntimeError:

py/torch_tensorrt/ts/_compiler.py

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Dict, Any
1+
from typing import List, Dict, Any, Tuple, Union, Set, Optional
22
import torch
33
from torch import nn
44

@@ -7,31 +7,32 @@
77
from torch_tensorrt.ts._compile_spec import _parse_compile_spec, _parse_device
88
from torch_tensorrt._Device import Device
99
from types import FunctionType
10+
from torch_tensorrt._Input import Input
1011

1112

1213
def compile(
1314
module: torch.jit.ScriptModule,
14-
inputs=[],
15-
input_signature=None,
16-
device=Device._current_device(),
17-
disable_tf32=False,
18-
sparse_weights=False,
19-
enabled_precisions=set(),
20-
refit=False,
21-
debug=False,
22-
capability=_enums.EngineCapability.default,
23-
num_avg_timing_iters=1,
24-
workspace_size=0,
25-
dla_sram_size=1048576,
26-
dla_local_dram_size=1073741824,
27-
dla_global_dram_size=536870912,
28-
calibrator=None,
29-
truncate_long_and_double=False,
30-
require_full_compilation=False,
31-
min_block_size=3,
32-
torch_executed_ops=[],
33-
torch_executed_modules=[],
34-
allow_shape_tensors=False,
15+
inputs: List[Union[Input, torch.Tensor]] = [],
16+
input_signature: Optional[Tuple[Union[Input, torch.Tensor]]] = None,
17+
device: Device = Device._current_device(),
18+
disable_tf32: bool = False,
19+
sparse_weights: bool = False,
20+
enabled_precisions: Set[Union[torch.dtype, _enums.dtype]] = set(), # type: ignore[name-defined]
21+
refit: bool = False,
22+
debug: bool = False,
23+
capability: _enums.EngineCapability = _enums.EngineCapability.default, # type: ignore[name-defined]
24+
num_avg_timing_iters: int = 1,
25+
workspace_size: int = 0,
26+
dla_sram_size: int = 1048576,
27+
dla_local_dram_size: int = 1073741824,
28+
dla_global_dram_size: int = 536870912,
29+
calibrator: object = None,
30+
truncate_long_and_double: bool = False,
31+
require_full_compilation: bool = False,
32+
min_block_size: int = 3,
33+
torch_executed_ops: List[str] = [],
34+
torch_executed_modules: List[str] = [],
35+
allow_shape_tensors: bool = False,
3536
) -> torch.jit.ScriptModule:
3637
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
3738
@@ -137,30 +138,30 @@ def compile(
137138
}
138139

139140
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
140-
compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod)
141+
compiled_module: torch.jit.ScriptModule = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod) # type: ignore[no-untyped-call]
141142
return compiled_module
142143

143144

144145
def convert_method_to_trt_engine(
145146
module: torch.jit.ScriptModule,
146147
method_name: str = "forward",
147-
inputs=[],
148-
device=Device._current_device(),
149-
disable_tf32=False,
150-
sparse_weights=False,
151-
enabled_precisions=set(),
152-
refit=False,
153-
debug=False,
154-
capability=_enums.EngineCapability.default,
155-
num_avg_timing_iters=1,
156-
workspace_size=0,
157-
dla_sram_size=1048576,
158-
dla_local_dram_size=1073741824,
159-
dla_global_dram_size=536870912,
160-
truncate_long_and_double=False,
161-
calibrator=None,
162-
allow_shape_tensors=False,
163-
) -> bytearray:
148+
inputs: List[Union[Input, torch.Tensor]] = [],
149+
device: Device = Device._current_device(),
150+
disable_tf32: bool = False,
151+
sparse_weights: bool = False,
152+
enabled_precisions: Set[Union[torch.dtype, _enums.dtype]] = set(), # type: ignore[name-defined]
153+
refit: bool = False,
154+
debug: bool =False,
155+
capability: _enums.EngineCapability = _enums.EngineCapability.default, # type: ignore[name-defined]
156+
num_avg_timing_iters: int = 1,
157+
workspace_size: int = 0,
158+
dla_sram_size: int = 1048576,
159+
dla_local_dram_size: int = 1073741824,
160+
dla_global_dram_size: int = 536870912,
161+
truncate_long_and_double: int = False,
162+
calibrator: object = None,
163+
allow_shape_tensors: bool = False,
164+
) -> bytes:
164165
"""Convert a TorchScript module method to a serialized TensorRT engine
165166
166167
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
@@ -221,7 +222,7 @@ def convert_method_to_trt_engine(
221222
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
222223
223224
Returns:
224-
bytearray: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
225+
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
225226
"""
226227
if isinstance(module, torch.jit.ScriptFunction):
227228
raise TypeError(
@@ -293,8 +294,8 @@ def embed_engine_in_new_module(
293294
input_binding_names,
294295
output_binding_names,
295296
)
296-
return torch.jit._recursive.wrap_cpp_module(cpp_mod)
297-
297+
wrapped_mod: torch.jit.ScriptModule = torch.jit._recursive.wrap_cpp_module(cpp_mod) # type: ignore[no-untyped-call]
298+
return wrapped_mod
298299

299300
def check_method_op_support(
300301
module: torch.jit.ScriptModule, method_name: str = "forward"
@@ -312,4 +313,5 @@ def check_method_op_support(
312313
Returns:
313314
bool: True if supported Method
314315
"""
315-
return _C.check_method_op_support(module._c, method_name)
316+
supported: bool = _C.check_method_op_support(module._c, method_name)
317+
return supported

0 commit comments

Comments
 (0)