import json
import os

import cv2

from app.lib.handwriting.io_utils import dump_image, ensure_dir


def load_template(template_path):
    with open(template_path, "r") as f:
        return json.load(f)


def extract_glyphs(grid_img, template, out_dir, grid_color=None):
    rows = template["grid"]["rows"]
    cols = template["grid"]["cols"]
    cells = template["grid"]["cells"]
    # print(f"templates are '{template}'")

    h, w = grid_img.shape[:2]
    cell_w = w / float(cols)
    cell_h = h / float(rows)

    glyph_dir = os.path.join(out_dir, "glyphs")
    ensure_dir(glyph_dir)

    for cell in cells:
        ch = cell["char"]
        row = cell["row"] - 1
        col = cell["col"] - 1

        x1 = int(col * cell_w)
        y1 = int(row * cell_h)
        x2 = int((col + 1) * cell_w)
        y2 = int((row + 1) * cell_h)
        raw = grid_img[y1:y2, x1:x2]
        # Crop inside the cell to avoid grid lines
        pad_ratio = 0.12
        pad_x = int((x2 - x1) * pad_ratio)
        pad_y = int((y2 - y1) * pad_ratio)
        inner = raw[pad_y : (y2 - y1) - pad_y, pad_x : (x2 - x1) - pad_x]
        if inner.size == 0:
            inner = raw
        dump_image(out_dir, f"glyphs/glyph_{ch}_raw.jpg", inner)

        if grid_color is not None:
            raw_color = grid_color[y1:y2, x1:x2].copy()
            inner_color = raw_color[
                pad_y : (y2 - y1) - pad_y, pad_x : (x2 - x1) - pad_x
            ]
            if inner_color.size == 0:
                inner_color = raw_color
            no_lines = remove_grid_lines_color(inner_color)
        else:
            no_lines = remove_grid_lines(inner)
        dump_image(out_dir, f"glyphs/glyph_{ch}_nolines.jpg", no_lines)
        dump_image(out_dir, f"glyphs/glyph_{ch}_binary.png", no_lines)


def remove_grid_lines(cell_img):
    h, w = cell_img.shape[:2]

    thr = cv2.adaptiveThreshold(
        cell_img,
        255,
        cv2.ADAPTIVE_THRESH_MEAN_C,
        cv2.THRESH_BINARY_INV,
        25,
        12,
    )

    horiz_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (max(30, int(w / 1.5)), 1))
    vert_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, max(30, int(h / 1.5))))
    horiz = cv2.morphologyEx(thr, cv2.MORPH_OPEN, horiz_kernel)
    vert = cv2.morphologyEx(thr, cv2.MORPH_OPEN, vert_kernel)

    line_mask = cv2.bitwise_or(horiz, vert)
    no_lines = cv2.subtract(thr, line_mask)
    return no_lines


def remove_grid_lines_color(cell_bgr):
    # Remove pale blue grid lines by color, then binarize.
    b, g, r = cv2.split(cell_bgr)
    mask_blue = (
        (b.astype(int) - r.astype(int) > 8)
        & (b.astype(int) - g.astype(int) > 4)
        & (b > 150)
    )
    cleaned = cell_bgr.copy()
    cleaned[mask_blue] = (255, 255, 255)
    gray = cv2.cvtColor(cleaned, cv2.COLOR_BGR2GRAY)
    _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    return binary
