Skip to content

Commit f3c7fc7

Browse files
committed
refactor: Modify prepare_inputs, remove lower_precision
Signed-off-by: Dheeraj Peri <[email protected]> chore: refactor Signed-off-by: Dheeraj Peri <[email protected]> chore: Address review comments Signed-off-by: Dheeraj Peri <[email protected]> chore: address review comments Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 4985372 commit f3c7fc7

File tree

15 files changed

+344
-388
lines changed

15 files changed

+344
-388
lines changed

.circleci/config.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,6 @@ commands:
742742
command: |
743743
cd tests/py/dynamo/backend/
744744
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml
745-
popd
746745
747746
- store_test_results:
748747
path: /tmp/artifacts
@@ -759,7 +758,6 @@ commands:
759758
pip3 install timm
760759
pip3 install transformers
761760
pytest test_models.py --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir torch_compile
762-
popd
763761
764762
- store_test_results:
765763
path: /tmp/artifacts
@@ -776,7 +774,6 @@ commands:
776774
pip3 install timm
777775
pip3 install transformers
778776
pytest test_models_export.py --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo
779-
popd
780777
781778
- store_test_results:
782779
path: /tmp/artifacts

py/torch_tensorrt/_compile.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,11 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
6767
)
6868
return _IRType.dynamo
6969
elif module_is_tsable:
70-
raise ValueError(
71-
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to compile."
70+
logging.log(
71+
logging.Level.Warning,
72+
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=ts",
7273
)
74+
return _IRType.ts
7375
else:
7476
raise ValueError("Module was provided with in an unsupported format")
7577
else:
@@ -154,18 +156,40 @@ def compile(
154156
dynamic_batch=False,
155157
**kwargs,
156158
)
157-
elif target_ir == _IRType.dynamo or target_ir == _IRType.torch_compile:
159+
elif target_ir == _IRType.dynamo:
158160
return torch_tensorrt.dynamo.compile(
159161
module,
160162
inputs=inputs,
161163
enabled_precisions=enabled_precisions,
162-
ir=target_ir.name,
163164
**kwargs,
164165
)
166+
elif target_ir == _IRType.torch_compile:
167+
return torch_compile(
168+
module, inputs, enabled_precisions=enabled_precisions, **kwargs
169+
)
165170
else:
166171
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
167172

168173

174+
def torch_compile(module, inputs, **kwargs):
175+
176+
from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device
177+
from torch_tensorrt.dynamo.backend import torch_tensorrt_backend
178+
from torch_tensorrt import Device
179+
import collections.abc
180+
181+
if not isinstance(inputs, collections.abc.Sequence):
182+
inputs = [inputs]
183+
184+
device = kwargs.get("device", Device._current_device())
185+
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
186+
model = torch.compile(module, backend=torch_tensorrt_backend, options={**kwargs})
187+
# Ensure compilation occurs by calling the function with provided inputs
188+
model(*torch_inputs)
189+
190+
return model
191+
192+
169193
def convert_method_to_trt_engine(
170194
module: Any,
171195
method_name: str,

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from ._settings import *
22
from .compile import compile
3+
from .aten_tracer import trace

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from torch_tensorrt.fx.utils import LowerPrecision
1+
import torch
22

3-
4-
PRECISION = LowerPrecision.FP32
3+
PRECISION = torch.float32
54
DEBUG = False
65
WORKSPACE_SIZE = 0
76
MIN_BLOCK_SIZE = 5

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, Sequence
3-
4-
from torch_tensorrt.fx.utils import LowerPrecision
3+
import torch
54
from torch_tensorrt.dynamo._defaults import (
65
PRECISION,
76
DEBUG,
@@ -17,7 +16,7 @@
1716

1817
@dataclass
1918
class CompilationSettings:
20-
precision: LowerPrecision = PRECISION
19+
precision: torch.dtype = PRECISION
2120
debug: bool = DEBUG
2221
workspace_size: int = WORKSPACE_SIZE
2322
min_block_size: int = MIN_BLOCK_SIZE
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
from .backends import torch_tensorrt_backend
2-
from .compile import compile

py/torch_tensorrt/dynamo/compile.py

Lines changed: 28 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from typing import Any, Optional, Sequence
88
from torch_tensorrt import EngineCapability, Device
9-
from torch_tensorrt.fx.utils import LowerPrecision
109
from torch.fx.passes.pass_manager import PassManager
1110
from torch.fx.passes.shape_prop import ShapeProp
1211
from torch_tensorrt.dynamo.aten_tracer import trace
@@ -78,117 +77,50 @@ def compile(
7877
if not isinstance(inputs, collections.abc.Sequence):
7978
inputs = [inputs]
8079

81-
inputs = prepare_inputs(inputs, prepare_device(device))
80+
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
8281

8382
if (
8483
torch.float16 in enabled_precisions
8584
or torch_tensorrt.dtype.half in enabled_precisions
8685
):
87-
lower_precision = LowerPrecision.FP16
86+
precision = torch.float16
8887
elif (
8988
torch.float32 in enabled_precisions
9089
or torch_tensorrt.dtype.float in enabled_precisions
9190
):
92-
lower_precision = LowerPrecision.FP32
91+
precision = torch.float32
9392
elif len(enabled_precisions) == 0:
9493
logger.info(f"No precision specified, defaulting to {PRECISION}")
95-
lower_precision = PRECISION
94+
precision = PRECISION
9695
else:
9796
raise ValueError(
9897
f"Precision {enabled_precisions} not supported in the Dynamo Path"
9998
)
10099

101-
if kwargs.get("ir", "dynamo") == "torch_compile":
102-
custom_backend = create_backend(
103-
precision=lower_precision,
104-
debug=debug,
105-
workspace_size=workspace_size,
106-
min_block_size=min_block_size,
107-
torch_executed_ops=torch_executed_ops,
108-
pass_through_build_failures=pass_through_build_failures,
109-
max_aux_streams=max_aux_streams,
110-
version_compatible=version_compatible,
111-
optimization_level=optimization_level,
112-
use_python_runtime=use_python_runtime,
113-
**kwargs,
114-
)
115-
model = torch.compile(gm, backend=custom_backend)
116-
# Ensure compilation occurs by calling the function with provided inputs
117-
model(*inputs)
118-
return model
119-
100+
compilation_options = {
101+
"precision": precision,
102+
"debug": debug,
103+
"workspace_size": workspace_size,
104+
"min_block_size": min_block_size,
105+
"torch_executed_ops": torch_executed_ops,
106+
"pass_through_build_failures": pass_through_build_failures,
107+
"max_aux_streams": max_aux_streams,
108+
"version_compatible": version_compatible,
109+
"optimization_level": optimization_level,
110+
"use_python_runtime": use_python_runtime,
111+
}
112+
113+
settings = CompilationSettings(**compilation_options)
114+
model = trace(gm, torch_inputs, **kwargs)
115+
116+
if kwargs.get("use_capability_partitioner", None):
117+
model = lower_model(model, torch_inputs)
118+
return _compile_module(model, torch_inputs, settings)
120119
else:
121-
settings = CompilationSettings(
122-
debug=debug,
123-
precision=lower_precision,
124-
workspace_size=workspace_size,
125-
min_block_size=min_block_size,
126-
torch_executed_ops=torch_executed_ops,
127-
pass_through_build_failures=pass_through_build_failures,
128-
max_aux_streams=max_aux_streams,
129-
version_compatible=version_compatible,
130-
optimization_level=optimization_level,
131-
use_python_runtime=use_python_runtime,
132-
)
120+
split_result = lower_model_using_trt_splitter(model, torch_inputs)
121+
trt_module = _compile_graph(split_result, torch_inputs, settings)
133122

134-
model = trace(gm, inputs, **kwargs)
135-
136-
if kwargs.get("use_capability_partitioner", None):
137-
model = lower_model(model, inputs)
138-
return _compile_module(model, inputs, settings)
139-
else:
140-
split_result = lower_model_using_trt_splitter(model, inputs)
141-
trt_module = _compile_graph(split_result, inputs, settings)
142-
143-
return trt_module
144-
145-
146-
def create_backend(
147-
precision: LowerPrecision = PRECISION,
148-
debug: bool = DEBUG,
149-
workspace_size: int = WORKSPACE_SIZE,
150-
min_block_size: int = MIN_BLOCK_SIZE,
151-
torch_executed_ops: Sequence[str] = set(),
152-
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
153-
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
154-
version_compatible: bool = VERSION_COMPATIBLE,
155-
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
156-
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
157-
**kwargs,
158-
):
159-
"""Create torch.compile backend given specified arguments
160-
161-
Args:
162-
precision: Model Layer precision
163-
debug: Whether to print out verbose debugging information
164-
workspace_size: Workspace TRT is allowed to use for the module (0 is default)
165-
min_block_size: Minimum number of operators per TRT-Engine Block
166-
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
167-
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
168-
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
169-
version_compatible: Provide version forward-compatibility for engine plan files
170-
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
171-
searching for more optimization options. TRT defaults to 3
172-
use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
173-
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
174-
argument as None
175-
Returns:
176-
Backend for torch.compile
177-
"""
178-
return partial(
179-
torch_tensorrt_backend,
180-
debug=debug,
181-
precision=precision,
182-
workspace_size=workspace_size,
183-
min_block_size=min_block_size,
184-
torch_executed_ops=torch_executed_ops,
185-
pass_through_build_failures=pass_through_build_failures,
186-
max_aux_streams=max_aux_streams,
187-
version_compatible=version_compatible,
188-
optimization_level=optimization_level,
189-
use_python_runtime=use_python_runtime,
190-
**kwargs,
191-
)
123+
return trt_module
192124

193125

194126
def _compile_graph(
@@ -234,7 +166,7 @@ def lower_model(model: torch.nn.Module, inputs: Any, **kwargs):
234166
[fuse_permute_matmul, fuse_permute_linear]
235167
)
236168
lowered_model = graph_optimization_pm(model)
237-
if isinstance(lowered_model, torch.fx.GraphModule):
238-
ShapeProp(lowered_model).propagate(*inputs)
169+
# if isinstance(lowered_model, torch.fx.GraphModule):
170+
# ShapeProp(lowered_model).propagate(*inputs)
239171

240172
return lowered_model

py/torch_tensorrt/dynamo/conversion/conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def convert_module(
4141
)
4242
interpreter_result = interpreter.run(
4343
workspace_size=settings.workspace_size,
44-
lower_precision=settings.precision,
44+
precision=settings.precision,
4545
profiling_verbosity=(
4646
trt.ProfilingVerbosity.VERBOSE
4747
if settings.debug

py/torch_tensorrt/dynamo/conversion/trt_interpreter.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torch_tensorrt.fx.observer import Observer
2020
from torch_tensorrt.fx.utils import (
2121
get_dynamic_dims,
22-
LowerPrecision,
2322
unified_dtype_converter,
2423
Frameworks,
2524
)
@@ -98,7 +97,7 @@ def validate_conversion(self):
9897
def run(
9998
self,
10099
workspace_size=0,
101-
lower_precision=LowerPrecision.FP16,
100+
precision=torch.float32,
102101
sparse_weights=False,
103102
disable_tf32=False,
104103
force_fp32_output=False,
@@ -115,7 +114,7 @@ def run(
115114
Build TensorRT engine with some configs.
116115
Args:
117116
workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation.
118-
lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
117+
precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
119118
sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
120119
force_fp32_output: force output to be fp32
121120
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
@@ -131,22 +130,14 @@ def run(
131130
"""
132131
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)
133132

134-
# For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and
133+
# For float outputs, we set their dtype to fp16 only if precision == torch.float16 and
135134
# force_fp32_output=False. Overriden by specifying output_dtypes
136-
self.output_fp16 = (
137-
not force_fp32_output and lower_precision == LowerPrecision.FP16
138-
)
135+
self.output_fp16 = not force_fp32_output and precision == torch.float16
139136

140-
if (
141-
lower_precision == LowerPrecision.INT8
142-
and not self.builder.platform_has_fast_int8
143-
):
137+
if precision == torch.int8 and not self.builder.platform_has_fast_int8:
144138
raise RuntimeError("Current platform doesn't support fast native int8!")
145139

146-
if (
147-
lower_precision == LowerPrecision.FP16
148-
and not self.builder.platform_has_fast_fp16
149-
):
140+
if precision == torch.float16 and not self.builder.platform_has_fast_fp16:
150141
warnings.warn("Current platform doesn't support fast native fp16!")
151142

152143
self.input_specs_iter = 0
@@ -190,10 +181,10 @@ def run(
190181
_LOGGER.info(f"Using optimization level {optimization_level}")
191182
builder_config.builder_optimization_level = optimization_level
192183

193-
if lower_precision == LowerPrecision.FP16:
184+
if precision == torch.float16:
194185
builder_config.set_flag(trt.BuilderFlag.FP16)
195186

196-
if lower_precision == LowerPrecision.INT8:
187+
if precision == torch.int8:
197188
builder_config.set_flag(trt.BuilderFlag.INT8)
198189

199190
if sparse_weights:

py/torch_tensorrt/dynamo/runtime/_PythonTRTModule.py renamed to py/torch_tensorrt/dynamo/runtime/_PythonTorchTRTModule.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77

88

99
class TRTModule(torch.nn.Module):
10+
"""TRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
11+
12+
This module is backed by the Torch-TensorRT runtime and is only compatibile with
13+
FX / Dynamo / Python deployments. This module cannot be serialized to torchscript via torch.jit.trace for C++ deployment.
14+
"""
15+
1016
def __init__(
1117
self, engine=None, input_names=None, output_names=None, cuda_graph_batch_size=-1
1218
):

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TorchTensorRTModule(torch.nn.Module):
1313
1414
This module is backed by the Torch-TensorRT runtime and is fully compatibile with both
1515
FX / Python deployments (just ``import torch_tensorrt`` as part of the application) as
16-
well as TorchScript / C++ deployments since TRTModule can be passed to ``torch.jit.trace``
16+
well as TorchScript / C++ deployments since TorchTensorRTModule can be passed to ``torch.jit.trace``
1717
and then saved.
1818
1919
The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from ._PythonTRTModule import TRTModule
1+
from ._PythonTorchTRTModule import TRTModule
22
from ._TorchTensorRTModule import TorchTensorRTModule

0 commit comments

Comments
 (0)