-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][spirv] Add support for SPV_ARM_graph extension - part 2 #156665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][spirv] Add support for SPV_ARM_graph extension - part 2 #156665
Conversation
This is the second patch to add support for the `SPV_ARM_graph` SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new `Graph` abstraction for expressing dataflow computations over full resources. The part 2 implementation includes: - Serialization and deserialization support for: - `OpGraphARM`, `OpGraphInputARM`, `OpGraphSetOutputARM`, `OpGraphEndARM` - `OpGraphEntryPointARM`, `OpGraphConstantARM`, `OpTypeGraphARM` - Tests covering binary round-tripping. Graphs currently support only `SPV_ARM_tensors`, but are designed to generalize to other resource types, such as images. Spec: KhronosGroup/SPIRV-Registry#346 RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947 Signed-off-by: Davide Grohmann <davide.grohmann@arm.com> Change-Id: I88a5ff0298e0d30f649798111785ea984db56515
@llvm/pr-subscribers-mlir-spirv Author: Davide Grohmann (davidegrohmann) ChangesThis is the second patch to add support for the The part 2 implementation includes:
Graphs currently support only Spec: KhronosGroup/SPIRV-Registry#346 Patch is 35.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156665.diff 7 Files Affected:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index ee18cf815e4a7..4c49ec868bbc8 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -86,6 +86,13 @@ Value spirv::Deserializer::getValue(uint32_t id) {
if (auto undef = getUndefType(id)) {
return spirv::UndefOp::create(opBuilder, unknownLoc, undef);
}
+ if (std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+ graphConstantARMInfo = getGraphConstantARM(id)) {
+ IntegerAttr graphConstantID = graphConstantARMInfo->graphConstantID;
+ Type resultType = graphConstantARMInfo->resultType;
+ return opBuilder.create<spirv::GraphConstantARMOp>(unknownLoc, resultType,
+ graphConstantID);
+ }
return valueMap.lookup(id);
}
@@ -180,6 +187,7 @@ LogicalResult spirv::Deserializer::processInstruction(
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
case spirv::Opcode::OpTypeTensorARM:
+ case spirv::Opcode::OpTypeGraphARM:
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
@@ -208,12 +216,26 @@ LogicalResult spirv::Deserializer::processInstruction(
return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
case spirv::Opcode::OpConstantNull:
return processConstantNull(operands);
+ case spirv::Opcode::OpGraphConstantARM:
+ return processGraphConstantARM(operands);
case spirv::Opcode::OpDecorate:
return processDecoration(operands);
case spirv::Opcode::OpMemberDecorate:
return processMemberDecoration(operands);
case spirv::Opcode::OpFunction:
return processFunction(operands);
+ case spirv::Opcode::OpGraphEntryPointARM:
+ if (deferInstructions) {
+ deferredInstructions.emplace_back(opcode, operands);
+ return success();
+ }
+ return processGraphEntryPointARM(operands);
+ case spirv::Opcode::OpGraphARM:
+ return processGraphARM(operands);
+ case spirv::Opcode::OpGraphSetOutputARM:
+ return processOpGraphSetOutputARM(operands);
+ case spirv::Opcode::OpGraphEndARM:
+ return processGraphEndARM(operands);
case spirv::Opcode::OpLabel:
return processLabel(operands);
case spirv::Opcode::OpBranch:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 3625dd2eb7dd3..37b5d348b0a1c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -669,6 +669,200 @@ spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult
+spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 2) {
+ return emitError(unknownLoc,
+ "missing graph defintion in OpGraphEntryPointARM");
+ }
+
+ unsigned wordIndex = 0;
+ uint32_t grID = operands[wordIndex++];
+ if (!graphMap.count(grID)) {
+ return emitError(unknownLoc,
+ "missing graph definition/declaration with id ")
+ << grID;
+ }
+
+ spirv::GraphARMOp graphARM = graphMap[grID];
+ StringRef name = decodeStringLiteral(operands, wordIndex);
+ graphARM.setSymName(name);
+ graphARM.setEntryPoint(true);
+
+ SmallVector<Attribute, 4> interface;
+ for (int64_t size = operands.size(); wordIndex < size; wordIndex++) {
+ if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
+ interface.push_back(SymbolRefAttr::get(arg.getOperation()));
+ } else {
+ return emitError(unknownLoc, "undefined result <id> ")
+ << operands[wordIndex] << " while decoding OpGraphEntryPoint";
+ }
+ }
+
+ // RAII guard to reset the insertion point to previous value when done.
+ OpBuilder::InsertionGuard insertionGuard(opBuilder);
+ opBuilder.setInsertionPoint(graphARM);
+ opBuilder.create<spirv::GraphEntryPointARMOp>(
+ unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
+ opBuilder.getArrayAttr(interface));
+
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
+ if (curGraph) {
+ return emitError(unknownLoc, "found graph inside graph");
+ }
+ // Get the result type.
+ if (operands.size() < 2) {
+ return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
+ }
+
+ Type type = getType(operands[0]);
+ if (!type || !isa<GraphType>(type)) {
+ return emitError(unknownLoc, "unknown graph type from <id> ")
+ << operands[0];
+ }
+ auto graphType = cast<GraphType>(type);
+ if (graphType.getNumResults() <= 0) {
+ return emitError(unknownLoc, "expected at least one result");
+ }
+
+ uint32_t grID = operands[1];
+ if (graphMap.count(grID)) {
+ return emitError(unknownLoc, "duplicate graph definition/declaration");
+ }
+
+ std::string grName = getGraphSymbol(grID);
+ auto graphOp =
+ opBuilder.create<spirv::GraphARMOp>(unknownLoc, grName, graphType);
+ curGraph = graphMap[grID] = graphOp;
+ Block *entryBlock = graphOp.addEntryBlock();
+ LLVM_DEBUG({
+ logger.startLine()
+ << "//===-------------------------------------------===//\n";
+ logger.startLine() << "[graph] name: " << grName << "\n";
+ logger.startLine() << "[graph] type: " << graphType << "\n";
+ logger.startLine() << "[graph] ID: " << grID << "\n";
+ logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
+ logger.indent();
+ });
+
+ // Parse the op argument instructions.
+ for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
+ spirv::Opcode opcode;
+ ArrayRef<uint32_t> operands;
+ if (failed(sliceInstruction(opcode, operands,
+ spirv::Opcode::OpGraphInputARM))) {
+ return failure();
+ }
+ if (operands.size() != 3) {
+ return emitError(unknownLoc, "expected result type, result <id> and "
+ "input index for OpGraphInputARM");
+ }
+
+ Type argDefinedType = getType(operands[0]);
+ if (!argDefinedType) {
+ return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
+ }
+
+ if (argDefinedType != argType) {
+ return emitError(unknownLoc,
+ "mismatch in argument type between graph type "
+ "definition ")
+ << graphType << " and argument type definition " << argDefinedType
+ << " at argument " << index;
+ }
+ if (getValue(operands[1])) {
+ return emitError(unknownLoc, "duplicate definition of result <id> ")
+ << operands[1];
+ }
+
+ IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
+ if (!inputIndexAttr) {
+ return emitError(unknownLoc,
+ "unable to read inputIndex value from constant op ")
+ << operands[2];
+ }
+ BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
+ valueMap[operands[1]] = argValue;
+ }
+
+ graphOutputs.resize(graphType.getNumResults());
+
+ // RAII guard to reset the insertion point to the module's region after
+ // deserializing the body of this function.
+ OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
+
+ blockMap[grID] = entryBlock;
+ if (failed(createGraphBlock(grID))) {
+ return failure();
+ }
+
+ // Process all the instructions in the graph until and including
+ // OpGraphEndARM.
+ spirv::Opcode opcode;
+ ArrayRef<uint32_t> instOperands;
+ do {
+ if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
+ return failure();
+ }
+
+ if (failed(processInstruction(opcode, instOperands))) {
+ return failure();
+ }
+ } while (opcode != spirv::Opcode::OpGraphEndARM);
+
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2) {
+ return emitError(
+ unknownLoc,
+ "expected value id and output index for OpGraphSetOutputARM");
+ }
+
+ uint32_t id = operands[0];
+ Value value = getValue(id);
+ if (!value) {
+ return emitError(unknownLoc, "could not find result <id> ") << id;
+ }
+
+ IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
+ if (!outputIndexAttr) {
+ return emitError(unknownLoc,
+ "unable to read outputIndex value from constant op ")
+ << operands[1];
+ }
+ graphOutputs[outputIndexAttr.getInt()] = value;
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphEndARM(ArrayRef<uint32_t> operands) {
+ // Create GraphOutputsARM instruction.
+ opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
+
+ // Process OpGraphEndARM.
+ if (!operands.empty()) {
+ return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
+ }
+
+ curBlock = nullptr;
+ curGraph = std::nullopt;
+ graphOutputs.clear();
+
+ LLVM_DEBUG({
+ logger.unindent();
+ logger.startLine()
+ << "//===-------------------------------------------===//\n";
+ });
+ return success();
+}
+
std::optional<std::pair<Attribute, Type>>
spirv::Deserializer::getConstant(uint32_t id) {
auto constIt = constantMap.find(id);
@@ -701,6 +895,14 @@ std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
return funcName;
}
+std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
+ std::string graphName = nameMap.lookup(id).str();
+ if (graphName.empty()) {
+ graphName = "spirv_graph_" + std::to_string(id);
+ }
+ return graphName;
+}
+
std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
auto constName = nameMap.lookup(id).str();
if (constName.empty()) {
@@ -723,6 +925,14 @@ spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
return op;
}
+std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+spirv::Deserializer::getGraphConstantARM(uint32_t id) {
+ auto graphConstIt = graphConstantMap.find(id);
+ if (graphConstIt == graphConstantMap.end())
+ return std::nullopt;
+ return graphConstIt->getSecond();
+}
+
LogicalResult
spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
unsigned wordIndex = 0;
@@ -944,6 +1154,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return processMatrixType(operands);
case spirv::Opcode::OpTypeTensorARM:
return processTensorARMType(operands);
+ case spirv::Opcode::OpTypeGraphARM:
+ return processGraphTypeARM(operands);
default:
return emitError(unknownLoc, "unhandled type instruction");
}
@@ -1311,6 +1523,35 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult
+spirv::Deserializer::processGraphTypeARM(ArrayRef<uint32_t> operands) {
+ unsigned size = operands.size();
+ if (size < 2) {
+ return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
+ "(result_id, num_inputs, (inout0_type, "
+ "inout1_type, ...))")
+ << size;
+ }
+ uint32_t numInputs = operands[1];
+ SmallVector<Type, 1> argTypes;
+ SmallVector<Type, 1> returnTypes;
+ for (unsigned i = 2; i < size; i++) {
+ Type inOutTy = getType(operands[i]);
+ if (!inOutTy) {
+ return emitError(unknownLoc,
+ "OpTypeGraphARM references undefined element type.")
+ << operands[i];
+ }
+ if (i - 2 >= numInputs) {
+ returnTypes.push_back(inOutTy);
+ } else {
+ argTypes.push_back(inOutTy);
+ }
+ }
+ typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
@@ -1823,6 +2064,34 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
<< resultType;
}
+LogicalResult
+spirv::Deserializer::processGraphConstantARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 3) {
+ return emitError(unknownLoc)
+ << "OpGraphConstantARM must have at least 2 operands";
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ uint32_t resultID = operands[1];
+
+ if (!dyn_cast<spirv::TensorArmType>(resultType)) {
+ return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
+ }
+
+ APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
+ Type i32Ty = opBuilder.getIntegerType(32);
+ IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
+ graphConstantMap.try_emplace(
+ resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Control flow
//===----------------------------------------------------------------------===//
@@ -1920,6 +2189,24 @@ LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
+ if (!curGraph) {
+ return emitError(unknownLoc, "a graph block must appear inside a graph");
+ }
+
+ // We may have forward declared this block.
+ Block *block = getOrCreateBlock(graphID);
+ LLVM_DEBUG(logger.startLine()
+ << "[block] populating block " << block << "\n");
+ // If we have seen this block, make sure it was just a forward declaration.
+ assert(block->empty() && "re-deserialize the same block!");
+
+ opBuilder.setInsertionPointToStart(block);
+ blockMap[graphID] = curBlock = block;
+
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
if (!curBlock) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index db1cc3f8d79c2..6027f1ac94c23 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -106,6 +106,13 @@ struct SpecConstOperationMaterializationInfo {
SmallVector<uint32_t> enclosedOpOperands;
};
+/// A struct that collects the info needed to materialize/emit a
+/// GraphConstantARMOp.
+struct GraphConstantARMOpMaterializationInfo {
+ Type resultType;
+ IntegerAttr graphConstantID;
+};
+
//===----------------------------------------------------------------------===//
// Deserializer Declaration
//===----------------------------------------------------------------------===//
@@ -211,9 +218,14 @@ class Deserializer {
/// exists; otherwise creates one based on the <id>.
std::string getFunctionSymbol(uint32_t id);
- /// Returns a symbol to be used for the specialization constant with the given
- /// result <id>. This tries to use the specialization constant's OpName if
+ /// Returns a symbol to be used for the graph name with the given
+ /// result <id>. This tries to use the graph's OpName if
/// exists; otherwise creates one based on the <id>.
+ std::string getGraphSymbol(uint32_t id);
+
+ /// Returns a symbol to be used for the specialization constant with the
+ /// given result <id>. This tries to use the specialization constant's
+ /// OpName if exists; otherwise creates one based on the <id>.
std::string getSpecConstantSymbol(uint32_t id);
/// Gets the specialization constant with the given result <id>.
@@ -237,6 +249,11 @@ class Deserializer {
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
TypedAttr defaultValue);
+ /// Gets the GraphConstantARM ID attribute and result type with the given
+ /// result <id>.
+ std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+ getGraphConstantARM(uint32_t id);
+
/// Processes the OpVariable instructions at current `offset` into `binary`.
/// It is expected that this method is used for variables that are to be
/// defined at module scope and will be deserialized into a
@@ -306,6 +323,16 @@ class Deserializer {
LogicalResult processTensorARMType(ArrayRef<uint32_t> operands);
+ LogicalResult processGraphTypeARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processGraphEntryPointARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processGraphARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processOpGraphSetOutputARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processGraphEndARM(ArrayRef<uint32_t> operands);
+
LogicalResult processTypeForwardPointer(ArrayRef<uint32_t> operands);
//===--------------------------------------------------------------------===//
@@ -353,6 +380,10 @@ class Deserializer {
/// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
+ /// Processes a SPIR-V OpGraphConstantARM instruction with the given
+ /// `operands`.
+ LogicalResult processGraphConstantARM(ArrayRef<uint32_t> operands);
+
//===--------------------------------------------------------------------===//
// Debug
//===--------------------------------------------------------------------===//
@@ -450,6 +481,9 @@ class Deserializer {
/// blocks declared as selection/loop headers are handled.
LogicalResult structurizeControlFlow();
+ /// Creates a block for graph with the given graphID.
+ LogicalResult createGraphBlock(uint32_t graphID);
+
//===--------------------------------------------------------------------===//
// Instruction
//===--------------------------------------------------------------------===//
@@ -546,6 +580,9 @@ class Deserializer {
/// The current function under construction.
std::optional<spirv::FuncOp> curFunction;
+ /// The current graph under construction.
+ std::optional<spirv::GraphARMOp> curGraph;
+
/// The current block under construction.
Block *curBlock = nullptr;
@@ -599,12 +636,19 @@ class Deserializer {
DenseMap<uint32_t, SpecConstOperationMaterializationInfo>
specConstOperationMap;
+ // Result <id> to GraphConstantARM ID attribute and result type.
+ DenseMap<uint32_t, spirv::GraphConstantARMOpMaterializationInfo>
+ graphConstantMap;
+
// Result <id> to variable mapping.
DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
// Result <id> to function mapping.
DenseMap<uint32_t, spirv::FuncOp> funcMap;
+ // Result <id> to function mapping.
+ DenseMap<uint32_t, spirv::GraphARMOp> graphMap;
+
// Result <id> to block mapping.
DenseMap<uint32_t, Block *> blockMap;
@@ -668,6 +712,9 @@ class Deserializer {
/// Deserialization options.
DeserializationOptions options;
+ /// List of IDs assigned to graph outputs.
+ SmallVector<Value> graphOutputs;
+
#ifndef NDEBUG
/// A logger used to emit information during the deserialzation process.
llvm::ScopedPrinter logger;
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index d62529b85b3aa..e9b180a70bb23 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -203,6 +203,16 @@ Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
return success();
}
+LogicalResult
+Serializer::processGraphConstantARMOp(spirv::GraphConstantARMOp op) {
+ if (uint32_t resultID = prepareGraphConstantId(op.getLoc(), op.getType(),
+ op.getGraphConstantIdAttr())) {
+ valueIDMap[op.getResult()] = resultID;
+ return success();
+ }
+ return failure();
+}
+
LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
auto undefType = op.getType();
auto &id = undefValIDMap[undefType];
@@ -368,6 +378,118 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
return success();
}
+LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) {
+ if (op.getNumResults() < 1) {
+ return op.emitError("cannot serialize graph with no return types");
+ }
+
+ LLVM_DEBUG(llvm::dbgs() <...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Davide Grohmann (davidegrohmann) ChangesThis is the second patch to add support for the The part 2 implementation includes:
Graphs currently support only Spec: KhronosGroup/SPIRV-Registry#346 Patch is 35.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156665.diff 7 Files Affected:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index ee18cf815e4a7..4c49ec868bbc8 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -86,6 +86,13 @@ Value spirv::Deserializer::getValue(uint32_t id) {
if (auto undef = getUndefType(id)) {
return spirv::UndefOp::create(opBuilder, unknownLoc, undef);
}
+ if (std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+ graphConstantARMInfo = getGraphConstantARM(id)) {
+ IntegerAttr graphConstantID = graphConstantARMInfo->graphConstantID;
+ Type resultType = graphConstantARMInfo->resultType;
+ return opBuilder.create<spirv::GraphConstantARMOp>(unknownLoc, resultType,
+ graphConstantID);
+ }
return valueMap.lookup(id);
}
@@ -180,6 +187,7 @@ LogicalResult spirv::Deserializer::processInstruction(
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
case spirv::Opcode::OpTypeTensorARM:
+ case spirv::Opcode::OpTypeGraphARM:
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
@@ -208,12 +216,26 @@ LogicalResult spirv::Deserializer::processInstruction(
return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
case spirv::Opcode::OpConstantNull:
return processConstantNull(operands);
+ case spirv::Opcode::OpGraphConstantARM:
+ return processGraphConstantARM(operands);
case spirv::Opcode::OpDecorate:
return processDecoration(operands);
case spirv::Opcode::OpMemberDecorate:
return processMemberDecoration(operands);
case spirv::Opcode::OpFunction:
return processFunction(operands);
+ case spirv::Opcode::OpGraphEntryPointARM:
+ if (deferInstructions) {
+ deferredInstructions.emplace_back(opcode, operands);
+ return success();
+ }
+ return processGraphEntryPointARM(operands);
+ case spirv::Opcode::OpGraphARM:
+ return processGraphARM(operands);
+ case spirv::Opcode::OpGraphSetOutputARM:
+ return processOpGraphSetOutputARM(operands);
+ case spirv::Opcode::OpGraphEndARM:
+ return processGraphEndARM(operands);
case spirv::Opcode::OpLabel:
return processLabel(operands);
case spirv::Opcode::OpBranch:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 3625dd2eb7dd3..37b5d348b0a1c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -669,6 +669,200 @@ spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult
+spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 2) {
+ return emitError(unknownLoc,
+ "missing graph defintion in OpGraphEntryPointARM");
+ }
+
+ unsigned wordIndex = 0;
+ uint32_t grID = operands[wordIndex++];
+ if (!graphMap.count(grID)) {
+ return emitError(unknownLoc,
+ "missing graph definition/declaration with id ")
+ << grID;
+ }
+
+ spirv::GraphARMOp graphARM = graphMap[grID];
+ StringRef name = decodeStringLiteral(operands, wordIndex);
+ graphARM.setSymName(name);
+ graphARM.setEntryPoint(true);
+
+ SmallVector<Attribute, 4> interface;
+ for (int64_t size = operands.size(); wordIndex < size; wordIndex++) {
+ if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
+ interface.push_back(SymbolRefAttr::get(arg.getOperation()));
+ } else {
+ return emitError(unknownLoc, "undefined result <id> ")
+ << operands[wordIndex] << " while decoding OpGraphEntryPoint";
+ }
+ }
+
+ // RAII guard to reset the insertion point to previous value when done.
+ OpBuilder::InsertionGuard insertionGuard(opBuilder);
+ opBuilder.setInsertionPoint(graphARM);
+ opBuilder.create<spirv::GraphEntryPointARMOp>(
+ unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
+ opBuilder.getArrayAttr(interface));
+
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
+ if (curGraph) {
+ return emitError(unknownLoc, "found graph inside graph");
+ }
+ // Get the result type.
+ if (operands.size() < 2) {
+ return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
+ }
+
+ Type type = getType(operands[0]);
+ if (!type || !isa<GraphType>(type)) {
+ return emitError(unknownLoc, "unknown graph type from <id> ")
+ << operands[0];
+ }
+ auto graphType = cast<GraphType>(type);
+ if (graphType.getNumResults() <= 0) {
+ return emitError(unknownLoc, "expected at least one result");
+ }
+
+ uint32_t grID = operands[1];
+ if (graphMap.count(grID)) {
+ return emitError(unknownLoc, "duplicate graph definition/declaration");
+ }
+
+ std::string grName = getGraphSymbol(grID);
+ auto graphOp =
+ opBuilder.create<spirv::GraphARMOp>(unknownLoc, grName, graphType);
+ curGraph = graphMap[grID] = graphOp;
+ Block *entryBlock = graphOp.addEntryBlock();
+ LLVM_DEBUG({
+ logger.startLine()
+ << "//===-------------------------------------------===//\n";
+ logger.startLine() << "[graph] name: " << grName << "\n";
+ logger.startLine() << "[graph] type: " << graphType << "\n";
+ logger.startLine() << "[graph] ID: " << grID << "\n";
+ logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
+ logger.indent();
+ });
+
+ // Parse the op argument instructions.
+ for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
+ spirv::Opcode opcode;
+ ArrayRef<uint32_t> operands;
+ if (failed(sliceInstruction(opcode, operands,
+ spirv::Opcode::OpGraphInputARM))) {
+ return failure();
+ }
+ if (operands.size() != 3) {
+ return emitError(unknownLoc, "expected result type, result <id> and "
+ "input index for OpGraphInputARM");
+ }
+
+ Type argDefinedType = getType(operands[0]);
+ if (!argDefinedType) {
+ return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
+ }
+
+ if (argDefinedType != argType) {
+ return emitError(unknownLoc,
+ "mismatch in argument type between graph type "
+ "definition ")
+ << graphType << " and argument type definition " << argDefinedType
+ << " at argument " << index;
+ }
+ if (getValue(operands[1])) {
+ return emitError(unknownLoc, "duplicate definition of result <id> ")
+ << operands[1];
+ }
+
+ IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
+ if (!inputIndexAttr) {
+ return emitError(unknownLoc,
+ "unable to read inputIndex value from constant op ")
+ << operands[2];
+ }
+ BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
+ valueMap[operands[1]] = argValue;
+ }
+
+ graphOutputs.resize(graphType.getNumResults());
+
+ // RAII guard to reset the insertion point to the module's region after
+ // deserializing the body of this function.
+ OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
+
+ blockMap[grID] = entryBlock;
+ if (failed(createGraphBlock(grID))) {
+ return failure();
+ }
+
+ // Process all the instructions in the graph until and including
+ // OpGraphEndARM.
+ spirv::Opcode opcode;
+ ArrayRef<uint32_t> instOperands;
+ do {
+ if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
+ return failure();
+ }
+
+ if (failed(processInstruction(opcode, instOperands))) {
+ return failure();
+ }
+ } while (opcode != spirv::Opcode::OpGraphEndARM);
+
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2) {
+ return emitError(
+ unknownLoc,
+ "expected value id and output index for OpGraphSetOutputARM");
+ }
+
+ uint32_t id = operands[0];
+ Value value = getValue(id);
+ if (!value) {
+ return emitError(unknownLoc, "could not find result <id> ") << id;
+ }
+
+ IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
+ if (!outputIndexAttr) {
+ return emitError(unknownLoc,
+ "unable to read outputIndex value from constant op ")
+ << operands[1];
+ }
+ graphOutputs[outputIndexAttr.getInt()] = value;
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphEndARM(ArrayRef<uint32_t> operands) {
+ // Create GraphOutputsARM instruction.
+ opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
+
+ // Process OpGraphEndARM.
+ if (!operands.empty()) {
+ return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
+ }
+
+ curBlock = nullptr;
+ curGraph = std::nullopt;
+ graphOutputs.clear();
+
+ LLVM_DEBUG({
+ logger.unindent();
+ logger.startLine()
+ << "//===-------------------------------------------===//\n";
+ });
+ return success();
+}
+
std::optional<std::pair<Attribute, Type>>
spirv::Deserializer::getConstant(uint32_t id) {
auto constIt = constantMap.find(id);
@@ -701,6 +895,14 @@ std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
return funcName;
}
+std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
+ std::string graphName = nameMap.lookup(id).str();
+ if (graphName.empty()) {
+ graphName = "spirv_graph_" + std::to_string(id);
+ }
+ return graphName;
+}
+
std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
auto constName = nameMap.lookup(id).str();
if (constName.empty()) {
@@ -723,6 +925,14 @@ spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
return op;
}
+std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+spirv::Deserializer::getGraphConstantARM(uint32_t id) {
+ auto graphConstIt = graphConstantMap.find(id);
+ if (graphConstIt == graphConstantMap.end())
+ return std::nullopt;
+ return graphConstIt->getSecond();
+}
+
LogicalResult
spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
unsigned wordIndex = 0;
@@ -944,6 +1154,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return processMatrixType(operands);
case spirv::Opcode::OpTypeTensorARM:
return processTensorARMType(operands);
+ case spirv::Opcode::OpTypeGraphARM:
+ return processGraphTypeARM(operands);
default:
return emitError(unknownLoc, "unhandled type instruction");
}
@@ -1311,6 +1523,35 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult
+spirv::Deserializer::processGraphTypeARM(ArrayRef<uint32_t> operands) {
+ unsigned size = operands.size();
+ if (size < 2) {
+ return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
+ "(result_id, num_inputs, (inout0_type, "
+ "inout1_type, ...))")
+ << size;
+ }
+ uint32_t numInputs = operands[1];
+ SmallVector<Type, 1> argTypes;
+ SmallVector<Type, 1> returnTypes;
+ for (unsigned i = 2; i < size; i++) {
+ Type inOutTy = getType(operands[i]);
+ if (!inOutTy) {
+ return emitError(unknownLoc,
+ "OpTypeGraphARM references undefined element type.")
+ << operands[i];
+ }
+ if (i - 2 >= numInputs) {
+ returnTypes.push_back(inOutTy);
+ } else {
+ argTypes.push_back(inOutTy);
+ }
+ }
+ typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
@@ -1823,6 +2064,34 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
<< resultType;
}
+LogicalResult
+spirv::Deserializer::processGraphConstantARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 3) {
+ return emitError(unknownLoc)
+ << "OpGraphConstantARM must have at least 2 operands";
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ uint32_t resultID = operands[1];
+
+ if (!dyn_cast<spirv::TensorArmType>(resultType)) {
+ return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
+ }
+
+ APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
+ Type i32Ty = opBuilder.getIntegerType(32);
+ IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
+ graphConstantMap.try_emplace(
+ resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Control flow
//===----------------------------------------------------------------------===//
@@ -1920,6 +2189,24 @@ LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
+ if (!curGraph) {
+ return emitError(unknownLoc, "a graph block must appear inside a graph");
+ }
+
+ // We may have forward declared this block.
+ Block *block = getOrCreateBlock(graphID);
+ LLVM_DEBUG(logger.startLine()
+ << "[block] populating block " << block << "\n");
+ // If we have seen this block, make sure it was just a forward declaration.
+ assert(block->empty() && "re-deserialize the same block!");
+
+ opBuilder.setInsertionPointToStart(block);
+ blockMap[graphID] = curBlock = block;
+
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
if (!curBlock) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index db1cc3f8d79c2..6027f1ac94c23 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -106,6 +106,13 @@ struct SpecConstOperationMaterializationInfo {
SmallVector<uint32_t> enclosedOpOperands;
};
+/// A struct that collects the info needed to materialize/emit a
+/// GraphConstantARMOp.
+struct GraphConstantARMOpMaterializationInfo {
+ Type resultType;
+ IntegerAttr graphConstantID;
+};
+
//===----------------------------------------------------------------------===//
// Deserializer Declaration
//===----------------------------------------------------------------------===//
@@ -211,9 +218,14 @@ class Deserializer {
/// exists; otherwise creates one based on the <id>.
std::string getFunctionSymbol(uint32_t id);
- /// Returns a symbol to be used for the specialization constant with the given
- /// result <id>. This tries to use the specialization constant's OpName if
+ /// Returns a symbol to be used for the graph name with the given
+ /// result <id>. This tries to use the graph's OpName if
/// exists; otherwise creates one based on the <id>.
+ std::string getGraphSymbol(uint32_t id);
+
+ /// Returns a symbol to be used for the specialization constant with the
+ /// given result <id>. This tries to use the specialization constant's
+ /// OpName if exists; otherwise creates one based on the <id>.
std::string getSpecConstantSymbol(uint32_t id);
/// Gets the specialization constant with the given result <id>.
@@ -237,6 +249,11 @@ class Deserializer {
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
TypedAttr defaultValue);
+ /// Gets the GraphConstantARM ID attribute and result type with the given
+ /// result <id>.
+ std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+ getGraphConstantARM(uint32_t id);
+
/// Processes the OpVariable instructions at current `offset` into `binary`.
/// It is expected that this method is used for variables that are to be
/// defined at module scope and will be deserialized into a
@@ -306,6 +323,16 @@ class Deserializer {
LogicalResult processTensorARMType(ArrayRef<uint32_t> operands);
+ LogicalResult processGraphTypeARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processGraphEntryPointARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processGraphARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processOpGraphSetOutputARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processGraphEndARM(ArrayRef<uint32_t> operands);
+
LogicalResult processTypeForwardPointer(ArrayRef<uint32_t> operands);
//===--------------------------------------------------------------------===//
@@ -353,6 +380,10 @@ class Deserializer {
/// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
+ /// Processes a SPIR-V OpGraphConstantARM instruction with the given
+ /// `operands`.
+ LogicalResult processGraphConstantARM(ArrayRef<uint32_t> operands);
+
//===--------------------------------------------------------------------===//
// Debug
//===--------------------------------------------------------------------===//
@@ -450,6 +481,9 @@ class Deserializer {
/// blocks declared as selection/loop headers are handled.
LogicalResult structurizeControlFlow();
+ /// Creates a block for graph with the given graphID.
+ LogicalResult createGraphBlock(uint32_t graphID);
+
//===--------------------------------------------------------------------===//
// Instruction
//===--------------------------------------------------------------------===//
@@ -546,6 +580,9 @@ class Deserializer {
/// The current function under construction.
std::optional<spirv::FuncOp> curFunction;
+ /// The current graph under construction.
+ std::optional<spirv::GraphARMOp> curGraph;
+
/// The current block under construction.
Block *curBlock = nullptr;
@@ -599,12 +636,19 @@ class Deserializer {
DenseMap<uint32_t, SpecConstOperationMaterializationInfo>
specConstOperationMap;
+ // Result <id> to GraphConstantARM ID attribute and result type.
+ DenseMap<uint32_t, spirv::GraphConstantARMOpMaterializationInfo>
+ graphConstantMap;
+
// Result <id> to variable mapping.
DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
// Result <id> to function mapping.
DenseMap<uint32_t, spirv::FuncOp> funcMap;
+ // Result <id> to function mapping.
+ DenseMap<uint32_t, spirv::GraphARMOp> graphMap;
+
// Result <id> to block mapping.
DenseMap<uint32_t, Block *> blockMap;
@@ -668,6 +712,9 @@ class Deserializer {
/// Deserialization options.
DeserializationOptions options;
+ /// List of IDs assigned to graph outputs.
+ SmallVector<Value> graphOutputs;
+
#ifndef NDEBUG
/// A logger used to emit information during the deserialzation process.
llvm::ScopedPrinter logger;
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index d62529b85b3aa..e9b180a70bb23 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -203,6 +203,16 @@ Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
return success();
}
+LogicalResult
+Serializer::processGraphConstantARMOp(spirv::GraphConstantARMOp op) {
+ if (uint32_t resultID = prepareGraphConstantId(op.getLoc(), op.getType(),
+ op.getGraphConstantIdAttr())) {
+ valueIDMap[op.getResult()] = resultID;
+ return success();
+ }
+ return failure();
+}
+
LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
auto undefType = op.getType();
auto &id = undefValIDMap[undefType];
@@ -368,6 +378,118 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
return success();
}
+LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) {
+ if (op.getNumResults() < 1) {
+ return op.emitError("cannot serialize graph with no return types");
+ }
+
+ LLVM_DEBUG(llvm::dbgs() <...
[truncated]
|
This is the second patch to add support for the
SPV_ARM_graph
SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a newGraph
abstraction for expressing dataflow computations over full resources.The part 2 implementation includes:
OpGraphARM
,OpGraphInputARM
,OpGraphSetOutputARM
,OpGraphEndARM
OpGraphEntryPointARM
,OpGraphConstantARM
,OpTypeGraphARM
Graphs currently support only
SPV_ARM_tensors
, but are designed to generalize to other resource types, such as images.Spec: KhronosGroup/SPIRV-Registry#346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947