Learn R Programming

neuralnet (version 1.44.2)

predict.nn: Neural network prediction

Description

Prediction of artificial neural network of class nn, produced by neuralnet().

Usage

# S3 method for nn
predict(object, newdata, rep = 1, all.units = FALSE, ...)

Arguments

object

Neural network of class nn.

newdata

New data of class data.frame or matrix.

rep

Integer indicating the neural network's repetition which should be used.

all.units

Return output for all units instead of final output only.

...

further arguments passed to or from other methods.

Value

Matrix of predictions. Each column represents one output unit. If all.units=TRUE, a list of matrices with output for each unit.

Examples

Run this code
# NOT RUN {
library(neuralnet)

# Split data
train_idx <- sample(nrow(iris), 2/3 * nrow(iris))
iris_train <- iris[train_idx, ]
iris_test <- iris[-train_idx, ]

# Binary classification
nn <- neuralnet(Species == "setosa" ~ Petal.Length + Petal.Width, iris_train, linear.output = FALSE)
pred <- predict(nn, iris_test)
table(iris_test$Species == "setosa", pred[, 1] > 0.5)

# Multiclass classification
nn <- neuralnet((Species == "setosa") + (Species == "versicolor") + (Species == "virginica")
                 ~ Petal.Length + Petal.Width, iris_train, linear.output = FALSE)
pred <- predict(nn, iris_test)
table(iris_test$Species, apply(pred, 1, which.max))

# }

Run the code above in your browser using DataLab