4. 3에서 구성한 CNN을 학습시키기

    >>> tr.x <- t(train.x)   # train셋 구성
    >>> dim(tr.x) <- c(28, 28, 1, ncol(tr.x))
    >>> ts.x <- t(test.x)    # test셋 구성
    >>> dim(ts.x) <- c(28, 28, 1, ncol(ts.x))
    
    >>> logger.epoc <- mx.callback.log.train.metric(100)
    >>> logger.batch <- mx.metric.logger$new()
    >>> mx.set.seed(42)      # 난수 생성의 시드를 지정함
    >>> ti <- proc.time()
    >>> model <- mx.model.FeedForward.create(lenet, X=tr.x, y=train.y,
    >>>     eval.data=list(data=ts.x, label=test.y),
    >>>     ctx=mx.cpu(),
    >>>     num.round=20,
    >>>     array.batch.size=100,
    >>>     learning.rate=0.05,
    >>>     momentum=0.9,
    >>>     wd=0.00001,
    >>>     eval.metric=mx.metric.accuracy,
    >>>     epoch.end.callback=logger.epoc,
    >>>     batch.end.callback=mx.callback.log.train.metric(1, logger.batch))
    
            …
    Batch [597] Validation-accuracy=0.991400007009506
    Batch [598] Train-accuracy=1
    Batch [598] Validation-accuracy=0.991400007009506
    Batch [599] Train-accuracy=1
    Batch [599] Validation-accuracy=0.991400007009506
    Batch [600] Train-accuracy=1
    Batch [600] Validation-accuracy=0.991400007009506
    [20] Train-accuracy=1
    [20] Validation-accuracy=0.991500006914139
    
    >>> te <- proc.time()
    >>> print(te-ti)
     사용자  시스템  elapsed
    2670.46  520.37  1489.68
    >>> mx.model.save(model, "mnistModel", 1)
    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.