@@ -498,11 +498,18 @@ Tensor& index_put__plumbing(Tensor & self, const List<optional<Tensor>> & indice
498
498
auto maybe_layer = maybeCurrentDynamicLayer ();
499
499
vmap_check_escaped (maybe_layer, " index_put__plumbing" );
500
500
int64_t cur_level = maybe_layer->layerId ();
501
- if (!isBatchedAtLevel (self, cur_level) && !isBatchedAtLevel (indices, cur_level) && !isBatchedAtLevel (values, cur_level)) {
502
- return self.index_put_ (indices, values, accumulate);
501
+
502
+ // on device mismatch, we can move 0d tensors to self device
503
+ auto values_ = values;
504
+ if (values.device () != self.device () && values.numel () == 1 && values.dim () == 0 ) {
505
+ values_ = values.to (self.device ());
506
+ }
507
+
508
+ if (!isBatchedAtLevel (self, cur_level) && !isBatchedAtLevel (indices, cur_level) && !isBatchedAtLevel (values_, cur_level)) {
509
+ return self.index_put_ (indices, values_, accumulate);
503
510
}
504
511
auto [self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim] =
505
- unpackSelfAndIndicesAndValuesAtCurrentLevel (self, indices, values , cur_level);
512
+ unpackSelfAndIndicesAndValuesAtCurrentLevel (self, indices, values_ , cur_level);
506
513
index_put__batch_rule (self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate);
507
514
return self;
508
515
}
@@ -645,11 +652,18 @@ Tensor index_put_plumbing(const Tensor & self, const List<optional<Tensor>> & in
645
652
auto maybe_layer = maybeCurrentDynamicLayer ();
646
653
vmap_check_escaped (maybe_layer, " index_put_plumbing" );
647
654
int64_t cur_level = maybe_layer->layerId ();
648
- if (!isBatchedAtLevel (self, cur_level) && !isBatchedAtLevel (indices, cur_level) && !isBatchedAtLevel (values, cur_level)) {
649
- return self.index_put (indices, values, accumulate);
655
+
656
+ // on device mismatch, we can move 0d tensors to self device
657
+ auto values_ = values;
658
+ if (values.device () != self.device () && values.numel () == 1 && values.dim () == 0 ) {
659
+ values_ = values.to (self.device ());
660
+ }
661
+
662
+ if (!isBatchedAtLevel (self, cur_level) && !isBatchedAtLevel (indices, cur_level) && !isBatchedAtLevel (values_, cur_level)) {
663
+ return self.index_put (indices, values_, accumulate);
650
664
}
651
665
auto [self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim] =
652
- unpackSelfAndIndicesAndValuesAtCurrentLevel (self, indices, values , cur_level);
666
+ unpackSelfAndIndicesAndValuesAtCurrentLevel (self, indices, values_ , cur_level);
653
667
auto results = index_put_batch_rule (self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate);
654
668
return makeBatched (std::get<0 >(results), std::get<1 >(results), cur_level);
655
669
}
0 commit comments