Skip to content

Commit b4b64f7

Browse files
guilhermeleobaspytorchmergebot
authored andcommitted
Ensure tensors devices match on torch.index_put batch rule impl (pytorch#130479)
Pull Request resolved: pytorch#130479 Approved by: https://github.com/zou3519
1 parent 00d71b3 commit b4b64f7

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

aten/src/ATen/functorch/BatchRulesScatterOps.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -498,11 +498,18 @@ Tensor& index_put__plumbing(Tensor & self, const List<optional<Tensor>> & indice
498498
auto maybe_layer = maybeCurrentDynamicLayer();
499499
vmap_check_escaped(maybe_layer, "index_put__plumbing");
500500
int64_t cur_level = maybe_layer->layerId();
501-
if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) {
502-
return self.index_put_(indices, values, accumulate);
501+
502+
// on device mismatch, we can move 0d tensors to self device
503+
auto values_ = values;
504+
if (values.device() != self.device() && values.numel() == 1 && values.dim() == 0) {
505+
values_ = values.to(self.device());
506+
}
507+
508+
if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values_, cur_level)) {
509+
return self.index_put_(indices, values_, accumulate);
503510
}
504511
auto [self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim] =
505-
unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values, cur_level);
512+
unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values_, cur_level);
506513
index_put__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate);
507514
return self;
508515
}
@@ -645,11 +652,18 @@ Tensor index_put_plumbing(const Tensor & self, const List<optional<Tensor>> & in
645652
auto maybe_layer = maybeCurrentDynamicLayer();
646653
vmap_check_escaped(maybe_layer, "index_put_plumbing");
647654
int64_t cur_level = maybe_layer->layerId();
648-
if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) {
649-
return self.index_put(indices, values, accumulate);
655+
656+
// on device mismatch, we can move 0d tensors to self device
657+
auto values_ = values;
658+
if (values.device() != self.device() && values.numel() == 1 && values.dim() == 0) {
659+
values_ = values.to(self.device());
660+
}
661+
662+
if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values_, cur_level)) {
663+
return self.index_put(indices, values_, accumulate);
650664
}
651665
auto [self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim] =
652-
unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values, cur_level);
666+
unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values_, cur_level);
653667
auto results = index_put_batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate);
654668
return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
655669
}

test/functorch/test_vmap.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from torch.testing._internal.common_cuda import with_tf32_off
5050
from torch.testing._internal.common_device_type import (
5151
instantiate_device_type_tests,
52+
onlyCUDA,
5253
OpDTypes,
5354
ops,
5455
tol,
@@ -4793,6 +4794,21 @@ def f(x, gy):
47934794

47944795
self.vmap_outplace_test(f, (x, gy), {}, in_dims=(None, 0))
47954796

4797+
@onlyCUDA
4798+
@parametrize("inplace", [True, False])
4799+
def test_0d_tensor_index_put(self, device, inplace):
4800+
def f(t, idx, v):
4801+
fn = torch.index_put_ if inplace else torch.index_put
4802+
return fn(t, idx, v)
4803+
4804+
N = 2
4805+
t = torch.zeros((N, 5), device="cuda")
4806+
idx = torch.tensor([1, 3])
4807+
v = torch.tensor(1, dtype=t.dtype, device="cpu")
4808+
4809+
expected = torch.tensor([[0, 1, 0, 1, 0], [0, 1, 0, 1, 0]], dtype=t.dtype)
4810+
self.assertEqual(expected, vmap(f, in_dims=(0, None, None))(t, (idx,), v))
4811+
47964812
@parametrize("training", [True, False])
47974813
@parametrize("track_running_stats", [True, False])
47984814
@parametrize("affine", [True, False])

0 commit comments

Comments
 (0)