import cv2
import numpy as np

from app.lib.handwriting.errors import PipelineInputError
from app.lib.handwriting.io_utils import dump_image


def _percentile_bounds(indices):
    if indices.size == 0:
        return None
    idx = np.sort(indices)
    if idx.size < 20:
        return int(idx[0]), int(idx[-1])
    lo = int(idx[int(idx.size * 0.05)])
    hi = int(idx[int(idx.size * 0.95)])
    return lo, hi


def _bounds_from_projection(line_mask, axis_size, is_rows=True):
    sums = np.sum(line_mask > 0, axis=1 if is_rows else 0)
    if sums.size == 0:
        return 0, axis_size - 1
    max_sum = float(sums.max()) if sums.max() > 0 else 0.0
    thresh = max(10.0, max_sum * 0.10)
    idx = np.where(sums > thresh)[0]
    if idx.size == 0:
        idx = np.where(sums > 0)[0]
    if idx.size == 0:
        return 0, axis_size - 1
    bounds = _percentile_bounds(idx)
    if bounds is None:
        return int(idx[0]), int(idx[-1])
    return bounds


def detect_grid_roi(rectified_gray, out_dir):
    h, w = rectified_gray.shape[:2]

    thr = cv2.adaptiveThreshold(
        rectified_gray,
        255,
        cv2.ADAPTIVE_THRESH_MEAN_C,
        cv2.THRESH_BINARY_INV,
        31,
        15,
    )

    horiz_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (max(25, w // 20), 1))
    vert_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, max(25, h // 20)))

    horiz = cv2.dilate(cv2.erode(thr, horiz_kernel, 1), horiz_kernel, 1)
    vert = cv2.dilate(cv2.erode(thr, vert_kernel, 1), vert_kernel, 1)

    grid_mask = cv2.bitwise_or(horiz, vert)
    dump_image(out_dir, "a6_grid_mask.jpg", grid_mask)

    # Close small gaps so the grid forms a single contour even with faint lines.
    close_kernel = cv2.getStructuringElement(
        cv2.MORPH_RECT,
        (max(15, w // 80), max(15, h // 80)),
    )
    grid_mask_closed = cv2.morphologyEx(grid_mask, cv2.MORPH_CLOSE, close_kernel)

    contours, _ = cv2.findContours(
        grid_mask_closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )
    if contours:
        page_area = float(h * w)
        filtered = []
        for c in contours:
            area = cv2.contourArea(c)
            if area < page_area * 0.05:
                continue
            if area > page_area * 0.98:
                continue
            filtered.append(c)
        if filtered:
            c = max(filtered, key=cv2.contourArea)
            gx, gy, gw, gh = cv2.boundingRect(c)
        else:
            contours = []

    # Prefer projection bounds over contours to handle faint bottom rows.
    if not contours:
        top, bottom = _bounds_from_projection(horiz, h, is_rows=True)
        left, right = _bounds_from_projection(vert, w, is_rows=False)
        if bottom <= top or right <= left:
            raise PipelineInputError(
                "Could not detect the grid. Please upload a clear photo of the full template "
                "with visible grid lines and all four markers."
            )
        gx, gy = left, top
        gw, gh = right - left + 1, bottom - top + 1

    # Tighten top/bottom bounds to strong horizontal grid lines to drop header/footer.
    horiz_sum = np.sum(horiz > 0, axis=1)
    if horiz_sum.size > 0:
        max_sum = float(horiz_sum.max()) if horiz_sum.max() > 0 else 0.0
        line_thresh = max(10.0, max_sum * 0.35)
        rows = np.where(horiz_sum > line_thresh)[0]
        if rows.size > 0:
            first_line = int(rows[0])
            last_line = int(rows[-1])
            if first_line + 5 < gy + gh:
                gy = max(first_line - 2, 0)
            if last_line - 5 > gy:
                gh = max((last_line + 2) - gy, 1)

    pad = 3
    gx = max(gx - pad, 0)
    gy = max(gy - pad, 0)
    gw = min(gw + 2 * pad, w - gx)
    gh = min(gh + 2 * pad, h - gy)

    grid_crop = rectified_gray[gy : gy + gh, gx : gx + gw]
    dump_image(out_dir, "a7_grid_cropped.jpg", grid_crop)

    return grid_crop, (gx, gy, gw, gh)
