Skip to content

Commit d877dd9

Browse files
committed
refactor: Modify prepare_inputs, remove lower_precision
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 6c3b3f7 commit d877dd9

File tree

7 files changed

+54
-59
lines changed

7 files changed

+54
-59
lines changed

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from torch_tensorrt.fx.utils import LowerPrecision
1+
import torch
22

3-
4-
PRECISION = LowerPrecision.FP32
3+
PRECISION = torch.float32
54
DEBUG = False
65
WORKSPACE_SIZE = 0
76
MIN_BLOCK_SIZE = 5

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, Sequence
3-
4-
from torch_tensorrt.fx.utils import LowerPrecision
3+
import torch
54
from torch_tensorrt.dynamo._defaults import (
65
PRECISION,
76
DEBUG,
@@ -17,7 +16,7 @@
1716

1817
@dataclass
1918
class CompilationSettings:
20-
precision: LowerPrecision = PRECISION
19+
precision: torch.dtype = PRECISION
2120
debug: bool = DEBUG
2221
workspace_size: int = WORKSPACE_SIZE
2322
min_block_size: int = MIN_BLOCK_SIZE
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
from .backends import torch_tensorrt_backend
2-
from .compile import compile

py/torch_tensorrt/dynamo/compile.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from typing import Any, Optional, Sequence
88
from torch_tensorrt import EngineCapability, Device
9-
from torch_tensorrt.fx.utils import LowerPrecision
109
from torch.fx.passes.pass_manager import PassManager
1110
from torch.fx.passes.shape_prop import ShapeProp
1211
from torch_tensorrt.dynamo.aten_tracer import trace
@@ -78,29 +77,29 @@ def compile(
7877
if not isinstance(inputs, collections.abc.Sequence):
7978
inputs = [inputs]
8079

81-
inputs = prepare_inputs(inputs, prepare_device(device))
80+
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
8281

8382
if (
8483
torch.float16 in enabled_precisions
8584
or torch_tensorrt.dtype.half in enabled_precisions
8685
):
87-
lower_precision = LowerPrecision.FP16
86+
precision = torch.float16
8887
elif (
8988
torch.float32 in enabled_precisions
9089
or torch_tensorrt.dtype.float in enabled_precisions
9190
):
92-
lower_precision = LowerPrecision.FP32
91+
precision = torch.float32
9392
elif len(enabled_precisions) == 0:
9493
logger.info(f"No precision specified, defaulting to {PRECISION}")
95-
lower_precision = PRECISION
94+
precision = PRECISION
9695
else:
9796
raise ValueError(
9897
f"Precision {enabled_precisions} not supported in the Dynamo Path"
9998
)
10099

101100
if kwargs.get("ir", "dynamo") == "torch_compile":
102101
custom_backend = create_backend(
103-
precision=lower_precision,
102+
precision=precision,
104103
debug=debug,
105104
workspace_size=workspace_size,
106105
min_block_size=min_block_size,
@@ -114,13 +113,13 @@ def compile(
114113
)
115114
model = torch.compile(gm, backend=custom_backend)
116115
# Ensure compilation occurs by calling the function with provided inputs
117-
model(*inputs)
116+
model(*torch_inputs)
118117
return model
119118

120119
else:
121120
settings = CompilationSettings(
122121
debug=debug,
123-
precision=lower_precision,
122+
precision=precision,
124123
workspace_size=workspace_size,
125124
min_block_size=min_block_size,
126125
torch_executed_ops=torch_executed_ops,
@@ -131,20 +130,20 @@ def compile(
131130
use_python_runtime=use_python_runtime,
132131
)
133132

134-
model = trace(gm, inputs, **kwargs)
133+
model = trace(gm, torch_inputs, **kwargs)
135134

136135
if kwargs.get("use_capability_partitioner", None):
137-
model = lower_model(model, inputs)
138-
return _compile_module(model, inputs, settings)
136+
model = lower_model(model, torch_inputs)
137+
return _compile_module(model, torch_inputs, settings)
139138
else:
140-
split_result = lower_model_using_trt_splitter(model, inputs)
141-
trt_module = _compile_graph(split_result, inputs, settings)
139+
split_result = lower_model_using_trt_splitter(model, torch_inputs)
140+
trt_module = _compile_graph(split_result, torch_inputs, settings)
142141

143142
return trt_module
144143

145144

146145
def create_backend(
147-
precision: LowerPrecision = PRECISION,
146+
precision: torch.dtype = PRECISION,
148147
debug: bool = DEBUG,
149148
workspace_size: int = WORKSPACE_SIZE,
150149
min_block_size: int = MIN_BLOCK_SIZE,
@@ -234,7 +233,7 @@ def lower_model(model: torch.nn.Module, inputs: Any, **kwargs):
234233
[fuse_permute_matmul, fuse_permute_linear]
235234
)
236235
lowered_model = graph_optimization_pm(model)
237-
if isinstance(lowered_model, torch.fx.GraphModule):
238-
ShapeProp(lowered_model).propagate(*inputs)
236+
# if isinstance(lowered_model, torch.fx.GraphModule):
237+
# ShapeProp(lowered_model).propagate(*inputs)
239238

240239
return lowered_model

py/torch_tensorrt/dynamo/conversion/conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def convert_module(
4141
)
4242
interpreter_result = interpreter.run(
4343
workspace_size=settings.workspace_size,
44-
lower_precision=settings.precision,
44+
precision=settings.precision,
4545
profiling_verbosity=(
4646
trt.ProfilingVerbosity.VERBOSE
4747
if settings.debug

py/torch_tensorrt/dynamo/conversion/trt_interpreter.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torch_tensorrt.fx.observer import Observer
2020
from torch_tensorrt.fx.utils import (
2121
get_dynamic_dims,
22-
LowerPrecision,
2322
unified_dtype_converter,
2423
Frameworks,
2524
)
@@ -98,7 +97,7 @@ def validate_conversion(self):
9897
def run(
9998
self,
10099
workspace_size=0,
101-
lower_precision=LowerPrecision.FP16,
100+
precision=torch.float32,
102101
sparse_weights=False,
103102
disable_tf32=False,
104103
force_fp32_output=False,
@@ -115,7 +114,7 @@ def run(
115114
Build TensorRT engine with some configs.
116115
Args:
117116
workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation.
118-
lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
117+
precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
119118
sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
120119
force_fp32_output: force output to be fp32
121120
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
@@ -131,22 +130,14 @@ def run(
131130
"""
132131
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)
133132

134-
# For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and
133+
# For float outputs, we set their dtype to fp16 only if precision == torch.float16 and
135134
# force_fp32_output=False. Overriden by specifying output_dtypes
136-
self.output_fp16 = (
137-
not force_fp32_output and lower_precision == LowerPrecision.FP16
138-
)
135+
self.output_fp16 = not force_fp32_output and precision == torch.float16
139136

140-
if (
141-
lower_precision == LowerPrecision.INT8
142-
and not self.builder.platform_has_fast_int8
143-
):
137+
if precision == torch.int8 and not self.builder.platform_has_fast_int8:
144138
raise RuntimeError("Current platform doesn't support fast native int8!")
145139

146-
if (
147-
lower_precision == LowerPrecision.FP16
148-
and not self.builder.platform_has_fast_fp16
149-
):
140+
if precision == torch.float16 and not self.builder.platform_has_fast_fp16:
150141
warnings.warn("Current platform doesn't support fast native fp16!")
151142

152143
self.input_specs_iter = 0
@@ -190,10 +181,10 @@ def run(
190181
_LOGGER.info(f"Using optimization level {optimization_level}")
191182
builder_config.builder_optimization_level = optimization_level
192183

193-
if lower_precision == LowerPrecision.FP16:
184+
if precision == torch.float16:
194185
builder_config.set_flag(trt.BuilderFlag.FP16)
195186

196-
if lower_precision == LowerPrecision.INT8:
187+
if precision == torch.int8:
197188
builder_config.set_flag(trt.BuilderFlag.INT8)
198189

199190
if sparse_weights:

py/torch_tensorrt/dynamo/utils.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import replace, fields
44
from torch_tensorrt.dynamo import CompilationSettings
55
from typing import Any, Union, Sequence, Dict
6-
from torch_tensorrt import _Input, Device
6+
from torch_tensorrt import Input, Device
77
from typing import Optional
88

99
logger = logging.getLogger(__name__)
@@ -55,43 +55,51 @@ def cosine_similarity(gt_tensor, pred_tensor):
5555

5656

5757
def prepare_inputs(
58-
inputs: Union[_Input.Input, torch.Tensor, Sequence, Dict],
58+
inputs: Union[Input, torch.Tensor, Sequence, Dict],
5959
device: torch.device = torch.device("cuda"),
6060
) -> Any:
61-
if isinstance(inputs, _Input.Input):
61+
if isinstance(inputs, Input):
6262
if isinstance(inputs.shape, dict):
63-
return inputs.example_tensor(optimization_profile_field="opt_shape").to(
64-
device
65-
)
63+
return inputs, inputs.example_tensor(
64+
optimization_profile_field="opt_shape"
65+
).to(device)
6666
else:
67-
return inputs.example_tensor().to(device)
67+
return inputs, inputs.example_tensor().to(device)
6868

6969
elif isinstance(inputs, torch.Tensor):
70-
return inputs
70+
return Input.from_tensor(inputs), inputs
7171

7272
elif isinstance(inputs, list):
7373
prepared_input = list()
74-
74+
torchtrt_inputs = []
75+
torch_inputs = []
7576
for input_obj in inputs:
76-
prepared_input.append(prepare_inputs(input_obj))
77+
torchtrt_input, torch_input = prepare_inputs(input_obj)
78+
torchtrt_inputs.append(torchtrt_input)
79+
torch_inputs.append(torch_input)
7780

78-
return prepared_input
81+
return torchtrt_inputs, torch_inputs
7982

8083
elif isinstance(inputs, tuple):
81-
prepared_input = list()
82-
84+
torchtrt_inputs = []
85+
torch_inputs = []
8386
for input_obj in inputs:
84-
prepared_input.append(prepare_inputs(input_obj))
87+
torchtrt_input, torch_input = prepare_inputs(input_obj)
88+
torchtrt_inputs.append(torchtrt_input)
89+
torch_inputs.append(torch_input)
8590

86-
return tuple(prepared_input)
91+
return tuple(torchtrt_inputs), tuple(torch_inputs)
8792

8893
elif isinstance(inputs, dict):
89-
prepared_input = dict()
94+
torchtrt_inputs = dict()
95+
torch_inputs = dict()
9096

9197
for key, input_obj in inputs.items():
92-
prepared_input[key] = prepare_inputs(input_obj)
98+
torchtrt_input, torch_input = prepare_inputs(input_obj)
99+
torchtrt_inputs[key] = torchtrt_input
100+
torch_inputs[key] = torch_input
93101

94-
return prepared_input
102+
return torchtrt_inputs, torch_inputs
95103

96104
else:
97105
raise ValueError(

0 commit comments

Comments
 (0)