Learn R Programming

torch (version 0.8.1)

nn_module: Base class for all neural network modules.

Description

Your models should also subclass this class.

Usage

nn_module(
  classname = NULL,
  inherit = nn_Module,
  ...,
  private = NULL,
  active = NULL,
  parent_env = parent.frame()
)

Arguments

classname

an optional name for the module

inherit

an optional module to inherit from

...

methods implementation

private

passed to R6::R6Class().

active

passed to R6::R6Class().

parent_env

passed to R6::R6Class().

Initialize

The initialize function will be called whenever a new instance of the nn_module is created. We use the initialize functions to define submodules and parameters of the module. For example:

initialize = function(input_size, output_size) {
   self$conv1 <- nn_conv2d(input_size, output_size, 5)
   self$conv2 <- nn_conv2d(output_size, output_size, 5)
}

The initialize function can have any number of parameters. All objects assigned to self$ will be available for other methods that you implement. Tensors wrapped with nn_parameter() or nn_buffer() and submodules are automatically tracked when assigned to self$.

The initialize function is optional if the module you are defining doesn't have weights, submodules or buffers.

Forward

The forward method is called whenever an instance of nn_module is called. This is usually used to implement the computation that the module does with the weights ad submodules defined in the initialize function.

For example:

forward = function(input) {
   input <- self$conv1(input)
   input <- nnf_relu(input)
   input <- self$conv2(input)
   input <- nnf_relu(input)
   input
 }

The forward function can use the self$training attribute to make different computations depending wether the model is training or not, for example if you were implementing the dropout module.

Details

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes.

You are expected to implement the initialize and the forward to create a new nn_module.

Examples

Run this code
if (torch_is_installed()) {
model <- nn_module(
  initialize = function() {
    self$conv1 <- nn_conv2d(1, 20, 5)
    self$conv2 <- nn_conv2d(20, 20, 5)
  },
  forward = function(input) {
    input <- self$conv1(input)
    input <- nnf_relu(input)
    input <- self$conv2(input)
    input <- nnf_relu(input)
    input
  }
)
}

Run the code above in your browser using DataLab