-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][vector] Add support for vector.multi_reduction
and vector.shape_cast
distribution.
#154438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][vector] Add support for vector.multi_reduction
and vector.shape_cast
distribution.
#154438
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Charitha Saumya (charithaintc) ChangesThis PR adds support for
PR also include changes in Full diff: https://github.com/llvm/llvm-project/pull/154438.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index be0d28a91cba7..6410a895fc9ae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,13 +15,19 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstddef>
#include <utility>
using namespace mlir;
@@ -977,44 +983,75 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
/// Pattern to move shape cast out of the warp op. shape cast is basically a
/// no-op for warp distribution; we need to handle the shape though.
struct WarpOpShapeCast : public WarpDistributionPattern {
- using Base::Base;
+
+ WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+ : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
if (!operand)
return failure();
-
auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
unsigned int operandNumber = operand->getOperandNumber();
- auto castDistributedType =
+ VectorType sourceType = oldCastOp.getSourceVectorType();
+ VectorType distributedResultType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
- VectorType castOriginalType = oldCastOp.getSourceVectorType();
- VectorType castResultType = castDistributedType;
-
- // We expect the distributed type to have a smaller rank than the original
- // type. Prepend with size-one dimensions to make them the same.
- unsigned castDistributedRank = castDistributedType.getRank();
- unsigned castOriginalRank = castOriginalType.getRank();
- if (castDistributedRank < castOriginalRank) {
- SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
- llvm::append_range(shape, castDistributedType.getShape());
- castDistributedType =
- VectorType::get(shape, castDistributedType.getElementType());
+ VectorType distributedSourceType = sourceType;
+ bool isResultDistributed = distributedResultType.getNumElements() <
+ oldCastOp.getResultVectorType().getNumElements();
+
+ // If the result is not distributed, source distribted type is the same
+ // as the source type. If the result is distributed, we need to compute the
+ // distributed source type according to following rules:
+ // 1. If the source type is yielded from the warp op, we can use the
+ // matching warp result type as the distributed source type.
+ // 2. If the source type is not yielded from the warp op, we need
+ // to compute the distributed source type based on the distribution map
+ // and the warp size.
+ if (isResultDistributed) {
+ // Check if the source is yielded from the warp op.
+ gpu::YieldOp yieldOp = cast<gpu::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ auto *it =
+ llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) {
+ return operand.get() == oldCastOp.getSource();
+ });
+
+ if (it != yieldOp->getOpOperands().end()) {
+ // If the source is yielded from the warp op, we can use the matching
+ // warp result type as the distributed source type.
+ distributedSourceType =
+ cast<VectorType>(warpOp->getResultTypes()[it->getOperandNumber()]);
+ } else {
+ // If the source is not yielded from the warp op, we need to compute
+ // the distributed source type based on the distribution map and the
+ // warp size.
+ AffineMap map = distributionMapFn(oldCastOp.getSource());
+ distributedSourceType =
+ getDistributedType(sourceType, map, warpOp.getWarpSize());
+ if (!distributedSourceType)
+ return rewriter.notifyMatchFailure(
+ oldCastOp,
+ "cannot compute distributed source type for shape cast");
+ }
}
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
+ rewriter, warpOp, {oldCastOp.getSource()}, {distributedSourceType},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value newCast = vector::ShapeCastOp::create(
- rewriter, oldCastOp.getLoc(), castResultType,
+ rewriter, oldCastOp.getLoc(), distributedResultType,
newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
return success();
}
+
+private:
+ DistributionMapFn distributionMapFn;
};
/// Sink out vector.create_mask op feeding into a warp op yield.
@@ -1996,6 +2033,107 @@ struct WarpOpReduction : public WarpDistributionPattern {
DistributedReductionFn distributedReductionFn;
};
+struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
+ VectorMultiDimReductionDistribution(MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : WarpDistributionPattern(context, benefit) {}
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
+ if (!yieldOperand)
+ return failure();
+ auto reductionOp =
+ cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
+ unsigned operandNumber = yieldOperand->getOperandNumber();
+ VectorType sourceType = reductionOp.getSourceVectorType();
+ VectorType distributedResultType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ Type elementType = distributedResultType.getElementType();
+ // Only 2D vectors are supported.
+ if (sourceType.getRank() != 2)
+ return rewriter.notifyMatchFailure(warpOp,
+ "Only 2D reductions are supported.");
+ ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
+ // Only 1 reduction dimension supported.
+ if (reductionDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Only 1 reduction dimension is supported.");
+
+ // Create a constant vector to store the result of the reduction per lane.
+ TypedAttr zeroAttr =
+ rewriter.getZeroAttr(distributedResultType.getElementType());
+ Value result = arith::ConstantOp::create(
+ rewriter, reductionOp->getLoc(), distributedResultType,
+ DenseElementsAttr::get(distributedResultType, zeroAttr));
+
+ // Col reduction.
+ if (reductionDims[0] == 0) {
+ // Source vector must be distributable to lanes in the col dimension.
+ if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Source vector dimension must be divisible by warp size.");
+ // Compute source distributed type.
+ SmallVector<int64_t> shape(sourceType.getShape());
+ shape[1] = shape[1] / warpOp.getWarpSize();
+ auto sourceDistributedType = VectorType::get(shape, elementType);
+
+ // Yield the source and acc vectors from the WarpOp.
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
+ {sourceDistributedType, distributedResultType}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ int nCols = sourceDistributedType.getShape()[1];
+ Value source = newWarpOp.getResult(newRetIndices[0]);
+ Value acc = newWarpOp.getResult(newRetIndices[1]);
+ // For each column owned by a lane, extract the column (of size nRows x
+ // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the
+ // result back to the result vector.
+ for (int i = 0; i < nCols; ++i) {
+ Value col = vector::ExtractStridedSliceOp::create(
+ rewriter, reductionOp.getLoc(), source, {0, i},
+ {sourceDistributedType.getShape()[0], 1}, {1, 1});
+ col = vector::ShapeCastOp::create(
+ rewriter, reductionOp.getLoc(),
+ VectorType::get({sourceDistributedType.getShape()[0]}, elementType),
+ col);
+ Value accCol =
+ vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
+ Value colReduce = vector::ReductionOp::create(
+ rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
+ result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
+ colReduce, result, i);
+ }
+ // Replace the warp op result with the new reduction op.
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
+ return success();
+ }
+ // For row reductions, we simply rewrite the MultiReductionOp in terms of
+ // multiple ReductionOps. Actual distribution is done by the WarpOpReduction
+ // pattern.
+ rewriter.setInsertionPointAfter(reductionOp);
+ int nRows = sourceType.getShape()[0];
+ // For each row of the source, extract the row vector, do a reduction and,
+ // insert the result back to the result.
+ for (int i = 0; i < nRows; ++i) {
+ Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+ reductionOp.getSource(), i);
+ Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+ reductionOp.getAcc(), i);
+ Value rowReduce = vector::ReductionOp::create(
+ rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc);
+ result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
+ rowReduce, result, i);
+ }
+ // Replace the warp op result with the final result.
+ rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
+
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
@@ -2017,15 +2155,15 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
patterns
- .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
- WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
- WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
+ .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, WarpOpExtract,
+ WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar,
+ WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice,
+ WarpOpInsertStridedSlice, VectorMultiDimReductionDistribution>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
- patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
- benefit);
+ patterns.add<WarpOpScfForOp, WarpOpShapeCast>(patterns.getContext(),
+ distributionMapFn, benefit);
}
void mlir::vector::populateDistributeReduction(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 4d2c964a6df3c..bf70fbbd27244 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -850,6 +850,83 @@ func.func @vector_reduction_acc(%laneid: index) -> (f32) {
return %r : f32
}
+// -----
+// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce
+// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) {
+// CHECK-PROP: %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32>
+// CHECK-PROP: %[[ACC:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK-PROP: gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP: %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP: %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32>
+// CHECK-PROP: %[[REDUCE0:.*]] = vector.reduction <add>, %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32
+// CHECK-PROP: %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP: %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP: %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32>
+// CHECK-PROP: %[[REDUCE1:.*]] = vector.reduction <add>, %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32
+// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32>
+// CHECK-PROP: return %[[R]] : vector<2xf32>
+func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+ %0 = "some_def"() : () -> (vector<32x64xf32>)
+ %acc = "some_def"() : () -> (vector<64xf32>)
+ %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<32x64xf32> to vector<64xf32>
+ gpu.yield %1 : vector<64xf32>
+ }
+ return %r : vector<2xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_multi_reduction_row_reduce
+// CHECK-PROP: %[[C16:.*]] = arith.constant 16 : i32
+// CHECK-PROP: %[[C8:.*]] = arith.constant 8 : i32
+// CHECK-PROP: %[[C4:.*]] = arith.constant 4 : i32
+// CHECK-PROP: %[[C2:.*]] = arith.constant 2 : i32
+// CHECK-PROP: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-PROP: %[[C32:.*]] = arith.constant 32 : i32
+// CHECK-PROP: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) {
+// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32>
+// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32
+// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32
+// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32
+// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32
+// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32
+// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32
+//
+// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32
+// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32
+// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32
+// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32
+// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32
+// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32
+// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
+// CHECK-PROP: return %[[R]] : vector<2xf32>
+func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> {
+ %zero = arith.constant dense<0.0> : vector<2xf32>
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+ %0 = "some_def"() : () -> (vector<2x32xf32>)
+ %1 = vector.multi_reduction <add>, %0, %zero [1] : vector<2x32xf32> to vector<2xf32>
+ gpu.yield %1 : vector<2xf32>
+ }
+ return %r : vector<2xf32>
+}
+
// -----
// CHECK-PROP-LABEL: func @warp_duplicate_yield(
@@ -1567,6 +1644,40 @@ func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>)
// CHECK-PROP: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32>
// CHECK-PROP: return %[[CAST]] : vector<4xf32>
+// -----
+func.func @warp_propagate_shape_cast_2d_to_2d(%laneid: index, %src: memref<64x32xf32>) -> vector<32x2xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<32x2xf32>) {
+ %2 = vector.transfer_read %src[%c0, %c0], %cst : memref<64x32xf32>, vector<64x32xf32>
+ %3 = vector.shape_cast %2 : vector<64x32xf32> to vector<32x64xf32>
+ gpu.yield %3 : vector<32x64xf32>
+ }
+ return %r : vector<32x2xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_2d_to_2d
+// CHECK-PROP: %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [false, true]} : memref<64x32xf32>, vector<2x32xf32>
+// CHECK-PROP: %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<2x32xf32> to vector<32x2xf32>
+// CHECK-PROP: return %[[CAST]] : vector<32x2xf32>
+
+// -----
+func.func @warp_propagate_shape_cast_non_distributed_result(%laneid: index, %src: memref<64xf32>) -> vector<8x4x2xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x4x2xf32>) {
+ %2 = vector.transfer_read %src[%c0], %cst : memref<64xf32>, vector<64xf32>
+ %3 = vector.shape_cast %2 : vector<64xf32> to vector<8x4x2xf32>
+ gpu.yield %3 : vector<8x4x2xf32>
+ }
+ return %r : vector<8x4x2xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_non_distributed_result
+// CHECK-PROP: %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [true]} : memref<64xf32>, vector<64xf32>
+// CHECK-PROP: %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<64xf32> to vector<8x4x2xf32>
+// CHECK-PROP: return %[[CAST]] : vector<8x4x2xf32>
+
// -----
func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index) -> vector<1xf32> {
|
cc @Garra1980 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General direction looks good to me.
I'll let others have a closer look at the distribution logic.
Hi @adam-smnk, Thanks for the reviews. I addressed the concerns. Please let me know if you have any additional concerns. I also properly documented the restrictions in this version. |
// distributed source type according to following rules: | ||
// 1. If the source type is yielded from the warp op, we can use the | ||
// matching warp result type as the distributed source type. | ||
// 2. If the source type is not yielded from the warp op, we need |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you give some example. Will it happen that there are conflicts between these two rules?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added two examples for the 2 cases as comment.
For both row and col reduction we assume that cols of the source vector is owned by each lane (i.e. in xegpu layouts this will be [1, 16]). Based on that we handle the reduction logic.
Given this layout,
Col reduction is easy : just reduce your own data.
Row reduction: needs to shuffle data with neighbors and do a tree like reduce (aka butterfly reduction with shuffles)
however source layout can also be [16, 1]. This case is not supported because the vector distribution infra does not allow me to express such layout currently (it always start distributing the vector from innermost dim). I am working on some proposal to improve this.
// case each lane owns its portion of the result (i.e. result is also | ||
// distributed). | ||
// 3. If reduction dim == 1, its a row reduction that require cross lanes | ||
// shuffles. In this case result is not distributed and broadcasted instead. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not distributed but broadcasted instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. for row reductions each lane owns the whole reduced result. Example.
%r = vector_multi_reduction <add> %src, %acc, [1] : vector<8x16xf32> -> vector<8xf32>
In this case each lane own a col of 8x16xf32 (i.e. 8x1xf32). Each lane will do a shuffle and add reduce log2(16) times to get the final result. In this case each lane will have the final result of 8x1xf32 at the end. meaning that the result is eventually broadcasted to each lane.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusion. I meant the comments is "not distributed but broadcasted instead" not "not distributed and broadcasted instead"
reductionOp.getSource(), i); | ||
Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), | ||
reductionOp.getAcc(), i); | ||
Value rowReduce = vector::ReductionOp::create( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious how is ShuffleOp inserted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is lowered progressively. Here we lower it to bunch of vector.reduction
ops. Then WarpOpReduction
pattern kicks in and do the actual distribution to shuffle ops.
WarpOpReduction
is free to use any reduction strategy (specified by distributedReductionFn
). Currently it by default use the one defined here.
https://github.com/llvm/llvm-project/blob/main/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp#L566
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, it looks good to me. It might be helpful to have someone familiar with the distribution mechanism review it for accuracy.
@chencha3 @adam-smnk I have addressed the comments. Please let me know if any other concerns. If not please consider approving the PR. |
This PR adds support for
vector.multi_reduction
distribution. Currently only 2D to 1D reductions are supported (col/row reductions) and assumes the inner dimension of source vector is distributed among lanes (each lane owns columns of the source vector).vector.reduction
of the column data.multi_reduction
in terms ofreduction
ops.PR also include changes in
vector.shape_cast
distribution to consider the distributed type of shape case source if given by aDistributionMapFn