import math

import cv2
import numpy as np

from app.lib.handwriting.config import DEFAULT_ARUCO_DICT, DEFAULT_ARUCO_IDS
from app.lib.handwriting.errors import PipelineInputError
from app.lib.handwriting.io_utils import dump_image


def _dist(a, b):
    return math.hypot(b[0] - a[0], b[1] - a[1])


def get_aruco_dictionary(dict_name):
    if not dict_name.startswith("DICT_"):
        dict_name = f"DICT_{dict_name}"
    if not hasattr(cv2.aruco, dict_name):
        raise ValueError(f"Unknown ArUco dictionary: {dict_name}")
    return cv2.aruco.getPredefinedDictionary(getattr(cv2.aruco, dict_name))


def find_aruco_markers(
    image,
    dict_name=DEFAULT_ARUCO_DICT,
    expected_ids=DEFAULT_ARUCO_IDS,
    verbose=False,
    out_dir=None,
):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    h, w = gray.shape

    dictionary = get_aruco_dictionary(dict_name)
    if hasattr(cv2.aruco, "ArucoDetector"):
        params = cv2.aruco.DetectorParameters()
        detector = cv2.aruco.ArucoDetector(dictionary, params)
        corners, ids, _ = detector.detectMarkers(gray)
    else:
        params = cv2.aruco.DetectorParameters_create()
        corners, ids, _ = cv2.aruco.detectMarkers(gray, dictionary, parameters=params)

    if ids is None or len(ids) == 0:
        raise PipelineInputError(
            "Could not detect the markers. Make sure all four markers are visible "
            "and the photo is clear."
        )

    ids = ids.flatten().tolist()
    id_to_corner = {
        expected_ids[0]: "tl",
        expected_ids[1]: "tr",
        expected_ids[2]: "br",
        expected_ids[3]: "bl",
    }
    quadrants = {"tl": None, "tr": None, "br": None, "bl": None}

    for marker_corners, marker_id in zip(corners, ids):
        pts = marker_corners[0]
        cx = float(np.mean(pts[:, 0]))
        cy = float(np.mean(pts[:, 1]))
        entry = {"id": marker_id, "corners": pts, "center": (cx, cy)}

        if marker_id in id_to_corner:
            quadrants[id_to_corner[marker_id]] = entry
        else:
            if cx < w / 2 and cy < h / 2:
                quad = "tl"
            elif cx >= w / 2 and cy < h / 2:
                quad = "tr"
            elif cx >= w / 2 and cy >= h / 2:
                quad = "br"
            else:
                quad = "bl"
            quadrants[quad] = entry

    if any(v is None for v in quadrants.values()):
        raise PipelineInputError(
            "Could not detect all four markers. Make sure the full template is visible."
        )

    if verbose or out_dir:
        debug = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
        cv2.aruco.drawDetectedMarkers(debug, corners, np.array(ids))
        for name, info in quadrants.items():
            cx, cy = info["center"]
            cv2.putText(
                debug,
                name.upper(),
                (int(cx), int(cy)),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.7,
                (0, 0, 255),
                2,
            )
        dump_image(out_dir or ".", "a4_aruco_markers.jpg", debug)

    ordered = [quadrants["tl"], quadrants["tr"], quadrants["br"], quadrants["bl"]]
    return ordered


def warp_perspective_using_markers(gray_img, markers, padding_ratio=0.03):
    pts = [m["center"] for m in markers]
    tl, tr, br, bl = [tuple(map(float, p)) for p in pts]
    src = np.array([tl, tr, br, bl], dtype="float32")

    width_a = _dist(bl, br)
    width_b = _dist(tl, tr)
    max_width = int(max(width_a, width_b))

    height_a = _dist(tl, bl)
    height_b = _dist(tr, br)
    max_height = int(max(height_a, height_b))

    pad_x = int(max_width * padding_ratio)
    pad_y = int(max_height * padding_ratio)
    out_w = max_width + pad_x * 2
    out_h = max_height + pad_y * 2

    dst = np.array(
        [
            [pad_x, pad_y],
            [pad_x + max_width - 1, pad_y],
            [pad_x + max_width - 1, pad_y + max_height - 1],
            [pad_x, pad_y + max_height - 1],
        ],
        dtype="float32",
    )

    M = cv2.getPerspectiveTransform(src, dst)
    return cv2.warpPerspective(gray_img, M, (out_w, out_h))


def rectify_page(gray_img, out_dir):
    markers = find_aruco_markers(
        cv2.cvtColor(gray_img, cv2.COLOR_GRAY2BGR),
        dict_name=DEFAULT_ARUCO_DICT,
        expected_ids=DEFAULT_ARUCO_IDS,
        verbose=False,
        out_dir=out_dir,
    )
    if not markers or len(markers) != 4:
        raise PipelineInputError(
            "Could not detect all four markers. Make sure the full template is visible."
        )

    warped = warp_perspective_using_markers(gray_img, markers, padding_ratio=0.03)
    dump_image(out_dir, "a5_rectified_page.jpg", warped)
    return warped


def rectify_page_color(color_img, out_dir):
    markers = find_aruco_markers(
        color_img,
        dict_name=DEFAULT_ARUCO_DICT,
        expected_ids=DEFAULT_ARUCO_IDS,
        verbose=False,
    )
    if not markers or len(markers) != 4:
        raise PipelineInputError(
            "Could not detect all four markers. Make sure the full template is visible."
        )
    warped = warp_perspective_using_markers(color_img, markers, padding_ratio=0.03)
    return warped
