#!/usr/bin/env python3
"""
Probe Mistral OCR without touching task/database state.

Examples:
  python mistral_ocr_probe.py --source temp-url --transport sdk
  python mistral_ocr_probe.py --source upload --transport sdk
  python mistral_ocr_probe.py --source temp-url --transport rest
"""

from __future__ import annotations

import argparse
import json
import os
import sys
import tempfile
import time
from pathlib import Path
from typing import Any, Dict, Optional


API_BASE = "https://api.mistral.ai/v1"


def _split_keys(value: str) -> list[str]:
    return [token.strip() for token in value.replace("\n", ",").replace(" ", ",").split(",") if token.strip()]


def load_api_key() -> str:
    for name in ("MISTRAL_API_KEY", "MISTRAL_OCR_API_KEY", "MISTRAL_OCR_API_KEYS"):
        value = os.getenv(name)
        if value:
            keys = _split_keys(value)
            if keys:
                return keys[0]

    try:
        import mistral_ocr_balancer

        keys = mistral_ocr_balancer._load_default_api_keys()
        if keys:
            return keys[0]
    except Exception:
        pass

    raise RuntimeError(
        "No Mistral API key found. Set MISTRAL_API_KEY or MISTRAL_OCR_API_KEYS."
    )


def load_temp_server_url() -> str:
    value = os.getenv("TASK_EXECUTOR_TEMP_FILE_SERVER_URL")
    if value:
        return value.rstrip("/") + "/"

    try:
        import value as task_value

        url = getattr(task_value, "tempfileserverurl", "")
        if url:
            return str(url).rstrip("/") + "/"
    except Exception:
        pass

    return "https://tempfileserver.knowledge.yunwoai.com/"


def env_bool(name: str, default: bool = False) -> bool:
    value = os.getenv(name)
    if value is None:
        return default
    return value.strip().lower() in {"1", "true", "yes", "y", "on"}


def load_mistral_proxy() -> Optional[str]:
    for name in ("MISTRAL_OCR_PROXY", "MISTRAL_OCR_HTTPS_PROXY"):
        value = os.getenv(name)
        if value and value.strip():
            return value.strip()
    return None


def pdf_escape(text: str) -> str:
    return text.replace("\\", "\\\\").replace("(", "\\(").replace(")", "\\)")


def write_probe_pdf(path: Path, text: str) -> None:
    content = f"BT /F1 24 Tf 72 720 Td ({pdf_escape(text)}) Tj ET\n".encode("ascii")
    objects = [
        b"<< /Type /Catalog /Pages 2 0 R >>",
        b"<< /Type /Pages /Kids [3 0 R] /Count 1 >>",
        b"<< /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] "
        b"/Resources << /Font << /F1 4 0 R >> >> /Contents 5 0 R >>",
        b"<< /Type /Font /Subtype /Type1 /BaseFont /Helvetica >>",
        b"<< /Length " + str(len(content)).encode("ascii") + b" >>\nstream\n" + content + b"endstream",
    ]

    data = bytearray(b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n")
    offsets = [0]
    for idx, obj in enumerate(objects, start=1):
        offsets.append(len(data))
        data.extend(f"{idx} 0 obj\n".encode("ascii"))
        data.extend(obj)
        data.extend(b"\nendobj\n")

    xref_offset = len(data)
    data.extend(f"xref\n0 {len(objects) + 1}\n".encode("ascii"))
    data.extend(b"0000000000 65535 f \n")
    for offset in offsets[1:]:
        data.extend(f"{offset:010d} 00000 n \n".encode("ascii"))
    data.extend(
        (
            f"trailer\n<< /Size {len(objects) + 1} /Root 1 0 R >>\n"
            f"startxref\n{xref_offset}\n%%EOF\n"
        ).encode("ascii")
    )
    path.write_bytes(bytes(data))


def make_default_pdf() -> Path:
    path = Path(tempfile.gettempdir()) / f"mistral_ocr_probe_{os.getpid()}_{int(time.time())}.pdf"
    write_probe_pdf(path, "Mistral OCR probe 2026-06-12")
    return path


def to_jsonable(value: Any) -> Any:
    if value is None or isinstance(value, (str, int, float, bool)):
        return value
    if isinstance(value, (list, tuple)):
        return [to_jsonable(item) for item in value]
    if isinstance(value, dict):
        return {str(key): to_jsonable(val) for key, val in value.items()}
    if hasattr(value, "model_dump"):
        return to_jsonable(value.model_dump())
    if hasattr(value, "dict"):
        return to_jsonable(value.dict())
    if hasattr(value, "__dict__"):
        return to_jsonable(vars(value))
    return repr(value)


def summarize_response(response: Any) -> Dict[str, Any]:
    data = to_jsonable(response)
    pages = data.get("pages", []) if isinstance(data, dict) else []
    first_markdown = ""
    if pages:
        first = pages[0] or {}
        first_markdown = str(first.get("markdown", ""))[:300]
    return {
        "ok": True,
        "page_count": len(pages),
        "first_markdown_preview": first_markdown,
        "usage_info": data.get("usage_info") if isinstance(data, dict) else None,
    }


def error_summary(exc: BaseException) -> Dict[str, Any]:
    result: Dict[str, Any] = {
        "ok": False,
        "error_type": f"{exc.__class__.__module__}.{exc.__class__.__name__}",
        "error": str(exc),
    }
    for attr in ("status_code", "response", "body"):
        value = getattr(exc, attr, None)
        if value is None:
            continue
        if attr == "response":
            result["response_status_code"] = getattr(value, "status_code", None)
            try:
                result["response_text"] = value.text[:1000]
            except Exception:
                result["response"] = repr(value)
        else:
            result[attr] = str(value)[:1000]
    return result


def make_client(api_key: str) -> Any:
    import httpx
    from mistralai import Mistral

    proxy = load_mistral_proxy()
    trust_env = env_bool("MISTRAL_OCR_TRUST_ENV", False)
    timeout_ms = int(float(os.getenv("MISTRAL_OCR_TIMEOUT_MS", "120000")))
    timeout_seconds = max(1.0, timeout_ms / 1000.0)
    timeout = httpx.Timeout(timeout_seconds, connect=min(10.0, timeout_seconds))
    client_kwargs = {"timeout": timeout, "trust_env": trust_env}
    async_client_kwargs = {"timeout": timeout, "trust_env": trust_env}
    if proxy:
        client_kwargs["proxy"] = proxy
        async_client_kwargs["proxy"] = proxy
    try:
        return Mistral(
            api_key=api_key,
            client=httpx.Client(**client_kwargs),
            async_client=httpx.AsyncClient(**async_client_kwargs),
            timeout_ms=timeout_ms,
        )
    except TypeError:
        return Mistral(api_key)


def upload_and_get_signed_url(api_key: str, pdf_path: Path) -> str:
    client = make_client(api_key)
    with pdf_path.open("rb") as pdf_file:
        uploaded = client.files.upload(
            file={"file_name": pdf_path.name, "content": pdf_file},
            purpose="ocr",
        )
    file_id = getattr(uploaded, "id", None)
    if not file_id:
        raise RuntimeError(f"Mistral file upload did not return an id: {uploaded!r}")

    signed_url = client.files.get_signed_url(file_id=file_id)
    url = getattr(signed_url, "url", None)
    if not url:
        raise RuntimeError(f"Mistral signed URL response did not include url: {signed_url!r}")
    return url


def preflight_url(url: str, timeout: float) -> Dict[str, Any]:
    import httpx

    try:
        with httpx.Client(timeout=timeout, follow_redirects=True) as client:
            response = client.get(url, headers={"Range": "bytes=0-127"})
            return {
                "ok": response.status_code < 400,
                "status_code": response.status_code,
                "content_type": response.headers.get("content-type"),
                "content_length": response.headers.get("content-length"),
                "bytes_read": len(response.content),
            }
    except Exception as exc:
        return error_summary(exc)


def ocr_sdk(
    api_key: str,
    document: Dict[str, Any],
    model: str,
    include_image_base64: bool,
) -> Dict[str, Any]:
    client = make_client(api_key)
    kwargs: Dict[str, Any] = {
        "model": model,
        "document": document,
    }
    if include_image_base64:
        kwargs["include_image_base64"] = True
    try:
        return summarize_response(client.ocr.process(**kwargs))
    except TypeError:
        if include_image_base64:
            kwargs.pop("include_image_base64", None)
            return summarize_response(client.ocr.process(**kwargs))
        raise


def ocr_rest(
    api_key: str,
    document: Dict[str, Any],
    model: str,
    include_image_base64: bool,
    timeout: float,
) -> Dict[str, Any]:
    import httpx

    payload: Dict[str, Any] = {
        "model": model,
        "document": document,
    }
    if include_image_base64:
        payload["include_image_base64"] = True

    client_kwargs = {
        "timeout": timeout,
        "trust_env": env_bool("MISTRAL_OCR_TRUST_ENV", False),
    }
    proxy = load_mistral_proxy()
    if proxy:
        client_kwargs["proxy"] = proxy
    with httpx.Client(**client_kwargs) as client:
        response = client.post(
            f"{API_BASE}/ocr",
            headers={
                "Authorization": f"Bearer {api_key}",
                "Content-Type": "application/json",
            },
            json=payload,
        )
    if response.status_code >= 400:
        return {
            "ok": False,
            "status_code": response.status_code,
            "response_text": response.text[:2000],
        }
    return summarize_response(response.json())


def build_document(args: argparse.Namespace, api_key: str, pdf_path: Path) -> Dict[str, Any]:
    if args.source == "url":
        if not args.url:
            raise RuntimeError("--url is required when --source=url")
        return {"type": "document_url", "document_url": args.url}

    if args.source == "temp-url":
        relative_path = os.path.relpath(str(pdf_path), start=tempfile.gettempdir())
        return {
            "type": "document_url",
            "document_url": load_temp_server_url() + relative_path,
        }

    if args.source == "upload":
        signed_url = upload_and_get_signed_url(api_key, pdf_path)
        return {"type": "document_url", "document_url": signed_url}

    raise RuntimeError(f"Unsupported source: {args.source}")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Probe Mistral OCR availability.")
    parser.add_argument("--source", choices=("temp-url", "upload", "url"), default="temp-url")
    parser.add_argument("--transport", choices=("sdk", "rest"), default="sdk")
    parser.add_argument("--pdf", help="Existing PDF path. If omitted, a tiny probe PDF is generated.")
    parser.add_argument("--url", help="Document URL for --source=url.")
    parser.add_argument("--model", default="mistral-ocr-latest")
    parser.add_argument("--timeout", type=float, default=60.0)
    parser.add_argument("--proxy", help="HTTP CONNECT proxy for Mistral API requests.")
    parser.add_argument("--include-image-base64", action="store_true")
    parser.add_argument("--keep", action="store_true", help="Keep generated probe PDF.")
    parser.add_argument("--skip-url-preflight", action="store_true")
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    if args.proxy:
        os.environ["MISTRAL_OCR_PROXY"] = args.proxy
    generated_pdf = False
    pdf_path = Path(args.pdf).resolve() if args.pdf else make_default_pdf()
    generated_pdf = not bool(args.pdf)

    started = time.time()
    try:
        api_key = load_api_key()
        document = build_document(args, api_key, pdf_path)
        output: Dict[str, Any] = {
            "source": args.source,
            "transport": args.transport,
            "model": args.model,
            "pdf_path": str(pdf_path),
            "document": {
                "type": document.get("type"),
                "document_url_prefix": str(document.get("document_url", ""))[:120],
            },
            "proxy": bool(load_mistral_proxy()),
            "preflight": None,
            "ocr": None,
        }

        if document.get("type") in ("document_url", "image_url") and not args.skip_url_preflight:
            output["preflight"] = preflight_url(str(document.get("document_url")), args.timeout)

        try:
            if args.transport == "sdk":
                output["ocr"] = ocr_sdk(
                    api_key,
                    document,
                    args.model,
                    args.include_image_base64,
                )
            else:
                output["ocr"] = ocr_rest(
                    api_key,
                    document,
                    args.model,
                    args.include_image_base64,
                    args.timeout,
                )
        except Exception as exc:
            output["ocr"] = error_summary(exc)

        output["elapsed_sec"] = round(time.time() - started, 3)
        print(json.dumps(output, ensure_ascii=False, indent=2))
        return 0 if output["ocr"] and output["ocr"].get("ok") else 2
    finally:
        if generated_pdf and not args.keep:
            try:
                pdf_path.unlink(missing_ok=True)
            except Exception:
                pass


if __name__ == "__main__":
    sys.exit(main())
