훈련 데이터셋을 이용하여 모델을 학습시키고, 검증 데이터셋을 이용하여 모델 성능을 확인해 봅니다.

    코드 7-37 모델 학습 및 성능 확인

    seq_dim = 28
    loss_list = []
    iter = 0
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader): ------ 훈련 데이터셋을 이용한 모델 학습
            if torch.cuda.is_available(): ------ GPU 사용 유무 확인
                images = Variable(images.view(-1, seq_dim, input_dim).cuda()) ------ ①
                labels = Variable(labels.cuda())
            else: ------ GPU를 사용하지 않기 때문에 else 구문이 실행
                images = Variable(images.view(-1, seq_dim, input_dim))
                labels = Variable(labels)
    
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels) ------ 손실 함수를 이용하여 오차 계산
    
            if torch.cuda.is_available():
                loss.cuda()
    
            loss.backward()
            optimizer.step() ------ 파라미터 업데이트
            loss_list.append(loss.item())
            iter += 1
    
            if iter % 500 == 0: ------ 정확도(accuracy) 계산
                correct = 0
                total = 0
                for images, labels in valid_loader: ------ 검증 데이터셋을 이용한 모델 성능 검증
    
                    if torch.cuda.is_available():
                        images = Variable(images.view(-1, seq_dim, input_dim).cuda())
                    else:
                        images = Variable(images.view(-1, seq_dim, input_dim))
    
                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1) ------ 모델을 통과한 결과의 최댓값으로부터 예측 결과 가져오기
    
                    total += labels.size(0) ------ 총 레이블 수
                    if torch.cuda.is_available():
                        correct += (predicted.cpu() == labels.cpu()).sum()
                    else:
                        correct += (predicted == labels).sum()
    
                accuracy = 100 * correct / total
                print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.item(), accuracy))
    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.