더북(TheBook)
# bnsp에는 가중치를 추가하지 않는다.
subgroup_bpsn_weight = (
    (overall * W_a)
    + (any_subgroup_binary.astype(int) * W_a)
    + (bpsn_binary.astype(int) * W_a)
) * FINAL_LOSS_WEIGHT

# subgroup_bpsn_weight를 배치마다 적용하기 위해 targets에 병합
# preds 인덱스 순서: y_pred(확률값), y_sub(subtypes)
# targets 인덱스 순서: y_true(확률값), y_sub(subtypes), subgroup_bpsn_weight
def loss_func(preds, targets): 
    # general auc를 위한 original loss(확률)
    bce_loss_original = nn.BCEWithLogitsLoss()(
        preds[:, :1], targets[:, :1]
    )  
    # subgroup과 bpsn에 가중치를 반영한 weighted loss(binary)
    targets_binary = (targets[:, :1] >= 0.5).to(preds.dtype)
    bce_loss_weighted = nn.BCEWithLogitsLoss(weight=targets[:, -1:])(
        preds[:, :1], targets_binary
    )    
    # subtype attribute들을 활용한 loss
    bce_loss_sub = nn.BCEWithLogitsLoss()(preds[:, 2:], targets[:, 2:-1])
    
    return bce_loss_weighted + bce_loss_original + bce_loss_sub
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.