Skip to content

Commit fe4e70d

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Fix & cleanup op tril (#696)
Summary: Pull Request resolved: #696 Fix for empty tensor & resize out tensor ghstack-source-id: 203341570 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D49620875 fbshipit-source-id: 776793a63ac289b1041cd382ba717ab8e2be6bfd
1 parent 0a9947d commit fe4e70d

File tree

4 files changed

+24
-21
lines changed

4 files changed

+24
-21
lines changed

kernels/portable/cpu/op_tril.cpp

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
910
#include <executorch/runtime/kernel/kernel_includes.h>
1011
#include <cstring>
11-
#include <type_traits>
1212

1313
namespace torch {
1414
namespace executor {
@@ -130,30 +130,25 @@ Tensor& tril_out(
130130
Tensor& out) {
131131
(void)ctx;
132132

133-
// Assert `self` has at least 2 dims.
134-
ET_CHECK_MSG(self.dim() >= 2, "self.dim() %zd < 2", self.dim());
133+
ET_KERNEL_CHECK(ctx, check_tril_args(self, out), InvalidArgument, out);
135134

136-
// Assert `self` and `out` have the same tensor shape.
137-
ET_CHECK_SAME_SHAPE_AND_DTYPE2(self, out);
135+
ET_KERNEL_CHECK(
136+
ctx,
137+
resize_tensor(out, self.sizes()) == torch::executor::Error::Ok,
138+
InvalidArgument,
139+
out);
140+
141+
if (self.numel() == 0) {
142+
return out;
143+
}
138144

139145
// Fill `out` with 0s prior to executing tril.
140146
clear_out(out);
141147

142-
// Create switches for all dtypes (real + bool).
143-
#define TRIL_OUT(ctype, dtype) \
144-
case ScalarType::dtype: \
145-
tril_kernel<ctype>(self, diagonal, out); \
146-
break;
147-
148-
switch (out.scalar_type()) {
149-
ET_FORALL_REAL_TYPES_AND(Bool, TRIL_OUT)
150-
default:
151-
ET_CHECK_MSG(
152-
false,
153-
"out tensor should be a real or bool dtype, but got %" PRId8,
154-
static_cast<int8_t>(out.scalar_type()));
155-
}
156-
#undef TRIL_OUT
148+
ScalarType out_type = out.scalar_type();
149+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE, [&]() {
150+
tril_kernel<CTYPE>(self, diagonal, out);
151+
});
157152

158153
return out;
159154
}

kernels/portable/cpu/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ _ATEN_OPS = (
757757
op_target(
758758
name = "op_tril",
759759
deps = [
760-
"//executorch/runtime/core/exec_aten/util:tensor_util",
760+
"//executorch/kernels/portable/cpu/util:copy_ops_util",
761761
],
762762
),
763763
op_target(

kernels/portable/cpu/util/copy_ops_util.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,5 +355,11 @@ void get_stack_out_target_size(
355355
}
356356
}
357357

358+
bool check_tril_args(const Tensor& in, Tensor& out) {
359+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
360+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(in, 2));
361+
return true;
362+
}
363+
358364
} // namespace executor
359365
} // namespace torch

kernels/portable/cpu/util/copy_ops_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,7 @@ void get_stack_out_target_size(
8989
Tensor::SizesType* out_sizes,
9090
size_t* out_ndim);
9191

92+
bool check_tril_args(const Tensor& in, Tensor& out);
93+
9294
} // namespace executor
9395
} // namespace torch

0 commit comments

Comments
 (0)