Skip to content

Commit 2eb7341

Browse files
add dynamic support for embedding_bag/index_select (#3032)
1 parent e6a1932 commit 2eb7341

File tree

3 files changed

+201
-4
lines changed

3 files changed

+201
-4
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,14 @@ def embedding_bag_validator(node: Node) -> bool:
283283

284284

285285
@dynamo_tensorrt_converter(
286-
torch.ops.aten.embedding_bag.default, capability_validator=embedding_bag_validator
286+
torch.ops.aten.embedding_bag.default,
287+
capability_validator=embedding_bag_validator,
288+
supports_dynamic_shapes=True,
287289
)
288290
@dynamo_tensorrt_converter(
289-
torch.ops.aten._embedding_bag.default, capability_validator=embedding_bag_validator
291+
torch.ops.aten._embedding_bag.default,
292+
capability_validator=embedding_bag_validator,
293+
supports_dynamic_shapes=True,
290294
)
291295
@enforce_tensor_types(
292296
{
@@ -3379,7 +3383,9 @@ def aten_ops_roll(
33793383
)
33803384

33813385

3382-
@dynamo_tensorrt_converter(torch.ops.aten.index_select.default)
3386+
@dynamo_tensorrt_converter(
3387+
torch.ops.aten.index_select.default, supports_dynamic_shapes=True
3388+
)
33833389
@enforce_tensor_types(
33843390
{
33853391
0: (TRTTensor,),

tests/py/dynamo/conversion/test_embedding_bag_aten.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
2+
import torch_tensorrt
23
from parameterized import param, parameterized
34
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
46

57
from .harness import DispatchTestCase
68

@@ -408,6 +410,103 @@ def forward(self, weight, indices, offsets):
408410
propagate_shapes=True,
409411
)
410412

413+
@parameterized.expand(
414+
[
415+
param(
416+
# 1d_indices_mode_0_with_per_sample_weights
417+
# weights is for compile
418+
weights=torch.randn((5, 2), dtype=torch.float32),
419+
# weights_1 is for inference
420+
weights_1=torch.randn((6, 2), dtype=torch.float32),
421+
dynamic_shapes={
422+
"weights": {0: torch.export.Dim("dyn_dim", min=2, max=6)},
423+
"indices": {},
424+
"offsets": {},
425+
},
426+
indices=torch.tensor([1, 2, 4], dtype=torch.int32),
427+
offsets=torch.tensor([0, 2, 3], dtype=torch.int32),
428+
mode=0,
429+
per_sample_weights=torch.randn((3,), dtype=torch.float32),
430+
),
431+
param(
432+
# 1d_indices_mode_1_without_per_sample_weights
433+
# weights is for compile
434+
weights=torch.randn((5, 2), dtype=torch.float32),
435+
# weights_1 is for inference
436+
weights_1=torch.randn((6, 3), dtype=torch.float32),
437+
dynamic_shapes={
438+
"weights": {
439+
0: torch.export.Dim("dyn_dim", min=2, max=8),
440+
1: torch.export.Dim("dyn_dim_1", min=1, max=3),
441+
},
442+
"indices": {},
443+
"offsets": {},
444+
},
445+
indices=torch.tensor([1, 2, 4, 2, 3, 4], dtype=torch.int32),
446+
offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
447+
mode=1,
448+
per_sample_weights=None,
449+
),
450+
]
451+
)
452+
def test_embedding_bag_with_weights_dynamic_shape(
453+
self,
454+
weights,
455+
weights_1,
456+
dynamic_shapes,
457+
indices,
458+
offsets,
459+
mode,
460+
per_sample_weights,
461+
):
462+
class EmbeddingBag(torch.nn.Module):
463+
def forward(self, weights, indices, offsets, per_sample_weights=None):
464+
return torch.ops.aten._embedding_bag.default(
465+
weight=weights,
466+
indices=indices,
467+
offsets=offsets,
468+
per_sample_weights=per_sample_weights,
469+
scale_grad_by_freq=False,
470+
mode=mode,
471+
sparse=False,
472+
include_last_offset=False,
473+
padding_idx=-1,
474+
)
475+
476+
if per_sample_weights is None:
477+
inputs = (weights, indices, offsets)
478+
else:
479+
inputs = (weights, indices, offsets, per_sample_weights)
480+
mod = EmbeddingBag()
481+
482+
if per_sample_weights is not None:
483+
dynamic_shapes["per_sample_weights"] = {}
484+
fx_mod = torch.export.export(mod, inputs, dynamic_shapes=dynamic_shapes)
485+
trt_mod = torch_tensorrt.dynamo.compile(
486+
fx_mod, inputs=inputs, enable_precisions=torch.float32, min_block_size=1
487+
)
488+
# use the inputs with different shape to inference:
489+
if per_sample_weights is None:
490+
inputs = (weights_1, indices, offsets)
491+
else:
492+
inputs = (weights_1, indices, offsets, per_sample_weights)
493+
494+
with torch.no_grad():
495+
cuda_inputs = []
496+
for i in inputs:
497+
cuda_inputs.append(i.cuda())
498+
ref_outputs = mod(*cuda_inputs)
499+
outputs = trt_mod(*cuda_inputs)
500+
for out, ref in zip(outputs, ref_outputs):
501+
torch.testing.assert_close(
502+
out,
503+
ref,
504+
rtol=0.001,
505+
atol=0.001,
506+
equal_nan=True,
507+
check_dtype=True,
508+
)
509+
411510

412511
if __name__ == "__main__":
413512
run_tests()

tests/py/dynamo/conversion/test_index_select_aten.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
22
import torch.nn as nn
3-
from parameterized import parameterized
3+
import torch_tensorrt
4+
from parameterized import param, parameterized
45
from torch.testing._internal.common_utils import run_tests
6+
from torch_tensorrt import Input
57

68
from .harness import DispatchTestCase
79

@@ -36,6 +38,96 @@ def forward(self, source_tensor, indices_tensor):
3638
input,
3739
)
3840

41+
@parameterized.expand(
42+
[
43+
param(
44+
# 1d_source_tensor
45+
# source_tensor is for compile
46+
source_tensor=torch.randn((3,), dtype=torch.float32),
47+
# source_tensor_1 is for inference
48+
source_tensor_1=torch.randn((5,), dtype=torch.float32),
49+
dynamic_shapes={
50+
"source_tensor": {0: torch.export.Dim("dyn_dim", min=3, max=6)},
51+
"indice_tensor": {},
52+
},
53+
dim=0,
54+
indice_tensor=torch.tensor(
55+
[
56+
1,
57+
],
58+
dtype=torch.int32,
59+
),
60+
),
61+
param(
62+
# 2d_source_tensor
63+
# source_tensor is for compile
64+
source_tensor=torch.randn((3, 3), dtype=torch.float32),
65+
# source_tensor_1 is for inference
66+
source_tensor_1=torch.randn((4, 6), dtype=torch.float32),
67+
dynamic_shapes={
68+
"source_tensor": {
69+
0: torch.export.Dim("dyn_dim1", min=3, max=6),
70+
1: torch.export.Dim("dyn_dim2", min=2, max=7),
71+
},
72+
"indice_tensor": {},
73+
},
74+
dim=-1,
75+
indice_tensor=torch.tensor([0, 2], dtype=torch.int32),
76+
),
77+
param(
78+
# 3d_source_tensor
79+
# source_tensor is for compile
80+
source_tensor=torch.randn((3, 4, 2), dtype=torch.float32),
81+
# source_tensor_1 is for inference
82+
source_tensor_1=torch.randn((6, 7, 2), dtype=torch.float32),
83+
dynamic_shapes={
84+
"source_tensor": {
85+
0: torch.export.Dim("dyn_dim1", min=3, max=6),
86+
1: torch.export.Dim("dyn_dim2", min=2, max=7),
87+
},
88+
"indice_tensor": {},
89+
},
90+
dim=-2,
91+
indice_tensor=torch.tensor([0, 0, 2], dtype=torch.int32),
92+
),
93+
]
94+
)
95+
def test_index_select_dynamic_shape(
96+
self, source_tensor, source_tensor_1, dynamic_shapes, dim, indice_tensor
97+
):
98+
class IndexSelect(torch.nn.Module):
99+
def forward(self, source_tensor, indice_tensor):
100+
return torch.ops.aten.index_select.default(
101+
source_tensor,
102+
dim,
103+
indice_tensor,
104+
)
105+
106+
inputs = (source_tensor, indice_tensor)
107+
mod = IndexSelect()
108+
109+
fx_mod = torch.export.export(mod, inputs, dynamic_shapes=dynamic_shapes)
110+
trt_mod = torch_tensorrt.dynamo.compile(
111+
fx_mod, inputs=inputs, enable_precisions=torch.float32, min_block_size=1
112+
)
113+
# use different shape of inputs for inference:
114+
inputs = (source_tensor_1, indice_tensor)
115+
with torch.no_grad():
116+
cuda_inputs = []
117+
for i in inputs:
118+
cuda_inputs.append(i.cuda())
119+
ref_outputs = mod(*cuda_inputs)
120+
outputs = trt_mod(*cuda_inputs)
121+
for out, ref in zip(outputs, ref_outputs):
122+
torch.testing.assert_close(
123+
out,
124+
ref,
125+
rtol=0.001,
126+
atol=0.001,
127+
equal_nan=True,
128+
check_dtype=True,
129+
)
130+
39131

40132
if __name__ == "__main__":
41133
run_tests()

0 commit comments

Comments
 (0)