diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index b4d696444cc44..cfe3e800484ce 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -185,6 +185,9 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { InterfaceMethod<"Check the availability of workgroup level layouts", "bool", "isForWorkgroup">, + InterfaceMethod<"Check the availability of subgroup level layouts", + "bool", + "isForSubgroup">, InterfaceMethod<"Get the rank of attribute", "int64_t", "getRank">, @@ -197,14 +200,26 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { return 0; }], [{}]>, InterfaceMethod<"Get the SgLayout field of the attribute as integer array", - "std::optional>", + "SmallVector", "getSgLayoutAsInt">, InterfaceMethod<"Get the SgData field of the attribute as integer array", - "std::optional>", + "SmallVector", "getSgDataAsInt">, + InterfaceMethod<"Get the InstData field of the attribute as integer array", + "SmallVector", + "getInstDataAsInt">, + InterfaceMethod<"Get the LaneLayout field of the attribute as integer array", + "SmallVector", + "getLaneLayoutAsInt">, + InterfaceMethod<"Get the LaneData field of the attribute as integer array", + "SmallVector", + "getLaneDataAsInt">, InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData", "xegpu::DistributeLayoutAttr", "dropSgLayoutAndData">, + InterfaceMethod<"Derive a new layout by dropping InstData", + "xegpu::DistributeLayoutAttr", + "dropInstData">, InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional indices based on the effective subgroup layout.}], "FailureOr>", @@ -376,16 +391,34 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> { getLaneLayout(), getLaneData(), getOrder()); } - std::optional> getSgLayoutAsInt() const { + SmallVector getSgLayoutAsInt() const { if (DenseI32ArrayAttr layout = getSgLayout()) return llvm::to_vector_of(layout.asArrayRef()); - return std::nullopt; + return {}; } - std::optional> getSgDataAsInt() const { + SmallVector getSgDataAsInt() const { if (DenseI32ArrayAttr data = getSgData()) return llvm::to_vector_of(data.asArrayRef()); - return std::nullopt; + return {}; + } + + SmallVector getInstDataAsInt() const { + if (DenseI32ArrayAttr inst = getInstData()) + return llvm::to_vector_of(inst.asArrayRef()); + return {}; + } + + SmallVector getLaneLayoutAsInt() const { + if (DenseI32ArrayAttr layout = getLaneLayout()) + return llvm::to_vector_of(layout.asArrayRef()); + return {}; + } + + SmallVector getLaneDataAsInt() const { + if (DenseI32ArrayAttr data = getLaneData()) + return llvm::to_vector_of(data.asArrayRef()); + return {}; } /// Delinearizes a linear subgroup ID into its multidimensional indices @@ -466,26 +499,67 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { /// Returns the SgLayout of the attribute, computed by applying /// the slice dimensions to the underlying LayoutAttr. - std::optional> getSgLayoutAsInt() const { + SmallVector getSgLayoutAsInt() const { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); - if (auto layout = parent.getSgLayoutAsInt()) { + auto layout = parent.getSgLayoutAsInt(); + if (layout.size()) { ArrayRef dims = attr.getDims().asArrayRef(); - return XeGPUDialect::slice(llvm::ArrayRef(*layout), dims); + return XeGPUDialect::slice(ArrayRef(layout), dims); } - return std::nullopt; + return {}; } /// Returns the SgData of the attribute, computed by applying /// the slice dimensions to the underlying LayoutAttr. - std::optional> getSgDataAsInt() const { + SmallVector getSgDataAsInt() const { + SliceAttr attr = flatten(); + auto parent = dyn_cast(attr.getParent()); + auto data = parent.getSgDataAsInt(); + if (data.size()) { + ArrayRef dims = attr.getDims().asArrayRef(); + return XeGPUDialect::slice(ArrayRef(data), dims); + } + return {}; + } + + /// Returns the InstData of the attribute, computed by applying + /// the slice dimensions to the underlying LayoutAttr. + SmallVector getInstDataAsInt() const { + SliceAttr attr = flatten(); + auto parent = dyn_cast(attr.getParent()); + auto inst = parent.getInstDataAsInt(); + if (inst.size()) { + ArrayRef dims = attr.getDims().asArrayRef(); + return XeGPUDialect::slice(llvm::ArrayRef(inst), dims); + } + return {}; + } + + /// Returns the LaneLayout of the attribute, computed by applying + /// the slice dimensions to the underlying LayoutAttr. + SmallVector getLaneLayoutAsInt() const { + SliceAttr attr = flatten(); + auto parent = dyn_cast(attr.getParent()); + auto layout = parent.getLaneLayoutAsInt(); + if (layout.size()) { + ArrayRef dims = attr.getDims().asArrayRef(); + return XeGPUDialect::slice(llvm::ArrayRef(layout), dims); + } + return {}; + } + + /// Returns the LaneData of the attribute, computed by applying + /// the slice dimensions to the underlying LayoutAttr. + SmallVector getLaneDataAsInt() const { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); - if (auto data = parent.getSgDataAsInt()) { + auto data = parent.getLaneDataAsInt(); + if (data.size()) { ArrayRef dims = attr.getDims().asArrayRef(); - return XeGPUDialect::slice(llvm::ArrayRef(*data), dims); + return XeGPUDialect::slice(llvm::ArrayRef(data), dims); } - return std::nullopt; + return {}; } SliceAttr dropSgLayoutAndData() { diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td index 76d58e5ea2424..c173b93face98 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td @@ -40,7 +40,7 @@ def XeGPU_Dialect : Dialect { let extraClassDeclaration = [{ /// Checks if the given shape can be evenly distributed based on the layout /// and data factors provided by the LayoutAttr. - static bool isEvenlyDistributable(llvm::ArrayRef shape, xegpu::LayoutAttr attr); + static bool isEvenlyDistributable(llvm::ArrayRef shape, xegpu::DistributeLayoutAttr attr); /// drops/slices the shape in the specified dims, and return the rest. e.g., /// for shape = [32, 64, 8], dims = [0, 2], it will return [64] diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index ab471a1f33ef9..2f6671c5e37cc 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1162,8 +1162,8 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou the IR is lowered to WI level because that is the end result of all distributions. }]; let arguments = (ins XeGPU_VectorType: $source, - XeGPU_LayoutAttr: $input_layout, - XeGPU_LayoutAttr: $target_layout); + DistributeLayoutAttr: $input_layout, + DistributeLayoutAttr: $target_layout); let results = (outs XeGPU_VectorType: $result); let assemblyFormat = [{ $source prop-dict attr-dict `:` type($source) diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index b2b2d3ab85231..bad734dbfd9f0 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_ #define MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_ +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" namespace mlir { @@ -21,6 +22,7 @@ class ValueRange; class TypeConverter; namespace xegpu { +class DistributeLayoutAttr; class LayoutAttr; class TensorDescType; } // namespace xegpu @@ -60,22 +62,33 @@ FailureOr getDistributedVectorType(xegpu::TensorDescType tdescTy); FailureOr getDistributedVectorType(VectorType originalType, LayoutAttr layout); -/// Return the attribute name for the OpOperand to attach LayoutAttr +/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr std::string getLayoutName(const OpOperand &operand); -/// Return the attribute name for the OpResult to attach LayoutAttr +/// Return the attribute name for the OpResult to attach DistributeLayoutAttr std::string getLayoutName(const OpResult result); -/// Retrieves the LayoutAttr associated with a given Value. For TensorDescType -/// values, the LayoutAttr is extracted from the TensorDescType itself. For -/// other values, it is obtained from the attributes of the defining operation. -/// Returns nullptr if no LayoutAttr is found. -LayoutAttr getLayoutAttr(const Value value); +/// Retrieves the DistributeLayoutAttr associated with a given Value. For +/// TensorDescType values, the DistributeLayoutAttr is extracted from the +/// TensorDescType itself. For other values, it is obtained from the attributes +/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is +/// found. +DistributeLayoutAttr getDistributeLayoutAttr(const Value value); -/// Retrieves the LayoutAttr associated with a given OpOperand. It will -/// first check the operand_layout_{id} of the owner operation. If not found, -/// it will check the operand itself and its defining op. -LayoutAttr getLayoutAttr(const OpOperand &opr); +template +AttrTy getDistributeLayoutAttrOfType(const Value value) { + return dyn_cast_if_present(getDistributeLayoutAttr(value)); +} + +/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It +/// will first check the operand_layout_{id} of the owner operation. If not +/// found, it will check the operand itself and its defining op. +DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr); + +template +AttrTy getDistributeLayoutAttrOfType(const OpOperand &opr) { + return dyn_cast_if_present(getDistributeLayoutAttr(opr)); +} /// Removes the LayoutAttr for a given OpOperand or OpResult if it exists. template >> void removeLayoutAttr(const T &operandOrResult); -/// Removes the LayoutAttr for each OpOperand and OpResult of the given -/// operation if they exist. If the operation contains regions, it is also +/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the +/// given operation if they exist. If the operation contains regions, it is also /// applied recursively to the contained operations void removeLayoutAttrs(Operation *op); -/// Sets the LayoutAttr for a given OpOperand or OpResult by attaching +/// Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching /// it to the owner's dictionary attributes template || std::is_same_v>> -void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout); - -/// Set the LayoutAttr for each OpOperand and OpResult of the given operation. -/// If the operation contains regions, it is also applied recursively to the -/// contained operations -void setLayoutAttrs(Operation *op, - function_ref getLayoutImpl); +void setDistributeLayoutAttr(const T &operandOrResult, + const DistributeLayoutAttr layout); + +/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given +/// operation. If the operation contains regions, it is also applied recursively +/// to the contained operations +void setDistributeLayoutAttrs( + Operation *op, function_ref getLayoutImpl); /// Extract a set of small vectors from a value with a given shape using /// vector.extract_stride_slice diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index a2d708be0e937..7f3be7f91c56b 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -91,7 +91,7 @@ genOffsetsComputingInsts(OpBuilder &builder, Location loc, // Checks if the given shape can be evenly distributed based on the layout // and data factors provided by the LayoutAttr. bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef shape, - xegpu::LayoutAttr attr) { + xegpu::DistributeLayoutAttr attr) { assert(attr && "Layout attribute is missing."); // Checks whether the given shape can be evenly distributed using the @@ -104,52 +104,51 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef shape, // smaller than `layout[i] * data[i]`, allowing multiple compute units to // share the data. auto tryDistribute = [&](llvm::ArrayRef shape, - DenseI32ArrayAttr layout, DenseI32ArrayAttr data, + SmallVector layout, + SmallVector data, bool rr = true) -> optional> { llvm::SmallVector newShape(shape); - if (layout) { - auto vec = llvm::to_vector_of(layout.asArrayRef()); - if (vec.size() != shape.size()) + if (layout.size()) { + if (layout.size() != shape.size()) return std::nullopt; - auto ratio = computeShapeRatio(shape, vec); + auto ratio = computeShapeRatio(shape, layout); if (!ratio.has_value()) return std::nullopt; newShape = ratio.value(); } - if (data) { - auto vec = llvm::to_vector_of(data.asArrayRef()); - if (vec.size() != shape.size()) + if (data.size()) { + if (data.size() != shape.size()) return std::nullopt; - auto ratio = computeShapeRatio(newShape, vec); + auto ratio = computeShapeRatio(newShape, data); if (!ratio.has_value() && rr) - ratio = computeShapeRatio(vec, newShape); + ratio = computeShapeRatio(data, newShape); if (!ratio.has_value()) return std::nullopt; // if data is not null, we always return it for next phase. - newShape = vec; + newShape = data; } return newShape; }; // check the sgLayout and sgData auto maybeSgShape = - tryDistribute(shape, attr.getSgLayout(), attr.getSgData()); + tryDistribute(shape, attr.getSgLayoutAsInt(), attr.getSgDataAsInt()); if (!maybeSgShape) return false; auto sgShape = maybeSgShape.value(); // check InstData, it neither have layout nor need round-robin auto maybeInstShape = - tryDistribute(sgShape, nullptr, attr.getInstData(), false); + tryDistribute(sgShape, {}, attr.getInstDataAsInt(), false); if (!maybeInstShape) return false; auto instShape = maybeInstShape.value(); // check LaneLayout and LaneData - auto maybeLaneShape = - tryDistribute(instShape, attr.getLaneLayout(), attr.getLaneData(), false); + auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayoutAsInt(), + attr.getLaneDataAsInt(), false); return maybeLaneShape.has_value(); } @@ -283,7 +282,7 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, if (!hasDefaultOrder()) return mlir::emitError(loc, "order attribute is currently not supported."); - auto dims = llvm::map_to_vector(*getSgLayoutAsInt(), [&](int64_t d) -> Value { + auto dims = llvm::map_to_vector(getSgLayoutAsInt(), [&](int64_t d) -> Value { return builder.createOrFold(loc, d); }); @@ -299,14 +298,14 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, if (!isForWorkgroup()) return failure(); - SmallVector sgLayout = getSgLayoutAsInt().value(); - SmallVector sgShape; - if (auto maybeSgShape = getSgDataAsInt()) - sgShape = maybeSgShape.value(); - else if (auto derivedShape = computeShapeRatio(shape, sgLayout)) - sgShape = derivedShape.value(); - else - return failure(); + SmallVector sgLayout = getSgLayoutAsInt(); + SmallVector sgShape = getSgDataAsInt(); + if (sgShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, sgLayout)) + sgShape = derivedShape.value(); + else + return failure(); + } // delinearize Ids auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); @@ -386,14 +385,14 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, if (!isForWorkgroup()) return failure(); - SmallVector sgLayout = getSgLayoutAsInt().value(); - SmallVector sgShape; - if (auto maybeSgShape = getSgDataAsInt()) - sgShape = maybeSgShape.value(); - else if (auto derivedShape = computeShapeRatio(shape, sgLayout)) - sgShape = derivedShape.value(); - else - return failure(); + SmallVector sgLayout = getSgLayoutAsInt(); + SmallVector sgShape = getSgDataAsInt(); + if (sgShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, sgLayout)) + sgShape = derivedShape.value(); + else + return failure(); + } // delinearize Ids auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index b3144e4c1e55d..9ee002ede7838 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -84,9 +84,10 @@ struct ConvertLayoutOpPattern using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override { - xegpu::LayoutAttr input_layout = op.getInputLayoutAttr(); - xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr(); - if (!input_layout.getInstData() || !target_layout.getInstData()) + xegpu::DistributeLayoutAttr input_layout = op.getInputLayoutAttr(); + xegpu::DistributeLayoutAttr target_layout = op.getTargetLayoutAttr(); + if (input_layout.getInstDataAsInt().empty() || + target_layout.getInstDataAsInt().empty()) return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp."); input_layout = input_layout.dropInstData(); @@ -140,10 +141,11 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { else value = (Value)operandOrResult; - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(operandOrResult); if (layout && layout.isForSubgroup()) { - if (auto inst_data = layout.getInstData()) - return llvm::to_vector_of(inst_data.asArrayRef()); + if (!layout.getInstDataAsInt().empty()) + return layout.getInstDataAsInt(); if (auto type = dyn_cast(value.getType())) return llvm::to_vector(type.getShape()); @@ -204,12 +206,14 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const { // skip the op if any of its operands or results has workgroup level layouts bool hasWgLayoutOperands = llvm::any_of(op->getOpOperands(), [](OpOperand &opr) { - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(opr); return layout && layout.isForWorkgroup(); }); bool hasWgLayoutResults = llvm::any_of(op->getOpResults(), [](OpResult result) { - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(result); return layout && layout.isForWorkgroup(); }); if (hasWgLayoutOperands || hasWgLayoutResults) { @@ -220,8 +224,8 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const { auto isUnrollable = [](Value value, ArrayRef tileShape) { Type valTy = value.getType(); if (auto tdescTy = dyn_cast(valTy)) { - xegpu::LayoutAttr layout = tdescTy.getLayoutAttr(); - return layout && layout.getInstData(); + xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr(); + return layout && !layout.getInstDataAsInt().empty(); } auto shapedType = dyn_cast(valTy); return shapedType && !llvm::equal(tileShape, shapedType.getShape()); @@ -247,7 +251,8 @@ void XeGPUBlockingPass::runOnOperation() { // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr. // This ensures that the LayoutAttr remains accessible even if the defining // operation is replaced. - xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); }); + xegpu::setDistributeLayoutAttrs( + op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); }); auto getTileShapeAndCount = [](llvm::ArrayRef shape, xegpu::LayoutAttr layout) { @@ -377,7 +382,7 @@ void XeGPUBlockingPass::runOnOperation() { if (auto layout = op->getAttrOfType(name)) { op->removeAttr(name); if (!isa(op)) - xegpu::setLayoutAttr(result, layout.dropInstData()); + xegpu::setDistributeLayoutAttr(result, layout.dropInstData()); } } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index bef88042fc663..5cb47b2accd68 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -718,7 +718,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, } // If the result is a vector type, add a temporary layout attribute to the // op. - xegpu::setLayoutAttr(result, layout); + xegpu::setDistributeLayoutAttr(result, layout); } return success(); } @@ -800,7 +800,7 @@ updateControlFlowOps(mlir::OpBuilder &builder, // If the type is a vector type and this region argument is an OpResult, // set the layout attribute on the OpResult. if (auto result = dyn_cast(successorInput)) - xegpu::setLayoutAttr(result, successorOperandLayout); + xegpu::setDistributeLayoutAttr(result, successorOperandLayout); } } return success(); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 2088c3c7fc5ec..e48e2180197ec 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -841,14 +841,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() { if (!isa(operand.get().getType())) continue; - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand); + auto layout = + xegpu::getDistributeLayoutAttrOfType(operand); if (!layout) { op->emitError("Could not find layout attribute for operand ") << operand.getOperandNumber() << " of operation " << op->getName(); signalPassFailure(); return; } - xegpu::setLayoutAttr(operand, layout); + xegpu::setDistributeLayoutAttr(operand, layout); } }); // Step 2: Move all operations of a GPU function inside @@ -882,7 +883,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() { if (vecRank == 0) return AffineMap::get(val.getContext()); // Get the layout of the vector type. - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val); + // TODO: support more layout types + auto layout = xegpu::getDistributeLayoutAttrOfType(val); // If no layout is specified, assume the inner most dimension is distributed // for now. if (!layout) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 93b4efcd125ec..0b7fe81facfce 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -52,9 +52,9 @@ getSgShapeAndCount(ArrayRef shape, int count = 1; SmallVector sgShape(shape); if (layout && layout.isForWorkgroup()) { - SmallVector sgLayout = layout.getSgLayoutAsInt().value(); - if (auto maybeSgData = layout.getSgDataAsInt()) - sgShape = *maybeSgData; + SmallVector sgLayout = layout.getSgLayoutAsInt(); + if (!layout.getSgDataAsInt().empty()) + sgShape = layout.getSgDataAsInt(); else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout)) sgShape = *maybeDerivedSgData; SmallVector distUnit = computeElementwiseMul(sgLayout, sgShape); @@ -406,7 +406,7 @@ struct WgToSgDpasOp : public OpConversionPattern { if (resultTy.getRank() != 2) return failure(); - auto originalLayout = xegpu::getLayoutAttr(op.getResult()); + auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult()); if (!originalLayout) return failure(); @@ -429,8 +429,8 @@ struct WgToSgDpasOp : public OpConversionPattern { VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, resultTy.getElementType()); tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands); - xegpu::setLayoutAttr(cast(tmpC), - originalLayout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(cast(tmpC), + originalLayout.dropSgLayoutAndData()); newDpasOps.push_back(tmpC); } @@ -470,8 +470,9 @@ struct WgToSgVectorBroadcastOp VectorType resultType = op.getResult().getType(); ArrayRef wgShape = resultType.getShape(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); - if (!layout || !layout.getSgLayout()) + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) return failure(); // TODO: Currently only supports cases where the source and result ranks @@ -486,10 +487,8 @@ struct WgToSgVectorBroadcastOp VectorType::get(sgShape, resultType.getElementType()); // Check if the output layout is distributable - SmallVector sgLayout; - if (auto sgLayoutAttr = layout.getSgLayout()) - sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); - else + SmallVector sgLayout = layout.getSgLayoutAsInt(); + if (sgLayout.empty()) return failure(); if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout)) @@ -508,8 +507,8 @@ struct WgToSgVectorBroadcastOp for (auto operand : adaptor.getOperands().front()) { auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), newResultType, operand); - xegpu::setLayoutAttr(newBroadcast->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), + layout.dropSgLayoutAndData()); newBroadcastOps.push_back(newBroadcast.getResult()); } @@ -535,8 +534,9 @@ struct WgToSgElementwiseOp : public ConversionPattern { ArrayRef wgShape = resultType.getShape(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); - if (!layout || !layout.getSgLayout()) + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op->getResult(0)); + if (!layout || !layout.isForWorkgroup()) return failure(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; @@ -611,8 +611,9 @@ struct WgToSgConvertLayoutOp LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - xegpu::LayoutAttr input = op.getInputLayout(); - xegpu::LayoutAttr target = op.getTargetLayout(); + // TODO: currently, we only support LayoutAttr + auto input = dyn_cast(op.getInputLayout()); + auto target = dyn_cast(op.getTargetLayout()); if (!input || !target || !input.isForWorkgroup() || !target.isForWorkgroup()) @@ -737,8 +738,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern { if (!vecAttr || !vecAttr.isSplat() || !vecType) return failure(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); - if (!layout || !layout.getSgLayout()) + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) return failure(); ArrayRef wgShape = vecType.getShape(); @@ -755,7 +757,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern { auto cstOp = arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); if (auto newLayout = layout.dropSgLayoutAndData()) - xegpu::setLayoutAttr(cstOp->getResult(0), newLayout); + xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout); SmallVector newConsts(count, cstOp); rewriter.replaceOpWithMultiple(op, {newConsts}); @@ -928,7 +930,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { }); target.addDynamicallyLegalOp([=](xegpu::DpasOp op) -> bool { - auto layout = xegpu::getLayoutAttr(op.getResult()); + auto layout = xegpu::getDistributeLayoutAttr(op.getResult()); return isLegal(layout); }); @@ -947,12 +949,12 @@ void XeGPUWgToSgDistributePass::runOnOperation() { auto vecType = dyn_cast(op.getType()); if (!vecType) return true; - return isLegal(xegpu::getLayoutAttr(op.getResult())); + return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); }); target.addDynamicallyLegalOp( [=](vector::BroadcastOp op) -> bool { - return isLegal(xegpu::getLayoutAttr(op.getResult())); + return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); }); target.addDynamicallyLegalOp( @@ -980,7 +982,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() { } } - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op->getResult(0)); return isLegal(layout); }); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 6835f64ad8ef7..cac1ffe4d3bc3 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -114,7 +114,7 @@ std::string xegpu::getLayoutName(const OpResult result) { return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str(); } -xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) { +xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) { if (!value) return nullptr; @@ -132,11 +132,11 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) { // for LoadNdOp, the layout is stored in the tensor descriptor if (auto loadNd = dyn_cast(defOp)) - return getLayoutAttr(loadNd.getTensorDesc()); + return getDistributeLayoutAttr(loadNd.getTensorDesc()); std::string layoutName = getLayoutName(result); if (defOp->hasAttr(layoutName)) - return defOp->getAttrOfType(layoutName); + return defOp->getAttrOfType(layoutName); } if (auto arg = dyn_cast(value)) { @@ -144,49 +144,51 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) { if (auto loop = dyn_cast(parentOp)) { OpOperand *tiedInit = loop.getTiedLoopInit(arg); if (tiedInit) - return getLayoutAttr(tiedInit->get()); + return getDistributeLayoutAttr(tiedInit->get()); } } return nullptr; } -xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) { +xegpu::DistributeLayoutAttr +xegpu::getDistributeLayoutAttr(const OpOperand &opr) { Operation *op = opr.getOwner(); std::string layoutName = xegpu::getLayoutName(opr); if (op->hasAttr(layoutName)) - return op->getAttrOfType(layoutName); - return getLayoutAttr(opr.get()); + return op->getAttrOfType(layoutName); + return getDistributeLayoutAttr(opr.get()); } template -void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) { +void xegpu::setDistributeLayoutAttr(const T &operandOrResult, + const DistributeLayoutAttr layout) { Operation *owner = operandOrResult.getOwner(); std::string name = xegpu::getLayoutName(operandOrResult); - if (layout && !owner->hasAttrOfType(name)) + if (layout && !owner->hasAttrOfType(name)) owner->setAttr(name, layout); } // Explicit instantiation for OpResult -template void -xegpu::setLayoutAttr(const mlir::OpResult &result, - const mlir::xegpu::LayoutAttr layout); +template void xegpu::setDistributeLayoutAttr( + const mlir::OpResult &result, + const mlir::xegpu::DistributeLayoutAttr layout); // Explicit instantiation for OpOperand -template void -xegpu::setLayoutAttr(const mlir::OpOperand &operand, - const mlir::xegpu::LayoutAttr layout); +template void xegpu::setDistributeLayoutAttr( + const mlir::OpOperand &operand, + const mlir::xegpu::DistributeLayoutAttr layout); -void xegpu::setLayoutAttrs(Operation *op, - function_ref getLayoutImpl) { +void xegpu::setDistributeLayoutAttrs( + Operation *op, function_ref getLayoutImpl) { op->walk([&](Operation *nestOp) { for (OpOperand &opr : nestOp->getOpOperands()) { auto layout = getLayoutImpl(opr.get()); - setLayoutAttr(opr, layout); + setDistributeLayoutAttr(opr, layout); } for (OpResult result : nestOp->getOpResults()) { auto layout = getLayoutImpl(result); - setLayoutAttr(result, layout); + setDistributeLayoutAttr(result, layout); } }); } @@ -195,7 +197,7 @@ template void xegpu::removeLayoutAttr(const T &operandOrResult) { Operation *owner = operandOrResult.getOwner(); std::string name = xegpu::getLayoutName(operandOrResult); - if (owner->hasAttrOfType(name)) + if (owner->hasAttrOfType(name)) owner->removeAttr(name); } @@ -306,7 +308,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( if (!inputTy || !resultTy) return WalkResult::skip(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(input); if (!layout) return WalkResult::skip(); @@ -344,7 +347,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( } { // perform the conversion from RankedTensorType to VectorType based on the - // LayoutAttr + // DistributeLayoutAttr // Handle the UnrealizedConversionCastOp introduced by the first step. // For vector->RankedTensorType, it will simply forward the inputs.