import glob
import os
import subprocess

import cv2
import numpy as np

from app.lib.handwriting.config import CLEAN_VECTOR_DIR, MIN_DOT_AREA, VECT_GLYPH_DIR
from app.lib.handwriting.io_utils import ensure_dir


def _pbm_black_count(pbm_path):
    img = cv2.imread(pbm_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return None
    return int(np.sum(img == 0))


def _svg_has_path(svg_path):
    try:
        with open(svg_path, "r") as f:
            return "<path" in f.read()
    except Exception:
        return False


def ensure_binary_polarity(binary):
    if np.mean(binary) > 127:
        binary = cv2.bitwise_not(binary)
    return binary


def remove_small_noise(binary):
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary, 8)
    out = np.zeros_like(binary)
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] >= MIN_DOT_AREA:
            out[labels == i] = 255
    return out


def safe_crop(binary):
    pts = cv2.findNonZero(binary)
    if pts is None:
        return binary
    x, y, w, h = cv2.boundingRect(pts)
    return binary[y : y + h, x : x + w]


def preprocess_vector(binary):
    binary = ensure_binary_polarity(binary)

    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary, 8)

    protected = np.zeros_like(binary)
    main_strokes = np.zeros_like(binary)

    for i in range(1, num_labels):
        area = stats[i, cv2.CC_STAT_AREA]
        if area < 80:
            protected[labels == i] = 255
        else:
            main_strokes[labels == i] = 255

    dist = cv2.distanceTransform(main_strokes, cv2.DIST_L2, 5)
    if np.count_nonzero(main_strokes) > 0:
        mean_stroke = np.mean(dist[main_strokes > 0]) * 2
    else:
        mean_stroke = 0

    if mean_stroke > 6.0:
        main_strokes = cv2.erode(
            main_strokes,
            cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2)),
            iterations=1,
        )

    main_strokes = cv2.morphologyEx(
        main_strokes,
        cv2.MORPH_CLOSE,
        cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2)),
        iterations=1,
    )

    binary = cv2.bitwise_or(main_strokes, protected)
    return binary


def clean_extracted_glyphs(out_dir):
    src_dir = os.path.join(out_dir, "glyphs")
    out_vector = os.path.join(out_dir, CLEAN_VECTOR_DIR)
    ensure_dir(out_vector)

    files = glob.glob(os.path.join(src_dir, "*_binary.png"))
    if not files:
        raise RuntimeError("No glyph binaries found to clean.")

    for path in files:
        name = os.path.basename(path)
        binary = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if binary is None:
            continue
        b2 = preprocess_vector(binary)
        cv2.imwrite(os.path.join(out_vector, name), b2)


def png_to_pbm_for_potrace(png_path, pbm_path):
    img = cv2.imread(png_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return False
    _, bw = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
    bw_inv = cv2.bitwise_not(bw)
    cv2.imwrite(pbm_path, bw_inv)
    return True


def vectorize_glyphs_with_potrace(out_dir):
    src_dir = os.path.join(out_dir, CLEAN_VECTOR_DIR)
    out_dir_svg = os.path.join(out_dir, VECT_GLYPH_DIR)
    ensure_dir(out_dir_svg)

    files = sorted(glob.glob(os.path.join(src_dir, "*_binary.png")))
    if not files:
        raise RuntimeError("No cleaned glyphs found to vectorize.")

    for png_path in files:
        stem = os.path.splitext(os.path.basename(png_path))[0]
        pbm_path = os.path.join(out_dir_svg, f"{stem}.pbm")
        svg_path = os.path.join(out_dir_svg, f"{stem}.svg")

        if not png_to_pbm_for_potrace(png_path, pbm_path):
            continue

        # Adaptive turdsize: keep tiny punctuation dots; filter noise for large glyphs.
        black_count = _pbm_black_count(pbm_path)
        if black_count is None:
            turdsize = 20
        elif black_count < 80:
            turdsize = 2
        elif black_count < 200:
            turdsize = 5
        else:
            turdsize = 20

        cmd = [
            "potrace",
            pbm_path,
            "-s",
            "-o",
            svg_path,
            "--opttolerance",
            "0.80",
            "--alphamax",
            "1.0",
            "--turdsize",
            str(turdsize),
            "--turnpolicy",
            "black",
        ]

        subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

        # If SVG is empty, retry with a minimal turdsize to preserve tiny marks.
        if not _svg_has_path(svg_path):
            cmd_retry = [
                "potrace",
                pbm_path,
                "-s",
                "-o",
                svg_path,
                "--opttolerance",
                "0.80",
                "--alphamax",
                "1.0",
                "--turdsize",
                "1",
                "--turnpolicy",
                "black",
            ]
            subprocess.run(
                cmd_retry, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
            )
