@@ -641,14 +641,14 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
641
641
}
642
642
EXPORT_SYMBOL_GPL (vhost_dev_cleanup );
643
643
644
- static int log_access_ok (void __user * log_base , u64 addr , unsigned long sz )
644
+ static bool log_access_ok (void __user * log_base , u64 addr , unsigned long sz )
645
645
{
646
646
u64 a = addr / VHOST_PAGE_SIZE / 8 ;
647
647
648
648
/* Make sure 64 bit math will not overflow. */
649
649
if (a > ULONG_MAX - (unsigned long )log_base ||
650
650
a + (unsigned long )log_base > ULONG_MAX )
651
- return 0 ;
651
+ return false ;
652
652
653
653
return access_ok (VERIFY_WRITE , log_base + a ,
654
654
(sz + VHOST_PAGE_SIZE * 8 - 1 ) / VHOST_PAGE_SIZE / 8 );
@@ -661,30 +661,30 @@ static bool vhost_overflow(u64 uaddr, u64 size)
661
661
}
662
662
663
663
/* Caller should have vq mutex and device mutex. */
664
- static int vq_memory_access_ok (void __user * log_base , struct vhost_umem * umem ,
665
- int log_all )
664
+ static bool vq_memory_access_ok (void __user * log_base , struct vhost_umem * umem ,
665
+ int log_all )
666
666
{
667
667
struct vhost_umem_node * node ;
668
668
669
669
if (!umem )
670
- return 0 ;
670
+ return false ;
671
671
672
672
list_for_each_entry (node , & umem -> umem_list , link ) {
673
673
unsigned long a = node -> userspace_addr ;
674
674
675
675
if (vhost_overflow (node -> userspace_addr , node -> size ))
676
- return 0 ;
676
+ return false ;
677
677
678
678
679
679
if (!access_ok (VERIFY_WRITE , (void __user * )a ,
680
680
node -> size ))
681
- return 0 ;
681
+ return false ;
682
682
else if (log_all && !log_access_ok (log_base ,
683
683
node -> start ,
684
684
node -> size ))
685
- return 0 ;
685
+ return false ;
686
686
}
687
- return 1 ;
687
+ return true ;
688
688
}
689
689
690
690
static inline void __user * vhost_vq_meta_fetch (struct vhost_virtqueue * vq ,
@@ -701,13 +701,13 @@ static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq,
701
701
702
702
/* Can we switch to this memory table? */
703
703
/* Caller should have device mutex but not vq mutex */
704
- static int memory_access_ok (struct vhost_dev * d , struct vhost_umem * umem ,
705
- int log_all )
704
+ static bool memory_access_ok (struct vhost_dev * d , struct vhost_umem * umem ,
705
+ int log_all )
706
706
{
707
707
int i ;
708
708
709
709
for (i = 0 ; i < d -> nvqs ; ++ i ) {
710
- int ok ;
710
+ bool ok ;
711
711
bool log ;
712
712
713
713
mutex_lock (& d -> vqs [i ]-> mutex );
@@ -717,12 +717,12 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
717
717
ok = vq_memory_access_ok (d -> vqs [i ]-> log_base ,
718
718
umem , log );
719
719
else
720
- ok = 1 ;
720
+ ok = true ;
721
721
mutex_unlock (& d -> vqs [i ]-> mutex );
722
722
if (!ok )
723
- return 0 ;
723
+ return false ;
724
724
}
725
- return 1 ;
725
+ return true ;
726
726
}
727
727
728
728
static int translate_desc (struct vhost_virtqueue * vq , u64 addr , u32 len ,
@@ -959,21 +959,21 @@ static void vhost_iotlb_notify_vq(struct vhost_dev *d,
959
959
spin_unlock (& d -> iotlb_lock );
960
960
}
961
961
962
- static int umem_access_ok (u64 uaddr , u64 size , int access )
962
+ static bool umem_access_ok (u64 uaddr , u64 size , int access )
963
963
{
964
964
unsigned long a = uaddr ;
965
965
966
966
/* Make sure 64 bit math will not overflow. */
967
967
if (vhost_overflow (uaddr , size ))
968
- return - EFAULT ;
968
+ return false ;
969
969
970
970
if ((access & VHOST_ACCESS_RO ) &&
971
971
!access_ok (VERIFY_READ , (void __user * )a , size ))
972
- return - EFAULT ;
972
+ return false ;
973
973
if ((access & VHOST_ACCESS_WO ) &&
974
974
!access_ok (VERIFY_WRITE , (void __user * )a , size ))
975
- return - EFAULT ;
976
- return 0 ;
975
+ return false ;
976
+ return true ;
977
977
}
978
978
979
979
static int vhost_process_iotlb_msg (struct vhost_dev * dev ,
@@ -988,7 +988,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
988
988
ret = - EFAULT ;
989
989
break ;
990
990
}
991
- if (umem_access_ok (msg -> uaddr , msg -> size , msg -> perm )) {
991
+ if (! umem_access_ok (msg -> uaddr , msg -> size , msg -> perm )) {
992
992
ret = - EFAULT ;
993
993
break ;
994
994
}
@@ -1135,10 +1135,10 @@ static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
1135
1135
return 0 ;
1136
1136
}
1137
1137
1138
- static int vq_access_ok (struct vhost_virtqueue * vq , unsigned int num ,
1139
- struct vring_desc __user * desc ,
1140
- struct vring_avail __user * avail ,
1141
- struct vring_used __user * used )
1138
+ static bool vq_access_ok (struct vhost_virtqueue * vq , unsigned int num ,
1139
+ struct vring_desc __user * desc ,
1140
+ struct vring_avail __user * avail ,
1141
+ struct vring_used __user * used )
1142
1142
1143
1143
{
1144
1144
size_t s = vhost_has_feature (vq , VIRTIO_RING_F_EVENT_IDX ) ? 2 : 0 ;
@@ -1161,8 +1161,8 @@ static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
1161
1161
vq -> meta_iotlb [type ] = node ;
1162
1162
}
1163
1163
1164
- static int iotlb_access_ok (struct vhost_virtqueue * vq ,
1165
- int access , u64 addr , u64 len , int type )
1164
+ static bool iotlb_access_ok (struct vhost_virtqueue * vq ,
1165
+ int access , u64 addr , u64 len , int type )
1166
1166
{
1167
1167
const struct vhost_umem_node * node ;
1168
1168
struct vhost_umem * umem = vq -> iotlb ;
@@ -1220,16 +1220,16 @@ EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
1220
1220
1221
1221
/* Can we log writes? */
1222
1222
/* Caller should have device mutex but not vq mutex */
1223
- int vhost_log_access_ok (struct vhost_dev * dev )
1223
+ bool vhost_log_access_ok (struct vhost_dev * dev )
1224
1224
{
1225
1225
return memory_access_ok (dev , dev -> umem , 1 );
1226
1226
}
1227
1227
EXPORT_SYMBOL_GPL (vhost_log_access_ok );
1228
1228
1229
1229
/* Verify access for write logging. */
1230
1230
/* Caller should have vq mutex and device mutex */
1231
- static int vq_log_access_ok (struct vhost_virtqueue * vq ,
1232
- void __user * log_base )
1231
+ static bool vq_log_access_ok (struct vhost_virtqueue * vq ,
1232
+ void __user * log_base )
1233
1233
{
1234
1234
size_t s = vhost_has_feature (vq , VIRTIO_RING_F_EVENT_IDX ) ? 2 : 0 ;
1235
1235
@@ -1242,12 +1242,14 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
1242
1242
1243
1243
/* Can we start vq? */
1244
1244
/* Caller should have vq mutex and device mutex */
1245
- int vhost_vq_access_ok (struct vhost_virtqueue * vq )
1245
+ bool vhost_vq_access_ok (struct vhost_virtqueue * vq )
1246
1246
{
1247
- int ret = vq_log_access_ok (vq , vq -> log_base );
1247
+ if (!vq_log_access_ok (vq , vq -> log_base ))
1248
+ return false;
1248
1249
1249
- if (ret || vq -> iotlb )
1250
- return ret ;
1250
+ /* Access validation occurs at prefetch time with IOTLB */
1251
+ if (vq -> iotlb )
1252
+ return true;
1251
1253
1252
1254
return vq_access_ok (vq , vq -> num , vq -> desc , vq -> avail , vq -> used );
1253
1255
}
0 commit comments