diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index bcc423a634148..130ed9083848b 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2047,8 +2047,8 @@ def Vector_GatherOp : DeclareOpInterfaceMethods ]>, Arguments<(ins Arg, "", [MemRead]>:$base, - Variadic:$indices, - VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec, + Variadic:$offsets, + VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices, VectorOfNonZeroRankOf<[I1]>:$mask, AnyVectorOfNonZeroRank:$pass_thru, ConfinedAttr, @@ -2072,11 +2072,11 @@ def Vector_GatherOp : ```mlir func.func @gather_3D_to_2D( - %base: memref, %i0: index, %i1: index, %i2: index, - %index_vec: vector<2x3xi32>, %mask: vector<2x3xi1>, + %base: memref, %ofs_0: index, %ofs_1: index, %ofs_2: index, + %indices: vector<2x3xi32>, %mask: vector<2x3xi1>, %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> { - %result = vector.gather %base[%i0, %i1, %i2] - [%index_vec], %mask, %fall_thru : [...] + %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2] + [%indices], %mask, %fall_thru : [...] return %result : vector<2x3xf32> } ``` @@ -2084,7 +2084,7 @@ def Vector_GatherOp : The indexing semantics are then, ``` - result[i,j] := if mask[i,j] then base[i0, i1, i2 + index_vec[i,j]] + result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]] else pass_thru[i,j] ``` The index into `base` only varies in the innermost ((k-1)-th) dimension. @@ -2118,16 +2118,16 @@ def Vector_GatherOp : let extraClassDeclaration = [{ ShapedType getBaseType() { return getBase().getType(); } - VectorType getIndexVectorType() { return getIndexVec().getType(); } + VectorType getIndexVectorType() { return getIndices().getType(); } VectorType getMaskVectorType() { return getMask().getType(); } VectorType getPassThruVectorType() { return getPassThru().getType(); } VectorType getVectorType() { return getResult().getType(); } }]; let assemblyFormat = - "$base `[` $indices `]` `[` $index_vec `]` `,` " + "$base `[` $offsets `]` `[` $indices `]` `,` " "$mask `,` $pass_thru attr-dict `:` type($base) `,` " - "type($index_vec) `,` type($mask) `,` type($pass_thru) " + "type($indices) `,` type($mask) `,` type($pass_thru) " "`into` type($result)"; let hasCanonicalizer = 1; let hasVerifier = 1; @@ -2150,8 +2150,8 @@ def Vector_GatherOp : def Vector_ScatterOp : Vector_Op<"scatter">, Arguments<(ins Arg:$base, - Variadic:$indices, - VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec, + Variadic:$offsets, + VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices, VectorOfNonZeroRankOf<[I1]>:$mask, AnyVectorOfNonZeroRank:$valueToStore, ConfinedAttr, @@ -2207,15 +2207,15 @@ def Vector_ScatterOp : let extraClassDeclaration = [{ MemRefType getMemRefType() { return getBase().getType(); } - VectorType getIndexVectorType() { return getIndexVec().getType(); } + VectorType getIndexVectorType() { return getIndices().getType(); } VectorType getMaskVectorType() { return getMask().getType(); } VectorType getVectorType() { return getValueToStore().getType(); } }]; let assemblyFormat = - "$base `[` $indices `]` `[` $index_vec `]` `,` " + "$base `[` $offsets `]` `[` $indices `]` `,` " "$mask `,` $valueToStore attr-dict `:` type($base) `,` " - "type($index_vec) `,` type($mask) `,` type($valueToStore)"; + "type($indices) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; let hasVerifier = 1; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index afc3d1b12ac0d..1ff7d5dad378e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -306,11 +306,11 @@ class VectorGatherOpConversion // Resolve address. Value ptr = getStridedElementPtr(rewriter, loc, memRefType, - adaptor.getBase(), adaptor.getIndices()); + adaptor.getBase(), adaptor.getOffsets()); Value base = adaptor.getBase(); Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType, - base, ptr, adaptor.getIndexVec(), vType); + base, ptr, adaptor.getIndices(), vType); // Replace with the gather intrinsic. rewriter.replaceOpWithNewOp( @@ -362,10 +362,10 @@ class VectorScatterOpConversion // Resolve address. Value ptr = getStridedElementPtr(rewriter, loc, memRefType, - adaptor.getBase(), adaptor.getIndices()); + adaptor.getBase(), adaptor.getOffsets()); Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType, - adaptor.getBase(), ptr, adaptor.getIndexVec(), vType); + adaptor.getBase(), ptr, adaptor.getIndices(), vType); // Replace with the scatter intrinsic. rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2b2581d353673..bc93339a68ed3 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5782,7 +5782,7 @@ LogicalResult GatherOp::verify() { if (resVType.getElementType() != baseType.getElementType()) return emitOpError("base and result element type should match"); - if (llvm::size(getIndices()) != baseType.getRank()) + if (llvm::size(getOffsets()) != baseType.getRank()) return emitOpError("requires ") << baseType.getRank() << " indices"; if (resVType.getShape() != indVType.getShape()) return emitOpError("expected result dim to match indices dim"); @@ -5854,11 +5854,11 @@ class FoldContiguousGather final : public OpRewritePattern { if (!isa(op.getBase().getType())) return rewriter.notifyMatchFailure(op, "base must be of memref type"); - if (failed(isZeroBasedContiguousSeq(op.getIndexVec()))) + if (failed(isZeroBasedContiguousSeq(op.getIndices()))) return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), op.getBase(), - op.getIndices(), op.getMask(), + op.getOffsets(), op.getMask(), op.getPassThru()); return success(); } @@ -5882,7 +5882,7 @@ LogicalResult ScatterOp::verify() { if (valueVType.getElementType() != memType.getElementType()) return emitOpError("base and valueToStore element type should match"); - if (llvm::size(getIndices()) != memType.getRank()) + if (llvm::size(getOffsets()) != memType.getRank()) return emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getShape() != indVType.getShape()) return emitOpError("expected valueToStore dim to match indices dim"); @@ -5917,11 +5917,11 @@ class FoldContiguousScatter final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ScatterOp op, PatternRewriter &rewriter) const override { - if (failed(isZeroBasedContiguousSeq(op.getIndexVec()))) + if (failed(isZeroBasedContiguousSeq(op.getIndices()))) return failure(); rewriter.replaceOpWithNewOp( - op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore()); + op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore()); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index 66196194b0585..546099ca975b7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -162,7 +162,7 @@ struct GatherOpInterface return failure(); replaceOpWithNewBufferizedOp( rewriter, gatherOp, gatherOp.getVectorType(), *buffer, - gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(), + gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(), gatherOp.getPassThru()); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 90f21c53246b0..983018934a85c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -54,7 +54,7 @@ struct UnrollGather : OpRewritePattern { LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { - Value indexVec = op.getIndexVec(); + Value indexVec = op.getIndices(); Value maskVec = op.getMask(); Value passThruVec = op.getPassThru(); @@ -69,7 +69,7 @@ struct UnrollGather : OpRewritePattern { Value passThruSubVec = vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx); return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(), - op.getIndices(), indexSubVec, maskSubVec, + op.getOffsets(), indexSubVec, maskSubVec, passThruSubVec); }; @@ -141,18 +141,18 @@ struct RemoveStrideFromGatherSource : OpRewritePattern { // 2. Generate new gather indices that will model the // strided access. IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim); - VectorType vType = op.getIndexVec().getType(); + VectorType vType = op.getIndices().getType(); Value mulCst = arith::ConstantOp::create( rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride)); Value newIdxs = - arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst); + arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst); // 3. Create an updated gather op with the collapsed input memref and the // updated indices. Value newGather = vector::GatherOp::create( rewriter, op.getLoc(), op.getResult().getType(), collapsed, - op.getIndices(), newIdxs, op.getMask(), op.getPassThru()); + op.getOffsets(), newIdxs, op.getMask(), op.getPassThru()); rewriter.replaceOp(op, newGather); return success(); @@ -195,8 +195,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern { Value indexVec = rewriter.createOrFold( loc, op.getIndexVectorType().clone(rewriter.getIndexType()), - op.getIndexVec()); - auto baseOffsets = llvm::to_vector(op.getIndices()); + op.getIndices()); + auto baseOffsets = llvm::to_vector(op.getOffsets()); Value lastBaseOffset = baseOffsets.back(); Value result = op.getPassThru(); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index 45ef7f01a85f1..5617b067d249e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -269,7 +269,7 @@ struct MaskedGatherOpPattern : public MaskOpRewritePattern { // Replace the `vector.mask` operation. rewriter.replaceOpWithNewOp( maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(), - gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(), + gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(), passthru); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 501abecfacd04..e8ecb0c0be846 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -640,7 +640,7 @@ struct UnrollGatherPattern : public OpRewritePattern { // decomposed shape from each of the index, mask, and pass-through // vectors. Value indexSubVec = rewriter.createOrFold( - loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); + loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides); Value maskSubVec = rewriter.createOrFold( loc, gatherOp.getMask(), elementOffsets, *targetShape, strides); Value passThruSubVec = @@ -648,7 +648,7 @@ struct UnrollGatherPattern : public OpRewritePattern { loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides); auto slicedGather = vector::GatherOp::create( - rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), + rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(), indexSubVec, maskSubVec, passThruSubVec); result = rewriter.createOrFold(