Gathers values along an axis specified by dim
.
For a 3-D tensor the output is specified by::
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
If input
is an n-dimensional tensor with size
\((x_0, x_1..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})\)
and dim = i
, then index
must be an \(n\)-dimensional tensor with
size \((x_0, x_1, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})\) where \(y \geq 1\)
and out
will have the same size as index
.