"""ASTM E1381 (low-level framing) + E1394 (record) helpers.

ASTM is the protocol Maglumi / Mispa and most serial analyzers speak. It is a
simplex, stop-and-wait, frame-by-frame ACKed protocol:

    sender → ENQ
    recv   → ACK
    sender → <STX> FN text <ETX> C1 C2 <CR><LF>     (a frame)
    recv   → ACK   (or NAK to retransmit)
    ...    (more frames, FN cycles 1..7)
    sender → EOT

This module provides the codec (control bytes, checksum, frame builder) and a
`Receiver` state machine that consumes inbound bytes, validates each frame,
tells the caller which control byte to send back, and assembles the full
message (records joined by CR). Higher layers (drivers) parse the E1394 records.
"""

from __future__ import annotations

from typing import List, Optional, Tuple

# E1381 control bytes
ENQ = 0x05
ACK = 0x06
NAK = 0x15
EOT = 0x04
STX = 0x02
ETX = 0x03
ETB = 0x17
CR = 0x0D
LF = 0x0A

CRLF = bytes([CR, LF])


def checksum(frame_body: bytes) -> bytes:
    """ASTM checksum: sum of bytes from FN up to & including ETX/ETB, mod 256,
    as two uppercase hex chars (ASCII)."""
    total = sum(frame_body) & 0xFF
    return f"{total:02X}".encode("ascii")


def build_frames(text: str, max_len: int = 240) -> List[bytes]:
    """Encode an ASTM message body (records joined by CR) into framed bytes.

    Each frame: STX <FN> <data> ETX <C1><C2> CR LF. Frame number cycles 1..7.
    Long records are split across frames using ETB for intermediate frames.
    """
    raw = text.encode("latin-1", errors="replace")
    chunks = [raw[i:i + max_len] for i in range(0, len(raw), max_len)] or [b""]
    frames: List[bytes] = []
    fn = 1
    for idx, chunk in enumerate(chunks):
        last = idx == len(chunks) - 1
        term = ETX if last else ETB
        body = bytes([ord(str(fn))]) + chunk + bytes([term])
        frame = bytes([STX]) + body + checksum(body) + CRLF
        frames.append(frame)
        fn = fn % 7 + 1
    return frames


class Receiver:
    """Stateful inbound-byte consumer for one ASTM transfer.

    Feed bytes via `feed`; it returns a list of (response_byte, message) tuples:
    `response_byte` is the control byte to write back (ACK/NAK/None) and
    `message` is the assembled record text when an EOT completes a transfer
    (else None).
    """

    def __init__(self) -> None:
        self._buf = bytearray()
        self._records: List[str] = []
        self._in_frame = False

    def feed(self, data: bytes) -> List[Tuple[Optional[int], Optional[str]]]:
        out: List[Tuple[Optional[int], Optional[str]]] = []
        for b in data:
            if not self._in_frame:
                if b == ENQ:
                    self._records = []
                    out.append((ACK, None))
                elif b == STX:
                    self._in_frame = True
                    self._buf = bytearray()
                elif b == EOT:
                    msg = "\r".join(self._records) if self._records else None
                    self._records = []
                    out.append((None, msg))
                # ignore stray CR/LF/ACK between frames
                continue

            # inside a frame: collect until CR LF after the checksum
            self._buf.append(b)
            if b == LF and len(self._buf) >= 2 and self._buf[-2] == CR:
                self._in_frame = False
                ok, text = self._decode_frame(bytes(self._buf))
                if ok:
                    if text:
                        self._records.append(text)
                    out.append((ACK, None))
                else:
                    out.append((NAK, None))
        return out

    @staticmethod
    def _decode_frame(frame: bytes) -> Tuple[bool, str]:
        """Validate <FN data ETX/ETB C1C2 CR LF> and return (ok, data_text)."""
        try:
            term_idx = None
            for i, ch in enumerate(frame):
                if ch in (ETX, ETB):
                    term_idx = i
                    break
            if term_idx is None or term_idx + 2 >= len(frame):
                return False, ""
            body = frame[0:term_idx + 1]          # FN .. ETX/ETB inclusive
            got = frame[term_idx + 1:term_idx + 3]  # two checksum chars
            if checksum(body) != got:
                return False, ""
            data = frame[1:term_idx]               # skip FN, exclude terminator
            return True, data.decode("latin-1", errors="replace")
        except Exception:  # noqa: BLE001
            return False, ""
