Skip to content

Commit e8c2786

Browse files
committed
Fixing aten::select test
1 parent 4f742a9 commit e8c2786

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class TestSelectConverter(DispatchTestCase):
1212
("select_dim_index", 2, 1),
1313
]
1414
)
15-
def test_select(self, dim_test, index_test):
15+
def test_select(self, _, dim_test, index_test):
1616
class TestModule(torch.nn.Module):
1717
def __init__(self, dim, index):
1818
super().__init__()
@@ -25,10 +25,14 @@ def forward(self, input):
2525
TestModule(dim_test, index_test), input, expected_ops={torch.ops.aten.select}, test_explicit_precision=True,
2626
)
2727

28-
def test_select_with_dynamic_shape(self, dim_test, index_test):
28+
def test_select_with_dynamic_shape(self, _, dim_test, index_test):
2929
class TestModule(torch.nn.Module):
30-
def forward(self, input, dim, index):
31-
return torch.select(input, dim, index)
30+
def __init__(self, dim, index):
31+
super().__init__()
32+
self.dim = dim
33+
self.index = index
34+
def forward(self, input):
35+
return torch.select(input, self.dim, self.index)
3236

3337
input_spec = [
3438
InputTensorSpec(

0 commit comments

Comments
 (0)