더북(TheBook)
# Subgroup
weights[(subgroup_bool)] += 0.25
# toxic_logits size: (batch_size x 11)
# ident_logits size: (batch_size x 9)
toxic_logits, ident_logits = model(input_ids, segment_ids, input_mask)
toxic_loss = (
    F.cross_entropy(toxic_logits, toxic_target, reduction="none") * weight
).mean()
ident_loss = F.binary_cross_entropy_with_logits(ident_logits, ident_target)
loss = 0.75 * toxic_loss + 0.25 * ident_loss

앞서 서브 클래스 타깃을 적용하기 위해 toxic_logits에는 11개 서브 클래스에 대한 크로스 엔트로피 함수를 적용했습니다. 또한, Subgroup에 해당하는 ident_logits에는 이진 크로스 엔트로피 함수를 적용하고 있습니다.

신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.