더북(TheBook)

4. 3에서 구성한 CNN을 학습시키기

>>> tr.x <- t(train.x)   # train셋 구성
>>> dim(tr.x) <- c(28, 28, 1, ncol(tr.x))
>>> ts.x <- t(test.x)    # test셋 구성
>>> dim(ts.x) <- c(28, 28, 1, ncol(ts.x))

>>> logger.epoc <- mx.callback.log.train.metric(100)
>>> logger.batch <- mx.metric.logger$new()
>>> mx.set.seed(42)      # 난수 생성의 시드를 지정함
>>> ti <- proc.time()
>>> model <- mx.model.FeedForward.create(lenet, X=tr.x, y=train.y,
>>>     eval.data=list(data=ts.x, label=test.y),
>>>     ctx=mx.cpu(),
>>>     num.round=20,
>>>     array.batch.size=100,
>>>     learning.rate=0.05,
>>>     momentum=0.9,
>>>     wd=0.00001,
>>>     eval.metric=mx.metric.accuracy,
>>>     epoch.end.callback=logger.epoc,
>>>     batch.end.callback=mx.callback.log.train.metric(1, logger.batch))

        …
Batch [597] Validation-accuracy=0.991400007009506
Batch [598] Train-accuracy=1
Batch [598] Validation-accuracy=0.991400007009506
Batch [599] Train-accuracy=1
Batch [599] Validation-accuracy=0.991400007009506
Batch [600] Train-accuracy=1
Batch [600] Validation-accuracy=0.991400007009506
[20] Train-accuracy=1
[20] Validation-accuracy=0.991500006914139

>>> te <- proc.time()
>>> print(te-ti)
 사용자  시스템  elapsed
2670.46  520.37  1489.68
>>> mx.model.save(model, "mnistModel", 1)
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.