① argmax는 이미 여러 차례 알아보았습니다. argmax는 출력된 열 중에서 가장 큰 값을 반환합니다. 예를 들어 다음과 같이 사용할 수 있습니다.
import torch a = torch.randn(4, 4) argmax = torch.argmax(a) print(a) print(argmax)
다음과 같이 4×4 형태의 임의의 텐서를 생성한 후 그중 가장 큰 값(2.2177)을 갖는 인덱스(13)를 반환합니다.
tensor([[ 0.6908, -0.2365, 0.3776, -0.7609], [ 1.0446, -0.2636, 0.4735, 1.6540], [-1.4138, 0.3427, 0.4632, -0.4743], [-0.2170, 2.2177, 0.2166, -0.2508]]) tensor(13)
따라서 softmax(logits) 값 중에서 가장 큰 값을 반환합니다.
② 데이터셋에 대한 오차가 감소할 때마다 모델을 저장하여 가장 낮은 오차를 갖는 모델로 학습을 마치도록 합니다.
사전 학습된 버트 모델에서 파라미터(옵티마이저와 학습률)를 미세 조정 후 모델을 학습시킵니다.