MNIST 데이터셋은 많은 딥러닝 프레임워크에서도 예제로 널리 사용하고 있으며, 텐서플로도 MNIST 데이터셋을 쉽게 불러올 수 있는 인터페이스를 제공합니다. 텐서플로에서 MNIST 데이터셋을 이용하여 필기체 숫자 인식 학습을 수행하는 파이썬 소스 코드를 코드 16-2에 나타냈습니다. 코드 16-2에 나타난 mnist_cnn.py 소스 파일은 내려받은 예제 파일 중 ch16/mnist_cnn 폴더에서 확인할 수 있습니다.

    코드 16-2 텐서플로를 이용한 필기체 숫자 인식 학습(mnist_cnn.py 파일) [ch16/mnist_cnn]

    01    import tensorflow as tf
    02    from tensorflow.examples.tutorials.mnist import input_data
    03    from tensorflow.python.framework import graph_util
    04    from tensorflow.python.platform import gfile
    05     
    06    tf.logging.set_verbosity(tf.logging.ERROR)
    07     
    08    mnist = input_data.read_data_sets("./MNIST_data/", one_hot=True)
    09     
    10    #
    11    # hyper parameters
    12    #
    13    learning_rate = 0.001
    14    training_epochs = 20
    15    batch_size = 100
    16     
    17    #
    18    # Model configuration
    19    #
    20    X = tf.placeholder(tf.float32, [None, 28, 28, 1], name='data')
    21    Y = tf.placeholder(tf.float32, [None, 10])
    22     
    23    conv1 = tf.layers.conv2d(X, 32, [3, 3], padding="same", activation=tf.nn.relu)
    24    pool1 = tf.layers.max_pooling2d(conv1, [2, 2], strides=2, padding="same")
    25     
    26    conv2 = tf.layers.conv2d(pool1, 64, [3, 3], padding="same", activation=tf.nn.relu)
    27    pool2 = tf.layers.max_pooling2d(conv2, [2, 2], strides=2, padding="same")
    28     
    29    flat3 = tf.contrib.layers.flatten(pool2)
    30    dense3 = tf.layers.dense(flat3, 256, activation=tf.nn.relu)
    31     
    32    logits = tf.layers.dense(dense3, 10, activation=None)
    33    final_tensor = tf.nn.softmax(logits, name='prob')
    34     
    35    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=logits))
    36    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    37     
    38    #
    39    # Training
    40    #
    41    with tf.Session() as sess:
    42        sess.run(tf.global_variables_initializer())
    43        total_batch = int(mnist.train.num_examples / batch_size)
    44     
    45        print('Start learning!')
    46        for epoch in range(training_epochs):
    47            total_cost = 0
    48     
    49            for i in range(total_batch):
    50                batch_xs, batch_ys = mnist.train.next_batch(batch_size)
    51                batch_xs = batch_xs.reshape(-1, 28, 28, 1)
    52                _, cost_val = sess.run([optimizer, cost], feed_dict={
    53                                       X: batch_xs, Y: batch_ys})
    54                total_cost += cost_val
    55     
    56            print('Epoch:', '%04d' % (epoch + 1), 'Avg. cost = ',
    57                  '{:.4f}'.format(total_cost/total_batch))
    58     
    59        print('Learning finished!')
    60     
    61        # Freeze variables and save pb file
    62        output_graph_def = graph_util.convert_variables_to_constants(
    63            sess, sess.graph_def, ['prob'])
    64        with gfile.FastGFile('./mnist_cnn.pb', 'wb') as f:
    65            f.write(output_graph_def.SerializeToString())
    66     
    67        print('Save done!')

     

    1~4행 프로그램 동작에 필요한 파이썬 패키지를 포함시킵니다.

    8행 MNIST 데이터셋을 인터넷에서 내려받아 MNIST_data 폴더에 저장합니다.

    20~36행 두 개의 컨볼루션 레이어와 하나의 완전 연결 레이어로 이루어진 네트워크를 구성합니다.

    41~57행 딥러닝 학습을 수행합니다.

    62~65행 학습 결과를 mnist_cnn.pb 파일로 저장합니다.

    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.