Skip to content

Commit 5cb5947

Browse files
author
Wei
authored
Changes done internally at Facebook (#1204)
6703b98dff0695d91026f057b951dba1355825fa Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.prod c822345d6d673e1653c2208435e34ab400bada3d Jason Park <[email protected]> Add support for generic torch ops to be used in training. e5758602a0592d6c2b71d6d66a0398c4dd9b5e20 Shreyansh Prajapati <[email protected]> Test dynamic shape support for repeat interleave c13c633f04df162500eed477c0569eb2b81eb070 Shreyansh Prajapati <[email protected]> Test dynamic shape support for reduce ops 863476cf43b210922b88585b8f196dd84fbebb56 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.convolution 68dff39793e5c30c20010919a855bb3d984015d7 Ruichao Xiao <[email protected]> [fbcode][GPU][DHEN]fuse split squeeze cat as reshape f8b920769507ebd2ff02419b4aece25451298a95 Ruichao Xiao <[email protected]> [fbcode][DHEN][GPU] reorder and merge cats whose input is a sublist of another cat 5b6a8d2d6be979983a52ac96225fefb510c3817c Andrew Or <[email protected]> [Quant][fx] Rename convert_to_reference to convert_to_reference_fx 996a0e080b8a8bc0b292a7c2ac92f41f6db33a2e Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.expand 084631fe74b304fbb9481ca15fd452a3714fb1b8 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.to_dtype b3195e76329ccddbb5c4640cfa884d0e457d2d34 Shreyansh Prajapati <[email protected]> Test dynamic shape support for std a5d964e62bdf769cf8c2e67321138b33e1f524a7 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.tile 3d33d45b2fc7f10f25c22946ba474b227e4b6529 Shreyansh Prajapati <[email protected]> Test dynamic shape support for squeeze 09085abf63d7e7732e2cd66e600e8afc6d58964f Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.topk 65edc7ea12899e9bd2af42c890a64de853d9b7fe Huamin Li <[email protected]> temporarily skip gelu tests d11e521f9b90554ca86912a49920afa4406bb40d Shirong Wu <[email protected]> Suppress accuracy check for remove_reshape_with_batch_size_change 6d948298b2327d229e010a34f1c221b11d2eb504 Ankur Singla <[email protected]> [GPULowering] Suppress accuracy check for fuse_unsqueeze_cat_sum e780b647fc9571b77d9f41c963041a6ac3d66f33 Janet Yang <[email protected]> Lower xrayvideo2022 to fx2trt 433c7207fef16b1fdff985546ea969c39fa83e7c generatedunixname89002005287564 <[email protected]> [Codemod][Remove @noautodeps and @autodeps-skip tags] deeplearning/trt 1/2 66fdb65cffa925660c77b4758388399db3cbfe48 Scott Wolchok <[email protected]> [fx2ait] Minor Python cleanup in acc_ops_getitem 188132ecb2c19bcbf83cb2dc381f6e3798629f87 generatedunixname89002005324833 <[email protected]> [AutoAccept][Codemod][FBSourceBuckFormatLinter] Daily `arc lint --take BUCKFORMAT` 4536bae4686dd01f2149541ea7fb330e178a4969 Wei Wei <[email protected]> [fx2trt] support sub 064602e666f86c110d931cd90a8536112a19b4ad Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.interpolate 9dfd0ee0cecb1975e3f53c44de237d67ca443ec5 Shreyansh Prajapati <[email protected]> Test dynamic shape support for unary_ops 39b9efad8d5d82463a2016d135c0cf277de1c3c6 Shreyansh Prajapati <[email protected]> Test dynamic shape support for unsqueeze 2bb17667d1dabc95391950426fc1f921eb3d0959 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.split 64dfb7b096686cb2fd33197340dc72f30d525456 Shirong Wu <[email protected]> Group LN trt plugin 438f670e28df59b0734baa092a514fba3d75eb4f Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.avgpool df0fe32dae4343827bd9b37b72daae761b02f228 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops masked fill 44fe735d3493ea2d05a56b49093e4a23dd63a98e Shreyansh Prajapati <[email protected]> Test dynamic shaope support for acc_ops.pad 4f931acca706d8ce79045ceafef2ea0486609149 Wei Wei <[email protected]> [fx2trt] torch.max dynamic shape test bf6f6cbe217d26a95ca9122574adf7de3966db9e Shreyansh Prajapati <[email protected]> Change the name of the test from full_reduce to dim_reduce 1c5680ed107d9206f3514eff4069a3f6c870ba8c Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.type_as 33e4c175a4f5fec78ac0b1c8eb262ca777c7aaba Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.min f37be34bcef9716080b8bafbd1f4ad72e412c44c Wei Wei <[email protected]> [fx2trt] plugin for grid_sample 57b5cc6a0f4839686ae360361a3a13b424794ee7 generatedunixname89002005367269 <[email protected]> [AutoAccept][Codemod][FBSourceBlackLinter] Daily `arc lint --take BLACK` eb741cc5e5a7babdc94e72d411670905f54da3e0 Shreyansh Prajapati <[email protected]> Updated the dynamic shape support for narrow op 521c36b96a14741ae89d7af6cbb658120bcec2ea Shreyansh Prajapati <[email protected]> Removing the comment for 4 dims dynamic shape support after analysis e947343375967fe9efb0a16fdb9f63bff1449328 Shreyansh Prajapati <[email protected]> Updated the pad test for dynamic batch for analysis 3d64087014e91bc301a315eae43683b1aa2b66bc Oleg Khabinov <[email protected]> [trt_bc] Some improvements dfd937a56fa01aca88a89b46176befdac4c202c4 Shreyansh Prajapati <[email protected]> Updated the test for as_strided op for analysis 11d76d0420dcaa4bb8890dcdeb86b6e534af831c Bangsheng Tang <[email protected]> [gpu][infer] replace fx2trt_layer_norm with fbgemm layer_norm 932046ff6ea6dead114c0222b23ca3854690cffa Wei Wei <[email protected]> [fx2trt] bridge the dynamic batch and fixed shape f911463393d8a671cfee6de6d1b5ef4d4f3991a6 Shirong Wu <[email protected]> group swish LN plugin ea65970f23dd7a468e5bc43240f2a9bfa07c9b3b Shirong Wu <[email protected]> Create backend specific lower pass 38183e4a724e5514db2be7193cf4897b59759252 Alex Beloi <[email protected]> [fx] run acc_linter.lint in acc_tracer.trace
1 parent e113b48 commit 5cb5947

File tree

4 files changed

+51
-13
lines changed

4 files changed

+51
-13
lines changed

py/torch_tensorrt/fx/input_tensor_spec.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@
77

88

99
def generate_input_specs(inputs, lower_setting, additional_inputs=None):
10-
# AIT lower setting doesn't have explicit_batch_dimension field and
11-
# we just return None.
12-
if not hasattr(lower_setting, "explicit_batch_dimension"):
13-
return None
14-
15-
# dynamic_batch is TRT only flag. It does not exist in AIT lower setting
10+
# dynamic_batch is TRT only flag.
1611
if (
1712
not lower_setting.explicit_batch_dimension
1813
or lower_setting.dynamic_batch is False

py/torch_tensorrt/fx/lower.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def __call__(
232232
x.half() if x is not None and x.dtype == torch.float32 else x
233233
for x in inputs
234234
)
235-
pm = self.lower_pass_manager_builder.build_lower_pipeline(
235+
pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
236236
inputs, additional_inputs
237237
)
238238

py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _split_pass(self) -> PassManager:
121121
)
122122
return PassManager.build_from_passlist(passes)
123123

124-
def _lower_pass(self) -> PassManager:
124+
def _trt_lower_pass(self) -> PassManager:
125125
def lower_func(split_result: SplitResult) -> nn.Module:
126126
if (
127127
hasattr(self.lower_setting, "explicit_batch_dimension")
@@ -169,7 +169,51 @@ def lower_func(split_result: SplitResult) -> nn.Module:
169169

170170
return PassManager.build_from_passlist([lower_func])
171171

172-
def build_lower_pipeline(
172+
def _default_lower_pass(self) -> PassManager:
173+
def lower_func(split_result: SplitResult) -> nn.Module:
174+
175+
for submod_name, submod_inputs in split_result.submodule_inputs.items():
176+
submod = getattr(split_result.split_module, submod_name)
177+
178+
LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs)
179+
180+
# Only acc submodules will be lowered.
181+
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
182+
print("Now lowering submodule", submod_name)
183+
lowering_start_time = datetime.datetime.now()
184+
185+
lowered_module = self._lower_func(
186+
submod, submod_inputs, self.lower_setting, submod_name
187+
)
188+
setattr(split_result.split_module, submod_name, lowered_module)
189+
LOWER_SPLIT_POST_OBSERVER.observe(
190+
submod_name, lowered_module, submod_inputs
191+
)
192+
print(
193+
f"Lowering submodule {submod_name} elapsed time",
194+
datetime.datetime.now() - lowering_start_time,
195+
)
196+
197+
return split_result.split_module
198+
199+
return PassManager.build_from_passlist([lower_func])
200+
201+
def build_trt_lower_pipeline(
202+
self, input: Input, additional_input: Optional[Input] = None
203+
) -> PassManager:
204+
self._input = input
205+
self._additional_input = additional_input
206+
passes = []
207+
208+
passes.append(self._const_fold_pass())
209+
passes.append(self.graph_optimization_pass())
210+
passes.append(self._split_pass())
211+
passes.append(self._trt_lower_pass())
212+
213+
pm = PassManager.build_from_passlist(passes)
214+
return pm
215+
216+
def build_default_lower_pipeline(
173217
self, input: Input, additional_input: Optional[Input] = None
174218
) -> PassManager:
175219
self._input = input
@@ -179,7 +223,7 @@ def build_lower_pipeline(
179223
passes.append(self._const_fold_pass())
180224
passes.append(self.graph_optimization_pass())
181225
passes.append(self._split_pass())
182-
passes.append(self._lower_pass())
226+
passes.append(self._default_lower_pass())
183227

184228
pm = PassManager.build_from_passlist(passes)
185229
return pm

py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,6 @@ def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
528528
return cat_node
529529

530530

531-
@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary)
532531
@register_acc_op_mapping(op_and_target=("call_function", torch.clamp))
533532
@register_acc_op_mapping(op_and_target=("call_method", "clamp"))
534533
@register_acc_op
@@ -1743,7 +1742,7 @@ def quantized_conv2d(
17431742
dilation,
17441743
groups,
17451744
padding_mode,
1746-
acc_out_ty,
1745+
acc_out_ty=None,
17471746
):
17481747
qparams = acc_out_ty.qparams
17491748
return torch.nn.quantized.functional.conv2d(
@@ -2041,7 +2040,7 @@ def quantized_batch_norm2d(
20412040
weight,
20422041
bias,
20432042
eps,
2044-
acc_out_ty,
2043+
acc_out_ty=None,
20452044
):
20462045
qparams = acc_out_ty.qparams
20472046
return torch.ops.quantized.batch_norm2d(

0 commit comments

Comments
 (0)