-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][spirv] Add support for SPV_ARM_graph extension - part 3 #156845
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][spirv] Add support for SPV_ARM_graph extension - part 3 #156845
Conversation
This is the third patch to add support for the `SPV_ARM_graph` SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new `Graph` abstraction for expressing dataflow computations over full resources. The part 3 implementation includes: - ABI lowering support for graph entry points via `LowerABIAttributesPass`. - Tests covering ABI handling. Graphs currently support only `SPV_ARM_tensors`, but are designed to generalize to other resource types, such as images. Spec: KhronosGroup/SPIRV-Registry#346 RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947 Signed-off-by: Davide Grohmann <davide.grohmann@arm.com> Change-Id: I31896806a3e3a856530149ffd919b8568d5b6208
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Davide Grohmann (davidegrohmann) ChangesThis is the third patch to add support for the The part 3 implementation includes:
Graphs currently support only Spec: KhronosGroup/SPIRV-Registry#346 Full diff: https://github.com/llvm/llvm-project/pull/156845.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index fcf1526491971..44c86bc8777e4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1066,7 +1066,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
}
LogicalResult SPIRVDialect::verifyRegionResultAttribute(
- Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
+ Operation *op, unsigned /*regionIndex*/, unsigned resultIndex,
NamedAttribute attribute) {
- return op->emitError("cannot attach SPIR-V attributes to region result");
+ if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op))
+ return verifyRegionAttribute(
+ op->getLoc(), graphOp.getResultTypes()[resultIndex], attribute);
+ return op->emitError(
+ "cannot attach SPIR-V attributes to region result which is "
+ "not part of a spirv::GraphARMOp type");
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 3911ec08fcc27..91aa0e3823a31 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -85,10 +85,36 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
abiInfo.getBinding());
}
+/// Creates a global variable for an argument or result based on the ABI info.
+static spirv::GlobalVariableOp
+createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp,
+ unsigned index, bool isArg,
+ spirv::InterfaceVarABIAttr abiInfo) {
+ auto spirvModule = graphOp->getParentOfType<spirv::ModuleOp>();
+ if (!spirvModule)
+ return nullptr;
+
+ OpBuilder::InsertionGuard moduleInsertionGuard(builder);
+ builder.setInsertionPoint(graphOp.getOperation());
+ std::string varName = graphOp.getName().str() + (isArg ? "_arg_" : "_res_") +
+ std::to_string(index);
+
+ Type varType = isArg ? graphOp.getFunctionType().getInput(index)
+ : graphOp.getFunctionType().getResult(index);
+
+ auto pointerType = spirv::PointerType::get(
+ varType,
+ abiInfo.getStorageClass().value_or(spirv::StorageClass::UniformConstant));
+
+ return builder.create<spirv::GlobalVariableOp>(
+ graphOp.getLoc(), pointerType, varName, abiInfo.getDescriptorSet(),
+ abiInfo.getBinding());
+}
+
/// Gets the global variables that need to be specified as interface variable
/// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
static LogicalResult
-getInterfaceVariables(spirv::FuncOp funcOp,
+getInterfaceVariables(mlir::FunctionOpInterface funcOp,
SmallVectorImpl<Attribute> &interfaceVars) {
auto module = funcOp->getParentOfType<spirv::ModuleOp>();
if (!module) {
@@ -224,6 +250,21 @@ class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// A pattern to convert graph signature according to interface variable ABI
+/// attributes.
+///
+/// Specifically, this pattern creates global variables according to interface
+/// variable ABI attributes attached to graph arguments and results.
+class ProcessGraphInterfaceVarABI final
+ : public OpConversionPattern<spirv::GraphARMOp> {
+public:
+ using OpConversionPattern<spirv::GraphARMOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Pass to implement the ABI information specified as attributes.
class LowerABIAttributesPass final
: public spirv::impl::SPIRVLowerABIAttributesPassBase<
@@ -297,6 +338,65 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
return success();
}
+LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
+ spirv::GraphARMOp graphOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Non-entry point graphs are not handled.
+ if (!graphOp.getEntryPoint().value_or(false))
+ return failure();
+
+ TypeConverter::SignatureConversion signatureConverter(
+ graphOp.getFunctionType().getNumInputs());
+
+ StringRef attrName = spirv::getInterfaceVarABIAttrName();
+ SmallVector<Attribute, 4> interfaceVars;
+
+ // Convert arguments.
+ unsigned numInputs = graphOp.getFunctionType().getNumInputs();
+ unsigned numResults = graphOp.getFunctionType().getNumResults();
+ for (unsigned index = 0; index < numInputs; ++index) {
+ auto abiInfo =
+ graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(index, attrName);
+ if (!abiInfo)
+ return failure();
+ spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
+ rewriter, graphOp, index, true, abiInfo);
+ if (!var)
+ return failure();
+ interfaceVars.push_back(
+ SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
+ }
+
+ for (unsigned index = 0; index < numResults; ++index) {
+ auto abiInfo = graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+ index, attrName);
+ if (!abiInfo)
+ return failure();
+ spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
+ rewriter, graphOp, index, false, abiInfo);
+ if (!var)
+ return failure();
+ interfaceVars.push_back(
+ SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
+ }
+
+ // Update signature.
+ rewriter.modifyOpInPlace(graphOp, [&] {
+ for (unsigned index = 0; index < numInputs; ++index) {
+ graphOp.removeArgAttr(index, attrName);
+ }
+ for (unsigned index = 0; index < numResults; ++index) {
+ graphOp.removeResultAttr(index, rewriter.getStringAttr(attrName));
+ }
+ });
+
+ OpBuilder::InsertionGuard insertionGuard(rewriter);
+ rewriter.setInsertionPoint(graphOp);
+ rewriter.create<spirv::GraphEntryPointARMOp>(graphOp.getLoc(), graphOp,
+ interfaceVars);
+ return success();
+}
+
void LowerABIAttributesPass::runOnOperation() {
// Uses the signature conversion methodology of the dialect conversion
// framework to implement the conversion.
@@ -323,6 +423,7 @@ void LowerABIAttributesPass::runOnOperation() {
RewritePatternSet patterns(context);
patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
+ patterns.add<ProcessGraphInterfaceVarABI>(typeConverter, context);
ConversionTarget target(*context);
// "Legal" function ops should have no interface variable ABI attributes.
@@ -333,6 +434,17 @@ void LowerABIAttributesPass::runOnOperation() {
return false;
return true;
});
+ target.addDynamicallyLegalOp<spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
+ StringRef attrName = spirv::getInterfaceVarABIAttrName();
+ for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
+ if (op.getArgAttr(i, attrName))
+ return false;
+ for (unsigned i = 0, e = op.getNumResults(); i < e; ++i)
+ if (op.getResultAttr(i, attrName))
+ return false;
+ return true;
+ });
+
// All other SPIR-V ops are legal.
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return op->getDialect()->getNamespace() ==
diff --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
index 10fbcf06eb052..63dea6af83556 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
@@ -101,6 +101,14 @@ func.func @interface_var(
// -----
+// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+spirv.ARM.Graph @interface_var(%arg: !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (
+ !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+) { spirv.ARM.GraphOutputs %arg : !spirv.arm.tensor<1xf32> }
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.resource_limits
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index f3a3218e5aec0..04667c828bbd1 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -35,6 +35,28 @@ spirv.module Logical GLSL450 {
// -----
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: spirv.module
+spirv.module Logical Vulkan {
+ // CHECK-DAG: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
+ // CHECK-DAG: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
+
+ // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+ // CHECK: spirv.ARM.Graph [[GN]]([[ARG0:%.*]]: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = true}
+ spirv.ARM.Graph @main(%arg0: !spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
+ -> (!spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+ spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+ }
+} // end spirv.module
+
+} // end module
+
+// -----
+
module {
// expected-error@+1 {{'spirv.module' op missing SPIR-V target env attribute}}
spirv.module Logical GLSL450 {}
|
This is the third patch to add support for the
SPV_ARM_graph
SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a newGraph
abstraction for expressing dataflow computations over full resources.The part 3 implementation includes:
LowerABIAttributesPass
.Graphs currently support only
SPV_ARM_tensors
, but are designed to generalize to other resource types, such as images.Spec: KhronosGroup/SPIRV-Registry#346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947