@@ -106,8 +106,7 @@ convolution_batch_rule(const Tensor& lhs, std::optional<int64_t> lhs_bdim, const
106
106
result = std::make_tuple (at::convolution_symint (lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), std::nullopt);
107
107
}
108
108
if (separate_bias) {
109
- auto A = std::get<0 >(result);
110
- auto A_batch_dim = std::get<1 >(result);
109
+ auto & [A, A_batch_dim] = result;
111
110
auto B = *bias;
112
111
auto B_batch_dim = bias_bdim;
113
112
A = moveBatchDimToFront (A, A_batch_dim);
@@ -273,12 +272,12 @@ convolution_backward_weight_batch_rule(
273
272
const auto grad_output_ = reshape_dim_into (*grad_output_bdim, 1 , grad_output);
274
273
const auto out_ch_dim = transposed ? 1 : 0 ;
275
274
const auto dummy_weight = make_dummy (weight, weight_bdim, out_ch_dim, batch_size);
276
- const auto result = at::convolution_backward_symint (
275
+ auto result = at::convolution_backward_symint (
277
276
grad_output_, input, dummy_weight, std::nullopt, stride, padding,
278
277
dilation, transposed, output_padding, groups, mask);
279
- auto grad_weight = std::get<1 >(result);
278
+ auto & grad_weight = std::get<1 >(result);
280
279
grad_weight = reshape_dim_outof_symint (out_ch_dim, batch_size, grad_weight);
281
- return std::make_tuple (grad_weight, out_ch_dim);
280
+ return std::make_tuple (std::move ( grad_weight) , out_ch_dim);
282
281
} else {
283
282
auto grad_output_ = moveBatchDimToFront (grad_output, grad_output_bdim); // BN(GO)
284
283
grad_output_ = reshape_dim_outof_symint (2 , groups, grad_output_); // BNGO
@@ -287,23 +286,23 @@ convolution_backward_weight_batch_rule(
287
286
if (!transposed) {
288
287
// BN(GO), N(GI) -> N(GBO), N(GI) -> (GBO)I
289
288
const auto dummy_weight = make_dummy (weight, weight_bdim, 0 , batch_size);
290
- const auto result = at::convolution_backward_symint (
289
+ auto result = at::convolution_backward_symint (
291
290
grad_output_, input, dummy_weight, std::nullopt, stride, padding,
292
291
dilation, transposed, output_padding, groups, mask);
293
- auto grad_weight = std::get<1 >(result);
292
+ auto & grad_weight = std::get<1 >(result);
294
293
grad_weight = grad_weight.unflatten_symint (0 , { groups, batch_size, -1 }); // GBOI
295
294
grad_weight = grad_weight.transpose (0 , 1 ); // BGOI
296
295
grad_weight = grad_weight.flatten (1 , 2 ); // B(GO)I
297
- return std::make_tuple (grad_weight, 0 );
296
+ return std::make_tuple (std::move ( grad_weight) , 0 );
298
297
} else {
299
298
// BN(GO), N(GI) -> N(GBO), N(GI) -> (GI)(BO)
300
299
const auto dummy_weight = make_dummy (weight, weight_bdim, 1 , batch_size);
301
- const auto result = at::convolution_backward_symint (
300
+ auto result = at::convolution_backward_symint (
302
301
grad_output_, input, dummy_weight, std::nullopt, stride, padding,
303
302
dilation, transposed, output_padding, groups, mask);
304
- auto grad_weight = std::get<1 >(result);
303
+ auto & grad_weight = std::get<1 >(result);
305
304
grad_weight = reshape_dim_outof_symint (1 , batch_size, grad_weight);
306
- return std::make_tuple (grad_weight, 1 );
305
+ return std::make_tuple (std::move ( grad_weight) , 1 );
307
306
}
308
307
}
309
308
} else if (!grad_output_bdim && input_bdim) {
@@ -314,12 +313,12 @@ convolution_backward_weight_batch_rule(
314
313
const auto input_ = reshape_dim_into (*input_bdim, 1 , input);
315
314
const auto in_ch_dim = transposed ? 0 : 1 ;
316
315
const auto dummy_weight = make_dummy (weight, weight_bdim, in_ch_dim, batch_size);
317
- const auto result = at::convolution_backward_symint (
316
+ auto result = at::convolution_backward_symint (
318
317
grad_output, input_, dummy_weight, std::nullopt, stride, padding,
319
318
dilation, transposed, output_padding, groups, mask);
320
- auto grad_weight = std::get<1 >(result);
319
+ auto & grad_weight = std::get<1 >(result);
321
320
grad_weight = reshape_dim_outof_symint (in_ch_dim, batch_size, grad_weight);
322
- return std::make_tuple (grad_weight, in_ch_dim);
321
+ return std::make_tuple (std::move ( grad_weight) , in_ch_dim);
323
322
} else {
324
323
auto input_ = moveBatchDimToFront (input, input_bdim); // BN(GI)
325
324
input_ = reshape_dim_outof_symint (2 , groups, input_); // BNGI
@@ -337,23 +336,23 @@ convolution_backward_weight_batch_rule(
337
336
} else {
338
337
// transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O
339
338
const auto dummy_weight = make_dummy (weight, weight_bdim, 0 , batch_size);
340
- const auto result = at::convolution_backward_symint (
339
+ auto result = at::convolution_backward_symint (
341
340
grad_output, input_, dummy_weight, std::nullopt, stride, padding,
342
341
dilation, transposed, output_padding, groups, mask);
343
- auto grad_weight = std::get<1 >(result);
342
+ auto & grad_weight = std::get<1 >(result);
344
343
grad_weight = grad_weight.unflatten_symint (0 , { groups, batch_size, -1 }); // GBIO
345
344
grad_weight = grad_weight.transpose (0 , 1 ); // BGIO
346
345
grad_weight = grad_weight.flatten (1 , 2 ); // B(GI)O
347
- return std::make_tuple (grad_weight, 0 );
346
+ return std::make_tuple (std::move ( grad_weight) , 0 );
348
347
}
349
348
}
350
349
} else {
351
350
TORCH_INTERNAL_ASSERT (weight_bdim);
352
351
const auto dummy_weight = make_dummy (weight, weight_bdim, 0 , 1 );
353
- const auto result = at::convolution_backward_symint (
352
+ auto result = at::convolution_backward_symint (
354
353
grad_output, input, dummy_weight, std::nullopt, stride, padding,
355
354
dilation, transposed, output_padding, groups, mask);
356
- return std::make_tuple (std::get<1 >(result), std::nullopt);
355
+ return std::make_tuple (std::move (std:: get<1 >(result) ), std::nullopt);
357
356
358
357
}
359
358
}
@@ -424,7 +423,7 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
424
423
Tensor grad_input;
425
424
if (output_mask[0 ]) {
426
425
c10::impl::ExcludeDispatchKeyGuard guard (DispatchKey::FuncTorchBatched);
427
- const auto result = convolution_backward_input_batch_rule (
426
+ auto result = convolution_backward_input_batch_rule (
428
427
grad_output, grad_output_bdim,
429
428
input, input_bdim,
430
429
weight, weight_bdim,
0 commit comments