-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR] Avoid resolving callable outside the analysis scope in DeadCodeAnalysis (NFC) #155088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…eAnalysis (NFC) We are using the symbol table machinery to lookup for a callable, but when the analysis scope if a function, such lookup will resolve outside of the scope. This can lead to race-condition issues since other passes may operate in parallel on the sibling functions. The callable would be discarded right after the lookup (we check the analysis scope), so avoiding the lookup is NFC. Fix llvm#154948
@llvm/pr-subscribers-mlir Author: Mehdi Amini (joker-eph) ChangesWe are using the symbol table machinery to lookup for a callable, but when the analysis scope if a function, such lookup will resolve outside of the scope. This can lead to race-condition issues since other passes may operate in parallel on the sibling functions. Fix #154948 Full diff: https://github.com/llvm/llvm-project/pull/155088.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 2250db823b551..c7c405e1423cb 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -229,6 +229,13 @@ class DeadCodeAnalysis : public DataFlowAnalysis {
/// considered an external callable.
Operation *analysisScope;
+ /// Whether the analysis scope has a symbol table. This is used to avoid
+ /// resolving callables outside the analysis scope.
+ /// It is updated when recursing into a region in case where the top-level
+ /// operation does not have a symbol table, but one is encountered in a nested
+ /// region.
+ bool hasSymbolTable = false;
+
/// A symbol table used for O(1) symbol lookups during simplification.
SymbolTableCollection symbolTable;
};
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 9424eff3e6b6f..131c49c44171b 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -22,6 +22,7 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
@@ -159,6 +160,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
LDBG() << "[init] Entering initializeSymbolCallables for top-level op: "
<< OpWithFlags(top, OpPrintingFlags().skipRegions());
analysisScope = top;
+ hasSymbolTable = top->hasTrait<OpTrait::SymbolTable>();
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
LDBG() << "[init] Processing symbol table op: "
<< OpWithFlags(symTable, OpPrintingFlags().skipRegions());
@@ -260,14 +262,25 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
return failure();
}
// Recurse on nested operations.
- for (Region ®ion : op->getRegions()) {
- LDBG() << "[init] Recursing into region of op: "
- << OpWithFlags(op, OpPrintingFlags().skipRegions());
- for (Operation &nestedOp : region.getOps()) {
- LDBG() << "[init] Recursing into nested op: "
- << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
- if (failed(initializeRecursively(&nestedOp)))
- return failure();
+ if (op->getNumRegions()) {
+ // If we haven't seen a symbol table yet, check if the current operation
+ // has one. If so, update the flag to allow for resolving callables in
+ // nested regions.
+ bool savedHasSymbolTable = hasSymbolTable;
+ auto restoreHasSymbolTable =
+ llvm::make_scope_exit([&]() { hasSymbolTable = savedHasSymbolTable; });
+ if (!hasSymbolTable && op->hasTrait<OpTrait::SymbolTable>())
+ hasSymbolTable = true;
+
+ for (Region ®ion : op->getRegions()) {
+ LDBG() << "[init] Recursing into region of op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ for (Operation &nestedOp : region.getOps()) {
+ LDBG() << "[init] Recursing into nested op: "
+ << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
+ if (failed(initializeRecursively(&nestedOp)))
+ return failure();
+ }
}
}
LDBG() << "[init] Finished initializeRecursively for op: "
@@ -388,7 +401,13 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
LDBG() << "visitCallOperation: "
<< OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
- Operation *callableOp = call.resolveCallableInTable(&symbolTable);
+
+ Operation *callableOp = nullptr;
+ if (hasSymbolTable)
+ callableOp = call.resolveCallableInTable(&symbolTable);
+ else
+ LDBG()
+ << "No symbol table present in analysis scope, can't resolve callable";
// A call to a externally-defined callable has unknown predecessors.
const auto isExternalCallable = [this](Operation *op) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @joker-eph ! I tested this against the original error and this solves it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I approved a little bit too quickly. I re-ran again and as I mentioned on the original issue, the error moves to a different place.
I am attaching the traces from TSan.
Read of size 8 at 0x722800004078 by thread T8: #0 mlir::Attribute::getDialect() const /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/Attributes.h:59:12 (libIREECompiler.so+0x7dc2799) (BuildId: 366a285014486d45)
#1 mlir::Attribute::getContext() const /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/Attributes.cpp:37:53 (libIREECompiler.so+0x7dc2799)
#2 mlir::Operation::getContext() /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/Operation.h:216:48 (libIREECompiler.so+0xfe88741) (BuildId: 366a285014486d45)
#3 mlir::RegisteredOperationName::Model<mlir::iree_compiler::IREE::Util::FuncOp>::getInherentAttr(mlir::Operation*, llvm::StringRef) /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/OperationSupport.h:570:56 (libIREECompiler.so+0xfe88741)
#4 mlir::OperationName::getInherentAttr(mlir::Operation*, llvm::StringRef) const /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/OperationSupport.h:393:23 (libIREECompiler.so+0x7eaf7c4) (BuildId: 366a285014486d45)
#5 mlir::Operation::getInherentAttr(llvm::StringRef) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/Operation.cpp:341:20 (libIREECompiler.so+0x7eaf7c4)
#6 mlir::Operation::getAttr(mlir::StringAttr) /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/Operation.h:536:51 (libIREECompiler.so+0x7ed8b86) (BuildId: 366a285014486d45)
#7 mlir::StringAttr mlir::Operation::getAttrOfType<mlir::StringAttr>(mlir::StringAttr) /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/Operation.h:551:46 (libIREECompiler.so+0x7ed8b86)
#8 getNameIfSymbol(mlir::Operation*, mlir::StringAttr) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:31:14 (libIREECompiler.so+0x7ed8b86)
#9 mlir::SymbolTable::lookupSymbolIn(mlir::Operation*, mlir::StringAttr) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:395:9 (libIREECompiler.so+0x7ed8b86)
#10 mlir::SymbolTable::lookupSymbolIn(mlir::Operation*, mlir::SymbolRefAttr, llvm::SmallVectorImpl<mlir::Operation*>&)::$_0::operator()(mlir::Operation*, mlir::StringAttr) const /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:446:12 (libIREECompiler.so+0x7ee03b0) (BuildId: 366a285014486d45)
#11 mlir::Operation* llvm::function_ref<mlir::Operation* (mlir::Operation*, mlir::StringAttr)>::callback_fn<mlir::SymbolTable::lookupSymbolIn(mlir::Operation*, mlir::SymbolRefAttr, llvm::SmallVectorImpl<mlir::Operation*>&)::$_0>(long, mlir::Operation*, mlir::StringAttr) /home/eochoalo/code/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46:12 (libIREECompiler.so+0x7ee03b0)
#12 llvm::function_ref<mlir::Operation* (mlir::Operation*, mlir::StringAttr)>::operator()(mlir::Operation*, mlir::StringAttr) const /home/eochoalo/code/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69:12 (libIREECompiler.so+0x7ed8e9a) (BuildId: 366a285014486d45)
#13 lookupSymbolInImpl(mlir::Operation*, mlir::SymbolRefAttr, llvm::SmallVectorImpl<mlir::Operation*>&, llvm::function_ref<mlir::Operation* (mlir::Operation*, mlir::StringAttr)>) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:416:19 (libIREECompiler.so+0x7ed8e9a)
#14 mlir::SymbolTable::lookupSymbolIn(mlir::Operation*, mlir::SymbolRefAttr, llvm::SmallVectorImpl<mlir::Operation*>&) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:448:10 (libIREECompiler.so+0x7ed9438) (BuildId: 366a285014486d45)
#15 mlir::SymbolTable::lookupSymbolIn(mlir::Operation*, mlir::SymbolRefAttr) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:402:14 (libIREECompiler.so+0x7ed9438)
#16 mlir::SymbolTable::lookupNearestSymbolFrom(mlir::Operation*, mlir::SymbolRefAttr) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:462:26 (libIREECompiler.so+0x7ed9438)
#17 mlir::call_interface_impl::resolveCallable(mlir::CallOpInterface, mlir::SymbolTableCollection*) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Interfaces/CallInterfaces.cpp:197:10 (libIREECompiler.so+0x10791d36) (BuildId: 366a285014486d45)
#18 mlir::detail::CallOpInterfaceTrait<mlir::iree_compiler::IREE::Util::CallOp>::resolveCallable() /home/eochoalo/code/iree-build/llvm-project/tools/mlir/include/mlir/Interfaces/CallInterfaces.h.inc:252:14 (libIREECompiler.so+0xfe81bad) (BuildId: 366a285014486d45)
#19 mlir::detail::CallOpInterfaceInterfaceTraits::Model<mlir::iree_compiler::IREE::Util::CallOp>::resolveCallable(mlir::detail::CallOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*) /home/eochoalo/code/iree-build/llvm-project/tools/mlir/include/mlir/Interfaces/CallInterfaces.h.inc:445:56 (libIREECompiler.so+0xfe81bad)
#20 mlir::CallOpInterface::resolveCallable() /home/eochoalo/code/iree-build/llvm-project/tools/mlir/include/mlir/Interfaces/CallInterfaces.cpp.inc:90:14 (libIREECompiler.so+0x107921f0) (BuildId: 366a285014486d45)
#21 mlir::dataflow::AbstractSparseForwardDataFlowAnalysis::visitCallOperation(mlir::CallOpInterface, llvm::ArrayRef<mlir::dataflow::AbstractSparseLattice const*>, llvm::ArrayRef<mlir::dataflow::AbstractSparseLattice*>) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp:232:53 (libIREECompiler.so+0x10756f68) (BuildId: 366a285014486d45)
#22 mlir::dataflow::AbstractSparseForwardDataFlowAnalysis::visitOperation(mlir::Operation*) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp:148:12 (libIREECompiler.so+0x107561b7) (BuildId: 366a285014486d45)
#23 mlir::dataflow::AbstractSparseForwardDataFlowAnalysis::visit(mlir::ProgramPoint*) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp:106:12 (libIREECompiler.so+0x10756dd2) (BuildId: 366a285014486d45)
#24 mlir::DataFlowSolver::initializeAndRun(mlir::Operation*) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Analysis/DataFlowFramework.cpp:132:26 (libIREECompiler.so+0x10715b59) (BuildId: 366a285014486d45)
#25 (anonymous namespace)::ArithUnsignedWhenEquivalentPass::runOnOperation() /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp:130:23 (libIREECompiler.so+0xe7141b3) (BuildId: 366a285014486d45)
...
Previous write of size 8 at 0x722800004078 by thread T5:
#0 mlir::Operation::setLoc(mlir::Location) /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/Operation.h:226:40 (libIREECompiler.so+0x1005670b) (BuildId: 366a285014486d45)
#1 (anonymous namespace)::ModifyOperationRewrite::rollback() /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:683:9 (libIREECompiler.so+0x1005670b)
#2 mlir::ConversionPatternRewriter::cancelOpModification(mlir::Operation*) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:2239:10 (libIREECompiler.so+0x10035a89) (BuildId: 366a285014486d45)
#3 (anonymous namespace)::OperationLegalizer::legalizeWithFold(mlir::Operation*, mlir::ConversionPatternRewriter&) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:2558:14 (libIREECompiler.so+0x100442bb) (BuildId: 366a285014486d45)
#4 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:2484:19 (libIREECompiler.so+0x10036ff0) (BuildId: 366a285014486d45)
#5 mlir::OperationConverter::convert(mlir::ConversionPatternRewriter&, mlir::Operation*) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3108:26 (libIREECompiler.so+0x1003624f) (BuildId: 366a285014486d45)
#6 mlir::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3207:16 (libIREECompiler.so+0x10037855) (BuildId: 366a285014486d45)
#7 applyConversion(llvm::ArrayRef<mlir::Operation*>, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig, (anonymous namespace)::OpConversionMode)::$_0::operator()() const /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3918:30 (libIREECompiler.so+0x1004aace) (BuildId: 366a285014486d45)
#8 void llvm::function_ref<void ()>::callback_fn<applyConversion(llvm::ArrayRef<mlir::Operation*>, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig, (anonymous namespace)::OpConversionMode)::$_0>(long) /home/eochoalo/code/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46:12 (libIREECompiler.so+0x1004aace)
#9 llvm::function_ref<void ()>::operator()() const /home/eochoalo/code/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69:12 (libIREECompiler.so+0x1003e62a) (BuildId: 366a285014486d45)
#10 void mlir::MLIRContext::executeAction<ApplyConversionAction>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>) /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/MLIRContext.h:280:7 (libIREECompiler.so+0x1003e62a)
#11 applyConversion(llvm::ArrayRef<mlir::Operation*>, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig, (anonymous namespace)::OpConversionMode) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3915:8 (libIREECompiler.so+0x1003e62a)
#12 mlir::applyPartialConversion(llvm::ArrayRef<mlir::Operation*>, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3931:10 (libIREECompiler.so+0x1003e709) (BuildId: 366a285014486d45)
#13 mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3938:10 (libIREECompiler.so+0x1003e709)
#14 (anonymous namespace)::LowerAffine::runOnOperation() /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp:564:16 (libIREECompiler.so+0xc812528) (BuildId: 366a285014486d45)
Let me try to produce a test case.
Thanks, this occurs with this resolve call: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp#L231-L232 |
We are using the symbol table machinery to lookup for a callable, but when the analysis scope if a function, such lookup will resolve outside of the scope. This can lead to race-condition issues since other passes may operate in parallel on the sibling functions.
The callable would be discarded right after the lookup (we check the analysis scope), so avoiding the lookup is NFC.
Fix #154948