더북(TheBook)

서브 클래스 타깃

이 솔루션 내용 중 가장 색다르면서도 재미있는 부분으로, 주어진 정답 값을 0~1 사이의 값이 아닌 11개의 서브 클래스로 변환해 사용한 것입니다. 데이터셋의 정답이 각 클래스에 해당하는 숫자 사이에 있을 경우 해당 클래스를 적용합니다. 다만, 이렇게 하면 테스트 데이터의 정답을 구할 때 0~1의 확률이 나오지 않으므로 모델의 출력인 길이 11의 벡터와 각 클래스에 해당하는 값을 요소별(Element-wise)로 곱한 뒤 그 합을 예측값으로 사용합니다.

cls_vals = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0]
for cls, v in enumerate(cls_vals):
    if target >= v:
        target = cls

# 11개 클래스의 확률 분포
toxic_logits = toxic_logits[0] # (N, 11)
y_pred = sum([p*v for p, v in zip(toxic_logits, cls_vals)])
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.