Skip to content

Commit a9709d6

Browse files
jasowangmstsirkin
authored andcommitted
vhost: convert pre sorted vhost memory array to interval tree
Current pre-sorted memory region array has some limitations for future device IOTLB conversion: 1) need extra work for adding and removing a single region, and it's expected to be slow because of sorting or memory re-allocation. 2) need extra work of removing a large range which may intersect several regions with different size. 3) need trick for a replacement policy like LRU To overcome the above shortcomings, this patch convert it to interval tree which can easily address the above issue with almost no extra work. The patch could be used for: - Extend the current API and only let the userspace to send diffs of memory table. - Simplify Device IOTLB implementation. Signed-off-by: Jason Wang <[email protected]> Signed-off-by: Michael S. Tsirkin <[email protected]>
1 parent bfe2bc5 commit a9709d6

File tree

3 files changed

+128
-89
lines changed

3 files changed

+128
-89
lines changed

drivers/vhost/net.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,20 +1036,20 @@ static long vhost_net_reset_owner(struct vhost_net *n)
10361036
struct socket *tx_sock = NULL;
10371037
struct socket *rx_sock = NULL;
10381038
long err;
1039-
struct vhost_memory *memory;
1039+
struct vhost_umem *umem;
10401040

10411041
mutex_lock(&n->dev.mutex);
10421042
err = vhost_dev_check_owner(&n->dev);
10431043
if (err)
10441044
goto done;
1045-
memory = vhost_dev_reset_owner_prepare();
1046-
if (!memory) {
1045+
umem = vhost_dev_reset_owner_prepare();
1046+
if (!umem) {
10471047
err = -ENOMEM;
10481048
goto done;
10491049
}
10501050
vhost_net_stop(n, &tx_sock, &rx_sock);
10511051
vhost_net_flush(n);
1052-
vhost_dev_reset_owner(&n->dev, memory);
1052+
vhost_dev_reset_owner(&n->dev, umem);
10531053
vhost_net_vq_reset(n);
10541054
done:
10551055
mutex_unlock(&n->dev.mutex);

drivers/vhost/vhost.c

Lines changed: 101 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <linux/cgroup.h>
2828
#include <linux/module.h>
2929
#include <linux/sort.h>
30+
#include <linux/interval_tree_generic.h>
3031

3132
#include "vhost.h"
3233

@@ -42,6 +43,10 @@ enum {
4243
#define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
4344
#define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
4445

46+
INTERVAL_TREE_DEFINE(struct vhost_umem_node,
47+
rb, __u64, __subtree_last,
48+
START, LAST, , vhost_umem_interval_tree);
49+
4550
#ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
4651
static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
4752
{
@@ -297,10 +302,10 @@ static void vhost_vq_reset(struct vhost_dev *dev,
297302
vq->call_ctx = NULL;
298303
vq->call = NULL;
299304
vq->log_ctx = NULL;
300-
vq->memory = NULL;
301305
vhost_reset_is_le(vq);
302306
vhost_disable_cross_endian(vq);
303307
vq->busyloop_timeout = 0;
308+
vq->umem = NULL;
304309
}
305310

306311
static int vhost_worker(void *data)
@@ -394,7 +399,7 @@ void vhost_dev_init(struct vhost_dev *dev,
394399
mutex_init(&dev->mutex);
395400
dev->log_ctx = NULL;
396401
dev->log_file = NULL;
397-
dev->memory = NULL;
402+
dev->umem = NULL;
398403
dev->mm = NULL;
399404
dev->worker = NULL;
400405
init_llist_head(&dev->work_list);
@@ -499,27 +504,36 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
499504
}
500505
EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
501506

502-
struct vhost_memory *vhost_dev_reset_owner_prepare(void)
507+
static void *vhost_kvzalloc(unsigned long size)
503508
{
504-
return kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL);
509+
void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
510+
511+
if (!n)
512+
n = vzalloc(size);
513+
return n;
514+
}
515+
516+
struct vhost_umem *vhost_dev_reset_owner_prepare(void)
517+
{
518+
return vhost_kvzalloc(sizeof(struct vhost_umem));
505519
}
506520
EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
507521

508522
/* Caller should have device mutex */
509-
void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)
523+
void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem)
510524
{
511525
int i;
512526

513527
vhost_dev_cleanup(dev, true);
514528

515529
/* Restore memory to default empty mapping. */
516-
memory->nregions = 0;
517-
dev->memory = memory;
530+
INIT_LIST_HEAD(&umem->umem_list);
531+
dev->umem = umem;
518532
/* We don't need VQ locks below since vhost_dev_cleanup makes sure
519533
* VQs aren't running.
520534
*/
521535
for (i = 0; i < dev->nvqs; ++i)
522-
dev->vqs[i]->memory = memory;
536+
dev->vqs[i]->umem = umem;
523537
}
524538
EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
525539

@@ -536,6 +550,21 @@ void vhost_dev_stop(struct vhost_dev *dev)
536550
}
537551
EXPORT_SYMBOL_GPL(vhost_dev_stop);
538552

553+
static void vhost_umem_clean(struct vhost_umem *umem)
554+
{
555+
struct vhost_umem_node *node, *tmp;
556+
557+
if (!umem)
558+
return;
559+
560+
list_for_each_entry_safe(node, tmp, &umem->umem_list, link) {
561+
vhost_umem_interval_tree_remove(node, &umem->umem_tree);
562+
list_del(&node->link);
563+
kvfree(node);
564+
}
565+
kvfree(umem);
566+
}
567+
539568
/* Caller should have device mutex if and only if locked is set */
540569
void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
541570
{
@@ -562,8 +591,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
562591
fput(dev->log_file);
563592
dev->log_file = NULL;
564593
/* No one will access memory at this point */
565-
kvfree(dev->memory);
566-
dev->memory = NULL;
594+
vhost_umem_clean(dev->umem);
595+
dev->umem = NULL;
567596
WARN_ON(!llist_empty(&dev->work_list));
568597
if (dev->worker) {
569598
kthread_stop(dev->worker);
@@ -589,33 +618,33 @@ static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
589618
}
590619

591620
/* Caller should have vq mutex and device mutex. */
592-
static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem,
621+
static int vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
593622
int log_all)
594623
{
595-
int i;
624+
struct vhost_umem_node *node;
596625

597-
if (!mem)
626+
if (!umem)
598627
return 0;
599628

600-
for (i = 0; i < mem->nregions; ++i) {
601-
struct vhost_memory_region *m = mem->regions + i;
602-
unsigned long a = m->userspace_addr;
603-
if (m->memory_size > ULONG_MAX)
629+
list_for_each_entry(node, &umem->umem_list, link) {
630+
unsigned long a = node->userspace_addr;
631+
632+
if (node->size > ULONG_MAX)
604633
return 0;
605634
else if (!access_ok(VERIFY_WRITE, (void __user *)a,
606-
m->memory_size))
635+
node->size))
607636
return 0;
608637
else if (log_all && !log_access_ok(log_base,
609-
m->guest_phys_addr,
610-
m->memory_size))
638+
node->start,
639+
node->size))
611640
return 0;
612641
}
613642
return 1;
614643
}
615644

616645
/* Can we switch to this memory table? */
617646
/* Caller should have device mutex but not vq mutex */
618-
static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
647+
static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
619648
int log_all)
620649
{
621650
int i;
@@ -628,7 +657,8 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
628657
log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
629658
/* If ring is inactive, will check when it's enabled. */
630659
if (d->vqs[i]->private_data)
631-
ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log);
660+
ok = vq_memory_access_ok(d->vqs[i]->log_base,
661+
umem, log);
632662
else
633663
ok = 1;
634664
mutex_unlock(&d->vqs[i]->mutex);
@@ -671,7 +701,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
671701
/* Caller should have device mutex but not vq mutex */
672702
int vhost_log_access_ok(struct vhost_dev *dev)
673703
{
674-
return memory_access_ok(dev, dev->memory, 1);
704+
return memory_access_ok(dev, dev->umem, 1);
675705
}
676706
EXPORT_SYMBOL_GPL(vhost_log_access_ok);
677707

@@ -682,7 +712,7 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
682712
{
683713
size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
684714

685-
return vq_memory_access_ok(log_base, vq->memory,
715+
return vq_memory_access_ok(log_base, vq->umem,
686716
vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
687717
(!vq->log_used || log_access_ok(log_base, vq->log_addr,
688718
sizeof *vq->used +
@@ -698,28 +728,12 @@ int vhost_vq_access_ok(struct vhost_virtqueue *vq)
698728
}
699729
EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
700730

701-
static int vhost_memory_reg_sort_cmp(const void *p1, const void *p2)
702-
{
703-
const struct vhost_memory_region *r1 = p1, *r2 = p2;
704-
if (r1->guest_phys_addr < r2->guest_phys_addr)
705-
return 1;
706-
if (r1->guest_phys_addr > r2->guest_phys_addr)
707-
return -1;
708-
return 0;
709-
}
710-
711-
static void *vhost_kvzalloc(unsigned long size)
712-
{
713-
void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
714-
715-
if (!n)
716-
n = vzalloc(size);
717-
return n;
718-
}
719-
720731
static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
721732
{
722-
struct vhost_memory mem, *newmem, *oldmem;
733+
struct vhost_memory mem, *newmem;
734+
struct vhost_memory_region *region;
735+
struct vhost_umem_node *node;
736+
struct vhost_umem *newumem, *oldumem;
723737
unsigned long size = offsetof(struct vhost_memory, regions);
724738
int i;
725739

@@ -739,24 +753,52 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
739753
kvfree(newmem);
740754
return -EFAULT;
741755
}
742-
sort(newmem->regions, newmem->nregions, sizeof(*newmem->regions),
743-
vhost_memory_reg_sort_cmp, NULL);
744756

745-
if (!memory_access_ok(d, newmem, 0)) {
757+
newumem = vhost_kvzalloc(sizeof(*newumem));
758+
if (!newumem) {
746759
kvfree(newmem);
747-
return -EFAULT;
760+
return -ENOMEM;
761+
}
762+
763+
newumem->umem_tree = RB_ROOT;
764+
INIT_LIST_HEAD(&newumem->umem_list);
765+
766+
for (region = newmem->regions;
767+
region < newmem->regions + mem.nregions;
768+
region++) {
769+
node = vhost_kvzalloc(sizeof(*node));
770+
if (!node)
771+
goto err;
772+
node->start = region->guest_phys_addr;
773+
node->size = region->memory_size;
774+
node->last = node->start + node->size - 1;
775+
node->userspace_addr = region->userspace_addr;
776+
INIT_LIST_HEAD(&node->link);
777+
list_add_tail(&node->link, &newumem->umem_list);
778+
vhost_umem_interval_tree_insert(node, &newumem->umem_tree);
748779
}
749-
oldmem = d->memory;
750-
d->memory = newmem;
780+
781+
if (!memory_access_ok(d, newumem, 0))
782+
goto err;
783+
784+
oldumem = d->umem;
785+
d->umem = newumem;
751786

752787
/* All memory accesses are done under some VQ mutex. */
753788
for (i = 0; i < d->nvqs; ++i) {
754789
mutex_lock(&d->vqs[i]->mutex);
755-
d->vqs[i]->memory = newmem;
790+
d->vqs[i]->umem = newumem;
756791
mutex_unlock(&d->vqs[i]->mutex);
757792
}
758-
kvfree(oldmem);
793+
794+
kvfree(newmem);
795+
vhost_umem_clean(oldumem);
759796
return 0;
797+
798+
err:
799+
vhost_umem_clean(newumem);
800+
kvfree(newmem);
801+
return -EFAULT;
760802
}
761803

762804
long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
@@ -1059,28 +1101,6 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
10591101
}
10601102
EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
10611103

1062-
static const struct vhost_memory_region *find_region(struct vhost_memory *mem,
1063-
__u64 addr, __u32 len)
1064-
{
1065-
const struct vhost_memory_region *reg;
1066-
int start = 0, end = mem->nregions;
1067-
1068-
while (start < end) {
1069-
int slot = start + (end - start) / 2;
1070-
reg = mem->regions + slot;
1071-
if (addr >= reg->guest_phys_addr)
1072-
end = slot;
1073-
else
1074-
start = slot + 1;
1075-
}
1076-
1077-
reg = mem->regions + start;
1078-
if (addr >= reg->guest_phys_addr &&
1079-
reg->guest_phys_addr + reg->memory_size > addr)
1080-
return reg;
1081-
return NULL;
1082-
}
1083-
10841104
/* TODO: This is really inefficient. We need something like get_user()
10851105
* (instruction directly accesses the data, with an exception table entry
10861106
* returning -EFAULT). See Documentation/x86/exception-tables.txt.
@@ -1231,29 +1251,29 @@ EXPORT_SYMBOL_GPL(vhost_vq_init_access);
12311251
static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
12321252
struct iovec iov[], int iov_size)
12331253
{
1234-
const struct vhost_memory_region *reg;
1235-
struct vhost_memory *mem;
1254+
const struct vhost_umem_node *node;
1255+
struct vhost_umem *umem = vq->umem;
12361256
struct iovec *_iov;
12371257
u64 s = 0;
12381258
int ret = 0;
12391259

1240-
mem = vq->memory;
12411260
while ((u64)len > s) {
12421261
u64 size;
12431262
if (unlikely(ret >= iov_size)) {
12441263
ret = -ENOBUFS;
12451264
break;
12461265
}
1247-
reg = find_region(mem, addr, len);
1248-
if (unlikely(!reg)) {
1266+
node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
1267+
addr, addr + len - 1);
1268+
if (node == NULL || node->start > addr) {
12491269
ret = -EFAULT;
12501270
break;
12511271
}
12521272
_iov = iov + ret;
1253-
size = reg->memory_size - addr + reg->guest_phys_addr;
1273+
size = node->size - addr + node->start;
12541274
_iov->iov_len = min((u64)len - s, size);
12551275
_iov->iov_base = (void __user *)(unsigned long)
1256-
(reg->userspace_addr + addr - reg->guest_phys_addr);
1276+
(node->userspace_addr + addr - node->start);
12571277
s += size;
12581278
addr += size;
12591279
++ret;

0 commit comments

Comments
 (0)