Skip to content

Commit bf65839

Browse files
committed
chore: rebase and update in test
1 parent cc47ac8 commit bf65839

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

tests/py/dynamo/conversion/test_eq_aten.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ class eq(nn.Module):
7575
def forward(self, lhs_val, rhs_val):
7676
return torch.ops.aten.eq.Tensor(lhs_val, rhs_val)
7777

78+
class eq_operator(nn.Module):
79+
def forward(self, lhs_val, rhs_val):
80+
return lhs_val == rhs_val
81+
7882
input_specs = [
7983
Input(
8084
dtype=torch.float32,
@@ -93,6 +97,10 @@ def forward(self, lhs_val, rhs_val):
9397
eq(),
9498
input_specs,
9599
)
100+
self.run_test_with_dynamic_shape(
101+
eq_operator(),
102+
input_specs,
103+
)
96104

97105
@parameterized.expand(
98106
[
@@ -107,6 +115,10 @@ class eq(nn.Module):
107115
def forward(self, lhs_val):
108116
return torch.ops.aten.eq.Tensor(lhs_val, torch.tensor(1))
109117

118+
class eq_operator(nn.Module):
119+
def forward(self, lhs_val):
120+
return lhs_val == torch.tensor(1)
121+
110122
input_specs = [
111123
Input(
112124
dtype=torch.int32,
@@ -119,6 +131,10 @@ def forward(self, lhs_val):
119131
eq(),
120132
input_specs,
121133
)
134+
self.run_test_with_dynamic_shape(
135+
eq_operator(),
136+
input_specs,
137+
)
122138

123139
@parameterized.expand(
124140
[
@@ -133,6 +149,10 @@ class eq(nn.Module):
133149
def forward(self, lhs_val):
134150
return torch.ops.aten.eq.Scalar(lhs_val, 1.0)
135151

152+
class eq_operator(nn.Module):
153+
def forward(self, lhs_val):
154+
return lhs_val == 1.0
155+
136156
input_specs = [
137157
Input(
138158
dtype=torch.int32,
@@ -145,6 +165,10 @@ def forward(self, lhs_val):
145165
eq(),
146166
input_specs,
147167
)
168+
self.run_test_with_dynamic_shape(
169+
eq_operator(),
170+
input_specs,
171+
)
148172

149173

150174
if __name__ == "__main__":

tests/py/dynamo/conversion/test_pow_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_pow_dynamic_shape(
8484
):
8585
class pow(nn.Module):
8686
def forward(self, lhs_val, rhs_val):
87-
return torch.ops.aten.floor_divide.default(lhs_val, rhs_val)
87+
return torch.ops.aten.pow.Tensor_Tensor(lhs_val, rhs_val)
8888

8989
class pow_scalar(nn.Module):
9090
def forward(self, lhs_val, rhs_val):

0 commit comments

Comments
 (0)