Skip to content

Commit 939779f

Browse files
jasowangmstsirkin
authored andcommitted
virtio_ring: validate used buffer length
This patch validate the used buffer length provided by the device before trying to use it. This is done by record the in buffer length in a new field in desc_state structure during virtqueue_add(), then we can fail the virtqueue_get_buf() when we find the device is trying to give us a used buffer length which is greater than the in buffer length. Since some drivers have already done the validation by themselves, this patch tries to makes the core validation optional. For the driver that doesn't want the validation, it can set the suppress_used_validation to be true (which could be overridden by force_used_validation module parameter). To be more efficient, a dedicate array is used for storing the validate used length, this helps to eliminate the cache stress if validation is done by the driver. Signed-off-by: Jason Wang <[email protected]> Link: https://lore.kernel.org/r/[email protected] Signed-off-by: Michael S. Tsirkin <[email protected]>
1 parent f083937 commit 939779f

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

drivers/virtio/virtio_ring.c

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
#include <linux/spinlock.h>
1515
#include <xen/xen.h>
1616

17+
static bool force_used_validation = false;
18+
module_param(force_used_validation, bool, 0444);
19+
1720
#ifdef DEBUG
1821
/* For development, we want to crash whenever the ring is screwed. */
1922
#define BAD_RING(_vq, fmt, args...) \
@@ -182,6 +185,9 @@ struct vring_virtqueue {
182185
} packed;
183186
};
184187

188+
/* Per-descriptor in buffer length */
189+
u32 *buflen;
190+
185191
/* How to notify other side. FIXME: commonalize hcalls! */
186192
bool (*notify)(struct virtqueue *vq);
187193

@@ -490,6 +496,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
490496
unsigned int i, n, avail, descs_used, prev, err_idx;
491497
int head;
492498
bool indirect;
499+
u32 buflen = 0;
493500

494501
START_USE(vq);
495502

@@ -571,6 +578,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
571578
VRING_DESC_F_NEXT |
572579
VRING_DESC_F_WRITE,
573580
indirect);
581+
buflen += sg->length;
574582
}
575583
}
576584
/* Last one doesn't continue. */
@@ -610,6 +618,10 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
610618
else
611619
vq->split.desc_state[head].indir_desc = ctx;
612620

621+
/* Store in buffer length if necessary */
622+
if (vq->buflen)
623+
vq->buflen[head] = buflen;
624+
613625
/* Put entry in available array (but don't update avail->idx until they
614626
* do sync). */
615627
avail = vq->split.avail_idx_shadow & (vq->split.vring.num - 1);
@@ -784,6 +796,11 @@ static void *virtqueue_get_buf_ctx_split(struct virtqueue *_vq,
784796
BAD_RING(vq, "id %u is not a head!\n", i);
785797
return NULL;
786798
}
799+
if (vq->buflen && unlikely(*len > vq->buflen[i])) {
800+
BAD_RING(vq, "used len %d is larger than in buflen %u\n",
801+
*len, vq->buflen[i]);
802+
return NULL;
803+
}
787804

788805
/* detach_buf_split clears data, so grab it now. */
789806
ret = vq->split.desc_state[i].data;
@@ -1062,6 +1079,7 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq,
10621079
unsigned int i, n, err_idx;
10631080
u16 head, id;
10641081
dma_addr_t addr;
1082+
u32 buflen = 0;
10651083

10661084
head = vq->packed.next_avail_idx;
10671085
desc = alloc_indirect_packed(total_sg, gfp);
@@ -1091,6 +1109,8 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq,
10911109
desc[i].addr = cpu_to_le64(addr);
10921110
desc[i].len = cpu_to_le32(sg->length);
10931111
i++;
1112+
if (n >= out_sgs)
1113+
buflen += sg->length;
10941114
}
10951115
}
10961116

@@ -1144,6 +1164,10 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq,
11441164
vq->packed.desc_state[id].indir_desc = desc;
11451165
vq->packed.desc_state[id].last = id;
11461166

1167+
/* Store in buffer length if necessary */
1168+
if (vq->buflen)
1169+
vq->buflen[id] = buflen;
1170+
11471171
vq->num_added += 1;
11481172

11491173
pr_debug("Added buffer head %i to %p\n", head, vq);
@@ -1179,6 +1203,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq,
11791203
__le16 head_flags, flags;
11801204
u16 head, id, prev, curr, avail_used_flags;
11811205
int err;
1206+
u32 buflen = 0;
11821207

11831208
START_USE(vq);
11841209

@@ -1258,6 +1283,8 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq,
12581283
1 << VRING_PACKED_DESC_F_AVAIL |
12591284
1 << VRING_PACKED_DESC_F_USED;
12601285
}
1286+
if (n >= out_sgs)
1287+
buflen += sg->length;
12611288
}
12621289
}
12631290

@@ -1277,6 +1304,10 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq,
12771304
vq->packed.desc_state[id].indir_desc = ctx;
12781305
vq->packed.desc_state[id].last = prev;
12791306

1307+
/* Store in buffer length if necessary */
1308+
if (vq->buflen)
1309+
vq->buflen[id] = buflen;
1310+
12801311
/*
12811312
* A driver MUST NOT make the first descriptor in the list
12821313
* available before all subsequent descriptors comprising
@@ -1463,6 +1494,11 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
14631494
BAD_RING(vq, "id %u is not a head!\n", id);
14641495
return NULL;
14651496
}
1497+
if (vq->buflen && unlikely(*len > vq->buflen[id])) {
1498+
BAD_RING(vq, "used len %d is larger than in buflen %u\n",
1499+
*len, vq->buflen[id]);
1500+
return NULL;
1501+
}
14661502

14671503
/* detach_buf_packed clears data, so grab it now. */
14681504
ret = vq->packed.desc_state[id].data;
@@ -1668,6 +1704,7 @@ static struct virtqueue *vring_create_virtqueue_packed(
16681704
struct vring_virtqueue *vq;
16691705
struct vring_packed_desc *ring;
16701706
struct vring_packed_desc_event *driver, *device;
1707+
struct virtio_driver *drv = drv_to_virtio(vdev->dev.driver);
16711708
dma_addr_t ring_dma_addr, driver_event_dma_addr, device_event_dma_addr;
16721709
size_t ring_size_in_bytes, event_size_in_bytes;
16731710

@@ -1757,6 +1794,15 @@ static struct virtqueue *vring_create_virtqueue_packed(
17571794
if (!vq->packed.desc_extra)
17581795
goto err_desc_extra;
17591796

1797+
if (!drv->suppress_used_validation || force_used_validation) {
1798+
vq->buflen = kmalloc_array(num, sizeof(*vq->buflen),
1799+
GFP_KERNEL);
1800+
if (!vq->buflen)
1801+
goto err_buflen;
1802+
} else {
1803+
vq->buflen = NULL;
1804+
}
1805+
17601806
/* No callback? Tell other side not to bother us. */
17611807
if (!callback) {
17621808
vq->packed.event_flags_shadow = VRING_PACKED_EVENT_FLAG_DISABLE;
@@ -1769,6 +1815,8 @@ static struct virtqueue *vring_create_virtqueue_packed(
17691815
spin_unlock(&vdev->vqs_list_lock);
17701816
return &vq->vq;
17711817

1818+
err_buflen:
1819+
kfree(vq->packed.desc_extra);
17721820
err_desc_extra:
17731821
kfree(vq->packed.desc_state);
17741822
err_desc_state:
@@ -2176,6 +2224,7 @@ struct virtqueue *__vring_new_virtqueue(unsigned int index,
21762224
void (*callback)(struct virtqueue *),
21772225
const char *name)
21782226
{
2227+
struct virtio_driver *drv = drv_to_virtio(vdev->dev.driver);
21792228
struct vring_virtqueue *vq;
21802229

21812230
if (virtio_has_feature(vdev, VIRTIO_F_RING_PACKED))
@@ -2235,6 +2284,15 @@ struct virtqueue *__vring_new_virtqueue(unsigned int index,
22352284
if (!vq->split.desc_extra)
22362285
goto err_extra;
22372286

2287+
if (!drv->suppress_used_validation || force_used_validation) {
2288+
vq->buflen = kmalloc_array(vring.num, sizeof(*vq->buflen),
2289+
GFP_KERNEL);
2290+
if (!vq->buflen)
2291+
goto err_buflen;
2292+
} else {
2293+
vq->buflen = NULL;
2294+
}
2295+
22382296
/* Put everything in free lists. */
22392297
vq->free_head = 0;
22402298
memset(vq->split.desc_state, 0, vring.num *
@@ -2245,6 +2303,8 @@ struct virtqueue *__vring_new_virtqueue(unsigned int index,
22452303
spin_unlock(&vdev->vqs_list_lock);
22462304
return &vq->vq;
22472305

2306+
err_buflen:
2307+
kfree(vq->split.desc_extra);
22482308
err_extra:
22492309
kfree(vq->split.desc_state);
22502310
err_state:

include/linux/virtio.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ size_t virtio_max_dma_size(struct virtio_device *vdev);
152152
* @feature_table_size: number of entries in the feature table array.
153153
* @feature_table_legacy: same as feature_table but when working in legacy mode.
154154
* @feature_table_size_legacy: number of entries in feature table legacy array.
155+
* @suppress_used_validation: set to not have core validate used length
155156
* @probe: the function to call when a device is found. Returns 0 or -errno.
156157
* @scan: optional function to call after successful probe; intended
157158
* for virtio-scsi to invoke a scan.
@@ -168,6 +169,7 @@ struct virtio_driver {
168169
unsigned int feature_table_size;
169170
const unsigned int *feature_table_legacy;
170171
unsigned int feature_table_size_legacy;
172+
bool suppress_used_validation;
171173
int (*validate)(struct virtio_device *dev);
172174
int (*probe)(struct virtio_device *dev);
173175
void (*scan)(struct virtio_device *dev);

0 commit comments

Comments
 (0)