Skip to content

Commit 256abbe

Browse files
SherlockNoMadWei Wei
authored andcommitted
Add support for torch.nn.functional.conv_transpose3d in fx2trt (#40)
Summary: Pull Request resolved: pytorch/fx2trt#40 add support for torch.nn.functional.conv_transpose3d Reviewed By: 842974287 Differential Revision: D35176171 fbshipit-source-id: 5b10fba4568ad9fb26203f3dd396e13a98bc3495
1 parent 8b82a98 commit 256abbe

File tree

4 files changed

+134
-10
lines changed

4 files changed

+134
-10
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ def acc_ops_convnd(
172172

173173

174174
@tensorrt_converter(acc_ops.conv_transpose2d)
175-
def acc_ops_conv_transpose2d(
175+
@tensorrt_converter(acc_ops.conv_transpose3d)
176+
def acc_ops_conv_transposend(
176177
network: TRTNetwork,
177178
target: Target,
178179
args: Tuple[Argument, ...],
@@ -195,7 +196,7 @@ def acc_ops_conv_transpose2d(
195196
# right now
196197
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
197198
raise RuntimeError(
198-
f"conv {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
199+
f"ConvTranspose {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
199200
)
200201
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
201202

@@ -205,10 +206,11 @@ def acc_ops_conv_transpose2d(
205206
# will need to use uninitialized weight and set it later to support
206207
# ITensor weights
207208
dummy_weight = trt.Weights()
208-
# nn.ConvTranspose2d weight size is (in_channels, out_channels/groups, kernel_0, kernel_1)
209+
210+
# nn.ConvTranspose2d/3d weight size is (in_channels, out_channels/groups, kernel_0, kernel_1, [kernel_2])
209211
layer = network.add_deconvolution_nd(
210212
input=input_val,
211-
num_output_maps=weight.shape[1]*kwargs["groups"],
213+
num_output_maps=weight.shape[1] * kwargs["groups"],
212214
kernel_shape=weight.shape[2:],
213215
kernel=dummy_weight,
214216
bias=bias,
@@ -221,10 +223,10 @@ def acc_ops_conv_transpose2d(
221223
f"conv {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
222224
)
223225
weight = to_numpy(kwargs["weight"])
224-
# nn.ConvTranspose2d weight size is (in_channels, out_channels/groups, kernel_0, kernel_1)
226+
# nn.ConvTranspose2d/3d weight size is (in_channels, out_channels/groups, kernel_0, kernel_1, [kernel_2])
225227
layer = network.add_deconvolution_nd(
226228
input=input_val,
227-
num_output_maps=weight.shape[1]*kwargs["groups"],
229+
num_output_maps=weight.shape[1] * kwargs["groups"],
228230
kernel_shape=weight.shape[2:],
229231
kernel=weight,
230232
bias=bias,

test/converters/acc_op/test_transpose_convolution.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_conv_transpose2d(
3232
class TestModule(torch.nn.Module):
3333
def __init__(self):
3434
super().__init__()
35-
self.conv = torch.nn.ConvTranspose2d(
35+
self.conv_transpose = torch.nn.ConvTranspose2d(
3636
in_channels=3,
3737
out_channels=6,
3838
kernel_size=kernel_size,
@@ -45,7 +45,7 @@ def __init__(self):
4545
)
4646

4747
def forward(self, x):
48-
return self.conv(x)
48+
return self.conv_transpose(x)
4949

5050
inputs = [torch.randn(1, 3, 224, 224)]
5151
self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv_transpose2d})
@@ -54,10 +54,10 @@ def test_conv_transpose2d_with_dynamic_shape(self):
5454
class TestModule(torch.nn.Module):
5555
def __init__(self):
5656
super().__init__()
57-
self.conv = torch.nn.ConvTranspose2d(3, 3, 1)
57+
self.conv_transpose = torch.nn.ConvTranspose2d(3, 3, 1)
5858

5959
def forward(self, x):
60-
return self.conv(x)
60+
return self.conv_transpose(x)
6161

6262
input_specs = [
6363
InputTensorSpec(
@@ -71,5 +71,67 @@ def forward(self, x):
7171
)
7272

7373

74+
@parameterized.expand(
75+
[
76+
("default", 1),
77+
param("no_bias", 1, bias=False),
78+
("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)),
79+
param("non_zero_padding", 1, padding=1),
80+
param("dilation", 1, dilation=2),
81+
param("groups", 1, groups=3),
82+
]
83+
)
84+
def test_conv_transpose3d(
85+
self,
86+
_,
87+
kernel_size,
88+
stride=1,
89+
padding=0,
90+
output_padding=0,
91+
groups=1,
92+
bias=True,
93+
dilation=1,
94+
):
95+
class TestModule(torch.nn.Module):
96+
def __init__(self):
97+
super().__init__()
98+
self.conv_transpose = torch.nn.ConvTranspose3d(
99+
in_channels=3,
100+
out_channels=6,
101+
kernel_size=kernel_size,
102+
stride=stride,
103+
padding=padding,
104+
output_padding=output_padding,
105+
groups=groups,
106+
bias=bias,
107+
dilation=dilation,
108+
)
109+
110+
def forward(self, x):
111+
return self.conv_transpose(x)
112+
113+
inputs = [torch.randn(1, 3, 32, 32, 32)]
114+
self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv_transpose3d})
115+
116+
def test_conv_transpose3d_with_dynamic_shape(self):
117+
class TestModule(torch.nn.Module):
118+
def __init__(self):
119+
super().__init__()
120+
self.conv_transpose = torch.nn.ConvTranspose3d(3, 6, 1)
121+
122+
def forward(self, x):
123+
return self.conv_transpose(x)
124+
125+
input_specs = [
126+
InputTensorSpec(
127+
shape=(-1, 3, -1, -1, -1),
128+
dtype=torch.float32,
129+
shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))],
130+
),
131+
]
132+
self.run_test_with_dynamic_shape(
133+
TestModule(), input_specs, expected_ops={acc_ops.conv_transpose3d}
134+
)
135+
74136
if __name__ == "__main__":
75137
run_tests()

test/tracer/test_acc_tracer.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,50 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
449449

450450
self.assertTrue(torch.equal(m(input), traced(input)))
451451

452+
def test_conv_transpose3d(self):
453+
"""
454+
Test that a conv_transpose3d is traced as expected.
455+
"""
456+
457+
class TestModule(nn.Module):
458+
def __init__(self):
459+
super().__init__()
460+
self.conv = nn.ConvTranspose3d(8, 7, 3, stride=2)
461+
462+
def forward(self, a: torch.Tensor) -> torch.Tensor:
463+
return self.conv(a)
464+
465+
m = TestModule()
466+
input = torch.randn(3, 8, 8, 10, 10)
467+
traced = acc_tracer.trace(m, [input])
468+
469+
ph = weight_attr = bias_attr = conv = None
470+
for node in traced.graph.nodes:
471+
if node.op == "placeholder":
472+
self.assertEqual(str(node.target), "a")
473+
ph = node
474+
elif node.op == "get_attr" and node.target == "conv.weight":
475+
weight_attr = node
476+
elif node.op == "get_attr" and node.target == "conv.bias":
477+
bias_attr = node
478+
elif node.op == "call_function":
479+
self.assertEqual(node.target, acc_ops.conv_transpose3d)
480+
self.assertEqual(node.kwargs["input"], ph)
481+
self.assertEqual(node.kwargs["weight"], weight_attr)
482+
self.assertEqual(node.kwargs["bias"], bias_attr)
483+
self.assertEqual(node.kwargs["stride"], (2, 2, 2))
484+
self.assertEqual(node.kwargs["padding"], (0, 0, 0))
485+
self.assertEqual(node.kwargs["output_padding"], (0, 0, 0))
486+
self.assertEqual(node.kwargs["dilation"], (1, 1, 1))
487+
self.assertEqual(node.kwargs["groups"], 1)
488+
conv = node
489+
elif node.op == "output":
490+
self.assertEqual(conv, node.args[0])
491+
else:
492+
self.fail(f"Unexpected node: {node.format_node()}")
493+
494+
self.assertTrue(torch.equal(m(input), traced(input)))
495+
452496
def test_embedding_bag(self):
453497
"""
454498
Test that an embedding_bag is traced as expected.
@@ -2283,6 +2327,7 @@ def test_all_acc_ops_registered(self):
22832327
acc_ops.conv2d,
22842328
acc_ops.conv3d,
22852329
acc_ops.conv_transpose2d,
2330+
acc_ops.conv_transpose3d,
22862331
acc_ops.batch_norm,
22872332
acc_ops.embedding_bag,
22882333
acc_ops.embedding_bag_byte_rowwise_offsets,

tracer/acc_tracer/acc_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,6 +1446,21 @@ def conv_transpose2d(*, input, weight, bias, stride, padding, output_padding, gr
14461446
)
14471447

14481448

1449+
@register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.conv_transpose3d))
1450+
@register_acc_op
1451+
def conv_transpose3d(*, input, weight, bias, stride, padding, output_padding, groups, dilation):
1452+
return nn.functional.conv_transpose3d(
1453+
input=input,
1454+
weight=weight,
1455+
bias=bias,
1456+
stride=stride,
1457+
padding=padding,
1458+
output_padding=output_padding,
1459+
groups=groups,
1460+
dilation=dilation,
1461+
)
1462+
1463+
14491464
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.batch_norm))
14501465
@register_acc_op
14511466
def batch_norm(

0 commit comments

Comments
 (0)