Skip to content

Commit 47283be

Browse files
committed
vhost: move memory pointer to VQs
commit 2ae7669 vhost: replace rcu with mutex replaced rcu sync for memory accesses with VQ mutex locl/unlock. This is correct since all accesses are under VQ mutex, but incomplete: we still do useless rcu lock/unlock operations, someone might copy this code into some other context where this won't be right. This use of RCU is also non standard and hard to understand. Let's copy the pointer to each VQ structure, this way the access rules become straight-forward, and there's no need for RCU anymore. Reported-by: Eric Dumazet <eric.dumazet@gmail.com> Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
1 parent ea16c51 commit 47283be

File tree

5 files changed

+33
-42
lines changed

5 files changed

+33
-42
lines changed

drivers/vhost/net.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ static void handle_tx(struct vhost_net *net)
374374
% UIO_MAXIOV == nvq->done_idx))
375375
break;
376376

377-
head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
377+
head = vhost_get_vq_desc(vq, vq->iov,
378378
ARRAY_SIZE(vq->iov),
379379
&out, &in,
380380
NULL, NULL);
@@ -506,7 +506,7 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
506506
r = -ENOBUFS;
507507
goto err;
508508
}
509-
r = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
509+
r = vhost_get_vq_desc(vq, vq->iov + seg,
510510
ARRAY_SIZE(vq->iov) - seg, &out,
511511
&in, log, log_num);
512512
if (unlikely(r < 0))

drivers/vhost/scsi.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ tcm_vhost_do_evt_work(struct vhost_scsi *vs, struct tcm_vhost_evt *evt)
606606

607607
again:
608608
vhost_disable_notify(&vs->dev, vq);
609-
head = vhost_get_vq_desc(&vs->dev, vq, vq->iov,
609+
head = vhost_get_vq_desc(vq, vq->iov,
610610
ARRAY_SIZE(vq->iov), &out, &in,
611611
NULL, NULL);
612612
if (head < 0) {
@@ -945,7 +945,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
945945
vhost_disable_notify(&vs->dev, vq);
946946

947947
for (;;) {
948-
head = vhost_get_vq_desc(&vs->dev, vq, vq->iov,
948+
head = vhost_get_vq_desc(vq, vq->iov,
949949
ARRAY_SIZE(vq->iov), &out, &in,
950950
NULL, NULL);
951951
pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n",

drivers/vhost/test.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ static void handle_vq(struct vhost_test *n)
5353
vhost_disable_notify(&n->dev, vq);
5454

5555
for (;;) {
56-
head = vhost_get_vq_desc(&n->dev, vq, vq->iov,
56+
head = vhost_get_vq_desc(vq, vq->iov,
5757
ARRAY_SIZE(vq->iov),
5858
&out, &in,
5959
NULL, NULL);

drivers/vhost/vhost.c

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include <linux/mmu_context.h>
1919
#include <linux/miscdevice.h>
2020
#include <linux/mutex.h>
21-
#include <linux/rcupdate.h>
2221
#include <linux/poll.h>
2322
#include <linux/file.h>
2423
#include <linux/highmem.h>
@@ -199,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
199198
vq->call_ctx = NULL;
200199
vq->call = NULL;
201200
vq->log_ctx = NULL;
201+
vq->memory = NULL;
202202
}
203203

204204
static int vhost_worker(void *data)
@@ -416,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
416416
/* Caller should have device mutex */
417417
void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)
418418
{
419+
int i;
420+
419421
vhost_dev_cleanup(dev, true);
420422

421423
/* Restore memory to default empty mapping. */
422424
memory->nregions = 0;
423-
RCU_INIT_POINTER(dev->memory, memory);
425+
dev->memory = memory;
426+
/* We don't need VQ locks below since vhost_dev_cleanup makes sure
427+
* VQs aren't running.
428+
*/
429+
for (i = 0; i < dev->nvqs; ++i)
430+
dev->vqs[i]->memory = memory;
424431
}
425432
EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
426433

@@ -463,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
463470
fput(dev->log_file);
464471
dev->log_file = NULL;
465472
/* No one will access memory at this point */
466-
kfree(rcu_dereference_protected(dev->memory,
467-
locked ==
468-
lockdep_is_held(&dev->mutex)));
469-
RCU_INIT_POINTER(dev->memory, NULL);
473+
kfree(dev->memory);
474+
dev->memory = NULL;
470475
WARN_ON(!list_empty(&dev->work_list));
471476
if (dev->worker) {
472477
kthread_stop(dev->worker);
@@ -558,11 +563,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
558563
/* Caller should have device mutex but not vq mutex */
559564
int vhost_log_access_ok(struct vhost_dev *dev)
560565
{
561-
struct vhost_memory *mp;
562-
563-
mp = rcu_dereference_protected(dev->memory,
564-
lockdep_is_held(&dev->mutex));
565-
return memory_access_ok(dev, mp, 1);
566+
return memory_access_ok(dev, dev->memory, 1);
566567
}
567568
EXPORT_SYMBOL_GPL(vhost_log_access_ok);
568569

@@ -571,12 +572,9 @@ EXPORT_SYMBOL_GPL(vhost_log_access_ok);
571572
static int vq_log_access_ok(struct vhost_virtqueue *vq,
572573
void __user *log_base)
573574
{
574-
struct vhost_memory *mp;
575575
size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
576576

577-
mp = rcu_dereference_protected(vq->dev->memory,
578-
lockdep_is_held(&vq->mutex));
579-
return vq_memory_access_ok(log_base, mp,
577+
return vq_memory_access_ok(log_base, vq->memory,
580578
vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
581579
(!vq->log_used || log_access_ok(log_base, vq->log_addr,
582580
sizeof *vq->used +
@@ -619,15 +617,13 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
619617
kfree(newmem);
620618
return -EFAULT;
621619
}
622-
oldmem = rcu_dereference_protected(d->memory,
623-
lockdep_is_held(&d->mutex));
624-
rcu_assign_pointer(d->memory, newmem);
620+
oldmem = d->memory;
621+
d->memory = newmem;
625622

626-
/* All memory accesses are done under some VQ mutex.
627-
* So below is a faster equivalent of synchronize_rcu()
628-
*/
623+
/* All memory accesses are done under some VQ mutex. */
629624
for (i = 0; i < d->nvqs; ++i) {
630625
mutex_lock(&d->vqs[i]->mutex);
626+
d->vqs[i]->memory = newmem;
631627
mutex_unlock(&d->vqs[i]->mutex);
632628
}
633629
kfree(oldmem);
@@ -1054,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq)
10541050
}
10551051
EXPORT_SYMBOL_GPL(vhost_init_used);
10561052

1057-
static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
1053+
static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
10581054
struct iovec iov[], int iov_size)
10591055
{
10601056
const struct vhost_memory_region *reg;
@@ -1063,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
10631059
u64 s = 0;
10641060
int ret = 0;
10651061

1066-
rcu_read_lock();
1067-
1068-
mem = rcu_dereference(dev->memory);
1062+
mem = vq->memory;
10691063
while ((u64)len > s) {
10701064
u64 size;
10711065
if (unlikely(ret >= iov_size)) {
@@ -1087,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
10871081
++ret;
10881082
}
10891083

1090-
rcu_read_unlock();
10911084
return ret;
10921085
}
10931086

@@ -1112,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc)
11121105
return next;
11131106
}
11141107

1115-
static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1108+
static int get_indirect(struct vhost_virtqueue *vq,
11161109
struct iovec iov[], unsigned int iov_size,
11171110
unsigned int *out_num, unsigned int *in_num,
11181111
struct vhost_log *log, unsigned int *log_num,
@@ -1131,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
11311124
return -EINVAL;
11321125
}
11331126

1134-
ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect,
1127+
ret = translate_desc(vq, indirect->addr, indirect->len, vq->indirect,
11351128
UIO_MAXIOV);
11361129
if (unlikely(ret < 0)) {
11371130
vq_err(vq, "Translation failure %d in indirect.\n", ret);
@@ -1171,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
11711164
return -EINVAL;
11721165
}
11731166

1174-
ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count,
1167+
ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
11751168
iov_size - iov_count);
11761169
if (unlikely(ret < 0)) {
11771170
vq_err(vq, "Translation failure %d indirect idx %d\n",
@@ -1208,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
12081201
* This function returns the descriptor number found, or vq->num (which is
12091202
* never a valid descriptor number) if none was found. A negative code is
12101203
* returned on error. */
1211-
int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1204+
int vhost_get_vq_desc(struct vhost_virtqueue *vq,
12121205
struct iovec iov[], unsigned int iov_size,
12131206
unsigned int *out_num, unsigned int *in_num,
12141207
struct vhost_log *log, unsigned int *log_num)
@@ -1282,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
12821275
return -EFAULT;
12831276
}
12841277
if (desc.flags & VRING_DESC_F_INDIRECT) {
1285-
ret = get_indirect(dev, vq, iov, iov_size,
1278+
ret = get_indirect(vq, iov, iov_size,
12861279
out_num, in_num,
12871280
log, log_num, &desc);
12881281
if (unlikely(ret < 0)) {
@@ -1293,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
12931286
continue;
12941287
}
12951288

1296-
ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count,
1289+
ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
12971290
iov_size - iov_count);
12981291
if (unlikely(ret < 0)) {
12991292
vq_err(vq, "Translation failure %d descriptor idx %d\n",

drivers/vhost/vhost.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ struct vhost_virtqueue {
104104
struct iovec *indirect;
105105
struct vring_used_elem *heads;
106106
/* Protected by virtqueue mutex. */
107+
struct vhost_memory *memory;
107108
void *private_data;
108109
unsigned acked_features;
109110
/* Log write descriptors */
@@ -112,10 +113,7 @@ struct vhost_virtqueue {
112113
};
113114

114115
struct vhost_dev {
115-
/* Readers use RCU to access memory table pointer
116-
* log base pointer and features.
117-
* Writers use mutex below.*/
118-
struct vhost_memory __rcu *memory;
116+
struct vhost_memory *memory;
119117
struct mm_struct *mm;
120118
struct mutex mutex;
121119
struct vhost_virtqueue **vqs;
@@ -140,7 +138,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp);
140138
int vhost_vq_access_ok(struct vhost_virtqueue *vq);
141139
int vhost_log_access_ok(struct vhost_dev *);
142140

143-
int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *,
141+
int vhost_get_vq_desc(struct vhost_virtqueue *,
144142
struct iovec iov[], unsigned int iov_count,
145143
unsigned int *out_num, unsigned int *in_num,
146144
struct vhost_log *log, unsigned int *log_num);

0 commit comments

Comments
 (0)