@@ -247,23 +247,35 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
247
247
void OperatorEntry::deregisterKernel_ (
248
248
const c10::Dispatcher& dispatcher,
249
249
std::optional<DispatchKey> dispatch_key,
250
- AnnotatedKernelContainerIterator kernel
251
- ) {
252
- // Redirect catchAll deregistrations to CompositeImplicitAutograd.
253
- DispatchKey dk = dispatch_key.has_value () ? *dispatch_key : DispatchKey::CompositeImplicitAutograd;
254
- auto found = kernels_.find (dk);
255
- TORCH_INTERNAL_ASSERT (found != kernels_.end (), " Tried to deregister a kernel for dispatch key " , toString (dispatch_key), " but there are no kernels registered for this dispatch key. The operator is " , toString (name_));
256
- auto & k = found->second ;
250
+ AnnotatedKernelContainerIterator kernel) {
251
+ // Redirect catchAll deregistrations to CompositeImplicitAutograd.
252
+ DispatchKey dk = dispatch_key.has_value ()
253
+ ? *dispatch_key
254
+ : DispatchKey::CompositeImplicitAutograd;
255
+ auto found = kernels_.find (dk);
256
+ TORCH_INTERNAL_ASSERT (
257
+ found != kernels_.end (),
258
+ " Tried to deregister a kernel for dispatch key " ,
259
+ toString (dispatch_key),
260
+ " but there are no kernels registered for this dispatch key. The operator is " ,
261
+ toString (name_));
262
+ auto & k = found->second ;
257
263
#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
258
- // We are about to remove the array from the map, no need to do anything.
264
+ // We are about to remove the array from the map, no need to do anything.
259
265
#else
260
- k.erase (kernel);
266
+ k.erase (kernel);
261
267
#endif
262
- if (k.empty ()) {
263
- // the invariant says we don't want empty lists but instead remove the list from the map
264
- kernels_.erase (found);
265
- }
266
- updateDispatchTable_ (dispatcher, dk);
268
+ if (k.empty ()) {
269
+ // the invariant says we don't want empty lists but instead remove the list
270
+ // from the map
271
+ kernels_.erase (found);
272
+ }
273
+ // The KernelFunction object in dispatchTable
274
+ // (1) s a copy of the object in kernels_
275
+ // (2) is the one that gets returned by getComputedKernelForDispatchKey
276
+ // so we cannot call invalidate tokens on kernel here and need to invalidate
277
+ // the tokens on that one instead.
278
+ updateDispatchTable_ (dispatcher, dk, /* invalidate_tokens=*/ true );
267
279
}
268
280
269
281
void OperatorEntry::updateFallback (const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
@@ -305,6 +317,17 @@ bool OperatorEntry::hasComputedKernelForDispatchKey(DispatchKey k) const {
305
317
return dispatchTable_[dispatch_ix].isValid ();
306
318
}
307
319
320
+ SafeKernelFunction OperatorEntry::getComputedKernelForDispatchKey (
321
+ DispatchKey k) const {
322
+ TORCH_CHECK (
323
+ !isAliasDispatchKey (k),
324
+ " Alias keys do not have runtime kernel registrations." );
325
+ const auto dispatch_ix = getDispatchTableIndexForDispatchKey (k);
326
+ TORCH_CHECK (dispatchTable_[dispatch_ix].isValid ())
327
+
328
+ return SafeKernelFunction (&dispatchTable_[dispatch_ix]);
329
+ }
330
+
308
331
const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey (DispatchKey dispatch_key) const {
309
332
auto kern_it = kernels_.find (dispatch_key);
310
333
if (kern_it != kernels_.end ()) {
@@ -450,46 +473,63 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
450
473
// dispatch keys (e.g. runtime keys and their associated autograd keys,
451
474
// or alias keys and their associated keysets).
452
475
// This function should be considered a private helper for updateDispatchTable_()
453
- void OperatorEntry::updateDispatchTableEntry_ (const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
476
+ void OperatorEntry::updateDispatchTableEntry_ (
477
+ const c10::Dispatcher& dispatcher,
478
+ DispatchKey dispatch_key,
479
+ bool invalidateTokens) {
454
480
const auto dispatch_ix = getDispatchTableIndexForDispatchKey (dispatch_key);
455
481
if (C10_UNLIKELY (dispatch_ix == -1 )) {
456
482
return ;
457
483
}
458
- dispatchTable_[dispatch_ix] = computeDispatchTableEntry (dispatcher, dispatch_key);
459
- dispatchKeyExtractor_.setOperatorHasFallthroughForKey (dispatch_key, dispatchTable_[dispatch_ix].isFallthrough ());
484
+ // Invalidate tokens for the old dispatch table entry before replacing it
485
+ if (invalidateTokens) {
486
+ dispatchTable_[dispatch_ix].invalidateTokens ();
487
+ }
488
+ dispatchTable_[dispatch_ix] =
489
+ computeDispatchTableEntry (dispatcher, dispatch_key);
490
+ dispatchKeyExtractor_.setOperatorHasFallthroughForKey (
491
+ dispatch_key, dispatchTable_[dispatch_ix].isFallthrough ());
460
492
}
461
493
462
494
// synchronizes the dispatch table entries for a given dispatch key *and its
463
495
// associated keys* with the current state of kernel registrations in the
464
496
// dispatcher.
465
497
// After a kernel has been registered to a dispatch key, a call to this
466
498
// function will synchronize the dispatcher state. See e.g. registerKernel()
467
- void OperatorEntry::updateDispatchTable_ (const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
468
- // Handle Undefined separately since it isn't a runtime key but we have an entry in dispatchTable_.
469
- // See Note [Undefined in dispatchTable_]
499
+ void OperatorEntry::updateDispatchTable_ (
500
+ const c10::Dispatcher& dispatcher,
501
+ DispatchKey dispatch_key,
502
+ bool invalidate_tokens) {
503
+ // Handle Undefined separately since it isn't a runtime key but we have an
504
+ // entry in dispatchTable_. See Note [Undefined in dispatchTable_]
470
505
if (dispatch_key == DispatchKey::Undefined) {
471
- updateDispatchTableEntry_ (dispatcher, dispatch_key);
506
+ updateDispatchTableEntry_ (dispatcher, dispatch_key, invalidate_tokens );
472
507
return ;
473
508
}
474
509
for (auto k : c10::getRuntimeDispatchKeySet (dispatch_key)) {
475
- updateDispatchTableEntry_ (dispatcher, k);
476
- }
477
- // Registration to CompositeExplicitAutogradNonFunctional, CompositeExplicitAutograd and CompositeImplicitAutograd should be populated to Undefined.
478
- // We cannot do this above since Undefined cannot be represented in DispatchKeySet.
479
- if (dispatch_key == DispatchKey::CompositeImplicitAutograd
480
- || dispatch_key == DispatchKey::CompositeExplicitAutograd
481
- || dispatch_key == DispatchKey::CompositeExplicitAutogradNonFunctional) {
482
- updateDispatchTableEntry_ (dispatcher, DispatchKey::Undefined);
510
+ updateDispatchTableEntry_ (dispatcher, k, invalidate_tokens);
511
+ }
512
+ // Registration to CompositeExplicitAutogradNonFunctional,
513
+ // CompositeExplicitAutograd and CompositeImplicitAutograd should be populated
514
+ // to Undefined. We cannot do this above since Undefined cannot be represented
515
+ // in DispatchKeySet.
516
+ if (dispatch_key == DispatchKey::CompositeImplicitAutograd ||
517
+ dispatch_key == DispatchKey::CompositeExplicitAutograd ||
518
+ dispatch_key == DispatchKey::CompositeExplicitAutogradNonFunctional) {
519
+ updateDispatchTableEntry_ (
520
+ dispatcher, DispatchKey::Undefined, invalidate_tokens);
483
521
}
484
522
// Note [Refresh Runtime Autograd entries in dispatchTable_]
485
- // Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
486
- // In theory, we should only have to check if the given runtime key has "dense" functionality,
487
- // e.g. DispatchKey::CPU (which is composed of DispatchKey::Dense and BackendComponent::CPUBit).
488
- // However, there are some backends that should be included in this set that don't have the dense key set.
489
- // E.g. DispatchKey::Meta, DispatchKey::MAIA.
523
+ // Registering to backend key might affect computed entry at its Autograd
524
+ // backend key due to (2.1) & (2.3). In theory, we should only have to check
525
+ // if the given runtime key has "dense" functionality, e.g. DispatchKey::CPU
526
+ // (which is composed of DispatchKey::Dense and BackendComponent::CPUBit).
527
+ // However, there are some backends that should be included in this set that
528
+ // don't have the dense key set. E.g. DispatchKey::Meta, DispatchKey::MAIA.
490
529
if (c10::isBackendDispatchKey (dispatch_key)) {
491
- DispatchKey autograd_key = getAutogradKeyFromBackend (toBackendComponent (dispatch_key));
492
- updateDispatchTableEntry_ (dispatcher, autograd_key);
530
+ DispatchKey autograd_key =
531
+ getAutogradKeyFromBackend (toBackendComponent (dispatch_key));
532
+ updateDispatchTableEntry_ (dispatcher, autograd_key, invalidate_tokens);
493
533
}
494
534
}
495
535
0 commit comments