Skip to content

Conversation

davidegrohmann
Copy link
Contributor

This is the third patch to add support for the SPV_ARM_graph SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new Graph abstraction for expressing dataflow computations over full resources.

The part 3 implementation includes:

  • ABI lowering support for graph entry points via LowerABIAttributesPass.
  • Tests covering ABI handling.

Graphs currently support only SPV_ARM_tensors, but are designed to generalize to other resource types, such as images.

Spec: KhronosGroup/SPIRV-Registry#346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947

This is the third patch to add support for the `SPV_ARM_graph` SPIR-V
extension to MLIR’s SPIR-V dialect. The extension introduces a new
`Graph` abstraction for expressing dataflow computations over full
resources.

The part 3 implementation includes:

- ABI lowering support for graph entry points via `LowerABIAttributesPass`.
- Tests covering ABI handling.

Graphs currently support only `SPV_ARM_tensors`, but are designed to
generalize to other resource types, such as images.

Spec: KhronosGroup/SPIRV-Registry#346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947

Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
Change-Id: I31896806a3e3a856530149ffd919b8568d5b6208
@llvmbot
Copy link
Member

llvmbot commented Sep 4, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Davide Grohmann (davidegrohmann)

Changes

This is the third patch to add support for the SPV_ARM_graph SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new Graph abstraction for expressing dataflow computations over full resources.

The part 3 implementation includes:

  • ABI lowering support for graph entry points via LowerABIAttributesPass.
  • Tests covering ABI handling.

Graphs currently support only SPV_ARM_tensors, but are designed to generalize to other resource types, such as images.

Spec: KhronosGroup/SPIRV-Registry#346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947


Full diff: https://github.com/llvm/llvm-project/pull/156845.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+7-2)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp (+113-1)
  • (modified) mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir (+8)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir (+22)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index fcf1526491971..44c86bc8777e4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1066,7 +1066,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
 }
 
 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
-    Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
+    Operation *op, unsigned /*regionIndex*/, unsigned resultIndex,
     NamedAttribute attribute) {
-  return op->emitError("cannot attach SPIR-V attributes to region result");
+  if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op))
+    return verifyRegionAttribute(
+        op->getLoc(), graphOp.getResultTypes()[resultIndex], attribute);
+  return op->emitError(
+      "cannot attach SPIR-V attributes to region result which is "
+      "not part of a spirv::GraphARMOp type");
 }
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 3911ec08fcc27..91aa0e3823a31 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -85,10 +85,36 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
                                          abiInfo.getBinding());
 }
 
+/// Creates a global variable for an argument or result based on the ABI info.
+static spirv::GlobalVariableOp
+createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp,
+                                  unsigned index, bool isArg,
+                                  spirv::InterfaceVarABIAttr abiInfo) {
+  auto spirvModule = graphOp->getParentOfType<spirv::ModuleOp>();
+  if (!spirvModule)
+    return nullptr;
+
+  OpBuilder::InsertionGuard moduleInsertionGuard(builder);
+  builder.setInsertionPoint(graphOp.getOperation());
+  std::string varName = graphOp.getName().str() + (isArg ? "_arg_" : "_res_") +
+                        std::to_string(index);
+
+  Type varType = isArg ? graphOp.getFunctionType().getInput(index)
+                       : graphOp.getFunctionType().getResult(index);
+
+  auto pointerType = spirv::PointerType::get(
+      varType,
+      abiInfo.getStorageClass().value_or(spirv::StorageClass::UniformConstant));
+
+  return builder.create<spirv::GlobalVariableOp>(
+      graphOp.getLoc(), pointerType, varName, abiInfo.getDescriptorSet(),
+      abiInfo.getBinding());
+}
+
 /// Gets the global variables that need to be specified as interface variable
 /// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
 static LogicalResult
-getInterfaceVariables(spirv::FuncOp funcOp,
+getInterfaceVariables(mlir::FunctionOpInterface funcOp,
                       SmallVectorImpl<Attribute> &interfaceVars) {
   auto module = funcOp->getParentOfType<spirv::ModuleOp>();
   if (!module) {
@@ -224,6 +250,21 @@ class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// A pattern to convert graph signature according to interface variable ABI
+/// attributes.
+///
+/// Specifically, this pattern creates global variables according to interface
+/// variable ABI attributes attached to graph arguments and results.
+class ProcessGraphInterfaceVarABI final
+    : public OpConversionPattern<spirv::GraphARMOp> {
+public:
+  using OpConversionPattern<spirv::GraphARMOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Pass to implement the ABI information specified as attributes.
 class LowerABIAttributesPass final
     : public spirv::impl::SPIRVLowerABIAttributesPassBase<
@@ -297,6 +338,65 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
   return success();
 }
 
+LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
+    spirv::GraphARMOp graphOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  // Non-entry point graphs are not handled.
+  if (!graphOp.getEntryPoint().value_or(false))
+    return failure();
+
+  TypeConverter::SignatureConversion signatureConverter(
+      graphOp.getFunctionType().getNumInputs());
+
+  StringRef attrName = spirv::getInterfaceVarABIAttrName();
+  SmallVector<Attribute, 4> interfaceVars;
+
+  // Convert arguments.
+  unsigned numInputs = graphOp.getFunctionType().getNumInputs();
+  unsigned numResults = graphOp.getFunctionType().getNumResults();
+  for (unsigned index = 0; index < numInputs; ++index) {
+    auto abiInfo =
+        graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(index, attrName);
+    if (!abiInfo)
+      return failure();
+    spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
+        rewriter, graphOp, index, true, abiInfo);
+    if (!var)
+      return failure();
+    interfaceVars.push_back(
+        SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
+  }
+
+  for (unsigned index = 0; index < numResults; ++index) {
+    auto abiInfo = graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+        index, attrName);
+    if (!abiInfo)
+      return failure();
+    spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
+        rewriter, graphOp, index, false, abiInfo);
+    if (!var)
+      return failure();
+    interfaceVars.push_back(
+        SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
+  }
+
+  // Update signature.
+  rewriter.modifyOpInPlace(graphOp, [&] {
+    for (unsigned index = 0; index < numInputs; ++index) {
+      graphOp.removeArgAttr(index, attrName);
+    }
+    for (unsigned index = 0; index < numResults; ++index) {
+      graphOp.removeResultAttr(index, rewriter.getStringAttr(attrName));
+    }
+  });
+
+  OpBuilder::InsertionGuard insertionGuard(rewriter);
+  rewriter.setInsertionPoint(graphOp);
+  rewriter.create<spirv::GraphEntryPointARMOp>(graphOp.getLoc(), graphOp,
+                                               interfaceVars);
+  return success();
+}
+
 void LowerABIAttributesPass::runOnOperation() {
   // Uses the signature conversion methodology of the dialect conversion
   // framework to implement the conversion.
@@ -323,6 +423,7 @@ void LowerABIAttributesPass::runOnOperation() {
 
   RewritePatternSet patterns(context);
   patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
+  patterns.add<ProcessGraphInterfaceVarABI>(typeConverter, context);
 
   ConversionTarget target(*context);
   // "Legal" function ops should have no interface variable ABI attributes.
@@ -333,6 +434,17 @@ void LowerABIAttributesPass::runOnOperation() {
         return false;
     return true;
   });
+  target.addDynamicallyLegalOp<spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
+    StringRef attrName = spirv::getInterfaceVarABIAttrName();
+    for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
+      if (op.getArgAttr(i, attrName))
+        return false;
+    for (unsigned i = 0, e = op.getNumResults(); i < e; ++i)
+      if (op.getResultAttr(i, attrName))
+        return false;
+    return true;
+  });
+
   // All other SPIR-V ops are legal.
   target.markUnknownOpDynamicallyLegal([](Operation *op) {
     return op->getDialect()->getNamespace() ==
diff --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
index 10fbcf06eb052..63dea6af83556 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
@@ -101,6 +101,14 @@ func.func @interface_var(
 
 // -----
 
+// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+spirv.ARM.Graph @interface_var(%arg: !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (
+    !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+) { spirv.ARM.GraphOutputs %arg : !spirv.arm.tensor<1xf32> }
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.resource_limits
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index f3a3218e5aec0..04667c828bbd1 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -35,6 +35,28 @@ spirv.module Logical GLSL450 {
 
 // -----
 
+module attributes {
+  spirv.target_env = #spirv.target_env<
+     #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: spirv.module
+spirv.module Logical Vulkan {
+  //  CHECK-DAG:    spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
+  //  CHECK-DAG:    spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
+
+  //      CHECK:    spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+  //      CHECK:    spirv.ARM.Graph [[GN]]([[ARG0:%.*]]: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = true}
+  spirv.ARM.Graph @main(%arg0: !spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
+                  -> (!spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+  }
+} // end spirv.module
+
+} // end module
+
+// -----
+
 module {
 // expected-error@+1 {{'spirv.module' op missing SPIR-V target env attribute}}
 spirv.module Logical GLSL450 {}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants