"""Confusion group resolution for similar single-byte encodings.

At runtime, loads pre-computed distinguishing byte maps from confusion.bin
and uses them to resolve statistical scoring ties between similar encodings.

Build-time computation (``compute_confusion_groups``, ``compute_distinguishing_maps``,
``serialize_confusion_data``) lives in ``scripts/confusion_training.py``.
"""

from __future__ import annotations

import importlib.resources
import struct
import threading
import warnings

from chardet.models import (
    NON_ASCII_BIGRAM_WEIGHT,
    BigramProfile,
    get_enc_index,
    score_with_profile,
)
from chardet.pipeline import DetectionResult

# Type alias for the distinguishing map structure:
# Maps (enc_a, enc_b) -> (distinguishing_byte_set, {byte_val: (cat_a, cat_b)})
DistinguishingMaps = dict[
    tuple[str, str],
    tuple[frozenset[int], dict[int, tuple[str, str]]],
]

# uint8 -> Unicode general category, inverse of the mapping in
# scripts/confusion_training.py used at serialization time.
_INT_TO_CATEGORY: dict[int, str] = {
    0: "Lu",
    1: "Ll",
    2: "Lt",
    3: "Lm",
    4: "Lo",
    5: "Mn",
    6: "Mc",
    7: "Me",
    8: "Nd",
    9: "Nl",
    10: "No",
    11: "Pc",
    12: "Pd",
    13: "Ps",
    14: "Pe",
    15: "Pi",
    16: "Pf",
    17: "Po",
    18: "Sm",
    19: "Sc",
    20: "Sk",
    21: "So",
    22: "Zs",
    23: "Zl",
    24: "Zp",
    25: "Cc",
    26: "Cf",
    27: "Cs",
    28: "Co",
    29: "Cn",
}

_CONFUSION_CACHE: DistinguishingMaps | None = None
_CONFUSION_CACHE_LOCK = threading.Lock()


def deserialize_confusion_data_from_bytes(data: bytes) -> DistinguishingMaps:
    """Load confusion group data from raw bytes.

    :param data: The raw binary content of a confusion.bin file.
    :returns: A :data:`DistinguishingMaps` dictionary keyed by encoding pairs.
    """
    result: DistinguishingMaps = {}
    offset = 0
    (num_pairs,) = struct.unpack_from("!H", data, offset)
    offset += 2

    for _ in range(num_pairs):
        (name_a_len,) = struct.unpack_from("!B", data, offset)
        offset += 1
        name_a = data[offset : offset + name_a_len].decode("utf-8")
        offset += name_a_len

        (name_b_len,) = struct.unpack_from("!B", data, offset)
        offset += 1
        name_b = data[offset : offset + name_b_len].decode("utf-8")
        offset += name_b_len

        (num_diffs,) = struct.unpack_from("!B", data, offset)
        offset += 1

        diff_bytes_list: list[int] = []
        categories: dict[int, tuple[str, str]] = {}
        for _ in range(num_diffs):
            bv, cat_a_int, cat_b_int = struct.unpack_from("!BBB", data, offset)
            offset += 3
            diff_bytes_list.append(bv)
            categories[bv] = (
                _INT_TO_CATEGORY.get(cat_a_int, "Cn"),
                _INT_TO_CATEGORY.get(cat_b_int, "Cn"),
            )
        result[(name_a, name_b)] = (frozenset(diff_bytes_list), categories)

    return result


def load_confusion_data() -> DistinguishingMaps:
    """Load confusion group data from the bundled confusion.bin file.

    :returns: A :data:`DistinguishingMaps` dictionary keyed by encoding pairs.
    """
    global _CONFUSION_CACHE  # noqa: PLW0603
    if _CONFUSION_CACHE is not None:
        return _CONFUSION_CACHE
    with _CONFUSION_CACHE_LOCK:
        if _CONFUSION_CACHE is not None:  # pragma: no cover - double-checked locking
            return _CONFUSION_CACHE
        ref = importlib.resources.files("chardet.models").joinpath("confusion.bin")
        raw = ref.read_bytes()
        if not raw:
            warnings.warn(
                "chardet confusion.bin is empty — confusion resolution disabled; "
                "reinstall chardet to fix",
                RuntimeWarning,
                stacklevel=2,
            )
            _CONFUSION_CACHE = {}
            return _CONFUSION_CACHE
        try:
            _CONFUSION_CACHE = deserialize_confusion_data_from_bytes(raw)
        except (struct.error, UnicodeDecodeError) as e:
            msg = f"corrupt confusion.bin: {e}"
            raise ValueError(msg) from e
        return _CONFUSION_CACHE


# Unicode general category preference scores for voting resolution.
# Higher scores indicate more linguistically meaningful characters.
_CATEGORY_PREFERENCE: dict[str, int] = {
    "Lu": 10,
    "Ll": 10,
    "Lt": 10,
    "Lm": 9,
    "Lo": 9,
    "Nd": 8,
    "Nl": 7,
    "No": 7,
    "Pc": 6,
    "Pd": 6,
    "Ps": 6,
    "Pe": 6,
    "Pi": 6,
    "Pf": 6,
    "Po": 6,
    "Sc": 5,
    "Sm": 5,
    "Sk": 4,
    "So": 4,
    "Zs": 3,
    "Zl": 3,
    "Zp": 3,
    "Cf": 2,
    "Cc": 1,
    "Co": 1,
    "Cs": 0,
    "Cn": 0,
    "Mn": 5,
    "Mc": 5,
    "Me": 5,
}


def resolve_by_category_voting(
    data: bytes,
    enc_a: str,
    enc_b: str,
    diff_bytes: frozenset[int],
    categories: dict[int, tuple[str, str]],
) -> str | None:
    """Resolve between two encodings using Unicode category voting.

    For each distinguishing byte present in the data, compare the Unicode
    general category under each encoding. The encoding whose interpretation
    has the higher category preference score gets a vote. The encoding with
    more votes wins.

    :param data: The raw byte data to examine.
    :param enc_a: First encoding name.
    :param enc_b: Second encoding name.
    :param diff_bytes: Byte values where the two encodings differ.
    :param categories: Mapping of byte value to ``(cat_a, cat_b)`` Unicode
        general category pairs.
    :returns: The winning encoding name, or ``None`` if tied.
    """
    votes_a = 0
    votes_b = 0
    relevant = frozenset(data) & diff_bytes
    if not relevant:
        return None
    for bv in relevant:
        cat_a, cat_b = categories[bv]
        pref_a = _CATEGORY_PREFERENCE.get(cat_a, 0)
        pref_b = _CATEGORY_PREFERENCE.get(cat_b, 0)
        if pref_a > pref_b:
            votes_a += pref_a - pref_b
        elif pref_b > pref_a:
            votes_b += pref_b - pref_a
    if votes_a > votes_b:
        return enc_a
    if votes_b > votes_a:
        return enc_b
    return None


def resolve_by_bigram_rescore(
    data: bytes,
    enc_a: str,
    enc_b: str,
    diff_bytes: frozenset[int],
) -> str | None:
    """Resolve between two encodings by re-scoring only distinguishing bigrams.

    Builds a focused bigram profile containing only bigrams where at least one
    byte is a distinguishing byte, then scores both encodings against their
    best language model.

    :param data: The raw byte data to examine.
    :param enc_a: First encoding name.
    :param enc_b: Second encoding name.
    :param diff_bytes: Byte values where the two encodings differ.
    :returns: The winning encoding name, or ``None`` if tied.
    """
    if len(data) < 2:
        return None

    freq: dict[int, int] = {}
    for i in range(len(data) - 1):
        b1 = data[i]
        b2 = data[i + 1]
        if b1 not in diff_bytes and b2 not in diff_bytes:
            continue
        idx = (b1 << 8) | b2
        weight = NON_ASCII_BIGRAM_WEIGHT if (b1 > 0x7F or b2 > 0x7F) else 1
        freq[idx] = freq.get(idx, 0) + weight

    if not freq:
        return None

    profile = BigramProfile.from_weighted_freq(freq)

    index = get_enc_index()

    best_a = 0.0
    variants_a = index.get(enc_a)
    if variants_a:
        for _, model, model_key in variants_a:
            s = score_with_profile(profile, model, model_key)
            best_a = max(best_a, s)

    best_b = 0.0
    variants_b = index.get(enc_b)
    if variants_b:
        for _, model, model_key in variants_b:
            s = score_with_profile(profile, model, model_key)
            best_b = max(best_b, s)

    if best_a > best_b:
        return enc_a
    if best_b > best_a:
        return enc_b
    return None


def _find_pair_key(
    maps: DistinguishingMaps,
    enc_a: str,
    enc_b: str,
) -> tuple[str, str] | None:
    """Find the canonical key for a pair of encodings in the confusion maps."""
    if (enc_a, enc_b) in maps:
        return (enc_a, enc_b)
    if (enc_b, enc_a) in maps:
        return (enc_b, enc_a)
    return None


def resolve_confusion_groups(
    data: bytes,
    results: list[DetectionResult],
) -> list[DetectionResult]:
    """Resolve confusion between similar encodings in the top results.

    Compares the top two results. If they form a known confusion pair,
    it determines which encoding should win by checking the
    resolve_by_bigram_rescore and resolve_by_category_voting tie-breakers
    and giving precedence to bigram re-scoring when they disagree.

    :param data: The raw byte data to examine.
    :param results: Detection results sorted by confidence descending.
    :returns: A reordered list of :class:`DetectionResult` with the winner first.
    """
    if len(results) < 2:
        return results

    top = results[0]
    second = results[1]
    if top.encoding is None or second.encoding is None:
        return results

    maps = load_confusion_data()
    pair_key = _find_pair_key(maps, top.encoding, second.encoding)
    if pair_key is None:
        return results

    diff_bytes, categories = maps[pair_key]
    enc_a, enc_b = pair_key

    cat_winner = resolve_by_category_voting(data, enc_a, enc_b, diff_bytes, categories)
    bigram_winner = resolve_by_bigram_rescore(data, enc_a, enc_b, diff_bytes)
    winner = bigram_winner if bigram_winner is not None else cat_winner

    if winner is None or winner == top.encoding:
        return results

    return [second, top, *results[2:]]
