Skip to content

Commit eee660d

Browse files
author
Wei Wei
committed
[fx2trt] support masked_fill, repeat (#29)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/29 BERT_pytorch model in torchbench: 1. masked_fill, repeat 2. fix a nit in embedding op Reviewed By: yinghai, wushirong Differential Revision: D35034664 fbshipit-source-id: 2c0e66df5a17f6960d71c3bc6e9bdca9a5daf3e5
1 parent 69b9457 commit eee660d

File tree

4 files changed

+151
-4
lines changed

4 files changed

+151
-4
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,15 +1117,12 @@ def acc_ops_embedding(
11171117
embedding_tensor = kwargs["weight"]
11181118

11191119
# unsupported parameters
1120-
padding_idx = kwargs["padding_idx"]
1120+
# ignore padding_idx since it is meaningful for training only
11211121
max_norm = kwargs["max_norm"]
11221122
norm_type = kwargs["norm_type"]
11231123
scale_grad_by_freq = kwargs["scale_grad_by_freq"]
11241124
sparse = kwargs["sparse"]
11251125

1126-
if padding_idx is not None:
1127-
raise RuntimeError(f"Currently we don't support specifying padding_idx, got {padding_idx}.")
1128-
11291126
if max_norm is not None:
11301127
raise RuntimeError(f"Currently we don't support specifying max_norm, got {max_norm}.")
11311128

@@ -1684,6 +1681,54 @@ def acc_ops_expand_tensor(
16841681
return layer.get_output(0)
16851682

16861683

1684+
@tensorrt_converter(acc_ops.masked_fill, no_implicit_batch_dim=True)
1685+
def acc_ops_masked_fill_tensor(
1686+
network: TRTNetwork,
1687+
target: Target,
1688+
args: Tuple[Argument, ...],
1689+
kwargs: Dict[str, Argument],
1690+
name: str,
1691+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1692+
input_t = kwargs["input"]
1693+
mask_t = kwargs["mask"]
1694+
value_t = kwargs["value"]
1695+
if network.has_implicit_batch_dimension:
1696+
raise RuntimeError("We don't support masked_fill with implicit batch dimension due to select layer!")
1697+
1698+
shape = list(input_t.shape)
1699+
mask_shape = list(mask_t.shape)
1700+
1701+
assert type(value_t) in (float, int, torch.Tensor), f"value {value_t} is not one of (float, int, torch.Tensor)!"
1702+
1703+
if type(mask_t) != TRTTensor:
1704+
assert mask_t.dtype == torch.bool, "mask dtype is not bool!"
1705+
if mask_shape != shape:
1706+
mask_t = mask_t.expand(shape)
1707+
mask_t = mask_t.to(torch.int32)
1708+
mask_const = get_trt_tensor(network, mask_t, f"{name}_mask")
1709+
mask_layer = network.add_identity(mask_const)
1710+
mask_layer.set_output_type(0, trt.bool)
1711+
set_layer_name(mask_layer, target, f"{name}_mask")
1712+
mask_val = mask_layer.get_output(0)
1713+
else:
1714+
assert mask_t.dtype == trt.bool, "mask dtype is not bool!"
1715+
if mask_shape != shape:
1716+
mask_val = acc_ops_expand_tensor(network, target, None, {"input": mask_t, "sizes": shape}, name=f"{name}_expand")
1717+
else:
1718+
mask_val = mask_t
1719+
1720+
if type(value_t) is torch.Tensor:
1721+
value_t = value_t.cpu().numpy()
1722+
value_t = float(value_t)
1723+
value_t = torch.ones(shape)*value_t
1724+
1725+
input_val = get_trt_tensor(network, input_t, f"{name}_input")
1726+
value_val = get_trt_tensor(network, value_t, f"{name}_input")
1727+
layer = network.add_select(mask_val, value_val, input_val)
1728+
set_layer_name(layer, target, f"{name}_select")
1729+
return layer.get_output(0)
1730+
1731+
16871732
@tensorrt_converter(acc_ops.split, no_explicit_batch_dim=True)
16881733
def acc_ops_split(
16891734
network: TRTNetwork,
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import torch
2+
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
3+
import torch.nn as nn
4+
from torch.testing._internal.common_fx2trt import AccTestCase
5+
from parameterized import parameterized
6+
from torch.testing._internal.common_utils import run_tests
7+
8+
9+
class TestMaskedFill(AccTestCase):
10+
@parameterized.expand(
11+
[
12+
("same_dims", (2, 3), 5),
13+
("same_dims_tensor", (2, 3), torch.tensor(5)),
14+
("not_same_dims", (2, 1), 5),
15+
("not_same_dims_tensor", (2, 1), torch.tensor(5)),
16+
]
17+
)
18+
def test_masked_fill(self, _, input_shape, value):
19+
class MaskedFill(nn.Module):
20+
def __init__(self, input_shape):
21+
super().__init__()
22+
self.mask = torch.zeros(input_shape)
23+
self.mask[0,0] = 1
24+
self.mask = self.mask.to(torch.bool)
25+
self.value = value
26+
def forward(self, x):
27+
return x.masked_fill(self.mask, self.value)
28+
29+
inputs = [torch.ones(*input_shape)]
30+
self.run_test(
31+
MaskedFill(input_shape),
32+
inputs,
33+
expected_ops={acc_ops.masked_fill},
34+
test_implicit_batch_dim = False
35+
)
36+
37+
@parameterized.expand(
38+
[
39+
("same_dims", (2, 3), (2,3), 5),
40+
("expand_first_dims", (2, 3), (1,3), 5),
41+
("expand_second_dims", (2, 3), (2,1), 5),
42+
("expand_third_dims", (2, 3, 4), (2, 3, 1), 5),
43+
]
44+
)
45+
def test_masked_fill_expand(self, _, input_shape, mask_shape, value):
46+
class MaskedFill(nn.Module):
47+
def __init__(self, input_shape):
48+
super().__init__()
49+
self.value = value
50+
def forward(self, x, mask_input):
51+
return x.masked_fill(mask_input, self.value)
52+
53+
mask_input = torch.zeros(*mask_shape)
54+
index = (0)*len(mask_shape)
55+
mask_input[index] = 1
56+
mask_input = mask_input.to(torch.bool)
57+
inputs = [torch.ones(*input_shape), mask_input]
58+
self.run_test(
59+
MaskedFill(input_shape),
60+
inputs,
61+
expected_ops={acc_ops.masked_fill},
62+
test_implicit_batch_dim = False
63+
)
64+
65+
if __name__ == '__main__':
66+
run_tests()

test/tracer/test_acc_tracer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,5 +2154,6 @@ def test_all_acc_ops_registered(self):
21542154
acc_ops.rescale_quantize_per_channel,
21552155
acc_ops.nan_to_num,
21562156
acc_ops.expand,
2157+
acc_ops.masked_fill,
21572158
},
21582159
)

tracer/acc_tracer/acc_ops.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,29 @@ def unsqueeze(*, input, dim):
256256
def tile(*, input, dims):
257257
return torch.tile(input=input, dims=dims)
258258

259+
@register_custom_acc_mapper_fn(
260+
op_and_target=("call_method", "repeat"),
261+
arg_replacement_tuples=[
262+
("input", "input"),
263+
("*", "sizes"),
264+
],
265+
)
266+
def repeat_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
267+
"""
268+
Map repeat to tile.
269+
"""
270+
with node.graph.inserting_before(node):
271+
inputs = node.kwargs["input"]
272+
dims = node.kwargs["sizes"][0]
273+
new_node = node.graph.create_node(
274+
"call_function",
275+
tile,
276+
kwargs={"input": inputs, "dims": dims},
277+
name=f"{node.name}_repeat_map",
278+
)
279+
new_node.meta = node.meta.copy()
280+
return new_node
281+
259282

260283
@register_custom_acc_mapper_fn(
261284
op_and_target=("call_function", torch.stack),
@@ -1648,6 +1671,18 @@ def expand(*, input, sizes):
16481671
return input.expand(*sizes)
16491672

16501673

1674+
@register_acc_op_mapping(
1675+
op_and_target=("call_method", "masked_fill"),
1676+
arg_replacement_tuples=[
1677+
("input", "input"),
1678+
("mask", "mask"),
1679+
("value", "value"),
1680+
],
1681+
)
1682+
@register_acc_op
1683+
def masked_fill(*, input, mask, value):
1684+
return input.masked_fill(mask, value)
1685+
16511686
@register_acc_op_properties(AccOpProperty.unary)
16521687
@register_acc_op
16531688
def slice_tensor(*, input, dim, start, stop, step):

0 commit comments

Comments
 (0)