"""TCP server that plays the analyzer's "LIS workstation" — one per device.

Listens on the configured port, accepts the analyzer connection, reads
MLLP-framed HL7 messages, hands each to the device driver, enqueues the parsed
results to the outbox, and (optionally) replies with an HL7 ACK.

Each server keeps thread-safe runtime counters so the web admin UI can show
live status (connections, messages, results, last barcode, errors).
"""

from __future__ import annotations

import logging
import socket
import socketserver
import threading
import time
from typing import Any, Dict

from . import mllp
from .buffer import Outbox
from .drivers import get_driver

log = logging.getLogger("server")


class DeviceServer:
    """Threaded TCP listener for a single analyzer device."""

    def __init__(self, device_cfg: Any, outbox: Outbox) -> None:
        self.cfg = device_cfg
        self.outbox = outbox
        self.driver = get_driver(device_cfg.driver)
        self._server: "_ThreadingTCPServer" | None = None
        self._thread: threading.Thread | None = None
        self._lock = threading.Lock()
        self._listening = False
        self._bind_error: str | None = None
        self._stat: Dict[str, Any] = {
            "connections_open": 0,
            "messages": 0,
            "results": 0,
            "errors": 0,
            "last_message_at": None,  # epoch seconds
            "last_barcode": None,
            "last_peer": None,
            "last_error": None,
        }

    # ----- lifecycle ------------------------------------------------------
    def start(self) -> None:
        handler = _make_handler(self)
        try:
            self._server = _ThreadingTCPServer((self.cfg.listen_host, self.cfg.listen_port), handler)
        except OSError as exc:
            self._bind_error = str(exc)
            log.error("could not bind '%s' on %s:%s — %s", self.cfg.name,
                      self.cfg.listen_host, self.cfg.listen_port, exc)
            return
        self._thread = threading.Thread(
            target=self._server.serve_forever, name=f"srv-{self.cfg.name}", daemon=True
        )
        self._thread.start()
        self._listening = True
        self._bind_error = None
        log.info(
            "listening for '%s' (machine_id=%s) on %s:%s [%s]",
            self.cfg.name, self.cfg.machine_id, self.cfg.listen_host,
            self.cfg.listen_port, self.cfg.driver,
        )

    def stop(self) -> None:
        self._listening = False
        if self._server:
            self._server.shutdown()
            self._server.server_close()
            self._server = None

    # ----- live status ----------------------------------------------------
    def stats(self) -> Dict[str, Any]:
        with self._lock:
            s = dict(self._stat)
        s.update({
            "name": self.cfg.name,
            "machine_id": self.cfg.machine_id,
            "driver": self.cfg.driver,
            "listen_host": self.cfg.listen_host,
            "listen_port": self.cfg.listen_port,
            "enabled": self.cfg.enabled,
            "listening": self._listening,
            "bind_error": self._bind_error,
            "relay": (f"{self.cfg.forward_host}:{self.cfg.forward_port}"
                      if self.cfg.relay_enabled else None),
        })
        return s

    def _bump(self, **changes: Any) -> None:
        with self._lock:
            for key, val in changes.items():
                if key in ("messages", "results", "errors", "connections_open"):
                    self._stat[key] = self._stat.get(key, 0) + val
                else:
                    self._stat[key] = val

    # ----- message handling ----------------------------------------------
    def handle_message(self, raw: str) -> str | None:
        parsed = self.driver.parse(raw, self.cfg)
        ptype = parsed.get("type")
        self._bump(messages=1, last_message_at=time.time())

        # Archive the FULL raw message (audit) — separate from the per-parameter
        # results. Goes through the same outbox so it survives an offline spell.
        try:
            summary = "%s | %s | %d result(s)" % (
                ptype, parsed.get("barcode") or parsed.get("query_barcode") or "-",
                len(parsed.get("results") or []))
            self.outbox.enqueue(self.cfg.machine_id, {
                "_kind": "comm_log", "machine_id": self.cfg.machine_id,
                "raw_data": raw, "parsed_summary": summary, "protocol": "hl7",
            })
        except Exception:  # noqa: BLE001
            pass

        if ptype == "result":
            barcode = parsed.get("barcode") or ""
            results = parsed.get("results") or []
            if not barcode:
                log.warning("[%s] result message with no barcode — skipped", self.cfg.name)
            for r in results:
                self.outbox.enqueue(
                    self.cfg.machine_id,
                    {
                        "machine_id": self.cfg.machine_id,
                        "barcode": barcode,
                        "test_code": r["test_code"],
                        "raw_result": r["raw_result"],
                        "parsed_result": r["raw_result"],
                        "unit": r.get("unit") or None,
                        "raw_data": {
                            "flag": r.get("flag"),
                            "value_type": r.get("value_type"),
                            "hl7_obx": r.get("raw_obx"),
                        },
                    },
                )
            self._bump(results=len(results), last_barcode=barcode or None)
            log.info(
                "[%s] queued %d result(s) for barcode %s",
                self.cfg.name, len(results), barcode or "?",
            )
        elif ptype == "query":
            log.info(
                "[%s] received query for barcode %s (bidirectional handling: app)",
                self.cfg.name, parsed.get("barcode") or "?",
            )
        else:
            log.debug("[%s] non-result message ignored (type=%s)", self.cfg.name, ptype)

        if self.cfg.send_ack:
            try:
                return self.driver.build_ack(raw)
            except Exception as exc:  # noqa: BLE001 - ACK is best-effort
                log.warning("[%s] could not build ACK: %s", self.cfg.name, exc)
        return None


class _ThreadingTCPServer(socketserver.ThreadingTCPServer):
    allow_reuse_address = True
    daemon_threads = True


def _pump(src: socket.socket, dst: socket.socket, stop: threading.Event) -> None:
    """Copy bytes src→dst until either side closes (used for downstream→analyzer)."""
    try:
        while not stop.is_set():
            data = src.recv(8192)
            if not data:
                break
            dst.sendall(data)
    except OSError:
        pass
    finally:
        stop.set()


def _make_handler(device: DeviceServer):
    class _Handler(socketserver.BaseRequestHandler):
        def handle(self) -> None:
            peer = f"{self.client_address[0]}:{self.client_address[1]}"
            device._bump(connections_open=1, last_peer=peer)
            log.info("[%s] connection from %s", device.cfg.name, peer)

            # Open the downstream relay (e.g. vendor data manager) if configured.
            down: socket.socket | None = None
            stop = threading.Event()
            if device.cfg.relay_enabled:
                try:
                    down = socket.create_connection(
                        (device.cfg.forward_host, device.cfg.forward_port), timeout=10
                    )
                    threading.Thread(target=_pump, args=(down, self.request, stop),
                                     name=f"relay-{device.cfg.name}", daemon=True).start()
                    log.info("[%s] relaying to %s:%s", device.cfg.name,
                             device.cfg.forward_host, device.cfg.forward_port)
                except OSError as exc:
                    # Downstream unreachable — keep tapping (don't lose results),
                    # but surface the failure.
                    device._bump(errors=1, last_error=f"relay connect failed: {exc}")
                    log.error("[%s] relay connect to %s:%s failed: %s — tapping only",
                              device.cfg.name, device.cfg.forward_host,
                              device.cfg.forward_port, exc)
                    down = None

            reader = mllp.MllpReader()
            try:
                self.request.settimeout(300)
                while not stop.is_set():
                    chunk = self.request.recv(8192)
                    if not chunk:
                        break
                    if down is not None:
                        try:
                            down.sendall(chunk)  # forward raw bytes downstream
                        except OSError as exc:
                            device._bump(errors=1, last_error=f"relay send failed: {exc}")
                            down = None
                    for message in reader.feed(chunk):
                        ack = device.handle_message(message)
                        # In relay mode the downstream owns the ACK (pumped back
                        # to the analyzer); only we ACK when there's no relay.
                        if ack and down is None:
                            self.request.sendall(mllp.frame(ack))
            except (socket.timeout, ConnectionError) as exc:
                log.info("[%s] connection closed (%s)", device.cfg.name, exc)
            except Exception as exc:  # noqa: BLE001
                device._bump(errors=1, last_error=str(exc))
                log.exception("[%s] handler error: %s", device.cfg.name, exc)
            finally:
                stop.set()
                if down is not None:
                    try:
                        down.close()
                    except OSError:
                        pass
                device._bump(connections_open=-1)
                log.info("[%s] disconnected %s", device.cfg.name, peer)

    return _Handler
