Skip to content

Commit 22d79c9

Browse files
committed
RMDA/odp: Consolidate umem_odp initialization
This is done in two different places, consolidate all the post-allocation initialization into a single function. Link: https://lore.kernel.org/r/[email protected] Signed-off-by: Leon Romanovsky <[email protected]> Signed-off-by: Jason Gunthorpe <[email protected]>
1 parent fd7dbf0 commit 22d79c9

File tree

1 file changed

+86
-114
lines changed

1 file changed

+86
-114
lines changed

drivers/infiniband/core/umem_odp.c

Lines changed: 86 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -171,23 +171,6 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
171171
.invalidate_range_end = ib_umem_notifier_invalidate_range_end,
172172
};
173173

174-
static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
175-
{
176-
struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
177-
178-
down_write(&per_mm->umem_rwsem);
179-
/*
180-
* Note that the representation of the intervals in the interval tree
181-
* considers the ending point as contained in the interval, while the
182-
* function ib_umem_end returns the first address which is not
183-
* contained in the umem.
184-
*/
185-
umem_odp->interval_tree.start = ib_umem_start(umem_odp);
186-
umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1;
187-
interval_tree_insert(&umem_odp->interval_tree, &per_mm->umem_tree);
188-
up_write(&per_mm->umem_rwsem);
189-
}
190-
191174
static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
192175
{
193176
struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
@@ -237,33 +220,23 @@ static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
237220
return ERR_PTR(ret);
238221
}
239222

240-
static int get_per_mm(struct ib_umem_odp *umem_odp)
223+
static struct ib_ucontext_per_mm *get_per_mm(struct ib_umem_odp *umem_odp)
241224
{
242225
struct ib_ucontext *ctx = umem_odp->umem.context;
243226
struct ib_ucontext_per_mm *per_mm;
244227

228+
lockdep_assert_held(&ctx->per_mm_list_lock);
229+
245230
/*
246231
* Generally speaking we expect only one or two per_mm in this list,
247232
* so no reason to optimize this search today.
248233
*/
249-
mutex_lock(&ctx->per_mm_list_lock);
250234
list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
251235
if (per_mm->mm == umem_odp->umem.owning_mm)
252-
goto found;
253-
}
254-
255-
per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
256-
if (IS_ERR(per_mm)) {
257-
mutex_unlock(&ctx->per_mm_list_lock);
258-
return PTR_ERR(per_mm);
236+
return per_mm;
259237
}
260238

261-
found:
262-
umem_odp->per_mm = per_mm;
263-
per_mm->odp_mrs_count++;
264-
mutex_unlock(&ctx->per_mm_list_lock);
265-
266-
return 0;
239+
return alloc_per_mm(ctx, umem_odp->umem.owning_mm);
267240
}
268241

269242
static void free_per_mm(struct rcu_head *rcu)
@@ -304,79 +277,114 @@ static void put_per_mm(struct ib_umem_odp *umem_odp)
304277
mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm);
305278
}
306279

280+
static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
281+
struct ib_ucontext_per_mm *per_mm)
282+
{
283+
struct ib_ucontext *ctx = umem_odp->umem.context;
284+
int ret;
285+
286+
umem_odp->umem.is_odp = 1;
287+
if (!umem_odp->is_implicit_odp) {
288+
size_t pages = ib_umem_odp_num_pages(umem_odp);
289+
290+
if (!pages)
291+
return -EINVAL;
292+
293+
/*
294+
* Note that the representation of the intervals in the
295+
* interval tree considers the ending point as contained in
296+
* the interval, while the function ib_umem_end returns the
297+
* first address which is not contained in the umem.
298+
*/
299+
umem_odp->interval_tree.start = ib_umem_start(umem_odp);
300+
umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1;
301+
302+
umem_odp->page_list = vzalloc(
303+
array_size(sizeof(*umem_odp->page_list), pages));
304+
if (!umem_odp->page_list)
305+
return -ENOMEM;
306+
307+
umem_odp->dma_list =
308+
vzalloc(array_size(sizeof(*umem_odp->dma_list), pages));
309+
if (!umem_odp->dma_list) {
310+
ret = -ENOMEM;
311+
goto out_page_list;
312+
}
313+
}
314+
315+
mutex_lock(&ctx->per_mm_list_lock);
316+
if (!per_mm) {
317+
per_mm = get_per_mm(umem_odp);
318+
if (IS_ERR(per_mm)) {
319+
ret = PTR_ERR(per_mm);
320+
goto out_unlock;
321+
}
322+
}
323+
umem_odp->per_mm = per_mm;
324+
per_mm->odp_mrs_count++;
325+
mutex_unlock(&ctx->per_mm_list_lock);
326+
327+
mutex_init(&umem_odp->umem_mutex);
328+
init_completion(&umem_odp->notifier_completion);
329+
330+
if (!umem_odp->is_implicit_odp) {
331+
down_write(&per_mm->umem_rwsem);
332+
interval_tree_insert(&umem_odp->interval_tree,
333+
&per_mm->umem_tree);
334+
up_write(&per_mm->umem_rwsem);
335+
}
336+
337+
return 0;
338+
339+
out_unlock:
340+
mutex_unlock(&ctx->per_mm_list_lock);
341+
vfree(umem_odp->dma_list);
342+
out_page_list:
343+
vfree(umem_odp->page_list);
344+
return ret;
345+
}
346+
307347
struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root,
308348
unsigned long addr, size_t size)
309349
{
310-
struct ib_ucontext_per_mm *per_mm = root->per_mm;
311-
struct ib_ucontext *ctx = per_mm->context;
350+
/*
351+
* Caller must ensure that root cannot be freed during the call to
352+
* ib_alloc_odp_umem.
353+
*/
312354
struct ib_umem_odp *odp_data;
313355
struct ib_umem *umem;
314-
int pages = size >> PAGE_SHIFT;
315356
int ret;
316357

317-
if (!size)
318-
return ERR_PTR(-EINVAL);
319-
320358
odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
321359
if (!odp_data)
322360
return ERR_PTR(-ENOMEM);
323361
umem = &odp_data->umem;
324-
umem->context = ctx;
362+
umem->context = root->umem.context;
325363
umem->length = size;
326364
umem->address = addr;
327-
odp_data->page_shift = PAGE_SHIFT;
328365
umem->writable = root->umem.writable;
329-
umem->is_odp = 1;
330-
odp_data->per_mm = per_mm;
331-
umem->owning_mm = per_mm->mm;
332-
mmgrab(umem->owning_mm);
333-
334-
mutex_init(&odp_data->umem_mutex);
335-
init_completion(&odp_data->notifier_completion);
336-
337-
odp_data->page_list =
338-
vzalloc(array_size(pages, sizeof(*odp_data->page_list)));
339-
if (!odp_data->page_list) {
340-
ret = -ENOMEM;
341-
goto out_odp_data;
342-
}
366+
umem->owning_mm = root->umem.owning_mm;
367+
odp_data->page_shift = PAGE_SHIFT;
343368

344-
odp_data->dma_list =
345-
vzalloc(array_size(pages, sizeof(*odp_data->dma_list)));
346-
if (!odp_data->dma_list) {
347-
ret = -ENOMEM;
348-
goto out_page_list;
369+
ret = ib_init_umem_odp(odp_data, root->per_mm);
370+
if (ret) {
371+
kfree(odp_data);
372+
return ERR_PTR(ret);
349373
}
350374

351-
/*
352-
* Caller must ensure that the umem_odp that the per_mm came from
353-
* cannot be freed during the call to ib_alloc_odp_umem.
354-
*/
355-
mutex_lock(&ctx->per_mm_list_lock);
356-
per_mm->odp_mrs_count++;
357-
mutex_unlock(&ctx->per_mm_list_lock);
358-
add_umem_to_per_mm(odp_data);
375+
mmgrab(umem->owning_mm);
359376

360377
return odp_data;
361-
362-
out_page_list:
363-
vfree(odp_data->page_list);
364-
out_odp_data:
365-
mmdrop(umem->owning_mm);
366-
kfree(odp_data);
367-
return ERR_PTR(ret);
368378
}
369379
EXPORT_SYMBOL(ib_alloc_odp_umem);
370380

371381
int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
372382
{
373-
struct ib_umem *umem = &umem_odp->umem;
374383
/*
375384
* NOTE: This must called in a process context where umem->owning_mm
376385
* == current->mm
377386
*/
378-
struct mm_struct *mm = umem->owning_mm;
379-
int ret_val;
387+
struct mm_struct *mm = umem_odp->umem.owning_mm;
380388

381389
if (umem_odp->umem.address == 0 && umem_odp->umem.length == 0)
382390
umem_odp->is_implicit_odp = 1;
@@ -397,43 +405,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
397405
up_read(&mm->mmap_sem);
398406
}
399407

400-
mutex_init(&umem_odp->umem_mutex);
401-
402-
init_completion(&umem_odp->notifier_completion);
403-
404-
if (!umem_odp->is_implicit_odp) {
405-
if (!ib_umem_odp_num_pages(umem_odp))
406-
return -EINVAL;
407-
408-
umem_odp->page_list =
409-
vzalloc(array_size(sizeof(*umem_odp->page_list),
410-
ib_umem_odp_num_pages(umem_odp)));
411-
if (!umem_odp->page_list)
412-
return -ENOMEM;
413-
414-
umem_odp->dma_list =
415-
vzalloc(array_size(sizeof(*umem_odp->dma_list),
416-
ib_umem_odp_num_pages(umem_odp)));
417-
if (!umem_odp->dma_list) {
418-
ret_val = -ENOMEM;
419-
goto out_page_list;
420-
}
421-
}
422-
423-
ret_val = get_per_mm(umem_odp);
424-
if (ret_val)
425-
goto out_dma_list;
426-
427-
if (!umem_odp->is_implicit_odp)
428-
add_umem_to_per_mm(umem_odp);
429-
430-
return 0;
431-
432-
out_dma_list:
433-
vfree(umem_odp->dma_list);
434-
out_page_list:
435-
vfree(umem_odp->page_list);
436-
return ret_val;
408+
return ib_init_umem_odp(umem_odp, NULL);
437409
}
438410

439411
void ib_umem_odp_release(struct ib_umem_odp *umem_odp)

0 commit comments

Comments
 (0)