Skip to content

Commit b02aa5e

Browse files
wbigatpytorchmergebot
authored andcommitted
[Feature] storage resize_ support custom device. (pytorch#99882)
Fixes pytorch#99326 Support storage resize_ for custom device, by calling dispatched tensor operations. @ezyang this pr is another case that was brought up in issue pytorch#99326, please take a moment to review this change. Pull Request resolved: pytorch#99882 Approved by: https://github.com/ezyang
1 parent 9834358 commit b02aa5e

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

test/cpp_extensions/open_registration_extension.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ bool custom_is_pinned(const at::Tensor& self, c10::optional<at::Device> device)
128128
return false;
129129
}
130130

131+
const at::Tensor& custom_resize_(const at::Tensor& self, at::IntArrayRef size,
132+
c10::optional<at::MemoryFormat> optional_memory_format) {
133+
self.unsafeGetTensorImpl()->set_sizes_contiguous(size);
134+
const auto itemsize = self.unsafeGetTensorImpl()->dtype().itemsize();
135+
const auto offset = self.unsafeGetTensorImpl()->storage_offset();
136+
const auto storage_size = at::detail::computeStorageNbytesContiguous(size, itemsize, offset);
137+
const auto &storage = self.unsafeGetTensorImpl()->unsafe_storage();
138+
if (storage_size > storage.nbytes()) {
139+
storage.unsafeGetStorageImpl()->set_nbytes(storage_size);
140+
}
141+
142+
return self;
143+
}
144+
131145
// This macro does the heavy lifting.
132146
// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
133147
// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
@@ -146,6 +160,7 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
146160
m.impl("set_.source_Storage", &custom_set_source_Storage);
147161
m.impl("_pin_memory", &custom__pin_memory);
148162
m.impl("is_pinned", &custom_is_pinned);
163+
m.impl("resize_", &custom_resize_);
149164
}
150165

151166
// This basic implementation doesn't bother dealing with different device indices

test/test_cpp_extensions_open_device_registration.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,17 @@ def test_open_device_serialization():
264264
foo_storage = torch.serialization.default_restore_location(cpu_storage, 'foo:0')
265265
self.assertTrue(foo_storage.is_foo)
266266

267+
def test_open_device_storage_resize(self):
268+
torch.utils.rename_privateuse1_backend('foo')
269+
cpu_tensor = torch.randn([8])
270+
foo_tensor = cpu_tensor.foo()
271+
foo_storage = foo_tensor.storage()
272+
self.assertTrue(foo_storage.size() == 8)
273+
foo_storage.resize_(8)
274+
self.assertTrue(foo_storage.size() == 8)
275+
with self.assertRaisesRegex(RuntimeError, 'overflow'):
276+
foo_storage.resize_(8**29)
277+
267278
test_base_device_registration()
268279
test_before_common_registration()
269280
test_common_registration()
@@ -274,6 +285,8 @@ def test_open_device_serialization():
274285
test_open_device_storage()
275286
test_open_device_storage_pin_memory()
276287
test_open_device_serialization()
288+
test_open_device_storage_resize()
289+
277290

278291
if __name__ == "__main__":
279292
common.run_tests()

torch/csrc/StorageMethods.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,34 @@ static PyObject* THPStorage_resize_(PyObject* self, PyObject* number_arg) {
134134
const auto size_bytes = static_cast<size_t>(size_bytes_i);
135135
at::native::resize_bytes_cuda(storage.unsafeGetStorageImpl(), size_bytes);
136136
#endif
137+
} else if (device_type == at::kPrivateUse1) {
138+
ptrdiff_t size_bytes_i = newsize;
139+
TORCH_CHECK(
140+
!c10::overflows<int64_t>(size_bytes_i),
141+
"Requested storage size (",
142+
size_bytes_i,
143+
") cannot be represented as a int64_t");
144+
const auto size_bytes = static_cast<int64_t>(size_bytes_i);
145+
void* original_data_ptr = storage.data_ptr().get();
146+
147+
auto src_option =
148+
c10::TensorOptions().device(storage.device()).dtype(at::kByte);
149+
auto src_tensor = at::empty({0}, {}, src_option).set_(storage);
150+
src_tensor.resize_({size_bytes});
151+
152+
// When using resize_ to replace resize_bytes_xxx, in some cases
153+
// the original data_ptr is still returned, which is an inconsistent
154+
// behavior when compared to resize_bytes_xxx. For these cases,
155+
// an additional memory copy and update for storage are required.
156+
if (original_data_ptr == src_tensor.storage().data_ptr().get()) {
157+
auto new_tensor = at::empty(src_tensor.sizes(), src_tensor.options());
158+
new_tensor.copy_(src_tensor);
159+
storage.set_data_ptr_noswap(
160+
std::move(const_cast<at::DataPtr&>(new_tensor.storage().data_ptr())));
161+
storage.unsafeGetStorageImpl()->set_allocator(
162+
new_tensor.storage().unsafeGetStorageImpl()->allocator());
163+
storage.set_nbytes(new_tensor.storage().nbytes());
164+
}
137165
} else {
138166
TORCH_CHECK(
139167
false,

0 commit comments

Comments
 (0)