Skip to content

Commit 96be048

Browse files
cyyeverpytorchmergebot
authored andcommitted
[1/N] Avoid copy in std::get (pytorch#141812)
Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#141812 Approved by: https://github.com/Skylion007
1 parent c2fa544 commit 96be048

File tree

18 files changed

+138
-202
lines changed

18 files changed

+138
-202
lines changed

aten/src/ATen/functorch/BatchRulesBinaryOps.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,15 @@
1414
namespace at::functorch {
1515

1616
template <typename F, F Func, typename... ExtraArgs>
17-
std::tuple<Tensor, std::optional<int64_t>> _binary_pointwise_batch_rule(
17+
static Tensor _binary_pointwise_batch_rule(
1818
const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
1919
const Tensor& other, std::optional<int64_t> other_batch_dim,
2020
ExtraArgs... extra_args) {
2121

22-
auto tensor_other = _binary_pointwise_helper(
22+
auto [tensor_, other_]= _binary_pointwise_helper(
2323
tensor, tensor_batch_dim, other, other_batch_dim);
24-
auto tensor_ = std::get<0>(tensor_other);
25-
auto other_ = std::get<1>(tensor_other);
2624

27-
auto result = Func(tensor_, other_, std::forward<ExtraArgs>(extra_args)...);
28-
return std::make_tuple(result, 0);
25+
return Func(tensor_, std::move(other_), std::forward<ExtraArgs>(extra_args)...);
2926
}
3027

3128
template <typename A, A a, typename C>
@@ -37,9 +34,9 @@ struct BinaryPointwiseBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
3734
const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
3835
const Tensor& other, std::optional<int64_t> other_batch_dim,
3936
T... extra_args) {
40-
return _binary_pointwise_batch_rule<F, Func, T...>(
37+
return std::tuple(_binary_pointwise_batch_rule<F, Func, T...>(
4138
tensor, tensor_batch_dim, other, other_batch_dim,
42-
std::forward<T>(extra_args)...);
39+
std::forward<T>(extra_args)...), 0);
4340
}
4441
};
4542

@@ -82,7 +79,7 @@ struct BinaryRandomPointwiseBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
8279
auto res = _binary_pointwise_batch_rule<F, Func, T...>(
8380
tensor_value, tensor_bdim, other_value, other_bdim,
8481
std::forward<T>(extra_args)...);
85-
return makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
82+
return makeBatched(std::move(res), 0, cur_level);
8683
}
8784
};
8885

@@ -93,7 +90,7 @@ struct BinaryRandomPointwiseBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
9390
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
9491

9592
template <typename M, M Meth, typename... ExtraArgs>
96-
void binary_pointwise_inplace_batch_rule(
93+
static void binary_pointwise_inplace_batch_rule(
9794
Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
9895
const Tensor& other, std::optional<int64_t> other_batch_dim,
9996
ExtraArgs... extra_args) {
@@ -120,7 +117,7 @@ void binary_pointwise_inplace_batch_rule(
120117
}
121118

122119
template <typename F, F Func>
123-
std::tuple<Tensor, std::optional<int64_t>> comparison_pointwise_batch_rule(
120+
static std::tuple<Tensor, std::optional<int64_t>> comparison_pointwise_batch_rule(
124121
const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
125122
const Tensor& other, std::optional<int64_t> other_batch_dim) {
126123
// compute max logical rank
@@ -165,9 +162,7 @@ static std::tuple<Tensor, std::optional<int64_t>> gelu_backward_batch_rule(
165162
c10::string_view approximate) {
166163

167164
// repeat the preprocessing from _binary_pointwise_batch_rule
168-
const auto tensor_other = _binary_pointwise_helper(grad_out, grad_out_bdim, input, input_bdim);
169-
auto grad_out_ = std::get<0>(tensor_other);
170-
auto input_ = std::get<1>(tensor_other);
165+
auto [grad_out_, input_]= _binary_pointwise_helper(grad_out, grad_out_bdim, input, input_bdim);
171166

172167
// gelu_backward doesn't broadcast well so we need to insist all inputs have a bdim
173168
const auto batch_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim);
@@ -243,8 +238,8 @@ static std::tuple<Tensor, std::optional<int64_t>> cdist_backward_batch_rule(
243238
// We need to apply the same preprocessing on x1 and x2 as in the forward pass
244239
// _binary_pointwise_batch_rule
245240
auto x12 = _binary_pointwise_helper(x1_, x1_bdim, x2, x2_bdim);
246-
x1_ = std::get<0>(x12);
247-
auto x2_ = std::get<1>(x12);
241+
x1_ = std::move(std::get<0>(x12));
242+
auto& x2_ = std::get<1>(x12);
248243

249244
auto grad_ = moveBatchDimToFront(grad, grad_bdim);
250245
if ((x1_bdim || x2_bdim) && !grad_bdim) {

aten/src/ATen/functorch/BatchRulesConvolution.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ convolution_batch_rule(const Tensor& lhs, std::optional<int64_t> lhs_bdim, const
106106
result = std::make_tuple(at::convolution_symint(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), std::nullopt);
107107
}
108108
if (separate_bias) {
109-
auto A = std::get<0>(result);
110-
auto A_batch_dim = std::get<1>(result);
109+
auto& [A, A_batch_dim] = result;
111110
auto B = *bias;
112111
auto B_batch_dim = bias_bdim;
113112
A = moveBatchDimToFront(A, A_batch_dim);
@@ -273,12 +272,12 @@ convolution_backward_weight_batch_rule(
273272
const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
274273
const auto out_ch_dim = transposed ? 1 : 0;
275274
const auto dummy_weight = make_dummy(weight, weight_bdim, out_ch_dim, batch_size);
276-
const auto result = at::convolution_backward_symint(
275+
auto result = at::convolution_backward_symint(
277276
grad_output_, input, dummy_weight, std::nullopt, stride, padding,
278277
dilation, transposed, output_padding, groups, mask);
279-
auto grad_weight = std::get<1>(result);
278+
auto& grad_weight = std::get<1>(result);
280279
grad_weight = reshape_dim_outof_symint(out_ch_dim, batch_size, grad_weight);
281-
return std::make_tuple(grad_weight, out_ch_dim);
280+
return std::make_tuple(std::move(grad_weight), out_ch_dim);
282281
} else {
283282
auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim); // BN(GO)
284283
grad_output_ = reshape_dim_outof_symint(2, groups, grad_output_); // BNGO
@@ -287,23 +286,23 @@ convolution_backward_weight_batch_rule(
287286
if (!transposed) {
288287
// BN(GO), N(GI) -> N(GBO), N(GI) -> (GBO)I
289288
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
290-
const auto result = at::convolution_backward_symint(
289+
auto result = at::convolution_backward_symint(
291290
grad_output_, input, dummy_weight, std::nullopt, stride, padding,
292291
dilation, transposed, output_padding, groups, mask);
293-
auto grad_weight = std::get<1>(result);
292+
auto& grad_weight = std::get<1>(result);
294293
grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBOI
295294
grad_weight = grad_weight.transpose(0, 1); // BGOI
296295
grad_weight = grad_weight.flatten(1, 2); // B(GO)I
297-
return std::make_tuple(grad_weight, 0);
296+
return std::make_tuple(std::move(grad_weight), 0);
298297
} else {
299298
// BN(GO), N(GI) -> N(GBO), N(GI) -> (GI)(BO)
300299
const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
301-
const auto result = at::convolution_backward_symint(
300+
auto result = at::convolution_backward_symint(
302301
grad_output_, input, dummy_weight, std::nullopt, stride, padding,
303302
dilation, transposed, output_padding, groups, mask);
304-
auto grad_weight = std::get<1>(result);
303+
auto& grad_weight = std::get<1>(result);
305304
grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight);
306-
return std::make_tuple(grad_weight, 1);
305+
return std::make_tuple(std::move(grad_weight), 1);
307306
}
308307
}
309308
} else if (!grad_output_bdim && input_bdim) {
@@ -314,12 +313,12 @@ convolution_backward_weight_batch_rule(
314313
const auto input_ = reshape_dim_into(*input_bdim, 1, input);
315314
const auto in_ch_dim = transposed ? 0 : 1;
316315
const auto dummy_weight = make_dummy(weight, weight_bdim, in_ch_dim, batch_size);
317-
const auto result = at::convolution_backward_symint(
316+
auto result = at::convolution_backward_symint(
318317
grad_output, input_, dummy_weight, std::nullopt, stride, padding,
319318
dilation, transposed, output_padding, groups, mask);
320-
auto grad_weight = std::get<1>(result);
319+
auto& grad_weight = std::get<1>(result);
321320
grad_weight = reshape_dim_outof_symint(in_ch_dim, batch_size, grad_weight);
322-
return std::make_tuple(grad_weight, in_ch_dim);
321+
return std::make_tuple(std::move(grad_weight), in_ch_dim);
323322
} else {
324323
auto input_ = moveBatchDimToFront(input, input_bdim); // BN(GI)
325324
input_ = reshape_dim_outof_symint(2, groups, input_); // BNGI
@@ -337,23 +336,23 @@ convolution_backward_weight_batch_rule(
337336
} else {
338337
// transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O
339338
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
340-
const auto result = at::convolution_backward_symint(
339+
auto result = at::convolution_backward_symint(
341340
grad_output, input_, dummy_weight, std::nullopt, stride, padding,
342341
dilation, transposed, output_padding, groups, mask);
343-
auto grad_weight = std::get<1>(result);
342+
auto& grad_weight = std::get<1>(result);
344343
grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBIO
345344
grad_weight = grad_weight.transpose(0, 1); // BGIO
346345
grad_weight = grad_weight.flatten(1, 2); // B(GI)O
347-
return std::make_tuple(grad_weight, 0);
346+
return std::make_tuple(std::move(grad_weight), 0);
348347
}
349348
}
350349
} else {
351350
TORCH_INTERNAL_ASSERT(weight_bdim);
352351
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, 1);
353-
const auto result = at::convolution_backward_symint(
352+
auto result = at::convolution_backward_symint(
354353
grad_output, input, dummy_weight, std::nullopt, stride, padding,
355354
dilation, transposed, output_padding, groups, mask);
356-
return std::make_tuple(std::get<1>(result), std::nullopt);
355+
return std::make_tuple(std::move(std::get<1>(result)), std::nullopt);
357356

358357
}
359358
}
@@ -424,7 +423,7 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
424423
Tensor grad_input;
425424
if (output_mask[0]) {
426425
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
427-
const auto result = convolution_backward_input_batch_rule(
426+
auto result = convolution_backward_input_batch_rule(
428427
grad_output, grad_output_bdim,
429428
input, input_bdim,
430429
weight, weight_bdim,

aten/src/ATen/functorch/BatchRulesHelper.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::S
145145
const auto& ivalue = arguments[idx];
146146
if (ivalue.isTensor()) {
147147
auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
148-
tensor_inputs.emplace_back(tensor_value, tensor_bdim);
148+
tensor_inputs.emplace_back(std::move(tensor_value), tensor_bdim);
149149
tensor_pos.push_back(static_cast<int64_t>(idx));
150150
}
151151
}
@@ -220,8 +220,7 @@ inline void find_and_unpack_tensors(
220220
continue;
221221
}
222222
auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
223-
const auto& tensor_value = std::get<0>(unpacked);
224-
const auto tensor_bdim = std::get<1>(unpacked);
223+
const auto& [tensor_value, tensor_bdim] = unpacked;
225224
if (tensor_bdim.has_value()) {
226225
auto candidate_batch_size = tensor_value.size(*tensor_bdim);
227226
if (computed_batch_size == -1) {
@@ -265,13 +264,9 @@ inline void boxed_existing_bdim_all_batch_rule(
265264

266265
// for each tensor, ensure it has a bdim and reshape it.
267266
for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
268-
const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
269-
auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
267+
const auto& [value, bdim] = tensor_inputs[tensor_idx];
270268
auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
271-
if (!bdim.has_value()) {
272-
bdim = 0;
273-
}
274-
(*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(*bdim, 0, value_);
269+
(*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(bdim.value_or(0), 0, value_);
275270
}
276271

277272
op.callBoxed(stack);

0 commit comments

Comments
 (0)