Skip to content

[executorch] Optimized 2D-by-1D broadcasting in optimized op_mul #4808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 86 additions & 24 deletions kernels/optimized/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,45 +22,74 @@ using ScalarType = exec_aten::ScalarType;

namespace {

// NOTE: we bake ArrayRef iterators being pointers into the return
// type here because we assume that iterators are portable across
// ArrayRef copies.
const Tensor::SizesType* arrayref_begin_ignoring_leading_1s(
ArrayRef<Tensor::SizesType> arr) {
return std::find_if(
arr.begin(), arr.end(), [](Tensor::SizesType x) { return x != 1; });
}

bool sizes_match_ignoring_leading_1s(
ArrayRef<Tensor::SizesType> lhs,
ArrayRef<Tensor::SizesType> rhs) {
auto lhs_begin = lhs.begin();
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs);
auto lhs_end = lhs.end();
while (lhs_begin != lhs_end && *lhs_begin == 1) {
++lhs_begin;
}

auto rhs_begin = rhs.begin();
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs);
auto rhs_end = rhs.end();
while (rhs_begin != rhs_end && *rhs_begin == 1) {
++rhs_begin;
}

return ((lhs_end - lhs_begin) == (rhs_end - rhs_begin)) &&
std::equal(lhs_begin, lhs_end, rhs_begin);
}

// Move to generic util as this is applicable to all binary ops
bool can_use_optimized_path(
const Tensor& a,
const Tensor& b,
const Tensor& out) {
enum class ElementwiseOptimizedPath {
kNone,
kTreatAs1d,
kBroadcast2dBy1d,
kBroadcast2dBy1dReverseArguments,
};

ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
const Tensor& lhs,
const Tensor& rhs) {
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
auto lhs_end = lhs.sizes().end();

auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
auto rhs_end = rhs.sizes().end();

const auto lhs_size = lhs_end - lhs_begin;
const auto rhs_size = rhs_end - rhs_begin;
if (lhs_size == 2 && rhs_size == 1 && lhs_begin[1] == rhs_begin[0]) {
return ElementwiseOptimizedPath::kBroadcast2dBy1d;
}

if (lhs_size == 1 && rhs_size == 2 && rhs_begin[1] == lhs_begin[0]) {
return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments;
}

return ElementwiseOptimizedPath::kNone;
}

ElementwiseOptimizedPath
select_optimized_path(const Tensor& a, const Tensor& b, const Tensor& out) {
ScalarType a_type = a.scalar_type();
ScalarType b_type = b.scalar_type();
ScalarType out_type = out.scalar_type();

bool can_use_optimized_path = true;
can_use_optimized_path =
can_use_optimized_path && ((a_type == b_type) && (a_type == out_type));
can_use_optimized_path = can_use_optimized_path &&
(a_type != ScalarType::Half && b_type != ScalarType::Half);
can_use_optimized_path = can_use_optimized_path &&
(a.sizes().equals(b.sizes()) ||
(a.numel() == b.numel() &&
(a.numel() == out.numel() ||
sizes_match_ignoring_leading_1s(a.sizes(), b.sizes()))));
return can_use_optimized_path;
if (a_type != b_type || a_type != out_type || a_type == ScalarType::Half) {
return ElementwiseOptimizedPath::kNone;
}
if (a.sizes().equals(b.sizes()) ||
(a.numel() == b.numel() &&
(a.numel() == out.numel() ||
sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) {
return ElementwiseOptimizedPath::kTreatAs1d;
}
return select_broadcast_2d_by_1d_optimized_path(a, b);
}

template <
Expand Down Expand Up @@ -147,7 +176,8 @@ Tensor& opt_mul_out(
return opt_mul_out(ctx, b, a, out);
}

if (can_use_optimized_path(a, b, out)) {
auto selected_optimized_path = select_optimized_path(a, b, out);
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
// Resize for dynamic shape
auto error = resize_tensor(out, a.sizes());
ET_KERNEL_CHECK_MSG(
Expand All @@ -166,6 +196,38 @@ Tensor& opt_mul_out(
b.const_data_ptr<CTYPE>(),
out.numel());
});
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
const Tensor* lhs;
const Tensor* rhs;
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
lhs = &b;
rhs = &a;
} else {
// Catch failure to update logic when adding new broadcasting possibility.
ET_DCHECK(
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1d);
lhs = &a;
rhs = &b;
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
using Vec = executorch::vec::Vectorized<CTYPE>;
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
[](Vec x, Vec y) { return x * y; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
lhs->sizes()[lhs->dim() - 2],
lhs->sizes()[lhs->dim() - 1]);
});
} else {
ScalarType common_type =
promoteTypes(a_type, b_type, /*half_to_float*/ true);
Expand Down
35 changes: 35 additions & 0 deletions kernels/optimized/vec/functional_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,5 +325,40 @@ inline void map4(
}
}


// Map vec_fun across input_data and input_data2, where input_data is
// a two-dimensional array of size (size, size2), input_data2 is a
// one-dimensional array of size size2, and input_data2 is broadcast
// to be of size (size, size2).
template <typename scalar_t, typename Op>
inline void broadcasting_map_2d_by_1d(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* input_data,
const scalar_t* input_data2,
int64_t size,
int64_t size2) {
using Vec = vec::Vectorized<scalar_t>;
for (int64_t outer_idx = 0; outer_idx < size; ++outer_idx) {
const scalar_t* input_data_row = input_data + outer_idx * size2;
scalar_t* output_data_row = output_data + outer_idx * size2;
int64_t inner_idx = 0;
for (; inner_idx < size2 - (size2 % Vec::size()); inner_idx += Vec::size()) {
Vec data_vec = Vec::loadu(input_data_row + inner_idx);
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx);
Vec output_vec = vec_fun(data_vec, data_vec2);
output_vec.store(output_data_row + inner_idx);
}
if (size2 - inner_idx > 0) {
Vec data_vec = Vec::loadu(input_data_row + inner_idx, size2 - inner_idx);
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx, size2 - inner_idx);
Vec output_vec = vec_fun(data_vec, data_vec2);
output_vec.store(output_data_row + inner_idx, size2 - inner_idx);
}
}
}



} // namespace vec
} // namespace executorch
29 changes: 18 additions & 11 deletions kernels/test/op_mul_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,19 @@ TEST_F(OpMulOutTest, OptimizedPathIgnoresLeading1Dimensions) {
}

// Mismatched shape tests.
TEST_F(OpMulOutTest, MismatchedInputShapesDies) {
TEST_F(OpMulOutTest, MismatchedNonBroadcastableInputShapesDies) {
if (SupportedFeatures::get()->is_aten) {
GTEST_SKIP() << "ATen currently supports mismatched shapes";
}

TensorFactory<ScalarType::Int> tf;

// Input tensors with different shapes.
Tensor a = tf.ones(/*sizes=*/{1, 2});
Tensor a = tf.ones(/*sizes=*/{4, 2});
Tensor b = tf.ones(/*sizes=*/{2, 2});

// Output tensor; matches the shape of one of the inputs.
Tensor out = tf.zeros(/*sizes=*/{4});
Tensor out = tf.zeros(/*sizes=*/{8});

// Multiplying the two mismatched tensors should cause an assertion and kill
// the test process.
Expand All @@ -204,16 +204,22 @@ TEST_F(OpMulOutTest, MismatchedInputShapesDies) {
TEST_F(OpMulOutTest, BroadcastA2BTest) {
TensorFactory<ScalarType::Int> tf_a;

// a and b of different shapes
Tensor a = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4});
Tensor b = tf_a.make({2}, /*data=*/{2, 2});
std::vector<std::vector<int32_t>> b_sizeses = {
{2},
{1, 2},
};
for (const auto& b_sizes : b_sizeses) {
// a and b of different shapes
Tensor a = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4});
Tensor b = tf_a.make(b_sizes, /*data=*/{2, 2});

// Destination for output of mul.
Tensor out = tf_a.zeros({2, 2});
// Destination for output of mul.
Tensor out = tf_a.zeros({2, 2});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(
op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8}));
// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(
op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8}));
}
}

// Broadcast tensor a's size to tensor b's size
Expand Down Expand Up @@ -262,6 +268,7 @@ TEST_F(OpMulOutTest, ScalarInputBroadcastTest) {

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
}

TEST_F(OpMulOutTest, MismatchedOutputShapesDies) {
Expand Down
Loading