Skip to content

Commit 150f055

Browse files
committed
adding select_scatter decomp
1 parent e7986df commit 150f055

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
aten.full,
175175
aten.repeat,
176176
aten.var_mean,
177+
aten.select_scatter,
177178
}
178179
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
179180
aten._softmax.default,

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ def forward(self, x, src, dim, index):
434434
torch.ops.aten.slice.Tensor,
435435
torch.ops.aten.squeeze.dim,
436436
torch.ops.aten.cat.default,
437+
torch.ops.aten.reshape.default,
437438
}
438439
unexpected_ops = {torch.ops.aten.select_scatter.default}
439440

@@ -496,6 +497,7 @@ def forward(self, x, src, dim, index):
496497
expected_ops = {
497498
torch.ops.aten.slice.Tensor,
498499
torch.ops.aten.squeeze.dim,
500+
torch.ops.aten.unsqueeze.default,
499501
torch.ops.aten.cat.default,
500502
}
501503
unexpected_ops = {torch.ops.aten.select_scatter.default}

0 commit comments

Comments
 (0)