Skip to content

Conversation

charithaintc
Copy link
Contributor

This PR adds support for vector.multi_reduction distribution. Currently only 2D to 1D reductions are supported (col/row reductions) and assumes the inner dimension of source vector is distributed among lanes (each lane owns columns of the source vector).

  • Col Reduce : Each lane owns the data. Each lane can do its own vector.reduction of the column data.
  • Row Reduce: Requires shuffling data with neighbors. Simply rewrite the multi_reduction in terms of reduction ops.

PR also include changes in vector.shape_cast distribution to consider the distributed type of shape case source if given by a DistributionMapFn

@llvmbot
Copy link
Member

llvmbot commented Aug 19, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Charitha Saumya (charithaintc)

Changes

This PR adds support for vector.multi_reduction distribution. Currently only 2D to 1D reductions are supported (col/row reductions) and assumes the inner dimension of source vector is distributed among lanes (each lane owns columns of the source vector).

  • Col Reduce : Each lane owns the data. Each lane can do its own vector.reduction of the column data.
  • Row Reduce: Requires shuffling data with neighbors. Simply rewrite the multi_reduction in terms of reduction ops.

PR also include changes in vector.shape_cast distribution to consider the distributed type of shape case source if given by a DistributionMapFn


Full diff: https://github.com/llvm/llvm-project/pull/154438.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+161-23)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+111)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index be0d28a91cba7..6410a895fc9ae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,13 +15,19 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstddef>
 #include <utility>
 
 using namespace mlir;
@@ -977,44 +983,75 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
 /// Pattern to move shape cast out of the warp op. shape cast is basically a
 /// no-op for warp distribution; we need to handle the shape though.
 struct WarpOpShapeCast : public WarpDistributionPattern {
-  using Base::Base;
+
+  WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
         getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
     if (!operand)
       return failure();
-
     auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
 
     unsigned int operandNumber = operand->getOperandNumber();
-    auto castDistributedType =
+    VectorType sourceType = oldCastOp.getSourceVectorType();
+    VectorType distributedResultType =
         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
-    VectorType castOriginalType = oldCastOp.getSourceVectorType();
-    VectorType castResultType = castDistributedType;
-
-    // We expect the distributed type to have a smaller rank than the original
-    // type. Prepend with size-one dimensions to make them the same.
-    unsigned castDistributedRank = castDistributedType.getRank();
-    unsigned castOriginalRank = castOriginalType.getRank();
-    if (castDistributedRank < castOriginalRank) {
-      SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
-      llvm::append_range(shape, castDistributedType.getShape());
-      castDistributedType =
-          VectorType::get(shape, castDistributedType.getElementType());
+    VectorType distributedSourceType = sourceType;
+    bool isResultDistributed = distributedResultType.getNumElements() <
+                               oldCastOp.getResultVectorType().getNumElements();
+
+    // If the result is not distributed, source distribted type is the same
+    // as the source type. If the result is distributed, we need to compute the
+    // distributed source type according to following rules:
+    // 1. If the source type is yielded from the warp op, we can use the
+    //    matching warp result type as the distributed source type.
+    // 2. If the source type is not yielded from the warp op, we need
+    //    to compute the distributed source type based on the distribution map
+    //    and the warp size.
+    if (isResultDistributed) {
+      // Check if the source is yielded from the warp op.
+      gpu::YieldOp yieldOp = cast<gpu::YieldOp>(
+          warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+      auto *it =
+          llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) {
+            return operand.get() == oldCastOp.getSource();
+          });
+
+      if (it != yieldOp->getOpOperands().end()) {
+        // If the source is yielded from the warp op, we can use the matching
+        // warp result type as the distributed source type.
+        distributedSourceType =
+            cast<VectorType>(warpOp->getResultTypes()[it->getOperandNumber()]);
+      } else {
+        // If the source is not yielded from the warp op, we need to compute
+        // the distributed source type based on the distribution map and the
+        // warp size.
+        AffineMap map = distributionMapFn(oldCastOp.getSource());
+        distributedSourceType =
+            getDistributedType(sourceType, map, warpOp.getWarpSize());
+        if (!distributedSourceType)
+          return rewriter.notifyMatchFailure(
+              oldCastOp,
+              "cannot compute distributed source type for shape cast");
+      }
     }
 
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
+        rewriter, warpOp, {oldCastOp.getSource()}, {distributedSourceType},
         newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value newCast = vector::ShapeCastOp::create(
-        rewriter, oldCastOp.getLoc(), castResultType,
+        rewriter, oldCastOp.getLoc(), distributedResultType,
         newWarpOp->getResult(newRetIndices[0]));
     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
     return success();
   }
+
+private:
+  DistributionMapFn distributionMapFn;
 };
 
 /// Sink out vector.create_mask op feeding into a warp op yield.
@@ -1996,6 +2033,107 @@ struct WarpOpReduction : public WarpDistributionPattern {
   DistributedReductionFn distributedReductionFn;
 };
 
+struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
+  VectorMultiDimReductionDistribution(MLIRContext *context,
+                                      PatternBenefit benefit = 1)
+      : WarpDistributionPattern(context, benefit) {}
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
+    if (!yieldOperand)
+      return failure();
+    auto reductionOp =
+        cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
+    unsigned operandNumber = yieldOperand->getOperandNumber();
+    VectorType sourceType = reductionOp.getSourceVectorType();
+    VectorType distributedResultType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    Type elementType = distributedResultType.getElementType();
+    // Only 2D vectors are supported.
+    if (sourceType.getRank() != 2)
+      return rewriter.notifyMatchFailure(warpOp,
+                                         "Only 2D reductions are supported.");
+    ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
+    // Only 1 reduction dimension supported.
+    if (reductionDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Only 1 reduction dimension is supported.");
+
+    // Create a constant vector to store the result of the reduction per lane.
+    TypedAttr zeroAttr =
+        rewriter.getZeroAttr(distributedResultType.getElementType());
+    Value result = arith::ConstantOp::create(
+        rewriter, reductionOp->getLoc(), distributedResultType,
+        DenseElementsAttr::get(distributedResultType, zeroAttr));
+
+    // Col reduction.
+    if (reductionDims[0] == 0) {
+      // Source vector must be distributable to lanes in the col dimension.
+      if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
+        return rewriter.notifyMatchFailure(
+            warpOp, "Source vector dimension must be divisible by warp size.");
+      // Compute source distributed type.
+      SmallVector<int64_t> shape(sourceType.getShape());
+      shape[1] = shape[1] / warpOp.getWarpSize();
+      auto sourceDistributedType = VectorType::get(shape, elementType);
+
+      // Yield the source and acc vectors from the WarpOp.
+      SmallVector<size_t> newRetIndices;
+      auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
+          {sourceDistributedType, distributedResultType}, newRetIndices);
+      rewriter.setInsertionPointAfter(newWarpOp);
+
+      int nCols = sourceDistributedType.getShape()[1];
+      Value source = newWarpOp.getResult(newRetIndices[0]);
+      Value acc = newWarpOp.getResult(newRetIndices[1]);
+      // For each column owned by a lane, extract the column (of size nRows x
+      // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the
+      // result back to the result vector.
+      for (int i = 0; i < nCols; ++i) {
+        Value col = vector::ExtractStridedSliceOp::create(
+            rewriter, reductionOp.getLoc(), source, {0, i},
+            {sourceDistributedType.getShape()[0], 1}, {1, 1});
+        col = vector::ShapeCastOp::create(
+            rewriter, reductionOp.getLoc(),
+            VectorType::get({sourceDistributedType.getShape()[0]}, elementType),
+            col);
+        Value accCol =
+            vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
+        Value colReduce = vector::ReductionOp::create(
+            rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
+        result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
+                                          colReduce, result, i);
+      }
+      // Replace the warp op result with the new reduction op.
+      rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
+      return success();
+    }
+    // For row reductions, we simply rewrite the MultiReductionOp in terms of
+    // multiple ReductionOps. Actual distribution is done by the WarpOpReduction
+    // pattern.
+    rewriter.setInsertionPointAfter(reductionOp);
+    int nRows = sourceType.getShape()[0];
+    // For each row of the source, extract the row vector, do a reduction and,
+    // insert the result back to the result.
+    for (int i = 0; i < nRows; ++i) {
+      Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+                                               reductionOp.getSource(), i);
+      Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+                                            reductionOp.getAcc(), i);
+      Value rowReduce = vector::ReductionOp::create(
+          rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc);
+      result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
+                                        rowReduce, result, i);
+    }
+    // Replace the warp op result with the final result.
+    rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
@@ -2017,15 +2155,15 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
   patterns
-      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
-           WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
-           WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
-           WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
+      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, WarpOpExtract,
+           WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar,
+           WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice,
+           WarpOpInsertStridedSlice, VectorMultiDimReductionDistribution>(
           patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
-  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
-                               benefit);
+  patterns.add<WarpOpScfForOp, WarpOpShapeCast>(patterns.getContext(),
+                                                distributionMapFn, benefit);
 }
 
 void mlir::vector::populateDistributeReduction(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 4d2c964a6df3c..bf70fbbd27244 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -850,6 +850,83 @@ func.func @vector_reduction_acc(%laneid: index) -> (f32) {
   return %r : f32
 }
 
+// -----
+// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce
+// CHECK-PROP:   %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) {
+// CHECK-PROP:     %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32>
+// CHECK-PROP:     %[[ACC:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK-PROP:     gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32>
+// CHECK-PROP:   }
+// CHECK-PROP:   %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP:   %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP:   %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32>
+// CHECK-PROP:   %[[REDUCE0:.*]] = vector.reduction <add>, %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32
+// CHECK-PROP:   %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP:   %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP:   %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32>
+// CHECK-PROP:   %[[REDUCE1:.*]] = vector.reduction <add>, %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32
+// CHECK-PROP:   %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32>
+// CHECK-PROP:   return %[[R]] : vector<2xf32>
+func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<32x64xf32>)
+    %acc = "some_def"() : () -> (vector<64xf32>)
+    %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<32x64xf32> to vector<64xf32>
+    gpu.yield %1 : vector<64xf32>
+  }
+  return %r : vector<2xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL:  func.func @vector_multi_reduction_row_reduce
+// CHECK-PROP:    %[[C16:.*]] = arith.constant 16 : i32
+// CHECK-PROP:    %[[C8:.*]] = arith.constant 8 : i32
+// CHECK-PROP:    %[[C4:.*]] = arith.constant 4 : i32
+// CHECK-PROP:    %[[C2:.*]] = arith.constant 2 : i32
+// CHECK-PROP:    %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-PROP:    %[[C32:.*]] = arith.constant 32 : i32
+// CHECK-PROP:    %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-PROP:    %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) {
+// CHECK-PROP:      %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32>
+// CHECK-PROP:      gpu.yield %[[SRC]] : vector<2x32xf32>
+// CHECK-PROP:    }
+// CHECK-PROP:    %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP:    %[[SR:.*]], %{{.*}} = gpu.shuffle  xor %[[T1]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP:    %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32
+// CHECK-PROP:    %[[SR0:.*]], %{{.*}} = gpu.shuffle  xor %[[T2]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP:    %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32
+// CHECK-PROP:    %[[SR2:.*]], %{{.*}} = gpu.shuffle  xor %[[T3]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP:    %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32
+// CHECK-PROP:    %[[SR4:.*]], %{{.*}} = gpu.shuffle  xor %[[T4]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP:    %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32
+// CHECK-PROP:    %[[SR6:.*]], %{{.*}} = gpu.shuffle  xor %[[T5]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP:    %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32
+// CHECK-PROP:    %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32
+//
+// CHECK-PROP:    %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP:    %[[SR8:.*]], %{{.*}} = gpu.shuffle  xor %[[T8]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP:    %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32
+// CHECK-PROP:    %[[SR10:.*]], %{{.*}} = gpu.shuffle  xor %[[T9]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP:    %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32
+// CHECK-PROP:    %[[SR12:.*]], %{{.*}} = gpu.shuffle  xor %[[T10]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP:    %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32
+// CHECK-PROP:    %[[SR14:.*]], %{{.*}} = gpu.shuffle  xor %[[T11]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP:    %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32
+// CHECK-PROP:    %[[SR16:.*]], %{{.*}} = gpu.shuffle  xor %[[T12]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP:    %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32
+// CHECK-PROP:    %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32
+// CHECK-PROP:    %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
+// CHECK-PROP:    return %[[R]] : vector<2xf32>
+func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> {
+  %zero = arith.constant dense<0.0> : vector<2xf32>
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<2x32xf32>)
+    %1 = vector.multi_reduction <add>, %0, %zero [1] : vector<2x32xf32> to vector<2xf32>
+    gpu.yield %1 : vector<2xf32>
+  }
+  return %r : vector<2xf32>
+}
+
 // -----
 
 // CHECK-PROP-LABEL:   func @warp_duplicate_yield(
@@ -1567,6 +1644,40 @@ func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>)
 // CHECK-PROP:   %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32>
 // CHECK-PROP:   return %[[CAST]] : vector<4xf32>
 
+// -----
+func.func @warp_propagate_shape_cast_2d_to_2d(%laneid: index, %src: memref<64x32xf32>) -> vector<32x2xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<32x2xf32>) {
+    %2 = vector.transfer_read %src[%c0, %c0], %cst : memref<64x32xf32>, vector<64x32xf32>
+    %3 = vector.shape_cast %2 : vector<64x32xf32> to vector<32x64xf32>
+    gpu.yield %3 : vector<32x64xf32>
+  }
+  return %r : vector<32x2xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_2d_to_2d
+// CHECK-PROP:    %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [false, true]} : memref<64x32xf32>, vector<2x32xf32>
+// CHECK-PROP:    %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<2x32xf32> to vector<32x2xf32>
+// CHECK-PROP:    return %[[CAST]] : vector<32x2xf32>
+
+// -----
+func.func @warp_propagate_shape_cast_non_distributed_result(%laneid: index, %src: memref<64xf32>) -> vector<8x4x2xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x4x2xf32>) {
+    %2 = vector.transfer_read %src[%c0], %cst : memref<64xf32>, vector<64xf32>
+    %3 = vector.shape_cast %2 : vector<64xf32> to vector<8x4x2xf32>
+    gpu.yield %3 : vector<8x4x2xf32>
+  }
+  return %r : vector<8x4x2xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_non_distributed_result
+// CHECK-PROP:    %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [true]} : memref<64xf32>, vector<64xf32>
+// CHECK-PROP:    %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<64xf32> to vector<8x4x2xf32>
+// CHECK-PROP:    return %[[CAST]] : vector<8x4x2xf32>
+
 // -----
 
 func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index) -> vector<1xf32> {

@charithaintc
Copy link
Contributor Author

cc @Garra1980

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General direction looks good to me.
I'll let others have a closer look at the distribution logic.

@charithaintc
Copy link
Contributor Author

General direction looks good to me. I'll let others have a closer look at the distribution logic.

Hi @adam-smnk, Thanks for the reviews. I addressed the concerns. Please let me know if you have any additional concerns.

I also properly documented the restrictions in this version.

// distributed source type according to following rules:
// 1. If the source type is yielded from the warp op, we can use the
// matching warp result type as the distributed source type.
// 2. If the source type is not yielded from the warp op, we need
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give some example. Will it happen that there are conflicts between these two rules?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added two examples for the 2 cases as comment.

For both row and col reduction we assume that cols of the source vector is owned by each lane (i.e. in xegpu layouts this will be [1, 16]). Based on that we handle the reduction logic.
Given this layout,
Col reduction is easy : just reduce your own data.
Row reduction: needs to shuffle data with neighbors and do a tree like reduce (aka butterfly reduction with shuffles)

however source layout can also be [16, 1]. This case is not supported because the vector distribution infra does not allow me to express such layout currently (it always start distributing the vector from innermost dim). I am working on some proposal to improve this.

// case each lane owns its portion of the result (i.e. result is also
// distributed).
// 3. If reduction dim == 1, its a row reduction that require cross lanes
// shuffles. In this case result is not distributed and broadcasted instead.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not distributed but broadcasted instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. for row reductions each lane owns the whole reduced result. Example.

%r = vector_multi_reduction <add> %src, %acc, [1] : vector<8x16xf32> -> vector<8xf32>

In this case each lane own a col of 8x16xf32 (i.e. 8x1xf32). Each lane will do a shuffle and add reduce log2(16) times to get the final result. In this case each lane will have the final result of 8x1xf32 at the end. meaning that the result is eventually broadcasted to each lane.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. I meant the comments is "not distributed but broadcasted instead" not "not distributed and broadcasted instead"

reductionOp.getSource(), i);
Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
reductionOp.getAcc(), i);
Value rowReduce = vector::ReductionOp::create(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious how is ShuffleOp inserted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is lowered progressively. Here we lower it to bunch of vector.reduction ops. Then WarpOpReduction pattern kicks in and do the actual distribution to shuffle ops.

WarpOpReduction is free to use any reduction strategy (specified by distributedReductionFn). Currently it by default use the one defined here.
https://github.com/llvm/llvm-project/blob/main/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp#L566

Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, it looks good to me. It might be helpful to have someone familiar with the distribution mechanism review it for accuracy.

@charithaintc
Copy link
Contributor Author

@chencha3 @adam-smnk I have addressed the comments. Please let me know if any other concerns. If not please consider approving the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants