import os
import json
import random
import glob
import logging
import requests
import gevent
from datetime import datetime, timedelta
from typing import List, Dict, Any, Tuple
from locust import FastHttpUser, task
from threading import Lock
import time


class VectorSearchLoadTestUser(FastHttpUser):
    search_queries: List[Dict[str, Any]] = []
    data_loading: bool = False
    oauth: str = None
    expiration: datetime = None
    _lock = Lock()

    # ---------------------------
    # Token management
    # ---------------------------
    def _fetch_oauth_token(self, token_lifetime: timedelta = timedelta(minutes=55)) -> Tuple[str, datetime]:
        deets = json.dumps([
                        {
                            "type": "unity_catalog_permission",
                            "securable_type": "table",
                            "securable_object_name": self.index_name,
                            "operation": ("ReadVectorIndex"),
                        },
                    ])
        payload = { 'grant_type': 'client_credentials', 'scope': 'all-apis', 'authorization_details': deets}

        url = f"{self.workspace_url}/oidc/v1/token"
        response = requests.post(
            url=url,
            auth=(self.CLIENT_ID, self.CLIENT_SECRET),
            headers={"Content-Type": "application/x-www-form-urlencoded"},
            data=payload,
        )
        if response.status_code != 200:
            logging.error(f"OAuth token request failed: {response.status_code} - {response.text}")
            response.raise_for_status()

        token_data = response.json()
        access_token = token_data.get("access_token")
        if not access_token:
            raise ValueError("Failed to get access token")
        return access_token, datetime.now() + token_lifetime

    def get_oauth(self) -> str:
        """Return valid OAuth token, refreshing if needed."""
        if not self.oauth or datetime.now() > self.expiration:
            logging.info("Fetching new OAuth token...")
            self.oauth, self.expiration = self._fetch_oauth_token()
        return self.oauth

    # ---------------------------
    # Progressive data loading
    # ---------------------------
    @classmethod
    def _load_queries_streaming(cls):
        """Load queries file by file, appending as we go."""
        cls.data_loading = True

        input_file = os.environ.get("INPUT_FILE")
        if input_file and os.path.exists(input_file):
            logging.info(f"Loading single file {input_file}")
            with open(input_file, "r") as f:
                data = json.load(f)
                with cls._lock:
                    cls.search_queries.extend(data if isinstance(data, list) else [data])
            cls.data_loading = False
            return

        input_folder = os.environ.get("INPUT_FOLDER", ".")
        input_pattern = os.environ.get("INPUT_FILE_PATTERN", "vector_search_input*.json")
        search_path = os.path.join(input_folder, input_pattern)
        input_files = glob.glob(search_path) or glob.glob("vector_search_input*.json") or glob.glob("*.json")

        for file_path in input_files:
            try:
                with open(file_path, "r") as f:
                    data = json.load(f)
                    with cls._lock:
                        cls.search_queries.extend(data if isinstance(data, list) else [data])
                logging.info(f"Loaded queries from {file_path}, total so far: {len(cls.search_queries)}")
                gevent.sleep(0)  # yield control so workers can use loaded queries
            except Exception as e:
                logging.warning(f"Could not load {file_path}: {e}")

        cls.data_loading = False
        logging.info(f"Finished loading all queries: {len(cls.search_queries)}")

    @classmethod
    def start_background_loader(cls):
        if not cls.data_loading and not cls.search_queries:
            gevent.spawn(cls._load_queries_streaming)

    # ---------------------------
    # Locust lifecycle
    # ---------------------------
    def on_start(self):
        """Initialize env vars and kick off background loader."""
        self.CLIENT_ID = os.environ.get("CLIENT_ID")
        self.CLIENT_SECRET = os.environ.get("CLIENT_SECRET")
        self.index_name = os.environ.get("VECTOR_SEARCH_INDEX_NAME")
        self.workspace_url = os.environ.get("DATABRICKS_WORKSPACE_URL")
        self.index_query_path = os.environ.get("VECTOR_SEARCH_INDEX_PATH")

        self.start_background_loader()

    def get_next_query(self) -> Dict[str, Any]:
        """Return a query if available, else dummy fallback."""
        with self._lock:
            if self.search_queries:
                return random.choice(self.search_queries)

        # fallback only if nothing loaded yet
        # logging.warning("Queries not ready yet, using fallback")
        return {"query_vector": [0.0] * 5, "num_results": 1}

    # ---------------------------
    # Locust tasks
    # ---------------------------
    @task
    def query_vector_search_ann(self):
        headers = {"Authorization": f"Bearer {self.get_oauth()}"}
        search_input = self.get_next_query()

        ann_payload = {"num_results": search_input.get("num_results", 10)}

        if "query_vector" not in search_input:
            logging.error("No query_vector in input")
            return

        ann_payload["query_vector"] = search_input["query_vector"]
        if "columns" in search_input and search_input["columns"] != ["*"]:
            ann_payload["columns"] = search_input["columns"]
        if "filters" in search_input:
            ann_payload["filters"] = search_input["filters"]
        if "score_threshold" in search_input:
            ann_payload["score_threshold"] = search_input["score_threshold"]

        # logging.warning(f"Sending query to path {self.index_query_path} with {ann_payload}")
        with self.client.get(
            self.index_query_path,
            headers=headers,
            json=ann_payload,
            catch_response=True,
            name="vector_search_query",
        ) as response:
            if response.status_code == 200:
                response.success()
            elif response.status_code in [429, 499]:
                # if hearder Retry-After is set, wait for that amount of time
                retry_after = response.headers.get("Retry-After") or response.headers.get("retry-after") or "3"
                time.sleep(int(retry_after))
                response.failure(f"Query failed: {response.status_code} {response.headers} {response.text}")
            else:
                response.failure(f"Query failed: {response.status_code} {response.headers} {response.text}")
