Skip to content

Commit 41210ea

Browse files
atalmanmalfet
andauthored
[MPS] Fix out-of-bounds fill to sliced tensor (pytorch#114958)
This fixes regression introduced by pytorch#81951 that caused out-of-bounds access when sliced tensor is filled with zeros Remove bogus `TORCH_INTERNAL_ASSERT(length >= offset)` as [NSMakeRange](https://developer.apple.com/documentation/foundation/1417188-nsmakerange?language=objc) arguments are location and length rather than start and end offset. In `fill_mps_tensor_`: - Pass `value` argument to `MPSStream::fill` - Pass `self.nbytes()` rather than `self.storage().nbytes()` as length of of buffer to fill as later will always results in out-of-bounds write if offset within the store is non-zero Add regression test Fixes pytorch#114692 Cherry pick of pytorch#114838 into release/2.1 branch Co-authored-by: Nikita Shulga <[email protected]>
1 parent 3183bcd commit 41210ea

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

aten/src/ATen/mps/MPSStream.mm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ @interface MPSGraphExecutionDescriptor ()
147147
}
148148

149149
void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) {
150-
TORCH_INTERNAL_ASSERT(length >= offset);
151-
if (length == 0)
150+
if (length == 0) {
152151
return;
152+
}
153153
dispatch_sync(_serialQueue, ^() {
154154
@autoreleasepool {
155155
endKernelCoalescing();

aten/src/ATen/native/mps/operations/ConstantOps.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ static bool fill_mps_tensor_(Tensor& self, uint8_t value) {
7272
if (self.is_contiguous()) {
7373
MPSStream* stream = getCurrentMPSStream();
7474
auto storage_byte_offset = self.storage_offset() * self.itemsize();
75-
stream->fill(mps::getMTLBufferStorage(self), 0, self.storage().nbytes(), storage_byte_offset);
75+
stream->fill(mps::getMTLBufferStorage(self), value, self.nbytes(), storage_byte_offset);
7676
return true;
7777
}
7878
return false;

test/test_mps.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1197,17 +1197,22 @@ def test_fill_storage_offset(self):
11971197
tensor_cpu = tensor_0[:][1].fill_(val)
11981198

11991199
self.assertEqual(tensor_mps, tensor_cpu)
1200+
self.assertEqual(tensor, tensor_0)
12001201

12011202
shape = [1, 10]
12021203
val = 0.0
12031204
tensor = torch.ones(shape, device="mps")
12041205
val_tensor_mps = torch.tensor(val, device="mps")
12051206
tensor_mps = tensor[:, 9].fill_(val_tensor_mps)
1207+
# Regression test for https://github.com/pytorch/pytorch/issues/114692
1208+
tensor[:, 5].fill_(val_tensor_mps)
12061209
tensor_0 = torch.ones(shape, device="cpu")
12071210
val_tensor_cpu = torch.tensor(val, device="cpu")
12081211
tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu)
1212+
tensor_0[:, 5].fill_(val_tensor_cpu)
12091213

1210-
self.assertEqual(tensor_mps, tensor_cpu)
1214+
self.assertEqual(tensor_mps.to(device="cpu"), tensor_cpu)
1215+
self.assertEqual(tensor.to(device="cpu"), tensor_0)
12111216

12121217
def test_cdist_large(self, device="mps"):
12131218
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:

0 commit comments

Comments
 (0)