27
27
#include <linux/cgroup.h>
28
28
#include <linux/module.h>
29
29
#include <linux/sort.h>
30
+ #include <linux/interval_tree_generic.h>
30
31
31
32
#include "vhost.h"
32
33
@@ -42,6 +43,10 @@ enum {
42
43
#define vhost_used_event (vq ) ((__virtio16 __user *)&vq->avail->ring[vq->num])
43
44
#define vhost_avail_event (vq ) ((__virtio16 __user *)&vq->used->ring[vq->num])
44
45
46
+ INTERVAL_TREE_DEFINE (struct vhost_umem_node ,
47
+ rb , __u64 , __subtree_last ,
48
+ START , LAST , , vhost_umem_interval_tree );
49
+
45
50
#ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
46
51
static void vhost_disable_cross_endian (struct vhost_virtqueue * vq )
47
52
{
@@ -297,10 +302,10 @@ static void vhost_vq_reset(struct vhost_dev *dev,
297
302
vq -> call_ctx = NULL ;
298
303
vq -> call = NULL ;
299
304
vq -> log_ctx = NULL ;
300
- vq -> memory = NULL ;
301
305
vhost_reset_is_le (vq );
302
306
vhost_disable_cross_endian (vq );
303
307
vq -> busyloop_timeout = 0 ;
308
+ vq -> umem = NULL ;
304
309
}
305
310
306
311
static int vhost_worker (void * data )
@@ -394,7 +399,7 @@ void vhost_dev_init(struct vhost_dev *dev,
394
399
mutex_init (& dev -> mutex );
395
400
dev -> log_ctx = NULL ;
396
401
dev -> log_file = NULL ;
397
- dev -> memory = NULL ;
402
+ dev -> umem = NULL ;
398
403
dev -> mm = NULL ;
399
404
dev -> worker = NULL ;
400
405
init_llist_head (& dev -> work_list );
@@ -499,27 +504,36 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
499
504
}
500
505
EXPORT_SYMBOL_GPL (vhost_dev_set_owner );
501
506
502
- struct vhost_memory * vhost_dev_reset_owner_prepare ( void )
507
+ static void * vhost_kvzalloc ( unsigned long size )
503
508
{
504
- return kmalloc (offsetof(struct vhost_memory , regions ), GFP_KERNEL );
509
+ void * n = kzalloc (size , GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT );
510
+
511
+ if (!n )
512
+ n = vzalloc (size );
513
+ return n ;
514
+ }
515
+
516
+ struct vhost_umem * vhost_dev_reset_owner_prepare (void )
517
+ {
518
+ return vhost_kvzalloc (sizeof (struct vhost_umem ));
505
519
}
506
520
EXPORT_SYMBOL_GPL (vhost_dev_reset_owner_prepare );
507
521
508
522
/* Caller should have device mutex */
509
- void vhost_dev_reset_owner (struct vhost_dev * dev , struct vhost_memory * memory )
523
+ void vhost_dev_reset_owner (struct vhost_dev * dev , struct vhost_umem * umem )
510
524
{
511
525
int i ;
512
526
513
527
vhost_dev_cleanup (dev , true);
514
528
515
529
/* Restore memory to default empty mapping. */
516
- memory -> nregions = 0 ;
517
- dev -> memory = memory ;
530
+ INIT_LIST_HEAD ( & umem -> umem_list ) ;
531
+ dev -> umem = umem ;
518
532
/* We don't need VQ locks below since vhost_dev_cleanup makes sure
519
533
* VQs aren't running.
520
534
*/
521
535
for (i = 0 ; i < dev -> nvqs ; ++ i )
522
- dev -> vqs [i ]-> memory = memory ;
536
+ dev -> vqs [i ]-> umem = umem ;
523
537
}
524
538
EXPORT_SYMBOL_GPL (vhost_dev_reset_owner );
525
539
@@ -536,6 +550,21 @@ void vhost_dev_stop(struct vhost_dev *dev)
536
550
}
537
551
EXPORT_SYMBOL_GPL (vhost_dev_stop );
538
552
553
+ static void vhost_umem_clean (struct vhost_umem * umem )
554
+ {
555
+ struct vhost_umem_node * node , * tmp ;
556
+
557
+ if (!umem )
558
+ return ;
559
+
560
+ list_for_each_entry_safe (node , tmp , & umem -> umem_list , link ) {
561
+ vhost_umem_interval_tree_remove (node , & umem -> umem_tree );
562
+ list_del (& node -> link );
563
+ kvfree (node );
564
+ }
565
+ kvfree (umem );
566
+ }
567
+
539
568
/* Caller should have device mutex if and only if locked is set */
540
569
void vhost_dev_cleanup (struct vhost_dev * dev , bool locked )
541
570
{
@@ -562,8 +591,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
562
591
fput (dev -> log_file );
563
592
dev -> log_file = NULL ;
564
593
/* No one will access memory at this point */
565
- kvfree (dev -> memory );
566
- dev -> memory = NULL ;
594
+ vhost_umem_clean (dev -> umem );
595
+ dev -> umem = NULL ;
567
596
WARN_ON (!llist_empty (& dev -> work_list ));
568
597
if (dev -> worker ) {
569
598
kthread_stop (dev -> worker );
@@ -589,33 +618,33 @@ static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
589
618
}
590
619
591
620
/* Caller should have vq mutex and device mutex. */
592
- static int vq_memory_access_ok (void __user * log_base , struct vhost_memory * mem ,
621
+ static int vq_memory_access_ok (void __user * log_base , struct vhost_umem * umem ,
593
622
int log_all )
594
623
{
595
- int i ;
624
+ struct vhost_umem_node * node ;
596
625
597
- if (!mem )
626
+ if (!umem )
598
627
return 0 ;
599
628
600
- for ( i = 0 ; i < mem -> nregions ; ++ i ) {
601
- struct vhost_memory_region * m = mem -> regions + i ;
602
- unsigned long a = m -> userspace_addr ;
603
- if (m -> memory_size > ULONG_MAX )
629
+ list_for_each_entry ( node , & umem -> umem_list , link ) {
630
+ unsigned long a = node -> userspace_addr ;
631
+
632
+ if (node -> size > ULONG_MAX )
604
633
return 0 ;
605
634
else if (!access_ok (VERIFY_WRITE , (void __user * )a ,
606
- m -> memory_size ))
635
+ node -> size ))
607
636
return 0 ;
608
637
else if (log_all && !log_access_ok (log_base ,
609
- m -> guest_phys_addr ,
610
- m -> memory_size ))
638
+ node -> start ,
639
+ node -> size ))
611
640
return 0 ;
612
641
}
613
642
return 1 ;
614
643
}
615
644
616
645
/* Can we switch to this memory table? */
617
646
/* Caller should have device mutex but not vq mutex */
618
- static int memory_access_ok (struct vhost_dev * d , struct vhost_memory * mem ,
647
+ static int memory_access_ok (struct vhost_dev * d , struct vhost_umem * umem ,
619
648
int log_all )
620
649
{
621
650
int i ;
@@ -628,7 +657,8 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
628
657
log = log_all || vhost_has_feature (d -> vqs [i ], VHOST_F_LOG_ALL );
629
658
/* If ring is inactive, will check when it's enabled. */
630
659
if (d -> vqs [i ]-> private_data )
631
- ok = vq_memory_access_ok (d -> vqs [i ]-> log_base , mem , log );
660
+ ok = vq_memory_access_ok (d -> vqs [i ]-> log_base ,
661
+ umem , log );
632
662
else
633
663
ok = 1 ;
634
664
mutex_unlock (& d -> vqs [i ]-> mutex );
@@ -671,7 +701,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
671
701
/* Caller should have device mutex but not vq mutex */
672
702
int vhost_log_access_ok (struct vhost_dev * dev )
673
703
{
674
- return memory_access_ok (dev , dev -> memory , 1 );
704
+ return memory_access_ok (dev , dev -> umem , 1 );
675
705
}
676
706
EXPORT_SYMBOL_GPL (vhost_log_access_ok );
677
707
@@ -682,7 +712,7 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
682
712
{
683
713
size_t s = vhost_has_feature (vq , VIRTIO_RING_F_EVENT_IDX ) ? 2 : 0 ;
684
714
685
- return vq_memory_access_ok (log_base , vq -> memory ,
715
+ return vq_memory_access_ok (log_base , vq -> umem ,
686
716
vhost_has_feature (vq , VHOST_F_LOG_ALL )) &&
687
717
(!vq -> log_used || log_access_ok (log_base , vq -> log_addr ,
688
718
sizeof * vq -> used +
@@ -698,28 +728,12 @@ int vhost_vq_access_ok(struct vhost_virtqueue *vq)
698
728
}
699
729
EXPORT_SYMBOL_GPL (vhost_vq_access_ok );
700
730
701
- static int vhost_memory_reg_sort_cmp (const void * p1 , const void * p2 )
702
- {
703
- const struct vhost_memory_region * r1 = p1 , * r2 = p2 ;
704
- if (r1 -> guest_phys_addr < r2 -> guest_phys_addr )
705
- return 1 ;
706
- if (r1 -> guest_phys_addr > r2 -> guest_phys_addr )
707
- return -1 ;
708
- return 0 ;
709
- }
710
-
711
- static void * vhost_kvzalloc (unsigned long size )
712
- {
713
- void * n = kzalloc (size , GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT );
714
-
715
- if (!n )
716
- n = vzalloc (size );
717
- return n ;
718
- }
719
-
720
731
static long vhost_set_memory (struct vhost_dev * d , struct vhost_memory __user * m )
721
732
{
722
- struct vhost_memory mem , * newmem , * oldmem ;
733
+ struct vhost_memory mem , * newmem ;
734
+ struct vhost_memory_region * region ;
735
+ struct vhost_umem_node * node ;
736
+ struct vhost_umem * newumem , * oldumem ;
723
737
unsigned long size = offsetof(struct vhost_memory , regions );
724
738
int i ;
725
739
@@ -739,24 +753,52 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
739
753
kvfree (newmem );
740
754
return - EFAULT ;
741
755
}
742
- sort (newmem -> regions , newmem -> nregions , sizeof (* newmem -> regions ),
743
- vhost_memory_reg_sort_cmp , NULL );
744
756
745
- if (!memory_access_ok (d , newmem , 0 )) {
757
+ newumem = vhost_kvzalloc (sizeof (* newumem ));
758
+ if (!newumem ) {
746
759
kvfree (newmem );
747
- return - EFAULT ;
760
+ return - ENOMEM ;
761
+ }
762
+
763
+ newumem -> umem_tree = RB_ROOT ;
764
+ INIT_LIST_HEAD (& newumem -> umem_list );
765
+
766
+ for (region = newmem -> regions ;
767
+ region < newmem -> regions + mem .nregions ;
768
+ region ++ ) {
769
+ node = vhost_kvzalloc (sizeof (* node ));
770
+ if (!node )
771
+ goto err ;
772
+ node -> start = region -> guest_phys_addr ;
773
+ node -> size = region -> memory_size ;
774
+ node -> last = node -> start + node -> size - 1 ;
775
+ node -> userspace_addr = region -> userspace_addr ;
776
+ INIT_LIST_HEAD (& node -> link );
777
+ list_add_tail (& node -> link , & newumem -> umem_list );
778
+ vhost_umem_interval_tree_insert (node , & newumem -> umem_tree );
748
779
}
749
- oldmem = d -> memory ;
750
- d -> memory = newmem ;
780
+
781
+ if (!memory_access_ok (d , newumem , 0 ))
782
+ goto err ;
783
+
784
+ oldumem = d -> umem ;
785
+ d -> umem = newumem ;
751
786
752
787
/* All memory accesses are done under some VQ mutex. */
753
788
for (i = 0 ; i < d -> nvqs ; ++ i ) {
754
789
mutex_lock (& d -> vqs [i ]-> mutex );
755
- d -> vqs [i ]-> memory = newmem ;
790
+ d -> vqs [i ]-> umem = newumem ;
756
791
mutex_unlock (& d -> vqs [i ]-> mutex );
757
792
}
758
- kvfree (oldmem );
793
+
794
+ kvfree (newmem );
795
+ vhost_umem_clean (oldumem );
759
796
return 0 ;
797
+
798
+ err :
799
+ vhost_umem_clean (newumem );
800
+ kvfree (newmem );
801
+ return - EFAULT ;
760
802
}
761
803
762
804
long vhost_vring_ioctl (struct vhost_dev * d , int ioctl , void __user * argp )
@@ -1059,28 +1101,6 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1059
1101
}
1060
1102
EXPORT_SYMBOL_GPL (vhost_dev_ioctl );
1061
1103
1062
- static const struct vhost_memory_region * find_region (struct vhost_memory * mem ,
1063
- __u64 addr , __u32 len )
1064
- {
1065
- const struct vhost_memory_region * reg ;
1066
- int start = 0 , end = mem -> nregions ;
1067
-
1068
- while (start < end ) {
1069
- int slot = start + (end - start ) / 2 ;
1070
- reg = mem -> regions + slot ;
1071
- if (addr >= reg -> guest_phys_addr )
1072
- end = slot ;
1073
- else
1074
- start = slot + 1 ;
1075
- }
1076
-
1077
- reg = mem -> regions + start ;
1078
- if (addr >= reg -> guest_phys_addr &&
1079
- reg -> guest_phys_addr + reg -> memory_size > addr )
1080
- return reg ;
1081
- return NULL ;
1082
- }
1083
-
1084
1104
/* TODO: This is really inefficient. We need something like get_user()
1085
1105
* (instruction directly accesses the data, with an exception table entry
1086
1106
* returning -EFAULT). See Documentation/x86/exception-tables.txt.
@@ -1231,29 +1251,29 @@ EXPORT_SYMBOL_GPL(vhost_vq_init_access);
1231
1251
static int translate_desc (struct vhost_virtqueue * vq , u64 addr , u32 len ,
1232
1252
struct iovec iov [], int iov_size )
1233
1253
{
1234
- const struct vhost_memory_region * reg ;
1235
- struct vhost_memory * mem ;
1254
+ const struct vhost_umem_node * node ;
1255
+ struct vhost_umem * umem = vq -> umem ;
1236
1256
struct iovec * _iov ;
1237
1257
u64 s = 0 ;
1238
1258
int ret = 0 ;
1239
1259
1240
- mem = vq -> memory ;
1241
1260
while ((u64 )len > s ) {
1242
1261
u64 size ;
1243
1262
if (unlikely (ret >= iov_size )) {
1244
1263
ret = - ENOBUFS ;
1245
1264
break ;
1246
1265
}
1247
- reg = find_region (mem , addr , len );
1248
- if (unlikely (!reg )) {
1266
+ node = vhost_umem_interval_tree_iter_first (& umem -> umem_tree ,
1267
+ addr , addr + len - 1 );
1268
+ if (node == NULL || node -> start > addr ) {
1249
1269
ret = - EFAULT ;
1250
1270
break ;
1251
1271
}
1252
1272
_iov = iov + ret ;
1253
- size = reg -> memory_size - addr + reg -> guest_phys_addr ;
1273
+ size = node -> size - addr + node -> start ;
1254
1274
_iov -> iov_len = min ((u64 )len - s , size );
1255
1275
_iov -> iov_base = (void __user * )(unsigned long )
1256
- (reg -> userspace_addr + addr - reg -> guest_phys_addr );
1276
+ (node -> userspace_addr + addr - node -> start );
1257
1277
s += size ;
1258
1278
addr += size ;
1259
1279
++ ret ;
0 commit comments