Skip to content

Commit fc74ec4

Browse files
cyyevermalfet
authored andcommitted
[2/N] Avoid copy in std::get (pytorch#141826)
Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#141826 Approved by: https://github.com/Skylion007, https://github.com/malfet Co-authored-by: Nikita Shulga <[email protected]>
1 parent b2fe1b9 commit fc74ec4

File tree

12 files changed

+57
-110
lines changed

12 files changed

+57
-110
lines changed

aten/src/ATen/functorch/BatchRulesConvolution.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -205,43 +205,41 @@ convolution_backward_input_batch_rule(
205205
const auto result = at::convolution_backward_symint(
206206
grad_output, dummy_input, weight_, std::nullopt, stride, padding,
207207
dilation, transposed, output_padding, groups, mask);
208-
const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
209-
return std::make_tuple(grad_input, 1);
208+
auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
209+
return std::make_tuple(std::move(grad_input), 1);
210210
}
211211
Tensor grad_input;
212212
if (!transposed) {
213213
// N(GO), B(GO)I -> N(GO), (GO)(BI) -> N(GBI)
214214
const auto weight_ = reshape_dim_into(*weight_bdim, 1, weight);
215215
auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
216-
const auto result = at::convolution_backward_symint(
216+
grad_input = std::get<0>(at::convolution_backward_symint(
217217
grad_output, dummy_input, weight_, std::nullopt, stride, padding,
218-
dilation, transposed, output_padding, groups, mask);
219-
grad_input = std::get<0>(result); // N(GBI)
218+
dilation, transposed, output_padding, groups, mask)); // N(GBI)
220219
} else {
221220
// N(GO), B(GI)O -> N(GO), (GBI)O -> N(GBI)
222221
auto weight_ = moveBatchDimToFront(weight, weight_bdim); // B(GI)O
223222
weight_ = reshape_dim_outof_symint(1, groups, weight_); // BGIO
224223
weight_ = weight_.transpose(0, 1); // GBIO
225224
weight_ = weight_.flatten(0, 2); // (GBI)O
226225
const auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
227-
const auto result = at::convolution_backward_symint(
226+
grad_input = std::get<0>(at::convolution_backward_symint(
228227
grad_output, dummy_input, weight_, std::nullopt, stride, padding,
229-
dilation, transposed, output_padding, groups, mask);
230-
grad_input = std::get<0>(result); // N(GBI)
228+
dilation, transposed, output_padding, groups, mask)); // N(GBI)
231229
}
232230
// N(GBI) -> NG(BI) -> NGBI -> NBGI -> NB(GI)
233231
grad_input = reshape_dim_outof_symint(1, groups, grad_input);
234232
grad_input = reshape_dim_outof_symint(2, batch_size, grad_input);
235233
grad_input = grad_input.transpose(1, 2);
236234
grad_input = reshape_dim_into(2, 2, grad_input);
237-
return std::make_tuple(grad_input, 1);
235+
return std::make_tuple(std::move(grad_input), 1);
238236
} else {
239237
TORCH_INTERNAL_ASSERT(input_bdim);
240238
const auto dummy_input = make_dummy(input, input_bdim, 0, 1);
241-
const auto result = at::convolution_backward_symint(
239+
auto result = at::convolution_backward_symint(
242240
grad_output, dummy_input, weight, std::nullopt, stride, padding,
243241
dilation, transposed, output_padding, groups, mask);
244-
return std::make_tuple(std::get<0>(result), std::nullopt);
242+
return std::make_tuple(std::move(std::get<0>(result)), std::nullopt);
245243
}
246244
}
247245
static std::tuple<Tensor, std::optional<int64_t>>
@@ -258,12 +256,12 @@ convolution_backward_weight_batch_rule(
258256
const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
259257
const auto input_ = reshape_dim_into(*input_bdim, 1, input);
260258
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
261-
const auto result = at::convolution_backward_symint(
259+
auto result = at::convolution_backward_symint(
262260
grad_output_, input_, dummy_weight, std::nullopt, stride, padding,
263261
dilation, transposed, output_padding, groups * batch_size, mask);
264-
auto grad_weight = std::get<1>(result);
262+
auto& grad_weight = std::get<1>(result);
265263
grad_weight = reshape_dim_outof_symint(0, batch_size, grad_weight);
266-
return std::make_tuple(grad_weight, 0);
264+
return std::make_tuple(std::move(grad_weight), 0);
267265
} else if (grad_output_bdim && !input_bdim) {
268266
const auto batch_size = grad_output.size(*grad_output_bdim);
269267
if (groups == 1) {
@@ -327,10 +325,10 @@ convolution_backward_weight_batch_rule(
327325
if (!transposed) {
328326
// regular: N(GO), BN(GI) -> N(GO), N(GBI) -> (GO)(BI)
329327
const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
330-
const auto result = at::convolution_backward_symint(
328+
auto result = at::convolution_backward_symint(
331329
grad_output, input_, dummy_weight, std::nullopt, stride, padding,
332330
dilation, transposed, output_padding, groups, mask);
333-
auto grad_weight = std::get<1>(result);
331+
auto& grad_weight = std::get<1>(result);
334332
grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight);
335333
return std::make_tuple(grad_weight, 1);
336334
} else {
@@ -423,23 +421,23 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
423421
Tensor grad_input;
424422
if (output_mask[0]) {
425423
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
426-
auto result = convolution_backward_input_batch_rule(
424+
auto [tensor, bdim] = convolution_backward_input_batch_rule(
427425
grad_output, grad_output_bdim,
428426
input, input_bdim,
429427
weight, weight_bdim,
430428
stride, padding, dilation, transposed, output_padding, groups);
431-
grad_input = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
429+
grad_input = makeBatched(tensor, bdim, cur_level);
432430
}
433431

434432
Tensor grad_weight;
435433
if (output_mask[1]) {
436434
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
437-
const auto result = convolution_backward_weight_batch_rule(
435+
auto [tensor, bdim] = convolution_backward_weight_batch_rule(
438436
grad_output, grad_output_bdim,
439437
input, input_bdim,
440438
weight, weight_bdim,
441439
stride, padding, dilation, transposed, output_padding, groups);
442-
grad_weight = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
440+
grad_weight = makeBatched(tensor, bdim, cur_level);
443441
}
444442
return std::make_tuple(grad_input, grad_weight, grad_bias);
445443

aten/src/ATen/functorch/BatchRulesModules.cpp

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,14 @@ grid_sample_backward_helper_in(
161161

162162
static std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
163163
grid_sample_backward_helper_out(
164-
const std::tuple<Tensor, Tensor> & bw_out,
164+
std::tuple<Tensor, Tensor> bw_out,
165165
std::optional<int64_t> grad_input_out_bdim,
166166
std::optional<int64_t> grad_grid_out_bdim,
167167
int64_t bdim_size) {
168-
auto grad_input = std::get<0>(bw_out);
169-
auto grad_grid = std::get<1>(bw_out);
168+
auto& [grad_input, grad_grid] = bw_out;
170169
grad_input = reshape_dim_outof(*grad_input_out_bdim, bdim_size, grad_input);
171170
grad_grid = reshape_dim_outof(*grad_grid_out_bdim, bdim_size, grad_grid);
172-
auto result = std::make_tuple(grad_input, grad_input_out_bdim, grad_grid, grad_grid_out_bdim);
173-
return result;
171+
return std::make_tuple(std::move(grad_input), grad_input_out_bdim, std::move(grad_grid), grad_grid_out_bdim);
174172
}
175173

176174

@@ -185,34 +183,26 @@ grid_sample_backward_batch_rule(
185183
auto new_bw_input = grid_sample_backward_helper_in(
186184
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
187185

188-
auto new_grad_output = std::get<0>(new_bw_input);
189-
auto new_input = std::get<1>(new_bw_input);
190-
auto new_grid = std::get<2>(new_bw_input);
191-
int64_t batch_size = std::get<3>(new_bw_input);
186+
auto [new_grad_output, new_input, new_grid, batch_size] = new_bw_input;
192187

193-
auto bw_out = Func(new_grad_output, new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);
188+
auto bw_out = Func(std::move(new_grad_output), std::move(new_input), std::move(new_grid), std::forward<ExtraArgs>(extra_args)...);
194189

195-
return grid_sample_backward_helper_out(bw_out, 0, 0, batch_size);
190+
return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, batch_size);
196191
}
197192

198193
template<typename F, F Func>
199194
std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
200-
cudnn_grid_sample_backward_batch_rule(
195+
static cudnn_grid_sample_backward_batch_rule(
201196
const Tensor& input, std::optional<int64_t> input_bdim,
202197
const Tensor& grid, std::optional<int64_t> grid_bdim,
203198
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim) {
204199

205-
auto new_bw_input = grid_sample_backward_helper_in(
200+
auto [new_grad_output,new_input,new_grid,bdim_size]= grid_sample_backward_helper_in(
206201
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
207202

208-
auto new_grad_output = std::get<0>(new_bw_input);
209-
auto new_input = std::get<1>(new_bw_input);
210-
auto new_grid = std::get<2>(new_bw_input);
211-
int64_t bdim_size = std::get<3>(new_bw_input);
212-
213-
auto bw_out = Func(new_input, new_grid, new_grad_output);
203+
auto bw_out = Func(std::move(new_input), std::move(new_grid), std::move(new_grad_output));
214204

215-
return grid_sample_backward_helper_out(bw_out, 0, 0, bdim_size);
205+
return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
216206
}
217207

218208
// TODO: replace with targetable functionalization

aten/src/ATen/functorch/BatchRulesScatterOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ static std::vector<std::optional<Tensor>> batchIndices(
8888
bool indices_batched = any_has_value(indices_bdims);
8989

9090
for (size_t i = 0; i < indices.size(); i++) {
91-
auto index = indices[i];
91+
auto const & index = indices[i];
9292
if (index.has_value() && index->sym_numel() != 0) {
9393
const auto idx_bdim = indices_bdims[i];
9494
indices_.emplace_back(maybePadToLogicalRank(moveBatchDimToFront(index.value(), idx_bdim), idx_bdim, maxLogicalRank));

aten/src/ATen/native/Math.h

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,15 +1476,11 @@ calc_i0(T _x) {
14761476
T x = std::abs(_x);
14771477

14781478
if (x <= T{8.0}) {
1479-
auto coeff_pair = chebyshev_coefficients_i0e_A<T>();
1480-
auto A = std::get<0>(coeff_pair);
1481-
auto len = std::get<1>(coeff_pair);
1479+
auto [A, len] = chebyshev_coefficients_i0e_A<T>();
14821480
T y = (x / T{2.0}) - T{2.0};
14831481
return static_cast<T>(std::exp(x) * chbevl(y, A, len));
14841482
}
1485-
auto coeff_pair = chebyshev_coefficients_i0e_B<T>();
1486-
auto B = std::get<0>(coeff_pair);
1487-
auto len = std::get<1>(coeff_pair);
1483+
auto [B, len] = chebyshev_coefficients_i0e_B<T>();
14881484
return std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x);
14891485
}
14901486

@@ -1507,16 +1503,12 @@ calc_i1(T _x) {
15071503
T x = std::abs(_x);
15081504

15091505
if (x <= T{8.0}) {
1510-
auto coeff_pair = chebyshev_coefficients_i1e_A<T>();
1511-
auto A = std::get<0>(coeff_pair);
1512-
auto len = std::get<1>(coeff_pair);
1506+
auto [A, len] = chebyshev_coefficients_i1e_A<T>();
15131507
T y = (x / T{2.0}) - T{2.0};
15141508
const T out = std::exp(x) * x * chbevl(y, A, len);
15151509
return (_x < T{0.0}) ? -out : out;
15161510
}
1517-
auto coeff_pair = chebyshev_coefficients_i1e_B<T>();
1518-
auto B = std::get<0>(coeff_pair);
1519-
auto len = std::get<1>(coeff_pair);
1511+
auto [B, len] = chebyshev_coefficients_i1e_B<T>();
15201512
const T out = (std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len)) / std::sqrt(x);
15211513
return (_x < T{0.0}) ? -out : out;
15221514
}
@@ -1541,16 +1533,12 @@ calc_i1e(T _x) {
15411533
T x = std::abs(_x);
15421534

15431535
if (x <= T{8.0}) {
1544-
auto coeff_pair = chebyshev_coefficients_i1e_A<T>();
1545-
auto A = std::get<0>(coeff_pair);
1546-
auto len = std::get<1>(coeff_pair);
1536+
auto [A, len] = chebyshev_coefficients_i1e_A<T>();
15471537
T y = (x / T{2.0}) - T{2.0};
15481538
const T out = chbevl(y, A, len) * x;
15491539
return (_x < T{0.0}) ? -out : out;
15501540
}
1551-
auto coeff_pair = chebyshev_coefficients_i1e_B<T>();
1552-
auto B = std::get<0>(coeff_pair);
1553-
auto len = std::get<1>(coeff_pair);
1541+
auto [B, len] = chebyshev_coefficients_i1e_B<T>();
15541542
const auto out = chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x);
15551543
return (_x < T{0.0}) ? -out : out;
15561544
}

aten/src/ATen/native/cuda/Math.cuh

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3210,16 +3210,12 @@ static inline C10_HOST_DEVICE scalar_t calc_i0(scalar_t _x) {
32103210
scalar_t x = ::abs(_x);
32113211

32123212
if (x <= scalar_t{8.0}) {
3213-
auto coeff_pair = chebyshev_coefficients_i0e_A<scalar_t>();
3214-
auto A = std::get<0>(coeff_pair);
3215-
auto len = std::get<1>(coeff_pair);
3213+
auto [A, len] = chebyshev_coefficients_i0e_A<scalar_t>();
32163214
scalar_t y = (x / scalar_t{2.0}) - scalar_t{2.0};
32173215
return (::exp(x) * chbevl(y, A, len));
32183216
}
32193217

3220-
auto coeff_pair = chebyshev_coefficients_i0e_B<scalar_t>();
3221-
auto B = std::get<0>(coeff_pair);
3222-
auto len = std::get<1>(coeff_pair);
3218+
auto [B, len] = chebyshev_coefficients_i0e_B<scalar_t>();
32233219
return (::exp(x) * chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len) / ::sqrt(x));
32243220
}
32253221

@@ -3334,17 +3330,13 @@ template <typename scalar_t>
33343330
static inline C10_HOST_DEVICE scalar_t calc_i1(scalar_t _x) {
33353331
const auto x = ::abs(_x);
33363332
if (x <= scalar_t{8.0}) {
3337-
auto coeff_pair = chebyshev_coefficients_i1e_A<scalar_t>();
3338-
auto A = std::get<0>(coeff_pair);
3339-
auto len = std::get<1>(coeff_pair);
3333+
auto [A, len] = chebyshev_coefficients_i1e_A<scalar_t>();
33403334
scalar_t y = x / scalar_t{2.0} - scalar_t{2.0};
33413335
const scalar_t out = ::exp(x) * x * chbevl(y, A, len);
33423336
return (_x < scalar_t{0.0}) ? -out : out;
33433337
}
33443338

3345-
auto coeff_pair = chebyshev_coefficients_i1e_B<scalar_t>();
3346-
auto B = std::get<0>(coeff_pair);
3347-
auto len = std::get<1>(coeff_pair);
3339+
auto [B, len] = chebyshev_coefficients_i1e_B<scalar_t>();
33483340
const scalar_t out = (::exp(x) * chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len)) / ::sqrt(x);
33493341
return (_x < scalar_t{0.0}) ? -out : out;
33503342
}
@@ -3353,17 +3345,13 @@ template <typename scalar_t>
33533345
static inline C10_HOST_DEVICE scalar_t calc_i1e(scalar_t _x) {
33543346
const auto x = ::abs(_x);
33553347
if (x <= scalar_t{8.0}) {
3356-
auto coeff_pair = chebyshev_coefficients_i1e_A<scalar_t>();
3357-
auto A = std::get<0>(coeff_pair);
3358-
auto len = std::get<1>(coeff_pair);
3348+
auto [A, len] = chebyshev_coefficients_i1e_A<scalar_t>();
33593349
const scalar_t y = x / scalar_t{2.0} - scalar_t{2.0};
33603350
const scalar_t out = chbevl(y, A, len) * x;
33613351
return (_x < scalar_t{0.0}) ? -out : out;
33623352
}
33633353

3364-
auto coeff_pair = chebyshev_coefficients_i1e_B<scalar_t>();
3365-
auto B = std::get<0>(coeff_pair);
3366-
auto len = std::get<1>(coeff_pair);
3354+
auto [B, len] = chebyshev_coefficients_i1e_B<scalar_t>();
33673355
const scalar_t out = chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len) / ::sqrt(x);
33683356
return (_x < scalar_t{0.0}) ? -out : out;
33693357
}

aten/src/ATen/native/cuda/Shape.cu

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -378,22 +378,15 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
378378
if (max_elements_per_tensor == 0)
379379
continue;
380380

381-
dim3 applyBlock, catGrid;
382-
383381
#ifdef USE_ROCM
384382
// always base grid size on max_elements_per_tensor
385-
{
386-
std::tuple<dim3, dim3> launchParams = getCatGridRocm<scalar_t>(
383+
auto [catGrid, applyBlock] = getCatGridRocm<scalar_t>(
387384
max_elements_per_tensor, batchCounter);
388-
catGrid = std::get<0>(launchParams);
389-
applyBlock = std::get<1>(launchParams);
390-
}
391385
#else
386+
dim3 applyBlock, catGrid;
392387
if (isContig && sizeof(scalar_t) > 2) {
393-
std::tuple<dim3, dim3> launchParams = getCatGridContig<scalar_t>(
388+
std::tie(catGrid, applyBlock) = getCatGridContig<scalar_t>(
394389
max_elements_per_tensor, batchCounter);
395-
catGrid = std::get<0>(launchParams);
396-
applyBlock = std::get<1>(launchParams);
397390
} else {
398391
applyBlock = dim3(32 * 16);
399392
getCatGrid(batchCounter, catGrid);

aten/src/ATen/native/cudnn/RNN.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2709,8 +2709,7 @@ void lstm_cudnn(
27092709
bidirectional,
27102710
batch_first);
27112711
output = result.first;
2712-
hy = std::get<0>(result.second);
2713-
cy = std::get<1>(result.second);
2712+
std::tie(hy, cy) = result.second;
27142713
}
27152714

27162715
void lstm_packed_cudnn(
@@ -2738,8 +2737,7 @@ void lstm_packed_cudnn(
27382737
train,
27392738
bidirectional);
27402739
output = result.first;
2741-
hy = std::get<0>(result.second);
2742-
cy = std::get<1>(result.second);
2740+
std::tie(hy, cy) = result.second;
27432741
}
27442742

27452743
REGISTER_CUDA_DISPATCH(lstm_cudnn_stub, &lstm_cudnn)

aten/src/ATen/native/group_norm.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,17 @@
1919
#endif
2020

2121
#include <array>
22-
#include <functional>
2322
#include <tuple>
2423
#include <vector>
2524

2625
namespace at::native {
2726

2827
template <typename T>
29-
void check_group_norm_inputs(
28+
static void check_group_norm_inputs(
3029
const Tensor& input,
3130
const Tensor& weight,
3231
const Tensor& bias,
33-
T C,
32+
const T& C,
3433
int64_t num_groups) {
3534
TORCH_CHECK(
3635
num_groups > 0,
@@ -237,8 +236,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> math_group_norm(
237236
/*training=*/true,
238237
/*momentum=*/0,
239238
eps);
240-
at::Tensor out = std::get<0>(outputs);
241-
out = out.view(input_shape);
239+
auto out = std::get<0>(outputs).view(input_shape);
242240
std::vector<int64_t> affine_param_shape(input.dim(), 1);
243241
affine_param_shape[1] = C;
244242
if (weight.defined() && bias.defined()) {
@@ -253,6 +251,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> math_group_norm(
253251
// This follows the same behavior as the CPU and CUDA kernels.
254252
at::Tensor mean = std::get<1>(outputs).to(c10::TensorOptions().dtype(input.scalar_type())).view({N, group});
255253
at::Tensor rstd = std::get<2>(outputs).to(c10::TensorOptions().dtype(input.scalar_type())).view({N, group});
256-
return std::make_tuple(out, mean, rstd);
254+
return std::make_tuple(std::move(out), std::move(mean), std::move(rstd));
257255
}
258256
} // namespace at::native

aten/src/ATen/native/layer_norm.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,7 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
241241
auto outputs = at::native_batch_norm(
242242
input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{},
243243
/*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps);
244-
at::Tensor out = std::get<0>(outputs);
245-
out = out.view(input_shape);
244+
auto out = std::get<0>(outputs).view(input_shape);
246245
if (weight.defined() && bias.defined()) {
247246
out = bias.addcmul(out, weight, 1);
248247
} else if (weight.defined()) {
@@ -297,7 +296,7 @@ Tensor rms_norm_symint(
297296
c10::ScalarType opmath_t = toOpMathType(input.scalar_type());
298297
Tensor upcasted_input = input.to(opmath_t);
299298

300-
Tensor rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val));
299+
auto rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));
301300
Tensor result = upcasted_input.mul(rqrst_input).type_as(input);
302301

303302
if (weight_opt.has_value()) {

0 commit comments

Comments
 (0)