Skip to content

Commit 0a9947d

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Fix & cleanup op logit (#695)
Summary: Pull Request resolved: #695 ghstack-source-id: 203341583 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D49625518 fbshipit-source-id: 1f5922e0f3ca7d2f36f69eaa07e9bc42618ca93d
1 parent 7a838f7 commit 0a9947d

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

kernels/portable/cpu/op_logit.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@ Tensor& logit_out(
2525
(void)ctx;
2626

2727
// Resize for dynamic shape
28-
auto error = resize_tensor(out, in.sizes());
29-
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
30-
ET_CHECK_SAME_SHAPE2(in, out);
28+
ET_KERNEL_CHECK(
29+
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
3130

32-
ET_SWITCH_REAL_TYPES_AND(Bool, in.scalar_type(), ctx, "logit", CTYPE_IN, [&] {
33-
ET_SWITCH_FLOAT_TYPES(out.scalar_type(), ctx, "logit", CTYPE_OUT, [&] {
31+
ScalarType in_type = in.scalar_type();
32+
ScalarType out_type = out.scalar_type();
33+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE_IN, [&] {
34+
ET_SWITCH_FLOAT_TYPES(out_type, ctx, __func__, CTYPE_OUT, [&] {
3435
apply_unary_map_fn(
3536
[eps](const CTYPE_IN val_in) {
3637
CTYPE_OUT xi = static_cast<CTYPE_OUT>(val_in);
@@ -41,7 +42,6 @@ Tensor& logit_out(
4142
xi = 1 - eps.value();
4243
}
4344
}
44-
ET_CHECK_MSG(xi > 0.0 && xi < 1.0, "input must be in (0, 1).");
4545
return static_cast<CTYPE_OUT>(
4646
log(xi / (static_cast<CTYPE_OUT>(1.0) - xi)));
4747
},

kernels/test/op_logit_test.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ void test_integer_logit_out() {
3737
// Destination for the logit operator.
3838
Tensor out = tf_out.zeros(sizes);
3939

40-
ET_EXPECT_KERNEL_FAILURE(
41-
op_logit_out(tf.make(sizes, /*data=*/{1, 2, 4, 8}), 0, out));
40+
op_logit_out(tf.make(sizes, /*data=*/{1, 2, 4, 8}), 0, out);
41+
EXPECT_TENSOR_CLOSE(
42+
out,
43+
tf_out.make(sizes, /*data=*/{INFINITY, INFINITY, INFINITY, INFINITY}));
4244
}
4345

4446
template <>
@@ -79,19 +81,13 @@ void test_integer_logit_out_eps_set() {
7981
}
8082

8183
TEST(OpLogitOutKernelTest, AllRealInputFloatOutputSupport) {
82-
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
83-
GTEST_SKIP() << "ATen kernel can handle this";
84-
}
8584
#define TEST_ENTRY(ctype, dtype) \
8685
test_integer_logit_out<ScalarType::dtype, ScalarType::Float>();
8786
ET_FORALL_REAL_TYPES(TEST_ENTRY);
8887
#undef TEST_ENTRY
8988
}
9089

9190
TEST(OpLogitOutKernelTest, AllRealInputDoubleOutputSupport) {
92-
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
93-
GTEST_SKIP() << "ATen kernel can handle this";
94-
}
9591
#define TEST_ENTRY(ctype, dtype) \
9692
test_integer_logit_out<ScalarType::dtype, ScalarType::Double>();
9793
ET_FORALL_REAL_TYPES(TEST_ENTRY);

0 commit comments

Comments
 (0)