Skip to content

Commit b276dec

Browse files
committed
[mlir] Add a DCE pass for dead symbols.
Summary: This pass deletes all symbols that are found to be unreachable. This is done by computing the set of operations that are known to be live, propagating that liveness to other symbols, and then deleting all symbols that are not within this live set. Differential Revision: https://reviews.llvm.org/D72482
1 parent ab9e559 commit b276dec

File tree

6 files changed

+280
-5
lines changed

6 files changed

+280
-5
lines changed

mlir/include/mlir/IR/SymbolTable.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ class SymbolTable {
9393
/// with the 'OpTrait::SymbolTable' trait.
9494
static Operation *lookupSymbolIn(Operation *op, StringRef symbol);
9595
static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol);
96+
/// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
97+
/// by a given SymbolRefAttr. Returns failure if any of the nested references
98+
/// could not be resolved.
99+
static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol,
100+
SmallVectorImpl<Operation *> &symbols);
96101

97102
/// Returns the operation registered with the given symbol name within the
98103
/// closest parent operation of, or including, 'from' with the

mlir/include/mlir/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ std::unique_ptr<OpPassBase<FuncOp>> createTestLoopFusionPass();
126126
/// Creates a pass which inlines calls and callable operations as defined by the
127127
/// CallGraph.
128128
std::unique_ptr<Pass> createInlinerPass();
129+
130+
/// Creates a pass which delete symbol operations that are unreachable. This
131+
/// pass may *only* be scheduled on an operation that defines a SymbolTable.
132+
std::unique_ptr<Pass> createSymbolDCEPass();
129133
} // end namespace mlir
130134

131135
#endif // MLIR_TRANSFORMS_PASSES_H

mlir/lib/IR/SymbolTable.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,30 +230,42 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
230230
}
231231
Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
232232
SymbolRefAttr symbol) {
233+
SmallVector<Operation *, 4> resolvedSymbols;
234+
if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
235+
return nullptr;
236+
return resolvedSymbols.back();
237+
}
238+
239+
LogicalResult
240+
SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
241+
SmallVectorImpl<Operation *> &symbols) {
233242
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
234243

235244
// Lookup the root reference for this symbol.
236245
symbolTableOp = lookupSymbolIn(symbolTableOp, symbol.getRootReference());
237246
if (!symbolTableOp)
238-
return nullptr;
247+
return failure();
248+
symbols.push_back(symbolTableOp);
239249

240250
// If there are no nested references, just return the root symbol directly.
241251
ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
242252
if (nestedRefs.empty())
243-
return symbolTableOp;
253+
return success();
244254

245255
// Verify that the root is also a symbol table.
246256
if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
247-
return nullptr;
257+
return failure();
248258

249259
// Otherwise, lookup each of the nested non-leaf references and ensure that
250260
// each corresponds to a valid symbol table.
251261
for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
252262
symbolTableOp = lookupSymbolIn(symbolTableOp, ref.getValue());
253263
if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
254-
return nullptr;
264+
return failure();
265+
symbols.push_back(symbolTableOp);
255266
}
256-
return lookupSymbolIn(symbolTableOp, symbol.getLeafReference());
267+
symbols.push_back(lookupSymbolIn(symbolTableOp, symbol.getLeafReference()));
268+
return success(symbols.back());
257269
}
258270

259271
/// Returns the operation registered with the given symbol name within the

mlir/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_llvm_library(MLIRTransforms
1717
PipelineDataTransfer.cpp
1818
SimplifyAffineStructures.cpp
1919
StripDebugInfo.cpp
20+
SymbolDCE.cpp
2021
Vectorize.cpp
2122
ViewOpGraph.cpp
2223
ViewRegionGraph.cpp

mlir/lib/Transforms/SymbolDCE.cpp

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
//===- SymbolDCE.cpp - Pass to delete dead symbols ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements an algorithm for eliminating symbol operations that are
10+
// known to be dead.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Pass/Pass.h"
15+
#include "mlir/Transforms/Passes.h"
16+
17+
using namespace mlir;
18+
19+
namespace {
20+
struct SymbolDCE : public OperationPass<SymbolDCE> {
21+
void runOnOperation() override;
22+
23+
/// Compute the liveness of the symbols within the given symbol table.
24+
/// `symbolTableIsHidden` is true if this symbol table is known to be
25+
/// unaccessible from operations in its parent regions.
26+
LogicalResult computeLiveness(Operation *symbolTableOp,
27+
bool symbolTableIsHidden,
28+
DenseSet<Operation *> &liveSymbols);
29+
};
30+
} // end anonymous namespace
31+
32+
void SymbolDCE::runOnOperation() {
33+
Operation *symbolTableOp = getOperation();
34+
35+
// SymbolDCE should only be run on operations that define a symbol table.
36+
if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
37+
symbolTableOp->emitOpError()
38+
<< " was scheduled to run under SymbolDCE, but does not define a "
39+
"symbol table";
40+
return signalPassFailure();
41+
}
42+
43+
// A flag that signals if the top level symbol table is hidden, i.e. not
44+
// accessible from parent scopes.
45+
bool symbolTableIsHidden = true;
46+
if (symbolTableOp->getParentOp() && SymbolTable::isSymbol(symbolTableOp)) {
47+
symbolTableIsHidden = SymbolTable::getSymbolVisibility(symbolTableOp) ==
48+
SymbolTable::Visibility::Private;
49+
}
50+
51+
// Compute the set of live symbols within the symbol table.
52+
DenseSet<Operation *> liveSymbols;
53+
if (failed(computeLiveness(symbolTableOp, symbolTableIsHidden, liveSymbols)))
54+
return signalPassFailure();
55+
56+
// After computing the liveness, delete all of the symbols that were found to
57+
// be dead.
58+
symbolTableOp->walk([&](Operation *nestedSymbolTable) {
59+
if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
60+
return;
61+
for (auto &block : nestedSymbolTable->getRegion(0)) {
62+
for (Operation &op :
63+
llvm::make_early_inc_range(block.without_terminator())) {
64+
if (SymbolTable::isSymbol(&op) && !liveSymbols.count(&op))
65+
op.erase();
66+
}
67+
}
68+
});
69+
}
70+
71+
/// Compute the liveness of the symbols within the given symbol table.
72+
/// `symbolTableIsHidden` is true if this symbol table is known to be
73+
/// unaccessible from operations in its parent regions.
74+
LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
75+
bool symbolTableIsHidden,
76+
DenseSet<Operation *> &liveSymbols) {
77+
// A worklist of live operations to propagate uses from.
78+
SmallVector<Operation *, 16> worklist;
79+
80+
// Walk the symbols within the current symbol table, marking the symbols that
81+
// are known to be live.
82+
for (auto &block : symbolTableOp->getRegion(0)) {
83+
for (Operation &op : block.without_terminator()) {
84+
// Always add non symbol operations to the worklist.
85+
if (!SymbolTable::isSymbol(&op)) {
86+
worklist.push_back(&op);
87+
continue;
88+
}
89+
90+
// Check the visibility to see if this symbol may be referenced
91+
// externally.
92+
SymbolTable::Visibility visibility =
93+
SymbolTable::getSymbolVisibility(&op);
94+
95+
// Private symbols are always initially considered dead.
96+
if (visibility == mlir::SymbolTable::Visibility::Private)
97+
continue;
98+
// We only include nested visibility here if the symbol table isn't
99+
// hidden.
100+
if (symbolTableIsHidden && visibility == SymbolTable::Visibility::Nested)
101+
continue;
102+
103+
// TODO(riverriddle) Add hooks here to allow symbols to provide additional
104+
// information, e.g. linkage can be used to drop some symbols that may
105+
// otherwise be considered "live".
106+
if (liveSymbols.insert(&op).second)
107+
worklist.push_back(&op);
108+
}
109+
}
110+
111+
// Process the set of symbols that were known to be live, adding new symbols
112+
// that are referenced within.
113+
while (!worklist.empty()) {
114+
Operation *op = worklist.pop_back_val();
115+
116+
// If this is a symbol table, recursively compute its liveness.
117+
if (op->hasTrait<OpTrait::SymbolTable>()) {
118+
// The internal symbol table is hidden if the parent is, if its not a
119+
// symbol, or if it is a private symbol.
120+
bool symbolIsHidden = symbolTableIsHidden || !SymbolTable::isSymbol(op) ||
121+
SymbolTable::getSymbolVisibility(op) ==
122+
SymbolTable::Visibility::Private;
123+
if (failed(computeLiveness(op, symbolIsHidden, liveSymbols)))
124+
return failure();
125+
}
126+
127+
// Collect the uses held by this operation.
128+
Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(op);
129+
if (!uses) {
130+
return op->emitError()
131+
<< "operation contains potentially unknown symbol table, "
132+
"meaning that we can't reliable compute symbol uses";
133+
}
134+
135+
SmallVector<Operation *, 4> resolvedSymbols;
136+
for (const SymbolTable::SymbolUse &use : *uses) {
137+
// Lookup the symbols referenced by this use.
138+
resolvedSymbols.clear();
139+
if (failed(SymbolTable::lookupSymbolIn(
140+
op->getParentOp(), use.getSymbolRef(), resolvedSymbols))) {
141+
return use.getUser()->emitError()
142+
<< "unable to resolve reference to symbol "
143+
<< use.getSymbolRef();
144+
}
145+
146+
// Mark each of the resolved symbols as live.
147+
for (Operation *resolvedSymbol : resolvedSymbols)
148+
if (liveSymbols.insert(resolvedSymbol).second)
149+
worklist.push_back(resolvedSymbol);
150+
}
151+
}
152+
153+
return success();
154+
}
155+
156+
std::unique_ptr<Pass> mlir::createSymbolDCEPass() {
157+
return std::make_unique<SymbolDCE>();
158+
}
159+
160+
static PassRegistration<SymbolDCE> pass("symbol-dce", "Eliminate dead symbols");

mlir/test/IR/test-symbol-dce.mlir

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// RUN: mlir-opt %s -symbol-dce -split-input-file -verify-diagnostics | FileCheck %s
2+
// RUN: mlir-opt %s -pass-pipeline="module(symbol-dce)" -split-input-file | FileCheck %s --check-prefix=NESTED
3+
4+
// Check that trivially dead and trivially live non-nested cases are handled.
5+
6+
// CHECK-LABEL: module attributes {test.simple}
7+
module attributes {test.simple} {
8+
// CHECK-NOT: func @dead_private_function
9+
func @dead_private_function() attributes { sym_visibility = "nested" }
10+
11+
// CHECK-NOT: func @dead_nested_function
12+
func @dead_nested_function() attributes { sym_visibility = "nested" }
13+
14+
// CHECK: func @live_private_function
15+
func @live_private_function() attributes { sym_visibility = "nested" }
16+
17+
// CHECK: func @live_nested_function
18+
func @live_nested_function() attributes { sym_visibility = "nested" }
19+
20+
// CHECK: func @public_function
21+
func @public_function() {
22+
"foo.return"() {uses = [@live_private_function, @live_nested_function]} : () -> ()
23+
}
24+
25+
// CHECK: func @public_function_explicit
26+
func @public_function_explicit() attributes { sym_visibility = "public" }
27+
}
28+
29+
// -----
30+
31+
// Check that we don't DCE nested symbols if they are used.
32+
// CHECK-LABEL: module attributes {test.nested}
33+
module attributes {test.nested} {
34+
// CHECK: module @public_module
35+
module @public_module {
36+
// CHECK-NOT: func @dead_nested_function
37+
func @dead_nested_function() attributes { sym_visibility = "nested" }
38+
39+
// CHECK: func @private_function
40+
func @private_function() attributes { sym_visibility = "private" }
41+
42+
// CHECK: func @nested_function
43+
func @nested_function() attributes { sym_visibility = "nested" } {
44+
"foo.return"() {uses = [@private_function]} : () -> ()
45+
}
46+
}
47+
48+
"live.user"() {uses = [@public_module::@nested_function]} : () -> ()
49+
}
50+
51+
// -----
52+
53+
// Check that we don't DCE symbols if we can't prove that the top-level symbol
54+
// table that we are running on is hidden from above.
55+
// NESTED-LABEL: module attributes {test.no_dce_non_hidden_parent}
56+
module attributes {test.no_dce_non_hidden_parent} {
57+
// NESTED: module @public_module
58+
module @public_module {
59+
// NESTED: func @nested_function
60+
func @nested_function() attributes { sym_visibility = "nested" }
61+
}
62+
// NESTED: module @nested_module
63+
module @nested_module attributes { sym_visibility = "nested" } {
64+
// NESTED: func @nested_function
65+
func @nested_function() attributes { sym_visibility = "nested" }
66+
}
67+
68+
// Only private modules can be assumed to be hidden.
69+
// NESTED: module @private_module
70+
module @private_module attributes { sym_visibility = "private" } {
71+
// NESTED-NOT: func @nested_function
72+
func @nested_function() attributes { sym_visibility = "nested" }
73+
}
74+
75+
"live.user"() {uses = [@nested_module, @private_module]} : () -> ()
76+
}
77+
78+
// -----
79+
80+
module {
81+
func @private_symbol() attributes { sym_visibility = "private" }
82+
83+
// expected-error@+1 {{contains potentially unknown symbol table}}
84+
"foo.possibly_unknown_symbol_table"() ({
85+
}) : () -> ()
86+
}
87+
88+
// -----
89+
90+
module {
91+
// expected-error@+1 {{unable to resolve reference to symbol}}
92+
"live.user"() {uses = [@unknown_symbol]} : () -> ()
93+
}

0 commit comments

Comments
 (0)