Skip to content

Commit 790c78f

Browse files
committed
chore(//py/torch_tensorrt): _compile.py conforms to mypy
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent b0ae48e commit 790c78f

File tree

1 file changed

+28
-25
lines changed

1 file changed

+28
-25
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
1-
from typing import List, Dict, Any
1+
from typing import List, Dict, Any, Set, Union, Callable, TypeGuard
2+
23
import torch_tensorrt.ts
34

4-
from torch_tensorrt import logging
5+
from torch_tensorrt import logging, Input, dtype
56
import torch
67
import torch.fx
78
from enum import Enum
89

910
import torch_tensorrt.fx
11+
from torch_tensorrt.fx import InputTensorSpec
1012
from torch_tensorrt.fx.utils import LowerPrecision
1113

1214

15+
def _non_fx_input_interface(inputs: List[Input | torch.Tensor | InputTensorSpec]) -> TypeGuard[List[Input | torch.Tensor]]:
16+
return all([isinstance(i, torch.Tensor | Input) for i in inputs])
17+
18+
def _fx_input_interface(inputs: List[Input | torch.Tensor | InputTensorSpec]) -> TypeGuard[List[InputTensorSpec | torch.Tensor]]:
19+
return all([isinstance(i, torch.Tensor | InputTensorSpec) for i in inputs])
20+
1321
class _IRType(Enum):
1422
"""Enum to set the minimum required logging level to print a message to stdout"""
1523

@@ -80,11 +88,11 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
8088

8189
def compile(
8290
module: Any,
83-
ir="default",
84-
inputs=[],
85-
enabled_precisions=set([torch.float]),
86-
**kwargs,
87-
):
91+
ir: str = "default",
92+
inputs: List[Union[Input, torch.Tensor, InputTensorSpec]] = [],
93+
enabled_precisions: Set[Union[torch.dtype, dtype]] = set([torch.float]),
94+
**kwargs: Any,
95+
) -> Union[torch.nn.Module, torch.jit.ScriptModule, torch.fx.GraphModule, Callable[[Any], Any]]:
8896
"""Compile a PyTorch module for NVIDIA GPUs using TensorRT
8997
9098
Takes a existing PyTorch module and a set of settings to configure the compiler
@@ -130,9 +138,11 @@ def compile(
130138
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
131139
)
132140
ts_mod = torch.jit.script(module)
133-
return torch_tensorrt.ts.compile(
141+
assert _non_fx_input_interface(inputs)
142+
compiled_ts_module: torch.jit.ScriptModule = torch_tensorrt.ts.compile(
134143
ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
135144
)
145+
return compiled_ts_module
136146
elif target_ir == _IRType.fx:
137147
if (
138148
torch.float16 in enabled_precisions
@@ -147,38 +157,31 @@ def compile(
147157
else:
148158
raise ValueError(f"Precision {enabled_precisions} not supported on FX")
149159

150-
return torch_tensorrt.fx.compile(
160+
assert _fx_input_interface(inputs)
161+
compiled_fx_module: torch.nn.Module = torch_tensorrt.fx.compile(
151162
module,
152163
inputs,
153164
lower_precision=lower_precision,
154-
max_batch_size=inputs[0].size(0),
155165
explicit_batch_dimension=True,
156166
dynamic_batch=False,
157167
**kwargs,
158168
)
169+
return compiled_fx_module
159170
elif target_ir == _IRType.dynamo:
160-
from torch_tensorrt import Device
161-
from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device
162-
import collections.abc
163-
164-
if not isinstance(inputs, collections.abc.Sequence):
165-
inputs = [inputs]
166-
device = kwargs.get("device", Device._current_device())
167-
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
168-
module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs)
169171
return torch_tensorrt.dynamo.compile(
170172
module,
171173
inputs=inputs,
172174
enabled_precisions=enabled_precisions,
173175
**kwargs,
174176
)
177+
return compiled_aten_module
175178
elif target_ir == _IRType.torch_compile:
176179
return torch_compile(module, enabled_precisions=enabled_precisions, **kwargs)
177180
else:
178181
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
179182

180183

181-
def torch_compile(module, **kwargs):
184+
def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Callable[..., Any]:
182185
"""
183186
Returns a boxed model which is the output of torch.compile.
184187
This does not compile the model to TRT. Execute this model on
@@ -194,11 +197,11 @@ def torch_compile(module, **kwargs):
194197
def convert_method_to_trt_engine(
195198
module: Any,
196199
method_name: str,
197-
ir="default",
198-
inputs=[],
199-
enabled_precisions=set([torch.float]),
200-
**kwargs,
201-
):
200+
ir: str = "default",
201+
inputs: List[Union[Input, torch.Tensor]] = [],
202+
enabled_precisions: Set[Union[torch.dtype, dtype]] = set([torch.float]),
203+
**kwargs: Any,
204+
) -> bytes:
202205
"""Convert a TorchScript module method to a serialized TensorRT engine
203206
204207
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings

0 commit comments

Comments
 (0)