from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from binance.client import Client
import os
import pandas as pd
import pandas_ta as ta
import asyncio
import json
import math
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from openai import OpenAI

# Binance + OpenAI Keys aus Env
api_key = os.environ.get('binance_api', '')
api_secret = os.environ.get('binance_secret', '')
openai_api_key = os.environ.get('OPENAI_API_KEY', '').strip()

client = Client(api_key, api_secret)
openai_client = OpenAI(api_key=openai_api_key) if openai_api_key else None

app = FastAPI()

TIMEFRAME_INTERVALS = {
    '1m': Client.KLINE_INTERVAL_1MINUTE,
    '3m': Client.KLINE_INTERVAL_3MINUTE,
    '5m': Client.KLINE_INTERVAL_5MINUTE,
    '15m': Client.KLINE_INTERVAL_15MINUTE,
    '1h': Client.KLINE_INTERVAL_1HOUR,
    '4h': Client.KLINE_INTERVAL_4HOUR,
    '1d': Client.KLINE_INTERVAL_1DAY,
    '1w': Client.KLINE_INTERVAL_1WEEK
}
TIMEFRAMES = list(TIMEFRAME_INTERVALS.keys())
FETCH_LIMIT = 500
DISPLAY_ROWS = 100
ANALYSIS_LOG_PATH = Path("data/ai_analysis_history.jsonl")
LAST_ANALYSIS_FILE = Path("data/last_analysis.json")

# Static Files
app.mount("/static", StaticFiles(directory="static"), name="static")


class ConnectionManager:
    def __init__(self):
        self.active_connections: list[WebSocket] = []

    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections.append(websocket)

    def disconnect(self, websocket: WebSocket):
        if websocket in self.active_connections:
            self.active_connections.remove(websocket)

    async def send_personal_message(self, message: str, websocket: WebSocket):
        await websocket.send_text(message)

    async def broadcast(self, message: str):
        for connection in self.active_connections:
            await connection.send_text(message)


manager = ConnectionManager()

ALLOWED_INDICATORS = {
    "ema": "EMA 20/50/200 (Trendrichtung)",
    "rsi": "RSI 14 (Momentum)",
    "macd": "MACD 12/26/9 (Signal & Histogramm)",
    "vwap": "VWAP (durchschnittlicher Preis)",
    "volume": "Volumen, Quote-Volumen & Taker Flow"
}
DEFAULT_INDICATORS = ["ema", "rsi", "macd"]
ALLOWED_MODELS = {
    "gpt-4o": "GPT-4o",
    "gpt-4o-mini": "GPT-4o Mini",
    "gpt-4.1-mini": "GPT-4.1 Mini"
}
DEFAULT_MODEL = "gpt-4o"
last_analysis_cache: dict[str, Any] | None = None


def sanitize_analysis_text(text: str) -> str:
    if not isinstance(text, str):
        return ""
    cleaned = text.translate(str.maketrans("", "", "*#"))
    lines = []
    for raw_line in cleaned.splitlines():
        line = raw_line.strip()
        line = line.lstrip("-• ").strip()
        if line:
            lines.append(line)
    return "\n\n".join(lines)


def format_sections_text(sections: dict[str, Any]) -> str:
    if not isinstance(sections, dict):
        return ""
    parts: list[str] = []
    trend = sections.get("trend")
    if isinstance(trend, dict):
        short = trend.get("kurzfristig") or trend.get("kurz")
        mid = trend.get("mittelfristig") or trend.get("mittel")
        segment = []
        if short:
            segment.append(f"Kurzer Trend: {short}")
        if mid:
            segment.append(f"Mittel- bis Langfristig: {mid}")
        if segment:
            parts.append("\n".join(segment))
    elif isinstance(trend, str):
        parts.append(f"Trend: {trend}")

    indikatoren = sections.get("indikatoren")
    if isinstance(indikatoren, list):
        statements = []
        for item in indikatoren:
            if isinstance(item, dict):
                name = item.get("name") or item.get("indikator")
                note = item.get("aussage") or item.get("bewertung")
                if name and note:
                    statements.append(f"{name}: {note}")
                elif note:
                    statements.append(note)
            elif isinstance(item, str):
                statements.append(item)
        if statements:
            parts.append("Signale:\n" + "\n".join(statements))

    risks = sections.get("risiken")
    if isinstance(risks, str):
        parts.append(f"Risiken: {risks}")

    hints = sections.get("hinweise")
    if isinstance(hints, str):
        parts.append(f"Hinweise: {hints}")

    security = sections.get("sicherheit") or sections.get("sicherheit_prozent")
    if security is not None:
        parts.append(f"Eingeschätzte Sicherheit: {security}%")

    return "\n\n".join(parts)


def append_analysis_log(entry: dict):
    global last_analysis_cache
    try:
        ANALYSIS_LOG_PATH.parent.mkdir(parents=True, exist_ok=True)
        with ANALYSIS_LOG_PATH.open("a", encoding="utf-8") as log_file:
            log_file.write(json.dumps(entry, ensure_ascii=False) + "\n")
    except Exception as exc:
        print(f"Analysis log error: {exc}")
    last_analysis_cache = entry
    save_last_analysis_file(entry)


def read_last_analysis_entry():
    if not ANALYSIS_LOG_PATH.exists():
        return None
    try:
        with ANALYSIS_LOG_PATH.open("r", encoding="utf-8") as log_file:
            lines = log_file.read().splitlines()
        for raw_line in reversed(lines):
            raw_line = raw_line.strip()
            if not raw_line:
                continue
            try:
                return json.loads(raw_line)
            except json.JSONDecodeError:
                continue
    except Exception as exc:
        print(f"Analysis read error: {exc}")
    return None


def load_last_analysis_from_file():
    if not LAST_ANALYSIS_FILE.exists():
        return None
    try:
        with LAST_ANALYSIS_FILE.open("r", encoding="utf-8") as source:
            return json.load(source)
    except Exception as exc:
        print(f"Load last analysis error: {exc}")
        return None


def save_last_analysis_file(entry: dict):
    try:
        LAST_ANALYSIS_FILE.parent.mkdir(parents=True, exist_ok=True)
        with LAST_ANALYSIS_FILE.open("w", encoding="utf-8") as target:
            json.dump(entry, target, ensure_ascii=False, indent=2)
    except Exception as exc:
        print(f"Save last analysis error: {exc}")


last_analysis_cache = load_last_analysis_from_file()


def _process_klines(klines):
    """
    Binance-Klines -> Indicators + bereinigte Records
    """
    df = pd.DataFrame(klines, columns=[
        'open_time', 'open', 'high', 'low', 'close', 'volume',
        'close_time', 'quote_asset_volume', 'number_of_trades',
        'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'ignore'
    ])

    df['open_time'] = df['open_time'].astype(int)
    df[['open', 'high', 'low', 'close', 'volume']] = df[['open', 'high', 'low', 'close', 'volume']].astype(float)
    df[['quote_asset_volume', 'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume']] = df[['quote_asset_volume', 'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume']].astype(float)
    df['number_of_trades'] = df['number_of_trades'].astype(int)

    # VWAP
    df['typical_price'] = (df['high'] + df['low'] + df['close']) / 3
    df['tpv'] = df['typical_price'] * df['volume']
    df['cum_tpv'] = df['tpv'].cumsum()
    df['cum_volume'] = df['volume'].cumsum()
    df['VWAP'] = df['cum_tpv'] / df['cum_volume'].replace({0: None})

    # Indikatoren
    df.ta.rsi(append=True)
    df.ta.macd(append=True)
    df.ta.ema(length=20, append=True)
    df.ta.ema(length=50, append=True)
    df.ta.ema(length=200, append=True)

    output_df = df[[
        'open_time', 'open', 'high', 'low', 'close', 'volume',
        'RSI_14',
        'MACD_12_26_9', 'MACDh_12_26_9', 'MACDs_12_26_9',
        'EMA_20', 'EMA_50', 'EMA_200',
        'VWAP',
        'quote_asset_volume', 'number_of_trades',
        'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume'
    ]].tail(DISPLAY_ROWS)

    records = output_df.to_dict(orient='records')
    for row in records:
        for key, value in row.items():
            if isinstance(value, float) and math.isnan(value):
                row[key] = None

    return records


def _fetch_timeframe_records(timeframe: str):
    klines = client.get_klines(symbol='BTCUSDT', interval=TIMEFRAME_INTERVALS[timeframe], limit=FETCH_LIMIT)
    return _process_klines(klines)


async def get_latest_data():
    """
    Holt alle Timeframes parallel und gibt JSON-String zurück.
    """
    tasks = [asyncio.to_thread(_fetch_timeframe_records, tf) for tf in TIMEFRAMES]
    all_records = await asyncio.gather(*tasks)
    payload = {tf: rec for tf, rec in zip(TIMEFRAMES, all_records)}
    return json.dumps(payload)


async def get_current_price():
    ticker = await asyncio.to_thread(client.get_ticker, symbol='BTCUSDT')
    last_price = ticker.get('lastPrice')
    price_change = ticker.get('priceChange')
    price_change_percent = ticker.get('priceChangePercent')

    def to_float(value):
        try:
            return float(value)
        except (TypeError, ValueError):
            return None

    return {
        'symbol': 'BTCUSDT',
        'price': to_float(last_price),
        'priceChange': to_float(price_change),
        'priceChangePercent': to_float(price_change_percent)
    }


async def data_broadcaster():
    """
    Holt periodisch Daten und broadcastet über WebSocket.
    """
    while True:
        try:
            message = await get_latest_data()
            await manager.broadcast(message)
        except Exception as e:
            print(f"Error in data_broadcaster: {e}")

        await asyncio.sleep(60)  # alle 60s


@app.on_event("startup")
async def startup_event():
    asyncio.create_task(data_broadcaster())


@app.get("/", response_class=HTMLResponse)
async def read_root():
    with open("static/index.html", "r") as f:
        return HTMLResponse(content=f.read())


@app.get("/data")
async def get_data_endpoint():
    data_json = await get_latest_data()
    return json.loads(data_json)


@app.get("/price")
async def price_endpoint():
    return await get_current_price()


@app.get("/analysis/last")
async def last_analysis_endpoint():
    global last_analysis_cache
    if last_analysis_cache:
        return last_analysis_cache
    entry = read_last_analysis_entry()
    if not entry:
        raise HTTPException(status_code=404, detail="No cached analysis available.")
    last_analysis_cache = entry
    return last_analysis_cache


@app.post("/analyze")
async def analyze_endpoint(request: Request):
    # OpenAI initial check
    if not openai_client:
        raise HTTPException(status_code=503, detail="OpenAI API key nicht konfiguriert.")

    body = await request.json()

    # mehrere Timeframes möglich
    timeframes = body.get("timeframes")
    if not timeframes:
        tf = body.get("timeframe", "1m")
        timeframes = [tf]

    # validieren
    for tf in timeframes:
        if tf not in TIMEFRAMES:
            raise HTTPException(status_code=400, detail=f"Invalid timeframe: {tf}")

    raw_indicators = body.get("indicators")
    selected_indicators: list[str] = []
    if isinstance(raw_indicators, list):
        for item in raw_indicators:
            if not isinstance(item, str):
                continue
            key = item.lower()
            if key in ALLOWED_INDICATORS and key not in selected_indicators:
                selected_indicators.append(key)
    if not selected_indicators:
        selected_indicators = DEFAULT_INDICATORS.copy()

    requested_model = body.get("model", DEFAULT_MODEL)
    model = requested_model if requested_model in ALLOWED_MODELS else DEFAULT_MODEL

    depth_value = body.get("depth", 20)
    try:
        depth_value = int(depth_value)
    except (TypeError, ValueError):
        depth_value = 20
    depth_value = max(5, min(depth_value, 100))

    raw_news_focus = body.get("news_focus", False)
    if isinstance(raw_news_focus, str):
        news_focus = raw_news_focus.strip().lower() in {"1", "true", "yes", "on"}
    else:
        news_focus = bool(raw_news_focus)

    data_json = await get_latest_data()
    all_data = json.loads(data_json)

    combined = {}

    for tf in timeframes:
        data = all_data.get(tf, [])
        if not data:
            continue
        recent = data[-depth_value:]

        rows = []
        for row in recent:
            entry = {
                "time": row["open_time"],
                "open": row["open"],
                "high": row["high"],
                "low": row["low"],
                "close": row["close"],
                "volume": row["volume"],
            }
            if "ema" in selected_indicators:
                entry.update({
                    "ema20": row.get("EMA_20"),
                    "ema50": row.get("EMA_50"),
                    "ema200": row.get("EMA_200"),
                })
            if "rsi" in selected_indicators:
                entry["rsi"] = row.get("RSI_14")
            if "macd" in selected_indicators:
                entry.update({
                    "macd": row.get("MACD_12_26_9"),
                    "macd_signal": row.get("MACDs_12_26_9"),
                    "macd_hist": row.get("MACDh_12_26_9"),
                })
            if "vwap" in selected_indicators:
                entry["vwap"] = row.get("VWAP")
            if "volume" in selected_indicators:
                entry.update({
                    "quote_volume": row.get("quote_asset_volume"),
                    "taker_buy_base": row.get("taker_buy_base_asset_volume"),
                    "taker_buy_quote": row.get("taker_buy_quote_asset_volume"),
                    "trades": row.get("number_of_trades"),
                })
            rows.append(entry)

        combined[tf] = rows

    if not combined:
        raise HTTPException(status_code=404, detail="No data to analyze.")

    timeframe_text = ", ".join(timeframes)
    indicator_text = ", ".join(ALLOWED_INDICATORS[ind] for ind in selected_indicators if ind in ALLOWED_INDICATORS)
    if not indicator_text:
        indicator_text = "Standard-Kerzeninformationen (Open, High, Low, Close, Volumen)"
    depth_text = f"Die Daten enthalten pro Timeframe die letzten {depth_value} Kerzen."
    news_clause = (
        "Füge außerdem eine kurze Einordnung hinzu, ob aktuelle Makro-, On-Chain- oder News-Impulse die kurzfristige Kursreaktion beeinflussen könnten. "
        "Wenn du keine Fakten hast, schreibe explizit, dass dir dazu keine sicheren Hinweise vorliegen."
    ) if news_focus else ""

    prompt = (
        "Du bist ein technischer Analyst für Bitcoin (BTCUSDT). Andere Coins oder Märkte dürfen nicht erwähnt werden.\n"
        f"Arbeite ausschließlich mit den folgenden Timeframes: {timeframe_text}.\n"
        f"{depth_text}\n"
        f"Berücksichtige für deine Signale diese Kennzahlen: {indicator_text}.\n"
        f"{news_clause}\n"
        "Strukturiere deine Antwort logisch.\n"
        "Liefere deine Antwort ausschließlich als JSON mit exakt folgendem Schema (keine zusätzliche Erklärung oder Text außerhalb des JSON):\n"
        "{\n"
        '  "trend": {\n'
        '    "kurzfristig": "Beschreibung des kurzfristigen Trends",\n'
        '    "mittelfristig": "Beschreibung des mittelfristigen Trends"\n'
        "  },\n"
        '  "indikatoren": [\n'
        '    {"name": "EMA", "aussage": "Kernaussage zu diesem Indikator"},\n'
        '    {"name": "RSI", "aussage": "Kernaussage"}\n'
        "  ],\n"
        '  "risiken": "Entscheidende Risiken oder kritische Preiszonen",\n'
        '  "hinweise": "Makro-/News-/On-Chain-Hinweise oder leer, falls keine",\n'
        '  "sicherheit": 72\n'
        "}\n"
        "Nutze die verfügbaren Daten, um jede Eigenschaft informativ zu füllen. Abschliessend teilst du deine eigene Handelempfehlung mit Confidece einschätzung und grund warum [bullish|neutral|bearisch].\n\n"
        "Daten:\n"
    )

    input_text = prompt + json.dumps(combined)

    try:
        response = openai_client.responses.create(
            model=model,
            input=input_text,
            max_output_tokens=400,
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"OpenAI Fehler: {e}")

    # Text extrahieren
    raw_output = getattr(response, "output_text", None)
    structured_sections: dict[str, Any] | None = None
    if raw_output:
        try:
            structured_sections = json.loads(raw_output)
        except json.JSONDecodeError:
            structured_sections = None
    analysis_text: str
    if structured_sections:
        formatted = format_sections_text(structured_sections)
        analysis_text = formatted or sanitize_analysis_text(json.dumps(structured_sections, ensure_ascii=False))
    else:
        fallback_text = raw_output
        if not fallback_text:
            try:
                dumped = response.model_dump()
                fallback_text = json.dumps(dumped, indent=2)
            except Exception:
                fallback_text = "Keine lesbare Analyse erzeugt."
        analysis_text = sanitize_analysis_text(fallback_text)

    # Token-Usage extrahieren
    usage = getattr(response, "usage", None)
    input_tokens = getattr(usage, "input_tokens", None) if usage else None
    output_tokens = getattr(usage, "output_tokens", None) if usage else None

    log_entry = {
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "timeframes": timeframes,
        "indicators": selected_indicators,
        "depth": depth_value,
        "news_focus": news_focus,
        "model": model,
        "analysis": analysis_text,
        "sections": structured_sections,
        "usage": {
            "input_tokens": input_tokens,
            "output_tokens": output_tokens
        }
    }
    append_analysis_log(log_entry)

    return {
        "analysis": analysis_text,
        "model": model,
        "sections": structured_sections,
        "usage": {
            "input_tokens": input_tokens,
            "output_tokens": output_tokens
        }
    }


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await manager.connect(websocket)
    try:
        # Direkt Daten beim Connect schicken
        initial_data = await get_latest_data()
        await manager.send_personal_message(initial_data, websocket)

        # Verbindung halten, Broadcaster pusht Updates
        while True:
            await websocket.receive_text()
    except WebSocketDisconnect:
        manager.disconnect(websocket)
    except Exception as e:
        print(f"WebSocket error: {e}")
        manager.disconnect(websocket)
