Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2058,39 +2058,52 @@ def Vector_GatherOp :
Results<(outs AnyVectorOfNonZeroRank:$result)> {

let summary = [{
gathers elements from memory or ranked tensor into a vector as defined by an
index vector and a mask vector
Gathers elements from memory or ranked tensor into a vector as defined by an
index vector and a mask vector.
}];

let description = [{
The gather operation returns an n-D vector whose elements are either loaded
from memory or ranked tensor, or taken from a pass-through vector, depending
from a k-D memref or tensor, or taken from an n-D pass-through vector, depending
on the values of an n-D mask vector.
If a mask bit is set, the corresponding result element is defined by the base
with indices and the n-D index vector (each index is a 1-D offset on the base).
Otherwise, the corresponding element is taken from the n-D pass-through vector.
Informally the semantics are:

If a mask bit is set, the corresponding result element is taken from `base`
at an index defined by k indices and n-D `index_vec`. Otherwise, the element
is taken from the pass-through vector. As an example, suppose that `base` is
3-D and the result is 2-D:

```mlir
func.func @gather_3D_to_2D(
%base: memref<?x10x?xf32>, %i0: index, %i1: index, %i2: index,
%index_vec: vector<2x3xi32>, %mask: vector<2x3xi1>,
%fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
%result = vector.gather %base[%i0, %i1, %i2]
[%index_vec], %mask, %fall_thru : [...]
return %result : vector<2x3xf32>
}
```
result[0] := if mask[0] then base[index[0]] else pass_thru[0]
result[1] := if mask[1] then base[index[1]] else pass_thru[1]
etc.

The indexing semantics are then,

```
result[i,j] := if mask[i,j] then base[i0, i1, i2 + index_vec[i,j]]
else pass_thru[i,j]
Comment on lines +2089 to +2090
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also be written as:

    result[i,j] := if mask[i,j] then base[i0, i1, i2] + index_vec[i,j]
                   else pass_thru[i,j]

As in, base[i0, i1, i2] provides the base address and then index_vec[i,j] is the "element" index, similarly to how pointer arithmetic works in C.

I wanted to bring it up to make sure that our interpretations are consistent. If that's the case, then I would consider rephrasing:

The index into `base` only varies in the innermost ((k-1)-th) dimension.

(which assumes one interpretation) as

The index vector defines the indices from the base address as defined by the offsets.

This is a bit tricky/nuanced though, as Tensors have no notion of "base address" 😅

Taking a step back, we should probably rename the input arguments as:

  • index -> offsets
  • index_vec -> indices

Have you thought about it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space Thanks for the feedback, and apologies for landing this faster than necessary. Let me know if you think this can be improved further and I'll definitely make a follow-up PR.

With respect to

result[i,j] := if mask[i,j] then base[i0, i1, i2] + index_vec[i,j]
                   else pass_thru[i,j]

I find interpreting base as a pointer less clear.

This is a bit tricky/nuanced though, as Tensors have no notion of "base address" 😅

Exactly!

I'll add that memrefs can be strided (see this test so should strides be included?

Another subtle difference is what 'out of bounds' means. Current lowering ends up as vector.loads of single elements

[...]
%foo = vector.load %base[%i, %j] : memref<100x100xf32>, vector<1xf32>
[...]

There is nothing in the vector.load definition about out-of-bounds, but I assume the natural definition there would be that if %j excedes 99 above, it's out of bounds a UB. Which I think is more inline with the current definition of adding index_vec[i,j] to i2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space Thanks for the feedback, and apologies for landing this faster than necessary.

No worries - this was in review for two days and two reviewers approved it, so it’s totally expected that you landed it. But since post-commit reviews are a thing in LLVM, and this is interesting... 😅

I find interpreting base as a pointer less clear.

Fair enough!

I'll add that memrefs can be strided (see this test so should strides be included?

Hm, not at the vector.gather nor Vector level, no. Are they?

Another subtle difference is what 'out of bounds' means.

UB sounds about right. Masks should take care of "out-of-bounds". If they don't, it would be a UB, yes. Admittedly, we haven't paid that much attention to gathers/scatters - performance is not great and we try to avoid them.

Copy link
Contributor Author

@newling newling Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add that memrefs can be strided (see this test so should strides be included?

Hm, not at the vector.gather nor Vector level, no. Are they?

I meant, in that test (copied below)

%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : 
   memref<4xf32, strided<[2]>>, vector<1xindex>, vector<1xi1>, vector<1xf32> into vector<1xf32>

if the explanation of vector.gather was pointer based:

    result[i,j] := if mask[i,j] then base[i0, i1, i2] + index_vec[i,j]
                   else pass_thru[i,j]

then we should probably include a stride in the above. i.e. it should + index_vec[i,j] should be + stride * index_vec[i,j]

This is one reason why I think the pointer-based definition of gather a bit less clear. With the tensor-based definition (just add index_vec[i,j] to another index, not a pointer) this question doesn't come up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for not replying earlier, I was OOO.

then we should probably include a stride in the above. i.e. it should + index_vec[i,j] should be + stride * index_vec[i,j]

I view this differently. To me, base[i0, i1, i2] + index_vec[i,j] is the vector abstraction and that's all we care about here. Later, these vector level indices are interpreted at either the memref or tensor abstraction levels - that's when "stride" would matter. Put differently, I agree that the actual meaning of this will depend on what base is:

base[i0, i1, i2] + index_vec[i,j]

However, to me that should not be a concern at the vector level.

Anyway, this is a side point - it's obviously totally fine to see things differently. Your change is a much appreciated improvement, lets leave it as is.

```
The index into `base` only varies in the innermost ((k-1)-th) dimension.

If a mask bit is set and the corresponding index is out-of-bounds for the
given base, the behavior is undefined. If a mask bit is not set, the value
comes from the pass-through vector regardless of the index, and the index is
allowed to be out-of-bounds.

The gather operation can be used directly where applicable, or can be used
during progressively lowering to bring other memory operations closer to
hardware ISA support for a gather.

Examples:

```mlir
// 1-D memref gathered to 2-D vector.
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru
: memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>

// 2-D memref gathered to 1-D vector.
%1 = vector.gather %base[%i, %j][%v], %mask, %pass_thru
: memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
```
Expand Down