Skip to content

Conversation

adam-smnk
Copy link
Contributor

Adds a utility getter to warp_execute_on_lane_0 which simplifies access to the op's terminator.

Uses are refactored to utilize the new terminator getter.

Adds a utility getter to `warp_execute_on_lane_0` which simplifies
access to op's terminator.

Uses are refactored to utilize the new terminator getter.
@llvmbot
Copy link
Member

llvmbot commented Aug 21, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-gpu

Author: Adam Siemieniuk (adam-smnk)

Changes

Adds a utility getter to warp_execute_on_lane_0 which simplifies access to the op's terminator.

Uses are refactored to utilize the new terminator getter.


Full diff: https://github.com/llvm/llvm-project/pull/154729.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+3)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+5-2)
  • (modified) mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+4-8)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+4-8)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index f946bb731e2ca..a5c3a92f1b7a5 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -3209,6 +3209,9 @@ def GPU_WarpExecuteOnLane0Op : GPU_Op<"warp_execute_on_lane_0",
     bool isDefinedOutsideOfRegion(Value value) {
       return !getRegion().isAncestor(value.getParentRegion());
     }
+
+    /// Get the terminator of the warp region.
+    gpu::YieldOp getTerminator();
   }];
 }
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 2503ccb6a2cfe..cc77aa6711c42 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2486,8 +2486,7 @@ LogicalResult WarpExecuteOnLane0Op::verify() {
   if (getArgs().size() != getWarpRegion().getNumArguments())
     return emitOpError(
         "expected same number op arguments and block arguments.");
-  auto yield =
-      cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
+  gpu::YieldOp yield = getTerminator();
   if (yield.getNumOperands() != getNumResults())
     return emitOpError(
         "expected same number of yield operands and return values.");
@@ -2511,6 +2510,10 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
       verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
 }
 
+gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
+  return cast<gpu::YieldOp>(getBody()->getTerminator());
+}
+
 //===----------------------------------------------------------------------===//
 // GPU KernelMetadataAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
index be71bd02fc43b..88f531f394765 100644
--- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -56,8 +56,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
     SmallVector<size_t> &indices) const {
   SmallVector<Type> types(warpOp.getResultTypes().begin(),
                           warpOp.getResultTypes().end());
-  auto yield = cast<gpu::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  gpu::YieldOp yield = warpOp.getTerminator();
   SmallVector<Value> yieldValues(yield.getOperands().begin(),
                                  yield.getOperands().end());
   llvm::SmallDenseMap<Value, unsigned> indexLookup;
@@ -89,8 +88,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
 OpOperand *WarpDistributionPattern::getWarpResult(
     WarpExecuteOnLane0Op warpOp,
     llvm::function_ref<bool(Operation *)> fn) const {
-  auto yield = cast<gpu::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  gpu::YieldOp yield = warpOp.getTerminator();
   for (OpOperand &yieldOperand : yield->getOpOperands()) {
     Value yieldValues = yieldOperand.get();
     Operation *definedOp = yieldValues.getDefiningOp();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index be0d28a91cba7..60aa0e9bae64a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -528,8 +528,7 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
 
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp yield = warpOp.getTerminator();
     Operation *lastNode = yield->getPrevNode();
     auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
     if (!writeOp)
@@ -846,8 +845,7 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
     newYieldValues.reserve(warpOp->getNumResults());
     DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
     DenseMap<OpResult, int64_t> dedupResultPositionMap;
-    auto yield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp yield = warpOp.getTerminator();
 
     // Some values may be yielded multiple times and correspond to multiple
     // results. Deduplicating occurs by taking each result with its matching
@@ -901,8 +899,7 @@ struct WarpOpForwardOperand : public WarpDistributionPattern {
   using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp yield = warpOp.getTerminator();
     Value valForwarded;
     unsigned resultIndex;
     for (OpOperand &operand : yield->getOpOperands()) {
@@ -1708,8 +1705,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
       : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto warpOpYield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp warpOpYield = warpOp.getTerminator();
     // Only pick up `ForOp` if it is the last op in the region.
     Operation *lastNode = warpOpYield->getPrevNode();
     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2088c3c7fc5ec..8e47968609d32 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -336,8 +336,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp yield = warpOp.getTerminator();
     Operation *lastNode = yield->getPrevNode();
     auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
     if (!storeOp)
@@ -449,8 +448,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
       // Make sure the same load op is the last operation in the warp op body.
       // This ensure that load op is not sinked earlier violating any barrier
       // synchronizations.
-      auto yield = cast<gpu::YieldOp>(
-          warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+      gpu::YieldOp yield = warpOp.getTerminator();
       return yield->getPrevNode() == op;
     });
 
@@ -752,8 +750,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp yield = warpOp.getTerminator();
     Operation *lastNode = yield->getPrevNode();
     auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
     if (!prefetchOp)
@@ -794,8 +791,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp yield = warpOp.getTerminator();
     Operation *lastNode = yield->getPrevNode();
     // The last node must be a gpu::BarrierOp.
     auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);

@llvmbot
Copy link
Member

llvmbot commented Aug 21, 2025

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

Changes

Adds a utility getter to warp_execute_on_lane_0 which simplifies access to the op's terminator.

Uses are refactored to utilize the new terminator getter.


Full diff: https://github.com/llvm/llvm-project/pull/154729.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+3)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+5-2)
  • (modified) mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+4-8)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+4-8)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index f946bb731e2ca..a5c3a92f1b7a5 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -3209,6 +3209,9 @@ def GPU_WarpExecuteOnLane0Op : GPU_Op<"warp_execute_on_lane_0",
     bool isDefinedOutsideOfRegion(Value value) {
       return !getRegion().isAncestor(value.getParentRegion());
     }
+
+    /// Get the terminator of the warp region.
+    gpu::YieldOp getTerminator();
   }];
 }
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 2503ccb6a2cfe..cc77aa6711c42 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2486,8 +2486,7 @@ LogicalResult WarpExecuteOnLane0Op::verify() {
   if (getArgs().size() != getWarpRegion().getNumArguments())
     return emitOpError(
         "expected same number op arguments and block arguments.");
-  auto yield =
-      cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
+  gpu::YieldOp yield = getTerminator();
   if (yield.getNumOperands() != getNumResults())
     return emitOpError(
         "expected same number of yield operands and return values.");
@@ -2511,6 +2510,10 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
       verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
 }
 
+gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
+  return cast<gpu::YieldOp>(getBody()->getTerminator());
+}
+
 //===----------------------------------------------------------------------===//
 // GPU KernelMetadataAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
index be71bd02fc43b..88f531f394765 100644
--- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -56,8 +56,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
     SmallVector<size_t> &indices) const {
   SmallVector<Type> types(warpOp.getResultTypes().begin(),
                           warpOp.getResultTypes().end());
-  auto yield = cast<gpu::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  gpu::YieldOp yield = warpOp.getTerminator();
   SmallVector<Value> yieldValues(yield.getOperands().begin(),
                                  yield.getOperands().end());
   llvm::SmallDenseMap<Value, unsigned> indexLookup;
@@ -89,8 +88,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
 OpOperand *WarpDistributionPattern::getWarpResult(
     WarpExecuteOnLane0Op warpOp,
     llvm::function_ref<bool(Operation *)> fn) const {
-  auto yield = cast<gpu::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  gpu::YieldOp yield = warpOp.getTerminator();
   for (OpOperand &yieldOperand : yield->getOpOperands()) {
     Value yieldValues = yieldOperand.get();
     Operation *definedOp = yieldValues.getDefiningOp();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index be0d28a91cba7..60aa0e9bae64a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -528,8 +528,7 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
 
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp yield = warpOp.getTerminator();
     Operation *lastNode = yield->getPrevNode();
     auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
     if (!writeOp)
@@ -846,8 +845,7 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
     newYieldValues.reserve(warpOp->getNumResults());
     DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
     DenseMap<OpResult, int64_t> dedupResultPositionMap;
-    auto yield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp yield = warpOp.getTerminator();
 
     // Some values may be yielded multiple times and correspond to multiple
     // results. Deduplicating occurs by taking each result with its matching
@@ -901,8 +899,7 @@ struct WarpOpForwardOperand : public WarpDistributionPattern {
   using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp yield = warpOp.getTerminator();
     Value valForwarded;
     unsigned resultIndex;
     for (OpOperand &operand : yield->getOpOperands()) {
@@ -1708,8 +1705,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
       : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto warpOpYield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp warpOpYield = warpOp.getTerminator();
     // Only pick up `ForOp` if it is the last op in the region.
     Operation *lastNode = warpOpYield->getPrevNode();
     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2088c3c7fc5ec..8e47968609d32 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -336,8 +336,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp yield = warpOp.getTerminator();
     Operation *lastNode = yield->getPrevNode();
     auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
     if (!storeOp)
@@ -449,8 +448,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
       // Make sure the same load op is the last operation in the warp op body.
       // This ensure that load op is not sinked earlier violating any barrier
       // synchronizations.
-      auto yield = cast<gpu::YieldOp>(
-          warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+      gpu::YieldOp yield = warpOp.getTerminator();
       return yield->getPrevNode() == op;
     });
 
@@ -752,8 +750,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp yield = warpOp.getTerminator();
     Operation *lastNode = yield->getPrevNode();
     auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
     if (!prefetchOp)
@@ -794,8 +791,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    gpu::YieldOp yield = warpOp.getTerminator();
     Operation *lastNode = yield->getPrevNode();
     // The last node must be a gpu::BarrierOp.
     auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. very useful change.

Existing code uses cast<gpu::YieldOp>(warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); to get the terminator. But this approach uses cast<gpu::YieldOp>(getBody()->getTerminator()). Wondering if they are the same or could be different in come cases?

@@ -2511,6 +2510,10 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
}

gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
return cast<gpu::YieldOp>(getBody()->getTerminator());
Copy link
Contributor

@charithaintc charithaintc Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question out of curiosity, is the body always guaranteed to have a terminator? who ensures that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warp executes's SingleBlockImplicitTerminator<"gpu::YieldOp"> trait ensure that there's only one block and must be terminated by the specific terminator op.
In textual representation like:

gpu.warp_execute_on_lane_0(%laneid)[32] {
    %c0 = arith.constant 0 : index
    %v = "test.dummy_op"() : () -> (vector<4xf32>)
    %v1 = "test.dummy_op"() : () -> (vector<4x1xf32>)
    vector.transfer_write %v1, %arg1[%c0, %c0] : vector<4x1xf32>, memref<1024x1024xf32>
    vector.transfer_write %v, %arg1[%c0, %c0] : vector<4xf32>, memref<1024x1024xf32>
  }

the terminator still exists but warp's printer omits it (see WarpExecuteOnLane0Op::print).

getBody() is an API provided by the SingleBlock trait and under the hood it gets the block in the same way as the more verbose version: op -> region -> block. It should be identical.

@adam-smnk adam-smnk merged commit 533ddcd into llvm:main Aug 22, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants