Skip to content

Commit f8bab8f

Browse files
committed
fix: Repair argument passing in both Dynamo paths
- Pass-through new TRT args in export - Pass-through build failures arg in compile - Remove deprecated options including `explicit_batch_dimension` and `explicit_precision` from Dynamo utilities and update references to those options in settings
1 parent cae6b7c commit f8bab8f

File tree

5 files changed

+16
-24
lines changed

5 files changed

+16
-24
lines changed

py/torch_tensorrt/dynamo/backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def compile(
4545
min_block_size=MIN_BLOCK_SIZE,
4646
torch_executed_ops=[],
4747
torch_executed_modules=[],
48+
pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES,
4849
**kwargs,
4950
):
5051
if debug:
@@ -86,6 +87,7 @@ def compile(
8687
workspace_size=workspace_size,
8788
min_block_size=min_block_size,
8889
torch_executed_ops=torch_executed_ops,
90+
pass_through_build_failures=pass_through_build_failures,
8991
**kwargs,
9092
)
9193

py/torch_tensorrt/dynamo/backend/conversion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def convert_module(
3636
interpreter = TRTInterpreter(
3737
module,
3838
InputTensorSpec.from_tensors(inputs),
39-
explicit_batch_dimension=True,
4039
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
4140
output_dtypes=output_dtypes,
4241
)

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ def __init__(
3838
self,
3939
module: torch.fx.GraphModule,
4040
input_specs: List[InputTensorSpec],
41-
explicit_batch_dimension: bool = True,
42-
explicit_precision: bool = False,
4341
logger_level=None,
4442
output_dtypes=None,
4543
):
@@ -49,17 +47,11 @@ def __init__(
4947
self.builder = trt.Builder(self.logger)
5048

5149
flag = 0
52-
if explicit_batch_dimension:
53-
EXPLICIT_BATCH = 1 << (int)(
54-
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH
55-
)
56-
flag |= EXPLICIT_BATCH
5750

58-
if explicit_precision:
59-
EXPLICIT_PRECISION = 1 << (int)(
60-
trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION
61-
)
62-
flag |= EXPLICIT_PRECISION
51+
# It is deprecated to not use this flag
52+
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
53+
flag |= EXPLICIT_BATCH
54+
6355
self.network = self.builder.create_network(flag)
6456

6557
missing_ops = self.validate_conversion()

py/torch_tensorrt/dynamo/fx_ts_compat/lower.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def compile(
4949
cuda_graph_batch_size=-1,
5050
is_aten=False,
5151
use_experimental_fx_rt=False,
52+
max_aux_streams=None,
53+
version_compatible=False,
54+
optimization_level=None,
5255
num_avg_timing_iters=1,
5356
torch_executed_ops=[],
5457
torch_executed_modules=[],
@@ -68,14 +71,12 @@ def compile(
6871
save_timing_cache: Update timing cache with current timing cache data if set to True.
6972
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
7073
use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
74+
max_aux_streams: max number of aux stream to use
75+
version_compatible: enable version compatible feature
76+
optimization_level: builder optimization level
7177
Returns:
7278
A torch.nn.Module lowered by TensorRT.
7379
"""
74-
if use_experimental_fx_rt and not explicit_batch_dimension:
75-
raise ValueError(
76-
"The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True"
77-
)
78-
7980
logger.warn(
8081
"For ir=fx_ts_compat backend only the "
8182
+ "following arguments are supported: "
@@ -123,6 +124,9 @@ def compile(
123124
cuda_graph_batch_size=cuda_graph_batch_size,
124125
is_aten=is_aten,
125126
use_experimental_rt=use_experimental_fx_rt,
127+
max_aux_streams=max_aux_streams,
128+
version_compatible=version_compatible,
129+
optimization_level=optimization_level,
126130
)
127131
lowerer = Lowerer.create(lower_setting=lower_setting)
128132
return lowerer(module, inputs)
@@ -162,8 +166,6 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
162166
interpreter = TRTInterpreter(
163167
mod,
164168
input_specs=self.lower_setting.input_specs,
165-
explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
166-
explicit_precision=self.lower_setting.explicit_precision,
167169
logger_level=trt.Logger.VERBOSE
168170
if self.lower_setting.debug
169171
else trt.Logger.WARNING,
@@ -198,7 +200,7 @@ def default_split_function(
198200
model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting
199201
) -> SplitResult:
200202
splitter_setting = TRTSplitterSetting()
201-
splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension
203+
splitter_setting.use_implicit_batch_dim = False
202204
splitter_setting.min_block_size = lower_setting.min_block_size
203205
splitter_setting.use_experimental_rt = lower_setting.use_experimental_rt
204206
splitter = TRTSplitter(model, inputs, settings=splitter_setting)

py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class LowerSetting(LowerSettingBasic):
4444
Args:
4545
input_specs: Specs for inputs to engine, can either be a single size or a
4646
range defined by Min, Optimal, Max sizes.
47-
explicit_precision: Use explicit precision during lowering.
4847
workspace_size: The maximum workspace size. The maximum GPU temporary
4948
memory which the TensorRT engine can use at execution time.
5049
strict_type_constraints: Require TensorRT engine to strictly follow data type
@@ -76,8 +75,6 @@ class LowerSetting(LowerSettingBasic):
7675
"""
7776

7877
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
79-
explicit_batch_dimension: bool = True
80-
explicit_precision: bool = False
8178
workspace_size: int = 0
8279
strict_type_constraints: bool = False
8380
customized_fuse_pass: PassManager = dc.field(

0 commit comments

Comments
 (0)