-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][vector] Rename gather/scatter arguments (nfc) #153640
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Rename gather/scatter arguments (nfc) #153640
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesRenames Full diff: https://github.com/llvm/llvm-project/pull/153640.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c7b83674fb009..4e1661cd1d3a2 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2051,8 +2051,8 @@ def Vector_GatherOp :
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
- Variadic<Index>:$indices,
- VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+ Variadic<Index>:$offsets,
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$pass_thru)>,
Results<(outs AnyVectorOfNonZeroRank:$result)> {
@@ -2074,11 +2074,11 @@ def Vector_GatherOp :
```mlir
func.func @gather_3D_to_2D(
- %base: memref<?x10x?xf32>, %i0: index, %i1: index, %i2: index,
- %index_vec: vector<2x3xi32>, %mask: vector<2x3xi1>,
+ %base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
+ %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
%fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
- %result = vector.gather %base[%i0, %i1, %i2]
- [%index_vec], %mask, %fall_thru : [...]
+ %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
+ [%indices], %mask, %fall_thru : [...]
return %result : vector<2x3xf32>
}
```
@@ -2086,7 +2086,7 @@ def Vector_GatherOp :
The indexing semantics are then,
```
- result[i,j] := if mask[i,j] then base[i0, i1, i2 + index_vec[i,j]]
+ result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
else pass_thru[i,j]
```
The index into `base` only varies in the innermost ((k-1)-th) dimension.
@@ -2111,16 +2111,16 @@ def Vector_GatherOp :
let extraClassDeclaration = [{
ShapedType getBaseType() { return getBase().getType(); }
- VectorType getIndexVectorType() { return getIndexVec().getType(); }
+ VectorType getIndexVectorType() { return getIndices().getType(); }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getPassThruVectorType() { return getPassThru().getType(); }
VectorType getVectorType() { return getResult().getType(); }
}];
let assemblyFormat =
- "$base `[` $indices `]` `[` $index_vec `]` `,` "
+ "$base `[` $offsets `]` `[` $indices `]` `,` "
"$mask `,` $pass_thru attr-dict `:` type($base) `,` "
- "type($index_vec) `,` type($mask) `,` type($pass_thru) "
+ "type($indices) `,` type($mask) `,` type($pass_thru) "
"`into` type($result)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
@@ -2129,8 +2129,8 @@ def Vector_GatherOp :
def Vector_ScatterOp :
Vector_Op<"scatter">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
- Variadic<Index>:$indices,
- VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+ Variadic<Index>:$offsets,
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$valueToStore)> {
@@ -2179,15 +2179,15 @@ def Vector_ScatterOp :
let extraClassDeclaration = [{
MemRefType getMemRefType() { return getBase().getType(); }
- VectorType getIndexVectorType() { return getIndexVec().getType(); }
+ VectorType getIndexVectorType() { return getIndices().getType(); }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getVectorType() { return getValueToStore().getType(); }
}];
let assemblyFormat =
- "$base `[` $indices `]` `[` $index_vec `]` `,` "
+ "$base `[` $offsets `]` `[` $indices `]` `,` "
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
- "type($index_vec) `,` type($mask) `,` type($valueToStore)";
+ "type($indices) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f9e2a01dbf969..3c9364270d9c0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -306,11 +306,11 @@ class VectorGatherOpConversion
// Resolve address.
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
- adaptor.getBase(), adaptor.getIndices());
+ adaptor.getBase(), adaptor.getOffsets());
Value base = adaptor.getBase();
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- base, ptr, adaptor.getIndexVec(), vType);
+ base, ptr, adaptor.getIndices(), vType);
// Replace with the gather intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
@@ -362,10 +362,10 @@ class VectorScatterOpConversion
// Resolve address.
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
- adaptor.getBase(), adaptor.getIndices());
+ adaptor.getBase(), adaptor.getOffsets());
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
+ adaptor.getBase(), ptr, adaptor.getIndices(), vType);
// Replace with the scatter intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 74e48b59b6460..385cc28d0188c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5647,7 +5647,7 @@ LogicalResult GatherOp::verify() {
if (resVType.getElementType() != baseType.getElementType())
return emitOpError("base and result element type should match");
- if (llvm::size(getIndices()) != baseType.getRank())
+ if (llvm::size(getOffsets()) != baseType.getRank())
return emitOpError("requires ") << baseType.getRank() << " indices";
if (resVType.getShape() != indVType.getShape())
return emitOpError("expected result dim to match indices dim");
@@ -5719,11 +5719,11 @@ class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
if (!isa<MemRefType>(op.getBase().getType()))
return rewriter.notifyMatchFailure(op, "base must be of memref type");
- if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+ if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
- op.getIndices(), op.getMask(),
+ op.getOffsets(), op.getMask(),
op.getPassThru());
return success();
}
@@ -5747,7 +5747,7 @@ LogicalResult ScatterOp::verify() {
if (valueVType.getElementType() != memType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getIndices()) != memType.getRank())
+ if (llvm::size(getOffsets()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
@@ -5782,11 +5782,11 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
- if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+ if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
rewriter.replaceOpWithNewOp<MaskedStoreOp>(
- op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
+ op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 66196194b0585..546099ca975b7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,7 +162,7 @@ struct GatherOpInterface
return failure();
replaceOpWithNewBufferizedOp<vector::GatherOp>(
rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
- gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
+ gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(),
gatherOp.getPassThru());
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index e062f55f87679..4f97df0c36991 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -64,7 +64,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
Location loc = op.getLoc();
- Value indexVec = op.getIndexVec();
+ Value indexVec = op.getIndices();
Value maskVec = op.getMask();
Value passThruVec = op.getPassThru();
@@ -83,7 +83,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
Value passThruSubVec =
vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
Value subGather = vector::GatherOp::create(
- rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec,
+ rewriter, loc, subTy, op.getBase(), op.getOffsets(), indexSubVec,
maskSubVec, passThruSubVec);
result =
vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx);
@@ -158,18 +158,18 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
// 2. Generate new gather indices that will model the
// strided access.
IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
- VectorType vType = op.getIndexVec().getType();
+ VectorType vType = op.getIndices().getType();
Value mulCst = arith::ConstantOp::create(
rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
Value newIdxs =
- arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst);
+ arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
// 3. Create an updated gather op with the collapsed input memref and the
// updated indices.
Value newGather = vector::GatherOp::create(
rewriter, op.getLoc(), op.getResult().getType(), collapsed,
- op.getIndices(), newIdxs, op.getMask(), op.getPassThru());
+ op.getOffsets(), newIdxs, op.getMask(), op.getPassThru());
rewriter.replaceOp(op, newGather);
return success();
@@ -212,8 +212,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
- op.getIndexVec());
- auto baseOffsets = llvm::to_vector(op.getIndices());
+ op.getIndices());
+ auto baseOffsets = llvm::to_vector(op.getOffsets());
Value lastBaseOffset = baseOffsets.back();
Value result = op.getPassThru();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 45ef7f01a85f1..5617b067d249e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -269,7 +269,7 @@ struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
// Replace the `vector.mask` operation.
rewriter.replaceOpWithNewOp<GatherOp>(
maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
- gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
+ gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(),
passthru);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 501abecfacd04..e8ecb0c0be846 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -640,7 +640,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
// decomposed shape from each of the index, mask, and pass-through
// vectors.
Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
+ loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
Value passThruSubVec =
@@ -648,7 +648,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
strides);
auto slicedGather = vector::GatherOp::create(
- rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
+ rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
indexSubVec, maskSubVec, passThruSubVec);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Scanning through the Vector ops, this seems to be better aligned with other ops now. Maybe not perfect as transfer_read/transfer_write use indices
which might be more like offsets
of other ops, but that's orthogonal to this PR.
Thanks!
Renames `indices` as `offsets` and `index_vec` as `indices`.
a718798
to
798b079
Compare
Renames
indices
asoffsets
andindex_vec
asindices
. This is primarily to make clearer distinction between the arguments.