1
1
import torch
2
2
import torch .nn as nn
3
+ from harness import DispatchTestCase
3
4
from parameterized import parameterized
4
5
from torch .testing ._internal .common_utils import run_tests
5
6
6
- from .harness import DispatchTestCase
7
-
8
7
9
8
class TestWhereConverter (DispatchTestCase ):
10
9
@parameterized .expand (
11
10
[
12
- ("2d_condition_xshape_yshape" , (2 , 2 ), (2 , 2 )),
13
- ("2d_broadcast_condition_xshape_yshape" , (2 , 2 ), (2 , 1 )),
14
- ("3d_condition_xshape_yshape" , (2 , 2 , 1 ), (2 , 2 , 1 )),
11
+ # ("2d_condition_xshape_yshape", (2, 2), (2, 2)),
12
+ # ("2d_broadcast_condition_xshape_yshape", (2, 2), (2, 1)),
13
+ # ("3d_condition_xshape_yshape", (2, 2, 1), (2, 2, 1)),
15
14
("2d_3d_condition_xshape_yshape" , (2 , 2 ), (1 , 2 , 2 )),
16
15
("3d_2d_condition_xshape_yshape" , (1 , 2 , 2 ), (2 , 2 )),
17
16
]
@@ -29,52 +28,52 @@ def forward(self, condition, x, y):
29
28
(condition , inputX , inputOther ),
30
29
)
31
30
32
- def test_0D_input (self ):
33
- class Where (nn .Module ):
34
- def forward (self , condition , x , y ):
35
- return torch .ops .aten .where .self (condition , x , y )
31
+ # def test_0D_input(self):
32
+ # class Where(nn.Module):
33
+ # def forward(self, condition, x, y):
34
+ # return torch.ops.aten.where.self(condition, x, y)
36
35
37
- inputX = torch .randn ((5 , 6 , 7 , 1 , 3 ))
38
- inputOther = torch .tensor (8.0 , dtype = torch .float )
39
- condition = inputX < 0
40
- self .run_test (
41
- Where (),
42
- (condition , inputX , inputOther ),
43
- )
36
+ # inputX = torch.randn((5, 6, 7, 1, 3))
37
+ # inputOther = torch.tensor(8.0, dtype=torch.float)
38
+ # condition = inputX < 0
39
+ # self.run_test(
40
+ # Where(),
41
+ # (condition, inputX, inputOther),
42
+ # )
44
43
45
- def test_const_input (self ):
46
- class Where (nn .Module ):
47
- def __init__ (self , * args , ** kwargs ) -> None :
48
- super ().__init__ (* args , ** kwargs )
49
- self .inputY = torch .randn ((5 , 6 , 7 ))
50
- self .inputX = torch .randn ((5 , 6 , 7 ))
44
+ # def test_const_input(self):
45
+ # class Where(nn.Module):
46
+ # def __init__(self, *args, **kwargs) -> None:
47
+ # super().__init__(*args, **kwargs)
48
+ # self.inputY = torch.randn((5, 6, 7))
49
+ # self.inputX = torch.randn((5, 6, 7))
51
50
52
- def forward (self , condition ):
53
- return torch .ops .aten .where .self (condition , self .inputX , self .inputY )
51
+ # def forward(self, condition):
52
+ # return torch.ops.aten.where.self(condition, self.inputX, self.inputY)
54
53
55
- input1 = torch .randn ((5 , 6 , 7 ))
56
- condition = input1 < 0
57
- self .run_test (
58
- Where (),
59
- (condition ,),
60
- )
54
+ # input1 = torch.randn((5, 6, 7))
55
+ # condition = input1 < 0
56
+ # self.run_test(
57
+ # Where(),
58
+ # (condition,),
59
+ # )
61
60
62
- def test_const_input_with_broadcast (self ):
63
- class Where (nn .Module ):
64
- def __init__ (self , * args , ** kwargs ) -> None :
65
- super ().__init__ (* args , ** kwargs )
66
- self .inputY = torch .randn ((1 ,))
67
- self .inputX = torch .randn ((1 ,))
61
+ # def test_const_input_with_broadcast(self):
62
+ # class Where(nn.Module):
63
+ # def __init__(self, *args, **kwargs) -> None:
64
+ # super().__init__(*args, **kwargs)
65
+ # self.inputY = torch.randn((1,))
66
+ # self.inputX = torch.randn((1,))
68
67
69
- def forward (self , condition ):
70
- return torch .ops .aten .where .self (condition , self .inputX , self .inputY )
68
+ # def forward(self, condition):
69
+ # return torch.ops.aten.where.self(condition, self.inputX, self.inputY)
71
70
72
- input1 = torch .randn ((5 , 6 , 7 ))
73
- condition = input1 < 0
74
- self .run_test (
75
- Where (),
76
- (condition ,),
77
- )
71
+ # input1 = torch.randn((5, 6, 7))
72
+ # condition = input1 < 0
73
+ # self.run_test(
74
+ # Where(),
75
+ # (condition,),
76
+ # )
78
77
79
78
80
79
if __name__ == "__main__" :
0 commit comments