Skip to content

Commit 7f36c70

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Fix & cleanup op scalar_tensor (#702)
Summary: Pull Request resolved: #702 Resize out tensor ghstack-source-id: 203341576 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D49792556 fbshipit-source-id: f2b276b0b986f9d1e0d5fc5720149454a5a3223a
1 parent c1032d7 commit 7f36c70

File tree

2 files changed

+8
-20
lines changed

2 files changed

+8
-20
lines changed

kernels/portable/cpu/op_scalar_tensor.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@
88

99
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1010
#include <executorch/runtime/kernel/kernel_includes.h>
11-
#include <executorch/runtime/platform/assert.h>
12-
13-
#include <cstdint>
14-
#include <cstring>
1511

1612
namespace torch {
1713
namespace executor {
@@ -20,16 +16,17 @@ namespace native {
2016
Tensor& scalar_tensor_out(RuntimeContext& ctx, const Scalar& s, Tensor& out) {
2117
(void)ctx;
2218

23-
ET_CHECK_MSG(out.numel() == 1, "Output tensor must have only one element");
19+
ET_KERNEL_CHECK(
20+
ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);
2421

2522
ScalarType s_type = utils::get_scalar_dtype(s);
2623
ScalarType out_type = out.scalar_type();
2724

28-
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "scalar_tensor", CTYPE, [&]() {
29-
ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, "scalar_tensor", CTYPE_S, [&]() {
25+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE, [&]() {
26+
ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, __func__, CTYPE_S, [&]() {
3027
CTYPE_S val_s;
3128
ET_EXTRACT_SCALAR(s, val_s);
32-
out.mutable_data_ptr<CTYPE>()[0] = val_s;
29+
out.mutable_data_ptr<CTYPE>()[0] = convert<CTYPE, CTYPE_S>(val_s);
3330
});
3431
});
3532

kernels/test/op_scalar_tensor_test.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,38 +55,29 @@ void test_scalar_tensor_out_1d(CTYPE value) {
5555
TensorFactory<DTYPE> tf;
5656

5757
std::vector<int32_t> sizes{1};
58-
Tensor expected = tf.make(sizes, /*data=*/{value});
59-
6058
Tensor out = tf.ones(sizes);
61-
op_scalar_tensor_out(value, out);
6259

63-
EXPECT_TENSOR_EQ(out, expected);
60+
ET_EXPECT_KERNEL_FAILURE(op_scalar_tensor_out(value, out));
6461
}
6562

6663
template <typename CTYPE, ScalarType DTYPE>
6764
void test_scalar_tensor_out_2d(CTYPE value) {
6865
TensorFactory<DTYPE> tf;
6966

7067
std::vector<int32_t> sizes{1, 1};
71-
Tensor expected = tf.make(sizes, /*data=*/{value});
72-
7368
Tensor out = tf.ones(sizes);
74-
op_scalar_tensor_out(value, out);
7569

76-
EXPECT_TENSOR_EQ(out, expected);
70+
ET_EXPECT_KERNEL_FAILURE(op_scalar_tensor_out(value, out));
7771
}
7872

7973
template <typename CTYPE, ScalarType DTYPE>
8074
void test_scalar_tensor_out_3d(CTYPE value) {
8175
TensorFactory<DTYPE> tf;
8276

8377
std::vector<int32_t> sizes{1, 1, 1};
84-
Tensor expected = tf.make(sizes, /*data=*/{value});
85-
8678
Tensor out = tf.ones(sizes);
87-
op_scalar_tensor_out(value, out);
8879

89-
EXPECT_TENSOR_EQ(out, expected);
80+
ET_EXPECT_KERNEL_FAILURE(op_scalar_tensor_out(value, out));
9081
}
9182

9283
#define GENERATE_TEST(ctype, dtype) \

0 commit comments

Comments
 (0)