Learn R Programming

bnlearn (version 4.1.1)

naive.bayes: Naive Bayes classifiers

Description

Create, fit and perform predictions with naive Bayes and Tree-Augmented naive Bayes (TAN) classifiers.

Usage

naive.bayes(x, training, explanatory)
# S3 method for bn.naive
predict(object, data, prior, ..., prob = FALSE, debug = FALSE)

tree.bayes(x, training, explanatory, whitelist = NULL, blacklist = NULL, mi = NULL, root = NULL, debug = FALSE) # S3 method for bn.tan predict(object, data, prior, ..., prob = FALSE, debug = FALSE)

Arguments

training

a character string, the label of the training variable.

explanatory

a vector of character strings, the labels of the explanatory variables.

object

an object of class bn.naive, either fitted or not.

x, data

a data frame containing the variables in the model, which must all be factors.

prior

a numeric vector, the prior distribution for the training variable. It is automatically normalized if not already so. The default prior is the probability distribution of the training variable in object.

whitelist

a data frame with two columns (optionally labeled "from" and "to"), containing a set of arcs to be included in the graph.

blacklist

a data frame with two columns (optionally labeled "from" and "to"), containing a set of arcs not to be included in the graph.

mi

a character string, the estimator used for the mutual information coefficients for the Chow-Liu algorithm in TAN. Possible values are mi (discrete mutual information) and mi-g (Gaussian mutual information).

root

a character string, the label of the explanatory variable to be used as the root of the tree in the TAN classifier.

extra arguments from the generic method (currently ignored).

prob

a boolean value. If TRUE the posterior probabilities used for prediction are attached to the predicted values as an attribute called prob.

debug

a boolean value. If TRUE a lot of debugging output is printed; otherwise the function is completely silent.

Value

naive.bayes returns an object of class c("bn.naive", "bn"), which behaves like a normal bn object unless passed to predict. tree.bayes returns an object of class c("bn.tan", "bn"), which again behaves like a normal bn object unless passed to predict.

predict returns a factor with the same levels as the training variable from data. If prob = TRUE, the posterior probabilities used for prediction are attached to the predicted values as an attribute called prob.

Details

The naive.bayes functions creates the star-shaped Bayesian network form of a naive Bayes classifier; the training variable (the one holding the group each observation belongs to) is at the center of the star, and it has an outgoing arc for each explanatory variable.

If data is specified, explanatory will be ignored and the labels of the explanatory variables will be extracted from the data.

predict performs a supervised classification of the observations by assigning them to the group with the maximum posterior probability.

References

Borgelt C, Kruse R, Steinbrecher M (2009). Graphical Models: Representations for Learning, Reasoning and Data Mining. Wiley, 2nd edition.

Friedman N, Geiger D, Goldszmidt M (1997). "Bayesian Network Classifiers". Machine Learning, 29(2--3), 131--163.

Examples

Run this code
data(learning.test)
# this is an in-sample prediction with naive Bayes (parameter learning
# is performed implicitly during the prediction).
bn = naive.bayes(learning.test, "A")
pred = predict(bn, learning.test)
table(pred, learning.test[, "A"])

# this is an in-sample prediction with TAN (parameter learning is
# performed explicitly with bn.fit).
tan = tree.bayes(learning.test, "A")
fitted = bn.fit(tan, learning.test, method = "bayes")
pred = predict(fitted, learning.test)
table(pred, learning.test[, "A"])

# this is an out-of-sample prediction, from a training test to a separate
# test set.
training.set = learning.test[1:4000, ]
test.set = learning.test[4001:5000, ]
bn = naive.bayes(training.set, "A")
fitted = bn.fit(bn, training.set)
pred = predict(fitted, test.set)
table(pred, test.set[, "A"])

Run the code above in your browser using DataLab