Skip to content

Commit 66af8a4

Browse files
authored
Merge pull request #1997 from pytorch/dynamo_arg_issues
fix: Repair argument passing in both Dynamo paths
2 parents 6dcd1fc + f8bab8f commit 66af8a4

File tree

5 files changed

+16
-24
lines changed

5 files changed

+16
-24
lines changed

py/torch_tensorrt/dynamo/backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def compile(
4444
min_block_size=MIN_BLOCK_SIZE,
4545
torch_executed_ops=[],
4646
torch_executed_modules=[],
47+
pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES,
4748
**kwargs,
4849
):
4950
if debug:
@@ -89,6 +90,7 @@ def compile(
8990
workspace_size=workspace_size,
9091
min_block_size=min_block_size,
9192
torch_executed_ops=torch_executed_ops,
93+
pass_through_build_failures=pass_through_build_failures,
9294
**kwargs,
9395
)
9496

py/torch_tensorrt/dynamo/backend/conversion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def convert_module(
3636
interpreter = TRTInterpreter(
3737
module,
3838
InputTensorSpec.from_tensors(inputs),
39-
explicit_batch_dimension=True,
4039
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
4140
output_dtypes=output_dtypes,
4241
)

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ def __init__(
4343
self,
4444
module: torch.fx.GraphModule,
4545
input_specs: List[InputTensorSpec],
46-
explicit_batch_dimension: bool = True,
47-
explicit_precision: bool = False,
4846
logger_level=None,
4947
output_dtypes=None,
5048
):
@@ -54,17 +52,11 @@ def __init__(
5452
self.builder = trt.Builder(self.logger)
5553

5654
flag = 0
57-
if explicit_batch_dimension:
58-
EXPLICIT_BATCH = 1 << (int)(
59-
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH
60-
)
61-
flag |= EXPLICIT_BATCH
6255

63-
if explicit_precision:
64-
EXPLICIT_PRECISION = 1 << (int)(
65-
trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION
66-
)
67-
flag |= EXPLICIT_PRECISION
56+
# It is deprecated to not use this flag
57+
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
58+
flag |= EXPLICIT_BATCH
59+
6860
self.network = self.builder.create_network(flag)
6961

7062
missing_ops = self.validate_conversion()

py/torch_tensorrt/dynamo/fx_ts_compat/lower.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def compile(
4949
cuda_graph_batch_size=-1,
5050
is_aten=False,
5151
use_experimental_fx_rt=False,
52+
max_aux_streams=None,
53+
version_compatible=False,
54+
optimization_level=None,
5255
num_avg_timing_iters=1,
5356
torch_executed_ops=[],
5457
torch_executed_modules=[],
@@ -68,14 +71,12 @@ def compile(
6871
save_timing_cache: Update timing cache with current timing cache data if set to True.
6972
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
7073
use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
74+
max_aux_streams: max number of aux stream to use
75+
version_compatible: enable version compatible feature
76+
optimization_level: builder optimization level
7177
Returns:
7278
A torch.nn.Module lowered by TensorRT.
7379
"""
74-
if use_experimental_fx_rt and not explicit_batch_dimension:
75-
raise ValueError(
76-
"The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True"
77-
)
78-
7980
logger.warn(
8081
"For ir=fx_ts_compat backend only the "
8182
+ "following arguments are supported: "
@@ -123,6 +124,9 @@ def compile(
123124
cuda_graph_batch_size=cuda_graph_batch_size,
124125
is_aten=is_aten,
125126
use_experimental_rt=use_experimental_fx_rt,
127+
max_aux_streams=max_aux_streams,
128+
version_compatible=version_compatible,
129+
optimization_level=optimization_level,
126130
)
127131
lowerer = Lowerer.create(lower_setting=lower_setting)
128132
return lowerer(module, inputs)
@@ -162,8 +166,6 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
162166
interpreter = TRTInterpreter(
163167
mod,
164168
input_specs=self.lower_setting.input_specs,
165-
explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
166-
explicit_precision=self.lower_setting.explicit_precision,
167169
logger_level=trt.Logger.VERBOSE
168170
if self.lower_setting.debug
169171
else trt.Logger.WARNING,
@@ -198,7 +200,7 @@ def default_split_function(
198200
model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting
199201
) -> SplitResult:
200202
splitter_setting = TRTSplitterSetting()
201-
splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension
203+
splitter_setting.use_implicit_batch_dim = False
202204
splitter_setting.min_block_size = lower_setting.min_block_size
203205
splitter_setting.use_experimental_rt = lower_setting.use_experimental_rt
204206
splitter = TRTSplitter(model, inputs, settings=splitter_setting)

py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class LowerSetting(LowerSettingBasic):
4444
Args:
4545
input_specs: Specs for inputs to engine, can either be a single size or a
4646
range defined by Min, Optimal, Max sizes.
47-
explicit_precision: Use explicit precision during lowering.
4847
workspace_size: The maximum workspace size. The maximum GPU temporary
4948
memory which the TensorRT engine can use at execution time.
5049
strict_type_constraints: Require TensorRT engine to strictly follow data type
@@ -76,8 +75,6 @@ class LowerSetting(LowerSettingBasic):
7675
"""
7776

7877
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
79-
explicit_batch_dimension: bool = True
80-
explicit_precision: bool = False
8178
workspace_size: int = 0
8279
strict_type_constraints: bool = False
8380
customized_fuse_pass: PassManager = dc.field(

0 commit comments

Comments
 (0)