import collections
import logging
import os
import threading
import time
from contextlib import contextmanager
from typing import Deque, Dict, Iterable, Iterator, List, Optional, Tuple

import httpx
from mistralai import Mistral

logger = logging.getLogger(__name__)


def _load_default_api_keys() -> List[str]:
    """
    Load API keys from environment or fallback to embedded defaults.
    The environment variable can contain a comma or whitespace separated list.
    """
    env_keys = os.getenv("MISTRAL_OCR_API_KEYS")
    if env_keys:
        candidates = [token.strip() for token in env_keys.replace("\n", ",").split(",")]
        return [token for token in candidates if token]

    # NOTE: keep the existing defaults so behaviour matches prior implementation.
    return [
        "tGD9yzMaoe101FqGH1ZsBhw69h5O8kIn",
        "ZVPOtlLH7zHohsMU5xNDe6JN2AGqu4ZF",
        "eI4E6Ig5AWjLlpEfKuXMuwy0SFffEjIs",
        "yLBrW3ApdfQ9YXBZB7KfPKsdwJgVn2Cb",
        "vaMpOtfXx9JWYnErrrUtjiIDx3zkXYIQ",
        "LkWmMNiXpUUA3DWMj8HFLL8IRybxXFmK",
        "Z8Nx4DnwsmgcXH6aKqhsufOXmSJ06kun",
        "ZDRbkkBzIs1TzOKagXPMJKoWNoRnntBq",
        "V47OcoVvI6b03ub6ShrJHzbnlK1ZJzPg"
    ]


def _env_bool(name: str, default: bool = False) -> bool:
    value = os.getenv(name)
    if value is None:
        return default
    return value.strip().lower() in {"1", "true", "yes", "y", "on"}


def _env_int(name: str, default: int) -> int:
    value = os.getenv(name)
    if value is None:
        return default
    try:
        return int(value)
    except ValueError:
        logger.warning("Invalid %s=%r, using default %s", name, value, default)
        return default


def _env_float(name: str, default: Optional[float]) -> Optional[float]:
    value = os.getenv(name)
    if value is None:
        return default
    try:
        parsed = float(value)
    except ValueError:
        logger.warning("Invalid %s=%r, using default %s", name, value, default)
        return default
    if parsed <= 0:
        return None
    return parsed


def _load_mistral_proxy() -> Optional[str]:
    for name in ("MISTRAL_OCR_PROXY", "MISTRAL_OCR_HTTPS_PROXY"):
        value = os.getenv(name)
        if value and value.strip():
            return value.strip()
    return None


def _create_mistral_client(api_key: str) -> Mistral:
    # Mistral's SDK creates httpx clients with trust_env=True by default, which
    # inherits HTTP(S)_PROXY/ALL_PROXY from the host. OCR must be isolated from
    # those process-wide proxy settings unless explicitly opted in.
    trust_env = _env_bool("MISTRAL_OCR_TRUST_ENV", False)
    timeout_ms = _env_int("MISTRAL_OCR_TIMEOUT_MS", 120_000)
    timeout_seconds = max(1.0, timeout_ms / 1000.0)
    timeout = httpx.Timeout(timeout_seconds, connect=min(10.0, timeout_seconds))
    proxy = _load_mistral_proxy()
    client_kwargs = {"timeout": timeout, "trust_env": trust_env}
    async_client_kwargs = {"timeout": timeout, "trust_env": trust_env}
    if proxy:
        client_kwargs["proxy"] = proxy
        async_client_kwargs["proxy"] = proxy
    return Mistral(
        api_key=api_key,
        client=httpx.Client(**client_kwargs),
        async_client=httpx.AsyncClient(**async_client_kwargs),
        timeout_ms=timeout_ms,
    )


class _KeyState:
    __slots__ = ("client", "cooldown_until", "error_count")

    def __init__(self, client: Mistral) -> None:
        self.client = client
        self.cooldown_until: float = 0.0
        self.error_count: int = 0


class MistralOCRBalancer:
    """
    Thread-safe balancer that multiplexes Mistral OCR clients across API keys.

    Each API key can be leased concurrently up to `per_key_concurrency`. When a key
    repeatedly fails it is temporarily cooled down to avoid hammering the same credential.
    """

    def __init__(
        self,
        api_keys: Iterable[str],
        *,
        per_key_concurrency: int = 1,
        max_errors_before_cooldown: int = 3,
        cooldown_seconds: float = 5.0,
        default_timeout: Optional[float] = 30.0,
    ) -> None:
        keys = [key.strip() for key in api_keys if key and key.strip()]
        if not keys:
            raise ValueError("At least one Mistral OCR API key is required.")

        if per_key_concurrency < 1:
            raise ValueError("per_key_concurrency must be >= 1.")

        self._default_timeout = default_timeout
        self._per_key_concurrency = per_key_concurrency
        self._max_errors_before_cooldown = max(1, int(max_errors_before_cooldown))
        self._cooldown_seconds = max(0.0, float(cooldown_seconds))

        self._condition = threading.Condition()
        self._states: List[_KeyState] = []
        self._available: Deque[Tuple[int, int]] = collections.deque()
        self._in_use: Dict[Tuple[int, int], _KeyState] = {}

        for idx, api_key in enumerate(keys):
            state = _KeyState(_create_mistral_client(api_key))
            self._states.append(state)
            for slot in range(per_key_concurrency):
                self._available.append((idx, slot))

        self._stats_lock = threading.Lock()
        self._stats = {
            "acquired": 0,
            "released": 0,
            "failures": 0,
            "cooldowns": 0,
        }

    @property
    def default_timeout(self) -> Optional[float]:
        return self._default_timeout

    @property
    def key_count(self) -> int:
        return len(self._states)

    @property
    def total_slots(self) -> int:
        return len(self._states) * self._per_key_concurrency

    def _now(self) -> float:
        return time.monotonic()

    def _next_cooldown(self, now: float) -> Optional[float]:
        cooldowns = [
            state.cooldown_until - now
            for state in self._states
            if state.cooldown_until > now
        ]
        if not cooldowns:
            return None
        return max(0.0, min(cooldowns))

    def acquire(self, timeout: Optional[float] = None) -> Tuple[Tuple[int, int], Mistral]:
        """
        Acquire a Mistral client lease. Blocks until a key is available or timeout is reached.
        Returns a token that must be released.
        """
        deadline: Optional[float] = None
        if timeout is None:
            timeout = self._default_timeout
        if timeout is not None:
            deadline = self._now() + timeout

        with self._condition:
            while True:
                now = self._now()
                for _ in range(len(self._available)):
                    token = self._available.popleft()
                    state = self._states[token[0]]
                    if state.cooldown_until > now:
                        # Key still cooling down; put it back to the tail and continue.
                        self._available.append(token)
                        continue

                    self._in_use[token] = state
                    with self._stats_lock:
                        self._stats["acquired"] += 1
                    return token, state.client

                if deadline is not None and now >= deadline:
                    raise TimeoutError("Timed out waiting for an available Mistral OCR client.")

                wait_time = None
                cooldown_wait = self._next_cooldown(now)
                if cooldown_wait is not None:
                    wait_time = cooldown_wait
                if deadline is not None:
                    remaining = deadline - now
                    wait_time = remaining if wait_time is None else min(wait_time, remaining)
                    if wait_time <= 0:
                        raise TimeoutError(
                            "Timed out waiting for an available Mistral OCR client."
                        )

                self._condition.wait(timeout=wait_time)

    def release(self, token: Tuple[int, int], *, success: bool) -> None:
        """
        Release a previously acquired token. Mark as success or failure so the balancer can
        adapt error counters and cooldowns.
        """
        with self._condition:
            state = self._in_use.pop(token, None)
            if state is None:
                logger.warning("Attempted to release unknown Mistral OCR token %s", token)
                return

            now = self._now()
            if success:
                state.error_count = 0
            else:
                state.error_count += 1
                with self._stats_lock:
                    self._stats["failures"] += 1
                if state.error_count >= self._max_errors_before_cooldown:
                    state.error_count = 0
                    state.cooldown_until = max(state.cooldown_until, now + self._cooldown_seconds)
                    with self._stats_lock:
                        self._stats["cooldowns"] += 1

            self._available.append(token)
            with self._stats_lock:
                self._stats["released"] += 1
            self._condition.notify()

    @contextmanager
    def lease(self, timeout: Optional[float] = None) -> Iterator[Mistral]:
        """
        Context manager that yields a client and ensures the lease is released.
        """
        token, client = self.acquire(timeout=timeout)
        try:
            yield client
            self.release(token, success=True)
        except Exception:
            self.release(token, success=False)
            raise

    def stats(self) -> Dict[str, int]:
        """
        Return a shallow copy of usage statistics for observability and debugging.
        """
        with self._stats_lock:
            return dict(self._stats)


_shared_balancer: Optional[MistralOCRBalancer] = None
_shared_lock = threading.Lock()


def get_shared_balancer() -> MistralOCRBalancer:
    """
    Lazily construct a process-wide balancer so multiple modules can share connections.
    """
    global _shared_balancer
    if _shared_balancer is not None:
        return _shared_balancer

    with _shared_lock:
        if _shared_balancer is None:
            api_keys = _load_default_api_keys()
            per_key_concurrency = _env_int("MISTRAL_OCR_PER_KEY_CONCURRENCY", 1)
            max_errors = _env_int("MISTRAL_OCR_MAX_ERRORS_BEFORE_COOLDOWN", 3)
            cooldown_seconds = _env_float("MISTRAL_OCR_COOLDOWN_SEC", 5.0)
            if cooldown_seconds is None:
                cooldown_seconds = 0.0
            lease_timeout = _env_float("MISTRAL_OCR_LEASE_TIMEOUT_SEC", 30.0)
            _shared_balancer = MistralOCRBalancer(
                api_keys,
                per_key_concurrency=per_key_concurrency,
                max_errors_before_cooldown=max_errors,
                cooldown_seconds=cooldown_seconds,
                default_timeout=lease_timeout,
            )
            logger.info(
                "Initialized Mistral OCR balancer with %d API key(s), %d total slot(s), proxy=%s",
                _shared_balancer.key_count,
                _shared_balancer.total_slots,
                bool(_load_mistral_proxy()),
            )
    return _shared_balancer
