import traceback
from fastapi import Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import HTTPException, RequestValidationError
from starlette import status
from app.middleware.request_id import REQUEST_ID_HEADER
from app.utils.env import is_dev_env, is_prod_env, is_staging_env
from app.utils.response import error_payload
from app.utils.logger import get_logger
from app.services.notifications.slack_notifier import notify_error

logger = get_logger(__name__)


def _get_request_id(request: Request) -> str | None:
    request_id = getattr(request.state, "request_id", None)
    if request_id:
        return request_id
    return request.headers.get(REQUEST_ID_HEADER)


def _with_request_id_header(request: Request, response: JSONResponse) -> JSONResponse:
    request_id = _get_request_id(request)
    if request_id:
        response.headers[REQUEST_ID_HEADER] = request_id
    return response


async def _notify_slack(
    request: Request,
    message: str,
    status_code: int,
    error: Exception | None = None,
) -> None:
    error_name = type(error).__name__ if error else None
    error_message = str(error) if error else None
    error_stack = traceback.format_exc() if error else None
    slack_message = error_message or message
    context = {
        "title": slack_message,
        "message": slack_message,
        "status_code": status_code,
        "method": request.method,
        "path": request.url.path,
        "request_id": _get_request_id(request),
        "user_agent": request.headers.get("user-agent"),
        "client_ip": request.client.host if request.client else None,
        "error_name": error_name,
        "error_message": error_message,
        "error_stack": error_stack,
    }
    await notify_error(context)


async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
    message = str(exc.detail)
    if (is_prod_env() or is_staging_env()) and exc.status_code >= 500:
        message = "Internal Server Error"
    if exc.status_code >= 500:
        # Skip logging here to avoid duplicate logs; service layer logs with stack trace.
        await _notify_slack(request, message, exc.status_code, exc)
    payload = error_payload(message)
    response = JSONResponse(status_code=exc.status_code, content=payload)
    return _with_request_id_header(request, response)


async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
    error_details = None
    if is_dev_env():
        error_details = {"stack": traceback.format_exc()}
    logger.exception("Unhandled exception: %s", exc)
    await _notify_slack(
        request, "Internal Server Error", status.HTTP_500_INTERNAL_SERVER_ERROR, exc
    )
    payload = error_payload("Internal Server Error", error_details)
    response = JSONResponse(
        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=payload
    )
    return _with_request_id_header(request, response)


async def request_validation_exception_handler(
    request: Request, exc: RequestValidationError
) -> JSONResponse:
    payload = error_payload(
        "Validation Error",
        {"errors": exc.errors()},
    )
    response = JSONResponse(
        status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content=payload
    )
    return _with_request_id_header(request, response)
