MXNet R Tutorial on Callback Function¶
This vignette gives users a guideline for using and writing callback functions, which can very useful in model training.
This tutorial is written in Rmarkdown.
- You can directly view the hosted version of the tutorial from MXNet R Document
- You can find the Rmarkdown source from here
Model training example¶
Let’s begin from a small example. We can build and train a model using the following code:
library(mxnet)
data(BostonHousing, package="mlbench")
train.ind = seq(1, 506, 3)
train.x = data.matrix(BostonHousing[train.ind, -14])
train.y = BostonHousing[train.ind, 14]
test.x = data.matrix(BostonHousing[-train.ind, -14])
test.y = BostonHousing[-train.ind, 14]
data <- mx.symbol.Variable("data")
fc1 <- mx.symbol.FullyConnected(data, num_hidden=1)
lro <- mx.symbol.LinearRegressionOutput(fc1)
mx.set.seed(0)
model <- mx.model.FeedForward.create(
lro, X=train.x, y=train.y,
eval.data=list(data=test.x, label=test.y),
ctx=mx.cpu(), num.round=10, array.batch.size=20,
learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse)
## Auto detect layout of input matrix, use rowmajor..
## Start training with 1 devices
## [1] Train-rmse=16.063282524034
## [1] Validation-rmse=10.1766446093622
## [2] Train-rmse=12.2792375712573
## [2] Validation-rmse=12.4331776190813
## [3] Train-rmse=11.1984634005885
## [3] Validation-rmse=10.3303041888193
## [4] Train-rmse=10.2645236892904
## [4] Validation-rmse=8.42760407903415
## [5] Train-rmse=9.49711005504284
## [5] Validation-rmse=8.44557808483234
## [6] Train-rmse=9.07733734175182
## [6] Validation-rmse=8.33225500266177
## [7] Train-rmse=9.07884450847991
## [7] Validation-rmse=8.38827833418459
## [8] Train-rmse=9.10463850277417
## [8] Validation-rmse=8.37394452365264
## [9] Train-rmse=9.03977049028532
## [9] Validation-rmse=8.25927979725672
## [10] Train-rmse=8.96870685004475
## [10] Validation-rmse=8.19509291481822
Besides, we provide two optional parameters, batch.end.callback
and epoch.end.callback
, which can provide great flexibility in model training.
How to use callback functions¶
Two callback functions are provided in this package:
mx.callback.save.checkpoint
is used to save checkpoint to files each period iteration.
model <- mx.model.FeedForward.create(
lro, X=train.x, y=train.y,
eval.data=list(data=test.x, label=test.y),
ctx=mx.cpu(), num.round=10, array.batch.size=20,
learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse,
epoch.end.callback = mx.callback.save.checkpoint("boston"))
## Auto detect layout of input matrix, use rowmajor..
## Start training with 1 devices
## [1] Train-rmse=19.1621424021617
## [1] Validation-rmse=20.721515592165
## Model checkpoint saved to boston-0001.params
## [2] Train-rmse=13.5127391952367
## [2] Validation-rmse=14.1822123675007
## Model checkpoint saved to boston-0002.params
............
mx.callback.log.train.metric
is used to log training metric each period. You can use it either as abatch.end.callback
or aepoch.end.callback
.
model <- mx.model.FeedForward.create(
lro, X=train.x, y=train.y,
eval.data=list(data=test.x, label=test.y),
ctx=mx.cpu(), num.round=10, array.batch.size=20,
learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse,
batch.end.callback = mx.callback.log.train.metric(5))
## Auto detect layout of input matrix, use rowmajor..
## Start training with 1 devices
## Batch [5] Train-rmse=17.6514558545416
## [1] Train-rmse=15.2879610219001
## [1] Validation-rmse=12.3332062820921
## Batch [5] Train-rmse=11.939392828565
## [2] Train-rmse=11.4382242547217
## [2] Validation-rmse=9.91176550103181
............
You can also save the training and evaluation errors for later usage by passing a reference class.
logger <- mx.metric.logger$new()
model <- mx.model.FeedForward.create(
lro, X=train.x, y=train.y,
eval.data=list(data=test.x, label=test.y),
ctx=mx.cpu(), num.round=10, array.batch.size=20,
learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse,
epoch.end.callback = mx.callback.log.train.metric(5, logger))
## Auto detect layout of input matrix, use rowmajor..
## Start training with 1 devices
## [1] Train-rmse=19.1083228733256
## [1] Validation-rmse=12.7150687428974
## [2] Train-rmse=15.7684378116157
## [2] Validation-rmse=14.8105319420491
............
head(logger$train)
## [1] 19.108323 15.768438 13.531470 11.386050 9.555477 9.351324
head(logger$eval)
## [1] 12.715069 14.810532 15.840361 10.898733 9.349706 9.363087
How to write your own callback functions¶
You can find the source code for two callback functions from here and they can be used as your template:
Basically, all callback functions follow the structure below:
mx.callback.fun <- function() {
function(iteration, nbatch, env) {
}
}
The mx.callback.save.checkpoint
function below is stateless. It just get the model from environment and save it.
mx.callback.save.checkpoint <- function(prefix, period=1) {
function(iteration, nbatch, env) {
if (iteration %% period == 0) {
mx.model.save(env$model, prefix, iteration)
cat(sprintf("Model checkpoint saved to %s-%04d.params\n", prefix, iteration))
}
return(TRUE)
}
}
The mx.callback.log.train.metric
is a little more complex. It will hold a reference class and update it during the training
process.
mx.callback.log.train.metric <- function(period, logger=NULL) {
function(iteration, nbatch, env) {
if (nbatch %% period == 0 && !is.null(env$metric)) {
result <- env$metric$get(env$train.metric)
if (nbatch != 0)
cat(paste0("Batch [", nbatch, "] Train-", result$name, "=", result$value, "\n"))
if (!is.null(logger)) {
if (class(logger) != "mx.metric.logger") {
stop("Invalid mx.metric.logger.")
}
logger$train <- c(logger$train, result$value)
if (!is.null(env$eval.metric)) {
result <- env$metric$get(env$eval.metric)
if (nbatch != 0)
cat(paste0("Batch [", nbatch, "] Validation-", result$name, "=", result$value, "\n"))
logger$eval <- c(logger$eval, result$value)
}
}
}
return(TRUE)
}
}
Now you might be curious why both callback functions return(TRUE)
.
Can we return(FALSE)
?
Yes! You can stop the training early by return(FALSE)
. See the examples below.
mx.callback.early.stop <- function(eval.metric) {
function(iteration, nbatch, env) {
if (!is.null(env$metric)) {
if (!is.null(eval.metric)) {
result <- env$metric$get(env$eval.metric)
if (result$value < eval.metric) {
return(FALSE)
}
}
}
return(TRUE)
}
}
model <- mx.model.FeedForward.create(
lro, X=train.x, y=train.y,
eval.data=list(data=test.x, label=test.y),
ctx=mx.cpu(), num.round=10, array.batch.size=20,
learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse,
epoch.end.callback = mx.callback.early.stop(10))
## Auto detect layout of input matrix, use rowmajor..
## Start training with 1 devices
## [1] Train-rmse=18.5897984387033
## [1] Validation-rmse=13.5555213820571
## [2] Train-rmse=12.5867564040256
## [2] Validation-rmse=9.76304967080928
You can see once the validation metric goes below the threshold we set, the training process will stop early.