import calendar
import math

import requests
from fastapi import FastAPI, Depends, Query, UploadFile, File, HTTPException, Body, logger
from typing import List, Optional, Dict, Any
from datetime import datetime, timedelta
from sqlalchemy import and_
from sqlalchemy.orm import Session, sessionmaker, relationship
from app.db import models, database, crud
from app.utils.datastream_api import fetch_kpi_data, parse_dotnet_date, generate_date_range, find_continuous_ranges, store_kpi_response
from app.db.models import Instrument, KPI, KPIValue, MarketExchnges, Market, Sector, IndexPriceValue, InstrumentPriceValue
from app.db.database import SessionLocal
from fastapi.middleware.cors import CORSMiddleware
from app.api import admin_routes
from app.schemas import KPIFilter
import os
import json
import io
from sqlalchemy.exc import IntegrityError
from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
import pydatastream as ds
import pandas as pd
from sqlalchemy import func

from tempfile import NamedTemporaryFile
from app.middleware import ComprehensiveAPILoggingMiddleware


import json
import traceback
from fastapi import FastAPI, Request, Depends, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from sqlalchemy.orm import Session
from typing import Optional
import logging

import numpy as np
from scipy.stats import t

models.Base.metadata.create_all(bind=database.engine)

app = FastAPI()
app.add_middleware(ComprehensiveAPILoggingMiddleware)
app.include_router(admin_routes.router)
# After app = FastAPI()

# Allow requests from browser on localhost
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Change to ["http://localhost:3000"] for stricter control
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

username = os.getenv("DATASTREAM_API_USERNAME")
password = os.getenv("DATASTREAM_API_PASSWORD")

def get_db():
    db = database.SessionLocal()
    try:
        yield db
    finally:
        db.close()

API_BASE = "https://product.datastream.com/DSWSClient/V1/DSService.svc/rest"

def get_token():
    response = requests.get(f"{API_BASE}/Token", params={
        "username": username,
        "password": password
    })
    response.raise_for_status()
    data = response.json()
    token = data.get("TokenValue")
    if not token:
        raise ValueError(f"Failed to fetch token: {data}")
    return token

def get_datastream():
    try :
        DS = ds.Datastream(username=username, password=password)
        return DS
    except  Exception:
        return None

@app.post("/fetch-kpis/")
def fetch_kpis(token: str, frequency: str, symbol: str, start_date: str, end_date: str, db: Session = Depends(get_db)):
    # Fetch all KPI codes from the database
    kpi_codes = [k.code for k in db.query(models.KPI).all()]
    body = {
        "TokenValue": token,
        "Properties": { "RequestAllMetadata": True },
        "DataRequests": [
            {
                "DataTypes": [{"Value": code, "Properties": {"ReturnName": True}} for code in kpi_codes],
                "Date": {
                    "Start": start_date,
                    "End": end_date,
                    "Frequency": frequency,
                    "Kind": 1
                },
                "Instrument": {
                    "Value": symbol,
                    "Properties": { "ReturnName": True }
                },
                "Tag": ""
            }
        ]
    }

    response = fetch_kpi_data(token, body)
    print("=============== Response from DataStream API ===============")
    print(response)
    data = response["DataResponses"][0]
    dates = [parse_dotnet_date(d) for d in data["Dates"]]
    instrument_symbol = data["DataTypeValues"][0]["SymbolValues"][0]["Symbol"]

    # Ensure instrument exists
    instrument = db.query(models.Instrument).filter_by(symbol=instrument_symbol).first()
    if not instrument:
        instrument = models.Instrument(symbol=instrument_symbol)
        db.add(instrument)
        db.commit()
        db.refresh(instrument)

    for entry in data["DataTypeValues"]:
        kpi_code = entry["DataType"]
        values = entry["SymbolValues"][0]["Value"]
        currency = entry["SymbolValues"][0].get("Currency")

        # Ensure KPI exists
        kpi = db.query(models.KPI).filter_by(code=kpi_code).first()
        if not kpi:
            kpi = models.KPI(code=kpi_code)
            db.add(kpi)
            db.commit()
            db.refresh(kpi)

        for date, value in zip(dates, values):
            if value is None:
                continue
            existing = db.query(models.KPIValue).filter_by(
                instrument_id=instrument.id,
                kpi_id=kpi.id,
                frequency=frequency,
                date=date
            ).first()
            if not existing:
                kpi_val = models.KPIValue(
                    instrument_id=instrument.id,
                    kpi_id=kpi.id,
                    frequency=frequency,
                    date=date,
                    value=value,
                    currency=currency
                )
                db.add(kpi_val)
    db.commit()
    return {"message": "KPI data fetched and stored."}

@app.get("/kpis/")
def get_kpis(
    symbol: str,
    frequency: str = Query("Q"),
    start_date: str = Query(...),
    end_date: str = Query(...),
    kpis: Optional[List[str]] = Query(None),
    db: Session = Depends(get_db)
):
    token = get_token() #os.getenv("DATASTREAM_API_TOKEN")
    start = datetime.strptime(start_date, "%Y-%m-%d").date()
    end = datetime.strptime(end_date, "%Y-%m-%d").date()

    # --- Frequency Mapping (frontend → external API) ---
    freq_map = {"D": "Daily", "M": "Monthly", "Q": "Quarterly", "Y": "Yearly"}
    external_freq = freq_map.get(frequency, "Quarterly")  # default fallback

    # STEP 1: Determine KPI codes to use
    if kpis:
        # Get mapping of name => code from DB
        name_code_map = {k.name or k.code: k.code for k in db.query(KPI).all()}
        print("Name to Code Mapping:", name_code_map)
        kpi_codes = [name_code_map.get(name, name) for name in kpis]  # fallback to name itself if not found
        print("KPI Codes to Fetch:", kpi_codes)

        # Ensure any new KPI codes are added to DB if missing
        for code in kpi_codes:
            if not db.query(KPI).filter_by(code=code).first():
                db.add(KPI(code=code, name=code))
        db.commit()
    else:
        kpi_codes = [k.code for k in db.query(KPI).all()]

    # STEP 2: Get instrument
    instrument = db.query(Instrument).filter_by(symbol=symbol).first()
    if not instrument:
        instrument = Instrument(symbol=symbol)
        db.add(instrument)
        db.commit()
        db.refresh(instrument)

    # STEP 3: Get existing KPIValues from DB
    kpi_ids = [
        db.query(KPI.id).filter_by(code=code).scalar()
        for code in kpi_codes
        if db.query(KPI.id).filter_by(code=code).scalar() is not None
    ]
    db_dates = db.query(KPIValue.date).filter(
        KPIValue.instrument_id == instrument.id,
        KPIValue.frequency == frequency,
        KPIValue.kpi_id.in_(kpi_ids),
        KPIValue.date >= start,
        KPIValue.date <= end
    ).group_by(KPIValue.date).having(func.count(KPIValue.kpi_id) == len(kpi_ids)).all()

    existing_dates = {d[0] for d in db_dates}
    print("Existing Dates in DB:", existing_dates)
    full_range = generate_date_range(frequency, start, end)
    print("Full Date Range:", full_range)
    missing_dates = sorted([d for d in full_range if d not in existing_dates])
    print("Missing Dates:", missing_dates)

    # STEP 4: If any missing dates, call fetch-kpis internally
    if missing_dates:
        chunks = find_continuous_ranges(missing_dates, frequency)
        print("Chunks:", chunks)
        for chunk_start, chunk_end in chunks:
            body = {
                "TokenValue": token,
                "Properties": {"RequestAllMetadata": True},
                "DataRequests": [{
                    "DataTypes": [{"Value": code, "Properties": {"ReturnName": True}} for code in kpi_codes],
                    "Date": {
                        "Start": chunk_start.isoformat(),
                        "End": chunk_end.isoformat(),
                        "Frequency": external_freq,
                        "Kind": 1
                    },
                    "Instrument": {
                        "Value": symbol,
                        "Properties": {"ReturnName": True}
                    },
                    "Tag": ""
                }]
            }
            response = fetch_kpi_data(token, body)
            store_kpi_response(db, response)

    # STEP 5: Fetch all data (now complete) from DB
    kpi_data = {}
    for code in kpi_codes:
        kpi = db.query(KPI).filter_by(code=code).first()
        if not kpi:
            continue
        records = db.query(KPIValue).filter(
            KPIValue.instrument_id == instrument.id,
            KPIValue.kpi_id == kpi.id,
            KPIValue.frequency == frequency,
            KPIValue.date >= start,
            KPIValue.date <= end
        ).order_by(KPIValue.date).all()
        kpi_data[kpi.name or kpi.code] = [
            {
                "date": str(r.date),
                "value": r.value,
                "currency": r.currency
            }
            for r in records
        ]

    response = {
        "symbol": symbol,
        "frequency": frequency,
        "start_date": str(start),
        "end_date": str(end),
        "kpis": kpi_data,
    }

    return response

@app.get("/stocks_details/")
def get_stocks():
    db = SessionLocal()
    instruments = db.query(Instrument).all()
    result = [
        {
            "id": inst.id,
            "symbol": inst.symbol,
            "name": inst.name,
            "market_id": inst.market_id,
            "sector_id": inst.sector_id,
            "market_name": inst.market.name if inst.market else None,
            "sector_name": inst.sector.name if inst.sector else None,
        }
        for inst in instruments
    ]
    db.close()
    return result

@app.get("/market_list/")
def get_market_list():
    db = SessionLocal()
    markets = db.query(Market).all()
    result = [
        {
            "id": market.id,
            "name": market.name,

        }
        for market in markets
    ]
    db.close()
    return result


@app.get("/sector_list/")
def get_sector_list():
    db = SessionLocal()
    sectors = db.query(Sector).all()
    result = [
        {
            "id": sector.id,
            "name": sector.name,

        }
        for sector in sectors
    ]
    db.close()
    return result

def get_stock_kpi_calculation(
    frequency,
    start_date,
    end_date,
    kpis,
    symbols,
    db: Session = Depends(get_db)
):
    # token = os.getenv("DATASTREAM_API_TOKEN")

    start = datetime.strptime(start_date, "%Y-%m-%d").date()
    end = datetime.strptime(end_date, "%Y-%m-%d").date()

    # Special KPI categories
    lower = {"DIO", "EVToEBITDA", "DebtEquityRatio", "DSO", "OperatingMargin", "PE", "EVToSales"}  # Lower is better

    higher = {"EPS", "EPSDiluted", "ReturnOnAssets", "CurrentRatio", "QuickRatio", "BookValuePerShare", "MarketCap",
              "EnterpriseValue", "POUT", "NetProfitMargin", "EBITDAMargin", "EPSGrowthRate", "InterestCoverage",
              "ReturnOnEquity", "PEGRatio", "OCFPerShare", "DividendYield"}  # Higher is better

    NA_KPIS = {"Beta", "Price52WeekHigh", "Price52WeekLow", "Alpha", "PTBV"}  # Skip

    # KPI_CALCULATION = ['NetDebtToEBITDA', 'DPO']
    # net_debt_cal = ['NetDebt', 'EBITDA']
    # dpo_cal = ['COGS', 'AverageInventory']

    KPI_DEPENDENCIES = {
        "NetDebtToEBITDA": ["NetDebt", "EBITDA"],
        "DPO": ["COGS", "AverageInventory"],
        "CashConversionCycle": ["DIO", "DSO", "COGS", "AverageInventory"],
        "FCF": ["fcf", "capex"],
        "EBITMargin": ["ebitda", "revenue"],
        "AssetTurnover": ["revenue", "total_assets"],
        "InventoryTurnover": ["cogs", "inventory"],
        "TangibleBVPS": ["equity", "intangible_assets", "goodwill", "shares_outstanding"],
        "ForwardPE": ["dwfc", "expected_eps"],
        "PriceToSales": ["price_per_share", "sales_per_share"],
        "GrossMargin": ["gross_income", "revenue"],
        "FCFYield": ["fcf", "capex", "market_cap"],
        "ShortInterestRatio": ["sid", "uvo"],
        "SharpeRatio": ["msdpd", "ry", "vol"],
        "InsiderOwnershipPct": ["shares_owned_by_insiders", "shares_outstanding"],
        "InstitutionalOwnershipPct": ["noshou", "wc05475"]
    }

    # Categorize calculated KPIs based on financial interpretation
    # Lower is better (debt ratios, time periods, risk metrics)
    calculated_lower = {"NetDebtToEBITDA", "DPO", "CashConversionCycle", "ForwardPE", "PriceToSales"}

    # Higher is better (margins, turnover ratios, yields, cash flow)
    calculated_higher = {"EBITMargin", "AssetTurnover", "InventoryTurnover", "TangibleBVPS", "FCF",
                         "GrossMargin", "FCFYield", "InsiderOwnershipPct", "InstitutionalOwnershipPct"}

    # Context-dependent or special interpretation (may vary based on strategy)
    calculated_na = {"SharpeRatio", "ShortInterestRatio"}

    lower.update(calculated_lower)
    higher.update(calculated_higher)
    NA_KPIS.update(calculated_na)

    calculation_list = []
    # STEP 1: Determine KPI codes to use
    if kpis:
        name_code_map = {k.name or k.code: k.code for k in db.query(KPI).all()}
        kpi_codes = [name_code_map.get(name, name) for name in kpis]

        for code in kpi_codes[:]:  # iterate on copy since we modify list
            if code in KPI_DEPENDENCIES:
                calculation_list.append(code)
                kpi_codes.remove(code)
                for dep in KPI_DEPENDENCIES[code]:
                    if dep in name_code_map:
                        kpi_codes.append(name_code_map[dep])
            else:
                print('false')

    else:
        kpi_codes = [k.code for k in db.query(KPI).all()]
        name_code_map = {k.name or k.code: k.code for k in db.query(KPI).all()}

    stock_wise_kpi_data = {}

    if kpis:
        for symbol in symbols:
            instrument = db.query(Instrument).filter_by(symbol=symbol).first()

            kpi_ids = [
                db.query(KPI.id).filter_by(code=code).scalar()
                for code in kpi_codes
                if db.query(KPI.id).filter_by(code=code).scalar() is not None
            ]
            db_dates = db.query(KPIValue.date).filter(
                KPIValue.instrument_id == instrument.id,
                KPIValue.frequency == frequency,
                KPIValue.kpi_id.in_(kpi_ids),
                KPIValue.date >= start,
                KPIValue.date <= end
            ).group_by(KPIValue.date).having(func.count(KPIValue.kpi_id) == len(kpi_ids)).all()

            existing_dates = {d[0] for d in db_dates}
            full_range = generate_date_range(frequency, start, end)
            missing_dates = sorted([d for d in full_range if d not in existing_dates])

            if missing_dates:
                token = get_token()
                chunks = find_continuous_ranges(missing_dates, frequency)
                for chunk_start, chunk_end in chunks:
                    body = {
                        "TokenValue": token,
                        "Properties": {"RequestAllMetadata": True},
                        "DataRequests": [{
                            "DataTypes": [{"Value": code, "Properties": {"ReturnName": True}} for code in kpi_codes],
                            "Date": {
                                "Start": chunk_start.isoformat(),
                                "End": chunk_end.isoformat(),
                                "Frequency": frequency,
                                "Kind": 1
                            },
                            "Instrument": {
                                "Value": symbol,
                                "Properties": {"ReturnName": True}
                            },
                            "Tag": ""
                        }]
                    }
                    response = fetch_kpi_data(token, body)
                    store_kpi_response(db, response)

            kpi_data = {}
            for code in kpi_codes:
                kpi = db.query(KPI).filter_by(code=code).first()
                if not kpi:
                    continue
                records = db.query(KPIValue).filter(
                    KPIValue.instrument_id == instrument.id,
                    KPIValue.kpi_id == kpi.id,
                    KPIValue.frequency == frequency,
                    KPIValue.date >= start,
                    KPIValue.date <= end
                ).order_by(KPIValue.date).all()
                if records:
                    kpi_data[kpi.name or kpi.code] = [
                        {
                            "date": str(r.date),
                            "value": r.value,
                            "currency": r.currency
                        }
                        for r in records
                    ]
                else:
                    kpi_data[kpi.name or kpi.code] = None
            print("kpi_data", kpi_data)
            kpi_data = dynamic_calculate_kpis(kpi_data, name_code_map, symbol, start_date, end_date, calculation_list,
                                              frequency)
            stock_wise_kpi_data.update({symbol: kpi_data})

    print("stock_wise_kpi_data", stock_wise_kpi_data)
    return stock_wise_kpi_data, KPI_DEPENDENCIES, NA_KPIS, name_code_map, lower, higher


@app.get("/stock_performance_calculation/")
def get_stock_performance_calculation(
    # symbol: str,
    frequency: str = Query("Q"),
    start_date: str = Query(...),
    end_date: str = Query(...),
    kpis: Optional[List[str]] = Query(None),
    symbols: Optional[List[str]] = Query(None),
    db: Session = Depends(get_db)
):

    # symbols = [s.strip() for item in symbols for s in item.split(",")]
    symbols = list({s.strip() for item in symbols for s in item.split(",")})
    response_data = {}
    symbols = list(set(symbols))
    stock_wise_kpi_data, KPI_DEPENDENCIES, NA_KPIS, name_code_map, lower, higher = get_stock_kpi_calculation(frequency, start_date, end_date, kpis, symbols, db)

    start = datetime.strptime(start_date, "%Y-%m-%d").date()
    end = datetime.strptime(end_date, "%Y-%m-%d").date()
    # ------------------ Stock + Index prices (DB-first) ------------------


    for symbol in symbols:
        instrument = db.query(Instrument).filter_by(symbol=symbol).first()

        # --- Index prices ---
        index_symbol = instrument.market_exchange.index_code
        index_name = instrument.market_exchange.index_name

        index_records = db.query(IndexPriceValue).filter(
            IndexPriceValue.market_exchange_for_id == instrument.market_exchange.id,
            IndexPriceValue.date >= start,
            IndexPriceValue.date <= end
        ).order_by(IndexPriceValue.date).all()

        if index_records:
            df_index = pd.DataFrame(
                [{"date": r.date, "value": r.value} for r in index_records]
            ).set_index("date")
        else:
            DS = get_datastream()
            df_index = DS.fetch(index_symbol, ["PI"], date_from=start_date, date_to=end_date).dropna()
            if not df_index.empty:
                for d, val in df_index.itertuples():
                    db.add(IndexPriceValue(
                        market_exchange_for_id=instrument.market_exchange.id,
                        date=d,
                        value=val
                    ))
                db.commit()

        if df_index.empty:
            raise ValueError("No index data found!")

        index_start = df_index.iloc[0, 0]
        index_end = df_index.iloc[-1, 0]
        index_return = (index_end - index_start) / index_start * 100

        # --- Stock prices ---
        stock_records = db.query(InstrumentPriceValue).filter(
            InstrumentPriceValue.instrument_for_id == instrument.id,
            InstrumentPriceValue.date >= start,
            InstrumentPriceValue.date <= end
        ).order_by(InstrumentPriceValue.date).all()

        if stock_records:
            df_stock = pd.DataFrame(
                [{"date": r.date, "value": r.value} for r in stock_records]
            ).set_index("date")
        else:
            DS = get_datastream()
            df_stock = DS.fetch(symbol, ["P"], date_from=start_date, date_to=end_date).dropna()
            if not df_stock.empty:
                for d, val in df_stock.itertuples():
                    db.add(InstrumentPriceValue(
                        instrument_for_id=instrument.id,
                        date=d,
                        value=val
                    ))
                db.commit()

        performance = []
        if df_stock.empty:
            performance.append({
                "Ticker": symbol,
                "Start Price": None,
                "End Price": None,
                "Return %": None,
                "Excess vs ATX %": None
            })
        else:
            start_price = df_stock.iloc[0, 0]
            end_price = df_stock.iloc[-1, 0]
            stock_return = (end_price - start_price) / start_price * 100
            excess_return = stock_return - index_return

            all_data = {
                "Ticker": symbol,
                "Start Price": round(start_price, 2),
                "End Price": round(end_price, 2),
                "Return %": round(stock_return, 2),
                "Index Name": index_name,
                "Index Start": round(index_start, 2),
                "Index End": round(index_end, 2),
                "Index Return %": round(index_return, 2),
                "Total Score %": round(excess_return, 2)
            }

            kpi_data = {
                k: v for k, v in (stock_wise_kpi_data.get(symbol, {}) or {}).items()
                if k in kpis
            }
            # Get all dependency codes to exclude from final response
            all_dependency_codes = set()
            for deps in KPI_DEPENDENCIES.values():
                for dep in deps:
                    if dep in name_code_map.values():
                        for name, code in name_code_map.items():
                            if code == dep:
                                all_dependency_codes.add(name)
                                break

            # --- KPI Calculation with absolute change ---
            # kpi_data = stock_wise_kpi_data.get(symbol, {})
            # for kpi_name, records in kpi_data.items():
            for kpi_name, records in kpi_data.items():
                # Skip dependency KPIs - only show calculated/derived KPIs
                if kpi_name in all_dependency_codes:
                    continue

                if kpi_name in NA_KPIS:
                    all_data[f"{kpi_name} Score %"] = "NA"
                    all_data[f"{kpi_name} Score % Rank"] = "NA"
                    continue

                if not records:
                    all_data[f"{kpi_name} Score %"] = "NA"
                    all_data[f"{kpi_name} Score % Rank"] = "NA"
                    continue
                df_kpi = pd.DataFrame(records)
                df_kpi["date"] = pd.to_datetime(df_kpi["date"])
                df_kpi = df_kpi.sort_values("date").reset_index(drop=True)

                start_val = df_kpi["value"].iloc[0]
                end_val = df_kpi["value"].iloc[-1]

                if start_val is None or end_val is None:
                    continue

                abs_change = end_val - start_val
                kpi_score = round(abs_change, 2)

                all_data[f"{kpi_name} Score %"] = kpi_score
            print("kpi_data-----", kpi_data)
            performance.append(all_data)

        perf_df = pd.DataFrame(performance).fillna(0)

        calculation_response = {
            "symbol": symbol,
            "frequency": frequency,
            "start_date": str(start),
            "end_date": str(end),
            "stock_with_performance": perf_df.to_dict(orient="records")
        }

        response_data.update({symbol: calculation_response})
    # print('response_data before rank', response_data)
    # --- Step 2: Ranking ---
    rows = []
    for symbol, details in response_data.items():
        rows.extend(details["stock_with_performance"])
    perf_df = pd.DataFrame(rows)

    # Rank only non-NA KPI columns
    for col in perf_df.columns:
        if col.endswith("Score %") and col not in ["Total Score %"]:
            perf_df[col] = pd.to_numeric(perf_df[col], errors="coerce")
            if perf_df[col].isna().all():
                continue
            kpi_name = col.replace(" Score %", "")
            if kpi_name in lower:
                perf_df[f"{col} Rank"] = perf_df[col].rank(ascending=True, method="dense").astype("Int64")
            elif kpi_name in higher:
                perf_df[f"{col} Rank"] = perf_df[col].rank(ascending=False, method="dense").astype("Int64")
            else:
                perf_df[f"{col} Rank"] = perf_df[col].rank(ascending=False, method="dense").astype("Int64")

    perf_df["Rank"] = perf_df["Total Score %"].rank(ascending=False, method="dense").astype("Int64")

    for _, row in perf_df.iterrows():
        ticker = row["Ticker"]
        for sym, details in response_data.items():
            if details["stock_with_performance"][0]["Ticker"] == ticker:
                details["stock_with_performance"][0].update(row.to_dict())


    def clean_nans(obj):
        if isinstance(obj, dict):
            return {k: clean_nans(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [clean_nans(i) for i in obj]
        elif isinstance(obj, float) and math.isnan(obj):
            return None  # or 0 if you prefer
        return obj

    # --- Call Quarterly API & merge ---
    # quarterly_results = get_stock_quarterly_performance_calculation(
    #     start_period="Q1-2022", end_period="Q2-2022", db=db
    # )
    # for sym, details in response_data.items():
    #     if sym in quarterly_results:
    #         details["stock_with_performance"][0].update(quarterly_results[sym])

    result = {
        "stock_performance_data": response_data,
        "stock_wise_kpi_data": filter_dependency_kpis_from_response(stock_wise_kpi_data, KPI_DEPENDENCIES, name_code_map)
    }

    result = clean_nans(result)  # ✅ Clean NaN → None before returning
    print(result)
    return result


@app.get("/download_stock_performance_excel/")
def download_stock_performance_excel(
    frequency: str = Query("Q"),
    start_date: str = Query(...),
    end_date: str = Query(...),
    kpis: Optional[List[str]] = Query(None),
    symbols: Optional[List[str]] = Query(None),
    db: Session = Depends(get_db)
):
    # ✅ First call your existing function to reuse logic
    result = get_stock_performance_calculation(
        frequency=frequency,
        start_date=start_date,
        end_date=end_date,
        kpis=kpis,
        symbols=symbols,
        db=db
    )

    # Extract only stock_performance_data
    stock_performance_data = result["stock_performance_data"]

    # Flatten into rows for Excel
    rows = []
    for symbol, details in stock_performance_data.items():
        rows.extend(details["stock_with_performance"])

    df = pd.DataFrame(rows)

    # Write to Excel in-memory
    output = io.BytesIO()
    with pd.ExcelWriter(output, engine="openpyxl") as writer:
        df.to_excel(writer, sheet_name="StockPerformance", index=False)

    output.seek(0)

    # Return as downloadable file
    headers = {
        "Content-Disposition": "attachment; filename=stock_performance.xlsx"
    }
    return StreamingResponse(output, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", headers=headers)

def filter_dependency_kpis_from_response(stock_wise_kpi_data, kpi_dependencies, name_code_map):
    """
    Filter out dependency KPIs from the response, keeping only calculated/derived KPIs
    """
    # Get all dependency codes to exclude
    all_dependency_codes = set()
    for deps in kpi_dependencies.values():
        for dep in deps:
            if dep in name_code_map.values():
                # Find the name for this code
                for name, code in name_code_map.items():
                    if code == dep:
                        all_dependency_codes.add(name)
                        break

    filtered_data = {}
    for symbol, kpi_data in stock_wise_kpi_data.items():
        filtered_kpi_data = {}
        for kpi_name, records in kpi_data.items():
            # Only include non-dependency KPIs (i.e., calculated/derived KPIs)
            if kpi_name not in all_dependency_codes:
                filtered_kpi_data[kpi_name] = records
        filtered_data[symbol] = filtered_kpi_data

    return filtered_data


@app.get("/kpi-options/")
def get_kpi_options(db: Session = Depends(get_db)):
    kpis = db.query(KPI).all()
    return [{"code": k.code, "name": k.name or k.code} for k in kpis]

@app.get("/symbol-search/")
def symbol_search(query: str = Query(...), db: Session = Depends(get_db)):
    results = (
        db.query(Instrument)
        .filter(Instrument.symbol.ilike(f"%{query}%"))
        .limit(10)
        .all()
    )
    return [{"symbol": i.symbol, "name": i.name} for i in results]


def check_trend_condition(values, trend_type, quarters, threshold, direction='positive'):
    data_list = [v.value for v in values if v.value is not None]

    if len(data_list) < quarters:
        return False

    if trend_type == 'consecutive_growth':
        # Check for consecutive growth or decline
        return check_consecutive_growth(data_list, quarters, threshold, direction)

    elif trend_type == 'negative_to_positive':
        # Check for negative to positive transition
        return check_negative_to_positive(data_list, quarters)

    elif trend_type == 'yoy_growth':
        # Check for year-over-year growth
        return check_yoy_growth(data_list, quarters, threshold, direction)

    elif trend_type == 'post_transition_growth':
        # Check for growth after a negative to positive transition
        return check_post_transition_growth(data_list, quarters, threshold)

    elif trend_type == 'absolute_threshold':
        # Check for absolute threshold condition
        return check_absolute_threshold(data_list, threshold, direction)

    return False

def check_absolute_threshold(data_list, threshold, direction='positive'):
    if direction == 'positive':
        return any(value is not None and value >= threshold for value in data_list)
    else:
        return any(value is not None and value <= threshold for value in data_list)

def check_consecutive_growth(data_list, quarters, threshold, direction='positive'):
    if len(data_list) < quarters:
        return False

    growth_list = [
        ((curr - prev) / abs(prev)) * 100
        for prev, curr in zip(data_list, data_list[1:])
        if prev is not None and curr is not None and prev != 0
    ]

    if direction == 'negative':
        compare = lambda g: g <= -threshold
    else:
        compare = lambda g: g >= threshold

    return any(
        all(compare(g) for g in growth_list[i:i + quarters - 1])
        for i in range(len(growth_list) - quarters + 2)
    )


def check_negative_to_positive(data_list, quarters):
    if len(data_list) < quarters:
        return False

    for i in range(len(data_list) - quarters + 1):
        segment = data_list[i:i + quarters]
        half = quarters // 2
        if any(x is not None and x < 0 for x in segment[:half]) and all(
            x is not None and x > 0 for x in segment[half:]
        ):
            return True
    return False


def check_yoy_growth(data_list, quarters, threshold, direction='positive'):
    if len(data_list) < quarters + 4:
        return False

    count = 0
    for i in range(4, len(data_list)):
        if data_list[i] is not None and data_list[i - 4] is not None and data_list[i - 4] != 0:
            growth = ((data_list[i] - data_list[i - 4]) / abs(data_list[i - 4])) * 100
            if direction == 'negative' and growth <= -threshold:
                count += 1
            elif direction != 'negative' and growth >= threshold:
                count += 1
            else:
                count = 0

            if count >= quarters:
                return True
    return False


def check_post_transition_growth(data_list, quarters_after=2, threshold=0):
    for i in range(len(data_list) - quarters_after - 1):
        if data_list[i] is not None and data_list[i] < 0 and data_list[i + 1] is not None and data_list[i + 1] > 0:
            growth_ok = True
            for j in range(i + 1, i + 1 + quarters_after):
                if data_list[j] is None or data_list[j - 1] is None or data_list[j - 1] == 0:
                    growth_ok = False
                    break
                growth = ((data_list[j] - data_list[j - 1]) / abs(data_list[j - 1])) * 100
                if growth < threshold:
                    growth_ok = False
                    break
            if growth_ok:
                return True
    return False


def apply_advanced_filter(db, instrument_id, kpi_code, filter_data, frequency, start_date, end_date):
    """Apply trend logic for advanced filters"""
    kpi = db.query(models.KPI).filter_by(code=kpi_code).first()
    if not kpi:
        return False, None

    values = db.query(models.KPIValue).filter(
        models.KPIValue.instrument_id == instrument_id,
        models.KPIValue.kpi_id == kpi.id,
        models.KPIValue.frequency == frequency,
        models.KPIValue.date >= start_date,
        models.KPIValue.date <= end_date
    ).order_by(models.KPIValue.date.asc()).all()

    if not values or len(values) < 2:
        return False, None

    trend_type = filter_data.get("trend")
    quarters = int(filter_data.get("quarters", 2))
    threshold = float(filter_data.get("threshold", 0))
    direction = filter_data.get("direction", "positive")

    passed = check_trend_condition(values, trend_type, quarters, threshold, direction)
    if passed:
        latest_value = max(values, key=lambda v: v.date)
        return True, latest_value
    return False, None


@app.get("/filter-stocks/")
def filter_stocks_get(
    frequency: str,
    start_date: str,
    end_date: str,
    logical_operator: str,
    filters: List[str] = Query(...),
    # logical_operator: str = Query("AND"),  # Default is AND
    db: Session = Depends(get_db)
):
    parsed_filters = [json.loads(f) for f in filters]
    print('logical_operator ===>', logical_operator)
    instruments = db.query(models.Instrument).all()
    matching_results = []

    for instrument in instruments:
        instrument_kpi_data = []
        filter_pass_count = 0

        for f in parsed_filters:
            kpi_code = f["kpi_code"]
            if "trend" in f:  # advanced filter
                passed, kpi_value_obj = apply_advanced_filter(
                    db, instrument.id, kpi_code, f, frequency, start_date, end_date
                )
                if passed:
                    filter_pass_count += 1
                    instrument_kpi_data.append({
                        "kpi_code": kpi_code,
                        "value": kpi_value_obj.value,
                        "date": str(kpi_value_obj.date)
                    })
            else:
                # Basic KPI filter
                operator = f["operator"]
                value = float(f["value"])

                kpi = db.query(models.KPI).filter_by(code=kpi_code).first()
                if not kpi:
                    continue

                kpi_values_query = db.query(models.KPIValue).filter(
                    models.KPIValue.instrument_id == instrument.id,
                    models.KPIValue.kpi_id == kpi.id,
                    models.KPIValue.frequency == frequency,
                    models.KPIValue.date >= start_date,
                    models.KPIValue.date <= end_date
                )

                # Apply operator filter
                if operator == "gt":
                    kpi_values_query = kpi_values_query.filter(models.KPIValue.value > value)
                elif operator == "gte":
                    kpi_values_query = kpi_values_query.filter(models.KPIValue.value >= value)
                elif operator == "lt":
                    kpi_values_query = kpi_values_query.filter(models.KPIValue.value < value)
                elif operator == "lte":
                    kpi_values_query = kpi_values_query.filter(models.KPIValue.value <= value)
                elif operator == "eq":
                    kpi_values_query = kpi_values_query.filter(models.KPIValue.value == value)
                elif operator == "ne":
                    kpi_values_query = kpi_values_query.filter(models.KPIValue.value != value)
                else:
                    continue  # Invalid operator

                kpi_value_obj = kpi_values_query.order_by(models.KPIValue.date.desc()).first()
                if kpi_value_obj:
                    filter_pass_count += 1
                    instrument_kpi_data.append({
                        "kpi_code": kpi_code,
                        "value": kpi_value_obj.value,
                        "date": str(kpi_value_obj.date)
                    })

        # Combine filters using AND / OR logic
        if logical_operator.upper() == "AND" and filter_pass_count == len(parsed_filters):
            matching_results.append({
                "symbol": instrument.symbol,
                "kpi_values": instrument_kpi_data
            })
        elif logical_operator.upper() == "OR" and filter_pass_count > 0:
            matching_results.append({
                "symbol": instrument.symbol,
                "kpi_values": instrument_kpi_data
            })
    print('matching_results', matching_results)
    return {"symbols": matching_results}


@app.get("/routes")
def list_routes():
    return [route.path for route in app.routes]

# ---------- Helper utilities ----------
def read_excel_to_df(uploaded_file: UploadFile) -> pd.DataFrame:
    contents = uploaded_file.file.read()
    try:
        df = pd.read_excel(io.BytesIO(contents), engine="openpyxl")
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Could not read Excel file: {e}")
    return df


def safe_str(value):
    if value is None:
        return None
    s = str(value).strip()
    if s.lower() in {"", "nan", "none"}:
        return None
    return s

@app.post("/api/upload/market-exchanges")
def upload_market_exchanges(file: UploadFile = File(...), db: Session = Depends(get_db)):
    df = read_excel_to_df(file)
    expected = {"index_code", "index_name", "exchange_name", "market_name"}

    lower_map = {c.lower(): c for c in df.columns}
    missing = expected - set(lower_map.keys())
    if missing:
        raise HTTPException(status_code=400, detail=f"Missing columns: {missing}")

    df = df.rename(columns={lower_map[k]: k for k in expected})
    created, skipped, errors = 0, 0, []

    for idx, row in df.iterrows():
        try:
            code = safe_str(row.get("index_code"))
            if not code:
                skipped += 1
                continue
            existing = db.query(MarketExchnges).filter_by(index_code=code).first()
            # if existing:
                # skipped += 1
            if existing:
                if not existing.market_name:
                    existing.market_name = safe_str(row.get("market_name"))

                if not existing.echange_name:  # or exchange_name if that's the actual field
                    existing.echange_name = safe_str(row.get("exchange_name"))

                if not existing.index_name:
                    existing.index_name = safe_str(row.get("index_name"))

                # no need for db.add(existing)
                db.flush()

                continue
            mx = MarketExchnges(
                index_code=code,
                index_name=safe_str(row.get("index_name")),
                echange_name=safe_str(row.get("echange_name")),
                market_name=safe_str(row.get("market_name"))
            )
            db.add(mx)
            db.flush()
            created += 1
        except Exception as e:
            db.rollback()
            errors.append({"row": int(idx), "error": str(e)})
    db.commit()
    return {"success": True, "created": created, "skipped": skipped, "errors": errors}

# --- Upload Instruments ---
@app.post("/api/upload/instruments")
def upload_instruments(file: UploadFile = File(...), db: Session = Depends(get_db)):
    df = read_excel_to_df(file)
    expected = {"market_index_code", "short_code", "symbol", "ric", "ticker", "name", "sector", "market"}

    lower_map = {c.lower(): c for c in df.columns}
    missing = expected - set(lower_map.keys())
    if missing:
        raise HTTPException(status_code=400, detail=f"Missing columns (case-insensitive): {missing}")
    df = df.rename(columns={lower_map[k]: k for k in expected})

    inserted, skipped, updated = 0, 0, 0

    for _, row in df.iterrows():
        market_index_code = safe_str(row.get("market_index_code"))
        symbol = safe_str(row.get("symbol"))
        short_code = safe_str(row.get("short_code"))
        ric = safe_str(row.get("ric"))
        ticker = safe_str(row.get("ticker"))
        name = safe_str(row.get("name"))
        sector = safe_str(row.get("sector"))
        market = safe_str(row.get("market"))

        if not market_index_code or not symbol or symbol.lower() == "nan":
            skipped += 1
            print("updated", skipped, "inst.symbol", symbol, "name", "name")
            continue

        # find or create market exchange
        mx = db.query(MarketExchnges).filter_by(index_code=market_index_code).first()
        if not mx:
            mx = MarketExchnges(index_code=market_index_code, index_name="", echange_name="", market_name="")
            db.add(mx)
            db.flush()

        se = None
        if sector:
            se = db.query(Sector).filter_by(name=sector).first()
            if not se:
                se = Sector(name=sector)
                db.add(se)
                db.flush()

        ma = None
        if market:
            ma = db.query(Market).filter_by(name=market).first()
            if not ma:
                ma = Market(name=market)
                db.add(ma)
                db.flush()

        # check if instrument exists
        # inst = db.query(Instrument).filter(Instrument.symbol == symbol).first()

        query = db.query(Instrument).filter(Instrument.symbol == symbol)
        if ric:
            query = query.union(db.query(Instrument).filter(Instrument.ric == ric))
        if ticker:
            query = query.union(db.query(Instrument).filter(Instrument.ticker == ticker))
        inst = query.first()

        if inst:
            inst.market_exchange_id = mx.id

            if se:  # only update if sector provided
                inst.sector_id = se.id
            # else: keep existing

            if ma:  # only update if market provided
                inst.market_id = ma.id
            # else: keep existing

            if short_code:
                try:
                    inst.short_code = int(short_code)
                except ValueError:
                    inst.short_code = None
            if ric:
                inst.ric = ric
            if ticker:
                inst.ticker = ticker
            if name:
                inst.name = name
            updated += 1
            print("updated", updated, "inst.symbol", inst.symbol)
        else:
            inst = Instrument(
                market_exchange_id=mx.id,
                short_code=short_code,
                symbol=symbol,
                ric=ric,
                ticker=ticker,
                name=name,
                sector_id=se.id if se else None,
                market_id=ma.id if ma else None
            )
            db.add(inst)
            try:
                db.flush()
                inserted += 1
            except IntegrityError:
                db.rollback()
                skipped += 1
                continue

    db.commit()
    response = {"success": True, "created": inserted, "updated": updated, "skipped": skipped}
    print(response)
    return response

@app.get("/market-exchanges")
def list_market_exchanges(db: Session = Depends(get_db)):
    rows = db.query(MarketExchnges).all()
    print("rows", rows)
    payload = []
    for r in rows:
        print('r', r)
        payload.append({
            "id": r.id,
            "index_code": r.index_code,
            "index_name": r.index_name,
            "echange_name": r.echange_name,
            "market_name": r.market_name
        })
    return JSONResponse(payload)


@app.get("/instruments")
def list_instruments(db: Session = Depends(get_db)):
    rows = db.query(Instrument).all()
    payload = []
    for r in rows:
        payload.append({
            "id": r.id,
            "market_exchange_id": r.market_exchange_id if r.market_exchange_id else None,
            "market_index_code": r.market_exchange.index_code if r.market_exchange else None,
            "short_code": r.short_code,
            "symbol": r.symbol,
            "ric": r.ric,
            "ticker": r.ticker,
            "name": r.name,
            "sector_name": r.sector.name if r.sector else None,
            "market_name": r.market.name if r.market else None
        })
    return JSONResponse(payload)

# ---- List endpoints for preview ----
@app.get("/markets")
def list_markets(db: Session = Depends(get_db)):
    rows = db.query(Market).all()
    return JSONResponse([{"id": r.id, "name": r.name} for r in rows])

@app.get("/sectors")
def list_sectors(db: Session = Depends(get_db)):
    rows = db.query(Sector).all()
    return JSONResponse([{"id": r.id, "name": r.name} for r in rows])


@app.post("/api/upload/market")
def upload_market(file: UploadFile = File(...), db: Session = Depends(get_db)):
    try:
        contents = file.file.read()
        df = pd.read_excel(io.BytesIO(contents), engine="openpyxl")
    except Exception as e:
        return JSONResponse({"success": False, "error": f"Failed to read Excel: {e}"}, status_code=400)

    cols = {c.lower().strip(): c for c in df.columns}
    print("cols", cols)
    if "name" not in cols:
        return JSONResponse({"success": False, "error": "Excel file must contain a 'name' column."}, status_code=400)

    name_col = cols["name"]
    created, skipped, errors = 0, 0, []
    for idx, val in df[name_col].items():
        try:
            if pd.isna(val):
                skipped += 1
                continue
            name = str(val).strip()
            if not name:
                skipped += 1
                continue
            if crud.get_market_by_name(db, name):
                skipped += 1
                continue
            crud.create_market(db, name)
            created += 1
        except Exception as e:
            errors.append({"row": int(idx), "error": str(e)})
    return {"success": True, "created": created, "skipped": skipped, "errors": errors}

@app.post("/api/upload/sector")
def upload_sector(file: UploadFile = File(...), db: Session = Depends(get_db)):
    try:
        contents = file.file.read()
        df = pd.read_excel(io.BytesIO(contents), engine="openpyxl")
    except Exception as e:
        return JSONResponse({"success": False, "error": f"Failed to read Excel: {e}"}, status_code=400)

    cols = {c.lower().strip(): c for c in df.columns}
    if "name" not in cols:
        return JSONResponse({"success": False, "error": "Excel file must contain a 'name' column."}, status_code=400)

    name_col = cols["name"]
    created, skipped, errors = 0, 0, []
    for idx, val in df[name_col].items():
        try:
            if pd.isna(val):
                skipped += 1
                continue
            name = str(val).strip()
            if not name:
                skipped += 1
                continue
            if crud.get_sector_by_name(db, name):
                skipped += 1
                continue
            crud.create_sector(db, name)
            created += 1
        except Exception as e:
            errors.append({"row": int(idx), "error": str(e)})
    return {"success": True, "created": created, "skipped": skipped, "errors": errors}


from decimal import Decimal, DivisionByZero, InvalidOperation, getcontext

getcontext().prec = 10

def safe_divide(numerator, denominator):
    try:
        return numerator / denominator if denominator != 0 else None
    except (DivisionByZero, InvalidOperation):
        return None

def to_decimal(value):
    try:
        if value is None or value == '':
            return None
        return Decimal(str(value))
    except (InvalidOperation, ValueError):
        return None

def calc_net_debt_to_ebitda(data):
    net_debt = data.get('WC18199')
    ebitda = data.get('WC18198')
    if None in (net_debt, ebitda) or ebitda == 0:
        return None
    return safe_divide(net_debt, ebitda)

def calc_dpo(data):
    accounts_payable = data.get('WC03040')
    cogs = data.get('WC01051')
    if None in (accounts_payable, cogs) or cogs == 0:
        return None
    return safe_divide(accounts_payable, cogs) * 365


def calc_ccc(data):
    days = 365
    dio = data.get('WC08126', 0)
    dso = data.get('WC08131')
    accounts_payable = data.get('WC03040')
    cogs = data.get('WC01051')
    if None in (dio, dso, accounts_payable, cogs) or cogs == 0:
        return None
    dpo = safe_divide(accounts_payable, safe_divide(cogs, days))
    if dpo is None:
        return None
    return dio + dso - dpo


def calc_ebitda_margin(data):
    ebitda = data.get('WC18198')
    revenue = data.get('WC01001')
    if None in (ebitda, revenue) or revenue == 0:
        return None
    return safe_divide(ebitda, revenue) * 100


def calc_fcf(data):
    fcf = data.get('WC04860')
    capex = data.get('WC04601')
    if None in (fcf, capex):
        return None
    return fcf - capex


def calc_asset_turnover(data):
    revenue = data.get('WC01001')
    total_assets = data.get('WC02999')
    if None in (revenue, total_assets) or total_assets == 0:
        return None
    return safe_divide(revenue, total_assets)


def calc_inventory_turnover(data):
    cogs = data.get('WC01051')
    inventory = data.get('WC02101')
    if None in (cogs, inventory) or inventory == 0:
        return None
    return safe_divide(cogs, inventory)


def calc_tangible_bvps(data):
    equity = data.get('WC03995')
    intangible_assets = data.get('WC02649')
    goodwill = data.get('WC18280R', 0)
    shares_outstanding = data.get('WC05301')
    if None in (equity, intangible_assets, goodwill, shares_outstanding) or shares_outstanding == 0:
        return None
    numerator = equity - intangible_assets - goodwill
    return safe_divide(numerator, shares_outstanding)


def calc_forward_pe(data):
    dwfc = data.get('DWFC')
    expected_eps = data.get('WC07250')
    if None in (dwfc, expected_eps) or expected_eps == 0:
        return None
    return safe_divide(dwfc, expected_eps)


def calc_price_to_sales(data):
    price_per_share = data.get('P')
    sales_per_share = data.get('WC05508')
    if None in (price_per_share, sales_per_share) or sales_per_share == 0:
        return None
    return safe_divide(price_per_share, sales_per_share)


def calc_gross_margin(data):
    gross_income = data.get('WC01100')
    revenue = data.get('WC01001')
    if None in (gross_income, revenue) or revenue == 0:
        return None
    return safe_divide(gross_income, revenue) * 100


def calc_fcf_yield(data):
    fcf = data.get('WC04860')
    capex = data.get('WC04601')
    market_cap = data.get('WC08001')
    if None in (fcf, capex, market_cap) or market_cap == 0:
        return None
    free_cash_flow = fcf - capex
    return safe_divide(free_cash_flow, market_cap) * 100


def calc_short_interest_ratio(data):
    sid = data.get('SID')
    uvo = data.get('UVO')
    if None in (sid, uvo) or uvo == 0:
        return None
    return safe_divide(sid * 1000, uvo)


def calc_sharpe_ratio(data):
    msdpd = data.get('MSDPD')
    ry = data.get('RY')
    vol = data.get('VOL')
    if None in (msdpd, ry, vol) or vol == 0:
        return None
    return safe_divide(msdpd - ry, vol)


def calc_institutional_ownership_pct(data):
    noshou = data.get('NOSHOU')
    wc05475 = data.get('WC05475')
    if None in (noshou, wc05475) or noshou == 0:
        return None
    return safe_divide((noshou - wc05475) * 100, noshou)


def calc_insider_ownership_pct(data):
    shares_owned_by_insiders = data.get('VO')
    shares_outstanding = data.get('DWCF')
    if None in (shares_owned_by_insiders, shares_outstanding) or shares_outstanding == 0:
        return None
    return safe_divide(shares_owned_by_insiders, shares_outstanding) * 100


def calculate_kpi_value(period_key, all_data):
    """
    Calculate KPI values for the given period
    """
    current_data = all_data
    results = {}

    calc_map = {
        'FCFYield': lambda: calc_fcf_yield(current_data),
        'NetDebtToEBITDA': lambda: calc_net_debt_to_ebitda(current_data),
        'DPO': lambda: calc_dpo(current_data),
        'CashConversionCycle': lambda: calc_ccc(current_data),
        'EBITMargin': lambda: calc_ebitda_margin(current_data),
        'FCF': lambda: calc_fcf(current_data),
        'AssetTurnover': lambda: calc_asset_turnover(current_data),
        'InventoryTurnover': lambda: calc_inventory_turnover(current_data),
        'TangibleBVPS': lambda: calc_tangible_bvps(current_data),
        'ForwardPE': lambda: calc_forward_pe(current_data),
        'PriceToSales': lambda: calc_price_to_sales(current_data),
        'GrossMargin': lambda: calc_gross_margin(current_data),
        'ShortInterestRatio': lambda: calc_short_interest_ratio(current_data),
        'SharpeRatio': lambda: calc_sharpe_ratio(current_data),
        'InsiderOwnershipPct': lambda: calc_insider_ownership_pct(current_data),
        'InstitutionalOwnershipPct': lambda: calc_institutional_ownership_pct(current_data),
    }
    print(f'calc_map for period {period_key} ===> {calc_map}')
    for kpi_name, func in calc_map.items():
        try:
            val = func()
            if val is not None:
                results[kpi_name] = val
        except Exception as e:
            print(f"Error calculating {kpi_name}: {e}")
            continue
    print(f'results for period {period_key} ===> {results}')
    return results


def save_stock_data(stock, kpi, period_key, value, kpi_data, symbol, frequency, date=None):
    if value is None:
        return
    if frequency == "D":
        year, month, day = period_key
        date = f"{year:04d}-{month:02d}-{day:02d}"
    elif frequency == "M":
        year, month = period_key
        date = f"{year:04d}-{month:02d}-01"
    elif frequency == "Q":
        year, quarter = period_key
        if quarter == 1:
            date = f"{year:04d}-01-01"
        elif quarter == 2:
            date = f"{year:04d}-04-01"
        elif quarter == 3:
            date = f"{year:04d}-07-01"
        elif quarter == 4:
            date = f"{year:04d}-10-01"
    elif frequency == "Y":
        year = period_key[0]
        date = f"{year:04d}-01-01"
    else:  # Default to quarterly
        year, quarter = period_key
        if quarter == 1:
            date = f"{year:04d}-01-01"
        elif quarter == 2:
            date = f"{year:04d}-04-01"
        elif quarter == 3:
            date = f"{year:04d}-07-01"
        elif quarter == 4:
            date = f"{year:04d}-10-01"

    print('==================================stock data start========================================\n')
    print('saving stock data:', stock, kpi, period_key, value, date, f'frequency: {frequency}')
    print('==================================stock data end========================================\n')

    if kpi not in kpi_data:
        kpi_data[kpi] = []

    kpi_data[kpi].append({
        "date": date,
        "value": float(value),
        "currency": ""
    })

    return kpi_data


def get_kpi_calculation_data(DS, kpi_code_data, start_date, end_date, frequency, symbol, kpi_data_list):
    for k, v in kpi_code_data.items():
        try:
            df_index = DS.fetch(symbol, [v], date_from=start_date, date_to=end_date, freq=frequency)
            df_index = df_index.dropna()
            if df_index.empty:
                continue
            records = []
            for date, row in df_index.iterrows():
                records.append({
                    "date": date.strftime("%Y-%m-%d"),
                    "value": float(row[v]),
                    "currency": ""
                })
            kpi_data_list[v] = records
        except Exception as e:
            print(f"Error fetching {v} for {symbol}: {e}")
            continue
    return kpi_data_list

def dynamic_calculate_kpis(kpi_data, name_code_map, symbol, start_date, end_date, calculation_list, frequency):
    quarterly_data = {}

    # --- Frequency Mapping (frontend → external API) ---
    freq_map = {"D": "Daily", "M": "Monthly", "Q": "Quarterly", "Y": "Yearly"}
    external_freq = freq_map.get(frequency, "Quarterly")  # default fallback

    # KPI data mapping for calculations
    kpi_data_mappings = {
        'NetDebtToEBITDA': {"NetDebt": "WC18199", "EBITDA": "WC18198"},
        'DPO': {"AccountsPayable": "WC03040", "CostOfGoodsSold": "WC01051"},
        'CashConversionCycle': {"DSO": "WC08131", "DIO": "WC08126", "AccountsPayable": "WC03040","CostOfGoodsSold": "WC01051"},
        'FCF': {"fcf": "WC04860", "capex": "WC04601"},
        'EBITMargin': {"ebitda": "WC18198", "revenue": "WC01001"},
        'AssetTurnover': {"revenue": "WC01001", "total_assets": "WC02999"},
        'InventoryTurnover': {"cogs": "WC01051", "inventory": "WC02101"},
        'TangibleBVPS': {"equity": "WC03995", "intangible_assets": "WC02649", "goodwill": "WC18280R","shares_outstanding": "WC05301"},
        'ForwardPE': {"dwfc": "DWFC", "expected_eps": "WC07250"},
        'PriceToSales': {"price_per_share": "P", "sales_per_share": "WC05508"},
        'GrossMargin': {"gross_income": "WC01100", "revenue": "WC01001"},
        'FCFYield': {"fcf": "WC04860", "capex": "WC04601", "market_cap": "WC08001"},
        'ShortInterestRatio': {"sid": "SID", "uvo": "UVO"},
        'SharpeRatio': {"msdpd": "MSDPD", "ry": "RY", "vol": "VOL"},
        'InsiderOwnershipPct': {"shares_owned_by_insiders": "VO", "shares_outstanding": "DWCF"},
        'InstitutionalOwnershipPct': {"noshou": "NOSHOU", "wc05475": "WC05475"}
    }


    kpi_data_list = {}

    # Fetch required data for each calculation KPI
    for kpi_name in calculation_list:
        if kpi_name in kpi_data_mappings:
            kpi_code_data = kpi_data_mappings[kpi_name]
            DS = get_datastream()
            kpi_data_list = get_kpi_calculation_data(DS, kpi_code_data, start_date, end_date, frequency, symbol,
                                                     kpi_data_list)

    print("kpi_data_list:", kpi_data_list)

    # Structure quarterly data
    for kpi_name, entries in kpi_data_list.items():
        for entry in entries:
            dt = datetime.strptime(entry["date"], "%Y-%m-%d")
            # Create period key based on frequency
            if frequency == "D":
                period_key = (dt.year, dt.month, dt.day)
            elif frequency == "M":
                period_key = (dt.year, dt.month)
            elif frequency == "Q":
                quarter = (dt.month - 1) // 3 + 1
                period_key = (dt.year, quarter)
            elif frequency == "Y":
                period_key = (dt.year,)
            else:  # Default to quarterly
                quarter = (dt.month - 1) // 3 + 1
                period_key = (dt.year, quarter)

            if period_key not in quarterly_data:
                quarterly_data[period_key] = {}

            quarterly_data[period_key][kpi_name] = Decimal(str(entry["value"]))

    print("structured_data:", quarterly_data)

    # Calculate KPIs for each period
    for period_key, values in quarterly_data.items():
        calculated = calculate_kpi_value(period_key, values)
        for kpi_name, value in calculated.items():
            save_stock_data(symbol, kpi_name, period_key, value, kpi_data, symbol, frequency)

    print("Final kpi_data====>", kpi_data)
    return kpi_data

@app.post("/stock_quarterly_performance_calculation")
def get_stock_quarterly_performance_calculation(
        symbols: Optional[List[str]] = Query(None),
        start_period: str = Query(...),
        end_period: str = Query(...),
        result: Optional[Dict[str, Any]] = Body(None),
        kpis: Optional[List[str]] = Query(None),
        db: Session = Depends(get_db)
):
    print("kpis received ===>", kpis)
    # symbols = [s.strip() for item in symbols for s in item.split(",")]
    symbols = list({s.strip() for item in symbols for s in item.split(",")})
    symbols = list(set(symbols))
    print("symbols",symbols)


    quarterly_results = stock_quarterly_performance_calculation(
        symbols=symbols, start_period=start_period, end_period=end_period, result=result, kpis=kpis, db=db
    )

    return quarterly_results

def quarter_to_dates(start_quarter: str, end_quarter: str):
    quarter_map = {"Q1": 1, "Q2": 4, "Q3": 7, "Q4": 10}
    q_start, year_start = start_quarter.split("-")
    year_start = int(year_start)
    month_start = quarter_map[q_start]
    start_date = datetime(year_start, month_start, 1).strftime("%Y-%m-%d")

    q_end, year_end = end_quarter.split("-")
    year_end = int(year_end)
    month_end = quarter_map[q_end] + 2
    last_day = calendar.monthrange(year_end, month_end)[1]
    end_date = datetime(year_end, month_end, last_day).strftime("%Y-%m-%d")
    return start_date, end_date

def generate_quarter_range(start_period, end_period):
    start_q, start_y = start_period.split('-')
    end_q, end_y = end_period.split('-')
    start_q, start_y, end_q, end_y = int(start_q[1:]), int(start_y), int(end_q[1:]), int(end_y)

    quarters = []
    year, quarter = start_y, start_q

    while (year < end_y) or (year == end_y and quarter <= end_q):
        quarters.append(f"Q{quarter}-{year}")
        quarter += 1
        if quarter > 4:
            quarter = 1
            year += 1

    return quarters

# Helper function to get the previous quarter
def get_previous_quarter(quarter: str) -> str:
    q, year = quarter.split("-")
    q_num = int(q[1])
    year = int(year)
    prev_q_num = q_num - 1
    prev_year = year
    if prev_q_num == 0:
        prev_q_num = 4
        prev_year -= 1
    return f"Q{prev_q_num}-{prev_year}"


def stock_quarterly_performance_calculation(symbols, start_period, end_period, result, kpis,
                                            db: Session = Depends(get_db)):


    query = db.query(Instrument).filter(Instrument.symbol.in_(symbols))
    results = query.all()
    stocks = [inst.symbol for inst in results]
    index_symbols = list({inst.market_exchange.index_code for inst in query if inst.market_exchange})

    # Get the previous quarter to include for returns calculation
    # --- Determine actual download range (1 quarter before start) ---
    start_year = int(start_period.split("-")[1])
    start_qnum = int(start_period[1])
    prev_qnum = start_qnum - 1
    prev_year = start_year
    if prev_qnum == 0:
        prev_qnum = 4
        prev_year -= 1
    prev_quarter = f"Q{prev_qnum}-{prev_year}"

    # --- Helper: Get start/end dates for a quarter ---
    def new_quarter_to_dates(qtr_label):
        q, year = qtr_label.split("-")
        year = int(year)
        q_num = int(q[-1])
        if q_num == 1:
            return f"{year}-01-01", f"{year}-03-31"
        elif q_num == 2:
            return f"{year}-04-01", f"{year}-06-30"
        elif q_num == 3:
            return f"{year}-07-01", f"{year}-09-30"
        elif q_num == 4:
            return f"{year}-10-01", f"{year}-12-31"

    # --- Define actual date range for fetching data ---
    prev_start, _ = new_quarter_to_dates(prev_quarter)  # one quarter before start
    _, end_date = new_quarter_to_dates(end_period)

    # --- Final start_date used for fetching ---
    start_date = prev_start

    period_list = generate_quarter_range(start_period, end_period)
    period_data = []
    frequency = "Q"
    stock_wise_kpi_data, KPI_DEPENDENCIES, NA_KPIS, name_code_map, lower, higher = get_stock_kpi_calculation(frequency, start_date, end_date, kpis, symbols, db)

    df_stocks = {}
    for symbol in stocks:
        instrument = db.query(Instrument).filter_by(symbol=symbol).first()

        stock_records = db.query(InstrumentPriceValue).filter(
            InstrumentPriceValue.instrument_for_id == instrument.id,
            InstrumentPriceValue.date >= start_date,
            InstrumentPriceValue.date <= end_date
        ).order_by(InstrumentPriceValue.date).all()

        if stock_records:
            df = pd.DataFrame(
                [{"date": r.date, "P": r.value} for r in stock_records]
            ).set_index("date")
            df.index = pd.to_datetime(df.index)
            df.index.name = None
        else:
            DS = get_datastream()
            # --- Step 2: If DB empty → fetch from Datastream ---
            df = DS.fetch(symbol, ["P"], date_from=start_date, date_to=end_date).dropna()
            if not df.empty:
                # --- Step 3: Store back into DB ---
                for d, val in df.itertuples():
                    db.add(InstrumentPriceValue(
                        instrument_for_id=instrument.id,
                        date=d,
                        value=val
                    ))
                db.commit()

        if not df.empty:
            df_stocks[symbol] = df

    # --- Fetch indexes ---
    df_indexes = {}
    for symbol in index_symbols:

        market_exchange = db.query(MarketExchnges).filter_by(index_code=symbol).first()
        # --- Step 1: Try DB first ---
        index_records = db.query(IndexPriceValue).filter(
            IndexPriceValue.market_exchange_for_id == market_exchange.id,
            IndexPriceValue.date >= start_date,
            IndexPriceValue.date <= end_date
        ).order_by(IndexPriceValue.date).all()

        if index_records:
            df = pd.DataFrame(
                [{"date": r.date, "PI": r.value} for r in index_records]
            ).set_index("date")
            df.index = pd.to_datetime(df.index)
            df.index.name = None
        else:
            # --- Step 2: If DB empty → fetch from Datastream ---
            df = DS.fetch(market_exchange.index_code, ["PI"], date_from=start_date, date_to=end_date).dropna()
            if not df.empty:
                # --- Step 3: Save back to DB ---
                for d, val in df.itertuples():
                    db.add(IndexPriceValue(
                        market_exchange_for_id=market_exchange.id,
                        date=d,
                        value=val
                    ))
                db.commit()
        if not df.empty:
            df_indexes[market_exchange.index_code] = df

    # --- Resample quarterly ---
    for symbol, df in df_stocks.items():
        df_q = df.resample("QE").last()
        df_q[f"{symbol}_Return"] = df_q.iloc[:, 0].pct_change()
        df_q.index = df_q.index.to_period("Q")
        df_stocks[symbol] = df_q

    for symbol, df in df_indexes.items():
        df_q = df.resample("QE").last()
        df_q[f"{symbol}_Return"] = df_q.iloc[:, 0].pct_change()
        df_q.index = df_q.index.to_period("Q")
        df_indexes[symbol] = df_q

    def period_change(df, col, start_period, end_period):
        try:
            start_val = df.loc[start_period, col]
            end_val = df.loc[end_period, col]
            return (end_val - start_val) / start_val
        except Exception:
            return None

    def convert_quarter_format(q_str: str) -> str:
        q, year = q_str.split("-")
        return f"{year}{q}"

    start_fmt, end_fmt = convert_quarter_format(start_period), convert_quarter_format(end_period)

    quarterly_results = {}
    for inst in results:
        stock_symbol = inst.symbol
        index_symbol = inst.market_exchange.index_code
        index_name = inst.market_exchange.index_name

        stock_change = None
        index_change = None
        if stock_symbol in df_stocks:
            stock_change = period_change(df_stocks[stock_symbol], "P", start_fmt, end_fmt)
        if index_symbol in df_indexes:
            index_change = period_change(df_indexes[index_symbol], "PI", start_fmt, end_fmt)

        if stock_change is not None:
            stock_change = round(stock_change * 100, 2)
        if index_change is not None:
            index_change = round(index_change * 100, 2)

        quarterly_results[stock_symbol] = {
            "Quarterly Stock Change %": stock_change,
            "Quarterly Index Change %": index_change,
            "Index Name": index_name
        }
    # for i in stocks:
    #     for j in result['stock_performance_data'][i]['stock_with_performance']:
    #         if j['Ticker'] == i:
    #             if i in quarterly_results:
    #                 j.update(quarterly_results[i])
    for i in stocks:
        for stock, stock_data in result['stock_performance_data'].items():
            for j in stock_data.get('stock_with_performance', []):
                if j['Ticker'] == i:
                    if i in quarterly_results:
                        j.update(quarterly_results[i])

    # --- Statistical Calculations ---
    stats_results = {}
    for inst in results:
        stock_symbol = inst.symbol
        index_symbol = inst.market_exchange.index_code

        df_stock = df_stocks.get(stock_symbol)
        df_index = df_indexes.get(index_symbol)

        if df_stock is not None and df_index is not None:
            # Calculate quarterly returns
            stock_returns = df_stock[f"{stock_symbol}_Return"].dropna()
            index_returns = df_index[f"{index_symbol}_Return"].dropna()
            print("Stock Returns:\n", stock_returns)
            print("Index Returns:\n", index_returns)
            # Align returns by quarter
            aligned_returns = pd.concat([stock_returns, index_returns], axis=1, join='inner')
            aligned_returns.columns = ['stock_return', 'index_return']
            aligned_returns.index = aligned_returns.index.to_timestamp()

            # Filter for the specified quarter range
            quarters = generate_quarter_range(start_period, end_period)
            formatted_quarters = [q.replace("-", " ") for q in quarters]
            aligned_returns["Quarter"] = [f"Q{idx.quarter} {idx.year}" for idx in aligned_returns.index]
            aligned_returns = aligned_returns[aligned_returns["Quarter"].isin(formatted_quarters)]

            # Mean quarterly returns
            stock_q_mean = aligned_returns['stock_return'].mean() if not aligned_returns.empty else None
            index_q_mean = aligned_returns['index_return'].mean() if not aligned_returns.empty else None
            print("stock_q_mean:", stock_q_mean, "---> index_q_mean:", index_q_mean)
            # Variance and covariance
            if len(aligned_returns) > 1:
                stock_variance = aligned_returns['stock_return'].var(ddof=1)
                print("stock_variance ===>", stock_variance)
                index_variance = aligned_returns['index_return'].var(ddof=1)
                print("index_variance ===>", index_variance)
                covariance = aligned_returns[['stock_return', 'index_return']].cov().iloc[0, 1]
                print("covariance ===>", covariance)
                stock_std = math.sqrt(stock_variance)
                print("stock_std",stock_std)
                index_std = math.sqrt(index_variance)
                print("index_std",index_std)
            else:
                stock_variance = index_variance = covariance = None

            # Beta
            beta = covariance / index_variance if covariance is not None and index_variance not in (None, 0) else None
            print("beta ===>", beta)
            # Regression diagnostics
            t_stat = se_slope = sum_sq_x = residual_variance = None
            residuals = []
            if beta is not None and len(aligned_returns) > 2:
                stock_q_values = aligned_returns['stock_return'].values
                index_q_values = aligned_returns['index_return'].values
                sum_sq_x = sum((x - index_q_mean) ** 2 for x in index_q_values) if index_q_mean is not None else None

                if sum_sq_x is not None and sum_sq_x != 0:
                    for i in range(len(stock_q_values)):
                        predicted = beta * index_q_values[i]
                        resid = stock_q_values[i] - predicted
                        residuals.append(resid)
                    residuals_sum = sum(r ** 2 for r in residuals)
                    n = len(residuals)
                    if n > 2:
                        residual_variance = residuals_sum / (n - 2)
                        se_slope = math.sqrt(residual_variance / sum_sq_x) if sum_sq_x != 0 else None
                        t_stat = beta / se_slope if se_slope and se_slope != 0 else None

            # p-value and significance
            p_value = significant = None
            if t_stat is not None and len(aligned_returns) > 2:
                df = len(aligned_returns) - 2
                p_value = 2 * (1 - t.cdf(abs(t_stat), df))
                significant = "Yes" if p_value <= 0.05 else "No"

            # Excess Return (α), Tracking Error (TE), Information Ratio (IR)
            alpha = tracking_error = info_ratio = None
            if residuals:
                n = len(residuals)
                alpha = sum(residuals) / n
                if n > 1:
                    tracking_error = math.sqrt(sum((r - alpha) ** 2 for r in residuals) / (n - 1))
                    info_ratio = alpha / tracking_error if tracking_error != 0 else None

            stats_results[stock_symbol] = {
                # "Mean Quarterly Return": round(stock_q_mean, 4) if stock_q_mean is not None else None,
                "Stock_Variance": round(stock_variance, 4) if stock_variance is not None else None,
                "stock_std": round(stock_std,4) if stock_std is not None else None,
                # "index_std": round(index_std,4) if index_std is not None else None
                # "Index_Variance": round(index_variance, 4) if index_variance is not None else None,
                "residual_variance": round(residual_variance, 4) if residual_variance is not None else None,
                "se_slope": round(se_slope, 4) if se_slope is not None else None,
                "Covariance": round(covariance, 4) if covariance is not None else None,
                "Beta": round(beta, 4) if beta is not None else None,
                "Alpha": round(alpha, 4) if alpha is not None else None,
                "Tracking Error": round(tracking_error, 4) if tracking_error is not None else None,
                "Information Ratio": round(info_ratio, 4) if info_ratio is not None else None,
                "t-statistic": round(t_stat, 4) if t_stat is not None else None,
                "p-value": round(p_value, 4) if p_value is not None else None,
                "Significant": significant
            }
            print("stats_results:", stats_results)

    # --- Formatted Quarterly Summary Table ---
    def date_to_quarter(date_str):
        date_obj = pd.to_datetime(date_str)
        q = (date_obj.month - 1) // 3 + 1
        return f"Q{q} {date_obj.year}"

    quarters = generate_quarter_range(start_period, end_period)
    formatted_quarters = [q.replace("-", " ") for q in quarters]
    table = {"Quarter": formatted_quarters}

    # Store stock-specific data temporarily to add in correct order
    stock_data_dict = {}

    for stock_symbol in stocks:
        instrument = db.query(Instrument).filter_by(symbol=stock_symbol).first()
        index_symbol = instrument.market_exchange.index_code if instrument.market_exchange else None

        stock_columns = {}

        # --- Get stock quarterly prices ---
        df_stock = df_stocks.get(stock_symbol)
        if df_stock is not None:
            df_stock_q = df_stock.copy()
            df_stock_q.index = df_stock_q.index.to_timestamp()
            df_stock_q["Quarter"] = [date_to_quarter(d) for d in df_stock_q.index]
            df_stock_q = df_stock_q[["Quarter", "P"]].rename(columns={"P": f"{stock_symbol} Price"})
        else:
            df_stock_q = pd.DataFrame(columns=["Quarter", f"{stock_symbol} Price"])

        # --- Get index quarterly values ---
        df_index = df_indexes.get(index_symbol)
        if df_index is not None:
            df_index_q = df_index.copy()
            df_index_q.index = df_index_q.index.to_timestamp()
            df_index_q["Quarter"] = [date_to_quarter(d) for d in df_index_q.index]
            df_index_q = df_index_q[["Quarter", "PI"]].rename(columns={"PI": f"{stock_symbol} Index"})
        else:
            df_index_q = pd.DataFrame(columns=["Quarter", f"{stock_symbol} Index"])

        # --- Merge price & index ---
        df_merge = pd.merge(df_stock_q, df_index_q, on="Quarter", how="outer")

        # Add Price column
        col = f"{stock_symbol} Price"
        if col in df_merge.columns:
            col_values = []
            for q in formatted_quarters:
                val = df_merge.loc[df_merge["Quarter"] == q, col]
                val = val.values[0] if not val.empty else np.nan
                col_values.append(round(val, 4) if pd.notna(val) else None)
            stock_columns[col] = col_values

        # Add Index column
        col = f"{stock_symbol} Index"
        if col in df_merge.columns:
            col_values = []
            for q in formatted_quarters:
                val = df_merge.loc[df_merge["Quarter"] == q, col]
                val = val.values[0] if not val.empty else np.nan
                col_values.append(round(val, 4) if pd.notna(val) else None)
            stock_columns[col] = col_values

        # --- Calculate and add Stock Performance %, Index Performance %, and Excess Return % ---
        if df_stock is not None:
            # Get all returns for this stock up to each quarter
            stock_returns_full = df_stock[f"{stock_symbol}_Return"].dropna()
            stock_returns_full.index = stock_returns_full.index.to_timestamp()

            stock_perf_values = []
            index_perf_values = []
            excess_return_values = []

            for q in formatted_quarters:
                # Parse quarter to get date range
                q_parts = q.split()
                q_num = int(q_parts[0][1])
                q_year = int(q_parts[1])

                # Get all returns from start_period up to current quarter
                q_date = pd.Timestamp(year=q_year, month=q_num * 3, day=1)

                # Filter returns from start to current quarter
                stock_returns_subset = stock_returns_full[stock_returns_full.index <= q_date]

                # Calculate cumulative return for stock
                if len(stock_returns_subset) > 0:
                    cumulative_stock_return = (1 + stock_returns_subset).prod() - 1
                    stock_perf_values.append(round(cumulative_stock_return * 100, 2))
                else:
                    stock_perf_values.append(None)

                # Calculate cumulative return for index
                if df_index is not None:
                    index_returns_full = df_index[f"{index_symbol}_Return"].dropna()
                    index_returns_full.index = index_returns_full.index.to_timestamp()
                    index_returns_subset = index_returns_full[index_returns_full.index <= q_date]

                    if len(index_returns_subset) > 0:
                        cumulative_index_return = (1 + index_returns_subset).prod() - 1
                        index_perf_values.append(round(cumulative_index_return * 100, 2))
                    else:
                        index_perf_values.append(None)

                    # Calculate excess return
                    if stock_perf_values[-1] is not None and index_perf_values[-1] is not None:
                        excess = stock_perf_values[-1] - index_perf_values[-1]
                        excess_return_values.append(round(excess, 2))
                    else:
                        excess_return_values.append(None)
                else:
                    index_perf_values.append(None)
                    excess_return_values.append(None)

            stock_columns[f"{stock_symbol} Stock Performance %"] = stock_perf_values
            if df_index is not None:
                stock_columns[f"{stock_symbol} Index Performance %"] = index_perf_values
                stock_columns[f"{stock_symbol} Excess Return %"] = excess_return_values

        # --- Add KPI columns ---
        kpi_data = stock_wise_kpi_data.get(stock_symbol, {})
        for kpi_name, entries in kpi_data.items():
            kpi_df = pd.DataFrame(entries)
            kpi_df["Quarter"] = kpi_df["date"].apply(date_to_quarter)
            kpi_df = kpi_df[["Quarter", "value"]].rename(columns={"value": f"{stock_symbol} {kpi_name}"})
            df_merge = pd.merge(df_merge, kpi_df, on="Quarter", how="outer")

            col = f"{stock_symbol} {kpi_name}"
            col_values = []
            for q in formatted_quarters:
                val = df_merge.loc[df_merge["Quarter"] == q, col]
                val = val.values[0] if not val.empty else np.nan
                col_values.append(round(val, 4) if pd.notna(val) else None)
            stock_columns[col] = col_values

        # --- Add quarterly returns ---
        if df_stock is not None:
            stock_returns = df_stock[f"{stock_symbol}_Return"].dropna()
            stock_returns.index = stock_returns.index.to_timestamp()
            stock_returns_df = pd.DataFrame({
                "Quarter": [f"Q{idx.quarter} {idx.year}" for idx in stock_returns.index],
                f"{stock_symbol} Return": stock_returns.values
            })
            df_merge = pd.merge(df_merge, stock_returns_df, on="Quarter", how="outer")

            col = f"{stock_symbol} Return"
            col_values = []
            for q in formatted_quarters:
                val = df_merge.loc[df_merge["Quarter"] == q, col]
                val = val.values[0] if not val.empty else np.nan
                col_values.append(round(val, 4) if pd.notna(val) else None)
            stock_columns[col] = col_values

        if df_index is not None:
            index_returns = df_index[f"{index_symbol}_Return"].dropna()
            index_returns.index = index_returns.index.to_timestamp()
            index_returns_df = pd.DataFrame({
                "Quarter": [f"Q{idx.quarter} {idx.year}" for idx in index_returns.index],
                f"{stock_symbol} Index Return": index_returns.values
            })
            df_merge = pd.merge(df_merge, index_returns_df, on="Quarter", how="outer")

            col = f"{stock_symbol} Index Return"
            col_values = []
            for q in formatted_quarters:
                val = df_merge.loc[df_merge["Quarter"] == q, col]
                val = val.values[0] if not val.empty else np.nan
                col_values.append(round(val, 4) if pd.notna(val) else None)
            stock_columns[col] = col_values

        # --- Add statistical metrics ---
        if stock_symbol in stats_results:
            for stat_name, stat_value in stats_results[stock_symbol].items():
                stock_columns[f"{stock_symbol} {stat_name}"] = [stat_value] * len(formatted_quarters)

        stock_data_dict[stock_symbol] = stock_columns

    # Now add all columns to table in the correct order
    # for stock_symbol in stocks:
    #     for col_name, col_values in stock_data_dict[stock_symbol].items():
    #         table[col_name] = col_values
    #
    # # Convert to DataFrame
    # df_summary = pd.DataFrame(table)
    #
    # # Print in terminal as a table
    # print("\n===== Quarterly Performance Summary =====")
    # print(df_summary.to_string(index=False))
    #
    #
    # # Add to result (convert to JSON serializable)
    # result["quarterly_performance_summary"] = df_summary.fillna("").to_dict(orient="records")
    # print(result)
    # return result

    # --- Filter stocks based on positive overall excess return ---
    positive_stocks = []

    for stock_symbol in stocks:
        stock_cols = stock_data_dict.get(stock_symbol, {})
        # Find the "Excess Return %" column for this stock
        excess_cols = [col for col in stock_cols.keys() if col.endswith("Excess Return %")]
        if excess_cols:
            excess_col = excess_cols[0]
            excess_values = [v for v in stock_cols[excess_col] if v is not None]

            if excess_values:
                total_excess = sum(excess_values)
                avg_excess = total_excess / len(excess_values)

                # Keep stock if overall average excess return is positive
                if avg_excess > 0:
                    positive_stocks.append(stock_symbol)

    # --- Build final table only for positive stocks ---
    for stock_symbol in positive_stocks:
        for col_name, col_values in stock_data_dict[stock_symbol].items():
            table[col_name] = col_values

    # Convert to DataFrame
    df_summary = pd.DataFrame(table)

    # Print in terminal as a table
    print("\n===== Quarterly Performance Summary (Positive Excess Return Only) =====")
    print(df_summary.to_string(index=False))

    # Add to result (convert to JSON serializable)
    result["quarterly_performance_summary"] = df_summary.fillna("").to_dict(orient="records")
    total_stocks = len(symbols)
    selected_kpis = len(kpis)
    message = excel_sheet_manager(total_stocks=total_stocks, selected_kpis=selected_kpis)
    result["warning_message"] = message
    print(result)
    return result


def excel_sheet_manager(total_stocks, selected_kpis, fixed_columns=20, excel_limit=16384):
    total_columns_per_stock = selected_kpis + fixed_columns
    total_columns_needed = total_stocks * total_columns_per_stock

    if total_columns_needed <= excel_limit:
        return None
    else:

        # Calculate max stocks that can fit in one sheet
        max_stocks_per_sheet = excel_limit // total_columns_per_stock
        sheets_needed = (total_stocks + max_stocks_per_sheet - 1) // max_stocks_per_sheet
        kpi_suggestion = max(1, (excel_limit // total_stocks) - fixed_columns)
        stock_suggestion = excel_limit // total_columns_per_stock

        response = (
            f"⚠️ You have selected {selected_kpis} KPIs and {total_stocks} Stocks. Based on Excel’s 16,384 column limit, this configuration requires approximately {sheets_needed} sheets — each sheet can contain up to {max_stocks_per_sheet} stocks.\n\n"
            f"🧭 Please reduce the number of stocks or KPIs for the best experience.\n\n"
            f"💡 Suggested Approaches:\n"
            f"1️⃣ Reduce KPIs per stock to ≤ {kpi_suggestion} to fit all {total_stocks} stocks in one sheet.\n"
            f"2️⃣ Keep KPIs = {selected_kpis} but limit stocks to ≤ {stock_suggestion} per sheet.")

        return response

@app.post("/download_stock_analysis_excel/")
def download_stock_analysis_excel(result: dict):

    # ---- Convert JSON data to DataFrame ----
    df = pd.DataFrame(result["quarterly_performance_summary"])

    # ---- Set the actual Excel column limit ----
    MAX_COLUMNS_PER_SHEET = 16384
    all_columns = list(df.columns)
    base_columns = ["Quarter"]

    # ---- Identify stock blocks (group ending in "Significant") ----
    stock_blocks = []
    temp_block = []

    for col in all_columns[1:]:  # skip 'Quarter'
        temp_block.append(col)
        if col.endswith("Significant"):
            stock_blocks.append(temp_block)
            temp_block = []

    # Handle any leftover columns (safety check)
    if temp_block:
        stock_blocks.append(temp_block)

    # ---- Allocate stock blocks to sheets ----
    sheets = []
    current_sheet_cols = base_columns.copy()
    for block in stock_blocks:
        # Check if this block fits into current sheet
        if len(current_sheet_cols) + len(block) <= MAX_COLUMNS_PER_SHEET:
            current_sheet_cols.extend(block)
        else:
            # Start new sheet
            sheets.append(current_sheet_cols)
            current_sheet_cols = base_columns.copy() + block

    # Add last sheet
    if current_sheet_cols:
        sheets.append(current_sheet_cols)

    # ---- Write Excel file ----
    with NamedTemporaryFile(delete=False, suffix=".xlsx") as tmp:
        with pd.ExcelWriter(tmp.name, engine="openpyxl") as writer:
            for i, cols in enumerate(sheets, start=1):
                sheet_name = f"Sheet{i}"
                df[cols].to_excel(writer, index=False, sheet_name=sheet_name)
        file_path = tmp.name

    # ---- Return Excel file ----
    return FileResponse(
        path=file_path,
        filename="Quarterly_Performance_Summary2.xlsx",
        media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
    )
