ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [딥러닝] MNIST 텐서플로우 예제 쉽게 따라하기
    딥러닝 공부하기 2023. 12. 19. 12:27

    MNIST는 딥러닝 예제로 유명하죠? MNIST 데이터는 0~9까지의 손글씨 데이터인데, 이미지를 학습하여 어떤 숫자인지 맞추는 것입니다. 28*28픽셀로 구성된 흑백 손글씨로, 학습 데이터는 6만 개, 테스트 데이터는 1만 개로 이루어져있습니다. 

    딥러닝 분류 예제로 많이 활용하는 MNIST 데이터로 딥러닝 학습 순서를 따라해보며 딥러닝을 더 쉽고 정확하게 이해하며 실습해보도록 하겠습니다.

     

     

    1. 우선 텐서플로우 1버전을 사용하도록 하겠습니다.

    import numpy as np
    import matplotlib.pyplot as plt
    
    import tensorflow.compat.v1 as tf
    tf.disable_v2_behavior()
    
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information
    tf.logging.set_verbosity(tf.logging.ERROR)

     

    2. MNIST데이터를 불러와 글씨가 써진 데이터는 reshape를 통해 1행 784열로 만들어서 학습에 맞는 형태로 바꿔줍니다. 정답인 라벨 데이터는 to_categorical로 원핫벡터로 바꿔주어 나중에 손실함수 계산할 수 있도록 데이터를 준비해줍니다.

    from tensorflow.keras import datasets, utils
    
    (train_data, train_label), (test_data, test_label) = datasets.mnist.load_data()
    
    train_data = train_data.reshape(60000, -1) / 255.0
    test_data = test_data.reshape(10000, 784) / 255.0
    
    train_label = utils.to_categorical(train_label) # 0~9 -> one-hot vector
    test_label = utils.to_categorical(test_label) # 0~9 -> one-hot vector

     

    3. 입력층, 은닉층, 출력층의 개수와 형태, 활성화 함수, dropout 준비 등 모델을 설계하고 구축하고 정의합니다. 

    # X, Y를 placeholder 로 담는 이유는 학습 과정일 때는 train 데이터가, 채점 과정일 때는 test 데이터가 들어가면서 동적으로 입력받기 위함힙니다.
    # 보통 행의 수는 바뀔수 있어서 보통은 따로 지정하지 않고 열의 수는 지정해주는 편입니다.
    
    # dropout을 tf.placeholder 로 미리 담아냄으로써 학습 과정일 때는 활성화하고, 채점할 때는 꺼놓을 수 있도록, 
    # 즉 모델의 입력 데이터를 동적으로 제공하기 위해 사용합니다. 
    
    X = tf.placeholder(tf.float32, [None, 784])
    Y = tf.placeholder(tf.float32, [None, 10])
    keep_prob = tf.placeholder(tf.float32)
    
    W1 = tf.Variable(tf.random_normal([784, 256], stddev=0.01))
    L1 = tf.nn.relu(tf.matmul(X, W1))
    L1 = tf.nn.dropout(L1, keep_prob) # (Dropout을 적용할 layer, 살릴 비율)
    
    W2 = tf.Variable(tf.random_normal([256, 256], stddev=0.01))
    L2 = tf.nn.relu(tf.matmul(L1, W2))
    L2 = tf.nn.dropout(L2, keep_prob) 
    
    # output layer에는 dropout을 적용하지 않습니다!!!
    W3 = tf.Variable(tf.random_normal([256, 10], stddev=0.01))
    model = tf.matmul(L2, W3)

     

     

    4. 손실 함수, 옵티마이저의 종류를 설정하여 기준을 설계하고 정의합니다.

    cost = tf.losses.softmax_cross_entropy(Y, model) # for Classification, "cross-entropy" after "softmax" 한 번에 계산해줌, cost = tf.losses.mean_squared_error(Y, model) for Regression
    optimizer = tf.train.AdamOptimizer(0.001).minimize(cost) # Select optimizer & connect with cost function (recommended start : "Adam")

     

    5. 모델 학습하기

    # Initialize all global variables (Parameter Theta)
    init = tf.global_variables_initializer() 
    sess = tf.Session()
    sess.run(init)
    
    # Gradient descent를 적용하기 전까지 한번에 밀어넣는 데이터의 수 지정 (Batch size == 하나의 데이터 덩어리 내 데이터 수)
    batch_size = 100  # 한번에 밀어넣는 데이터가 100개, 그래서 총 100*784
    total_batch = int(len(train_data) / batch_size)  # 6만 장의 사진을 100개씩
    # iteration 횟수 ==  Gradient descent 적용 횟수 == 15*600
    # 세타들이 모두 15*600=9000번씩 바뀜
    
    for epoch in range(15):  # epoch  6만 개를 총 15번 
    
        total_cost = 0 # 매 epoch 마다의 평균 에러 값 계산을 위해 활용됩니다.
        batch_idx = 0 # 매 batch 마다 꺼낼 데이터의 시작 index 값 지정을 위해 활용됩니다.
    
        for i in range(total_batch): # iterate over # of batches
    
            # Training data(60000장)에서 batch_size(100개) 만큼 순서대로 꺼내어 학습에 활용해줍니다.
            batch_x = train_data[ batch_idx : batch_idx + batch_size ]
            batch_y = train_label[ batch_idx : batch_idx + batch_size ]
    
            sess.run(optimizer, feed_dict={X: batch_x, 
                                           Y: batch_y, 
                                           keep_prob: 0.8}) # 살릴 비율 지정, node 중 80%만 유지하고 20%를 train 시마다 off
            
            
            
            # 이번 batch를 기준으로 계산이 끝난 Cross-entropy 값을 total_cost에 더해줍니다. (epoch 종료 후 평균을 냅니다.)
            batch_cost = sess.run(cost, feed_dict={X: batch_x, 
                                                   Y: batch_y, 
                                                   keep_prob: 0.8}) # 살릴 비율 지정, node 중 80%만 유지하고 20%를 train 시마다 off
    
            total_cost = total_cost + batch_cost
    
            # 다음 for loop에서 꺼낼 데이터의 시작 index 번호를 batch_size(100) 만큼 증가시킵니다.
            batch_idx += batch_size
    
        # (이번 epoch가 종료되었을 시점의) training data 기준 Cross-entropy 값을 계산합니다.
        # 전체 epoch 동안 600개의 배치당 cross entropy 에러 값 모아서 합쳐낸 것을 600으로 나눠서 평균내기
        training_cost = total_cost / total_batch
    #     training_cost = sess.run(cost, feed_dict={X: batch_x, Y: batch_y}) 
    
        # (이번 epoch가 종료되었을 시점의) test data 기준 Cross-entropy 값을 계산합니다.
        test_cost = sess.run(cost, feed_dict={X: test_data, Y: test_label}) 
    
    
        print('Epoch: {}'.format(epoch + 1), 
              '|| Avg. Training cost = {:.3f}'.format(training_cost), 
              '|| Current Test cost = {:.3f}'.format(test_cost))
    
    print('Learning process is completed!')

     

    7. 모델 테스트 및 정확도 계산

    is_correct = tf.equal(tf.argmax(model, 1), tf.argmax(Y, 1))
    accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32)) 
    
    # 10,000건의 Test data 전체에 대해 모델의 정확도를 계산합니다.
    # 살릴 비율 지정, 정확도를 측정하는 Test 단계에서는 전체 Node를 살려줘야 합니다.
    print('정확도:', sess.run(accuracy,
                            feed_dict={X: test_data,
                                       Y: test_label,
                                       keep_prob: 1}))

     

    8. 예측값 시각화해보기

    # 모델의 예측값을 labels에 저장
    labels = sess.run(tf.argmax(model, 1),
                      feed_dict={X: test_data,
                                 Y: test_label,
                                 keep_prob: 1}) 
    print(labels)
    fig = plt.figure(figsize=(10, 10))
    
    for i in range(10):
        subplot = fig.add_subplot(2, 5, i + 1)
        subplot.set_xticks([])
        subplot.set_yticks([])
        subplot.set_title('%d' % labels[i])
        subplot.imshow(test_data[i].reshape((28, 28)),
                       cmap=plt.cm.gray_r)
                       
    # 상단의 번호가 예측된 숫자, 아래의 이미지가 실제 데이터(이미지 내 숫자)
    plt.show()

     

     

Designed by Tistory.