import torch.nn as nn
import torch.nn.functional as F
import pretrainedmodels

class MultiHeadResNet101(nn.Module):
    def __init__(self, pretrained, requires_grad):
        super(MultiHeadResNet101, self).__init__()
        if pretrained == True:
            self.model = pretrainedmodels.__dict__['resnet101'](pretrained='imagenet')
        else:
            self.model = pretrainedmodels.__dict__['resnet101'](pretrained=None)

        if requires_grad == True:
            for param in self.model.parameters():
                param.requires_grad = True
            print('Training intermediate layer parameters...')
        elif requires_grad == False:
            for param in self.model.parameters():
                param.requires_grad = False
            print('Freezing intermediate layer parameters...')

        # change the final layer
        self.l0 = nn.Linear(2048, 3)
        self.l1 = nn.Linear(2048, 3)
        self.l2 = nn.Linear(2048, 3)

    def forward(self, x):
        # get the batch size only, ignore (c, h, w)
        batch, _, _, _ = x.shape
        x = self.model.features(x)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
        l0 = self.l0(x)
        l1 = self.l1(x)
        l2 = self.l2(x)
        return l0, l1, l2

# custom loss function for multi-head multi-category classification
def loss_fn(outputs, targets):
    o1, o2, o3 = outputs
    t1, t2, t3 = targets
    l1 = nn.CrossEntropyLoss()(o1, t1)
    l2 = nn.CrossEntropyLoss()(o2, t2)
    l3 = nn.CrossEntropyLoss()(o3, t3)

    return (l1 + l2 + l3) / 3, l1, l2, l3 #weighted loss normalized e.g., dynamic weighted loss l1=0.6 l2=0.9, l3=1.5 ; w1 = 0.2, w
