18
18
#include <linux/mmu_context.h>
19
19
#include <linux/miscdevice.h>
20
20
#include <linux/mutex.h>
21
- #include <linux/rcupdate.h>
22
21
#include <linux/poll.h>
23
22
#include <linux/file.h>
24
23
#include <linux/highmem.h>
@@ -199,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
199
198
vq -> call_ctx = NULL ;
200
199
vq -> call = NULL ;
201
200
vq -> log_ctx = NULL ;
201
+ vq -> memory = NULL ;
202
202
}
203
203
204
204
static int vhost_worker (void * data )
@@ -416,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
416
416
/* Caller should have device mutex */
417
417
void vhost_dev_reset_owner (struct vhost_dev * dev , struct vhost_memory * memory )
418
418
{
419
+ int i ;
420
+
419
421
vhost_dev_cleanup (dev , true);
420
422
421
423
/* Restore memory to default empty mapping. */
422
424
memory -> nregions = 0 ;
423
- RCU_INIT_POINTER (dev -> memory , memory );
425
+ dev -> memory = memory ;
426
+ /* We don't need VQ locks below since vhost_dev_cleanup makes sure
427
+ * VQs aren't running.
428
+ */
429
+ for (i = 0 ; i < dev -> nvqs ; ++ i )
430
+ dev -> vqs [i ]-> memory = memory ;
424
431
}
425
432
EXPORT_SYMBOL_GPL (vhost_dev_reset_owner );
426
433
@@ -463,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
463
470
fput (dev -> log_file );
464
471
dev -> log_file = NULL ;
465
472
/* No one will access memory at this point */
466
- kfree (rcu_dereference_protected (dev -> memory ,
467
- locked ==
468
- lockdep_is_held (& dev -> mutex )));
469
- RCU_INIT_POINTER (dev -> memory , NULL );
473
+ kfree (dev -> memory );
474
+ dev -> memory = NULL ;
470
475
WARN_ON (!list_empty (& dev -> work_list ));
471
476
if (dev -> worker ) {
472
477
kthread_stop (dev -> worker );
@@ -558,11 +563,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
558
563
/* Caller should have device mutex but not vq mutex */
559
564
int vhost_log_access_ok (struct vhost_dev * dev )
560
565
{
561
- struct vhost_memory * mp ;
562
-
563
- mp = rcu_dereference_protected (dev -> memory ,
564
- lockdep_is_held (& dev -> mutex ));
565
- return memory_access_ok (dev , mp , 1 );
566
+ return memory_access_ok (dev , dev -> memory , 1 );
566
567
}
567
568
EXPORT_SYMBOL_GPL (vhost_log_access_ok );
568
569
@@ -571,12 +572,9 @@ EXPORT_SYMBOL_GPL(vhost_log_access_ok);
571
572
static int vq_log_access_ok (struct vhost_virtqueue * vq ,
572
573
void __user * log_base )
573
574
{
574
- struct vhost_memory * mp ;
575
575
size_t s = vhost_has_feature (vq , VIRTIO_RING_F_EVENT_IDX ) ? 2 : 0 ;
576
576
577
- mp = rcu_dereference_protected (vq -> dev -> memory ,
578
- lockdep_is_held (& vq -> mutex ));
579
- return vq_memory_access_ok (log_base , mp ,
577
+ return vq_memory_access_ok (log_base , vq -> memory ,
580
578
vhost_has_feature (vq , VHOST_F_LOG_ALL )) &&
581
579
(!vq -> log_used || log_access_ok (log_base , vq -> log_addr ,
582
580
sizeof * vq -> used +
@@ -619,15 +617,13 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
619
617
kfree (newmem );
620
618
return - EFAULT ;
621
619
}
622
- oldmem = rcu_dereference_protected (d -> memory ,
623
- lockdep_is_held (& d -> mutex ));
624
- rcu_assign_pointer (d -> memory , newmem );
620
+ oldmem = d -> memory ;
621
+ d -> memory = newmem ;
625
622
626
- /* All memory accesses are done under some VQ mutex.
627
- * So below is a faster equivalent of synchronize_rcu()
628
- */
623
+ /* All memory accesses are done under some VQ mutex. */
629
624
for (i = 0 ; i < d -> nvqs ; ++ i ) {
630
625
mutex_lock (& d -> vqs [i ]-> mutex );
626
+ d -> vqs [i ]-> memory = newmem ;
631
627
mutex_unlock (& d -> vqs [i ]-> mutex );
632
628
}
633
629
kfree (oldmem );
@@ -1054,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq)
1054
1050
}
1055
1051
EXPORT_SYMBOL_GPL (vhost_init_used );
1056
1052
1057
- static int translate_desc (struct vhost_dev * dev , u64 addr , u32 len ,
1053
+ static int translate_desc (struct vhost_virtqueue * vq , u64 addr , u32 len ,
1058
1054
struct iovec iov [], int iov_size )
1059
1055
{
1060
1056
const struct vhost_memory_region * reg ;
@@ -1063,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
1063
1059
u64 s = 0 ;
1064
1060
int ret = 0 ;
1065
1061
1066
- rcu_read_lock ();
1067
-
1068
- mem = rcu_dereference (dev -> memory );
1062
+ mem = vq -> memory ;
1069
1063
while ((u64 )len > s ) {
1070
1064
u64 size ;
1071
1065
if (unlikely (ret >= iov_size )) {
@@ -1087,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
1087
1081
++ ret ;
1088
1082
}
1089
1083
1090
- rcu_read_unlock ();
1091
1084
return ret ;
1092
1085
}
1093
1086
@@ -1112,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc)
1112
1105
return next ;
1113
1106
}
1114
1107
1115
- static int get_indirect (struct vhost_dev * dev , struct vhost_virtqueue * vq ,
1108
+ static int get_indirect (struct vhost_virtqueue * vq ,
1116
1109
struct iovec iov [], unsigned int iov_size ,
1117
1110
unsigned int * out_num , unsigned int * in_num ,
1118
1111
struct vhost_log * log , unsigned int * log_num ,
@@ -1131,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1131
1124
return - EINVAL ;
1132
1125
}
1133
1126
1134
- ret = translate_desc (dev , indirect -> addr , indirect -> len , vq -> indirect ,
1127
+ ret = translate_desc (vq , indirect -> addr , indirect -> len , vq -> indirect ,
1135
1128
UIO_MAXIOV );
1136
1129
if (unlikely (ret < 0 )) {
1137
1130
vq_err (vq , "Translation failure %d in indirect.\n" , ret );
@@ -1171,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1171
1164
return - EINVAL ;
1172
1165
}
1173
1166
1174
- ret = translate_desc (dev , desc .addr , desc .len , iov + iov_count ,
1167
+ ret = translate_desc (vq , desc .addr , desc .len , iov + iov_count ,
1175
1168
iov_size - iov_count );
1176
1169
if (unlikely (ret < 0 )) {
1177
1170
vq_err (vq , "Translation failure %d indirect idx %d\n" ,
@@ -1208,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1208
1201
* This function returns the descriptor number found, or vq->num (which is
1209
1202
* never a valid descriptor number) if none was found. A negative code is
1210
1203
* returned on error. */
1211
- int vhost_get_vq_desc (struct vhost_dev * dev , struct vhost_virtqueue * vq ,
1204
+ int vhost_get_vq_desc (struct vhost_virtqueue * vq ,
1212
1205
struct iovec iov [], unsigned int iov_size ,
1213
1206
unsigned int * out_num , unsigned int * in_num ,
1214
1207
struct vhost_log * log , unsigned int * log_num )
@@ -1282,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1282
1275
return - EFAULT ;
1283
1276
}
1284
1277
if (desc .flags & VRING_DESC_F_INDIRECT ) {
1285
- ret = get_indirect (dev , vq , iov , iov_size ,
1278
+ ret = get_indirect (vq , iov , iov_size ,
1286
1279
out_num , in_num ,
1287
1280
log , log_num , & desc );
1288
1281
if (unlikely (ret < 0 )) {
@@ -1293,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1293
1286
continue ;
1294
1287
}
1295
1288
1296
- ret = translate_desc (dev , desc .addr , desc .len , iov + iov_count ,
1289
+ ret = translate_desc (vq , desc .addr , desc .len , iov + iov_count ,
1297
1290
iov_size - iov_count );
1298
1291
if (unlikely (ret < 0 )) {
1299
1292
vq_err (vq , "Translation failure %d descriptor idx %d\n" ,
0 commit comments