|
@@ -27,6 +27,7 @@
|
|
|
#include <linux/cgroup.h>
|
|
|
#include <linux/module.h>
|
|
|
#include <linux/sort.h>
|
|
|
+#include <linux/interval_tree_generic.h>
|
|
|
|
|
|
#include "vhost.h"
|
|
|
|
|
@@ -34,6 +35,10 @@ static ushort max_mem_regions = 64;
|
|
|
module_param(max_mem_regions, ushort, 0444);
|
|
|
MODULE_PARM_DESC(max_mem_regions,
|
|
|
"Maximum number of memory regions in memory map. (default: 64)");
|
|
|
+static int max_iotlb_entries = 2048;
|
|
|
+module_param(max_iotlb_entries, int, 0444);
|
|
|
+MODULE_PARM_DESC(max_iotlb_entries,
|
|
|
+ "Maximum number of iotlb entries. (default: 2048)");
|
|
|
|
|
|
enum {
|
|
|
VHOST_MEMORY_F_LOG = 0x1,
|
|
@@ -42,6 +47,10 @@ enum {
|
|
|
#define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
|
|
|
#define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
|
|
|
|
|
|
+INTERVAL_TREE_DEFINE(struct vhost_umem_node,
|
|
|
+ rb, __u64, __subtree_last,
|
|
|
+ START, LAST, , vhost_umem_interval_tree);
|
|
|
+
|
|
|
#ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
|
|
|
static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
|
|
|
{
|
|
@@ -131,6 +140,19 @@ static void vhost_reset_is_le(struct vhost_virtqueue *vq)
|
|
|
vq->is_le = virtio_legacy_is_little_endian();
|
|
|
}
|
|
|
|
|
|
+struct vhost_flush_struct {
|
|
|
+ struct vhost_work work;
|
|
|
+ struct completion wait_event;
|
|
|
+};
|
|
|
+
|
|
|
+static void vhost_flush_work(struct vhost_work *work)
|
|
|
+{
|
|
|
+ struct vhost_flush_struct *s;
|
|
|
+
|
|
|
+ s = container_of(work, struct vhost_flush_struct, work);
|
|
|
+ complete(&s->wait_event);
|
|
|
+}
|
|
|
+
|
|
|
static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
|
|
|
poll_table *pt)
|
|
|
{
|
|
@@ -155,11 +177,9 @@ static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync,
|
|
|
|
|
|
void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
|
|
|
{
|
|
|
- INIT_LIST_HEAD(&work->node);
|
|
|
+ clear_bit(VHOST_WORK_QUEUED, &work->flags);
|
|
|
work->fn = fn;
|
|
|
init_waitqueue_head(&work->done);
|
|
|
- work->flushing = 0;
|
|
|
- work->queue_seq = work->done_seq = 0;
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_work_init);
|
|
|
|
|
@@ -211,31 +231,17 @@ void vhost_poll_stop(struct vhost_poll *poll)
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_poll_stop);
|
|
|
|
|
|
-static bool vhost_work_seq_done(struct vhost_dev *dev, struct vhost_work *work,
|
|
|
- unsigned seq)
|
|
|
-{
|
|
|
- int left;
|
|
|
-
|
|
|
- spin_lock_irq(&dev->work_lock);
|
|
|
- left = seq - work->done_seq;
|
|
|
- spin_unlock_irq(&dev->work_lock);
|
|
|
- return left <= 0;
|
|
|
-}
|
|
|
-
|
|
|
void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
|
|
|
{
|
|
|
- unsigned seq;
|
|
|
- int flushing;
|
|
|
+ struct vhost_flush_struct flush;
|
|
|
|
|
|
- spin_lock_irq(&dev->work_lock);
|
|
|
- seq = work->queue_seq;
|
|
|
- work->flushing++;
|
|
|
- spin_unlock_irq(&dev->work_lock);
|
|
|
- wait_event(work->done, vhost_work_seq_done(dev, work, seq));
|
|
|
- spin_lock_irq(&dev->work_lock);
|
|
|
- flushing = --work->flushing;
|
|
|
- spin_unlock_irq(&dev->work_lock);
|
|
|
- BUG_ON(flushing < 0);
|
|
|
+ if (dev->worker) {
|
|
|
+ init_completion(&flush.wait_event);
|
|
|
+ vhost_work_init(&flush.work, vhost_flush_work);
|
|
|
+
|
|
|
+ vhost_work_queue(dev, &flush.work);
|
|
|
+ wait_for_completion(&flush.wait_event);
|
|
|
+ }
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_work_flush);
|
|
|
|
|
@@ -249,16 +255,16 @@ EXPORT_SYMBOL_GPL(vhost_poll_flush);
|
|
|
|
|
|
void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
|
|
|
{
|
|
|
- unsigned long flags;
|
|
|
+ if (!dev->worker)
|
|
|
+ return;
|
|
|
|
|
|
- spin_lock_irqsave(&dev->work_lock, flags);
|
|
|
- if (list_empty(&work->node)) {
|
|
|
- list_add_tail(&work->node, &dev->work_list);
|
|
|
- work->queue_seq++;
|
|
|
- spin_unlock_irqrestore(&dev->work_lock, flags);
|
|
|
+ if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
|
|
|
+ /* We can only add the work to the list after we're
|
|
|
+ * sure it was not in the list.
|
|
|
+ */
|
|
|
+ smp_mb();
|
|
|
+ llist_add(&work->node, &dev->work_list);
|
|
|
wake_up_process(dev->worker);
|
|
|
- } else {
|
|
|
- spin_unlock_irqrestore(&dev->work_lock, flags);
|
|
|
}
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_work_queue);
|
|
@@ -266,7 +272,7 @@ EXPORT_SYMBOL_GPL(vhost_work_queue);
|
|
|
/* A lockless hint for busy polling code to exit the loop */
|
|
|
bool vhost_has_work(struct vhost_dev *dev)
|
|
|
{
|
|
|
- return !list_empty(&dev->work_list);
|
|
|
+ return !llist_empty(&dev->work_list);
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_has_work);
|
|
|
|
|
@@ -300,17 +306,18 @@ static void vhost_vq_reset(struct vhost_dev *dev,
|
|
|
vq->call_ctx = NULL;
|
|
|
vq->call = NULL;
|
|
|
vq->log_ctx = NULL;
|
|
|
- vq->memory = NULL;
|
|
|
vhost_reset_is_le(vq);
|
|
|
vhost_disable_cross_endian(vq);
|
|
|
vq->busyloop_timeout = 0;
|
|
|
+ vq->umem = NULL;
|
|
|
+ vq->iotlb = NULL;
|
|
|
}
|
|
|
|
|
|
static int vhost_worker(void *data)
|
|
|
{
|
|
|
struct vhost_dev *dev = data;
|
|
|
- struct vhost_work *work = NULL;
|
|
|
- unsigned uninitialized_var(seq);
|
|
|
+ struct vhost_work *work, *work_next;
|
|
|
+ struct llist_node *node;
|
|
|
mm_segment_t oldfs = get_fs();
|
|
|
|
|
|
set_fs(USER_DS);
|
|
@@ -320,35 +327,25 @@ static int vhost_worker(void *data)
|
|
|
/* mb paired w/ kthread_stop */
|
|
|
set_current_state(TASK_INTERRUPTIBLE);
|
|
|
|
|
|
- spin_lock_irq(&dev->work_lock);
|
|
|
- if (work) {
|
|
|
- work->done_seq = seq;
|
|
|
- if (work->flushing)
|
|
|
- wake_up_all(&work->done);
|
|
|
- }
|
|
|
-
|
|
|
if (kthread_should_stop()) {
|
|
|
- spin_unlock_irq(&dev->work_lock);
|
|
|
__set_current_state(TASK_RUNNING);
|
|
|
break;
|
|
|
}
|
|
|
- if (!list_empty(&dev->work_list)) {
|
|
|
- work = list_first_entry(&dev->work_list,
|
|
|
- struct vhost_work, node);
|
|
|
- list_del_init(&work->node);
|
|
|
- seq = work->queue_seq;
|
|
|
- } else
|
|
|
- work = NULL;
|
|
|
- spin_unlock_irq(&dev->work_lock);
|
|
|
|
|
|
- if (work) {
|
|
|
+ node = llist_del_all(&dev->work_list);
|
|
|
+ if (!node)
|
|
|
+ schedule();
|
|
|
+
|
|
|
+ node = llist_reverse_order(node);
|
|
|
+ /* make sure flag is seen after deletion */
|
|
|
+ smp_wmb();
|
|
|
+ llist_for_each_entry_safe(work, work_next, node, node) {
|
|
|
+ clear_bit(VHOST_WORK_QUEUED, &work->flags);
|
|
|
__set_current_state(TASK_RUNNING);
|
|
|
work->fn(work);
|
|
|
if (need_resched())
|
|
|
schedule();
|
|
|
- } else
|
|
|
- schedule();
|
|
|
-
|
|
|
+ }
|
|
|
}
|
|
|
unuse_mm(dev->mm);
|
|
|
set_fs(oldfs);
|
|
@@ -407,11 +404,16 @@ void vhost_dev_init(struct vhost_dev *dev,
|
|
|
mutex_init(&dev->mutex);
|
|
|
dev->log_ctx = NULL;
|
|
|
dev->log_file = NULL;
|
|
|
- dev->memory = NULL;
|
|
|
+ dev->umem = NULL;
|
|
|
+ dev->iotlb = NULL;
|
|
|
dev->mm = NULL;
|
|
|
- spin_lock_init(&dev->work_lock);
|
|
|
- INIT_LIST_HEAD(&dev->work_list);
|
|
|
dev->worker = NULL;
|
|
|
+ init_llist_head(&dev->work_list);
|
|
|
+ init_waitqueue_head(&dev->wait);
|
|
|
+ INIT_LIST_HEAD(&dev->read_list);
|
|
|
+ INIT_LIST_HEAD(&dev->pending_list);
|
|
|
+ spin_lock_init(&dev->iotlb_lock);
|
|
|
+
|
|
|
|
|
|
for (i = 0; i < dev->nvqs; ++i) {
|
|
|
vq = dev->vqs[i];
|
|
@@ -512,27 +514,36 @@ err_mm:
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
|
|
|
|
|
|
-struct vhost_memory *vhost_dev_reset_owner_prepare(void)
|
|
|
+static void *vhost_kvzalloc(unsigned long size)
|
|
|
{
|
|
|
- return kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL);
|
|
|
+ void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
|
|
|
+
|
|
|
+ if (!n)
|
|
|
+ n = vzalloc(size);
|
|
|
+ return n;
|
|
|
+}
|
|
|
+
|
|
|
+struct vhost_umem *vhost_dev_reset_owner_prepare(void)
|
|
|
+{
|
|
|
+ return vhost_kvzalloc(sizeof(struct vhost_umem));
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
|
|
|
|
|
|
/* Caller should have device mutex */
|
|
|
-void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)
|
|
|
+void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem)
|
|
|
{
|
|
|
int i;
|
|
|
|
|
|
vhost_dev_cleanup(dev, true);
|
|
|
|
|
|
/* Restore memory to default empty mapping. */
|
|
|
- memory->nregions = 0;
|
|
|
- dev->memory = memory;
|
|
|
+ INIT_LIST_HEAD(&umem->umem_list);
|
|
|
+ dev->umem = umem;
|
|
|
/* We don't need VQ locks below since vhost_dev_cleanup makes sure
|
|
|
* VQs aren't running.
|
|
|
*/
|
|
|
for (i = 0; i < dev->nvqs; ++i)
|
|
|
- dev->vqs[i]->memory = memory;
|
|
|
+ dev->vqs[i]->umem = umem;
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
|
|
|
|
|
@@ -549,6 +560,47 @@ void vhost_dev_stop(struct vhost_dev *dev)
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_dev_stop);
|
|
|
|
|
|
+static void vhost_umem_free(struct vhost_umem *umem,
|
|
|
+ struct vhost_umem_node *node)
|
|
|
+{
|
|
|
+ vhost_umem_interval_tree_remove(node, &umem->umem_tree);
|
|
|
+ list_del(&node->link);
|
|
|
+ kfree(node);
|
|
|
+ umem->numem--;
|
|
|
+}
|
|
|
+
|
|
|
+static void vhost_umem_clean(struct vhost_umem *umem)
|
|
|
+{
|
|
|
+ struct vhost_umem_node *node, *tmp;
|
|
|
+
|
|
|
+ if (!umem)
|
|
|
+ return;
|
|
|
+
|
|
|
+ list_for_each_entry_safe(node, tmp, &umem->umem_list, link)
|
|
|
+ vhost_umem_free(umem, node);
|
|
|
+
|
|
|
+ kvfree(umem);
|
|
|
+}
|
|
|
+
|
|
|
+static void vhost_clear_msg(struct vhost_dev *dev)
|
|
|
+{
|
|
|
+ struct vhost_msg_node *node, *n;
|
|
|
+
|
|
|
+ spin_lock(&dev->iotlb_lock);
|
|
|
+
|
|
|
+ list_for_each_entry_safe(node, n, &dev->read_list, node) {
|
|
|
+ list_del(&node->node);
|
|
|
+ kfree(node);
|
|
|
+ }
|
|
|
+
|
|
|
+ list_for_each_entry_safe(node, n, &dev->pending_list, node) {
|
|
|
+ list_del(&node->node);
|
|
|
+ kfree(node);
|
|
|
+ }
|
|
|
+
|
|
|
+ spin_unlock(&dev->iotlb_lock);
|
|
|
+}
|
|
|
+
|
|
|
/* Caller should have device mutex if and only if locked is set */
|
|
|
void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
|
|
|
{
|
|
@@ -575,9 +627,13 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
|
|
|
fput(dev->log_file);
|
|
|
dev->log_file = NULL;
|
|
|
/* No one will access memory at this point */
|
|
|
- kvfree(dev->memory);
|
|
|
- dev->memory = NULL;
|
|
|
- WARN_ON(!list_empty(&dev->work_list));
|
|
|
+ vhost_umem_clean(dev->umem);
|
|
|
+ dev->umem = NULL;
|
|
|
+ vhost_umem_clean(dev->iotlb);
|
|
|
+ dev->iotlb = NULL;
|
|
|
+ vhost_clear_msg(dev);
|
|
|
+ wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
|
|
|
+ WARN_ON(!llist_empty(&dev->work_list));
|
|
|
if (dev->worker) {
|
|
|
kthread_stop(dev->worker);
|
|
|
dev->worker = NULL;
|
|
@@ -601,26 +657,34 @@ static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
|
|
|
(sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
|
|
|
}
|
|
|
|
|
|
+static bool vhost_overflow(u64 uaddr, u64 size)
|
|
|
+{
|
|
|
+ /* Make sure 64 bit math will not overflow. */
|
|
|
+ return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size;
|
|
|
+}
|
|
|
+
|
|
|
/* Caller should have vq mutex and device mutex. */
|
|
|
-static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem,
|
|
|
+static int vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
|
|
|
int log_all)
|
|
|
{
|
|
|
- int i;
|
|
|
+ struct vhost_umem_node *node;
|
|
|
|
|
|
- if (!mem)
|
|
|
+ if (!umem)
|
|
|
return 0;
|
|
|
|
|
|
- for (i = 0; i < mem->nregions; ++i) {
|
|
|
- struct vhost_memory_region *m = mem->regions + i;
|
|
|
- unsigned long a = m->userspace_addr;
|
|
|
- if (m->memory_size > ULONG_MAX)
|
|
|
+ list_for_each_entry(node, &umem->umem_list, link) {
|
|
|
+ unsigned long a = node->userspace_addr;
|
|
|
+
|
|
|
+ if (vhost_overflow(node->userspace_addr, node->size))
|
|
|
return 0;
|
|
|
- else if (!access_ok(VERIFY_WRITE, (void __user *)a,
|
|
|
- m->memory_size))
|
|
|
+
|
|
|
+
|
|
|
+ if (!access_ok(VERIFY_WRITE, (void __user *)a,
|
|
|
+ node->size))
|
|
|
return 0;
|
|
|
else if (log_all && !log_access_ok(log_base,
|
|
|
- m->guest_phys_addr,
|
|
|
- m->memory_size))
|
|
|
+ node->start,
|
|
|
+ node->size))
|
|
|
return 0;
|
|
|
}
|
|
|
return 1;
|
|
@@ -628,7 +692,7 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem,
|
|
|
|
|
|
/* Can we switch to this memory table? */
|
|
|
/* Caller should have device mutex but not vq mutex */
|
|
|
-static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
|
|
|
+static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
|
|
|
int log_all)
|
|
|
{
|
|
|
int i;
|
|
@@ -641,7 +705,8 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
|
|
|
log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
|
|
|
/* If ring is inactive, will check when it's enabled. */
|
|
|
if (d->vqs[i]->private_data)
|
|
|
- ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log);
|
|
|
+ ok = vq_memory_access_ok(d->vqs[i]->log_base,
|
|
|
+ umem, log);
|
|
|
else
|
|
|
ok = 1;
|
|
|
mutex_unlock(&d->vqs[i]->mutex);
|
|
@@ -651,12 +716,385 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
|
|
|
return 1;
|
|
|
}
|
|
|
|
|
|
+static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
|
|
|
+ struct iovec iov[], int iov_size, int access);
|
|
|
+
|
|
|
+static int vhost_copy_to_user(struct vhost_virtqueue *vq, void *to,
|
|
|
+ const void *from, unsigned size)
|
|
|
+{
|
|
|
+ int ret;
|
|
|
+
|
|
|
+ if (!vq->iotlb)
|
|
|
+ return __copy_to_user(to, from, size);
|
|
|
+ else {
|
|
|
+ /* This function should be called after iotlb
|
|
|
+ * prefetch, which means we're sure that all vq
|
|
|
+ * could be access through iotlb. So -EAGAIN should
|
|
|
+ * not happen in this case.
|
|
|
+ */
|
|
|
+ /* TODO: more fast path */
|
|
|
+ struct iov_iter t;
|
|
|
+ ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
|
|
|
+ ARRAY_SIZE(vq->iotlb_iov),
|
|
|
+ VHOST_ACCESS_WO);
|
|
|
+ if (ret < 0)
|
|
|
+ goto out;
|
|
|
+ iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size);
|
|
|
+ ret = copy_to_iter(from, size, &t);
|
|
|
+ if (ret == size)
|
|
|
+ ret = 0;
|
|
|
+ }
|
|
|
+out:
|
|
|
+ return ret;
|
|
|
+}
|
|
|
+
|
|
|
+static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
|
|
|
+ void *from, unsigned size)
|
|
|
+{
|
|
|
+ int ret;
|
|
|
+
|
|
|
+ if (!vq->iotlb)
|
|
|
+ return __copy_from_user(to, from, size);
|
|
|
+ else {
|
|
|
+ /* This function should be called after iotlb
|
|
|
+ * prefetch, which means we're sure that vq
|
|
|
+ * could be access through iotlb. So -EAGAIN should
|
|
|
+ * not happen in this case.
|
|
|
+ */
|
|
|
+ /* TODO: more fast path */
|
|
|
+ struct iov_iter f;
|
|
|
+ ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
|
|
|
+ ARRAY_SIZE(vq->iotlb_iov),
|
|
|
+ VHOST_ACCESS_RO);
|
|
|
+ if (ret < 0) {
|
|
|
+ vq_err(vq, "IOTLB translation failure: uaddr "
|
|
|
+ "%p size 0x%llx\n", from,
|
|
|
+ (unsigned long long) size);
|
|
|
+ goto out;
|
|
|
+ }
|
|
|
+ iov_iter_init(&f, READ, vq->iotlb_iov, ret, size);
|
|
|
+ ret = copy_from_iter(to, size, &f);
|
|
|
+ if (ret == size)
|
|
|
+ ret = 0;
|
|
|
+ }
|
|
|
+
|
|
|
+out:
|
|
|
+ return ret;
|
|
|
+}
|
|
|
+
|
|
|
+static void __user *__vhost_get_user(struct vhost_virtqueue *vq,
|
|
|
+ void *addr, unsigned size)
|
|
|
+{
|
|
|
+ int ret;
|
|
|
+
|
|
|
+ /* This function should be called after iotlb
|
|
|
+ * prefetch, which means we're sure that vq
|
|
|
+ * could be access through iotlb. So -EAGAIN should
|
|
|
+ * not happen in this case.
|
|
|
+ */
|
|
|
+ /* TODO: more fast path */
|
|
|
+ ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
|
|
|
+ ARRAY_SIZE(vq->iotlb_iov),
|
|
|
+ VHOST_ACCESS_RO);
|
|
|
+ if (ret < 0) {
|
|
|
+ vq_err(vq, "IOTLB translation failure: uaddr "
|
|
|
+ "%p size 0x%llx\n", addr,
|
|
|
+ (unsigned long long) size);
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
|
|
|
+ vq_err(vq, "Non atomic userspace memory access: uaddr "
|
|
|
+ "%p size 0x%llx\n", addr,
|
|
|
+ (unsigned long long) size);
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
+
|
|
|
+ return vq->iotlb_iov[0].iov_base;
|
|
|
+}
|
|
|
+
|
|
|
+#define vhost_put_user(vq, x, ptr) \
|
|
|
+({ \
|
|
|
+ int ret = -EFAULT; \
|
|
|
+ if (!vq->iotlb) { \
|
|
|
+ ret = __put_user(x, ptr); \
|
|
|
+ } else { \
|
|
|
+ __typeof__(ptr) to = \
|
|
|
+ (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
|
|
|
+ if (to != NULL) \
|
|
|
+ ret = __put_user(x, to); \
|
|
|
+ else \
|
|
|
+ ret = -EFAULT; \
|
|
|
+ } \
|
|
|
+ ret; \
|
|
|
+})
|
|
|
+
|
|
|
+#define vhost_get_user(vq, x, ptr) \
|
|
|
+({ \
|
|
|
+ int ret; \
|
|
|
+ if (!vq->iotlb) { \
|
|
|
+ ret = __get_user(x, ptr); \
|
|
|
+ } else { \
|
|
|
+ __typeof__(ptr) from = \
|
|
|
+ (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
|
|
|
+ if (from != NULL) \
|
|
|
+ ret = __get_user(x, from); \
|
|
|
+ else \
|
|
|
+ ret = -EFAULT; \
|
|
|
+ } \
|
|
|
+ ret; \
|
|
|
+})
|
|
|
+
|
|
|
+static void vhost_dev_lock_vqs(struct vhost_dev *d)
|
|
|
+{
|
|
|
+ int i = 0;
|
|
|
+ for (i = 0; i < d->nvqs; ++i)
|
|
|
+ mutex_lock(&d->vqs[i]->mutex);
|
|
|
+}
|
|
|
+
|
|
|
+static void vhost_dev_unlock_vqs(struct vhost_dev *d)
|
|
|
+{
|
|
|
+ int i = 0;
|
|
|
+ for (i = 0; i < d->nvqs; ++i)
|
|
|
+ mutex_unlock(&d->vqs[i]->mutex);
|
|
|
+}
|
|
|
+
|
|
|
+static int vhost_new_umem_range(struct vhost_umem *umem,
|
|
|
+ u64 start, u64 size, u64 end,
|
|
|
+ u64 userspace_addr, int perm)
|
|
|
+{
|
|
|
+ struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC);
|
|
|
+
|
|
|
+ if (!node)
|
|
|
+ return -ENOMEM;
|
|
|
+
|
|
|
+ if (umem->numem == max_iotlb_entries) {
|
|
|
+ tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link);
|
|
|
+ vhost_umem_free(umem, tmp);
|
|
|
+ }
|
|
|
+
|
|
|
+ node->start = start;
|
|
|
+ node->size = size;
|
|
|
+ node->last = end;
|
|
|
+ node->userspace_addr = userspace_addr;
|
|
|
+ node->perm = perm;
|
|
|
+ INIT_LIST_HEAD(&node->link);
|
|
|
+ list_add_tail(&node->link, &umem->umem_list);
|
|
|
+ vhost_umem_interval_tree_insert(node, &umem->umem_tree);
|
|
|
+ umem->numem++;
|
|
|
+
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+static void vhost_del_umem_range(struct vhost_umem *umem,
|
|
|
+ u64 start, u64 end)
|
|
|
+{
|
|
|
+ struct vhost_umem_node *node;
|
|
|
+
|
|
|
+ while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
|
|
|
+ start, end)))
|
|
|
+ vhost_umem_free(umem, node);
|
|
|
+}
|
|
|
+
|
|
|
+static void vhost_iotlb_notify_vq(struct vhost_dev *d,
|
|
|
+ struct vhost_iotlb_msg *msg)
|
|
|
+{
|
|
|
+ struct vhost_msg_node *node, *n;
|
|
|
+
|
|
|
+ spin_lock(&d->iotlb_lock);
|
|
|
+
|
|
|
+ list_for_each_entry_safe(node, n, &d->pending_list, node) {
|
|
|
+ struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
|
|
|
+ if (msg->iova <= vq_msg->iova &&
|
|
|
+ msg->iova + msg->size - 1 > vq_msg->iova &&
|
|
|
+ vq_msg->type == VHOST_IOTLB_MISS) {
|
|
|
+ vhost_poll_queue(&node->vq->poll);
|
|
|
+ list_del(&node->node);
|
|
|
+ kfree(node);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ spin_unlock(&d->iotlb_lock);
|
|
|
+}
|
|
|
+
|
|
|
+static int umem_access_ok(u64 uaddr, u64 size, int access)
|
|
|
+{
|
|
|
+ unsigned long a = uaddr;
|
|
|
+
|
|
|
+ /* Make sure 64 bit math will not overflow. */
|
|
|
+ if (vhost_overflow(uaddr, size))
|
|
|
+ return -EFAULT;
|
|
|
+
|
|
|
+ if ((access & VHOST_ACCESS_RO) &&
|
|
|
+ !access_ok(VERIFY_READ, (void __user *)a, size))
|
|
|
+ return -EFAULT;
|
|
|
+ if ((access & VHOST_ACCESS_WO) &&
|
|
|
+ !access_ok(VERIFY_WRITE, (void __user *)a, size))
|
|
|
+ return -EFAULT;
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+int vhost_process_iotlb_msg(struct vhost_dev *dev,
|
|
|
+ struct vhost_iotlb_msg *msg)
|
|
|
+{
|
|
|
+ int ret = 0;
|
|
|
+
|
|
|
+ vhost_dev_lock_vqs(dev);
|
|
|
+ switch (msg->type) {
|
|
|
+ case VHOST_IOTLB_UPDATE:
|
|
|
+ if (!dev->iotlb) {
|
|
|
+ ret = -EFAULT;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
|
|
|
+ ret = -EFAULT;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size,
|
|
|
+ msg->iova + msg->size - 1,
|
|
|
+ msg->uaddr, msg->perm)) {
|
|
|
+ ret = -ENOMEM;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ vhost_iotlb_notify_vq(dev, msg);
|
|
|
+ break;
|
|
|
+ case VHOST_IOTLB_INVALIDATE:
|
|
|
+ vhost_del_umem_range(dev->iotlb, msg->iova,
|
|
|
+ msg->iova + msg->size - 1);
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ ret = -EINVAL;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+
|
|
|
+ vhost_dev_unlock_vqs(dev);
|
|
|
+ return ret;
|
|
|
+}
|
|
|
+ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
|
|
|
+ struct iov_iter *from)
|
|
|
+{
|
|
|
+ struct vhost_msg_node node;
|
|
|
+ unsigned size = sizeof(struct vhost_msg);
|
|
|
+ size_t ret;
|
|
|
+ int err;
|
|
|
+
|
|
|
+ if (iov_iter_count(from) < size)
|
|
|
+ return 0;
|
|
|
+ ret = copy_from_iter(&node.msg, size, from);
|
|
|
+ if (ret != size)
|
|
|
+ goto done;
|
|
|
+
|
|
|
+ switch (node.msg.type) {
|
|
|
+ case VHOST_IOTLB_MSG:
|
|
|
+ err = vhost_process_iotlb_msg(dev, &node.msg.iotlb);
|
|
|
+ if (err)
|
|
|
+ ret = err;
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ ret = -EINVAL;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+
|
|
|
+done:
|
|
|
+ return ret;
|
|
|
+}
|
|
|
+EXPORT_SYMBOL(vhost_chr_write_iter);
|
|
|
+
|
|
|
+unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
|
|
|
+ poll_table *wait)
|
|
|
+{
|
|
|
+ unsigned int mask = 0;
|
|
|
+
|
|
|
+ poll_wait(file, &dev->wait, wait);
|
|
|
+
|
|
|
+ if (!list_empty(&dev->read_list))
|
|
|
+ mask |= POLLIN | POLLRDNORM;
|
|
|
+
|
|
|
+ return mask;
|
|
|
+}
|
|
|
+EXPORT_SYMBOL(vhost_chr_poll);
|
|
|
+
|
|
|
+ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
|
|
|
+ int noblock)
|
|
|
+{
|
|
|
+ DEFINE_WAIT(wait);
|
|
|
+ struct vhost_msg_node *node;
|
|
|
+ ssize_t ret = 0;
|
|
|
+ unsigned size = sizeof(struct vhost_msg);
|
|
|
+
|
|
|
+ if (iov_iter_count(to) < size)
|
|
|
+ return 0;
|
|
|
+
|
|
|
+ while (1) {
|
|
|
+ if (!noblock)
|
|
|
+ prepare_to_wait(&dev->wait, &wait,
|
|
|
+ TASK_INTERRUPTIBLE);
|
|
|
+
|
|
|
+ node = vhost_dequeue_msg(dev, &dev->read_list);
|
|
|
+ if (node)
|
|
|
+ break;
|
|
|
+ if (noblock) {
|
|
|
+ ret = -EAGAIN;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (signal_pending(current)) {
|
|
|
+ ret = -ERESTARTSYS;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (!dev->iotlb) {
|
|
|
+ ret = -EBADFD;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+
|
|
|
+ schedule();
|
|
|
+ }
|
|
|
+
|
|
|
+ if (!noblock)
|
|
|
+ finish_wait(&dev->wait, &wait);
|
|
|
+
|
|
|
+ if (node) {
|
|
|
+ ret = copy_to_iter(&node->msg, size, to);
|
|
|
+
|
|
|
+ if (ret != size || node->msg.type != VHOST_IOTLB_MISS) {
|
|
|
+ kfree(node);
|
|
|
+ return ret;
|
|
|
+ }
|
|
|
+
|
|
|
+ vhost_enqueue_msg(dev, &dev->pending_list, node);
|
|
|
+ }
|
|
|
+
|
|
|
+ return ret;
|
|
|
+}
|
|
|
+EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
|
|
|
+
|
|
|
+static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
|
|
|
+{
|
|
|
+ struct vhost_dev *dev = vq->dev;
|
|
|
+ struct vhost_msg_node *node;
|
|
|
+ struct vhost_iotlb_msg *msg;
|
|
|
+
|
|
|
+ node = vhost_new_msg(vq, VHOST_IOTLB_MISS);
|
|
|
+ if (!node)
|
|
|
+ return -ENOMEM;
|
|
|
+
|
|
|
+ msg = &node->msg.iotlb;
|
|
|
+ msg->type = VHOST_IOTLB_MISS;
|
|
|
+ msg->iova = iova;
|
|
|
+ msg->perm = access;
|
|
|
+
|
|
|
+ vhost_enqueue_msg(dev, &dev->read_list, node);
|
|
|
+
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
|
|
|
struct vring_desc __user *desc,
|
|
|
struct vring_avail __user *avail,
|
|
|
struct vring_used __user *used)
|
|
|
+
|
|
|
{
|
|
|
size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
|
|
|
+
|
|
|
return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
|
|
|
access_ok(VERIFY_READ, avail,
|
|
|
sizeof *avail + num * sizeof *avail->ring + s) &&
|
|
@@ -664,11 +1102,59 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
|
|
|
sizeof *used + num * sizeof *used->ring + s);
|
|
|
}
|
|
|
|
|
|
+static int iotlb_access_ok(struct vhost_virtqueue *vq,
|
|
|
+ int access, u64 addr, u64 len)
|
|
|
+{
|
|
|
+ const struct vhost_umem_node *node;
|
|
|
+ struct vhost_umem *umem = vq->iotlb;
|
|
|
+ u64 s = 0, size;
|
|
|
+
|
|
|
+ while (len > s) {
|
|
|
+ node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
|
|
|
+ addr,
|
|
|
+ addr + len - 1);
|
|
|
+ if (node == NULL || node->start > addr) {
|
|
|
+ vhost_iotlb_miss(vq, addr, access);
|
|
|
+ return false;
|
|
|
+ } else if (!(node->perm & access)) {
|
|
|
+ /* Report the possible access violation by
|
|
|
+ * request another translation from userspace.
|
|
|
+ */
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ size = node->size - addr + node->start;
|
|
|
+ s += size;
|
|
|
+ addr += size;
|
|
|
+ }
|
|
|
+
|
|
|
+ return true;
|
|
|
+}
|
|
|
+
|
|
|
+int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
|
|
|
+{
|
|
|
+ size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
|
|
|
+ unsigned int num = vq->num;
|
|
|
+
|
|
|
+ if (!vq->iotlb)
|
|
|
+ return 1;
|
|
|
+
|
|
|
+ return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
|
|
|
+ num * sizeof *vq->desc) &&
|
|
|
+ iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
|
|
|
+ sizeof *vq->avail +
|
|
|
+ num * sizeof *vq->avail->ring + s) &&
|
|
|
+ iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
|
|
|
+ sizeof *vq->used +
|
|
|
+ num * sizeof *vq->used->ring + s);
|
|
|
+}
|
|
|
+EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
|
|
|
+
|
|
|
/* Can we log writes? */
|
|
|
/* Caller should have device mutex but not vq mutex */
|
|
|
int vhost_log_access_ok(struct vhost_dev *dev)
|
|
|
{
|
|
|
- return memory_access_ok(dev, dev->memory, 1);
|
|
|
+ return memory_access_ok(dev, dev->umem, 1);
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_log_access_ok);
|
|
|
|
|
@@ -679,7 +1165,7 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
|
|
|
{
|
|
|
size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
|
|
|
|
|
|
- return vq_memory_access_ok(log_base, vq->memory,
|
|
|
+ return vq_memory_access_ok(log_base, vq->umem,
|
|
|
vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
|
|
|
(!vq->log_used || log_access_ok(log_base, vq->log_addr,
|
|
|
sizeof *vq->used +
|
|
@@ -690,33 +1176,36 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
|
|
|
/* Caller should have vq mutex and device mutex */
|
|
|
int vhost_vq_access_ok(struct vhost_virtqueue *vq)
|
|
|
{
|
|
|
+ if (vq->iotlb) {
|
|
|
+ /* When device IOTLB was used, the access validation
|
|
|
+ * will be validated during prefetching.
|
|
|
+ */
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) &&
|
|
|
vq_log_access_ok(vq, vq->log_base);
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
|
|
|
|
|
|
-static int vhost_memory_reg_sort_cmp(const void *p1, const void *p2)
|
|
|
+static struct vhost_umem *vhost_umem_alloc(void)
|
|
|
{
|
|
|
- const struct vhost_memory_region *r1 = p1, *r2 = p2;
|
|
|
- if (r1->guest_phys_addr < r2->guest_phys_addr)
|
|
|
- return 1;
|
|
|
- if (r1->guest_phys_addr > r2->guest_phys_addr)
|
|
|
- return -1;
|
|
|
- return 0;
|
|
|
-}
|
|
|
+ struct vhost_umem *umem = vhost_kvzalloc(sizeof(*umem));
|
|
|
|
|
|
-static void *vhost_kvzalloc(unsigned long size)
|
|
|
-{
|
|
|
- void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
|
|
|
+ if (!umem)
|
|
|
+ return NULL;
|
|
|
|
|
|
- if (!n)
|
|
|
- n = vzalloc(size);
|
|
|
- return n;
|
|
|
+ umem->umem_tree = RB_ROOT;
|
|
|
+ umem->numem = 0;
|
|
|
+ INIT_LIST_HEAD(&umem->umem_list);
|
|
|
+
|
|
|
+ return umem;
|
|
|
}
|
|
|
|
|
|
static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
|
|
|
{
|
|
|
- struct vhost_memory mem, *newmem, *oldmem;
|
|
|
+ struct vhost_memory mem, *newmem;
|
|
|
+ struct vhost_memory_region *region;
|
|
|
+ struct vhost_umem *newumem, *oldumem;
|
|
|
unsigned long size = offsetof(struct vhost_memory, regions);
|
|
|
int i;
|
|
|
|
|
@@ -736,24 +1225,47 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
|
|
|
kvfree(newmem);
|
|
|
return -EFAULT;
|
|
|
}
|
|
|
- sort(newmem->regions, newmem->nregions, sizeof(*newmem->regions),
|
|
|
- vhost_memory_reg_sort_cmp, NULL);
|
|
|
|
|
|
- if (!memory_access_ok(d, newmem, 0)) {
|
|
|
+ newumem = vhost_umem_alloc();
|
|
|
+ if (!newumem) {
|
|
|
kvfree(newmem);
|
|
|
- return -EFAULT;
|
|
|
+ return -ENOMEM;
|
|
|
}
|
|
|
- oldmem = d->memory;
|
|
|
- d->memory = newmem;
|
|
|
+
|
|
|
+ for (region = newmem->regions;
|
|
|
+ region < newmem->regions + mem.nregions;
|
|
|
+ region++) {
|
|
|
+ if (vhost_new_umem_range(newumem,
|
|
|
+ region->guest_phys_addr,
|
|
|
+ region->memory_size,
|
|
|
+ region->guest_phys_addr +
|
|
|
+ region->memory_size - 1,
|
|
|
+ region->userspace_addr,
|
|
|
+ VHOST_ACCESS_RW))
|
|
|
+ goto err;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (!memory_access_ok(d, newumem, 0))
|
|
|
+ goto err;
|
|
|
+
|
|
|
+ oldumem = d->umem;
|
|
|
+ d->umem = newumem;
|
|
|
|
|
|
/* All memory accesses are done under some VQ mutex. */
|
|
|
for (i = 0; i < d->nvqs; ++i) {
|
|
|
mutex_lock(&d->vqs[i]->mutex);
|
|
|
- d->vqs[i]->memory = newmem;
|
|
|
+ d->vqs[i]->umem = newumem;
|
|
|
mutex_unlock(&d->vqs[i]->mutex);
|
|
|
}
|
|
|
- kvfree(oldmem);
|
|
|
+
|
|
|
+ kvfree(newmem);
|
|
|
+ vhost_umem_clean(oldumem);
|
|
|
return 0;
|
|
|
+
|
|
|
+err:
|
|
|
+ vhost_umem_clean(newumem);
|
|
|
+ kvfree(newmem);
|
|
|
+ return -EFAULT;
|
|
|
}
|
|
|
|
|
|
long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
|
|
@@ -974,6 +1486,30 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
|
|
|
|
|
|
+int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
|
|
|
+{
|
|
|
+ struct vhost_umem *niotlb, *oiotlb;
|
|
|
+ int i;
|
|
|
+
|
|
|
+ niotlb = vhost_umem_alloc();
|
|
|
+ if (!niotlb)
|
|
|
+ return -ENOMEM;
|
|
|
+
|
|
|
+ oiotlb = d->iotlb;
|
|
|
+ d->iotlb = niotlb;
|
|
|
+
|
|
|
+ for (i = 0; i < d->nvqs; ++i) {
|
|
|
+ mutex_lock(&d->vqs[i]->mutex);
|
|
|
+ d->vqs[i]->iotlb = niotlb;
|
|
|
+ mutex_unlock(&d->vqs[i]->mutex);
|
|
|
+ }
|
|
|
+
|
|
|
+ vhost_umem_clean(oiotlb);
|
|
|
+
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
|
|
|
+
|
|
|
/* Caller must have device mutex */
|
|
|
long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
|
|
|
{
|
|
@@ -1056,28 +1592,6 @@ done:
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
|
|
|
|
|
|
-static const struct vhost_memory_region *find_region(struct vhost_memory *mem,
|
|
|
- __u64 addr, __u32 len)
|
|
|
-{
|
|
|
- const struct vhost_memory_region *reg;
|
|
|
- int start = 0, end = mem->nregions;
|
|
|
-
|
|
|
- while (start < end) {
|
|
|
- int slot = start + (end - start) / 2;
|
|
|
- reg = mem->regions + slot;
|
|
|
- if (addr >= reg->guest_phys_addr)
|
|
|
- end = slot;
|
|
|
- else
|
|
|
- start = slot + 1;
|
|
|
- }
|
|
|
-
|
|
|
- reg = mem->regions + start;
|
|
|
- if (addr >= reg->guest_phys_addr &&
|
|
|
- reg->guest_phys_addr + reg->memory_size > addr)
|
|
|
- return reg;
|
|
|
- return NULL;
|
|
|
-}
|
|
|
-
|
|
|
/* TODO: This is really inefficient. We need something like get_user()
|
|
|
* (instruction directly accesses the data, with an exception table entry
|
|
|
* returning -EFAULT). See Documentation/x86/exception-tables.txt.
|
|
@@ -1156,7 +1670,8 @@ EXPORT_SYMBOL_GPL(vhost_log_write);
|
|
|
static int vhost_update_used_flags(struct vhost_virtqueue *vq)
|
|
|
{
|
|
|
void __user *used;
|
|
|
- if (__put_user(cpu_to_vhost16(vq, vq->used_flags), &vq->used->flags) < 0)
|
|
|
+ if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
|
|
|
+ &vq->used->flags) < 0)
|
|
|
return -EFAULT;
|
|
|
if (unlikely(vq->log_used)) {
|
|
|
/* Make sure the flag is seen before log. */
|
|
@@ -1174,7 +1689,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq)
|
|
|
|
|
|
static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
|
|
|
{
|
|
|
- if (__put_user(cpu_to_vhost16(vq, vq->avail_idx), vhost_avail_event(vq)))
|
|
|
+ if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
|
|
|
+ vhost_avail_event(vq)))
|
|
|
return -EFAULT;
|
|
|
if (unlikely(vq->log_used)) {
|
|
|
void __user *used;
|
|
@@ -1208,15 +1724,20 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
|
|
|
if (r)
|
|
|
goto err;
|
|
|
vq->signalled_used_valid = false;
|
|
|
- if (!access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
|
|
|
+ if (!vq->iotlb &&
|
|
|
+ !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
|
|
|
r = -EFAULT;
|
|
|
goto err;
|
|
|
}
|
|
|
- r = __get_user(last_used_idx, &vq->used->idx);
|
|
|
- if (r)
|
|
|
+ r = vhost_get_user(vq, last_used_idx, &vq->used->idx);
|
|
|
+ if (r) {
|
|
|
+ vq_err(vq, "Can't access used idx at %p\n",
|
|
|
+ &vq->used->idx);
|
|
|
goto err;
|
|
|
+ }
|
|
|
vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
|
|
|
return 0;
|
|
|
+
|
|
|
err:
|
|
|
vq->is_le = is_le;
|
|
|
return r;
|
|
@@ -1224,36 +1745,48 @@ err:
|
|
|
EXPORT_SYMBOL_GPL(vhost_vq_init_access);
|
|
|
|
|
|
static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
|
|
|
- struct iovec iov[], int iov_size)
|
|
|
+ struct iovec iov[], int iov_size, int access)
|
|
|
{
|
|
|
- const struct vhost_memory_region *reg;
|
|
|
- struct vhost_memory *mem;
|
|
|
+ const struct vhost_umem_node *node;
|
|
|
+ struct vhost_dev *dev = vq->dev;
|
|
|
+ struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem;
|
|
|
struct iovec *_iov;
|
|
|
u64 s = 0;
|
|
|
int ret = 0;
|
|
|
|
|
|
- mem = vq->memory;
|
|
|
while ((u64)len > s) {
|
|
|
u64 size;
|
|
|
if (unlikely(ret >= iov_size)) {
|
|
|
ret = -ENOBUFS;
|
|
|
break;
|
|
|
}
|
|
|
- reg = find_region(mem, addr, len);
|
|
|
- if (unlikely(!reg)) {
|
|
|
- ret = -EFAULT;
|
|
|
+
|
|
|
+ node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
|
|
|
+ addr, addr + len - 1);
|
|
|
+ if (node == NULL || node->start > addr) {
|
|
|
+ if (umem != dev->iotlb) {
|
|
|
+ ret = -EFAULT;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ ret = -EAGAIN;
|
|
|
+ break;
|
|
|
+ } else if (!(node->perm & access)) {
|
|
|
+ ret = -EPERM;
|
|
|
break;
|
|
|
}
|
|
|
+
|
|
|
_iov = iov + ret;
|
|
|
- size = reg->memory_size - addr + reg->guest_phys_addr;
|
|
|
+ size = node->size - addr + node->start;
|
|
|
_iov->iov_len = min((u64)len - s, size);
|
|
|
_iov->iov_base = (void __user *)(unsigned long)
|
|
|
- (reg->userspace_addr + addr - reg->guest_phys_addr);
|
|
|
+ (node->userspace_addr + addr - node->start);
|
|
|
s += size;
|
|
|
addr += size;
|
|
|
++ret;
|
|
|
}
|
|
|
|
|
|
+ if (ret == -EAGAIN)
|
|
|
+ vhost_iotlb_miss(vq, addr, access);
|
|
|
return ret;
|
|
|
}
|
|
|
|
|
@@ -1288,7 +1821,7 @@ static int get_indirect(struct vhost_virtqueue *vq,
|
|
|
unsigned int i = 0, count, found = 0;
|
|
|
u32 len = vhost32_to_cpu(vq, indirect->len);
|
|
|
struct iov_iter from;
|
|
|
- int ret;
|
|
|
+ int ret, access;
|
|
|
|
|
|
/* Sanity check */
|
|
|
if (unlikely(len % sizeof desc)) {
|
|
@@ -1300,9 +1833,10 @@ static int get_indirect(struct vhost_virtqueue *vq,
|
|
|
}
|
|
|
|
|
|
ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
|
|
|
- UIO_MAXIOV);
|
|
|
+ UIO_MAXIOV, VHOST_ACCESS_RO);
|
|
|
if (unlikely(ret < 0)) {
|
|
|
- vq_err(vq, "Translation failure %d in indirect.\n", ret);
|
|
|
+ if (ret != -EAGAIN)
|
|
|
+ vq_err(vq, "Translation failure %d in indirect.\n", ret);
|
|
|
return ret;
|
|
|
}
|
|
|
iov_iter_init(&from, READ, vq->indirect, ret, len);
|
|
@@ -1340,16 +1874,22 @@ static int get_indirect(struct vhost_virtqueue *vq,
|
|
|
return -EINVAL;
|
|
|
}
|
|
|
|
|
|
+ if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
|
|
|
+ access = VHOST_ACCESS_WO;
|
|
|
+ else
|
|
|
+ access = VHOST_ACCESS_RO;
|
|
|
+
|
|
|
ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
|
|
|
vhost32_to_cpu(vq, desc.len), iov + iov_count,
|
|
|
- iov_size - iov_count);
|
|
|
+ iov_size - iov_count, access);
|
|
|
if (unlikely(ret < 0)) {
|
|
|
- vq_err(vq, "Translation failure %d indirect idx %d\n",
|
|
|
- ret, i);
|
|
|
+ if (ret != -EAGAIN)
|
|
|
+ vq_err(vq, "Translation failure %d indirect idx %d\n",
|
|
|
+ ret, i);
|
|
|
return ret;
|
|
|
}
|
|
|
/* If this is an input descriptor, increment that count. */
|
|
|
- if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) {
|
|
|
+ if (access == VHOST_ACCESS_WO) {
|
|
|
*in_num += ret;
|
|
|
if (unlikely(log)) {
|
|
|
log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
|
|
@@ -1388,11 +1928,11 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
|
|
|
u16 last_avail_idx;
|
|
|
__virtio16 avail_idx;
|
|
|
__virtio16 ring_head;
|
|
|
- int ret;
|
|
|
+ int ret, access;
|
|
|
|
|
|
/* Check it isn't doing very strange things with descriptor numbers. */
|
|
|
last_avail_idx = vq->last_avail_idx;
|
|
|
- if (unlikely(__get_user(avail_idx, &vq->avail->idx))) {
|
|
|
+ if (unlikely(vhost_get_user(vq, avail_idx, &vq->avail->idx))) {
|
|
|
vq_err(vq, "Failed to access avail idx at %p\n",
|
|
|
&vq->avail->idx);
|
|
|
return -EFAULT;
|
|
@@ -1414,8 +1954,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
|
|
|
|
|
|
/* Grab the next descriptor number they're advertising, and increment
|
|
|
* the index we've seen. */
|
|
|
- if (unlikely(__get_user(ring_head,
|
|
|
- &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
|
|
|
+ if (unlikely(vhost_get_user(vq, ring_head,
|
|
|
+ &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
|
|
|
vq_err(vq, "Failed to read head: idx %d address %p\n",
|
|
|
last_avail_idx,
|
|
|
&vq->avail->ring[last_avail_idx % vq->num]);
|
|
@@ -1450,7 +1990,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
|
|
|
i, vq->num, head);
|
|
|
return -EINVAL;
|
|
|
}
|
|
|
- ret = __copy_from_user(&desc, vq->desc + i, sizeof desc);
|
|
|
+ ret = vhost_copy_from_user(vq, &desc, vq->desc + i,
|
|
|
+ sizeof desc);
|
|
|
if (unlikely(ret)) {
|
|
|
vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
|
|
|
i, vq->desc + i);
|
|
@@ -1461,22 +2002,28 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
|
|
|
out_num, in_num,
|
|
|
log, log_num, &desc);
|
|
|
if (unlikely(ret < 0)) {
|
|
|
- vq_err(vq, "Failure detected "
|
|
|
- "in indirect descriptor at idx %d\n", i);
|
|
|
+ if (ret != -EAGAIN)
|
|
|
+ vq_err(vq, "Failure detected "
|
|
|
+ "in indirect descriptor at idx %d\n", i);
|
|
|
return ret;
|
|
|
}
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
+ if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
|
|
|
+ access = VHOST_ACCESS_WO;
|
|
|
+ else
|
|
|
+ access = VHOST_ACCESS_RO;
|
|
|
ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
|
|
|
vhost32_to_cpu(vq, desc.len), iov + iov_count,
|
|
|
- iov_size - iov_count);
|
|
|
+ iov_size - iov_count, access);
|
|
|
if (unlikely(ret < 0)) {
|
|
|
- vq_err(vq, "Translation failure %d descriptor idx %d\n",
|
|
|
- ret, i);
|
|
|
+ if (ret != -EAGAIN)
|
|
|
+ vq_err(vq, "Translation failure %d descriptor idx %d\n",
|
|
|
+ ret, i);
|
|
|
return ret;
|
|
|
}
|
|
|
- if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) {
|
|
|
+ if (access == VHOST_ACCESS_WO) {
|
|
|
/* If this is an input descriptor,
|
|
|
* increment that count. */
|
|
|
*in_num += ret;
|
|
@@ -1538,15 +2085,15 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
|
|
|
start = vq->last_used_idx & (vq->num - 1);
|
|
|
used = vq->used->ring + start;
|
|
|
if (count == 1) {
|
|
|
- if (__put_user(heads[0].id, &used->id)) {
|
|
|
+ if (vhost_put_user(vq, heads[0].id, &used->id)) {
|
|
|
vq_err(vq, "Failed to write used id");
|
|
|
return -EFAULT;
|
|
|
}
|
|
|
- if (__put_user(heads[0].len, &used->len)) {
|
|
|
+ if (vhost_put_user(vq, heads[0].len, &used->len)) {
|
|
|
vq_err(vq, "Failed to write used len");
|
|
|
return -EFAULT;
|
|
|
}
|
|
|
- } else if (__copy_to_user(used, heads, count * sizeof *used)) {
|
|
|
+ } else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) {
|
|
|
vq_err(vq, "Failed to write used");
|
|
|
return -EFAULT;
|
|
|
}
|
|
@@ -1590,7 +2137,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
|
|
|
|
|
|
/* Make sure buffer is written before we update index. */
|
|
|
smp_wmb();
|
|
|
- if (__put_user(cpu_to_vhost16(vq, vq->last_used_idx), &vq->used->idx)) {
|
|
|
+ if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
|
|
|
+ &vq->used->idx)) {
|
|
|
vq_err(vq, "Failed to increment used idx");
|
|
|
return -EFAULT;
|
|
|
}
|
|
@@ -1622,7 +2170,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
|
|
|
|
|
|
if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
|
|
|
__virtio16 flags;
|
|
|
- if (__get_user(flags, &vq->avail->flags)) {
|
|
|
+ if (vhost_get_user(vq, flags, &vq->avail->flags)) {
|
|
|
vq_err(vq, "Failed to get flags");
|
|
|
return true;
|
|
|
}
|
|
@@ -1636,7 +2184,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
|
|
|
if (unlikely(!v))
|
|
|
return true;
|
|
|
|
|
|
- if (__get_user(event, vhost_used_event(vq))) {
|
|
|
+ if (vhost_get_user(vq, event, vhost_used_event(vq))) {
|
|
|
vq_err(vq, "Failed to get used event idx");
|
|
|
return true;
|
|
|
}
|
|
@@ -1678,7 +2226,7 @@ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
|
|
|
__virtio16 avail_idx;
|
|
|
int r;
|
|
|
|
|
|
- r = __get_user(avail_idx, &vq->avail->idx);
|
|
|
+ r = vhost_get_user(vq, avail_idx, &vq->avail->idx);
|
|
|
if (r)
|
|
|
return false;
|
|
|
|
|
@@ -1713,7 +2261,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
|
|
|
/* They could have slipped one in as we were doing that: make
|
|
|
* sure it's written, then check again. */
|
|
|
smp_mb();
|
|
|
- r = __get_user(avail_idx, &vq->avail->idx);
|
|
|
+ r = vhost_get_user(vq, avail_idx, &vq->avail->idx);
|
|
|
if (r) {
|
|
|
vq_err(vq, "Failed to check avail idx at %p: %d\n",
|
|
|
&vq->avail->idx, r);
|
|
@@ -1741,6 +2289,47 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(vhost_disable_notify);
|
|
|
|
|
|
+/* Create a new message. */
|
|
|
+struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
|
|
|
+{
|
|
|
+ struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL);
|
|
|
+ if (!node)
|
|
|
+ return NULL;
|
|
|
+ node->vq = vq;
|
|
|
+ node->msg.type = type;
|
|
|
+ return node;
|
|
|
+}
|
|
|
+EXPORT_SYMBOL_GPL(vhost_new_msg);
|
|
|
+
|
|
|
+void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
|
|
|
+ struct vhost_msg_node *node)
|
|
|
+{
|
|
|
+ spin_lock(&dev->iotlb_lock);
|
|
|
+ list_add_tail(&node->node, head);
|
|
|
+ spin_unlock(&dev->iotlb_lock);
|
|
|
+
|
|
|
+ wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
|
|
|
+}
|
|
|
+EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
|
|
|
+
|
|
|
+struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
|
|
|
+ struct list_head *head)
|
|
|
+{
|
|
|
+ struct vhost_msg_node *node = NULL;
|
|
|
+
|
|
|
+ spin_lock(&dev->iotlb_lock);
|
|
|
+ if (!list_empty(head)) {
|
|
|
+ node = list_first_entry(head, struct vhost_msg_node,
|
|
|
+ node);
|
|
|
+ list_del(&node->node);
|
|
|
+ }
|
|
|
+ spin_unlock(&dev->iotlb_lock);
|
|
|
+
|
|
|
+ return node;
|
|
|
+}
|
|
|
+EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
|
|
|
+
|
|
|
+
|
|
|
static int __init vhost_init(void)
|
|
|
{
|
|
|
return 0;
|