두 번째 단계로 ARIMA() 함수를 사용한 예측을 진행해 보겠습니다. 먼저 필요한 라이브러리를 호출합니다.

    코드 7-2 statsmodels 라이브러리를 이용한 sales 데이터셋 예측

    import numpy as np
    from pandas import read_csv
    from pandas import datetime
    from matplotlib import pyplot
    from statsmodels.tsa.arima_model import ARIMA
    from sklearn.metrics import mean_squared_error
    
    def parser(x):
        return datetime.strptime('199'+x, '%Y-%m')
    
    series = read_csv('../chap7/data/sales.csv', header=0, parse_dates=[0], index_col=0,
                      squeeze=True, date_parser=parser)
    X = series.values
    X = np.nan_to_num(X)
    size = int(len(X) * 0.66)
    train, test = X[0:size], X[size:len(X)] ------ train과 test로 데이터셋 분리
    history = [x for x in train]
    predictions = list()
    for t in range(len(test)): ------ test 데이터셋의 길이(13번)만큼 반복하여 수행
        model = ARIMA(history, order=(5,1,0)) ------ ARIMA() 함수 호출
        model_fit = model.fit(disp=0)
        output = model_fit.forecast() ------ forecast() 메서드를 사용하여 예측 수행
        yhat = output[0] ------ 모델 출력 결과를 yhat에 저장
        predictions.append(yhat)
        obs = test[t]
        history.append(obs)
        print('predicted=%f, expected=%f' % (yhat, obs)) ------ 모델 실행 결과를 predicted로 출력하고, test로 분리해 둔 데이터를 expected로 사용하여 출력
    error = mean_squared_error(test, predictions) ------ 손실 함수로 평균 제곱 오차 사용
    print('Test MSE: %.3f' % error)
    pyplot.plot(test)
    pyplot.plot(predictions, color='red')
    pyplot.show()
    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.