@@ -171,23 +171,6 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
171
171
.invalidate_range_end = ib_umem_notifier_invalidate_range_end ,
172
172
};
173
173
174
- static void add_umem_to_per_mm (struct ib_umem_odp * umem_odp )
175
- {
176
- struct ib_ucontext_per_mm * per_mm = umem_odp -> per_mm ;
177
-
178
- down_write (& per_mm -> umem_rwsem );
179
- /*
180
- * Note that the representation of the intervals in the interval tree
181
- * considers the ending point as contained in the interval, while the
182
- * function ib_umem_end returns the first address which is not
183
- * contained in the umem.
184
- */
185
- umem_odp -> interval_tree .start = ib_umem_start (umem_odp );
186
- umem_odp -> interval_tree .last = ib_umem_end (umem_odp ) - 1 ;
187
- interval_tree_insert (& umem_odp -> interval_tree , & per_mm -> umem_tree );
188
- up_write (& per_mm -> umem_rwsem );
189
- }
190
-
191
174
static void remove_umem_from_per_mm (struct ib_umem_odp * umem_odp )
192
175
{
193
176
struct ib_ucontext_per_mm * per_mm = umem_odp -> per_mm ;
@@ -237,33 +220,23 @@ static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
237
220
return ERR_PTR (ret );
238
221
}
239
222
240
- static int get_per_mm (struct ib_umem_odp * umem_odp )
223
+ static struct ib_ucontext_per_mm * get_per_mm (struct ib_umem_odp * umem_odp )
241
224
{
242
225
struct ib_ucontext * ctx = umem_odp -> umem .context ;
243
226
struct ib_ucontext_per_mm * per_mm ;
244
227
228
+ lockdep_assert_held (& ctx -> per_mm_list_lock );
229
+
245
230
/*
246
231
* Generally speaking we expect only one or two per_mm in this list,
247
232
* so no reason to optimize this search today.
248
233
*/
249
- mutex_lock (& ctx -> per_mm_list_lock );
250
234
list_for_each_entry (per_mm , & ctx -> per_mm_list , ucontext_list ) {
251
235
if (per_mm -> mm == umem_odp -> umem .owning_mm )
252
- goto found ;
253
- }
254
-
255
- per_mm = alloc_per_mm (ctx , umem_odp -> umem .owning_mm );
256
- if (IS_ERR (per_mm )) {
257
- mutex_unlock (& ctx -> per_mm_list_lock );
258
- return PTR_ERR (per_mm );
236
+ return per_mm ;
259
237
}
260
238
261
- found :
262
- umem_odp -> per_mm = per_mm ;
263
- per_mm -> odp_mrs_count ++ ;
264
- mutex_unlock (& ctx -> per_mm_list_lock );
265
-
266
- return 0 ;
239
+ return alloc_per_mm (ctx , umem_odp -> umem .owning_mm );
267
240
}
268
241
269
242
static void free_per_mm (struct rcu_head * rcu )
@@ -304,79 +277,114 @@ static void put_per_mm(struct ib_umem_odp *umem_odp)
304
277
mmu_notifier_call_srcu (& per_mm -> rcu , free_per_mm );
305
278
}
306
279
280
+ static inline int ib_init_umem_odp (struct ib_umem_odp * umem_odp ,
281
+ struct ib_ucontext_per_mm * per_mm )
282
+ {
283
+ struct ib_ucontext * ctx = umem_odp -> umem .context ;
284
+ int ret ;
285
+
286
+ umem_odp -> umem .is_odp = 1 ;
287
+ if (!umem_odp -> is_implicit_odp ) {
288
+ size_t pages = ib_umem_odp_num_pages (umem_odp );
289
+
290
+ if (!pages )
291
+ return - EINVAL ;
292
+
293
+ /*
294
+ * Note that the representation of the intervals in the
295
+ * interval tree considers the ending point as contained in
296
+ * the interval, while the function ib_umem_end returns the
297
+ * first address which is not contained in the umem.
298
+ */
299
+ umem_odp -> interval_tree .start = ib_umem_start (umem_odp );
300
+ umem_odp -> interval_tree .last = ib_umem_end (umem_odp ) - 1 ;
301
+
302
+ umem_odp -> page_list = vzalloc (
303
+ array_size (sizeof (* umem_odp -> page_list ), pages ));
304
+ if (!umem_odp -> page_list )
305
+ return - ENOMEM ;
306
+
307
+ umem_odp -> dma_list =
308
+ vzalloc (array_size (sizeof (* umem_odp -> dma_list ), pages ));
309
+ if (!umem_odp -> dma_list ) {
310
+ ret = - ENOMEM ;
311
+ goto out_page_list ;
312
+ }
313
+ }
314
+
315
+ mutex_lock (& ctx -> per_mm_list_lock );
316
+ if (!per_mm ) {
317
+ per_mm = get_per_mm (umem_odp );
318
+ if (IS_ERR (per_mm )) {
319
+ ret = PTR_ERR (per_mm );
320
+ goto out_unlock ;
321
+ }
322
+ }
323
+ umem_odp -> per_mm = per_mm ;
324
+ per_mm -> odp_mrs_count ++ ;
325
+ mutex_unlock (& ctx -> per_mm_list_lock );
326
+
327
+ mutex_init (& umem_odp -> umem_mutex );
328
+ init_completion (& umem_odp -> notifier_completion );
329
+
330
+ if (!umem_odp -> is_implicit_odp ) {
331
+ down_write (& per_mm -> umem_rwsem );
332
+ interval_tree_insert (& umem_odp -> interval_tree ,
333
+ & per_mm -> umem_tree );
334
+ up_write (& per_mm -> umem_rwsem );
335
+ }
336
+
337
+ return 0 ;
338
+
339
+ out_unlock :
340
+ mutex_unlock (& ctx -> per_mm_list_lock );
341
+ vfree (umem_odp -> dma_list );
342
+ out_page_list :
343
+ vfree (umem_odp -> page_list );
344
+ return ret ;
345
+ }
346
+
307
347
struct ib_umem_odp * ib_alloc_odp_umem (struct ib_umem_odp * root ,
308
348
unsigned long addr , size_t size )
309
349
{
310
- struct ib_ucontext_per_mm * per_mm = root -> per_mm ;
311
- struct ib_ucontext * ctx = per_mm -> context ;
350
+ /*
351
+ * Caller must ensure that root cannot be freed during the call to
352
+ * ib_alloc_odp_umem.
353
+ */
312
354
struct ib_umem_odp * odp_data ;
313
355
struct ib_umem * umem ;
314
- int pages = size >> PAGE_SHIFT ;
315
356
int ret ;
316
357
317
- if (!size )
318
- return ERR_PTR (- EINVAL );
319
-
320
358
odp_data = kzalloc (sizeof (* odp_data ), GFP_KERNEL );
321
359
if (!odp_data )
322
360
return ERR_PTR (- ENOMEM );
323
361
umem = & odp_data -> umem ;
324
- umem -> context = ctx ;
362
+ umem -> context = root -> umem . context ;
325
363
umem -> length = size ;
326
364
umem -> address = addr ;
327
- odp_data -> page_shift = PAGE_SHIFT ;
328
365
umem -> writable = root -> umem .writable ;
329
- umem -> is_odp = 1 ;
330
- odp_data -> per_mm = per_mm ;
331
- umem -> owning_mm = per_mm -> mm ;
332
- mmgrab (umem -> owning_mm );
333
-
334
- mutex_init (& odp_data -> umem_mutex );
335
- init_completion (& odp_data -> notifier_completion );
336
-
337
- odp_data -> page_list =
338
- vzalloc (array_size (pages , sizeof (* odp_data -> page_list )));
339
- if (!odp_data -> page_list ) {
340
- ret = - ENOMEM ;
341
- goto out_odp_data ;
342
- }
366
+ umem -> owning_mm = root -> umem .owning_mm ;
367
+ odp_data -> page_shift = PAGE_SHIFT ;
343
368
344
- odp_data -> dma_list =
345
- vzalloc (array_size (pages , sizeof (* odp_data -> dma_list )));
346
- if (!odp_data -> dma_list ) {
347
- ret = - ENOMEM ;
348
- goto out_page_list ;
369
+ ret = ib_init_umem_odp (odp_data , root -> per_mm );
370
+ if (ret ) {
371
+ kfree (odp_data );
372
+ return ERR_PTR (ret );
349
373
}
350
374
351
- /*
352
- * Caller must ensure that the umem_odp that the per_mm came from
353
- * cannot be freed during the call to ib_alloc_odp_umem.
354
- */
355
- mutex_lock (& ctx -> per_mm_list_lock );
356
- per_mm -> odp_mrs_count ++ ;
357
- mutex_unlock (& ctx -> per_mm_list_lock );
358
- add_umem_to_per_mm (odp_data );
375
+ mmgrab (umem -> owning_mm );
359
376
360
377
return odp_data ;
361
-
362
- out_page_list :
363
- vfree (odp_data -> page_list );
364
- out_odp_data :
365
- mmdrop (umem -> owning_mm );
366
- kfree (odp_data );
367
- return ERR_PTR (ret );
368
378
}
369
379
EXPORT_SYMBOL (ib_alloc_odp_umem );
370
380
371
381
int ib_umem_odp_get (struct ib_umem_odp * umem_odp , int access )
372
382
{
373
- struct ib_umem * umem = & umem_odp -> umem ;
374
383
/*
375
384
* NOTE: This must called in a process context where umem->owning_mm
376
385
* == current->mm
377
386
*/
378
- struct mm_struct * mm = umem -> owning_mm ;
379
- int ret_val ;
387
+ struct mm_struct * mm = umem_odp -> umem .owning_mm ;
380
388
381
389
if (umem_odp -> umem .address == 0 && umem_odp -> umem .length == 0 )
382
390
umem_odp -> is_implicit_odp = 1 ;
@@ -397,43 +405,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
397
405
up_read (& mm -> mmap_sem );
398
406
}
399
407
400
- mutex_init (& umem_odp -> umem_mutex );
401
-
402
- init_completion (& umem_odp -> notifier_completion );
403
-
404
- if (!umem_odp -> is_implicit_odp ) {
405
- if (!ib_umem_odp_num_pages (umem_odp ))
406
- return - EINVAL ;
407
-
408
- umem_odp -> page_list =
409
- vzalloc (array_size (sizeof (* umem_odp -> page_list ),
410
- ib_umem_odp_num_pages (umem_odp )));
411
- if (!umem_odp -> page_list )
412
- return - ENOMEM ;
413
-
414
- umem_odp -> dma_list =
415
- vzalloc (array_size (sizeof (* umem_odp -> dma_list ),
416
- ib_umem_odp_num_pages (umem_odp )));
417
- if (!umem_odp -> dma_list ) {
418
- ret_val = - ENOMEM ;
419
- goto out_page_list ;
420
- }
421
- }
422
-
423
- ret_val = get_per_mm (umem_odp );
424
- if (ret_val )
425
- goto out_dma_list ;
426
-
427
- if (!umem_odp -> is_implicit_odp )
428
- add_umem_to_per_mm (umem_odp );
429
-
430
- return 0 ;
431
-
432
- out_dma_list :
433
- vfree (umem_odp -> dma_list );
434
- out_page_list :
435
- vfree (umem_odp -> page_list );
436
- return ret_val ;
408
+ return ib_init_umem_odp (umem_odp , NULL );
437
409
}
438
410
439
411
void ib_umem_odp_release (struct ib_umem_odp * umem_odp )
0 commit comments