Skip to content

Commit 31a2478

Browse files
committed
Reapply #11294 and #11295 (improve GLU test and implement using internal views to avoid copying)
Pull Request resolved: #11509 These were reverted due to internal test failures. Sending this as an exported internal diff so that we can make sure we get internal signal. Original summary for #11294 (to make the GLU test input asymmetric): This way it will produce different results along each tested dim. Original summaryfor #11295: GLU requires slicing the input Tensor into two halves. Currently, we accomplish this by copying; ExecuTorch does not support views in general because it requires Tensors to be contiguous. However, nothing stops us from implementing [the ATen that uses views](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/GatedLinearUnit.cpp#L35) entirely internally to the op. To support this, I added `support_noncontiguous_tensors` as an optional template argument to BroadcastIndexesRange and plumbed it through to the elementwise_util functions as an optional SupportNonContiguousTensors parameter. Differential Revision: [D76311585](https://our.internmc.facebook.com/intern/diff/D76311585/) ghstack-source-id: 289429482
1 parent ef1d2ff commit 31a2478

File tree

5 files changed

+220
-113
lines changed

5 files changed

+220
-113
lines changed

kernels/portable/cpu/op_glu.cpp

Lines changed: 77 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <c10/util/irange.h>
1010
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
11+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1112
#include <executorch/runtime/kernel/kernel_includes.h>
1213
#include <executorch/runtime/platform/assert.h>
1314
#include <cinttypes>
@@ -23,92 +24,45 @@ using ScalarType = executorch::aten::ScalarType;
2324

2425
namespace {
2526

26-
double exp_overload(double d) {
27-
return exp(d);
28-
}
29-
30-
float exp_overload(float f) {
31-
return expf(f);
32-
}
33-
34-
/**
35-
* In-place element-wise sigmoid function , i.e., f(x) = 1 / (1 + e^{-x})
36-
*/
37-
// TODO: T146333648, refactor this as a common helper function
38-
template <typename CTYPE_OUT>
39-
void sigmoid_tensor(Tensor& out) {
40-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
41-
for (const auto i : c10::irange(out.numel())) {
42-
out_data[i] = 1.0 / (1.0 + exp_overload(-out_data[i]));
43-
}
44-
}
45-
46-
/**
47-
* Element-wise multiplication of the first half of `in` along the specified
48-
* dimension and `out`, overwriting `out`.
49-
*/
50-
template <typename CTYPE_IN, typename CTYPE_OUT>
51-
void mul_tensors(const Tensor& in, int64_t dim, Tensor& out) {
52-
size_t num_values = static_cast<size_t>(in.size(dim)) / 2;
53-
size_t dim_length_in = static_cast<size_t>(in.size(dim));
54-
size_t dim_length_out = static_cast<size_t>(out.size(dim));
55-
size_t leading_dims = getLeadingDims(in, dim);
56-
size_t trailing_dims = getTrailingDims(in, dim);
57-
58-
const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
59-
CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
60-
61-
for (const auto i : c10::irange(leading_dims)) {
62-
const CTYPE_IN* input_data =
63-
input_data_base + i * dim_length_in * trailing_dims;
64-
CTYPE_OUT* output_data =
65-
output_data_base + i * dim_length_out * trailing_dims;
66-
for ([[maybe_unused]] const auto j : c10::irange(num_values)) {
67-
for (const auto k : c10::irange(trailing_dims)) {
68-
output_data[k] = static_cast<CTYPE_OUT>(input_data[k]) * output_data[k];
69-
}
70-
input_data += trailing_dims;
71-
output_data += trailing_dims;
72-
}
27+
struct SplitGLUInputTensor {
28+
explicit SplitGLUInputTensor(const Tensor& self, int64_t dim);
29+
using SizesArray = std::array<executorch::aten::SizesType, kTensorDimensionLimit>;
30+
SizesArray half_sizes;
31+
TensorImpl first_half_impl;
32+
TensorImpl second_half_impl;
33+
Tensor first_half;
34+
Tensor second_half;
35+
36+
private:
37+
static SizesArray get_half_sizes(const Tensor& self, int64_t dim) {
38+
SizesArray half_sizes;
39+
std::copy(self.sizes().begin(), self.sizes().end(), half_sizes.begin());
40+
half_sizes[dim] /= 2;
41+
return half_sizes;
7342
}
74-
}
75-
76-
/**
77-
* Slice the tensor in the given dim, from start to end, assume tensor in and
78-
* out have same shape and dtype, the dim is a non-negative number and start,
79-
* end are valid non-negative number
80-
*/
81-
template <typename CTYPE_IN, typename CTYPE_OUT>
82-
void slice_tensor(
83-
const Tensor& in,
84-
int64_t dim,
85-
int64_t start,
86-
int64_t end,
87-
Tensor& out) {
88-
size_t num_values = static_cast<size_t>(end - start);
89-
size_t dim_length_in = static_cast<size_t>(in.size(dim));
90-
size_t dim_length_out = static_cast<size_t>(out.size(dim));
91-
size_t non_negative_start = static_cast<size_t>(start);
92-
size_t leading_dims = getLeadingDims(in, dim);
93-
size_t trailing_dims = getTrailingDims(in, dim);
94-
95-
const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
96-
CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
97-
98-
for (const auto i : c10::irange(leading_dims)) {
99-
const CTYPE_IN* input_data = input_data_base +
100-
(i * dim_length_in + non_negative_start) * trailing_dims;
101-
CTYPE_OUT* output_data =
102-
output_data_base + i * dim_length_out * trailing_dims;
103-
for ([[maybe_unused]] const auto j : c10::irange(num_values)) {
104-
for (const auto k : c10::irange(trailing_dims)) {
105-
output_data[k] = static_cast<CTYPE_OUT>(input_data[k]);
106-
}
107-
input_data += trailing_dims;
108-
output_data += trailing_dims;
109-
}
110-
}
111-
}
43+
};
44+
45+
SplitGLUInputTensor::SplitGLUInputTensor(const Tensor& self, int64_t dim)
46+
: half_sizes(get_half_sizes(self, dim)),
47+
first_half_impl(
48+
self.scalar_type(),
49+
self.dim(),
50+
half_sizes.data(),
51+
self.mutable_data_ptr(),
52+
const_cast<executorch::aten::DimOrderType*>(self.dim_order().data()),
53+
const_cast<executorch::aten::StridesType*>(self.strides().data()),
54+
self.shape_dynamism()),
55+
second_half_impl(
56+
self.scalar_type(),
57+
self.dim(),
58+
half_sizes.data(),
59+
reinterpret_cast<char*>(self.mutable_data_ptr()) +
60+
self.strides()[dim] * self.size(dim) / 2 * self.element_size(),
61+
const_cast<executorch::aten::DimOrderType*>(self.dim_order().data()),
62+
const_cast<executorch::aten::StridesType*>(self.strides().data()),
63+
self.shape_dynamism()),
64+
first_half(&first_half_impl),
65+
second_half(&second_half_impl) {}
11266

11367
/**
11468
* Applies the gated linear unit function
@@ -120,11 +74,43 @@ void slice_tensor(
12074
* 2. The output shall be in float types (Float, Double)
12175
*/
12276
template <typename CTYPE_IN, typename CTYPE_OUT>
123-
Tensor& glu_out_tensor(const Tensor& self, int64_t dim, Tensor& out) {
124-
const auto self_size = self.size(dim);
125-
slice_tensor<CTYPE_IN, CTYPE_OUT>(self, dim, self_size / 2, self_size, out);
126-
sigmoid_tensor<CTYPE_OUT>(out);
127-
mul_tensors<CTYPE_IN, CTYPE_OUT>(self, dim, out);
77+
Tensor& glu_out_tensor(
78+
KernelRuntimeContext& ctx,
79+
const Tensor& self,
80+
int64_t dim,
81+
Tensor& out) {
82+
ET_KERNEL_CHECK(
83+
ctx,
84+
self.dim() <= static_cast<ssize_t>(kTensorDimensionLimit),
85+
InvalidArgument,
86+
out);
87+
SplitGLUInputTensor split_input(self, dim);
88+
ScalarType compute_type =
89+
executorch::runtime::isFloatingType(self.scalar_type())
90+
? self.scalar_type()
91+
: ScalarType::Float;
92+
// @lint-ignore CLANGTIDY facebook-hte-CArray
93+
static constexpr const char op_name[] = "glu.out";
94+
ET_SWITCH_FLOATHBF16_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
95+
utils::apply_bitensor_elementwise_fn<
96+
CTYPE_COMPUTE,
97+
op_name,
98+
utils::SupportedTensorDtypes::FLOATHBF16>(
99+
[](const auto val_a, const auto val_b) -> CTYPE_COMPUTE {
100+
// TODO: rewrite this to be vectorization-capable? the
101+
// tensors might not be contiguous; need to have
102+
// apply_bitensor_elementwise_fn check that.
103+
const auto one = static_cast<decltype(val_a)>(1.0);
104+
return val_a * (one / (one + std::exp(-val_b)));
105+
},
106+
ctx,
107+
split_input.first_half,
108+
utils::SupportedTensorDtypes::FLOATHBF16,
109+
split_input.second_half,
110+
utils::SupportedTensorDtypes::FLOATHBF16,
111+
out,
112+
utils::internal::SupportNoncontiguousTensors());
113+
});
128114
return out;
129115
}
130116
} // namespace
@@ -158,7 +144,7 @@ Tensor& glu_out(
158144

159145
ET_SWITCH_FLOATHBF16_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() {
160146
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "glu", CTYPE_OUT, [&]() {
161-
glu_out_tensor<CTYPE_IN, CTYPE_OUT>(self, non_negative_dim, out);
147+
glu_out_tensor<CTYPE_IN, CTYPE_OUT>(ctx, self, non_negative_dim, out);
162148
});
163149
});
164150

kernels/portable/cpu/util/broadcast_indexes_range.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ inline bool sizes_match_ignoring_leading_1s(
4343
std::equal(lhs_begin, lhs_end, rhs_begin);
4444
}
4545

46-
template <std::size_t kNumInputs>
46+
template <std::size_t kNumInputs, bool support_noncontiguous_tensors = false>
4747
class BroadcastIndexesIterator {
4848
public:
4949
using difference_type = ssize_t;
@@ -57,16 +57,20 @@ class BroadcastIndexesIterator {
5757
template <typename... Args>
5858
explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args)
5959
: output_dim_or_zero_if_no_broadcasting_(
60-
(sizes_match_ignoring_leading_1s(args.sizes(), output.sizes()) &&
61-
...)
60+
!support_noncontiguous_tensors &&
61+
(sizes_match_ignoring_leading_1s(
62+
args.sizes(),
63+
output.sizes()) &&
64+
...)
6265
? 0
6366
: output.dim()),
6467
output_shape_(output.sizes()) {
6568
static_assert(
6669
sizeof...(args) == kNumInputs && (std::is_same_v<Args, Tensor> && ...),
6770
"BroadcastIndexesIterator constructor requires kNumInputs input tensor"
6871
"arguments!");
69-
if (output_dim_or_zero_if_no_broadcasting_ != 0) {
72+
if (support_noncontiguous_tensors ||
73+
output_dim_or_zero_if_no_broadcasting_ != 0) {
7074
effective_input_broadcast_strides_ = {
7175
effective_input_broadcast_stride(output, args)...};
7276
}
@@ -249,11 +253,17 @@ class BroadcastIndexesIterator {
249253
* Unlike looping using delinearize_index() and
250254
* linearize_access_indexes(), BroadcastIndexesRange avoids expensive
251255
* division and modulo operations on each iteration.
256+
*
257+
* The support_noncontiguous_tensors argument disables an optimization
258+
* that causes the iterators not to respect strides in some
259+
* cases. This optimization is normally safe because ExecuTorch
260+
* tensors are contiguous.
252261
*/
253-
template <std::size_t kNumInputs>
262+
template <std::size_t kNumInputs, bool support_noncontiguous_tensors = false>
254263
class BroadcastIndexesRange {
255264
public:
256-
using iterator = internal::BroadcastIndexesIterator<kNumInputs>;
265+
using iterator = internal::
266+
BroadcastIndexesIterator<kNumInputs, support_noncontiguous_tensors>;
257267

258268
template <typename... Args>
259269
BroadcastIndexesRange(const Tensor& output, const Args&... args)

0 commit comments

Comments
 (0)