import json
import traceback
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request
import app.db.database as database
import app.db.models as models
from fastapi import logger


class ComprehensiveAPILoggingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        # --- Extract basic info ---
        # ip = request.client.host if request.client else "unknown"
        # --- Get client IP correctly (supports proxy setups) ---
        ip = (
                request.headers.get("x-forwarded-for")
                or request.headers.get("x-real-ip")
                or (request.client.host if request.client else "unknown")
        )

        # If multiple IPs are listed, take the first (real client)
        if "," in ip:
            ip = ip.split(",")[0].strip()

        user_agent = request.headers.get("user-agent")
        endpoint = request.url.path
        method = request.method

        # --- Extract and clean query params ---
        raw_params = request.query_params
        query_params = {}
        for key, value in raw_params.items():
            if "," in value:
                # Split, strip, and remove duplicates while preserving order
                seen = set()
                cleaned_list = []
                for v in value.split(","):
                    v = v.strip()
                    if v and v not in seen:
                        seen.add(v)
                        cleaned_list.append(v)
                query_params[key] = cleaned_list
            else:
                query_params[key] = value.strip()

        # --- Read request body (for POST/PUT/PATCH) ---
        body = None
        if request.method in ["POST", "PUT", "PATCH"]:
            try:
                body_bytes = await request.body()
                if body_bytes:
                    decoded = body_bytes.decode()
                    try:
                        body = json.loads(decoded)
                    except json.JSONDecodeError:
                        body = decoded
            except Exception:
                body = "<unreadable>"

        # --- Prepare context for logging ---
        context = {
            "query_params": query_params,
            "body": body,
            "selected_stocks": query_params.get("symbols", []),
            "selected_kpis": query_params.get("kpis", []),
            # "frequency": query_params.get("frequency"),
            # "start_date": query_params.get("start_date"),
            # "end_date": query_params.get("end_date"),
            # "filters": query_params.get("filters", []),
            # "user_action": self._infer_action(endpoint, query_params, body),
        }

        # --- Create DB session ---
        db = database.SessionLocal()
        log_entry = models.ApiLog(
            endpoint=endpoint,
            method=method,
            ip_address=ip,
            user_agent=user_agent,
            request_data=context,
            # created_at=datetime.utcnow()
        )
        db.add(log_entry)
        db.commit()

        # --- Process request ---
        try:
            response = await call_next(request)
            log_entry.response_status = response.status_code
            db.commit()
            return response

        except Exception as exc:
            # --- Log error in DB ---
            error_msg = traceback.format_exc()
            log_entry.response_status = 500
            log_entry.error_message = error_msg
            db.commit()

            logger.error(f"API Error: {endpoint} | IP: {ip} | Error: {error_msg}")
            raise exc

        finally:
            db.close()

    def _infer_action(self, endpoint: str, params: dict, body) -> str:
        """Generate a human-readable description of the user's action."""
        actions = {
            "/kpis/": "Fetching KPI time series data",
            "/stock_performance_calculation/": "Calculating stock performance & rankings",
            "/download_stock_performance_excel/": "Downloading performance Excel",
            "/filter-stocks/": "Applying advanced KPI filters",
            "/upload/instruments": "Uploading instrument list",
            "/upload/market-exchanges": "Uploading market/exchange data",
            "/symbol-search/": "Searching for stock symbols",
            "/stock_quarterly_performance_calculation": "Running quarterly analysis",
        }
        base = endpoint.split("?")[0]
        return actions.get(base, f"Unknown action on {base}")
