|
| 1 | +// Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | + |
| 3 | +#include <executorch/core/Assert.h> |
| 4 | +#include <executorch/kernels/kernel_includes.h> |
| 5 | +#include <executorch/kernels/optimized/vec/functional.h> |
| 6 | +#include <executorch/kernels/optimized/vec/vec.h> |
| 7 | +#include <executorch/kernels/portable/cpu/scalar_utils.h> |
| 8 | +#include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
| 9 | + |
| 10 | +namespace torch { |
| 11 | +namespace executor { |
| 12 | +namespace native { |
| 13 | + |
| 14 | +using Tensor = exec_aten::Tensor; |
| 15 | +using ScalarType = exec_aten::ScalarType; |
| 16 | + |
| 17 | +Tensor& opt_add_out( |
| 18 | + RuntimeContext& ctx, |
| 19 | + const Tensor& a, |
| 20 | + const Tensor& b, |
| 21 | + const Scalar& alpha, |
| 22 | + Tensor& out) { |
| 23 | + (void)ctx; |
| 24 | + |
| 25 | + ScalarType a_type = a.scalar_type(); |
| 26 | + ScalarType b_type = b.scalar_type(); |
| 27 | + ScalarType out_type = out.scalar_type(); |
| 28 | + |
| 29 | + if (a_type == b_type && a_type == out_type && a.sizes().equals(b.sizes())) { |
| 30 | + // Resize for dynamic shape |
| 31 | + auto error = resize_tensor(out, a.sizes()); |
| 32 | + ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor."); |
| 33 | + |
| 34 | + ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "add", CTYPE, [&]() { |
| 35 | + CTYPE alpha_val; |
| 36 | + ET_EXTRACT_SCALAR(alpha, alpha_val); |
| 37 | + |
| 38 | + using Vec = executorch::vec::Vectorized<CTYPE>; |
| 39 | + executorch::vec::map2<CTYPE>( |
| 40 | + [alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; }, |
| 41 | + out.mutable_data_ptr<CTYPE>(), |
| 42 | + a.const_data_ptr<CTYPE>(), |
| 43 | + b.const_data_ptr<CTYPE>(), |
| 44 | + out.numel()); |
| 45 | + }); |
| 46 | + } else { |
| 47 | + ScalarType common_type = promoteTypes(a_type, b_type); |
| 48 | + ET_CHECK(canCast(common_type, out_type)); |
| 49 | + |
| 50 | + resize_to_broadcast_target_size(a, b, out); |
| 51 | + |
| 52 | + ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "add", CTYPE_A, [&]() { |
| 53 | + ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "add", CTYPE_B, [&]() { |
| 54 | + ET_SWITCH_REAL_TYPES_AND( |
| 55 | + Bool, common_type, ctx, "add", CTYPE_IN, [&]() { |
| 56 | + ET_SWITCH_REAL_TYPES_AND( |
| 57 | + Bool, out_type, ctx, "add", CTYPE_OUT, [&]() { |
| 58 | + CTYPE_IN alpha_val; |
| 59 | + ET_EXTRACT_SCALAR(alpha, alpha_val); |
| 60 | + |
| 61 | + apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>( |
| 62 | + [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { |
| 63 | + CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a); |
| 64 | + CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b); |
| 65 | + CTYPE_IN value = a_casted + alpha_val * b_casted; |
| 66 | + |
| 67 | + return static_cast<CTYPE_OUT>(value); |
| 68 | + }, |
| 69 | + a, |
| 70 | + b, |
| 71 | + out); |
| 72 | + }); |
| 73 | + }); |
| 74 | + }); |
| 75 | + }); |
| 76 | + } |
| 77 | + |
| 78 | + return out; |
| 79 | +} |
| 80 | + |
| 81 | +Tensor& opt_add_scalar_out( |
| 82 | + RuntimeContext& ctx, |
| 83 | + const Tensor& a, |
| 84 | + const Scalar& b, |
| 85 | + const Scalar& alpha, |
| 86 | + Tensor& out) { |
| 87 | + (void)ctx; |
| 88 | + |
| 89 | + ScalarType a_type = a.scalar_type(); |
| 90 | + ScalarType b_type = utils::get_scalar_dtype(b); |
| 91 | + ScalarType common_type = utils::promote_type_with_scalar(a_type, b); |
| 92 | + ScalarType out_type = out.scalar_type(); |
| 93 | + |
| 94 | + ET_CHECK(common_type == out_type); |
| 95 | + |
| 96 | + // Resize for dynamic shape |
| 97 | + auto error = resize_tensor(out, a.sizes()); |
| 98 | + ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor."); |
| 99 | + |
| 100 | + if (a_type == common_type && a_type == out_type) { |
| 101 | + ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "add", CTYPE, [&]() { |
| 102 | + ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "add", CTYPE_B, [&]() { |
| 103 | + CTYPE_B b_val; |
| 104 | + ET_EXTRACT_SCALAR(b, b_val); |
| 105 | + CTYPE b_casted = static_cast<CTYPE>(b_val); |
| 106 | + CTYPE alpha_val; |
| 107 | + ET_EXTRACT_SCALAR(alpha, alpha_val); |
| 108 | + |
| 109 | + using Vec = executorch::vec::Vectorized<CTYPE>; |
| 110 | + executorch::vec::map<CTYPE>( |
| 111 | + [alpha_val, b_casted](Vec x) { |
| 112 | + return x + Vec(alpha_val * b_casted); |
| 113 | + }, |
| 114 | + out.mutable_data_ptr<CTYPE>(), |
| 115 | + a.const_data_ptr<CTYPE>(), |
| 116 | + out.numel()); |
| 117 | + }); |
| 118 | + }); |
| 119 | + } else { |
| 120 | + ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "add", CTYPE_A, [&]() { |
| 121 | + ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "add", CTYPE_B, [&]() { |
| 122 | + ET_SWITCH_REAL_TYPES_AND( |
| 123 | + Bool, common_type, ctx, "add", CTYPE_IN, [&]() { |
| 124 | + ET_SWITCH_REAL_TYPES_AND( |
| 125 | + Bool, out_type, ctx, "add", CTYPE_OUT, [&]() { |
| 126 | + CTYPE_B b_val; |
| 127 | + ET_EXTRACT_SCALAR(b, b_val); |
| 128 | + CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val); |
| 129 | + CTYPE_IN alpha_val; |
| 130 | + ET_EXTRACT_SCALAR(alpha, alpha_val); |
| 131 | + |
| 132 | + const size_t n = a.numel(); |
| 133 | + const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>(); |
| 134 | + CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>(); |
| 135 | + for (auto i = 0; i < n; ++i) { |
| 136 | + out_data[i] = static_cast<CTYPE_OUT>( |
| 137 | + static_cast<CTYPE_IN>(a_data[i]) + |
| 138 | + alpha_val * b_casted); |
| 139 | + } |
| 140 | + }); |
| 141 | + }); |
| 142 | + }); |
| 143 | + }); |
| 144 | + } |
| 145 | + |
| 146 | + return out; |
| 147 | +} |
| 148 | + |
| 149 | +} // namespace native |
| 150 | +} // namespace executor |
| 151 | +} // namespace torch |
0 commit comments