@@ -115,34 +115,35 @@ static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp)
115
115
}
116
116
117
117
/* Account for a new mmu notifier in an ib_ucontext. */
118
- static void ib_ucontext_notifier_start_account (struct ib_ucontext * context )
118
+ static void
119
+ ib_ucontext_notifier_start_account (struct ib_ucontext_per_mm * per_mm )
119
120
{
120
- atomic_inc (& context -> notifier_count );
121
+ atomic_inc (& per_mm -> notifier_count );
121
122
}
122
123
123
124
/* Account for a terminating mmu notifier in an ib_ucontext.
124
125
*
125
126
* Must be called with the ib_ucontext->umem_rwsem semaphore unlocked, since
126
127
* the function takes the semaphore itself. */
127
- static void ib_ucontext_notifier_end_account (struct ib_ucontext * context )
128
+ static void ib_ucontext_notifier_end_account (struct ib_ucontext_per_mm * per_mm )
128
129
{
129
- int zero_notifiers = atomic_dec_and_test (& context -> notifier_count );
130
+ int zero_notifiers = atomic_dec_and_test (& per_mm -> notifier_count );
130
131
131
132
if (zero_notifiers &&
132
- !list_empty (& context -> no_private_counters )) {
133
+ !list_empty (& per_mm -> no_private_counters )) {
133
134
/* No currently running mmu notifiers. Now is the chance to
134
135
* add private accounting to all previously added umems. */
135
136
struct ib_umem_odp * odp_data , * next ;
136
137
137
138
/* Prevent concurrent mmu notifiers from working on the
138
139
* no_private_counters list. */
139
- down_write (& context -> umem_rwsem );
140
+ down_write (& per_mm -> umem_rwsem );
140
141
141
142
/* Read the notifier_count again, with the umem_rwsem
142
143
* semaphore taken for write. */
143
- if (!atomic_read (& context -> notifier_count )) {
144
+ if (!atomic_read (& per_mm -> notifier_count )) {
144
145
list_for_each_entry_safe (odp_data , next ,
145
- & context -> no_private_counters ,
146
+ & per_mm -> no_private_counters ,
146
147
no_private_counters ) {
147
148
mutex_lock (& odp_data -> umem_mutex );
148
149
odp_data -> mn_counters_active = true;
@@ -152,7 +153,7 @@ static void ib_ucontext_notifier_end_account(struct ib_ucontext *context)
152
153
}
153
154
}
154
155
155
- up_write (& context -> umem_rwsem );
156
+ up_write (& per_mm -> umem_rwsem );
156
157
}
157
158
}
158
159
@@ -179,19 +180,20 @@ static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp,
179
180
static void ib_umem_notifier_release (struct mmu_notifier * mn ,
180
181
struct mm_struct * mm )
181
182
{
182
- struct ib_ucontext * context = container_of (mn , struct ib_ucontext , mn );
183
+ struct ib_ucontext_per_mm * per_mm =
184
+ container_of (mn , struct ib_ucontext_per_mm , mn );
183
185
184
- if (!context -> invalidate_range )
186
+ if (!per_mm -> context -> invalidate_range )
185
187
return ;
186
188
187
- ib_ucontext_notifier_start_account (context );
188
- down_read (& context -> umem_rwsem );
189
- rbt_ib_umem_for_each_in_range (& context -> umem_tree , 0 ,
189
+ ib_ucontext_notifier_start_account (per_mm );
190
+ down_read (& per_mm -> umem_rwsem );
191
+ rbt_ib_umem_for_each_in_range (& per_mm -> umem_tree , 0 ,
190
192
ULLONG_MAX ,
191
193
ib_umem_notifier_release_trampoline ,
192
194
true,
193
195
NULL );
194
- up_read (& context -> umem_rwsem );
196
+ up_read (& per_mm -> umem_rwsem );
195
197
}
196
198
197
199
static int invalidate_page_trampoline (struct ib_umem_odp * item , u64 start ,
@@ -217,23 +219,24 @@ static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
217
219
unsigned long end ,
218
220
bool blockable )
219
221
{
220
- struct ib_ucontext * context = container_of (mn , struct ib_ucontext , mn );
222
+ struct ib_ucontext_per_mm * per_mm =
223
+ container_of (mn , struct ib_ucontext_per_mm , mn );
221
224
int ret ;
222
225
223
- if (!context -> invalidate_range )
226
+ if (!per_mm -> context -> invalidate_range )
224
227
return 0 ;
225
228
226
229
if (blockable )
227
- down_read (& context -> umem_rwsem );
228
- else if (!down_read_trylock (& context -> umem_rwsem ))
230
+ down_read (& per_mm -> umem_rwsem );
231
+ else if (!down_read_trylock (& per_mm -> umem_rwsem ))
229
232
return - EAGAIN ;
230
233
231
- ib_ucontext_notifier_start_account (context );
232
- ret = rbt_ib_umem_for_each_in_range (& context -> umem_tree , start ,
234
+ ib_ucontext_notifier_start_account (per_mm );
235
+ ret = rbt_ib_umem_for_each_in_range (& per_mm -> umem_tree , start ,
233
236
end ,
234
237
invalidate_range_start_trampoline ,
235
238
blockable , NULL );
236
- up_read (& context -> umem_rwsem );
239
+ up_read (& per_mm -> umem_rwsem );
237
240
238
241
return ret ;
239
242
}
@@ -250,22 +253,23 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
250
253
unsigned long start ,
251
254
unsigned long end )
252
255
{
253
- struct ib_ucontext * context = container_of (mn , struct ib_ucontext , mn );
256
+ struct ib_ucontext_per_mm * per_mm =
257
+ container_of (mn , struct ib_ucontext_per_mm , mn );
254
258
255
- if (!context -> invalidate_range )
259
+ if (!per_mm -> context -> invalidate_range )
256
260
return ;
257
261
258
262
/*
259
263
* TODO: we currently bail out if there is any sleepable work to be done
260
264
* in ib_umem_notifier_invalidate_range_start so we shouldn't really block
261
265
* here. But this is ugly and fragile.
262
266
*/
263
- down_read (& context -> umem_rwsem );
264
- rbt_ib_umem_for_each_in_range (& context -> umem_tree , start ,
267
+ down_read (& per_mm -> umem_rwsem );
268
+ rbt_ib_umem_for_each_in_range (& per_mm -> umem_tree , start ,
265
269
end ,
266
270
invalidate_range_end_trampoline , true, NULL );
267
- up_read (& context -> umem_rwsem );
268
- ib_ucontext_notifier_end_account (context );
271
+ up_read (& per_mm -> umem_rwsem );
272
+ ib_ucontext_notifier_end_account (per_mm );
269
273
}
270
274
271
275
static const struct mmu_notifier_ops ib_umem_notifiers = {
@@ -277,6 +281,7 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
277
281
struct ib_umem_odp * ib_alloc_odp_umem (struct ib_ucontext * context ,
278
282
unsigned long addr , size_t size )
279
283
{
284
+ struct ib_ucontext_per_mm * per_mm ;
280
285
struct ib_umem_odp * odp_data ;
281
286
struct ib_umem * umem ;
282
287
int pages = size >> PAGE_SHIFT ;
@@ -292,6 +297,7 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
292
297
umem -> page_shift = PAGE_SHIFT ;
293
298
umem -> writable = 1 ;
294
299
umem -> is_odp = 1 ;
300
+ odp_data -> per_mm = per_mm = & context -> per_mm ;
295
301
296
302
mutex_init (& odp_data -> umem_mutex );
297
303
init_completion (& odp_data -> notifier_completion );
@@ -310,15 +316,15 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
310
316
goto out_page_list ;
311
317
}
312
318
313
- down_write (& context -> umem_rwsem );
314
- context -> odp_mrs_count ++ ;
315
- rbt_ib_umem_insert (& odp_data -> interval_tree , & context -> umem_tree );
316
- if (likely (!atomic_read (& context -> notifier_count )))
319
+ down_write (& per_mm -> umem_rwsem );
320
+ per_mm -> odp_mrs_count ++ ;
321
+ rbt_ib_umem_insert (& odp_data -> interval_tree , & per_mm -> umem_tree );
322
+ if (likely (!atomic_read (& per_mm -> notifier_count )))
317
323
odp_data -> mn_counters_active = true;
318
324
else
319
325
list_add (& odp_data -> no_private_counters ,
320
- & context -> no_private_counters );
321
- up_write (& context -> umem_rwsem );
326
+ & per_mm -> no_private_counters );
327
+ up_write (& per_mm -> umem_rwsem );
322
328
323
329
return odp_data ;
324
330
@@ -334,6 +340,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
334
340
{
335
341
struct ib_ucontext * context = umem_odp -> umem .context ;
336
342
struct ib_umem * umem = & umem_odp -> umem ;
343
+ struct ib_ucontext_per_mm * per_mm ;
337
344
int ret_val ;
338
345
struct pid * our_pid ;
339
346
struct mm_struct * mm = get_task_mm (current );
@@ -396,36 +403,38 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
396
403
* notification before the "current" task (and MM) is
397
404
* destroyed. We use the umem_rwsem semaphore to synchronize.
398
405
*/
399
- down_write (& context -> umem_rwsem );
400
- context -> odp_mrs_count ++ ;
406
+ umem_odp -> per_mm = per_mm = & context -> per_mm ;
407
+
408
+ down_write (& per_mm -> umem_rwsem );
409
+ per_mm -> odp_mrs_count ++ ;
401
410
if (likely (ib_umem_start (umem ) != ib_umem_end (umem )))
402
411
rbt_ib_umem_insert (& umem_odp -> interval_tree ,
403
- & context -> umem_tree );
404
- if (likely (!atomic_read (& context -> notifier_count )) ||
405
- context -> odp_mrs_count == 1 )
412
+ & per_mm -> umem_tree );
413
+ if (likely (!atomic_read (& per_mm -> notifier_count )) ||
414
+ per_mm -> odp_mrs_count == 1 )
406
415
umem_odp -> mn_counters_active = true;
407
416
else
408
417
list_add (& umem_odp -> no_private_counters ,
409
- & context -> no_private_counters );
410
- downgrade_write (& context -> umem_rwsem );
418
+ & per_mm -> no_private_counters );
419
+ downgrade_write (& per_mm -> umem_rwsem );
411
420
412
- if (context -> odp_mrs_count == 1 ) {
421
+ if (per_mm -> odp_mrs_count == 1 ) {
413
422
/*
414
423
* Note that at this point, no MMU notifier is running
415
- * for this context !
424
+ * for this per_mm !
416
425
*/
417
- atomic_set (& context -> notifier_count , 0 );
418
- INIT_HLIST_NODE (& context -> mn .hlist );
419
- context -> mn .ops = & ib_umem_notifiers ;
420
- ret_val = mmu_notifier_register (& context -> mn , mm );
426
+ atomic_set (& per_mm -> notifier_count , 0 );
427
+ INIT_HLIST_NODE (& per_mm -> mn .hlist );
428
+ per_mm -> mn .ops = & ib_umem_notifiers ;
429
+ ret_val = mmu_notifier_register (& per_mm -> mn , mm );
421
430
if (ret_val ) {
422
431
pr_err ("Failed to register mmu_notifier %d\n" , ret_val );
423
432
ret_val = - EBUSY ;
424
433
goto out_mutex ;
425
434
}
426
435
}
427
436
428
- up_read (& context -> umem_rwsem );
437
+ up_read (& per_mm -> umem_rwsem );
429
438
430
439
/*
431
440
* Note that doing an mmput can cause a notifier for the relevant mm.
@@ -437,7 +446,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
437
446
return 0 ;
438
447
439
448
out_mutex :
440
- up_read (& context -> umem_rwsem );
449
+ up_read (& per_mm -> umem_rwsem );
441
450
vfree (umem_odp -> dma_list );
442
451
out_page_list :
443
452
vfree (umem_odp -> page_list );
@@ -449,7 +458,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
449
458
void ib_umem_odp_release (struct ib_umem_odp * umem_odp )
450
459
{
451
460
struct ib_umem * umem = & umem_odp -> umem ;
452
- struct ib_ucontext * context = umem -> context ;
461
+ struct ib_ucontext_per_mm * per_mm = umem_odp -> per_mm ;
453
462
454
463
/*
455
464
* Ensure that no more pages are mapped in the umem.
@@ -460,11 +469,11 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
460
469
ib_umem_odp_unmap_dma_pages (umem_odp , ib_umem_start (umem ),
461
470
ib_umem_end (umem ));
462
471
463
- down_write (& context -> umem_rwsem );
472
+ down_write (& per_mm -> umem_rwsem );
464
473
if (likely (ib_umem_start (umem ) != ib_umem_end (umem )))
465
474
rbt_ib_umem_remove (& umem_odp -> interval_tree ,
466
- & context -> umem_tree );
467
- context -> odp_mrs_count -- ;
475
+ & per_mm -> umem_tree );
476
+ per_mm -> odp_mrs_count -- ;
468
477
if (!umem_odp -> mn_counters_active ) {
469
478
list_del (& umem_odp -> no_private_counters );
470
479
complete_all (& umem_odp -> notifier_completion );
@@ -477,13 +486,13 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
477
486
* that since we are doing it atomically, no other user could register
478
487
* and unregister while we do the check.
479
488
*/
480
- downgrade_write (& context -> umem_rwsem );
481
- if (!context -> odp_mrs_count ) {
489
+ downgrade_write (& per_mm -> umem_rwsem );
490
+ if (!per_mm -> odp_mrs_count ) {
482
491
struct task_struct * owning_process = NULL ;
483
492
struct mm_struct * owning_mm = NULL ;
484
493
485
- owning_process = get_pid_task ( context -> tgid ,
486
- PIDTYPE_PID );
494
+ owning_process =
495
+ get_pid_task ( umem_odp -> umem . context -> tgid , PIDTYPE_PID );
487
496
if (owning_process == NULL )
488
497
/*
489
498
* The process is already dead, notifier were removed
@@ -498,15 +507,15 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
498
507
* removed already.
499
508
*/
500
509
goto out_put_task ;
501
- mmu_notifier_unregister (& context -> mn , owning_mm );
510
+ mmu_notifier_unregister (& per_mm -> mn , owning_mm );
502
511
503
512
mmput (owning_mm );
504
513
505
514
out_put_task :
506
515
put_task_struct (owning_process );
507
516
}
508
517
out :
509
- up_read (& context -> umem_rwsem );
518
+ up_read (& per_mm -> umem_rwsem );
510
519
511
520
vfree (umem_odp -> dma_list );
512
521
vfree (umem_odp -> page_list );
0 commit comments