Learn R Programming

torch (version 0.8.1)

nn_flatten: Flattens a contiguous range of dims into a tensor.

Description

For use with nn_sequential.

Usage

nn_flatten(start_dim = 2, end_dim = -1)

Arguments

start_dim

first dim to flatten (default = 2).

end_dim

last dim to flatten (default = -1).

Shape

  • Input: (*, S_start,..., S_i, ..., S_end, *), where S_i is the size at dimension i and * means any number of dimensions including none.

  • Output: (*, S_start*...*S_i*...S_end, *).

See Also

nn_unflatten

Examples

Run this code
if (torch_is_installed()) {
input <- torch_randn(32, 1, 5, 5)
m <- nn_flatten()
m(input)
}

Run the code above in your browser using DataLab