Skip to content

Commit 2048601

Browse files
committed
chore: address review comments
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 8d6c5aa commit 2048601

File tree

4 files changed

+273
-265
lines changed

4 files changed

+273
-265
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
7373
)
7474
return _IRType.ts
7575
else:
76-
raise ValueError("Module was provided with in an unsupported format")
76+
raise ValueError("Module was provided in an unsupported format")
7777
else:
7878
raise ValueError("Unknown ir was requested")
7979

@@ -157,37 +157,38 @@ def compile(
157157
**kwargs,
158158
)
159159
elif target_ir == _IRType.dynamo:
160+
from torch_tensorrt import Device
161+
from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device
162+
import collections.abc
163+
164+
if not isinstance(inputs, collections.abc.Sequence):
165+
inputs = [inputs]
166+
device = kwargs.get("device", Device._current_device())
167+
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
168+
module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs)
160169
return torch_tensorrt.dynamo.compile(
161170
module,
162171
inputs=inputs,
163172
enabled_precisions=enabled_precisions,
164173
**kwargs,
165174
)
166175
elif target_ir == _IRType.torch_compile:
167-
return torch_compile(
168-
module, inputs, enabled_precisions=enabled_precisions, **kwargs
169-
)
176+
return torch_compile(module, enabled_precisions=enabled_precisions, **kwargs)
170177
else:
171178
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
172179

173180

174-
def torch_compile(module, inputs, **kwargs):
175-
176-
from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device
181+
def torch_compile(module, **kwargs):
182+
"""
183+
Returns a boxed model which is the output of torch.compile.
184+
This does not compile the model to TRT. Execute this model on
185+
sample inputs to compile the model to TRT.
186+
"""
177187
from torch_tensorrt.dynamo.backend import torch_tensorrt_backend
178-
from torch_tensorrt import Device
179-
import collections.abc
180-
181-
if not isinstance(inputs, collections.abc.Sequence):
182-
inputs = [inputs]
183188

184-
device = kwargs.get("device", Device._current_device())
185-
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
186-
model = torch.compile(module, backend=torch_tensorrt_backend, options={**kwargs})
187-
# Ensure compilation occurs by calling the function with provided inputs
188-
model(*torch_inputs)
189+
boxed_fn = torch.compile(module, backend=torch_tensorrt_backend, options={**kwargs})
189190

190-
return model
191+
return boxed_fn
191192

192193

193194
def convert_method_to_trt_engine(
@@ -246,6 +247,16 @@ def convert_method_to_trt_engine(
246247
**kwargs,
247248
)
248249
elif target_ir == _IRType.fx:
249-
raise RuntimeError("fx is currently not supported")
250+
raise RuntimeError(
251+
"convert_method_to_trt_engine call is not supported for ir=fx"
252+
)
253+
elif target_ir == _IRType.dynamo:
254+
raise RuntimeError(
255+
"convert_method_to_trt_engine call is not supported for ir=dynamo."
256+
)
257+
elif target_ir == _IRType.torch_compile:
258+
raise RuntimeError(
259+
"convert_method_to_trt_engine call is not supported for ir=torch_compile"
260+
)
250261
else:
251262
raise RuntimeError("Module is an unknown format or the ir requested is unknown")

py/torch_tensorrt/dynamo/compile.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from torch.fx.passes.pass_manager import PassManager
1010
from torch.fx.passes.shape_prop import ShapeProp
1111
from torch.fx.passes.splitter_base import SplitResult
12-
from torch_tensorrt.dynamo.aten_tracer import trace
1312
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
1413
from torch_tensorrt.dynamo.lowering import (
1514
fuse_permute_linear,
@@ -38,7 +37,7 @@
3837

3938

4039
def compile(
41-
gm: torch.nn.Module,
40+
gm: Any,
4241
inputs: Any,
4342
*,
4443
device=Device._current_device(),
@@ -113,13 +112,11 @@ def compile(
113112
}
114113

115114
settings = CompilationSettings(**compilation_options)
116-
model = trace(gm, torch_inputs, **kwargs)
117-
118115
if kwargs.get("use_capability_partitioner", None):
119-
model = lower_model(model, torch_inputs)
116+
model = lower_model(gm, torch_inputs)
120117
return _compile_module(model, torch_inputs, settings)
121118
else:
122-
split_result = lower_model_using_trt_splitter(model, torch_inputs)
119+
split_result = lower_model_using_trt_splitter(gm, torch_inputs)
123120
trt_module = _compile_graph(split_result, torch_inputs, settings)
124121

125122
return trt_module

0 commit comments

Comments
 (0)