"""Serial (RS-232 / COM) transport running an ASTM exchange — for analyzers
like Maglumi / Mispa that talk ASTM over a COM port rather than TCP/HL7.

Bidirectional: when the analyzer sends a host query (Q record) with a tube
barcode, we ask the cloud which tests are ordered for it and reply with the
order (O records). When it sends results (R records) we enqueue them like the
TCP path. Frame-level ACK/NAK is handled via `astm.Receiver`.

pyserial is an optional dependency; if it's missing, the device surfaces a clear
error instead of crashing the whole middleware.
"""

from __future__ import annotations

import logging
import threading
import time
from typing import Any, Callable, Dict, List, Optional

from . import astm
from .buffer import Outbox
from .drivers import get_driver

try:
    import serial  # type: ignore
    _PARITY = {"none": serial.PARITY_NONE, "odd": serial.PARITY_ODD, "even": serial.PARITY_EVEN}
    _STOP = {1: serial.STOPBITS_ONE, 2: serial.STOPBITS_TWO}
    _BYTES = {7: serial.SEVENBITS, 8: serial.EIGHTBITS}
    HAVE_SERIAL = True
except Exception:  # noqa: BLE001
    serial = None  # type: ignore
    HAVE_SERIAL = False

log = logging.getLogger("serial")

# Order provider: (machine_id, barcode) -> list[analyzer_code]. Set by the app.
OrderProvider = Callable[[int, str], List[str]]


class SerialDevice:
    """Owns one COM port + its ASTM exchange thread."""

    def __init__(self, device_cfg: Any, outbox: Outbox, order_provider: Optional[OrderProvider] = None) -> None:
        self.cfg = device_cfg
        self.outbox = outbox
        self.driver = get_driver(device_cfg.driver)
        self.order_provider = order_provider
        self._port: "serial.Serial | None" = None
        self._thread: threading.Thread | None = None
        self._stop = threading.Event()
        self._listening = False
        self._bind_error: str | None = None
        self._lock = threading.Lock()
        self._stat: Dict[str, Any] = {
            "connections_open": 0, "messages": 0, "results": 0, "errors": 0,
            "last_message_at": None, "last_barcode": None, "last_peer": None, "last_error": None,
        }

    # ----- lifecycle ------------------------------------------------------
    def start(self) -> None:
        if not HAVE_SERIAL:
            self._bind_error = "pyserial not installed (pip install pyserial)"
            log.error("[%s] %s", self.cfg.name, self._bind_error)
            return
        cs = getattr(self.cfg, "serial", {}) or {}
        port = cs.get("com_port") or self.cfg.com_port if hasattr(self.cfg, "com_port") else cs.get("com_port")
        try:
            self._port = serial.Serial(
                port=self.cfg.com_port,
                baudrate=int(self.cfg.baud_rate or 9600),
                bytesize=_BYTES.get(int(self.cfg.data_bits or 8), serial.EIGHTBITS),
                parity=_PARITY.get(str(self.cfg.parity or "none").lower(), serial.PARITY_NONE),
                stopbits=_STOP.get(int(self.cfg.stop_bits or 1), serial.STOPBITS_ONE),
                timeout=1,
            )
        except Exception as exc:  # noqa: BLE001
            self._bind_error = str(exc)
            log.error("[%s] cannot open %s — %s", self.cfg.name, getattr(self.cfg, "com_port", "?"), exc)
            return
        self._bind_error = None
        self._listening = True
        self._thread = threading.Thread(target=self._loop, name=f"serial-{self.cfg.name}", daemon=True)
        self._thread.start()
        log.info("[%s] listening on %s @ %s [%s]", self.cfg.name, self.cfg.com_port,
                 self.cfg.baud_rate, self.cfg.driver)

    def stop(self) -> None:
        self._stop.set()
        self._listening = False
        if self._port:
            try:
                self._port.close()
            except Exception:  # noqa: BLE001
                pass
            self._port = None

    # ----- live status ----------------------------------------------------
    def stats(self) -> Dict[str, Any]:
        with self._lock:
            s = dict(self._stat)
        s.update({
            "name": self.cfg.name, "machine_id": self.cfg.machine_id, "driver": self.cfg.driver,
            "listen_host": "COM", "listen_port": getattr(self.cfg, "com_port", None),
            "enabled": self.cfg.enabled, "listening": self._listening,
            "bind_error": self._bind_error, "relay": None,
        })
        return s

    def _bump(self, **changes: Any) -> None:
        with self._lock:
            for k, v in changes.items():
                if k in ("messages", "results", "errors", "connections_open"):
                    self._stat[k] = self._stat.get(k, 0) + v
                else:
                    self._stat[k] = v

    # ----- ASTM exchange --------------------------------------------------
    def _loop(self) -> None:
        rx = astm.Receiver()
        while not self._stop.is_set() and self._port:
            try:
                data = self._port.read(256)
            except Exception as exc:  # noqa: BLE001
                self._bump(errors=1, last_error=str(exc))
                log.warning("[%s] read error: %s", self.cfg.name, exc)
                time.sleep(1)
                continue
            if not data:
                continue
            for response, message in rx.feed(data):
                if response is not None:
                    try:
                        self._port.write(bytes([response]))
                    except Exception:  # noqa: BLE001
                        pass
                if message:
                    self._handle_message(message)

    def _handle_message(self, message: str) -> None:
        self._bump(messages=1, last_message_at=time.time())
        parsed = self.driver.parse(message, self.cfg)
        ptype = parsed.get("type")

        # Archive the full raw ASTM message (audit) via the outbox.
        try:
            summary = "%s | %s | %d result(s)" % (
                ptype, parsed.get("barcode") or parsed.get("query_barcode") or "-",
                len(parsed.get("results") or []))
            self.outbox.enqueue(self.cfg.machine_id, {
                "_kind": "comm_log", "machine_id": self.cfg.machine_id,
                "raw_data": message, "parsed_summary": summary, "protocol": "astm",
            })
        except Exception:  # noqa: BLE001
            pass

        if ptype == "query":
            barcode = parsed.get("query_barcode") or ""
            log.info("[%s] host query for %s", self.cfg.name, barcode or "?")
            codes: List[str] = []
            if self.order_provider and barcode:
                try:
                    codes = self.order_provider(self.cfg.machine_id, barcode) or []
                except Exception as exc:  # noqa: BLE001
                    self._bump(errors=1, last_error=f"order lookup: {exc}")
            text = self.driver.build_order(barcode, codes) if codes else self.driver.build_no_order(barcode)
            self._send_message(text)
            return

        if ptype == "result":
            barcode = parsed.get("barcode") or ""
            results = parsed.get("results") or []
            for r in results:
                self.outbox.enqueue(self.cfg.machine_id, {
                    "machine_id": self.cfg.machine_id, "barcode": barcode,
                    "test_code": r["test_code"], "raw_result": r["raw_result"],
                    "parsed_result": r["raw_result"], "unit": r.get("unit") or None,
                    "raw_data": {"flag": r.get("flag"), "value_type": r.get("value_type"), "astm": r.get("raw_obx")},
                })
            self._bump(results=len(results), last_barcode=barcode or None)
            log.info("[%s] queued %d result(s) for %s", self.cfg.name, len(results), barcode or "?")

    def _send_message(self, text: str) -> None:
        """Send an ASTM message (ENQ → frames(ACKed) → EOT) over the port."""
        port = self._port
        if not port:
            return
        try:
            port.write(bytes([astm.ENQ]))
            if not self._wait_for(port, astm.ACK):
                log.warning("[%s] no ACK after ENQ", self.cfg.name)
                port.write(bytes([astm.EOT]))
                return
            for frame in astm.build_frames(text):
                port.write(frame)
                if not self._wait_for(port, astm.ACK):
                    log.warning("[%s] no ACK after frame", self.cfg.name)
                    break
            port.write(bytes([astm.EOT]))
            log.info("[%s] sent order reply", self.cfg.name)
        except Exception as exc:  # noqa: BLE001
            self._bump(errors=1, last_error=f"send: {exc}")

    @staticmethod
    def _wait_for(port: "serial.Serial", want: int, timeout: float = 5.0) -> bool:
        deadline = time.time() + timeout
        while time.time() < deadline:
            b = port.read(1)
            if not b:
                continue
            if b[0] == want:
                return True
            if b[0] == astm.NAK:
                return False
        return False
