@@ -106,6 +106,15 @@ template <
106
106
typename CTYPE_OUT>
107
107
struct MulInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
108
108
: public ReportCanCastBug {};
109
+
110
+ Scalar tensor_to_scalar (RuntimeContext& ctx, const Tensor& t) {
111
+ ET_DCHECK (t.numel () == 1 );
112
+ Scalar result;
113
+ ET_SWITCH_REALHB_TYPES (t.scalar_type (), ctx, " mul.out" , CTYPE, [&]() {
114
+ result = Scalar (*t.const_data_ptr <CTYPE>());
115
+ });
116
+ return result;
117
+ }
109
118
} // namespace
110
119
111
120
Tensor& opt_mul_out (
@@ -119,6 +128,35 @@ Tensor& opt_mul_out(
119
128
ScalarType b_type = b.scalar_type ();
120
129
ScalarType out_type = out.scalar_type ();
121
130
131
+ if (b.numel () == 1 ) {
132
+ if (a_type == b_type && a_type == out_type &&
133
+ a_type != ScalarType::Half) {
134
+ auto error = resize_tensor (out, a.sizes ());
135
+ ET_KERNEL_CHECK_MSG (
136
+ ctx,
137
+ error == Error::Ok,
138
+ InvalidArgument,
139
+ out,
140
+ " Failed to resize output tensor." );
141
+ ET_SWITCH_REALB_TYPES (a_type, ctx, " mul.out" , CTYPE, [&]() {
142
+ ET_SWITCH_REALB_TYPES (b_type, ctx, " mul.out" , CTYPE_B, [&]() {
143
+ CTYPE_B b_val = *b.const_data_ptr <CTYPE_B>();
144
+ CTYPE b_casted = static_cast <CTYPE>(b_val);
145
+
146
+ using Vec = executorch::vec::Vectorized<CTYPE>;
147
+ executorch::vec::map<CTYPE>(
148
+ [b_casted](Vec x) { return x * Vec (b_casted); },
149
+ out.mutable_data_ptr <CTYPE>(),
150
+ a.const_data_ptr <CTYPE>(),
151
+ out.numel ());
152
+ });
153
+ });
154
+ return out;
155
+ }
156
+ } else if (a.numel () == 1 ) {
157
+ return opt_mul_out (ctx, b, a, out);
158
+ }
159
+
122
160
if (can_use_optimized_path (a, b, out)) {
123
161
// Resize for dynamic shape
124
162
auto error = resize_tensor (out, a.sizes ());
0 commit comments