Skip to content

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Aug 14, 2025

Renames indices as offsets and index_vec as indices. This is primarily to make clearer distinction between the arguments.

@llvmbot
Copy link
Member

llvmbot commented Aug 14, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Renames indices as offsets and index_vec as indices.


Full diff: https://github.com/llvm/llvm-project/pull/153640.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+15-15)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+7-7)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+2-2)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c7b83674fb009..4e1661cd1d3a2 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2051,8 +2051,8 @@ def Vector_GatherOp :
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
   ]>,
     Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
-               Variadic<Index>:$indices,
-               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+               Variadic<Index>:$offsets,
+               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
                AnyVectorOfNonZeroRank:$pass_thru)>,
     Results<(outs AnyVectorOfNonZeroRank:$result)> {
@@ -2074,11 +2074,11 @@ def Vector_GatherOp :
 
     ```mlir
     func.func @gather_3D_to_2D(
-        %base: memref<?x10x?xf32>, %i0: index, %i1: index, %i2: index,
-        %index_vec: vector<2x3xi32>, %mask: vector<2x3xi1>,
+        %base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
+        %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
         %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
-            %result = vector.gather %base[%i0, %i1, %i2]
-                                   [%index_vec], %mask, %fall_thru : [...]
+            %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
+                                   [%indices], %mask, %fall_thru : [...]
             return %result : vector<2x3xf32>
     }
     ```
@@ -2086,7 +2086,7 @@ def Vector_GatherOp :
     The indexing semantics are then,
 
     ```
-    result[i,j] := if mask[i,j] then base[i0, i1, i2 + index_vec[i,j]]
+    result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
                    else pass_thru[i,j]
     ```
     The index into `base` only varies in the innermost ((k-1)-th) dimension.
@@ -2111,16 +2111,16 @@ def Vector_GatherOp :
 
   let extraClassDeclaration = [{
     ShapedType getBaseType() { return getBase().getType(); }
-    VectorType getIndexVectorType() { return getIndexVec().getType(); }
+    VectorType getIndexVectorType() { return getIndices().getType(); }
     VectorType getMaskVectorType() { return getMask().getType(); }
     VectorType getPassThruVectorType() { return getPassThru().getType(); }
     VectorType getVectorType() { return getResult().getType(); }
   }];
 
   let assemblyFormat =
-    "$base `[` $indices `]` `[` $index_vec `]` `,` "
+    "$base `[` $offsets `]` `[` $indices `]` `,` "
     "$mask `,` $pass_thru attr-dict `:` type($base) `,` "
-    "type($index_vec)  `,` type($mask) `,` type($pass_thru) "
+    "type($indices)  `,` type($mask) `,` type($pass_thru) "
     "`into` type($result)";
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -2129,8 +2129,8 @@ def Vector_GatherOp :
 def Vector_ScatterOp :
   Vector_Op<"scatter">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
-               Variadic<Index>:$indices,
-               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+               Variadic<Index>:$offsets,
+               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
                AnyVectorOfNonZeroRank:$valueToStore)> {
 
@@ -2179,15 +2179,15 @@ def Vector_ScatterOp :
 
   let extraClassDeclaration = [{
     MemRefType getMemRefType() { return getBase().getType(); }
-    VectorType getIndexVectorType() { return getIndexVec().getType(); }
+    VectorType getIndexVectorType() { return getIndices().getType(); }
     VectorType getMaskVectorType() { return getMask().getType(); }
     VectorType getVectorType() { return getValueToStore().getType(); }
   }];
 
   let assemblyFormat =
-      "$base `[` $indices `]` `[` $index_vec `]` `,` "
+      "$base `[` $offsets `]` `[` $indices `]` `,` "
       "$mask `,` $valueToStore attr-dict `:` type($base) `,` "
-      "type($index_vec)  `,` type($mask) `,` type($valueToStore)";
+      "type($indices)  `,` type($mask) `,` type($valueToStore)";
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f9e2a01dbf969..3c9364270d9c0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -306,11 +306,11 @@ class VectorGatherOpConversion
 
     // Resolve address.
     Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
-                                     adaptor.getBase(), adaptor.getIndices());
+                                     adaptor.getBase(), adaptor.getOffsets());
     Value base = adaptor.getBase();
     Value ptrs =
         getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
-                       base, ptr, adaptor.getIndexVec(), vType);
+                       base, ptr, adaptor.getIndices(), vType);
 
     // Replace with the gather intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
@@ -362,10 +362,10 @@ class VectorScatterOpConversion
 
     // Resolve address.
     Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
-                                     adaptor.getBase(), adaptor.getIndices());
+                                     adaptor.getBase(), adaptor.getOffsets());
     Value ptrs =
         getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
-                       adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
+                       adaptor.getBase(), ptr, adaptor.getIndices(), vType);
 
     // Replace with the scatter intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 74e48b59b6460..385cc28d0188c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5647,7 +5647,7 @@ LogicalResult GatherOp::verify() {
 
   if (resVType.getElementType() != baseType.getElementType())
     return emitOpError("base and result element type should match");
-  if (llvm::size(getIndices()) != baseType.getRank())
+  if (llvm::size(getOffsets()) != baseType.getRank())
     return emitOpError("requires ") << baseType.getRank() << " indices";
   if (resVType.getShape() != indVType.getShape())
     return emitOpError("expected result dim to match indices dim");
@@ -5719,11 +5719,11 @@ class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
     if (!isa<MemRefType>(op.getBase().getType()))
       return rewriter.notifyMatchFailure(op, "base must be of memref type");
 
-    if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+    if (failed(isZeroBasedContiguousSeq(op.getIndices())))
       return failure();
 
     rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
-                                              op.getIndices(), op.getMask(),
+                                              op.getOffsets(), op.getMask(),
                                               op.getPassThru());
     return success();
   }
@@ -5747,7 +5747,7 @@ LogicalResult ScatterOp::verify() {
 
   if (valueVType.getElementType() != memType.getElementType())
     return emitOpError("base and valueToStore element type should match");
-  if (llvm::size(getIndices()) != memType.getRank())
+  if (llvm::size(getOffsets()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (valueVType.getShape() != indVType.getShape())
     return emitOpError("expected valueToStore dim to match indices dim");
@@ -5782,11 +5782,11 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(ScatterOp op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+    if (failed(isZeroBasedContiguousSeq(op.getIndices())))
       return failure();
 
     rewriter.replaceOpWithNewOp<MaskedStoreOp>(
-        op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
+        op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 66196194b0585..546099ca975b7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,7 +162,7 @@ struct GatherOpInterface
       return failure();
     replaceOpWithNewBufferizedOp<vector::GatherOp>(
         rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
-        gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
+        gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(),
         gatherOp.getPassThru());
     return success();
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index e062f55f87679..4f97df0c36991 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -64,7 +64,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
       return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
 
     Location loc = op.getLoc();
-    Value indexVec = op.getIndexVec();
+    Value indexVec = op.getIndices();
     Value maskVec = op.getMask();
     Value passThruVec = op.getPassThru();
 
@@ -83,7 +83,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
       Value passThruSubVec =
           vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
       Value subGather = vector::GatherOp::create(
-          rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec,
+          rewriter, loc, subTy, op.getBase(), op.getOffsets(), indexSubVec,
           maskSubVec, passThruSubVec);
       result =
           vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx);
@@ -158,18 +158,18 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
     // 2. Generate new gather indices that will model the
     // strided access.
     IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
-    VectorType vType = op.getIndexVec().getType();
+    VectorType vType = op.getIndices().getType();
     Value mulCst = arith::ConstantOp::create(
         rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
 
     Value newIdxs =
-        arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst);
+        arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
 
     // 3. Create an updated gather op with the collapsed input memref and the
     // updated indices.
     Value newGather = vector::GatherOp::create(
         rewriter, op.getLoc(), op.getResult().getType(), collapsed,
-        op.getIndices(), newIdxs, op.getMask(), op.getPassThru());
+        op.getOffsets(), newIdxs, op.getMask(), op.getPassThru());
     rewriter.replaceOp(op, newGather);
 
     return success();
@@ -212,8 +212,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
 
     Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
         loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
-        op.getIndexVec());
-    auto baseOffsets = llvm::to_vector(op.getIndices());
+        op.getIndices());
+    auto baseOffsets = llvm::to_vector(op.getOffsets());
     Value lastBaseOffset = baseOffsets.back();
 
     Value result = op.getPassThru();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 45ef7f01a85f1..5617b067d249e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -269,7 +269,7 @@ struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
     // Replace the `vector.mask` operation.
     rewriter.replaceOpWithNewOp<GatherOp>(
         maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
-        gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
+        gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(),
         passthru);
     return success();
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 501abecfacd04..e8ecb0c0be846 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -640,7 +640,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
       // decomposed shape from each of the index, mask, and pass-through
       // vectors.
       Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
-          loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
+          loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
       Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
           loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
       Value passThruSubVec =
@@ -648,7 +648,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
               loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
               strides);
       auto slicedGather = vector::GatherOp::create(
-          rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
+          rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
           indexSubVec, maskSubVec, passThruSubVec);
 
       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(

@banach-space banach-space changed the title [mlir][vector] Rename gather/scatter arguments [mlir][vector] Rename gather/scatter arguments (nfc) Aug 14, 2025
Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Scanning through the Vector ops, this seems to be better aligned with other ops now. Maybe not perfect as transfer_read/transfer_write use indices which might be more like offsets of other ops, but that's orthogonal to this PR.

Thanks!

Renames `indices` as `offsets` and `index_vec` as `indices`.
@banach-space banach-space force-pushed the andrzej/vector/gather_scatter_rename branch from a718798 to 798b079 Compare August 25, 2025 15:25
@banach-space banach-space merged commit 613ec4c into llvm:main Aug 25, 2025
9 checks passed
@banach-space banach-space deleted the andrzej/vector/gather_scatter_rename branch August 25, 2025 17:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants