더북(TheBook)

MNIST 데이터셋은 많은 딥러닝 프레임워크에서도 예제로 널리 사용하고 있으며, 텐서플로도 MNIST 데이터셋을 쉽게 불러올 수 있는 인터페이스를 제공합니다. 텐서플로에서 MNIST 데이터셋을 이용하여 필기체 숫자 인식 학습을 수행하는 파이썬 소스 코드를 코드 16-2에 나타냈습니다. 코드 16-2에 나타난 mnist_cnn.py 소스 파일은 내려받은 예제 파일 중 ch16/mnist_cnn 폴더에서 확인할 수 있습니다.

코드 16-2 텐서플로를 이용한 필기체 숫자 인식 학습(mnist_cnn.py 파일) [ch16/mnist_cnn]

01    import tensorflow as tf
02    from tensorflow.examples.tutorials.mnist import input_data
03    from tensorflow.python.framework import graph_util
04    from tensorflow.python.platform import gfile
05     
06    tf.logging.set_verbosity(tf.logging.ERROR)
07     
08    mnist = input_data.read_data_sets("./MNIST_data/", one_hot=True)
09     
10    #
11    # hyper parameters
12    #
13    learning_rate = 0.001
14    training_epochs = 20
15    batch_size = 100
16     
17    #
18    # Model configuration
19    #
20    X = tf.placeholder(tf.float32, [None, 28, 28, 1], name='data')
21    Y = tf.placeholder(tf.float32, [None, 10])
22     
23    conv1 = tf.layers.conv2d(X, 32, [3, 3], padding="same", activation=tf.nn.relu)
24    pool1 = tf.layers.max_pooling2d(conv1, [2, 2], strides=2, padding="same")
25     
26    conv2 = tf.layers.conv2d(pool1, 64, [3, 3], padding="same", activation=tf.nn.relu)
27    pool2 = tf.layers.max_pooling2d(conv2, [2, 2], strides=2, padding="same")
28     
29    flat3 = tf.contrib.layers.flatten(pool2)
30    dense3 = tf.layers.dense(flat3, 256, activation=tf.nn.relu)
31     
32    logits = tf.layers.dense(dense3, 10, activation=None)
33    final_tensor = tf.nn.softmax(logits, name='prob')
34     
35    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=logits))
36    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
37     
38    #
39    # Training
40    #
41    with tf.Session() as sess:
42        sess.run(tf.global_variables_initializer())
43        total_batch = int(mnist.train.num_examples / batch_size)
44     
45        print('Start learning!')
46        for epoch in range(training_epochs):
47            total_cost = 0
48     
49            for i in range(total_batch):
50                batch_xs, batch_ys = mnist.train.next_batch(batch_size)
51                batch_xs = batch_xs.reshape(-1, 28, 28, 1)
52                _, cost_val = sess.run([optimizer, cost], feed_dict={
53                                       X: batch_xs, Y: batch_ys})
54                total_cost += cost_val
55     
56            print('Epoch:', '%04d' % (epoch + 1), 'Avg. cost = ',
57                  '{:.4f}'.format(total_cost/total_batch))
58     
59        print('Learning finished!')
60     
61        # Freeze variables and save pb file
62        output_graph_def = graph_util.convert_variables_to_constants(
63            sess, sess.graph_def, ['prob'])
64        with gfile.FastGFile('./mnist_cnn.pb', 'wb') as f:
65            f.write(output_graph_def.SerializeToString())
66     
67        print('Save done!')

 

1~4행 프로그램 동작에 필요한 파이썬 패키지를 포함시킵니다.

8행 MNIST 데이터셋을 인터넷에서 내려받아 MNIST_data 폴더에 저장합니다.

20~36행 두 개의 컨볼루션 레이어와 하나의 완전 연결 레이어로 이루어진 네트워크를 구성합니다.

41~57행 딥러닝 학습을 수행합니다.

62~65행 학습 결과를 mnist_cnn.pb 파일로 저장합니다.

신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.