Skip to content

Commit 613ec4c

Browse files
authored
[mlir][vector] Rename gather/scatter arguments (nfc) (#153640)
Renames `indices` as `offsets` and `index_vec` as `indices`. This is primarily to make clearer distinction between the arguments.
1 parent 9899567 commit 613ec4c

File tree

7 files changed

+36
-36
lines changed

7 files changed

+36
-36
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,8 +2047,8 @@ def Vector_GatherOp :
20472047
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
20482048
]>,
20492049
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
2050-
Variadic<Index>:$indices,
2051-
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
2050+
Variadic<Index>:$offsets,
2051+
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
20522052
VectorOfNonZeroRankOf<[I1]>:$mask,
20532053
AnyVectorOfNonZeroRank:$pass_thru,
20542054
ConfinedAttr<OptionalAttr<I64Attr>,
@@ -2072,19 +2072,19 @@ def Vector_GatherOp :
20722072

20732073
```mlir
20742074
func.func @gather_3D_to_2D(
2075-
%base: memref<?x10x?xf32>, %i0: index, %i1: index, %i2: index,
2076-
%index_vec: vector<2x3xi32>, %mask: vector<2x3xi1>,
2075+
%base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
2076+
%indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
20772077
%fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
2078-
%result = vector.gather %base[%i0, %i1, %i2]
2079-
[%index_vec], %mask, %fall_thru : [...]
2078+
%result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
2079+
[%indices], %mask, %fall_thru : [...]
20802080
return %result : vector<2x3xf32>
20812081
}
20822082
```
20832083

20842084
The indexing semantics are then,
20852085

20862086
```
2087-
result[i,j] := if mask[i,j] then base[i0, i1, i2 + index_vec[i,j]]
2087+
result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
20882088
else pass_thru[i,j]
20892089
```
20902090
The index into `base` only varies in the innermost ((k-1)-th) dimension.
@@ -2118,16 +2118,16 @@ def Vector_GatherOp :
21182118

21192119
let extraClassDeclaration = [{
21202120
ShapedType getBaseType() { return getBase().getType(); }
2121-
VectorType getIndexVectorType() { return getIndexVec().getType(); }
2121+
VectorType getIndexVectorType() { return getIndices().getType(); }
21222122
VectorType getMaskVectorType() { return getMask().getType(); }
21232123
VectorType getPassThruVectorType() { return getPassThru().getType(); }
21242124
VectorType getVectorType() { return getResult().getType(); }
21252125
}];
21262126

21272127
let assemblyFormat =
2128-
"$base `[` $indices `]` `[` $index_vec `]` `,` "
2128+
"$base `[` $offsets `]` `[` $indices `]` `,` "
21292129
"$mask `,` $pass_thru attr-dict `:` type($base) `,` "
2130-
"type($index_vec) `,` type($mask) `,` type($pass_thru) "
2130+
"type($indices) `,` type($mask) `,` type($pass_thru) "
21312131
"`into` type($result)";
21322132
let hasCanonicalizer = 1;
21332133
let hasVerifier = 1;
@@ -2150,8 +2150,8 @@ def Vector_GatherOp :
21502150
def Vector_ScatterOp :
21512151
Vector_Op<"scatter">,
21522152
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
2153-
Variadic<Index>:$indices,
2154-
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
2153+
Variadic<Index>:$offsets,
2154+
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
21552155
VectorOfNonZeroRankOf<[I1]>:$mask,
21562156
AnyVectorOfNonZeroRank:$valueToStore,
21572157
ConfinedAttr<OptionalAttr<I64Attr>,
@@ -2207,15 +2207,15 @@ def Vector_ScatterOp :
22072207

22082208
let extraClassDeclaration = [{
22092209
MemRefType getMemRefType() { return getBase().getType(); }
2210-
VectorType getIndexVectorType() { return getIndexVec().getType(); }
2210+
VectorType getIndexVectorType() { return getIndices().getType(); }
22112211
VectorType getMaskVectorType() { return getMask().getType(); }
22122212
VectorType getVectorType() { return getValueToStore().getType(); }
22132213
}];
22142214

22152215
let assemblyFormat =
2216-
"$base `[` $indices `]` `[` $index_vec `]` `,` "
2216+
"$base `[` $offsets `]` `[` $indices `]` `,` "
22172217
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
2218-
"type($index_vec) `,` type($mask) `,` type($valueToStore)";
2218+
"type($indices) `,` type($mask) `,` type($valueToStore)";
22192219
let hasCanonicalizer = 1;
22202220
let hasVerifier = 1;
22212221

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,11 @@ class VectorGatherOpConversion
306306

307307
// Resolve address.
308308
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
309-
adaptor.getBase(), adaptor.getIndices());
309+
adaptor.getBase(), adaptor.getOffsets());
310310
Value base = adaptor.getBase();
311311
Value ptrs =
312312
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
313-
base, ptr, adaptor.getIndexVec(), vType);
313+
base, ptr, adaptor.getIndices(), vType);
314314

315315
// Replace with the gather intrinsic.
316316
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
@@ -362,10 +362,10 @@ class VectorScatterOpConversion
362362

363363
// Resolve address.
364364
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
365-
adaptor.getBase(), adaptor.getIndices());
365+
adaptor.getBase(), adaptor.getOffsets());
366366
Value ptrs =
367367
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
368-
adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
368+
adaptor.getBase(), ptr, adaptor.getIndices(), vType);
369369

370370
// Replace with the scatter intrinsic.
371371
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5782,7 +5782,7 @@ LogicalResult GatherOp::verify() {
57825782

57835783
if (resVType.getElementType() != baseType.getElementType())
57845784
return emitOpError("base and result element type should match");
5785-
if (llvm::size(getIndices()) != baseType.getRank())
5785+
if (llvm::size(getOffsets()) != baseType.getRank())
57865786
return emitOpError("requires ") << baseType.getRank() << " indices";
57875787
if (resVType.getShape() != indVType.getShape())
57885788
return emitOpError("expected result dim to match indices dim");
@@ -5854,11 +5854,11 @@ class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
58545854
if (!isa<MemRefType>(op.getBase().getType()))
58555855
return rewriter.notifyMatchFailure(op, "base must be of memref type");
58565856

5857-
if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5857+
if (failed(isZeroBasedContiguousSeq(op.getIndices())))
58585858
return failure();
58595859

58605860
rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
5861-
op.getIndices(), op.getMask(),
5861+
op.getOffsets(), op.getMask(),
58625862
op.getPassThru());
58635863
return success();
58645864
}
@@ -5882,7 +5882,7 @@ LogicalResult ScatterOp::verify() {
58825882

58835883
if (valueVType.getElementType() != memType.getElementType())
58845884
return emitOpError("base and valueToStore element type should match");
5885-
if (llvm::size(getIndices()) != memType.getRank())
5885+
if (llvm::size(getOffsets()) != memType.getRank())
58865886
return emitOpError("requires ") << memType.getRank() << " indices";
58875887
if (valueVType.getShape() != indVType.getShape())
58885888
return emitOpError("expected valueToStore dim to match indices dim");
@@ -5917,11 +5917,11 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
59175917
using OpRewritePattern::OpRewritePattern;
59185918
LogicalResult matchAndRewrite(ScatterOp op,
59195919
PatternRewriter &rewriter) const override {
5920-
if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5920+
if (failed(isZeroBasedContiguousSeq(op.getIndices())))
59215921
return failure();
59225922

59235923
rewriter.replaceOpWithNewOp<MaskedStoreOp>(
5924-
op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5924+
op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
59255925
return success();
59265926
}
59275927
};

mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ struct GatherOpInterface
162162
return failure();
163163
replaceOpWithNewBufferizedOp<vector::GatherOp>(
164164
rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
165-
gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
165+
gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(),
166166
gatherOp.getPassThru());
167167
return success();
168168
}

mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
5454

5555
LogicalResult matchAndRewrite(vector::GatherOp op,
5656
PatternRewriter &rewriter) const override {
57-
Value indexVec = op.getIndexVec();
57+
Value indexVec = op.getIndices();
5858
Value maskVec = op.getMask();
5959
Value passThruVec = op.getPassThru();
6060

@@ -69,7 +69,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
6969
Value passThruSubVec =
7070
vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
7171
return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
72-
op.getIndices(), indexSubVec, maskSubVec,
72+
op.getOffsets(), indexSubVec, maskSubVec,
7373
passThruSubVec);
7474
};
7575

@@ -141,18 +141,18 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
141141
// 2. Generate new gather indices that will model the
142142
// strided access.
143143
IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
144-
VectorType vType = op.getIndexVec().getType();
144+
VectorType vType = op.getIndices().getType();
145145
Value mulCst = arith::ConstantOp::create(
146146
rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
147147

148148
Value newIdxs =
149-
arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst);
149+
arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
150150

151151
// 3. Create an updated gather op with the collapsed input memref and the
152152
// updated indices.
153153
Value newGather = vector::GatherOp::create(
154154
rewriter, op.getLoc(), op.getResult().getType(), collapsed,
155-
op.getIndices(), newIdxs, op.getMask(), op.getPassThru());
155+
op.getOffsets(), newIdxs, op.getMask(), op.getPassThru());
156156
rewriter.replaceOp(op, newGather);
157157

158158
return success();
@@ -195,8 +195,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
195195

196196
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
197197
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
198-
op.getIndexVec());
199-
auto baseOffsets = llvm::to_vector(op.getIndices());
198+
op.getIndices());
199+
auto baseOffsets = llvm::to_vector(op.getOffsets());
200200
Value lastBaseOffset = baseOffsets.back();
201201

202202
Value result = op.getPassThru();

mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
269269
// Replace the `vector.mask` operation.
270270
rewriter.replaceOpWithNewOp<GatherOp>(
271271
maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
272-
gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
272+
gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(),
273273
passthru);
274274
return success();
275275
}

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,15 +640,15 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
640640
// decomposed shape from each of the index, mask, and pass-through
641641
// vectors.
642642
Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
643-
loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
643+
loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
644644
Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
645645
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
646646
Value passThruSubVec =
647647
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
648648
loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
649649
strides);
650650
auto slicedGather = vector::GatherOp::create(
651-
rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
651+
rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
652652
indexSubVec, maskSubVec, passThruSubVec);
653653

654654
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(

0 commit comments

Comments
 (0)