Rate this Page

torch.gather#

torch.gather(input, dim, index, *, sparse_grad=False, out=None) Tensor#

Gathers values along an axis specified by dim.

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

input and index must have the same number of dimensions. It is also required that index.size(d) <= input.size(d) for all dimensions d != dim. out will have the same shape as index. Note that input and index do not broadcast against each other. When index is empty, we always return an empty output with the same shape without further error checking.

Parameters
  • input (Tensor) – the source tensor

  • dim (int) – the axis along which to index

  • index (LongTensor) – the indices of elements to gather

Keyword Arguments
  • sparse_grad (bool, optional) – If True, gradient w.r.t. input will be a sparse tensor.

  • out (Tensor, optional) – the destination tensor

Example:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])