# bnsp에는 가중치를 추가하지 않는다.
    subgroup_bpsn_weight = (
        (overall * W_a)
        + (any_subgroup_binary.astype(int) * W_a)
        + (bpsn_binary.astype(int) * W_a)
    ) * FINAL_LOSS_WEIGHT
    
    # subgroup_bpsn_weight를 배치마다 적용하기 위해 targets에 병합
    # preds 인덱스 순서: y_pred(확률값), y_sub(subtypes)
    # targets 인덱스 순서: y_true(확률값), y_sub(subtypes), subgroup_bpsn_weight
    def loss_func(preds, targets): 
        # general auc를 위한 original loss(확률)
        bce_loss_original = nn.BCEWithLogitsLoss()(
            preds[:, :1], targets[:, :1]
        )  
        # subgroup과 bpsn에 가중치를 반영한 weighted loss(binary)
        targets_binary = (targets[:, :1] >= 0.5).to(preds.dtype)
        bce_loss_weighted = nn.BCEWithLogitsLoss(weight=targets[:, -1:])(
            preds[:, :1], targets_binary
        )    
        # subtype attribute들을 활용한 loss
        bce_loss_sub = nn.BCEWithLogitsLoss()(preds[:, 2:], targets[:, 2:-1])
        
        return bce_loss_weighted + bce_loss_original + bce_loss_sub
    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.