import os
import cv2
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import joblib
from ultralytics import YOLO
from Helpers.models import MultiHeadResNet101

from pathlib import Path
# Load the object detection model
detection_model = YOLO('Models/Yolov11m/best.pt')
# Load the classification model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attributes_model = MultiHeadResNet101(pretrained=False, requires_grad=False)

# Load the classification model checkpoint
checkpoint = torch.load('Models/ResNet101/best.pth', map_location=device)
attributes_model.load_state_dict(checkpoint['model_state_dict'])
attributes_model.to(device)
attributes_model.eval()

# Load label encodings
num_list_material = joblib.load('Helpers/num_list_material.pkl')
num_list_size = joblib.load('Helpers/num_list_size.pkl')
num_list_orientation = joblib.load('Helpers/num_list_orientation.pkl')

# Define the image transformation
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def process_image(image_path):
    #Load the image
    image = cv2.imread(image_path)
    if image is None:
        print(f"Failed to read image: {image_path}")
        return {'detections': [], 'path': image_path}

    # Run object detection
    results = detection_model(image_path, conf=0.70)

    detections = []
    for result in results:
        for box in result.boxes:
            bbox = box.xyxy[0].tolist()  # [x1, y1, x2, y2]
            label = detection_model.model.names[int(box.cls[0])]
            confidence = float(box.conf[0])

            # Convert bbox coordinates to integers
            x1, y1, x2, y2 = map(int, bbox)
            # Ensure coordinates are within image bounds
            x1 = max(0, x1)
            y1 = max(0, y1)
            x2 = min(image.shape[1], x2)
            y2 = min(image.shape[0], y2)

            # Check for valid bbox
            if x2 <= x1 or y2 <= y1:
                print(f"Invalid bbox coordinates: {bbox}")
                continue

            # Crop the detection from the image
            cropped_image = image[y1:y2, x1:x2]

            # Preprocess the cropped image
            if cropped_image.size == 0:
                # Skip if the crop is invalid
                print(f"Invalid crop for bbox {bbox}")
                continue

            # Attributes Classification for this cylinder instance
            input_tensor = transform(cropped_image)
            input_tensor = input_tensor.unsqueeze(0).to(device)

            # Run classification model
            with torch.no_grad():
                outputs = attributes_model(input_tensor)
                # Get outputs for material, size, orientation
                output1, output2, output3 = outputs

                # Apply softmax
                probabilities1 = F.softmax(output1, dim=1)
                probabilities2 = F.softmax(output2, dim=1)
                probabilities3 = F.softmax(output3, dim=1)

                # Get the index positions of the highest label score
                out_label_1 = probabilities1.cpu().numpy().argmax()
                out_label_2 = probabilities2.cpu().numpy().argmax()
                out_label_3 = probabilities3.cpu().numpy().argmax()

                # Map indices back to labels
                material_keys = list(num_list_material.keys())
                material_values = list(num_list_material.values())
                size_keys = list(num_list_size.keys())
                size_values = list(num_list_size.values())
                orientation_keys = list(num_list_orientation.keys())
                orientation_values = list(num_list_orientation.values())

                material_label = material_keys[material_values.index(out_label_1)]
                size_label = size_keys[size_values.index(out_label_2)]
                orientation_label = orientation_keys[orientation_values.index(out_label_3)]

            # Append the detection with classification results
            detections.append({
                'bbox': bbox,
                'label': label,
                'confidence': confidence,
                'material': material_label,
                'size': size_label,
                'orientation': orientation_label
            })

    results = {'detections': detections, 'path': image_path}
    return results

def visualize (image_path, results):
    image = cv2.imread(image_path)
    
    os.makedirs("Results_images", exist_ok=True)
    for detection in results['detections']:
        x1, y1, x2, y2 = map(int, detection['bbox'])  # Convert float to int
        label = detection['label']
        confidence = detection['confidence']
        material = detection['material']
        size = detection['size']
        orientation = detection['orientation']

        # Draw bounding box
        color = (0, 255, 0)  # Green color
        thickness = 2
        cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)

        # Prepare text to display
        text = f"{label} ({confidence:.2f})"
        additional_text = f"{material}, {size}, {orientation}"

        # Put the label text
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.6
        font_thickness = 1
        text_color = (0, 255, 255)  # Yellow color
        cv2.putText(image, text, (x1, y1 - 10), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
        cv2.putText(image, additional_text, (x1, y1 - 30), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
    cv2.imwrite(os.path.join('Results_images', image_path.split('/')[1]), image)


if __name__ == "__main__":
    pathlist = Path('Test_images').glob('**/*.jpg')
    for path in pathlist:
        path_in_str = str(path)
        results = process_image(path_in_str)
        visualize(path_in_str, results)

