import torch.nn as nn

# custom loss function for multi-head multi-category classification
def loss_fn(outputs, targets):
    o1, o2, o3 = outputs
    t1, t2, t3 = targets
    l1 = nn.CrossEntropyLoss()(o1, t1)
    l2 = nn.CrossEntropyLoss()(o2, t2)
    l3 = nn.CrossEntropyLoss()(o3, t3)

    return (l1 + l2 + l3) / 3, l1, l2, l3 #weighted loss normalized e.g., dynamic weighted loss l1=0.6 l2=0.9, l3=1.5 ; w1 = 0.2, w2=0.3, w3=0.5
