import asyncio
import aiohttp
import time
import math
import os
import json
from datetime import datetime, timezone

from mempool import fetch_mempool, normalize_mempool
from engine import detect_signal
from logger import append_log

BINANCE_API = "https://api.binance.us"
SYMBOL = "BTCUSDT"

WINDOWS = [1, 5, 15]
ROLLING_CVD_FILE = "rolling_cvd_v6.json"
POSITION_STATE_FILE = "position_state.json"


# =====================================================
# POSITION STATE
# =====================================================

def load_position_state():
    if not os.path.exists(POSITION_STATE_FILE):
        return {"has_position": False, "buy_price": None}
    try:
        with open(POSITION_STATE_FILE, "r", encoding="utf-8") as f:
            return json.load(f)
    except:
        return {"has_position": False, "buy_price": None}


def save_position_state(has_position, buy_price=None):
    with open(POSITION_STATE_FILE, "w", encoding="utf-8") as f:
        json.dump({"has_position": has_position, "buy_price": buy_price}, f, indent=2)


# =====================================================
# HTTP
# =====================================================

async def fetch_json(session, url, params=None, timeout=8):
    try:
        async with session.get(url, params=params, timeout=timeout) as r:
            r.raise_for_status()
            return await r.json()
    except Exception as e:
        print(f"[HTTP ERROR] {url} -> {e}")
        return None


# =====================================================
# BINANCE FETCHERS
# =====================================================

async def fetch_binance_trades_v6(session):
    url = f"{BINANCE_API}/api/v3/aggTrades"
    now_ms = int(time.time() * 1000)
    start_ms = now_ms - 15 * 60 * 1000
    end_ms = now_ms

    out = []
    cursor = start_ms

    while True:
        params = {"symbol": SYMBOL, "startTime": cursor, "endTime": end_ms, "limit": 1000}
        data = await fetch_json(session, url, params=params)
        if not data:
            break

        for t in data:
            ts = t["T"] / 1000
            side = "sell" if t["m"] else "buy"
            out.append({
                "ts": ts,
                "side": side,
                "qty": float(t["q"]),
                "price": float(t["p"])
            })

        if len(data) < 1000:
            break

        last_ts = max(d["T"] for d in data)
        cursor = last_ts + 1
        if cursor >= end_ms:
            break

    out.sort(key=lambda x: x["ts"])
    return out


async def fetch_binance_orderbook(session):
    return await fetch_json(session, f"{BINANCE_API}/api/v3/depth", {"symbol": SYMBOL, "limit": 100})


async def fetch_binance_24h(session):
    return await fetch_json(session, f"{BINANCE_API}/api/v3/ticker/24hr", {"symbol": SYMBOL})


# =====================================================
# ORDERBOOK HELPERS
# =====================================================

def build_orderbook_info(raw):
    if not raw or "bids" not in raw or "asks" not in raw:
        return None

    try:
        bids = [(float(p), float(q)) for p, q, *_ in raw["bids"]]
        asks = [(float(p), float(q)) for p, q, *_ in raw["asks"]]
    except:
        return None

    bids = sorted(bids, key=lambda x: x[0], reverse=True)
    asks = sorted(asks, key=lambda x: x[0])

    best_bid, bid_sz = bids[0]
    best_ask, ask_sz = asks[0]
    mid = (best_bid + best_ask) / 2

    depth_liq_range = 0.002
    low = mid * (1 - depth_liq_range)
    high = mid * (1 + depth_liq_range)

    bid_liq = sum(p * q for p, q in bids if p >= low)
    ask_liq = sum(p * q for p, q in asks if p <= high)

    depth_ratio = bid_liq / ask_liq if ask_liq > 0 else float("inf")
    best_level_ratio = bid_sz / ask_sz if ask_sz > 0 else float("inf")

    return {
        "mid": mid,
        "depth_ratio": depth_ratio,
        "best_level_ratio": best_level_ratio,
    }


# =====================================================
# MULTI-TF STATS, VOLUME, WHALES
# =====================================================

def compute_multi_tf_stats(trades):
    out = {}
    now_ts = time.time()
    for w in WINDOWS:
        win = [t for t in trades if t["ts"] >= now_ts - w * 60]
        buy = sum(t["qty"] for t in win if t["side"] == "buy")
        sell = sum(t["qty"] for t in win if t["side"] == "sell")
        total = buy + sell
        out[w] = {
            "buy": buy,
            "sell": sell,
            "total": total,
            "obi": buy / sell if sell else float("inf"),
            "buy_ratio": buy / total if total else 0.5
        }
    return out


def classify_volume_regime(ratio_5m):
    if ratio_5m < 0.5: return "LOW"
    if ratio_5m < 1.5: return "NORMAL"
    if ratio_5m < 3.0: return "HIGH"
    return "EXTREME"


def classify_volatility(prices):
    if len(prices) < 2:
        return "QUIET"
    returns = [(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
    std = math.sqrt(sum(r*r for r in returns) / len(returns))
    if std < 0.0005: return "QUIET"
    if std < 0.001: return "NORMAL"
    if std < 0.002: return "MID"
    if std < 0.004: return "HIGH"
    return "EXTREME"


def calc_whale_bias_usd(trades, price):
    if price <= 0:
        return 0
    now = time.time()
    recent = [t for t in trades if t["ts"] >= now - 300]
    long_usd = sum(t["qty"] * price for t in recent if t["side"] == "buy" and t["qty"] >= 5)
    short_usd = sum(t["qty"] * price for t in recent if t["side"] == "sell" and t["qty"] >= 5)
    return long_usd - short_usd


# =====================================================
# CVD ROLLING
# =====================================================

def _load_cvd_state():
    if not os.path.exists(ROLLING_CVD_FILE):
        return {"entries": [], "cvd": 0}
    try:
        return json.load(open(ROLLING_CVD_FILE))
    except:
        return {"entries": [], "cvd": 0}


def _save_cvd_state(s):
    json.dump(s, open(ROLLING_CVD_FILE, "w"), indent=2)


def update_rolling_cvd_v6(trades):
    s = _load_cvd_state()
    entries = s.get("entries", [])
    now = time.time()
    entries = [e for e in entries if e["ts"] >= now - 4 * 3600]

    for t in trades:
        delta = t["qty"] if t["side"] == "buy" else -t["qty"]
        entries.append({"ts": t["ts"], "delta": delta})

    entries.sort(key=lambda x: x["ts"])

    cvd = 0
    for e in entries:
        cvd += e["delta"]
        e["cvd"] = cvd

    recent = [e for e in entries if e["ts"] >= now - 300]
    if len(recent) >= 2:
        cvd_delta_5m = recent[-1]["cvd"] - recent[0]["cvd"]
    else:
        cvd_delta_5m = 0

    s = {"entries": entries, "cvd": cvd, "cvd_delta_5m": cvd_delta_5m}
    _save_cvd_state(s)
    return s


# =====================================================
# BUILD STATE for ENGINE
# =====================================================

async def fetch_state_v6(session):
    trades = await fetch_binance_trades_v6(session)
    if not trades:
        return None

    orderbook_raw = await fetch_binance_orderbook(session)
    ob = build_orderbook_info(orderbook_raw)
    if not ob:
        return None

    price = ob["mid"]

    cvd_s = update_rolling_cvd_v6(trades)
    cvd_delta_5m = cvd_s["cvd_delta_5m"]

    mf = compute_multi_tf_stats(trades)
    stats_5m = mf[5]
    stats_1m = mf[1]
    stats_15m = mf[15]

    now = time.time()
    trades_5 = [t for t in trades if t["ts"] >= now - 300]
    if trades_5:
        price_change_5m = (trades_5[-1]["price"] - trades_5[0]["price"]) / trades_5[0]["price"] * 100
    else:
        price_change_5m = 0

    ticker = await fetch_binance_24h(session)
    vol24 = float(ticker["volume"]) if ticker and "volume" in ticker else 1
    avg_1m = vol24 / 1440
    avg_5m = avg_1m * 5
    avg_15m = avg_1m * 15

    ratio_1m = stats_1m["total"] / avg_1m if avg_1m > 0 else 0
    ratio_5m = stats_5m["total"] / avg_5m if avg_5m > 0 else 0
    ratio_15m = stats_15m["total"] / avg_15m if avg_15m > 0 else 0

    vol_price = classify_volatility([t["price"] for t in trades])
    vol_regime_vol = classify_volume_regime(ratio_5m)

    whale_bias = calc_whale_bias_usd(trades, price)

    pos = load_position_state()

    return {
        "price": price,
        "buy_ratio": stats_5m["buy_ratio"],
        "smooth": None,
        "cvd": cvd_delta_5m,
        "whale": whale_bias,
        "depth_ratio": ob["depth_ratio"],
        "best_level_ratio": ob["best_level_ratio"],

        "vol_price": vol_price,
        "vol_regime_vol": vol_regime_vol,

        "spoof": {"bid_spoof": False, "ask_spoof": False},
        "iceberg": {"iceberg_buy": False, "iceberg_sell": False},

        "has_position": pos["has_position"],
        "buy_price": pos["buy_price"],

        "mf_stats": mf,
        "ratio_1m": ratio_1m,
        "ratio_5m": ratio_5m,
        "ratio_15m": ratio_15m,
        "price_change_pct_5m": price_change_5m,
    }


# =====================================================
# CYCLE
# =====================================================

async def run_cycle():
    async with aiohttp.ClientSession() as session:
        mp = await fetch_mempool(session)
        f_mem, f_fee, f_block = normalize_mempool(mp)

        state = await fetch_state_v6(session)
        if not state:
            print("[STATE ERROR] Failed to build state")
            return

        state["onchain"] = (f_mem, f_fee, f_block)

        decision = detect_signal(state)

        # Apply BUY/SELL
        if decision["action"] == "BUY":
            save_position_state(True, state["price"])
            print(f">>> BUY @ {state['price']:.2f}")

        elif decision["action"] == "SELL":
            save_position_state(False)
            pnl = decision.get("pnl_pct", 0)
            print(f">>> SELL (PnL {pnl:+.2f}%)")

        # Log
        full_log = {
            "timestamp": datetime.now(timezone.utc).isoformat(),
            "action": decision["action"],
            "reason": decision["reason"],

            "trend_score": decision.get("trend_score", 0),
            "trend_score_base": decision.get("trend_score_base", 0),
            "net_bias": decision.get("net_bias", 0),
            "net_bias_base": decision.get("net_bias_base", 0),

            "state": state,
            "mempool": mp,
            "onchain_factors": {
                "f_mempool": f_mem,
                "f_fee": f_fee,
                "f_block_tx": f_block
            }
        }

        if "pnl_pct" in decision:
            full_log["pnl_pct"] = decision["pnl_pct"]

        append_log(full_log)

        print(
            f"[{full_log['timestamp']}] → "
            f"{decision['action']} | "
            f"T={full_log['trend_score']:.1f} | "
            f"B={full_log['net_bias']:.1f} | "
            f"price={state['price']:.2f}"
        )


async def main():
    cycle = 0
    while True:
        cycle += 1
        print("=" * 60)
        print(f"Cycle #{cycle} – BOT V10")
        print("=" * 60)

        start = time.time()

        try:
            await run_cycle()
        except Exception as e:
            print(f"[CYCLE ERROR] {e}")

        print(f"[Cycle {cycle}] done in {time.time() - start:.2f}s\n")
        await asyncio.sleep(10)


if __name__ == "__main__":
    asyncio.run(main())
