# multi_coin_collector.py - Multi-Coin Data Collection Bot via Binance WebSocket
# Pure data collection - NO TRADING LOGIC
# Collection: Real-time via WebSocket + periodic REST API calls
# Compatible with Python 3.9+

import time
import asyncio
import aiohttp
import websockets
import json
import requests
import numpy as np
from collections import deque, defaultdict
from datetime import datetime
from typing import Optional, Dict, List
import traceback

# Import configuration
from config import (
    SYMBOLS,
    REST_INTERVAL,
    WS_MAX_RETRIES,
    WS_RETRY_DELAY,
    DATA_DIR_BASE,
    WHALE_THRESHOLD_USD,
    API_TIMEOUT,
    MEMPOOL_SYMBOLS,
    SUMMARY_INTERVAL_RECORDS,
    PRICE_WINDOW_SIZE,
    PRICE_HISTORY_SIZE,
    VOLUME_WINDOW_SIZE,
    TRADES_WINDOW_SIZE,
    CVD_HISTORY_SIZE,
    WHALE_HISTORY_SIZE,
    ORDERBOOK_LIMIT,
    AGGTRADES_LIMIT,
)

from indicators_extended import (
    calc_all_price_changes,
    calc_all_ratios,
    calc_volatility,
    calc_smooth_multi_tf,
    calc_cvd_extended,
    calc_whale_extended,
    calc_orderbook_extended,
    calc_volume_extended,
    calc_market_microstructure,
    calc_composite_indicators,
)
from mempool_extended import fetch_mempool_extended, normalize_mempool_extended
from logger_daily_multi import append_daily_log, print_daily_summary

# ================================================================
# GLOBAL STATE (PER SYMBOL)
# ================================================================

class SymbolData:
    """Data container for a single symbol"""
    def __init__(self, symbol: str):
        self.symbol = symbol
        self.price_window = deque(maxlen=PRICE_WINDOW_SIZE)
        self.price_history = deque(maxlen=PRICE_HISTORY_SIZE)
        self.volume_window = deque(maxlen=VOLUME_WINDOW_SIZE)
        self.trades_window = deque(maxlen=TRADES_WINDOW_SIZE)
        self.cvd_history = deque(maxlen=CVD_HISTORY_SIZE)
        self.whale_history = deque(maxlen=WHALE_HISTORY_SIZE)
        
        # Current stats
        self.current_price = 0.0
        self.last_trade_time = 0
        self.total_records = 0
        self.last_summary_time = time.time()
        
        # WebSocket trade buffer (accumulated between REST calls)
        self.ws_trades_buffer = []
        self.ws_last_flush = time.time()

# Global state dictionary
SYMBOL_DATA = {}

def init_symbol_data():
    """Initialize data structures for all symbols"""
    global SYMBOL_DATA
    for symbol in SYMBOLS:
        SYMBOL_DATA[symbol] = SymbolData(symbol)
        print(f"✅ Initialized data structures for {symbol}")

# ================================================================
# BINANCE REST API HELPERS
# ================================================================

def get_aggtrades(symbol: str, limit: int = None):
    """Fetch recent aggregate trades"""
    if limit is None:
        limit = AGGTRADES_LIMIT
    url = f"https://api.binance.com/api/v3/aggTrades?symbol={symbol}&limit={limit}"
    try:
        r = requests.get(url, timeout=API_TIMEOUT)
        r.raise_for_status()
        data = r.json()
    except Exception as e:
        print(f"❌ [{symbol}] Error fetching aggTrades: {e}")
        return []

    trades = []
    for t in data:
        trades.append({
            "p": float(t["p"]),
            "q": float(t["q"]),
            "T": int(t["T"]),
            "isBuyerMaker": bool(t["m"]),
        })
    return trades


def get_orderbook(symbol: str, limit: int = None):
    """Fetch orderbook depth"""
    if limit is None:
        limit = ORDERBOOK_LIMIT
    url = f"https://api.binance.com/api/v3/depth?symbol={symbol}&limit={limit}"
    try:
        r = requests.get(url, timeout=API_TIMEOUT)
        r.raise_for_status()
        return r.json()
    except Exception as e:
        print(f"❌ [{symbol}] Error fetching orderbook: {e}")
        return {"bids": [], "asks": []}


def get_klines(symbol: str, interval: str = "15m", limit: int = 100):
    """Fetch kline/candlestick data"""
    url = "https://api.binance.com/api/v3/klines"
    params = {"symbol": symbol, "interval": interval, "limit": limit}
    try:
        r = requests.get(url, params=params, timeout=5)
        r.raise_for_status()
        return r.json()
    except Exception as e:
        print(f"❌ [{symbol}] Error fetching klines: {e}")
        return []


# ================================================================
# CVD & WHALE CALCULATIONS
# ================================================================

def calc_cvd_from_trades(trades):
    """Calculate Cumulative Volume Delta"""
    buy_volume_usd = 0.0
    sell_volume_usd = 0.0

    for t in trades:
        price = t["p"]
        qty = t["q"]
        volume_usd = price * qty

        if t["isBuyerMaker"]:
            sell_volume_usd += volume_usd
        else:
            buy_volume_usd += volume_usd

    cvd_raw = (buy_volume_usd - sell_volume_usd) / 1_000_000
    cvd = max(-10.0, min(10.0, cvd_raw))
    return float(cvd)


def calc_whale_from_trades(trades, threshold_usd=None):
    """Detect whale trades (large trades)"""
    if threshold_usd is None:
        threshold_usd = WHALE_THRESHOLD_USD
    whale_buy = 0.0
    whale_sell = 0.0
    whale_count = 0
    whale_trades = []

    for t in trades:
        price = t["p"]
        qty = t["q"]
        volume_usd = price * qty

        if volume_usd >= threshold_usd:
            whale_count += 1
            whale_trades.append(volume_usd)
            if t["isBuyerMaker"]:
                whale_sell += volume_usd
            else:
                whale_buy += volume_usd

    whale_total = (whale_buy + whale_sell) / 1_000_000
    whale_buy_vol = whale_buy / 1_000_000
    whale_sell_vol = whale_sell / 1_000_000
    whale_largest = max(whale_trades) if whale_trades else 0.0
    whale_avg = (sum(whale_trades) / len(whale_trades)) if whale_trades else 0.0

    return {
        "whale_total": float(whale_total),
        "whale_buy": float(whale_buy_vol),
        "whale_sell": float(whale_sell_vol),
        "whale_count": whale_count,
        "whale_largest": float(whale_largest),
        "whale_avg": float(whale_avg),
    }


# ================================================================
# HTF TREND (24h)
# ================================================================

def get_htf_trend(symbol: str):
    """Long-term trend: 24 hours, based on 15m candles"""
    klines = get_klines(symbol=symbol, interval="15m", limit=96)
    
    if not klines or len(klines) < 5:
        return 0.0, 0.0

    closes = [float(k[4]) for k in klines]
    
    first = closes[0]
    last = closes[-1]
    daily_change_pct = (last - first) / first * 100.0

    x = np.arange(len(closes))
    y = np.array(closes)

    try:
        m, _ = np.polyfit(x, y, 1)
        daily_smooth = np.tanh(m / (y.mean() * 0.0001 + 1e-8))
    except Exception:
        daily_smooth = 0.0

    return float(daily_change_pct), float(daily_smooth)


# ================================================================
# WEBSOCKET HANDLER
# ================================================================

async def handle_websocket_stream():
    """Handle WebSocket streams for all symbols"""
    
    # Build WebSocket URL for multiple streams
    # Format: wss://stream.binance.com:9443/stream?streams=bnbusdt@aggTrade/btcusdt@aggTrade
    streams = "/".join([f"{symbol.lower()}@aggTrade" for symbol in SYMBOLS])
    ws_url = f"wss://stream.binance.com:9443/stream?streams={streams}"
    
    print(f"\n🌐 Connecting to Binance WebSocket...")
    print(f"📡 Streams: {streams}")
    
    retry_count = 0
    max_retries = WS_MAX_RETRIES
    
    while retry_count < max_retries:
        try:
            async with websockets.connect(ws_url, ping_interval=20, ping_timeout=10) as ws:
                print(f"✅ WebSocket connected!")
                retry_count = 0  # Reset on successful connection
                
                async for message in ws:
                    try:
                        data = json.loads(message)
                        
                        # Handle stream data format
                        if "stream" in data and "data" in data:
                            stream_name = data["stream"]
                            trade_data = data["data"]
                            
                            # Extract symbol from stream name (e.g., "bnbusdt@aggTrade" -> "BNBUSDT")
                            symbol = stream_name.split("@")[0].upper()
                            
                            if symbol in SYMBOL_DATA:
                                # Parse trade data
                                trade = {
                                    "p": float(trade_data["p"]),  # price
                                    "q": float(trade_data["q"]),  # quantity
                                    "T": int(trade_data["T"]),    # timestamp
                                    "isBuyerMaker": bool(trade_data["m"]),  # is buyer maker
                                }
                                
                                # Add to buffer
                                sd = SYMBOL_DATA[symbol]
                                sd.ws_trades_buffer.append(trade)
                                sd.current_price = trade["p"]
                                sd.last_trade_time = trade["T"]
                                
                                # Update price history
                                sd.price_window.append(trade["p"])
                                sd.price_history.append((time.time(), trade["p"]))
                    
                    except Exception as e:
                        print(f"❌ Error processing WebSocket message: {e}")
                        
        except websockets.exceptions.ConnectionClosed:
            retry_count += 1
            print(f"⚠️  WebSocket connection closed. Retry {retry_count}/{max_retries} in {WS_RETRY_DELAY}s...")
            await asyncio.sleep(WS_RETRY_DELAY)
        except Exception as e:
            retry_count += 1
            print(f"❌ WebSocket error: {e}")
            print(f"⚠️  Retry {retry_count}/{max_retries} in {WS_RETRY_DELAY}s...")
            await asyncio.sleep(WS_RETRY_DELAY)
    
    print(f"❌ WebSocket failed after {max_retries} retries")


# ================================================================
# DATA COLLECTION CYCLE (PER SYMBOL)
# ================================================================

async def collect_cycle(session: aiohttp.ClientSession, symbol: str):
    """Single data collection cycle for one symbol"""
    
    sd = SYMBOL_DATA[symbol]
    
    cycle_start = time.time()
    timestamp = datetime.utcnow().isoformat()
    
    # Print header every N records
    if sd.total_records % SUMMARY_INTERVAL_RECORDS == 0:
        print(f"\n{'='*80}")
        print(f"🔄 [{symbol}] Collection Cycle #{sd.total_records + 1} | {timestamp}")
        print(f"{'='*80}")

    # ─────────────────────────────────────────────────────────
    # 1. GET TRADES (REST API fallback + WebSocket buffer)
    # ─────────────────────────────────────────────────────────
    
    # Use WebSocket buffer if available, otherwise fetch REST
    if sd.ws_trades_buffer:
        trades = sd.ws_trades_buffer.copy()
        sd.ws_trades_buffer.clear()
        print(f"📊 [{symbol}] Using {len(trades)} trades from WebSocket buffer")
    else:
        trades = get_aggtrades(symbol=symbol, limit=1000)
        print(f"📊 [{symbol}] Fetched {len(trades)} trades from REST API")
    
    if not trades:
        print(f"⚠️  [{symbol}] No trades data")
        return

    # Current price
    price = trades[-1]["p"]
    sd.price_window.append(price)
    sd.price_history.append((time.time(), price))
    sd.trades_window.extend(trades)

    # ─────────────────────────────────────────────────────────
    # 2. FETCH ORDERBOOK
    # ─────────────────────────────────────────────────────────
    orderbook = get_orderbook(symbol=symbol, limit=100)

    # ─────────────────────────────────────────────────────────
    # 3. CALCULATE ALL INDICATORS
    # ─────────────────────────────────────────────────────────

    # Price changes
    price_changes = calc_all_price_changes(price, sd.price_history)

    # Smooth (momentum)
    smooth_values = calc_smooth_multi_tf(sd.price_window, sd.price_history)

    # Volatility
    volatility = calc_volatility(sd.price_history)

    # Money flow ratios
    ratios = calc_all_ratios(trades, sd.trades_window)

    # CVD
    cvd_current = calc_cvd_from_trades(trades)
    sd.cvd_history.append((time.time(), cvd_current))
    cvd_extended = calc_cvd_extended(cvd_current, sd.cvd_history)

    # Whale
    whale_current = calc_whale_from_trades(trades, threshold_usd=50_000)
    sd.whale_history.append((time.time(), whale_current))
    whale_extended = calc_whale_extended(whale_current, sd.whale_history)

    # Orderbook
    orderbook_metrics = calc_orderbook_extended(orderbook, price)

    # Volume
    volume_current = sum(t["q"] for t in trades)
    sd.volume_window.append((time.time(), volume_current))
    volume_metrics = calc_volume_extended(trades, sd.volume_window)

    # Market microstructure
    microstructure = calc_market_microstructure(trades, sd.trades_window)

    # HTF trend
    daily_trend_pct, daily_smooth = get_htf_trend(symbol)

    # ─────────────────────────────────────────────────────────
    # 4. MEMPOOL DATA (only for configured symbols)
    # ─────────────────────────────────────────────────────────
    if symbol in MEMPOOL_SYMBOLS:
        try:
            mempool_data = await fetch_mempool_extended(session)
            mempool_metrics = normalize_mempool_extended(mempool_data)
        except Exception as e:
            print(f"⚠️  [{symbol}] Mempool error: {e}")
            mempool_metrics = normalize_mempool_extended(None)
    else:
        # No mempool for non-BTC symbols
        mempool_metrics = normalize_mempool_extended(None)

    # ─────────────────────────────────────────────────────────
    # 5. COMPOSITE INDICATORS
    # ─────────────────────────────────────────────────────────
    composite = calc_composite_indicators(
        smooth_values['2m'],
        ratios['1m'],
        cvd_extended['current'],
        whale_extended['total'],
        volume_metrics['change_rate'],
        orderbook_metrics['depth_ratio']
    )

    # ─────────────────────────────────────────────────────────
    # 6. TEMPORAL FEATURES
    # ─────────────────────────────────────────────────────────
    now = datetime.utcnow()
    temporal = {
        "hour_of_day": now.hour,
        "day_of_week": now.weekday(),
        "is_market_hours": 13 <= now.hour <= 21,
    }

    time_since_whale = 0.0
    if sd.whale_history:
        for ts, whale_info in reversed(sd.whale_history):
            if whale_info['whale_count'] > 0:
                time_since_whale = time.time() - ts
                break
    temporal["time_since_last_whale"] = time_since_whale

    # ─────────────────────────────────────────────────────────
    # 7. BUILD COMPLETE DATA RECORD
    # ─────────────────────────────────────────────────────────
    data_record = {
        "symbol": symbol,
        "timestamp": timestamp,
        "unix_time": time.time(),
        
        # PRICE
        "price": price,
        
        # PRICE CHANGES
        "price_change_1m": price_changes['1m'],
        "price_change_3m": price_changes['3m'],
        "price_change_5m": price_changes['5m'],
        "price_change_10m": price_changes['10m'],
        "price_change_15m": price_changes['15m'],
        "price_change_20m": price_changes['20m'],
        "price_change_30m": price_changes['30m'],
        
        # SMOOTH (MOMENTUM)
        "smooth_2m": smooth_values['2m'],
        "smooth_5m": smooth_values['5m'],
        "smooth_10m": smooth_values['10m'],
        "smooth_15m": smooth_values['15m'],
        
        # VOLATILITY
        "volatility_1m": volatility['1m'],
        "volatility_5m": volatility['5m'],
        "volatility_high_low_range_5m": volatility['high_low_range_5m'],
        
        # MONEY FLOW RATIOS
        "ratio_1m": ratios['1m'],
        "ratio_5m": ratios['5m'],
        "ratio_15m": ratios['15m'],
        "ratio_20m": ratios['20m'],
        "ratio_30m": ratios['30m'],
        "ratio_mf_acceleration": ratios['mf_acceleration'],
        
        # CVD
        "cvd_current": cvd_extended['current'],
        "cvd_5m": cvd_extended['5m'],
        "cvd_15m": cvd_extended['15m'],
        "cvd_change_rate": cvd_extended['change_rate'],
        "cvd_acceleration": cvd_extended['acceleration'],
        
        # WHALE
        "whale_total": whale_extended['total'],
        "whale_buy": whale_extended['buy'],
        "whale_sell": whale_extended['sell'],
        "whale_count": whale_extended['count'],
        "whale_largest": whale_extended['largest'],
        "whale_avg": whale_extended['avg'],
        "whale_buy_sell_ratio": whale_extended['buy_sell_ratio'],
        
        # ORDERBOOK
        "depth_ratio": orderbook_metrics['depth_ratio'],
        "best_level_ratio": orderbook_metrics['best_level_ratio'],
        "spread": orderbook_metrics['spread'],
        "spread_pct": orderbook_metrics['spread_pct'],
        "imbalance_5": orderbook_metrics['imbalance_5'],
        "imbalance_20": orderbook_metrics['imbalance_20'],
        "bid_wall_size": orderbook_metrics['bid_wall_size'],
        "ask_wall_size": orderbook_metrics['ask_wall_size'],
        "wall_distance_pct": orderbook_metrics['wall_distance_pct'],
        "liquidity_bid_pressure": orderbook_metrics['liquidity_bid_pressure'],
        "liquidity_ask_pressure": orderbook_metrics['liquidity_ask_pressure'],
        "liquidity_shift": orderbook_metrics['liquidity_shift'],
        "impact_buy_price": orderbook_metrics['impact_buy_price'],
        "impact_sell_price": orderbook_metrics['impact_sell_price'],
        "orderbook_entropy": orderbook_metrics['orderbook_entropy'],
        "liquidity_gap_bids": orderbook_metrics['liquidity_gap_bids'],
        "liquidity_gap_asks": orderbook_metrics['liquidity_gap_asks'],
        "spoof_buy_score": orderbook_metrics['spoof_buy_score'],
        "spoof_sell_score": orderbook_metrics['spoof_sell_score'],
        "iceberg_score": orderbook_metrics['iceberg_score'],
        
        # VOLUME
        "volume_1m": volume_metrics['1m'],
        "volume_5m": volume_metrics['5m'],
        "volume_15m": volume_metrics['15m'],
        "volume_avg_1m": volume_metrics['avg_1m'],
        "volume_change_rate": volume_metrics['change_rate'],
        "volume_acceleration": volume_metrics['acceleration'],
        "volume_avg_trade_size": volume_metrics['avg_trade_size'],
        "volume_trade_frequency": volume_metrics['trade_frequency'],
        "volume_large_trade_frequency": volume_metrics['large_trade_frequency'],
        
        # MARKET MICROSTRUCTURE
        "trades_per_minute": microstructure['trades_per_minute'],
        "buy_trades_count": microstructure['buy_trades_count'],
        "sell_trades_count": microstructure['sell_trades_count'],
        "avg_time_between_trades": microstructure['avg_time_between_trades'],
        "aggressive_buy_pct": microstructure['aggressive_buy_pct'],
        "aggressive_sell_pct": microstructure['aggressive_sell_pct'],
        "delta_strength": microstructure['delta_strength'],
        "vwap_drift": microstructure['vwap_drift'],
        "micro_trend_3s": microstructure['micro_trend_3s'],
        "micro_trend_7s": microstructure['micro_trend_7s'],
        "micro_trend_12s": microstructure['micro_trend_12s'],
        "variance_ratio": microstructure['variance_ratio'],
        "breakout_energy": microstructure['breakout_energy'],
        
        # HTF TREND
        "daily_trend_pct": daily_trend_pct,
        "daily_smooth": daily_smooth,
        
        # MEMPOOL (BTC only)
        "tx_count": mempool_metrics['tx_count'],
        "mempool_vsize_mb": mempool_metrics['mempool_vsize_mb'],
        "fee_low": mempool_metrics['fee_low'],
        "fee_medium": mempool_metrics['fee_medium'],
        "fee_high": mempool_metrics['fee_high'],
        "block_tx_count": mempool_metrics['block_tx_count'],
        "avg_block_time": mempool_metrics['avg_block_time'],
        "growth_rate": mempool_metrics['growth_rate'],
        "unconfirmed_value": mempool_metrics['unconfirmed_value'],
        "pending_over_1btc": mempool_metrics['pending_over_1btc'],
        "f_mempool": mempool_metrics['f_mempool'],
        "f_fee": mempool_metrics['f_fee'],
        "f_block": mempool_metrics['f_block'],
        "f_block_time": mempool_metrics.get('f_block_time', 0),
        
        # COMPOSITE INDICATORS
        "momentum_strength": composite['momentum_strength'],
        "volume_price_trend": composite['volume_price_trend'],
        "buying_pressure": composite['buying_pressure'],
        "selling_pressure": composite['selling_pressure'],
        "market_sentiment": composite['market_sentiment'],
        
        # TEMPORAL
        "hour_of_day": temporal['hour_of_day'],
        "day_of_week": temporal['day_of_week'],
        "is_market_hours": temporal['is_market_hours'],
        "time_since_last_whale": temporal['time_since_last_whale'],
    }

    # ─────────────────────────────────────────────────────────
    # 8. SAVE TO DAILY LOG (symbol-specific directory)
    # ─────────────────────────────────────────────────────────
    
    # Save with symbol-specific directory
    symbol_dir = f"{DATA_DIR_BASE}/{symbol.lower()}"
    append_daily_log(data_record, symbol=symbol, directory=symbol_dir)
    
    sd.total_records += 1
    cycle_time = time.time() - cycle_start
    
    # Short summary
    print(f"✅ [{symbol}] #{sd.total_records} | ${price:,.2f} | Ratio: {ratios['1m']:.2f} | CVD: {cvd_extended['current']:.2f} | {cycle_time:.2f}s")


# ================================================================
# PERIODIC REST API COLLECTOR
# ================================================================

async def periodic_collector(session: aiohttp.ClientSession):
    """Collect data for all symbols periodically"""
    
    print(f"\n🔄 Starting periodic REST API collector...")
    print(f"⏱️  Interval: {REST_INTERVAL} seconds")
    print(f"📊 Symbols: {', '.join(SYMBOLS)}")
    
    while True:
        try:
            # Collect for all symbols
            for symbol in SYMBOLS:
                try:
                    await collect_cycle(session, symbol)
                except Exception as e:
                    print(f"❌ [{symbol}] Collection error: {e}")
                    traceback.print_exc()
            
            # Wait for next interval
            await asyncio.sleep(REST_INTERVAL)
            
        except KeyboardInterrupt:
            print("\n⛔ Collector stopped by user")
            break
        except Exception as e:
            print(f"❌ Periodic collector error: {e}")
            traceback.print_exc()
            await asyncio.sleep(REST_INTERVAL)


# ================================================================
# MAIN LOOP
# ================================================================

async def main():
    print("\n" + "🤖" * 40)
    print("MULTI-COIN DATA COLLECTOR - BINANCE WEBSOCKET")
    print("=" * 80)
    print(f"📊 Tracking: {', '.join(SYMBOLS)}")
    print(f"⏱️  Collection interval: {REST_INTERVAL} seconds")
    print(f"💾 Data directory: {DATA_DIR_BASE}/")
    print("🌐 Using WebSocket for real-time trades")
    print("=" * 80)
    print("🤖" * 40 + "\n")

    # Initialize data structures
    init_symbol_data()
    
    async with aiohttp.ClientSession() as session:
        # Create tasks
        ws_task = asyncio.create_task(handle_websocket_stream())
        collector_task = asyncio.create_task(periodic_collector(session))
        
        # Run both concurrently
        try:
            await asyncio.gather(ws_task, collector_task)
        except KeyboardInterrupt:
            print("\n⛔ Shutting down...")
            ws_task.cancel()
            collector_task.cancel()
            
            # Print final summary for all symbols
            for symbol in SYMBOLS:
                print(f"\n{'='*60}")
                print(f"📊 {symbol} - Final Summary")
                print(f"{'='*60}")
                sd = SYMBOL_DATA[symbol]
                print(f"Total records: {sd.total_records}")
                print(f"Last price: ${sd.current_price:,.2f}")
                print(f"{'='*60}")


if __name__ == "__main__":
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        print("\n👋 Goodbye!")
