import os
import json
import random
import glob
import logging
import requests
import gevent
from datetime import datetime, timedelta
from typing import List, Dict, Any, Tuple
from locust import FastHttpUser, task, events
from threading import Lock
from collections import defaultdict
import time


# Global stats collector for debug_info metrics
class DebugInfoStats:
    """Collects and aggregates debug_info timing metrics."""
    
    def __init__(self):
        self._lock = Lock()
        self.metrics = defaultdict(list)  # metric_name -> list of values
        self.window_size = 10000  # Keep last N samples per worker for percentile calc
    
    def record(self, debug_info: Dict[str, Any]):
        """Record debug_info metrics from a response."""
        with self._lock:
            for key, value in debug_info.items():
                if isinstance(value, (int, float)):
                    self.metrics[key].append(value)
                    # Keep window bounded
                    if len(self.metrics[key]) > self.window_size:
                        self.metrics[key] = self.metrics[key][-self.window_size:]
    
    def merge(self, other_metrics: Dict[str, list]):
        """Merge metrics from another source (worker reports)."""
        with self._lock:
            for key, values in other_metrics.items():
                self.metrics[key].extend(values)
    
    def get_raw_metrics(self) -> Dict[str, list]:
        """Get raw metrics for sending to master."""
        with self._lock:
            # Return a copy and clear local metrics to avoid double-counting
            result = {k: list(v) for k, v in self.metrics.items()}
            self.metrics.clear()
            return result
    
    def get_stats(self) -> Dict[str, Dict[str, float]]:
        """Get aggregated stats for all metrics."""
        with self._lock:
            result = {}
            for key, values in self.metrics.items():
                if values:
                    sorted_vals = sorted(values)
                    n = len(sorted_vals)
                    result[key] = {
                        "min": sorted_vals[0],
                        "max": sorted_vals[-1],
                        "avg": sum(sorted_vals) / n,
                        "median": sorted_vals[n // 2],
                        "p95": sorted_vals[int(n * 0.95)] if n >= 20 else sorted_vals[-1],
                        "p99": sorted_vals[int(n * 0.99)] if n >= 100 else sorted_vals[-1],
                        "count": n,
                    }
            return result
    
    def reset(self):
        """Clear all collected metrics."""
        with self._lock:
            self.metrics.clear()


# Global instance
debug_stats = DebugInfoStats()


# History collector for time-series debug stats
class DebugInfoHistory:
    """Collects debug_info stats over time for history CSV."""
    
    def __init__(self):
        self._lock = Lock()
        self.history = []  # List of {timestamp, user_count, metric_stats...}
    
    def record_snapshot(self, user_count: int, stats: Dict[str, Dict[str, float]]):
        """Record a point-in-time snapshot of stats."""
        with self._lock:
            entry = {
                "timestamp": datetime.now().isoformat(),
                "user_count": user_count,
            }
            for metric, values in stats.items():
                for stat_name, stat_value in values.items():
                    entry[f"{metric}_{stat_name}"] = stat_value
            self.history.append(entry)
    
    def get_history(self) -> List[Dict]:
        with self._lock:
            return list(self.history)
    
    def reset(self):
        with self._lock:
            self.history.clear()


debug_history = DebugInfoHistory()


@events.report_to_master.add_listener
def on_report_to_master(client_id, data):
    """Send raw debug metrics to master in distributed mode."""
    # Get raw values and clear local to avoid double-counting
    data["debug_raw_metrics"] = debug_stats.get_raw_metrics()


@events.worker_report.add_listener  
def on_worker_report(client_id, data):
    """Aggregate worker debug metrics on master."""
    if "debug_raw_metrics" in data:
        debug_stats.merge(data["debug_raw_metrics"])


# Periodic history recorder (started on test_start)
_history_greenlet = None

def _record_history_periodically(environment):
    """Background greenlet that records stats every 2 seconds."""
    while True:
        gevent.sleep(2)
        current_stats = debug_stats.get_stats()
        if current_stats:
            user_count = environment.runner.user_count if environment.runner else 0
            debug_history.record_snapshot(user_count, current_stats)


@events.test_start.add_listener
def on_test_start(environment, **kwargs):
    """Start the history recorder greenlet on master."""
    global _history_greenlet
    # Only run on master (or local mode)
    if not hasattr(environment.runner, 'worker_index'):
        _history_greenlet = gevent.spawn(_record_history_periodically, environment)


@events.test_stop.add_listener  
def on_test_stop(environment, **kwargs):
    """Print debug_info summary and export to CSV when test stops (only on master)."""
    global _history_greenlet
    
    # Stop the history greenlet
    if _history_greenlet:
        _history_greenlet.kill()
        _history_greenlet = None
    
    # In distributed mode, only master should write the file
    if hasattr(environment.runner, 'worker_index'):
        return  # Skip on workers
    
    stats = debug_stats.get_stats()
    if stats:
        logging.info("=" * 60)
        logging.info("DEBUG_INFO TIMING SUMMARY (ms)")
        logging.info("=" * 60)
        for metric, values in sorted(stats.items()):
            logging.info(
                f"{metric:20s} | avg: {values['avg']:7.1f} | "
                f"p50: {values['median']:7.1f} | p95: {values['p95']:7.1f} | "
                f"p99: {values['p99']:7.1f} | min: {values['min']:7.1f} | max: {values['max']:7.1f}"
            )
        logging.info("=" * 60)
        
        # Export to CSV
        csv_path = os.environ.get("DEBUG_STATS_CSV", "debug_info_stats.csv")
        try:
            import csv
            with open(csv_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["metric", "count", "min", "max", "avg", "median", "p95", "p99"])
                for metric, values in sorted(stats.items()):
                    writer.writerow([
                        metric,
                        values["count"],
                        round(values["min"], 2),
                        round(values["max"], 2),
                        round(values["avg"], 2),
                        round(values["median"], 2),
                        round(values["p95"], 2),
                        round(values["p99"], 2),
                    ])
            logging.info(f"Debug stats exported to {csv_path}")
        except Exception as e:
            logging.error(f"Failed to export debug stats CSV: {e}")
        
        # Export history to CSV
        history_csv_path = os.environ.get("DEBUG_STATS_HISTORY_CSV", "debug_info_stats_history.csv")
        history = debug_history.get_history()
        if history:
            try:
                import csv
                # Get all unique keys from history entries
                all_keys = set()
                for entry in history:
                    all_keys.update(entry.keys())
                fieldnames = ["timestamp", "user_count"] + sorted([k for k in all_keys if k not in ["timestamp", "user_count"]])
                
                with open(history_csv_path, "w", newline="") as f:
                    writer = csv.DictWriter(f, fieldnames=fieldnames)
                    writer.writeheader()
                    for entry in history:
                        writer.writerow(entry)
                logging.info(f"Debug stats history exported to {history_csv_path}")
            except Exception as e:
                logging.error(f"Failed to export debug stats history CSV: {e}")
        
        # Print final stats to console
        print("\n")
        print("=" * 80)
        print("DEBUG_INFO TIMING STATS")
        print("=" * 80)
        print(f"{'Metric':<25} {'Count':>8} {'Min':>10} {'Avg':>10} {'Median':>10} {'P95':>10} {'P99':>10} {'Max':>10}")
        print("-" * 80)
        for metric, values in sorted(stats.items()):
            print(
                f"{metric:<25} {values['count']:>8} {values['min']:>10.1f} {values['avg']:>10.1f} "
                f"{values['median']:>10.1f} {values['p95']:>10.1f} {values['p99']:>10.1f} {values['max']:>10.1f}"
            )
        print("=" * 80)
        print("\n")


class VectorSearchLoadTestUser(FastHttpUser):
    search_queries: List[Dict[str, Any]] = []
    data_loading: bool = False
    oauth: str = None
    expiration: datetime = None
    _lock = Lock()
    
    # Rate limit backoff settings (per worker)
    consecutive_429s = 0
    base_wait_time = 0.100  # seconds
    max_wait_time = 30  # seconds
    backoff_multiplier = 2
    jitter_range = 0.3  # 30% jitter

    def calculate_backoff_time(self) -> float:
        """Calculate wait time based on consecutive 429s with jitter."""
        wait_time = min(
            self.base_wait_time * (self.backoff_multiplier ** self.consecutive_429s),
            self.max_wait_time
        )
        jitter = random.uniform(-wait_time * self.jitter_range, wait_time * self.jitter_range)
        return max(0, wait_time + jitter)

    def _fetch_oauth_token(self, token_lifetime: timedelta = timedelta(minutes=55)) -> Tuple[str, datetime]:
        deets = json.dumps([
            {
                "type": "unity_catalog_permission",
                "securable_type": "table",
                "securable_object_name": self.index_name,
                "operation": ("ReadVectorIndex"),
            },
        ])
        payload = {
            'grant_type': 'client_credentials',
            'scope': 'all-apis',
            'authorization_details': deets
        }

        url = f"{self.workspace_url}/oidc/v1/token"
        response = requests.post(
            url=url,
            auth=(self.CLIENT_ID, self.CLIENT_SECRET),
            headers={"Content-Type": "application/x-www-form-urlencoded"},
            data=payload,
        )
        if response.status_code != 200:
            logging.error(f"OAuth token request failed: {response.status_code} - {response.text}")
            response.raise_for_status()

        token_data = response.json()
        access_token = token_data.get("access_token")
        if not access_token:
            raise ValueError("Failed to get access token")
        return access_token, datetime.now() + token_lifetime

    def get_oauth(self) -> str:
        """Return valid OAuth token, refreshing if needed."""
        if not self.oauth or datetime.now() > (self.expiration - timedelta(minutes=5)):
            logging.info("Fetching new OAuth token...")
            self.oauth, self.expiration = self._fetch_oauth_token()
        return self.oauth

    @classmethod
    def _load_queries_streaming(cls):
        """Load queries file by file, appending as we go."""
        cls.data_loading = True

        input_file = os.environ.get("INPUT_FILE")
        if input_file and os.path.exists(input_file):
            logging.info(f"Loading single file {input_file}")
            with open(input_file, "r") as f:
                data = json.load(f)
                with cls._lock:
                    cls.search_queries.extend(data if isinstance(data, list) else [data])
            cls.data_loading = False
            return

        input_folder = os.environ.get("INPUT_FOLDER", ".")
        input_pattern = os.environ.get("INPUT_FILE_PATTERN", "vector_search_input*.json")
        search_path = os.path.join(input_folder, input_pattern)
        input_files = glob.glob(search_path) or glob.glob("vector_search_input*.json") or glob.glob("*.json")

        for file_path in input_files:
            try:
                with open(file_path, "r") as f:
                    data = json.load(f)
                    with cls._lock:
                        cls.search_queries.extend(data if isinstance(data, list) else [data])
                logging.info(f"Loaded queries from {file_path}, total so far: {len(cls.search_queries)}")
                gevent.sleep(0)
            except Exception as e:
                logging.warning(f"Could not load {file_path}: {e}")

        cls.data_loading = False
        logging.info(f"Finished loading all queries: {len(cls.search_queries)}")

    @classmethod
    def start_background_loader(cls):
        if not cls.data_loading and not cls.search_queries:
            gevent.spawn(cls._load_queries_streaming)

    def on_start(self):
        """Initialize env vars and kick off background loader."""
        self.CLIENT_ID = os.environ.get("CLIENT_ID")
        self.CLIENT_SECRET = os.environ.get("CLIENT_SECRET")
        self.index_name = os.environ.get("VECTOR_SEARCH_INDEX_NAME")
        self.workspace_url = os.environ.get("DATABRICKS_WORKSPACE_URL")
        self.index_query_path = os.environ.get("VECTOR_SEARCH_INDEX_PATH")

        self.base_wait_time = float(os.environ.get("RATE_LIMIT_BASE_WAIT", "3"))
        self.max_wait_time = float(os.environ.get("RATE_LIMIT_MAX_WAIT", "60"))
        self.backoff_multiplier = float(os.environ.get("RATE_LIMIT_MULTIPLIER", "2"))
        self.jitter_range = float(os.environ.get("RATE_LIMIT_JITTER", "0.3"))

        self.start_background_loader()

    def get_next_query(self) -> Dict[str, Any]:
        """Return a query if available, else dummy fallback."""
        with self._lock:
            if self.search_queries:
                return random.choice(self.search_queries)
        return {"query_vector": [0.0] * 5, "num_results": 1}

    @task
    def query_vector_search_ann(self):
        headers = {"Authorization": f"Bearer {self.get_oauth()}"}
        payload = self.get_next_query()

        with self.client.get(
            self.index_query_path,
            headers=headers,
            json=payload,
            catch_response=True,
            name="vector_search_query",
        ) as response:
            if response.status_code == 200:
                # Parse response and extract debug_info
                try:
                    data = response.json()
                    if "debug_info" in data:
                        debug_info = data["debug_info"]
                        
                        # Record to our custom stats collector
                        debug_stats.record(debug_info)
                        
                        # Log individual request debug info at DEBUG level
                        logging.debug(
                            f"debug_info: response_time={debug_info.get('response_time')}ms "
                            f"ann_time={debug_info.get('ann_time')}ms "
                            f"embedding_gen_time={debug_info.get('embedding_gen_time')}ms "
                            f"reranker_time={debug_info.get('reranker_time')}ms"
                        )
                except json.JSONDecodeError:
                    logging.warning("Failed to parse response JSON for debug_info")
                
                response.success()
                self.consecutive_429s = 0
                
            elif response.status_code in [429, 499]:
                self.consecutive_429s += 1
                retry_after = response.headers.get("Retry-After") or response.headers.get("retry-after")
                
                if retry_after:
                    wait_time = int(retry_after)
                else:
                    wait_time = self.calculate_backoff_time()
                
                logging.warning(
                    f"Rate limited (consecutive: {self.consecutive_429s}), "
                    f"waiting {wait_time:.2f}s before next request"
                )
                gevent.sleep(wait_time)
                response.failure(f"Rate limited: {response.status_code}")
                
            else:
                response.failure(f"Query failed: {response.status_code} {response.headers} {response.text}")
                self.consecutive_429s = 0