#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
后处理修复脚本 — 修复 convert_to_json.py 产物的已知质量问题
============================================================
修复项：
  1. stem末尾残留括号 (（ / ()
  2. 选项中含 \xa0 分隔符导致多个选项合并为一个
  3. 华秀题库混合格式：后半段 huaxiu 格式被 standard 解析器误处理
  4. 忙医考 E选项题（5选项）答案引用修复
"""

import json
import re
import glob
import os
from pathlib import Path

JSON_DIR = Path(__file__).parent / "standardized_json"
SRC_DIR = Path(__file__).parent

fixed_count = 0
fixed_files = 0


def fix_trailing_paren(stem: str) -> str:
    """去除 stem 末尾的半截括号。"""
    s = stem.rstrip()
    while s and s[-1] in '（(':
        s = s[:-1].rstrip()
    return s


def fix_merged_options(options: dict) -> dict:
    """
    检测选项值中含 \\xa0 / 多选项合并 / E选项合并的情况，拆分为独立选项。
    例: {'D': '补虚泻实     E．辨证、辨病'} → {'D': '补虚泻实', 'E': '辨证、辨病'}
    """
    new_opts = {}
    changed = False

    for key, val in options.items():
        should_split = False

        # 检测1: 值中出现连续的选项字母序列 如 "xxx B xxx C xxx D"
        # 空格+字母+分隔符
        seq = re.findall(
            r'[\s\xa0]([A-Ea-e])[.．、：:\s\xa0]',
            val
        )
        # 中文直接跟字母（不限制后面必须空格）
        seq += re.findall(r'[一-鿿]([A-Ea-e])(?=[\s\xa0.．、：一-鿿])', val)
        # 末尾字母
        seq += re.findall(r'[\s\xa0]([A-Ea-e])$', val)

        unique = list(dict.fromkeys(s.upper() for s in seq))
        if len(unique) >= 2 and all(
            ord(unique[i+1]) - ord(unique[i]) == 1
            for i in range(len(unique)-1)
        ):
            should_split = True

        # 检测2: 单个选项值中嵌入了后续选项 如 "补虚泻实     E．辨证、辨病"
        # 条件: 值中含 空格+字母+分隔符 模式（字母不在开头）
        if not should_split and len(val) > 8:
            embedded = re.search(
                r'[\s\xa0\t]([A-Ea-e])[.．、：]',
                val
            )
            if embedded:
                should_split = True

        # 检测3: 值中含 空格+字母+空格+中文 模式 如 "血热妄行  B 阴虚内热"
        # 表示两个选项合并为一个值
        if not should_split and len(val) > 8:
            embedded2 = re.search(
                r'\s([A-Ea-e])\s[一-鿿]',
                val
            )
            if embedded2:
                should_split = True

        if should_split:
            # 确认是合并选项，执行拆分
            # 用字母边界拆分：
            # 模式1: 任意空白(含全角空格　) + 字母 + 分隔符 "xxx B.xxx"
            # 模式2: 中文直接跟字母 "增大B 右室" 或 "大D心包"
            parts = re.split(
                r'(?:\s+|(?<=[一-鿿]))([A-Ea-e])(?:[.．、：:\s]+|(?=[一-鿿])|$)',
                val
            )
            if len(parts) >= 3:
                new_opts[key] = parts[0].rstrip()
                i = 1
                while i + 1 < len(parts):
                    k = parts[i].upper()
                    v = parts[i + 1].strip()
                    if k not in new_opts:
                        new_opts[k] = v
                    i += 2
                changed = True
            else:
                new_opts[key] = val
        elif '\xa0' in val:
            new_opts[key] = val.replace('\xa0', ' ')
            changed = True
        else:
            new_opts[key] = val

    return new_opts if changed else options


def fix_e_in_stem(stem: str, options: dict) -> tuple:
    """
    从 stem 中提取混入的 E 选项。
    例: stem="仰卧位颈椎旋转扳法多用于(  E" → 去掉尾部 "(  E"
    例: stem="某患者...即要（  E" → 去掉尾部 "（  E"
    同时如果 D 选项中含 "E．xxx"，拆出 E 选项。
    """
    changed = False

    # 从stem尾部提取E标记
    s = stem.rstrip()
    # 匹配末尾: (E / ( E / （E / （ E / (  E 等
    m = re.search(r'[（(]\s*E\s*$', s)
    if m:
        s = s[:m.start()].rstrip()
        changed = True
    # 也匹配末尾直接跟 E 的（无括号）
    elif re.search(r'\sE$', s) and not re.search(r'[A-Da-d]\s*E$', s):
        # 确认不是正常文本中的E
        s = re.sub(r'\s*E$', '', s).rstrip()
        changed = True

    # 从stem中提取 "E）" / "E)" 模式
    m2 = re.search(r'[（(]\s*E\s*[）)]', s)
    if m2:
        s = (s[:m2.start()] + s[m2.end():]).strip()
        changed = True

    return (s, True) if changed else (stem, False)


def fix_answer_ref(answer: list, options: dict) -> list:
    """如果答案引用了不存在的选项key，尝试修正。"""
    fixed = []
    for a in answer:
        if a in options:
            fixed.append(a)
        else:
            # 答案不在选项中，保留原样
            fixed.append(a)
    return fixed if fixed else answer


def reparse_huaxiu_file(filepath: Path) -> list:
    """
    重新解析华秀题库源文件，按段落检测格式，混合处理。
    """
    raw = None
    for enc in ["utf-8", "utf-8-sig", "gbk"]:
        try:
            with open(filepath, "r", encoding=enc) as f:
                raw = f.read()
            break
        except UnicodeDecodeError:
            continue
    if raw is None:
        return None

    raw = re.sub(r'^•\s*', '', raw, flags=re.MULTILINE)
    lines = raw.split("\n")

    # 按段落分割（题号为1的行作为新段落起点）
    # 注意：排除 "2B．" 这种华秀选项前缀（数字后跟字母）
    all_questions = []
    block = []

    for i, line in enumerate(lines):
        stripped = line.strip()
        # 检测新段落开始：题号为 1 且后面不是字母（排除 "1A：" 等选项行）
        if re.match(r'^1[.．、]\s', stripped) and not re.match(r'^1[A-Za-z]', stripped) and block:
            parsed = parse_block(block, filepath.name)
            all_questions.extend(parsed)
            block = []
        block.append(line)

    if block:
        parsed = parse_block(block, filepath.name)
        all_questions.extend(parsed)

    return all_questions if all_questions else None


def parse_block(lines: list, filename: str) -> list:
    """解析一个段落，自动检测 standard / huaxiu 格式。"""
    text = "\n".join(lines)

    # 检测是否含 huaxiu 特征: 数字+字母+冒号 如 "2B："
    huaxiu_count = len(re.findall(r'^\d*[A-Da-d][：:]', text, re.MULTILINE))
    # 检测 standard 特征: 正确答案
    standard_count = len(re.findall(r'正确答案[：:]', text))

    # 如果同时有数字前缀选项和正确答案，且数字前缀多于3个，认为是 huaxiu
    has_digit_prefix = len(re.findall(r'^\d+[A-Da-d][、：:.．]', text, re.MULTILINE))

    if has_digit_prefix >= 2:
        return parse_huaxiu_block(lines, filename)
    else:
        return parse_standard_block(lines, filename)


def parse_standard_block(lines: list, filename: str) -> list:
    """标准格式解析（简化版）。"""
    questions = []
    i = 0
    while i < len(lines):
        line = lines[i].strip()
        if not line:
            i += 1
            continue

        q_match = re.match(r'^(\d+)[、.．:：\s]\s*(.*?)$', line)
        if not q_match:
            i += 1
            continue

        q_num = int(q_match.group(1))
        stem = q_match.group(2).strip().rstrip("（(）)")

        options = {}
        answer = ""
        j = i + 1
        while j < len(lines):
            ol = lines[j].strip()
            if not ol:
                j += 1
                continue

            ans_match = re.match(r'(?:正确[答案]{0,2}|答|答案)[：:]?\s*([A-Da-d])', ol)
            if ans_match:
                answer = ans_match.group(1).upper()
                break

            # 同行多选项: "A.内容 B.内容 C.内容 D.内容"
            inline = re.findall(r'([A-Da-d])[.．、]\s*(.+?)(?=\s+[A-Da-d][.．、]|$)', ol)
            if len(inline) >= 3:
                for k, v in inline:
                    options[k.upper()] = v.strip().rstrip("；;，,。.")
                j += 1
                continue

            opt_match = re.match(r'^([A-Da-d])\s*[、.．:：\s]\s*(.*?)$', ol)
            if opt_match:
                options[opt_match.group(1).upper()] = opt_match.group(2).strip().rstrip("；;，,。.")
                j += 1
                continue

            if re.match(r'^\d+[、.．:：\s]', ol):
                break
            j += 1

        if answer and options:
            questions.append({
                "id": q_num,
                "type": "single",
                "stem": stem,
                "options": dict(sorted(options.items())),
                "answer": [answer],
            })
        i = j if j > i else i + 1
    return questions


def parse_huaxiu_block(lines: list, filename: str) -> list:
    """华秀格式解析（简化版）。"""
    questions = []
    text = "\n".join(lines)

    q_blocks = re.split(r'\n(?=\s*\d+[.．、]\s)', text)

    for block in q_blocks:
        block = block.strip()
        if not block:
            continue

        block_lines = block.split("\n")
        first = block_lines[0].strip()
        q_match = re.match(r'^(\d+)[.．、]\s*(.*?)$', first)
        if not q_match:
            continue

        q_num = int(q_match.group(1))
        stem = q_match.group(2).strip()

        options = {}
        answer = ""

        for line in block_lines[1:]:
            line = line.strip()
            if not line:
                continue

            ans_match = re.match(r'(?:正确[答案]{0,2}|答|答案)[：:]?\s*(\d*[A-Da-d])', line)
            if ans_match:
                raw = ans_match.group(1).strip()
                m = re.search(r'[A-Da-d]', raw)
                answer = m.group(0).upper() if m else raw.upper()
                continue

            # 带数字前缀: "2B：内容" 或 "4D．内容" 或 "4D、内容"
            opt_match = re.match(r'^(\d*)([A-Da-d])\s*[、：:.．]\s*(.*?)$', line)
            if opt_match:
                key = opt_match.group(2).upper()
                opt_text = opt_match.group(3).strip().rstrip("；;，,。.")
                # 去掉可能混入的答案标记
                opt_text = re.sub(r'\d*[A-Da-d]正确答案[：:].*$', '', opt_text).strip()
                if opt_text:
                    options[key] = opt_text
                continue

            opt_match2 = re.match(r'^([A-Da-d])\s*[、.．:：\s]\s*(.*?)$', line)
            if opt_match2:
                options[opt_match2.group(1).upper()] = opt_match2.group(2).strip().rstrip("；;，,。.")

        if answer and options:
            questions.append({
                "id": q_num,
                "type": "single",
                "stem": stem,
                "options": dict(sorted(options.items())),
                "answer": [answer],
            })

    return questions


def process_json_file(json_path: Path) -> bool:
    """处理单个JSON文件，返回是否有修改。"""
    global fixed_count

    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    questions = data.get("questions", [])
    if not questions:
        return False

    changed = False
    local_fix = 0

    for q in questions:
        q_changed = False

        # Fix 1: stem末尾括号
        original_stem = q.get("stem", "")
        fixed_stem = fix_trailing_paren(original_stem)
        if fixed_stem != original_stem:
            q["stem"] = fixed_stem
            q_changed = True

        # Fix 1b: stem中混入的E选项
        fixed_stem2, e_fixed = fix_e_in_stem(q.get("stem", ""), q.get("options", {}))
        if e_fixed:
            q["stem"] = fixed_stem2
            q_changed = True

        # Fix 2: 合并选项拆分（含E选项）
        original_opts = q.get("options", {})
        fixed_opts = fix_merged_options(original_opts)
        if fixed_opts is not original_opts:
            q["options"] = dict(sorted(fixed_opts.items()))
            q_changed = True

        # Fix 3: 答案引用修正
        if q.get("answer") and q.get("options"):
            fixed_ans = fix_answer_ref(q["answer"], q["options"])
            if fixed_ans != q["answer"]:
                q["answer"] = fixed_ans
                q_changed = True

        if q_changed:
            local_fix += 1
            changed = True

    if changed:
        fixed_count += local_fix
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)

    return changed


def main():
    global fixed_count, fixed_files

    json_files = sorted(JSON_DIR.glob("*.json"))
    json_files = [f for f in json_files if f.name != "summary.json"]

    print(f"扫描 {len(json_files)} 个JSON文件...\n")

    # Phase 1: 后处理修复（stem括号、合并选项、答案引用）
    print("=" * 50)
    print("Phase 1: 后处理修复")
    print("=" * 50)
    for jf in json_files:
        before = fixed_count
        if process_json_file(jf):
            fixed_files += 1
            print(f"  修复 {jf.name}: {fixed_count - before} 处")

    print(f"\nPhase 1 完成: {fixed_files} 文件, {fixed_count} 处修复\n")

    # Phase 2: 重新解析华秀题库（混合格式问题）
    print("=" * 50)
    print("Phase 2: 重新解析华秀题库")
    print("=" * 50)
    huaxiu_src = SRC_DIR / "华秀题库.txt"
    if huaxiu_src.exists():
        questions = reparse_huaxiu_file(huaxiu_src)
        if questions:
            # 构建meta
            meta = {
                "title": "华秀题库",
                "source": "华秀题库.txt",
                "totalQuestions": len(questions),
                "version": "1.1",
                "subject": "综合练习",
                "note": "re-parsed with mixed format support"
            }
            output = {"meta": meta, "questions": questions}
            out_path = JSON_DIR / "华秀题库.json"
            with open(out_path, "w", encoding="utf-8") as f:
                json.dump(output, f, ensure_ascii=False, indent=2)
            print(f"  华秀题库: 重新解析 {len(questions)} 题")

            # 质量检查
            no_opts = sum(1 for q in questions if len(q.get("options", {})) == 0)
            no_ans = sum(1 for q in questions if not q.get("answer"))
            print(f"  质量: 无选项={no_opts}, 无答案={no_ans}")
        else:
            print("  华秀题库: 重新解析失败")

    # Phase 3: 重新解析 综合练习2 和 解析2（\xa0 分隔符问题）
    print(f"\n{'=' * 50}")
    print("Phase 3: 重新解析 \\xa0 分隔符文件")
    print("=" * 50)

    for src_name, json_name in [
        ("综合练习2.txt", "综合练习2.json"),
        ("解析2 盲人医疗按摩人员考试综合笔试模拟试题（二）.txt",
         "解析2 盲人医疗按摩人员考试综合笔试模拟试题（二）.json"),
    ]:
        src_path = SRC_DIR / src_name
        if not src_path.exists():
            continue

        raw = None
        for enc in ["utf-8", "utf-8-sig", "gbk"]:
            try:
                with open(src_path, "r", encoding=enc) as f:
                    raw = f.read()
                break
            except UnicodeDecodeError:
                continue
        if not raw:
            continue

        raw = re.sub(r'^•\s*', '', raw, flags=re.MULTILINE)
        lines = raw.split("\n")

        questions = parse_standard_block(lines, src_name)
        if questions:
            # 应用后处理
            for q in questions:
                q["stem"] = fix_trailing_paren(q.get("stem", ""))
                q["options"] = dict(sorted(fix_merged_options(q.get("options", {})).items()))

            meta = {
                "title": src_name.replace(".txt", ""),
                "source": src_name,
                "totalQuestions": len(questions),
                "version": "1.1",
                "note": "re-parsed with \\xa0 fix"
            }
            out_path = JSON_DIR / json_name
            with open(out_path, "w", encoding="utf-8") as f:
                json.dump({"meta": meta, "questions": questions}, f, ensure_ascii=False, indent=2)

            no_opts = sum(1 for q in questions if len(q.get("options", {})) < 3)
            no_ans = sum(1 for q in questions if not q.get("answer"))
            print(f"  {json_name}: {len(questions)} 题, 选项<3={no_opts}, 无答案={no_ans}")

    # 最终质量报告
    print(f"\n{'=' * 50}")
    print("最终质量报告")
    print("=" * 50)

    total_q = 0
    total_bad = 0
    for jf in sorted(JSON_DIR.glob("*.json")):
        if jf.name == "summary.json":
            continue
        with open(jf, encoding="utf-8") as f:
            d = json.load(f)
        qs = d.get("questions", [])
        total = len(qs)
        total_q += total

        bad = 0
        for q in qs:
            issues = []
            if not q.get("answer"):
                issues.append("no_ans")
            if len(q.get("options", {})) < 3:
                issues.append("few_opt")
            if q.get("stem", "").rstrip().endswith(("(", "（")):
                issues.append("paren")
            for a in q.get("answer", []):
                if a not in q.get("options", {}):
                    issues.append("bad_ref")
                    break
            if issues:
                bad += 1

        total_bad += bad
        pct = (bad / total * 100) if total else 0
        status = "  OK" if pct == 0 else (" !!" if pct > 50 else " !")
        if pct > 0:
            print(f"{status} {jf.name}: {total}题, 坏={bad} ({pct:.0f}%)")

    good = total_q - total_bad
    print(f"\n总计: {total_q}题, 好={good} 坏={total_bad} ({total_bad/total_q*100:.1f}%)")
    print(f"修复前: 好=5577 坏=1617 (22.5%)")
    print(f"修复后: 好={good} 坏={total_bad} ({total_bad/total_q*100:.1f}%)")


if __name__ == "__main__":
    main()
