MNIST 데이터셋은 많은 딥러닝 프레임워크에서도 예제로 널리 사용하고 있으며, 텐서플로도 MNIST 데이터셋을 쉽게 불러올 수 있는 인터페이스를 제공합니다. 텐서플로에서 MNIST 데이터셋을 이용하여 필기체 숫자 인식 학습을 수행하는 파이썬 소스 코드를 코드 16-2에 나타냈습니다. 코드 16-2에 나타난 mnist_cnn.py 소스 파일은 내려받은 예제 파일 중 ch16/mnist_cnn 폴더에서 확인할 수 있습니다.
코드 16-2 텐서플로를 이용한 필기체 숫자 인식 학습(mnist_cnn.py 파일) [ch16/mnist_cnn]
01 import tensorflow as tf 02 from tensorflow.examples.tutorials.mnist import input_data 03 from tensorflow.python.framework import graph_util 04 from tensorflow.python.platform import gfile 05 06 tf.logging.set_verbosity(tf.logging.ERROR) 07 08 mnist = input_data.read_data_sets("./MNIST_data/", one_hot=True) 09 10 # 11 # hyper parameters 12 # 13 learning_rate = 0.001 14 training_epochs = 20 15 batch_size = 100 16 17 # 18 # Model configuration 19 # 20 X = tf.placeholder(tf.float32, [None, 28, 28, 1], name='data') 21 Y = tf.placeholder(tf.float32, [None, 10]) 22 23 conv1 = tf.layers.conv2d(X, 32, [3, 3], padding="same", activation=tf.nn.relu) 24 pool1 = tf.layers.max_pooling2d(conv1, [2, 2], strides=2, padding="same") 25 26 conv2 = tf.layers.conv2d(pool1, 64, [3, 3], padding="same", activation=tf.nn.relu) 27 pool2 = tf.layers.max_pooling2d(conv2, [2, 2], strides=2, padding="same") 28 29 flat3 = tf.contrib.layers.flatten(pool2) 30 dense3 = tf.layers.dense(flat3, 256, activation=tf.nn.relu) 31 32 logits = tf.layers.dense(dense3, 10, activation=None) 33 final_tensor = tf.nn.softmax(logits, name='prob') 34 35 cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=logits)) 36 optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost) 37 38 # 39 # Training 40 # 41 with tf.Session() as sess: 42 sess.run(tf.global_variables_initializer()) 43 total_batch = int(mnist.train.num_examples / batch_size) 44 45 print('Start learning!') 46 for epoch in range(training_epochs): 47 total_cost = 0 48 49 for i in range(total_batch): 50 batch_xs, batch_ys = mnist.train.next_batch(batch_size) 51 batch_xs = batch_xs.reshape(-1, 28, 28, 1) 52 _, cost_val = sess.run([optimizer, cost], feed_dict={ 53 X: batch_xs, Y: batch_ys}) 54 total_cost += cost_val 55 56 print('Epoch:', '%04d' % (epoch + 1), 'Avg. cost = ', 57 '{:.4f}'.format(total_cost/total_batch)) 58 59 print('Learning finished!') 60 61 # Freeze variables and save pb file 62 output_graph_def = graph_util.convert_variables_to_constants( 63 sess, sess.graph_def, ['prob']) 64 with gfile.FastGFile('./mnist_cnn.pb', 'wb') as f: 65 f.write(output_graph_def.SerializeToString()) 66 67 print('Save done!')
• 1~4행 프로그램 동작에 필요한 파이썬 패키지를 포함시킵니다.
• 8행 MNIST 데이터셋을 인터넷에서 내려받아 MNIST_data 폴더에 저장합니다.
• 20~36행 두 개의 컨볼루션 레이어와 하나의 완전 연결 레이어로 이루어진 네트워크를 구성합니다.
• 41~57행 딥러닝 학습을 수행합니다.
• 62~65행 학습 결과를 mnist_cnn.pb 파일로 저장합니다.