Skip to content

Commit cfd5e56

Browse files
ShreyanshPrajapatiWei Wei
authored andcommitted
Add dynamic shape suport for acc_ops.max_pool1d (#83)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/83 Initial draft for adding dynamic shape support for acc_ops.max_pool1d. Reviewed By: frank-wei, wushirong Differential Revision: D36617147 fbshipit-source-id: 162cf66e0c8539f74a6c009e8eee624038b50d13
1 parent 0f10a31 commit cfd5e56

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,21 +1849,21 @@ def acc_ops_embedding(
18491849
return gather_layer.get_output(0)
18501850

18511851

1852-
@tensorrt_converter(acc_ops.max_pool1d, no_explicit_batch_dim=True)
1852+
@tensorrt_converter(acc_ops.max_pool1d)
18531853
def acc_ops_max_pool1d(
18541854
network: TRTNetwork,
18551855
target: Target,
18561856
args: Tuple[Argument, ...],
18571857
kwargs: Dict[str, Argument],
18581858
name: str,
18591859
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1860-
if not network.has_implicit_batch_dimension:
1860+
input_trt = kwargs["input"]
1861+
if not isinstance(input_trt, TRTTensor):
18611862
raise RuntimeError(
1862-
"Current implementation does not support dynamic shape. Make sure that the network has an explicit batch dimension!"
1863+
f"Max_pool1d received input {input_trt} that is not part "
1864+
"of the TensorRT region!"
18631865
)
18641866

1865-
input_trt = kwargs["input"]
1866-
18671867
# adds unsqueeze layer -> max pool 2d -> squeeze layer to emulate max pool 1d.
18681868
unsqueeze_layer = network.add_shuffle(input=input_trt)
18691869
unsqueeze_layer.reshape_dims = tuple([*input_trt.shape, 1])

test/converters/acc_op/test_maxpool.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,32 @@ def forward(self, x):
4040
TestModule(),
4141
inputs,
4242
expected_ops={acc_ops.max_pool1d},
43-
test_explicit_batch_dim=False,
43+
)
44+
45+
def test_max_pool1d_with_dynamic_shape(
46+
self,
47+
):
48+
class TestModule(torch.nn.Module):
49+
def __init__(self):
50+
super().__init__()
51+
self.max_pool = torch.nn.MaxPool1d(1)
52+
53+
def forward(self, x):
54+
return self.max_pool(x)
55+
56+
# shape is not set to (-1, -1, -1) as reshape dimension with
57+
# more than one -1 wildcard is not allowed while adding unsqueeze layer
58+
input_specs = [
59+
InputTensorSpec(
60+
shape=(1, 1, -1),
61+
dtype=torch.float32,
62+
shape_ranges=[((1, 1, 1), (1, 1, 4), (1, 1, 4))],
63+
),
64+
]
65+
self.run_test_with_dynamic_shape(
66+
TestModule(),
67+
input_specs,
68+
expected_ops={acc_ops.max_pool1d},
4469
)
4570

4671
@parameterized.expand(

0 commit comments

Comments
 (0)