Skip to content

Commit 2e7e26d

Browse files
Zhengping ZhouWei Wei
authored andcommitted
[fx2trt][bootcamp] Add support for torch.nn.functional.embedding (#27)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/27 Follows the instruction in T93104604 to add support for torch.nn.functional.embedding in fx2trt. Reviewed By: frank-wei Differential Revision: D34945232 fbshipit-source-id: c0ed25e2b7585bfb11736be7dad60ddbbb065050
1 parent 8f1233e commit 2e7e26d

File tree

4 files changed

+121
-10
lines changed

4 files changed

+121
-10
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,49 @@ def acc_ops_fmod(
11011101
)
11021102
return sub_value
11031103

1104+
1105+
@tensorrt_converter(acc_ops.embedding, no_implicit_batch_dim=True)
1106+
def acc_ops_embedding(
1107+
network: TRTNetwork,
1108+
target: Target,
1109+
args: Tuple[Argument, ...],
1110+
kwargs: Dict[str, Argument],
1111+
name: str,
1112+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1113+
if network.has_implicit_batch_dimension:
1114+
raise RuntimeError("The `embedding` function should be called with explicit batch dimension.")
1115+
1116+
indices_tensor = kwargs["input"]
1117+
embedding_tensor = kwargs["weight"]
1118+
1119+
# unsupported parameters
1120+
padding_idx = kwargs["padding_idx"]
1121+
max_norm = kwargs["max_norm"]
1122+
norm_type = kwargs["norm_type"]
1123+
scale_grad_by_freq = kwargs["scale_grad_by_freq"]
1124+
sparse = kwargs["sparse"]
1125+
1126+
if padding_idx is not None:
1127+
raise RuntimeError(f"Currently we don't support specifying padding_idx, got {padding_idx}.")
1128+
1129+
if max_norm is not None:
1130+
raise RuntimeError(f"Currently we don't support specifying max_norm, got {max_norm}.")
1131+
1132+
if norm_type != 2.0:
1133+
raise RuntimeError(f"Currently we don't support specifying max_norm, got {norm_type} for norm_type.")
1134+
1135+
if scale_grad_by_freq:
1136+
raise RuntimeError("Currently we don't support scale gradient by word frequency.")
1137+
1138+
if sparse:
1139+
raise RuntimeError("Currently we don't support sparse gradient.")
1140+
1141+
# Implement embedding lookup with gather layer
1142+
gather_layer = network.add_gather(embedding_tensor, indices_tensor, axis=0)
1143+
set_layer_name(gather_layer, target, name + "_gather")
1144+
return gather_layer.get_output(0)
1145+
1146+
11041147
@tensorrt_converter(acc_ops.max_pool1d, no_explicit_batch_dim=True)
11051148
def acc_ops_max_pool1d(
11061149
network: TRTNetwork,
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Owner(s): ["oncall: aiacc"]
2+
3+
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
4+
import torch
5+
from parameterized import parameterized, param
6+
from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec
7+
from torch.testing._internal.common_utils import run_tests
8+
9+
10+
class TestEmbeddingConverter(AccTestCase):
11+
@parameterized.expand(
12+
[
13+
param(
14+
test_name="1d_indices",
15+
indices_tensor=torch.tensor([3, 1, 2]),
16+
weights_tensor=torch.randn(5, 10),
17+
),
18+
param(
19+
test_name="2d_indices",
20+
indices_tensor=torch.tensor([[3, 1, 2], [4, 1, 3]]),
21+
weights_tensor=torch.randn(5, 10),
22+
),
23+
param(
24+
test_name="3d_indices",
25+
indices_tensor=torch.tensor([[[0, 1], [2, 3]], [[3, 4], [4, 0]]]),
26+
weights_tensor=torch.randn(5, 10),
27+
),
28+
]
29+
)
30+
def test_embedding(
31+
self,
32+
test_name,
33+
indices_tensor,
34+
weights_tensor,
35+
padding_idx=None,
36+
max_norm=None,
37+
norm_type=2.0,
38+
scale_grad_by_freq=False,
39+
sparse=False,
40+
):
41+
class TestEmbedding(torch.nn.Module):
42+
def forward(self, indices, weights):
43+
return torch.nn.functional.embedding(
44+
input=indices,
45+
weight=weights,
46+
padding_idx=padding_idx,
47+
max_norm=max_norm,
48+
norm_type=norm_type,
49+
scale_grad_by_freq=scale_grad_by_freq,
50+
sparse=sparse,
51+
)
52+
53+
self.run_test(
54+
TestEmbedding(),
55+
inputs=[indices_tensor.int(), weights_tensor.float()],
56+
expected_ops={acc_ops.embedding},
57+
test_implicit_batch_dim=False,
58+
test_explicit_batch_dim=True,
59+
)
60+
61+
62+
if __name__ == "__main__":
63+
run_tests()

test/tracer/test_acc_tracer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,6 +2063,7 @@ def test_all_acc_ops_registered(self):
20632063
acc_normalizer._acc_ops,
20642064
{
20652065
acc_ops.linear,
2066+
acc_ops.embedding,
20662067
acc_ops.max_pool1d,
20672068
acc_ops.max_pool2d,
20682069
acc_ops.flatten,

tracer/acc_tracer/acc_ops.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ def squeeze(*, input, dim=None):
8080
return input.squeeze(dim=dim)
8181

8282

83+
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.embedding))
84+
@register_acc_op
85+
def embedding(
86+
*, input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse
87+
):
88+
return torch.nn.functional.embedding(**locals())
89+
90+
8391
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.max_pool1d))
8492
@register_acc_op
8593
def max_pool1d(
@@ -95,6 +103,7 @@ def max_pool1d(
95103
return_indices=return_indices,
96104
)
97105

106+
98107
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.max_pool2d))
99108
@register_acc_op
100109
def max_pool2d(
@@ -118,24 +127,17 @@ def max_pool2d(
118127
def adaptive_avg_pool2d(*, input, output_size):
119128
return nn.functional.adaptive_avg_pool2d(input=input, output_size=output_size)
120129

130+
121131
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.avg_pool1d))
122132
@register_acc_op
123-
def avg_pool1d(
124-
*,
125-
input,
126-
kernel_size,
127-
stride,
128-
padding,
129-
ceil_mode,
130-
count_include_pad
131-
):
133+
def avg_pool1d(*, input, kernel_size, stride, padding, ceil_mode, count_include_pad):
132134
return nn.functional.avg_pool1d(
133135
input=input,
134136
kernel_size=kernel_size,
135137
stride=stride,
136138
padding=padding,
137139
ceil_mode=ceil_mode,
138-
count_include_pad=count_include_pad
140+
count_include_pad=count_include_pad,
139141
)
140142

141143

@@ -499,6 +501,7 @@ def dropout_mapper(node: torch.fx.Node, mod: nn.Module):
499501

500502
try:
501503
from torchvision.ops import stochastic_depth
504+
502505
assert callable(stochastic_depth)
503506
except Exception as e:
504507
warnings.warn(f"Unable to import torchvision related libraries.: {e}")
@@ -903,6 +906,7 @@ def prod(*, input, dim=None, keepdim=False, dtype=None):
903906
else:
904907
return input.prod(dtype=dtype)
905908

909+
906910
@register_custom_acc_mapper_fn(
907911
op_and_target=("call_method", "prod"),
908912
arg_replacement_tuples=[

0 commit comments

Comments
 (0)