Skip to content

Commit b181b58

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Fix Storage.filename to not track the filename when storage was mmap-ed with MAP_PRIVATE (pytorch#128725)
Pull Request resolved: pytorch#128725 Approved by: https://github.com/albanD
1 parent 213eba7 commit b181b58

File tree

4 files changed

+14
-6
lines changed

4 files changed

+14
-6
lines changed

aten/src/ATen/MapAllocator.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ class TORCH_API MapAllocator {
5555
return base_ptr_;
5656
}
5757

58+
int flags() const {
59+
return flags_;
60+
}
61+
5862
static MapAllocator* fromDataPtr(const at::DataPtr&);
5963
static at::DataPtr makeDataPtr(
6064
c10::string_view filename,

test/test_tensor_creation_ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3177,18 +3177,19 @@ def test_from_file(self, device, shared):
31773177
dtype = torch.float64
31783178
t = torch.randn(2, 5, dtype=dtype, device=device)
31793179
with tempfile.NamedTemporaryFile() as f:
3180+
expected_filename = f.name if shared else None
31803181
t.numpy().tofile(f)
31813182
t_mapped = torch.from_file(f.name, shared=shared, size=t.numel(), dtype=dtype)
3182-
self.assertTrue(t_mapped.storage().filename == f.name)
3183+
self.assertTrue(t_mapped.untyped_storage().filename == expected_filename)
31833184
self.assertEqual(torch.flatten(t), t_mapped)
31843185

31853186
s = torch.UntypedStorage.from_file(f.name, shared, t.numel() * dtype.itemsize)
3186-
self.assertTrue(s.filename == f.name)
3187+
self.assertTrue(s.filename == expected_filename)
31873188

31883189
@onlyCPU
31893190
def test_storage_filename(self, device):
31903191
t = torch.randn(2, 5, device=device)
3191-
self.assertIsNone(t.storage().filename)
3192+
self.assertIsNone(t.untyped_storage().filename)
31923193

31933194

31943195
# Class for testing random tensor creation ops, like torch.randint

torch/csrc/StorageMethods.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,8 @@ static PyObject* THPStorage__get_filename(PyObject* self, PyObject* noargs) {
624624
const c10::DataPtr& data_ptr = self_.data_ptr();
625625
at::MapAllocator* map_allocator = at::MapAllocator::fromDataPtr(data_ptr);
626626

627-
if (map_allocator == nullptr) {
627+
if (map_allocator == nullptr ||
628+
!(map_allocator->flags() & at::ALLOCATOR_MAPPED_SHARED)) {
628629
Py_RETURN_NONE;
629630
}
630631
std::string filename = map_allocator->filename();

torch/storage.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,10 @@ def is_hpu(self):
363363

364364
@property
365365
def filename(self) -> _Optional[str]:
366-
"""Returns the file name associated with this storage if the storage was memory mapped from a file.
367-
or ``None`` if the storage was not created by memory mapping a file."""
366+
"""Returns the file name associated with this storage.
367+
368+
The file name will be a string if the storage is on CPU and was created via
369+
:meth:`~torch.from_file()` with ``shared`` as ``True``. This attribute is ``None`` otherwise."""
368370
return self._get_filename()
369371

370372
@_share_memory_lock_protected

0 commit comments

Comments
 (0)