Skip to content

Commit 1052e3b

Browse files
swolchokfacebook-github-bot
authored andcommitted
Unbreak optimized sub in the case where one input is a scalar and dtype is mixed or both are Half (#5894)
Summary: Pull Request resolved: #5894 I misplaced a return. Caught this during development in div, but not sub. ghstack-source-id: 246398371 exported-using-ghexport Reviewed By: kirklandsign Differential Revision: D63919553 fbshipit-source-id: 84e27efc400711ae0253e48a4dbdb4b419140925
1 parent 94289ad commit 1052e3b

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

kernels/optimized/cpu/op_sub.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ Tensor& opt_sub_out(
134134
}
135135
});
136136
});
137+
return out;
137138
}
138-
return out;
139139
}
140140

141141
auto selected_optimized_path = select_optimized_path(a, b, out);

kernels/test/op_sub_test.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,27 @@ class OpSubOutTest : public OperatorTest {
107107

108108
#undef ENUMERATE_TEST_ENTRY
109109
}
110+
111+
template <ScalarType DTYPE>
112+
void test_broadcast_rank1_scalar() {
113+
TensorFactory<DTYPE> tf;
114+
115+
Tensor a = tf.make({2, 1, 3}, {2, 3, 4, 5, 6, 7});
116+
Tensor b = tf.make({1}, {2});
117+
118+
// Destination for the broadcasting div. Follow the broadcasting rules in
119+
// https://fburl.com/n9wl4d0o
120+
Tensor out = tf.zeros({2, 1, 3});
121+
122+
op_sub_out(a, b, 1, out);
123+
124+
Tensor ret = tf.make({2, 1, 3}, {0, 1, 2, 3, 4, 5});
125+
EXPECT_TENSOR_EQ(out, ret);
126+
127+
op_sub_out(b, a, 1, out);
128+
ret = tf.make({2, 1, 3}, {0, -1, -2, -3, -4, -5});
129+
EXPECT_TENSOR_EQ(out, ret);
130+
}
110131
};
111132

112133
class OpSubScalarOutTest : public OperatorTest {
@@ -171,19 +192,8 @@ TEST_F(OpSubOutTest, BroadcastSupported2) {
171192
}
172193

173194
TEST_F(OpSubOutTest, BroadcastScalarSupported1) {
174-
TensorFactory<ScalarType::Float> tf;
175-
176-
Tensor a = tf.make({2, 1, 3}, {2, 3, 4, 5, 6, 7});
177-
Tensor b = tf.make({1}, {2});
178-
179-
// Destination for the broadcasting div. Follow the broadcasting rules in
180-
// https://fburl.com/n9wl4d0o
181-
Tensor out = tf.zeros({2, 1, 3});
182-
183-
op_sub_out(a, b, 1, out);
184-
185-
Tensor ret = tf.make({2, 1, 3}, {0, 1, 2, 3, 4, 5});
186-
EXPECT_TENSOR_EQ(out, ret);
195+
test_broadcast_rank1_scalar<ScalarType::Float>();
196+
test_broadcast_rank1_scalar<ScalarType::Half>();
187197
}
188198

189199
TEST_F(OpSubOutTest, BroadcastScalarSupported2) {

0 commit comments

Comments
 (0)