@@ -26,6 +26,34 @@ namespace torch {
26
26
namespace executor {
27
27
namespace util {
28
28
29
+ namespace {
30
+
31
+ struct Range {
32
+ // Address or offset.
33
+ uintptr_t start;
34
+ // Size in bytes.
35
+ size_t size;
36
+ };
37
+
38
+ /* *
39
+ * Given an address region, returns the start offset and byte size of the set of
40
+ * pages that completely covers the region.
41
+ */
42
+ Range get_overlapping_pages (uintptr_t offset, size_t size, size_t page_size) {
43
+ size_t page_mask = ~(page_size - 1 );
44
+ // The address of the page that starts at or before the beginning of the
45
+ // region.
46
+ uintptr_t start = offset & page_mask;
47
+ // The address of the page that starts after the end of the region.
48
+ uintptr_t end = (offset + size + ~page_mask) & page_mask;
49
+ return {
50
+ /* start=*/ start,
51
+ /* size=*/ static_cast <size_t >(end - start),
52
+ };
53
+ }
54
+
55
+ } // namespace
56
+
29
57
MmapDataLoader::~MmapDataLoader () {
30
58
// file_name_ can be nullptr if this instance was moved from, but freeing a
31
59
// null pointer is safe.
@@ -41,7 +69,7 @@ Result<MmapDataLoader> MmapDataLoader::from(
41
69
// Cache the page size.
42
70
long page_size = sysconf (_SC_PAGESIZE);
43
71
if (page_size < 0 ) {
44
- ET_LOG (Error, " Could not get page size: %s (%d)" , strerror (errno), errno);
72
+ ET_LOG (Error, " Could not get page size: %s (%d)" , :: strerror (errno), errno);
45
73
return Error::AccessFailed;
46
74
}
47
75
if ((page_size & ~(page_size - 1 )) != page_size) {
@@ -53,7 +81,11 @@ Result<MmapDataLoader> MmapDataLoader::from(
53
81
int fd = ::open (file_name, O_RDONLY);
54
82
if (fd < 0 ) {
55
83
ET_LOG (
56
- Error, " Failed to open %s: %s (%d)" , file_name, strerror (errno), errno);
84
+ Error,
85
+ " Failed to open %s: %s (%d)" ,
86
+ file_name,
87
+ ::strerror (errno),
88
+ errno);
57
89
return Error::AccessFailed;
58
90
}
59
91
@@ -89,9 +121,28 @@ Result<MmapDataLoader> MmapDataLoader::from(
89
121
}
90
122
91
123
namespace {
92
- // / FreeableBuffer::FreeFn-compatible callback.
93
- void MunmapSegment (__ET_UNUSED void * context, void * data, size_t size) {
94
- ::munmap (data, size);
124
+ /* *
125
+ * FreeableBuffer::FreeFn-compatible callback.
126
+ *
127
+ * `context` is actually the OS page size as a uintptr_t.
128
+ */
129
+ void MunmapSegment (void * context, void * data, size_t size) {
130
+ const uintptr_t page_size = reinterpret_cast <uintptr_t >(context);
131
+
132
+ Range range =
133
+ get_overlapping_pages (reinterpret_cast <uintptr_t >(data), size, page_size);
134
+ int ret = ::munmap (reinterpret_cast <void *>(range.start ), range.size );
135
+ if (ret < 0 ) {
136
+ // Let the user know that something went wrong, but there's nothing we can
137
+ // do about it.
138
+ ET_LOG (
139
+ Error,
140
+ " munmap(0x%zx, %zu) failed: %s (ignored)" ,
141
+ (size_t )range.start ,
142
+ range.size ,
143
+ ::strerror (errno),
144
+ errno);
145
+ }
95
146
}
96
147
} // namespace
97
148
@@ -109,13 +160,6 @@ Result<FreeableBuffer> MmapDataLoader::Load(size_t offset, size_t size) {
109
160
offset,
110
161
size,
111
162
file_size_);
112
- ET_CHECK_OR_RETURN_ERROR (
113
- (offset & ~(page_size_ - 1 )) == offset,
114
- InvalidArgument,
115
- " File %s: offset 0x%zx not aligned to 0x%zx" ,
116
- file_name_,
117
- offset,
118
- page_size_);
119
163
ET_CHECK_OR_RETURN_ERROR (
120
164
// Recommended by a lint warning.
121
165
offset <= std::numeric_limits<off_t >::max (),
@@ -128,43 +172,64 @@ Result<FreeableBuffer> MmapDataLoader::Load(size_t offset, size_t size) {
128
172
return FreeableBuffer (nullptr , 0 , /* free_fn=*/ nullptr );
129
173
}
130
174
175
+ // Find the range of pages that covers the requested region.
176
+ Range range =
177
+ get_overlapping_pages (static_cast <uintptr_t >(offset), size, page_size_);
178
+
131
179
// Map the pages read-only. MAP_PRIVATE vs. MAP_SHARED doesn't matter since
132
180
// the data is read-only, but use PRIVATE just to further avoid accidentally
133
181
// modifying the file.
134
- void * pages = mmap (
135
- nullptr , size, PROT_READ, MAP_PRIVATE, fd_, static_cast <off_t >(offset));
182
+ void * pages = ::mmap (
183
+ nullptr ,
184
+ range.size ,
185
+ PROT_READ,
186
+ MAP_PRIVATE,
187
+ fd_,
188
+ static_cast <off_t >(range.start ));
136
189
ET_CHECK_OR_RETURN_ERROR (
137
190
pages != MAP_FAILED,
138
191
AccessFailed,
139
192
" Failed to map %s: mmap(..., size=%zd, ..., fd=%d, offset=0x%zx)" ,
140
193
file_name_,
141
- size,
194
+ range. size ,
142
195
fd_,
143
- offset );
196
+ range. start );
144
197
145
198
if (mlock_config_ == MlockConfig::UseMlock ||
146
199
mlock_config_ == MlockConfig::UseMlockIgnoreErrors) {
147
- int err = mlock (pages, size);
200
+ int err = :: mlock (pages, size);
148
201
if (err < 0 ) {
149
202
ET_LOG (
150
203
Error,
151
204
" File %s: mlock(%p, %zu) failed: %s (%d)" ,
152
205
file_name_,
153
206
pages,
154
207
size,
155
- strerror (errno),
208
+ :: strerror (errno),
156
209
errno);
157
210
if (mlock_config_ == MlockConfig::UseMlockIgnoreErrors) {
158
211
ET_LOG (Info, " Ignoring mlock() error" );
159
212
} else {
160
- munmap (pages, size);
213
+ :: munmap (pages, size);
161
214
return Error::NotSupported;
162
215
}
163
216
}
164
217
// No need to keep track of this. munmap() will unlock as a side effect.
165
218
}
166
219
167
- return FreeableBuffer (pages, size, MunmapSegment);
220
+ // The requested data is at an offset into the mapped pages.
221
+ const void * data = static_cast <const uint8_t *>(pages) + offset - range.start ;
222
+
223
+ return FreeableBuffer (
224
+ // The callback knows to unmap the whole pages that encompass this region.
225
+ data,
226
+ size,
227
+ MunmapSegment,
228
+ /* free_fn_context=*/
229
+ reinterpret_cast <void *>(
230
+ // Pass the cached OS page size to the callback so it doesn't need to
231
+ // query it again.
232
+ static_cast <uintptr_t >(page_size_)));
168
233
}
169
234
170
235
Result<size_t > MmapDataLoader::size () const {
0 commit comments