// React and MUI imports
import React, { useRef, useState, useEffect } from 'react';
import { Backdrop, CircularProgress, Box, Paper, Typography, useTheme, } from '@mui/material';

import * as tf from '@tensorflow/tfjs';

// Konva
import Konva from 'konva';
import { Stage, Layer, Image, Line } from 'react-konva';
import useImage from 'use-image';
import simplify from 'simplify-js';
import contour2D from 'contour-2d';
import pack from 'ndarray-pack';
import { Image as ImageJS } from 'image-js';

// API
import { useSelector } from "react-redux";

// Import toolbar
import ImageOptionBar from './ImageOptionBar';

// Image Layer with Memo to avoid re-rendering
const ImageLayer = React.memo(({ image, imageRef, brightness, contrast }) => {
    return (
        <Layer>
            <Image
                ref={imageRef}
                image={image}
                filters={[Konva.Filters.Brighten, Konva.Filters.Contrast]}
                brightness={brightness}
                contrast={contrast}
            />
        </Layer>
    );
});

// Random ID generator (based on the current time)
const generateId = () => {
    return Date.now().toString();
};

// Function to convert hex to rgba
function hexToRGBA(hex, alpha = 0.5) {
    // Remove the hash at the start if it's there
    hex = hex.replace(/^#/, '');

    // Parse the r, g, b values
    let r, g, b;
    if (hex.length === 3) {
        // If it's a shorthand hex color
        r = parseInt(hex[0] + hex[0], 16);
        g = parseInt(hex[1] + hex[1], 16);
        b = parseInt(hex[2] + hex[2], 16);
    } else if (hex.length === 6) {
        // If it's a full hex color
        r = parseInt(hex.substring(0, 2), 16);
        g = parseInt(hex.substring(2, 4), 16);
        b = parseInt(hex.substring(4, 6), 16);
    } else {
        throw new Error("Invalid hex color: " + hex);
    }

    // Return the rgba() formatted color
    return `rgba(${r}, ${g}, ${b}, ${alpha})`;
}

// Image Editor component
const ImageEditor = ( { accessToken, annotations, setAnnotations, labels, inputModel, modelReady, annotationSelected } ) => {
    const theme = useTheme();
    const [image] = useImage(accessToken, 'Anonymous'); // Anonymous to use CORS with konva
    const user = useSelector((state) => state.persistedReducer.user);
    const [model, setModel] = useState(inputModel);

    // Use Effect to load the model every time the input model changes
    useEffect(() => {
        setModel(inputModel);
    }, [inputModel]);

    // Refs for the stage, paper, and image
    const stageRef = useRef(null);
    const paperRef = useRef(null);
    const imageRef = useRef(null);

    // Sate size and original dimensions
    const [stageSize, setStageSize] = useState({ width: 500, height: 100 });

    // Editor initial state
    const initialState = {
        scale: 1,
        position: { x: 0, y: 0 },
        drawing: false,
        lines: [],
        shapes: [],
        currentMode: 'NONE',
        brightness: 0,
        contrast: 0,
        strokeSize: 5,
        selectedLabel: { 'text': '', 'color': '#000000' },
        tooltipOpen: false,
        tooltipContent : '',
        tooltipPosition: { x: 0, y: 0 },
        highlightedShapeId: null,
        toolbarDisabled: true,
    };
    const [state, setState] = useState(initialState);

    // Function to update the state
    const updateState = (newValues) => {
        setState((prevState) => ({
            ...prevState,
            ...newValues,
        }));
    };

    // Place the image in the center of the canvas at the start
    useEffect(() => {
        if (image && stageRef.current) {
            // Determine scale to fit the image in the canvas, maintaining aspect ratio
            const scaleX = stageSize.width / image.width;
            const scaleY = stageSize.height / image.height;
            const scaleToFit = Math.min(scaleX, scaleY);

            // Calculate centered position based on scaled image size
            const centeredX = (stageSize.width - (image.width * scaleToFit)) / 2;
            const centeredY = (stageSize.height - (image.height * scaleToFit)) / 2;

            updateState({
                scale: scaleToFit,
                position: { x: centeredX, y: centeredY },
                originalDimensions: { width: image.width, height: image.height },
            });

            // Apply transformations to stage for immediate visual update
            stageRef.current.scale({ x: scaleToFit, y: scaleToFit });
            stageRef.current.position({ x: centeredX, y: centeredY });
            stageRef.current.batchDraw();
        }
    }, [image, stageRef]); // eslint-disable-line react-hooks/exhaustive-deps

    // Cache the image when it's loaded (for brightness and contrast filters)
    useEffect(() => {
        if (image) {
            imageRef.current.cache();
            imageRef.current.getLayer().batchDraw();

            // Enable the toolbar
            updateState({ toolbarDisabled: false });
        }
    }, [image]);

    // Update the stage size when the paper ref changes
    useEffect(() => {
        const updateStageSize = () => {
            if (paperRef.current) {
                setStageSize({
                    width: paperRef.current.offsetWidth,
                    height: paperRef.current.offsetHeight,
                });
            }
        };

        // Update stage size initially
        updateStageSize();
        window.addEventListener('resize', updateStageSize);

        // Cleanup the event listener on component unmount
        return () => window.removeEventListener('resize', updateStageSize);
    }, []);

    // Initialize the shapes based on the current image
    useEffect(() => {
        if (annotations.length > 0) {
            updateState({ shapes: annotations });
        }
        else{
            updateState({ shapes: [] });
        }

    }, [annotations]);

    // Zoom in and out with the mouse wheel
    const handleWheel = (e) => {
        if (state.currentMode !== 'PANNING') return; // Only zoom if panning is enabled

        e.evt.preventDefault();
        const scaleBy = e.evt.deltaY > 0 ? 0.9 : 1.1; // Faster zoom on wheel scroll

        const stage = e.target.getStage();
        const oldScale = stage.scaleX();
        const pointer = stage.getPointerPosition();

        const mousePointTo = {
            x: (pointer.x - stage.x()) / oldScale,
            y: (pointer.y - stage.y()) / oldScale,
        };

        const newScale = oldScale * scaleBy;

        const newPos = {
            x: pointer.x - mousePointTo.x * newScale,
            y: pointer.y - mousePointTo.y * newScale,
        };

        // Set the new scale and position
        updateState({ scale: newScale });
        updateState({ position: newPos });
    };

    // Mouse down handlers
    const handleMouseDown = (e) => {
        if (state.currentMode === 'DRAWING') {
            updateState({ drawing: true });
            // Get the current stage
            const stage = stageRef.current.getStage();
            // Get the mouse position relative to the stage
            const pos = stage.getPointerPosition();
            // Adjust the position based on the stage's scale and position
            const adjustedPos = {
                x: (pos.x - stage.x()) / stage.scaleX(),
                y: (pos.y - stage.y()) / stage.scaleX(), // Assuming uniform scaling for simplicity
            };
            updateState({ lines: [...state.lines, { points: [adjustedPos.x, adjustedPos.y] }] });
        }
    };

    // Mouse move handlers
    const handleMouseMove = (e) => {
        if (!state.drawing || state.currentMode !== 'DRAWING') return;

        requestAnimationFrame(() => {
            const stage = stageRef.current.getStage();
            const point = stage.getPointerPosition();
            // Adjust the point based on the stage's transformation
            const adjustedPoint = {
                x: (point.x - stage.x()) / stage.scaleX(),
                y: (point.y - stage.y()) / stage.scaleX(),
            };

            let lastLine = state.lines[state.lines.length - 1];
            // Update the last line with the new, adjusted point
            lastLine.points = lastLine.points.concat([adjustedPoint.x, adjustedPoint.y]);
            state.lines.splice(state.lines.length - 1, 1, lastLine);
            updateState({ lines: state.lines.concat() });
        });
    };

    // Mouse up handler
    const handleMouseUp = () => {
        updateState({ drawing: false });
    };

    // Finalize the shape and close it
    const finalizeShape = () => {
        if (state.lines.length === 0) return;

        // Accumulate all points from the drawn lines
        let allPoints = [];
        state.lines.forEach(line => {
            // Convert points to {x, y} format for simplify-js
            for (let i = 0; i < line.points.length; i += 2) {
                allPoints.push({ x: line.points[i], y: line.points[i + 1] });
            }
        });

        // Simplify the points
        const tolerance = 1; // Define a suitable tolerance for your needs
        const simplifiedPoints = simplify(allPoints, tolerance, true);

        // Optionally, ensure the shape is closed by adding the start point at the end
        if (simplifiedPoints.length > 1 &&
            (simplifiedPoints[0].x !== simplifiedPoints[simplifiedPoints.length - 1].x ||
            simplifiedPoints[0].y !== simplifiedPoints[simplifiedPoints.length - 1].y)) {
            simplifiedPoints.push(simplifiedPoints[0]);
        }

        // Convert points back to the flat format expected by Konva
        const pointsForShape = simplifiedPoints.flatMap(p => [p.x, p.y]);

        // Calculate the widht and height of the annotation
        let xs = [];
        let ys = [];

        // Extracting x and y values separately
        for (let i = 0; i < pointsForShape.length; i += 2) {
            xs.push(pointsForShape[i]);
            ys.push(pointsForShape[i + 1]);
        }

        const minX = Math.min(...xs);
        const maxX = Math.max(...xs);
        const minY = Math.min(...ys);
        const maxY = Math.max(...ys);

        // Compute and round width and height if not defined
        const width = Math.round(maxX - minX);
        const height = Math.round(maxY - minY);

        // Create the new shape with the simplified points
        const newShape = {
            id: generateId(),
            text: state.selectedLabel.text,
            stroke: state.selectedLabel.color,
            strokeWidth: state.strokeSize,
            fill: hexToRGBA(state.selectedLabel.color, 0.5),
            points: pointsForShape,
            tags: [],
            notes: '',
            createdAt: new Date().toISOString(),
            author: user.email,
            width: width,
            height: height,
            bbox: [minX, minY, maxX, maxY],
        };

        // Add the new shape to your collection of shapes
        updateState({ shapes: [...state.shapes, newShape] });

        // Clear the lines as they are no longer needed
        updateState({ lines: [] });

        // Add the new shape to the annotations array on currentImage
        setAnnotations([...annotations, newShape]);
    };

    // Shape click handler
    const handleShapeClick = (shapeId) => {
        // If in inspection mode, highlight the clicked shape
        if (state.currentMode === 'INSPECTION') {
            // Toggle the highlight off if the same shape is clicked again
            if (shapeId === state.highlightedShapeId) {
                updateState({ highlightedShapeId: null });
                annotationSelected('');
            } else {
                updateState({ highlightedShapeId: shapeId });
                annotationSelected(shapeId);
            }
        }

        // If in delete mode, delete the clicked shape
        if (state.currentMode === 'DELETE') {
            // Remove the shape from the shapes array and update the state
            updateState({
                shapes: state.shapes.filter(shape => shape.id !== shapeId),
                highlightedShapeId: null
            });

            // Remove the shape from the annotations array on currentImage
            setAnnotations(annotations.filter(annotation => annotation.id !== shapeId));
        }
    };

    // Download image functionality
    const downloadImage = () => {
        // Store the current stage settings
        const oldDimensions = { width: stageRef.current.width(), height: stageRef.current.height() };
        const oldScale = { x: stageRef.current.scaleX(), y: stageRef.current.scaleY() };
        const oldPosition = { x: stageRef.current.x(), y: stageRef.current.y() };

        // Temporarily adjust the stage to match the original image size
        stageRef.current.width(state.originalDimensions.width);
        stageRef.current.height(state.originalDimensions.height);
        stageRef.current.scaleX(1);
        stageRef.current.scaleY(1);
        stageRef.current.x(0);
        stageRef.current.y(0);
        stageRef.current.draw(); // Make sure the stage is redrawn with these settings

        // Export the image
        const dataURL = stageRef.current.toDataURL();
        const link = document.createElement('a');
        link.href = dataURL;
        link.download = 'AnotiaImage.png';
        document.body.appendChild(link);
        link.click();
        document.body.removeChild(link);

        // Revert the stage to its original settings
        stageRef.current.width(oldDimensions.width);
        stageRef.current.height(oldDimensions.height);
        stageRef.current.scaleX(oldScale.x);
        stageRef.current.scaleY(oldScale.y);
        stageRef.current.x(oldPosition.x);
        stageRef.current.y(oldPosition.y);
        stageRef.current.batchDraw(); // Redraw the stage with the original settings
    };

    // Magic Assistant for Object Detection
    const magicAssistant = async () => {
        // If the model is null, show an alert and return
        if (!model) {
            return -1;
        }

        if (stageRef.current) {
            // Phase 1. Capture the current stage settings
            const oldDimensions = { width: stageRef.current.width(), height: stageRef.current.height() };
            const oldScale = { x: stageRef.current.scaleX(), y: stageRef.current.scaleY() };
            const oldPosition = { x: stageRef.current.x(), y: stageRef.current.y() };

            // Temporarily adjust the stage to match the original image size
            stageRef.current.width(state.originalDimensions.width);
            stageRef.current.height(state.originalDimensions.height);
            stageRef.current.scaleX(1);
            stageRef.current.scaleY(1);
            stageRef.current.x(0);
            stageRef.current.y(0);
            stageRef.current.draw(); // Make sure the stage is redrawn with these settings

            // Get the image from the ref
            let tensor = tf.browser.fromPixels(imageRef.current.toCanvas());

            // Revert the stage to its original settings
            stageRef.current.width(oldDimensions.width);
            stageRef.current.height(oldDimensions.height);
            stageRef.current.scaleX(oldScale.x);
            stageRef.current.scaleY(oldScale.y);
            stageRef.current.x(oldPosition.x);
            stageRef.current.y(oldPosition.y);
            stageRef.current.batchDraw(); // Redraw the stage with the original settings


            // Phase 2. Run the model (scale to 256x256, cast to int32, expandDims, executeAsync)
            tensor = tensor.cast('int32');
            tensor = tensor.expandDims(0);
            const prediction = await model.executeAsync(tensor);
            let scores, classes, masks, boxes, imageInfo, number;

            // Check the content to assign the right tensor to the right variable
            for (let i = 0; i < prediction.length; i++) {
                const tensorContent = await prediction[i].array();

                // Case 1. Content[0] is a number (number of clasess)
                if (typeof tensorContent[0] === 'number') {
                    number = tensorContent[0];
                }

                // Case 2. Array of integers (classes)
                if (Array.isArray(tensorContent[0]) && Number.isInteger(tensorContent[0][0])) {
                    classes = tensorContent[0];
                }

                // Case 3. Array of floats (scores)
                if (Array.isArray(tensorContent[0]) && !Array.isArray(tensorContent[0][0]) && !Number.isInteger(tensorContent[0][0])) {
                    scores = tensorContent[0];
                }

                // Case 4. Array of arrays of length 4 (boxes)
                if (Array.isArray(tensorContent[0]) && Array.isArray(tensorContent[0][0]) && tensorContent[0][0].length === 4) {
                    boxes = tensorContent[0];
                }

                // Case 5. Array of arrays of length 2 (imageInfo)
                if (Array.isArray(tensorContent[0]) && Array.isArray(tensorContent[0][0]) && tensorContent[0][0].length === 2) {
                    imageInfo = tensorContent[0];
                }

                // Case 6. Array of arrays of length 28 or higher (masks)
                if (Array.isArray(tensorContent[0]) && Array.isArray(tensorContent[0][0]) && tensorContent[0][0].length >= 28) {
                    masks = tensorContent[0];
                }
            }

            // Get the scale factor to scale the points to the original image size
            const scaleFactor = {height: imageInfo[2][0], width: imageInfo[2][1]};

            // Keep only the predictions with score > 0.5
            const threshold = 0.5;
            const filteredPredictions = [];

            for (let i = 0; i < scores.length; i++) {
                if (scores[i] > threshold) {
                    // Scale the points to the real dimensions of the image
                    const xmin = boxes[i][1] / scaleFactor.width;
                    const ymin = boxes[i][0] / scaleFactor.height;
                    const xmax = boxes[i][3] / scaleFactor.width;
                    const ymax = boxes[i][2] / scaleFactor.height;

                    filteredPredictions.push({
                        class: classes[i],
                        score: scores[i],
                        mask: masks[i],
                        box: [ymin, xmin, ymax, xmax],
                    });
                }
            }

            // Phase 3. Draw the predictions on the canvas
            // Moving average function to smooth the points
            const smoothPoints = (points, windowSize = 1) => {
                const smoothed = [];
                for (let i = 0; i < points.length; i += 2) {
                    let avgX = points[i];
                    let avgY = points[i + 1];
                    let count = 1;
                    for (let j = 2; j <= windowSize * 2; j += 2) {
                        if (i - j >= 0) {
                            avgX += points[i - j];
                            avgY += points[i + 1 - j];
                            count++;
                        }
                        if (i + j < points.length) {
                            avgX += points[i + j];
                            avgY += points[i + 1 + j];
                            count++;
                        }
                    }
                    smoothed.push(avgX / count, avgY / count);
                }
                return smoothed;
            };

            // Function to reconstruct the mask and prepare it for Konva shapes
            const reconstructAndPrepareForKonva = (filteredPredictions) => {
                return filteredPredictions.map(prediction => {
                    const { mask, box } = prediction;
                    const [ymin, xmin, ymax, xmax] = box;

                    // Calculate the dimensions of the bbox (in integer)
                    const bboxWidth = Math.round((xmax - xmin));
                    const bboxHeight = Math.round((ymax - ymin));

                    // Convert mask to binary and extract contours
                    const threshold = 0.4;
                    const binaryMask = mask.map(row => row.map(value => (value > threshold ? 1 : 0)));

                    // Convert the binary mask to an Image
                    let maskImage = new ImageJS(28, 28, binaryMask.flat(), { kind: 'GREY' });

                    // Resize the image to the size of the bounding box
                    let resizedImage = maskImage.resize({ width: bboxWidth, height: bboxHeight });

                    // Convert the resized image back to a 2D array
                    let resizedMask = [];
                    for (let i = 0; i < bboxHeight; i++) {
                        resizedMask[i] = [];
                        for (let j = 0; j < bboxWidth; j++) {
                            resizedMask[i][j] = resizedImage.getPixelXY(j, i)[0];
                        }
                    }

                    // Compute the contours of the resized mask
                    const ndarrayMask = pack(resizedMask);
                    const contours = contour2D(ndarrayMask);

                    // Transform contours to Konva points format
                    let points = [];
                    contours.forEach(contour => {
                        contour.forEach(([x, y]) => {
                            // Adjust x and y based on the bounding box
                            const adjustedX = xmin + x;
                            const adjustedY = ymin + y;

                            points.push(adjustedX, adjustedY);
                        });
                    });

                    // Smooth the points
                    points = smoothPoints(points);

                    return {
                        id: generateId() + Math.random().toString(36).substring(2),
                        text: labels[prediction.class-1].text,
                        stroke: labels[prediction.class-1].color,
                        strokeWidth: state.strokeSize,
                        fill: hexToRGBA(labels[prediction.class-1].color, 0.5),
                        points: points,
                        closed: true,
                        tags: [],
                        notes: '',
                        createdAt: new Date().toISOString(),
                    };
                });
            };

            // Get the new shapes and the konva shapes
            const konvaShapes = reconstructAndPrepareForKonva(filteredPredictions);

            // Add the new shapes and the konva shapes to the state
            setAnnotations([...annotations, ...konvaShapes]);

            // Add the new shape to the annotations array on currentImage
            updateState({ shapes: [...state.shapes, ...konvaShapes] });

            // Dispose the tensor to free memory
            tensor.dispose();

            return 0;
        }
    };

    // Return the component
    return (
        <Paper
            elevation={3}
            sx={{
                display: 'flex',
                flexDirection:'column',
                flexGrow:1,
                overflow: 'hidden',
                backgroundColor:theme.palette.secondary[100],
                border: "1px solid #ccc",
                borderRadius: '10px'
            }}
        >
            {/* Image Option Bar */}
            <ImageOptionBar
                state={state}
                updateState={updateState}
                downloadImage={downloadImage}
                finalizeShape={finalizeShape}
                labels={labels}
                magicAssistant={magicAssistant}
                modelReady={modelReady}
                annotationSelected={annotationSelected}
            />

            <Box ref={paperRef} sx={{flexGrow:1}}>

                <Stage
                    ref={stageRef}
                    width={stageSize.width}
                    height={stageSize.height}
                    scaleX={state.scale}
                    scaleY={state.scale}
                    x={state.position.x}
                    y={state.position.y}
                    draggable={state.currentMode === 'PANNING'}
                    onWheel={handleWheel}
                    onMouseDown={handleMouseDown}
                    onMousemove={handleMouseMove}
                    onMouseup={handleMouseUp}
                    onDragEnd={() => {updateState({ position: { x: stageRef.current.x(), y: stageRef.current.y() } });}}
                >
                    <ImageLayer image={image} imageRef={imageRef} stageSize={stageSize} brightness={state.brightness} contrast={state.contrast} />

                    <Layer>
                        {state.lines.map((line, i) => (
                            <Line
                                key={i}
                                points={line.points}
                                stroke={state.selectedLabel.color}
                                strokeWidth={state.strokeSize}
                                tension={0.0}
                                lineCap="round"
                                globalCompositeOperation="source-over"
                            />
                        ))}
                    </Layer>

                    <Layer>
                        {state.shapes.map((shape) => (
                            <Line
                                key={shape.id}
                                points={shape.points}
                                stroke={state.currentMode === 'INSPECTION' ? (state.highlightedShapeId === shape.id ? shape.stroke : theme.palette.primary[200]) : shape.stroke}
                                strokeWidth={shape.strokeWidth}
                                closed={true}
                                onClick={() => handleShapeClick(shape.id)}
                                onMouseEnter={(e) => {
                                    if (state.currentMode === 'INSPECTION') {
                                        updateState({ tooltipOpen: true })
                                        updateState({ tooltipContent:
                                            <Typography variant="h1" style={{ fontSize: '1rem' }}>
                                                {shape.text || 'Unknown Pathology'}
                                                {shape.confidence ? `Confidence: ${shape.confidence}%` : ''}
                                            </Typography>
                                        });
                                        updateState({ tooltipPosition: { x: e.evt.clientX, y: e.evt.clientY } })
                                    }
                                }}
                                onMouseLeave={() => {
                                    if (state.currentMode === 'INSPECTION') {
                                        updateState({ tooltipOpen: false });
                                    }
                                }}
                                fill={state.currentMode === 'INSPECTION' ? (state.highlightedShapeId === shape.id ? shape.fill : hexToRGBA(theme.palette.primary[200])) : shape.fill}
                            />
                        ))}
                    </Layer>

                </Stage>

            </Box>

        </Paper>
    );
}

export default ImageEditor;