import base64
import io
import json
import os
import re
import time
from pathlib import Path
from typing import Any

import fitz
import httpx
from PIL import Image, ImageDraw

BASE_URL = "https://api.siliconflow.cn/v1"
MODEL = "PaddlePaddle/PaddleOCR-VL-1.5"
OUT_DIR = Path("/root/knowledge/service/task_executor/tmp/paddle_ocr_vl_eval_" + time.strftime("%Y%m%d_%H%M%S"))
OUT_DIR.mkdir(parents=True, exist_ok=True)

SAMPLES = [
    {
        "name": "plain_text_emotion_p1",
        "pdf": "/root/knowledge/storage/industry/13/knowledge/65b432ed855eba966be1439f55a38c9c.pdf",
        "page": 1,
        "note": "ordinary text page",
    },
    {
        "name": "scan_exam_many_images_p80",
        "pdf": "/root/knowledge/storage/industry/1/knowledge/d45ef8492af228c9aa4436132a15dfa6.pdf",
        "page": 80,
        "note": "scanned/image-heavy exam page",
    },
    {
        "name": "construction_layout_p46",
        "pdf": "/root/knowledge/storage/industry/9/knowledge/73d3f96f82d539b6df4c731ea772c96d.pdf",
        "page": 46,
        "note": "mixed engineering page with images/drawings/text",
    },
    {
        "name": "construction_layout_p68",
        "pdf": "/root/knowledge/storage/industry/9/knowledge/73d3f96f82d539b6df4c731ea772c96d.pdf",
        "page": 68,
        "note": "mixed engineering page with layout/table potential",
    },
]


def load_key():
    text = Path("/etc/yunwo/task-executor-ocr.env").read_text(encoding="utf-8", errors="ignore")
    raw = ""
    for line in text.splitlines():
        if line.strip().startswith("SILICONFLOW_DEEPSEEK_OCR_API_KEYS="):
            raw = line.split("=", 1)[1].strip()
            break
    keys = [p.strip() for p in re.split(r"[,;\s]+", raw) if p.strip()]
    if not keys:
        raise RuntimeError("no key")
    return keys[0]

KEY = load_key()

MARKDOWN_PROMPT = (
    "<image>\n<|grounding|>Convert this document page to markdown. Preserve reading order, "
    "headings, paragraphs, lists, tables and formulas. Return markdown only."
)
LAYOUT_PROMPT = (
    "<image>\n<|grounding|>Analyze this document page layout. Return strict JSON only, no markdown fences. "
    "Schema: {\"page_markdown\": string, \"blocks\": [{\"type\": \"title|paragraph|table|figure|chart|image|header|footer|formula|list\", "
    "\"text\": string, \"bbox\": [x1,y1,x2,y2]}]}. Use pixel coordinates in the provided image. "
    "Include figure/image/chart/table blocks even when text is empty."
)


def render_page(pdf_path: str, page_no: int, dpi=160):
    doc = fitz.open(pdf_path)
    page = doc[page_no - 1]
    pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 72, dpi / 72), alpha=False)
    img = Image.open(io.BytesIO(pix.tobytes("png"))).convert("RGB")
    # keep same envelope as production
    max_edge = 2400
    if max(img.size) > max_edge:
        scale = max_edge / max(img.size)
        img = img.resize((int(img.width * scale), int(img.height * scale)))
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=86, optimize=True)
    info = {
        "width": img.width,
        "height": img.height,
        "pdf_text_chars": len(page.get_text("text") or ""),
        "pdf_image_count": len(page.get_images(full=True)),
        "pdf_drawing_count": len(page.get_drawings()),
    }
    doc.close()
    return img, buf.getvalue(), info


def call_model(image_bytes: bytes, prompt: str, max_tokens=4096):
    image_url = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode("ascii")
    payload = {
        "model": MODEL,
        "messages": [{"role": "user", "content": [
            {"type": "image_url", "image_url": {"url": image_url, "detail": "high"}},
            {"type": "text", "text": prompt},
        ]}],
        "stream": False,
        "temperature": 0,
        "max_tokens": max_tokens,
    }
    headers = {"Authorization": f"Bearer {KEY}", "Content-Type": "application/json"}
    started = time.time()
    with httpx.Client(timeout=httpx.Timeout(120, connect=10), trust_env=False, http2=False) as client:
        resp = client.post(f"{BASE_URL}/chat/completions", headers=headers, json=payload)
    elapsed = time.time() - started
    text = resp.text
    try:
        data = resp.json()
        if resp.status_code < 400:
            content = (((data.get("choices") or [{}])[0].get("message") or {}).get("content") or "")
        else:
            content = text
    except Exception:
        data = None
        content = text
    return {"status": resp.status_code, "elapsed_sec": round(elapsed, 3), "raw": data if data is not None else text[:1000], "content": content}


def strip_fence(text: str) -> str:
    s = (text or "").strip()
    if s.startswith("```"):
        lines = s.splitlines()
        if lines and lines[0].startswith("```"):
            lines = lines[1:]
        if lines and lines[-1].strip() == "```":
            lines = lines[:-1]
        return "\n".join(lines).strip()
    return s


def parse_jsonish(text: str) -> Any:
    s = strip_fence(text)
    try:
        return json.loads(s)
    except Exception:
        pass
    m = re.search(r"\{.*\}", s, re.S)
    if m:
        try:
            return json.loads(m.group(0))
        except Exception:
            return None
    return None


def normalize_bbox(bbox, width, height):
    if not isinstance(bbox, list) or len(bbox) < 4:
        return None
    try:
        x1, y1, x2, y2 = [float(v) for v in bbox[:4]]
    except Exception:
        return None
    if max(abs(x1), abs(y1), abs(x2), abs(y2)) <= 1000 and (width > 1000 or height > 1000):
        x1, x2 = x1 * width / 1000, x2 * width / 1000
        y1, y2 = y1 * height / 1000, y2 * height / 1000
    x1, x2 = sorted((max(0, min(width, x1)), max(0, min(width, x2))))
    y1, y2 = sorted((max(0, min(height, y1)), max(0, min(height, y2))))
    if x2 - x1 < 5 or y2 - y1 < 5:
        return None
    return [int(x1), int(y1), int(x2), int(y2)]


def annotate_and_crop(sample_dir: Path, img: Image.Image, parsed: Any):
    blocks = []
    if isinstance(parsed, dict):
        blocks = parsed.get("blocks") or []
    if not isinstance(blocks, list):
        blocks = []
    annotated = img.copy()
    draw = ImageDraw.Draw(annotated)
    crops = []
    for idx, block in enumerate(blocks[:80]):
        if not isinstance(block, dict):
            continue
        bbox = normalize_bbox(block.get("bbox"), img.width, img.height)
        if not bbox:
            continue
        typ = str(block.get("type") or "block")
        color = "red" if typ in {"figure", "image", "chart", "table"} else "blue"
        draw.rectangle(bbox, outline=color, width=3)
        draw.text((bbox[0] + 2, max(0, bbox[1] - 14)), f"{idx}:{typ}", fill=color)
        if typ in {"figure", "image", "chart", "table"}:
            crop = img.crop(tuple(bbox))
            crop_path = sample_dir / f"crop_{idx}_{typ}.jpg"
            crop.save(crop_path, quality=92)
            crops.append(str(crop_path))
    annotated_path = sample_dir / "layout_annotated.jpg"
    annotated.save(annotated_path, quality=92)
    return {"block_count": len(blocks), "crop_count": len(crops), "crops": crops, "annotated": str(annotated_path)}

summary = []
for sample in SAMPLES:
    sample_dir = OUT_DIR / sample["name"]
    sample_dir.mkdir(parents=True, exist_ok=True)
    img, image_bytes, render_info = render_page(sample["pdf"], sample["page"])
    image_path = sample_dir / "page.jpg"
    img.save(image_path, quality=92)
    markdown = call_model(image_bytes, MARKDOWN_PROMPT, max_tokens=4096)
    layout = call_model(image_bytes, LAYOUT_PROMPT, max_tokens=4096)
    parsed = parse_jsonish(layout["content"])
    crop_info = annotate_and_crop(sample_dir, img, parsed)
    record = {
        **sample,
        "out_dir": str(sample_dir),
        "image_path": str(image_path),
        "render_info": render_info,
        "markdown_status": markdown["status"],
        "markdown_elapsed_sec": markdown["elapsed_sec"],
        "markdown_chars": len(markdown["content"] or ""),
        "markdown_preview": (markdown["content"] or "")[:800],
        "layout_status": layout["status"],
        "layout_elapsed_sec": layout["elapsed_sec"],
        "layout_chars": len(layout["content"] or ""),
        "layout_parse_ok": parsed is not None,
        "layout_preview": (layout["content"] or "")[:1200],
        "crop_info": crop_info,
    }
    (sample_dir / "markdown.txt").write_text(markdown["content"] or "", encoding="utf-8")
    (sample_dir / "layout_raw.txt").write_text(layout["content"] or "", encoding="utf-8")
    (sample_dir / "record.json").write_text(json.dumps(record, ensure_ascii=False, indent=2), encoding="utf-8")
    summary.append(record)

summary_path = OUT_DIR / "summary.json"
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps({"out_dir": str(OUT_DIR), "summary": summary}, ensure_ascii=False, indent=2))
