Skip to content

Commit ced1978

Browse files
Yinghai LuWei Wei
authored andcommitted
Remove fuse_unsqueeze_cat_sum from OSS (#11)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/11 This pass is a case specific one and it should be applied to the whole net instead of the lowering subnet because there are special cases where #cat input is 1 and it can potentially remove all the ops and make input and output the same tensor which Trt doesn't like. Reviewed By: wushirong Differential Revision: D34709569 fbshipit-source-id: 70438ba582063aa65b27f74f11d39b27897370b8
1 parent 74187d8 commit ced1978

File tree

3 files changed

+0
-77
lines changed

3 files changed

+0
-77
lines changed

fx/lower.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from .passes.fuse_pass import (
2424
fuse_permute_linear,
2525
fuse_permute_matmul,
26-
fuse_unsqueeze_cat_sum,
2726
)
2827
from .passes.remove_duplicate_output_args import (
2928
remove_duplicate_output_args,
@@ -252,7 +251,6 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
252251
if self.lower_setting.enable_fuse:
253252
mod = fuse_permute_matmul(mod)
254253
mod = fuse_permute_linear(mod)
255-
mod = fuse_unsqueeze_cat_sum(mod)
256254
FUSE_PASSES_POST_OBSERVER.observe(mod, input)
257255

258256
# Prepare algorithm selector and timing_cache for TRTInterpreter

fx/passes/fuse_pass.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -151,40 +151,6 @@ def fuse_permute_matmul(gm: torch.fx.GraphModule):
151151
return gm
152152

153153

154-
@observable()
155-
def fuse_unsqueeze_cat_sum(gm: torch.fx.GraphModule):
156-
for node in gm.graph.nodes:
157-
if node.target != acc_ops.sum:
158-
continue
159-
prev_node = node.kwargs["input"]
160-
if prev_node.target != acc_ops.cat or prev_node.kwargs["dim"] != 0:
161-
continue
162-
cat_inputs = prev_node.kwargs["tensors"]
163-
valid_pass = True
164-
for i in cat_inputs:
165-
if i.target != acc_ops.unsqueeze or i.kwargs["dim"] != 0:
166-
valid_pass = False
167-
break
168-
169-
if not valid_pass:
170-
continue
171-
input_val = [i.kwargs["input"] for i in cat_inputs]
172-
173-
with gm.graph.inserting_before(node):
174-
left = input_val[0]
175-
for i in range(1, len(input_val)):
176-
right = input_val[i]
177-
fused_node = gm.graph.call_function(acc_ops.add, kwargs={"input": left, "other": right})
178-
left = fused_node
179-
node.replace_all_uses_with(fused_node)
180-
181-
gm.graph.eliminate_dead_code()
182-
gm.graph.lint()
183-
gm.recompile()
184-
return gm
185-
186-
187-
188154
try:
189155
# @manual=//deeplearning/trt/python:py_tensorrt
190156
import tensorrt as trt

test/passes/test_fuse_unsqueeze_cat_sum_trt.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)