@@ -75,6 +75,10 @@ class eq(nn.Module):
75
75
def forward (self , lhs_val , rhs_val ):
76
76
return torch .ops .aten .eq .Tensor (lhs_val , rhs_val )
77
77
78
+ class eq_operator (nn .Module ):
79
+ def forward (self , lhs_val , rhs_val ):
80
+ return lhs_val == rhs_val
81
+
78
82
input_specs = [
79
83
Input (
80
84
dtype = torch .float32 ,
@@ -93,6 +97,10 @@ def forward(self, lhs_val, rhs_val):
93
97
eq (),
94
98
input_specs ,
95
99
)
100
+ self .run_test_with_dynamic_shape (
101
+ eq_operator (),
102
+ input_specs ,
103
+ )
96
104
97
105
@parameterized .expand (
98
106
[
@@ -107,6 +115,10 @@ class eq(nn.Module):
107
115
def forward (self , lhs_val ):
108
116
return torch .ops .aten .eq .Tensor (lhs_val , torch .tensor (1 ))
109
117
118
+ class eq_operator (nn .Module ):
119
+ def forward (self , lhs_val ):
120
+ return lhs_val == torch .tensor (1 )
121
+
110
122
input_specs = [
111
123
Input (
112
124
dtype = torch .int32 ,
@@ -119,6 +131,10 @@ def forward(self, lhs_val):
119
131
eq (),
120
132
input_specs ,
121
133
)
134
+ self .run_test_with_dynamic_shape (
135
+ eq_operator (),
136
+ input_specs ,
137
+ )
122
138
123
139
@parameterized .expand (
124
140
[
@@ -133,6 +149,10 @@ class eq(nn.Module):
133
149
def forward (self , lhs_val ):
134
150
return torch .ops .aten .eq .Scalar (lhs_val , 1.0 )
135
151
152
+ class eq_operator (nn .Module ):
153
+ def forward (self , lhs_val ):
154
+ return lhs_val == 1.0
155
+
136
156
input_specs = [
137
157
Input (
138
158
dtype = torch .int32 ,
@@ -145,6 +165,10 @@ def forward(self, lhs_val):
145
165
eq (),
146
166
input_specs ,
147
167
)
168
+ self .run_test_with_dynamic_shape (
169
+ eq_operator (),
170
+ input_specs ,
171
+ )
148
172
149
173
150
174
if __name__ == "__main__" :
0 commit comments