# bnsp에는 가중치를 추가하지 않는다.
subgroup_bpsn_weight = (
(overall * W_a)
+ (any_subgroup_binary.astype(int) * W_a)
+ (bpsn_binary.astype(int) * W_a)
) * FINAL_LOSS_WEIGHT
# subgroup_bpsn_weight를 배치마다 적용하기 위해 targets에 병합
# preds 인덱스 순서: y_pred(확률값), y_sub(subtypes)
# targets 인덱스 순서: y_true(확률값), y_sub(subtypes), subgroup_bpsn_weight
def loss_func(preds, targets):
# general auc를 위한 original loss(확률)
bce_loss_original = nn.BCEWithLogitsLoss()(
preds[:, :1], targets[:, :1]
)
# subgroup과 bpsn에 가중치를 반영한 weighted loss(binary)
targets_binary = (targets[:, :1] >= 0.5).to(preds.dtype)
bce_loss_weighted = nn.BCEWithLogitsLoss(weight=targets[:, -1:])(
preds[:, :1], targets_binary
)
# subtype attribute들을 활용한 loss
bce_loss_sub = nn.BCEWithLogitsLoss()(preds[:, 2:], targets[:, 2:-1])
return bce_loss_weighted + bce_loss_original + bce_loss_sub