diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt index 846547ff131e3..d34bab2b76162 100644 --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -1,4 +1,7 @@ add_mlir_interface(SymbolInterfaces) +set(LLVM_TARGET_DEFINITIONS SymbolInterfaces.td) +mlir_tablegen(SymbolInterfacesAttrInterface.h.inc -gen-attr-interface-decls) +mlir_tablegen(SymbolInterfacesAttrInterface.cpp.inc -gen-attr-interface-defs) add_mlir_interface(RegionKindInterface) set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td) diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td index bbfa30815bd4a..9793c8c1752fc 100644 --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -220,6 +220,24 @@ def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> { ]; } +def SymbolUserAttrInterface : AttrInterface<"SymbolUserAttrInterface"> { + let description = [{ + This interface describes an attribute that may use a `Symbol`. This + interface allows for users of symbols to hook into verification and other + symbol related utilities that are either costly or otherwise disallowed + within a traditional operation. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<"Verify the symbol uses held by this attribute of this operation.", + "::llvm::LogicalResult", "verifySymbolUses", + (ins "::mlir::Operation *":$op, + "::mlir::SymbolTableCollection &":$symbolTable) + >, + ]; +} + //===----------------------------------------------------------------------===// // Symbol Traits //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h index e4622354b8980..a174062d8d019 100644 --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -499,5 +499,6 @@ ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, /// Include the generated symbol interfaces. #include "mlir/IR/SymbolInterfaces.h.inc" +#include "mlir/IR/SymbolInterfacesAttrInterface.h.inc" #endif // MLIR_IR_SYMBOLTABLE_H diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index 87b47992905e0..cf074b967d039 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -511,7 +511,14 @@ LogicalResult detail::verifySymbolTable(Operation *op) { SymbolTableCollection symbolTable; auto verifySymbolUserFn = [&](Operation *op) -> std::optional { if (SymbolUserOpInterface user = dyn_cast(op)) - return WalkResult(user.verifySymbolUses(symbolTable)); + if (failed(user.verifySymbolUses(symbolTable))) + return WalkResult::interrupt(); + for (auto &attr : op->getAttrs()) { + if (auto user = dyn_cast(attr.getValue())) { + if (failed(user.verifySymbolUses(op, symbolTable))) + return WalkResult::interrupt(); + } + } return WalkResult::advance(); }; @@ -1132,3 +1139,4 @@ ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser, /// Include the generated symbol interfaces. #include "mlir/IR/SymbolInterfaces.cpp.inc" +#include "mlir/IR/SymbolInterfacesAttrInterface.cpp.inc" diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index f00c31003b185..25e148ac9f8cc 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -82,6 +82,8 @@ exports_files(glob(["include/**/*.td"])) tbl_outs = { "include/mlir/IR/" + name + ".h.inc": ["-gen-op-interface-decls"], "include/mlir/IR/" + name + ".cpp.inc": ["-gen-op-interface-defs"], + "include/mlir/IR/" + name + "AttrInterface.h.inc": ["-gen-attr-interface-decls"], + "include/mlir/IR/" + name + "AttrInterface.cpp.inc": ["-gen-attr-interface-defs"], }, tblgen = ":mlir-tblgen", td_file = "include/mlir/IR/" + name + ".td",