@@ -119,6 +119,34 @@ Tensor& opt_mul_out(
119
119
ScalarType b_type = b.scalar_type ();
120
120
ScalarType out_type = out.scalar_type ();
121
121
122
+ if (b.numel () == 1 ) {
123
+ if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
124
+ auto error = resize_tensor (out, a.sizes ());
125
+ ET_KERNEL_CHECK_MSG (
126
+ ctx,
127
+ error == Error::Ok,
128
+ InvalidArgument,
129
+ out,
130
+ " Failed to resize output tensor." );
131
+ ET_SWITCH_REALB_TYPES (a_type, ctx, " mul.out" , CTYPE, [&]() {
132
+ ET_SWITCH_REALB_TYPES (b_type, ctx, " mul.out" , CTYPE_B, [&]() {
133
+ CTYPE_B b_val = *b.const_data_ptr <CTYPE_B>();
134
+ CTYPE b_casted = static_cast <CTYPE>(b_val);
135
+
136
+ using Vec = executorch::vec::Vectorized<CTYPE>;
137
+ executorch::vec::map<CTYPE>(
138
+ [b_casted](Vec x) { return x * Vec (b_casted); },
139
+ out.mutable_data_ptr <CTYPE>(),
140
+ a.const_data_ptr <CTYPE>(),
141
+ out.numel ());
142
+ });
143
+ });
144
+ return out;
145
+ }
146
+ } else if (a.numel () == 1 ) {
147
+ return opt_mul_out (ctx, b, a, out);
148
+ }
149
+
122
150
if (can_use_optimized_path (a, b, out)) {
123
151
// Resize for dynamic shape
124
152
auto error = resize_tensor (out, a.sizes ());
0 commit comments