# mempool_extended.py - Extended mempool/on-chain metrics
# Compatible with Python 3.9+

import aiohttp
import math
from typing import Optional, Dict

MEMPOOL_API_BASE = "https://mempool.space/api"


async def fetch_mempool_extended(session: aiohttp.ClientSession):
    """
    Fetch extended mempool data including:
    - Mempool status (tx count, vsize)
    - Fee recommendations (low, medium, high)
    - Recent blocks (for avg block time)
    - Pending transaction details
    
    Returns dict with all metrics, or fallback values on error
    """
    
    result = {
        "tx_count": 0.0,
        "mempool_vsize": 0.0,
        "mempool_vsize_mb": 0.0,
        "fee_low": 0.0,
        "fee_medium": 0.0,
        "fee_high": 0.0,
        "block_tx_count": 0.0,
        "avg_block_time": 600.0,  # Default 10 minutes
        "mempool_growth_rate": 0.0,
        "unconfirmed_value": 0.0,
        "pending_over_1btc": 0,
    }
    
    # ─────────────────────────────────────────────────────────
    # 1. GET MEMPOOL STATUS
    # ─────────────────────────────────────────────────────────
    try:
        async with session.get(f"{MEMPOOL_API_BASE}/mempool", timeout=6) as res:
            res.raise_for_status()
            mempool = await res.json()
            
            result["tx_count"] = float(mempool.get("count", 0))
            result["mempool_vsize"] = float(mempool.get("vsize", 0))
            result["mempool_vsize_mb"] = result["mempool_vsize"] / (1024 * 1024)
            
    except Exception as e:
        print(f"[MEMPOOL ERROR] /mempool → {e}")
    
    # ─────────────────────────────────────────────────────────
    # 2. GET FEE RECOMMENDATIONS
    # ─────────────────────────────────────────────────────────
    try:
        async with session.get(f"{MEMPOOL_API_BASE}/v1/fees/recommended", timeout=6) as res:
            res.raise_for_status()
            fees = await res.json()
            
            result["fee_low"] = float(fees.get("hourFee", 0))
            result["fee_medium"] = float(fees.get("halfHourFee", 0))
            result["fee_high"] = float(fees.get("fastestFee", 0))
            
    except Exception as e:
        print(f"[MEMPOOL ERROR] /v1/fees/recommended → {e}")
    
    # ─────────────────────────────────────────────────────────
    # 3. GET RECENT BLOCKS (for avg block time and tx count)
    # ─────────────────────────────────────────────────────────
    try:
        async with session.get(f"{MEMPOOL_API_BASE}/blocks", timeout=6) as res:
            res.raise_for_status()
            blocks = await res.json()
            
            if isinstance(blocks, list) and len(blocks) > 0:
                # Latest block tx count
                latest_block = blocks[0]
                result["block_tx_count"] = float(latest_block.get("tx_count", 0))
                
                # Calculate average block time (last 10 blocks)
                if len(blocks) >= 2:
                    block_times = []
                    for i in range(min(10, len(blocks)-1)):
                        time_diff = blocks[i].get("timestamp", 0) - blocks[i+1].get("timestamp", 0)
                        if time_diff > 0:
                            block_times.append(time_diff)
                    
                    if block_times:
                        result["avg_block_time"] = float(sum(block_times) / len(block_times))
                        
    except Exception as e:
        print(f"[MEMPOOL ERROR] /blocks → {e}")
    
    # ─────────────────────────────────────────────────────────
    # 4. CALCULATE MEMPOOL GROWTH RATE
    # ─────────────────────────────────────────────────────────
    # Note: This would require storing previous tx_count values
    # For now, we'll estimate based on tx_count vs avg block capacity
    if result["block_tx_count"] > 0:
        # Rough estimate: if mempool has more tx than 1 block can handle
        blocks_worth = result["tx_count"] / result["block_tx_count"]
        # Estimate growth: how many tx/min being added
        result["mempool_growth_rate"] = (blocks_worth - 1) * 100  # rough metric
    
    return result


def normalize_mempool_extended(m: Optional[Dict]):
    """
    Normalize extended mempool data into standardized metrics
    
    Returns dict with:
    - Raw values (tx_count, fees, etc.)
    - Normalized factors (0-1 range)
    """
    
    if not m:
        return {
            "tx_count": 0,
            "mempool_vsize_mb": 0,
            "fee_low": 0,
            "fee_medium": 0,
            "fee_high": 0,
            "block_tx_count": 0,
            "avg_block_time": 600,
            "growth_rate": 0,
            "unconfirmed_value": 0,
            "pending_over_1btc": 0,
            "f_mempool": 0,
            "f_fee": 0,
            "f_block": 0,
        }
    
    # Extract raw values
    tx_count = m.get("tx_count", 0)
    mempool_vsize = m.get("mempool_vsize", 0)
    mempool_vsize_mb = m.get("mempool_vsize_mb", 0)
    fee_low = m.get("fee_low", 0)
    fee_medium = m.get("fee_medium", 0)
    fee_high = m.get("fee_high", 0)
    block_tx = m.get("block_tx_count", 0)
    avg_block_time = m.get("avg_block_time", 600)
    growth_rate = m.get("mempool_growth_rate", 0)
    unconfirmed_value = m.get("unconfirmed_value", 0)
    pending_over_1btc = m.get("pending_over_1btc", 0)
    
    # ─────────────────────────────────────────────────────────
    # NORMALIZE TO 0-1 RANGE (using tanh)
    # ─────────────────────────────────────────────────────────
    
    # Mempool congestion (based on vsize)
    # Normal: ~10-50 MB, High: 100+ MB, Extreme: 200+ MB
    f_mempool = math.tanh(mempool_vsize_mb / 100) if mempool_vsize_mb > 0 else 0.0
    
    # Fee pressure (based on high priority fee)
    # Normal: 1-10 sat/vB, High: 20-50, Extreme: 100+
    f_fee = math.tanh(fee_high / 50) if fee_high > 0 else 0.0
    
    # Block fullness (based on tx count)
    # Normal block: 1500-2500 tx, Full: 3000+
    f_block = math.tanh(block_tx / 2500) if block_tx > 0 else 0.0
    
    # Block time factor (longer blocks = more congestion)
    # Normal: 600s (10 min), Fast: <480s, Slow: >720s
    f_block_time = math.tanh((avg_block_time - 600) / 300) if avg_block_time > 0 else 0.0
    
    return {
        # Raw values
        "tx_count": float(tx_count),
        "mempool_vsize_mb": float(mempool_vsize_mb),
        "fee_low": float(fee_low),
        "fee_medium": float(fee_medium),
        "fee_high": float(fee_high),
        "block_tx_count": float(block_tx),
        "avg_block_time": float(avg_block_time),
        "growth_rate": float(growth_rate),
        "unconfirmed_value": float(unconfirmed_value),
        "pending_over_1btc": int(pending_over_1btc),
        
        # Normalized factors (0-1)
        "f_mempool": float(f_mempool),
        "f_fee": float(f_fee),
        "f_block": float(f_block),
        "f_block_time": float(f_block_time),
    }


# ─────────────────────────────────────────────────────────
# HELPER: Previous mempool state for growth rate
# ─────────────────────────────────────────────────────────

_previous_mempool_state = {
    "tx_count": 0,
    "timestamp": 0,
}


def update_mempool_growth(current_tx_count):
    """
    Update mempool growth rate based on previous state
    Returns: tx per minute growth rate
    """
    import time
    global _previous_mempool_state
    
    current_time = time.time()
    prev_tx = _previous_mempool_state["tx_count"]
    prev_time = _previous_mempool_state["timestamp"]
    
    if prev_time > 0 and current_time > prev_time:
        time_diff_minutes = (current_time - prev_time) / 60
        tx_diff = current_tx_count - prev_tx
        growth_rate = tx_diff / time_diff_minutes if time_diff_minutes > 0 else 0
    else:
        growth_rate = 0
    
    # Update state
    _previous_mempool_state["tx_count"] = current_tx_count
    _previous_mempool_state["timestamp"] = current_time
    
    return float(growth_rate)
