from fastapi import FastAPI, Depends, Query
from typing import List, Optional
from datetime import datetime, timedelta
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.db import models, database
from app.utils.datastream_api import fetch_kpi_data, parse_dotnet_date, generate_date_range, find_continuous_ranges, store_kpi_response
from app.db.models import Instrument, KPI, KPIValue
from app.db.database import SessionLocal
from fastapi.middleware.cors import CORSMiddleware
from app.api import admin_routes
from app.schemas import KPIFilter
import os
import json

models.Base.metadata.create_all(bind=database.engine)

app = FastAPI()
app.include_router(admin_routes.router)

# Allow requests from browser on localhost
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Change to ["http://localhost:3000"] for stricter control
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

def get_db():
    db = database.SessionLocal()
    try:
        yield db
    finally:
        db.close()

@app.post("/fetch-kpis/")
def fetch_kpis(token: str, frequency: str, symbol: str, start_date: str, end_date: str, db: Session = Depends(get_db)):
    # Fetch all KPI codes from the database
    kpi_codes = [k.code for k in db.query(models.KPI).all()]
    body = {
        "TokenValue": token,
        "Properties": { "RequestAllMetadata": True },
        "DataRequests": [
            {
                "DataTypes": [{"Value": code, "Properties": {"ReturnName": True}} for code in kpi_codes],
                "Date": {
                    "Start": start_date,
                    "End": end_date,
                    "Frequency": frequency,
                    "Kind": 1
                },
                "Instrument": {
                    "Value": symbol,
                    "Properties": { "ReturnName": True }
                },
                "Tag": ""
            }
        ]
    }

    response = fetch_kpi_data(token, body)
    print("=============== Response from DataStream API ===============")
    print(response)
    data = response["DataResponses"][0]
    dates = [parse_dotnet_date(d) for d in data["Dates"]]
    instrument_symbol = data["DataTypeValues"][0]["SymbolValues"][0]["Symbol"]

    # Ensure instrument exists
    instrument = db.query(models.Instrument).filter_by(symbol=instrument_symbol).first()
    if not instrument:
        instrument = models.Instrument(symbol=instrument_symbol)
        db.add(instrument)
        db.commit()
        db.refresh(instrument)

    for entry in data["DataTypeValues"]:
        kpi_code = entry["DataType"]
        values = entry["SymbolValues"][0]["Value"]
        currency = entry["SymbolValues"][0].get("Currency")

        # Ensure KPI exists
        kpi = db.query(models.KPI).filter_by(code=kpi_code).first()
        if not kpi:
            kpi = models.KPI(code=kpi_code)
            db.add(kpi)
            db.commit()
            db.refresh(kpi)

        for date, value in zip(dates, values):
            if value is None:
                continue
            existing = db.query(models.KPIValue).filter_by(
                instrument_id=instrument.id,
                kpi_id=kpi.id,
                frequency=frequency,
                date=date
            ).first()
            if not existing:
                kpi_val = models.KPIValue(
                    instrument_id=instrument.id,
                    kpi_id=kpi.id,
                    frequency=frequency,
                    date=date,
                    value=value,
                    currency=currency
                )
                db.add(kpi_val)
    db.commit()
    return {"message": "KPI data fetched and stored."}

from fastapi import Query
from typing import List, Optional

from sqlalchemy import func

@app.get("/kpis/")
def get_kpis(
    symbol: str,
    frequency: str = Query("Q"),
    start_date: str = Query(...),
    end_date: str = Query(...),
    kpis: Optional[List[str]] = Query(None),
    db: Session = Depends(get_db)
):
    token = os.getenv("DATASTREAM_API_TOKEN")
    start = datetime.strptime(start_date, "%Y-%m-%d").date()
    end = datetime.strptime(end_date, "%Y-%m-%d").date()

    # STEP 1: Determine KPI codes to use
    if kpis:
        # Get mapping of name => code from DB
        name_code_map = {k.name or k.code: k.code for k in db.query(KPI).all()}
        print("Name to Code Mapping:", name_code_map)
        kpi_codes = [name_code_map.get(name, name) for name in kpis]  # fallback to name itself if not found
        print("KPI Codes to Fetch:", kpi_codes)

        # Ensure any new KPI codes are added to DB if missing
        for code in kpi_codes:
            if not db.query(KPI).filter_by(code=code).first():
                db.add(KPI(code=code, name=code))
        db.commit()
    else:
        kpi_codes = [k.code for k in db.query(KPI).all()]

    # STEP 2: Get instrument
    instrument = db.query(Instrument).filter_by(symbol=symbol).first()
    if not instrument:
        instrument = Instrument(symbol=symbol)
        db.add(instrument)
        db.commit()
        db.refresh(instrument)

    # STEP 3: Get existing KPIValues from DB
    kpi_ids = [
        db.query(KPI.id).filter_by(code=code).scalar()
        for code in kpi_codes
        if db.query(KPI.id).filter_by(code=code).scalar() is not None
    ]
    db_dates = db.query(KPIValue.date).filter(
        KPIValue.instrument_id == instrument.id,
        KPIValue.frequency == frequency,
        KPIValue.kpi_id.in_(kpi_ids),
        KPIValue.date >= start,
        KPIValue.date <= end
    ).group_by(KPIValue.date).having(func.count(KPIValue.kpi_id) == len(kpi_ids)).all()

    existing_dates = {d[0] for d in db_dates}
    print("Existing Dates in DB:", existing_dates)
    full_range = generate_date_range(frequency, start, end)
    print("Full Date Range:", full_range)
    missing_dates = sorted([d for d in full_range if d not in existing_dates])
    print("Missing Dates:", missing_dates)

    # STEP 4: If any missing dates, call fetch-kpis internally
    if missing_dates:
        chunks = find_continuous_ranges(missing_dates, frequency)
        print("Chunks:", chunks)
        for chunk_start, chunk_end in chunks:
            body = {
                "TokenValue": token,
                "Properties": {"RequestAllMetadata": True},
                "DataRequests": [{
                    "DataTypes": [{"Value": code, "Properties": {"ReturnName": True}} for code in kpi_codes],
                    "Date": {
                        "Start": chunk_start.isoformat(),
                        "End": chunk_end.isoformat(),
                        "Frequency": frequency,
                        "Kind": 1
                    },
                    "Instrument": {
                        "Value": symbol,
                        "Properties": {"ReturnName": True}
                    },
                    "Tag": ""
                }]
            }
            response = fetch_kpi_data(token, body)
            store_kpi_response(db, response)

    # STEP 5: Fetch all data (now complete) from DB
    kpi_data = {}
    for code in kpi_codes:
        kpi = db.query(KPI).filter_by(code=code).first()
        if not kpi:
            continue
        records = db.query(KPIValue).filter(
            KPIValue.instrument_id == instrument.id,
            KPIValue.kpi_id == kpi.id,
            KPIValue.frequency == frequency,
            KPIValue.date >= start,
            KPIValue.date <= end
        ).order_by(KPIValue.date).all()
        kpi_data[kpi.name or kpi.code] = [
            {
                "date": str(r.date),
                "value": r.value,
                "currency": r.currency
            }
            for r in records
        ]

    return {
        "symbol": symbol,
        "frequency": frequency,
        "start_date": str(start),
        "end_date": str(end),
        "kpis": kpi_data
    }


@app.get("/kpi-options/")
def get_kpi_options(db: Session = Depends(get_db)):
    kpis = db.query(KPI).all()
    return [{"code": k.code, "name": k.name or k.code} for k in kpis]

@app.get("/symbol-search/")
def symbol_search(query: str = Query(...), db: Session = Depends(get_db)):
    results = (
        db.query(Instrument)
        .filter(Instrument.symbol.ilike(f"%{query}%"))
        .limit(10)
        .all()
    )
    return [{"symbol": i.symbol, "name": i.name} for i in results]


def check_trend_condition(values, trend_type, quarters, threshold, direction='positive'):
    data_list = [v.value for v in values if v.value is not None]

    if len(data_list) < quarters:
        return False

    if trend_type == 'consecutive_growth':
        # Check for consecutive growth or decline
        return check_consecutive_growth(data_list, quarters, threshold, direction)

    elif trend_type == 'negative_to_positive':
        # Check for negative to positive transition
        return check_negative_to_positive(data_list, quarters)

    elif trend_type == 'yoy_growth':
        # Check for year-over-year growth
        return check_yoy_growth(data_list, quarters, threshold, direction)

    elif trend_type == 'post_transition_growth':
        # Check for growth after a negative to positive transition
        return check_post_transition_growth(data_list, quarters, threshold)

    elif trend_type == 'absolute_threshold':
        # Check for absolute threshold condition
        return check_absolute_threshold(data_list, threshold, direction)

    return False

def check_absolute_threshold(data_list, threshold, direction='positive'):
    if direction == 'positive':
        return any(value is not None and value >= threshold for value in data_list)
    else:
        return any(value is not None and value <= threshold for value in data_list)

def check_consecutive_growth(data_list, quarters, threshold, direction='positive'):
    if len(data_list) < quarters:
        return False

    growth_list = [
        ((curr - prev) / abs(prev)) * 100
        for prev, curr in zip(data_list, data_list[1:])
        if prev is not None and curr is not None and prev != 0
    ]

    if direction == 'negative':
        compare = lambda g: g <= -threshold
    else:
        compare = lambda g: g >= threshold

    return any(
        all(compare(g) for g in growth_list[i:i + quarters - 1])
        for i in range(len(growth_list) - quarters + 2)
    )


def check_negative_to_positive(data_list, quarters):
    if len(data_list) < quarters:
        return False

    for i in range(len(data_list) - quarters + 1):
        segment = data_list[i:i + quarters]
        half = quarters // 2
        if any(x is not None and x < 0 for x in segment[:half]) and all(
            x is not None and x > 0 for x in segment[half:]
        ):
            return True
    return False


def check_yoy_growth(data_list, quarters, threshold, direction='positive'):
    if len(data_list) < quarters + 4:
        return False

    count = 0
    for i in range(4, len(data_list)):
        if data_list[i] is not None and data_list[i - 4] is not None and data_list[i - 4] != 0:
            growth = ((data_list[i] - data_list[i - 4]) / abs(data_list[i - 4])) * 100
            if direction == 'negative' and growth <= -threshold:
                count += 1
            elif direction != 'negative' and growth >= threshold:
                count += 1
            else:
                count = 0

            if count >= quarters:
                return True
    return False


def check_post_transition_growth(data_list, quarters_after=2, threshold=0):
    for i in range(len(data_list) - quarters_after - 1):
        if data_list[i] is not None and data_list[i] < 0 and data_list[i + 1] is not None and data_list[i + 1] > 0:
            growth_ok = True
            for j in range(i + 1, i + 1 + quarters_after):
                if data_list[j] is None or data_list[j - 1] is None or data_list[j - 1] == 0:
                    growth_ok = False
                    break
                growth = ((data_list[j] - data_list[j - 1]) / abs(data_list[j - 1])) * 100
                if growth < threshold:
                    growth_ok = False
                    break
            if growth_ok:
                return True
    return False


def apply_advanced_filter(db, instrument_id, kpi_code, filter_data, frequency, start_date, end_date):
    """Apply trend logic for advanced filters"""
    kpi = db.query(models.KPI).filter_by(code=kpi_code).first()
    if not kpi:
        return False, None

    values = db.query(models.KPIValue).filter(
        models.KPIValue.instrument_id == instrument_id,
        models.KPIValue.kpi_id == kpi.id,
        models.KPIValue.frequency == frequency,
        models.KPIValue.date >= start_date,
        models.KPIValue.date <= end_date
    ).order_by(models.KPIValue.date.asc()).all()

    if not values or len(values) < 2:
        return False, None

    trend_type = filter_data.get("trend")
    quarters = int(filter_data.get("quarters", 2))
    threshold = float(filter_data.get("threshold", 0))
    direction = filter_data.get("direction", "positive")

    passed = check_trend_condition(values, trend_type, quarters, threshold, direction)
    if passed:
        latest_value = max(values, key=lambda v: v.date)
        return True, latest_value
    return False, None


@app.get("/filter-stocks/")
def filter_stocks_get(
    frequency: str,
    start_date: str,
    end_date: str,
    logical_operator: str,
    filters: List[str] = Query(...),
    # logical_operator: str = Query("AND"),  # Default is AND
    db: Session = Depends(get_db)
):
    parsed_filters = [json.loads(f) for f in filters]
    print('logical_operator ===>', logical_operator)
    instruments = db.query(models.Instrument).all()
    matching_results = []

    for instrument in instruments:
        instrument_kpi_data = []
        filter_pass_count = 0

        for f in parsed_filters:
            kpi_code = f["kpi_code"]
            if "trend" in f:  # advanced filter
                passed, kpi_value_obj = apply_advanced_filter(
                    db, instrument.id, kpi_code, f, frequency, start_date, end_date
                )
                if passed:
                    filter_pass_count += 1
                    instrument_kpi_data.append({
                        "kpi_code": kpi_code,
                        "value": kpi_value_obj.value,
                        "date": str(kpi_value_obj.date)
                    })
            else:
                # Basic KPI filter
                operator = f["operator"]
                value = float(f["value"])

                kpi = db.query(models.KPI).filter_by(code=kpi_code).first()
                if not kpi:
                    continue

                kpi_values_query = db.query(models.KPIValue).filter(
                    models.KPIValue.instrument_id == instrument.id,
                    models.KPIValue.kpi_id == kpi.id,
                    models.KPIValue.frequency == frequency,
                    models.KPIValue.date >= start_date,
                    models.KPIValue.date <= end_date
                )

                # Apply operator filter
                if operator == "gt":
                    kpi_values_query = kpi_values_query.filter(models.KPIValue.value > value)
                elif operator == "gte":
                    kpi_values_query = kpi_values_query.filter(models.KPIValue.value >= value)
                elif operator == "lt":
                    kpi_values_query = kpi_values_query.filter(models.KPIValue.value < value)
                elif operator == "lte":
                    kpi_values_query = kpi_values_query.filter(models.KPIValue.value <= value)
                elif operator == "eq":
                    kpi_values_query = kpi_values_query.filter(models.KPIValue.value == value)
                elif operator == "ne":
                    kpi_values_query = kpi_values_query.filter(models.KPIValue.value != value)
                else:
                    continue  # Invalid operator

                kpi_value_obj = kpi_values_query.order_by(models.KPIValue.date.desc()).first()
                if kpi_value_obj:
                    filter_pass_count += 1
                    instrument_kpi_data.append({
                        "kpi_code": kpi_code,
                        "value": kpi_value_obj.value,
                        "date": str(kpi_value_obj.date)
                    })

        # Combine filters using AND / OR logic
        if logical_operator.upper() == "AND" and filter_pass_count == len(parsed_filters):
            matching_results.append({
                "symbol": instrument.symbol,
                "kpi_values": instrument_kpi_data
            })
        elif logical_operator.upper() == "OR" and filter_pass_count > 0:
            matching_results.append({
                "symbol": instrument.symbol,
                "kpi_values": instrument_kpi_data
            })
    print('matching_results', matching_results)
    return {"symbols": matching_results}


@app.get("/routes")
def list_routes():
    return [route.path for route in app.routes]