#!/usr/bin/env python3
"""A little program that displays progress information as a model trains."""
import curses
import sys
import argparse
import re
from typing import Any
from bpreveal import logUtils
# These will be initialized after we have a window.
_COLOR_GREENBG = _COLOR_REDBG = _COLOR_ALARM = _COLOR_HIGHLIGHT = _COLOR_CRITICAL = None

_NORMAL_REGEXES = [
    re.compile(r"^INFO.*"),
    re.compile(".*Compiled cluster using XLA.*"),
    re.compile(".*val_loss improved"),
    re.compile(".*val_loss did not improve"),
    re.compile("^Restoring model weights.*"),
    re.compile(".*early stopping$"),
    re.compile(".*Unable to register cu.*"),
    re.compile(".*Could not find TensorRT.*"),
    re.compile(".*is called are written to STDERR.*"),
    re.compile(".*ReduceLROnPlateau reducing.*"),
    re.compile(".*does not guarantee that XLA will be used.*"),
    re.compile("^DEBUG.*"),
    re.compile("^I0000 00:00.*StreamExecutor device.*"),
    re.compile("^I0000 00:00.*Loaded cuDNN version.*"),
    re.compile("^I0000 00:00.*Created device.*"),
    re.compile("^AttributeError: module .ml_dtypes.*")
]


def isNormalMessage(msg: str) -> bool:
    """Is this something that we usually see during training?

    :param msg: The line from the input file.
    :return: True if this is normal BPReveal output, False if the user should be alerted.
    """
    for regex in _NORMAL_REGEXES:
        if regex.match(msg):
            return True
    return False


class Screen:
    """A utility to draw tables from the logs of the training script.

    :param stdscr: The curses screen, generated by a call to wrapper()
    :param noDebug: If true, don't show debug messages.
    :param border: How many character spaces should be kept around the screen edge?
    :param colSep: How many spaces between boxes?
    :param statusHeight: How tall should the status box be?
    :param messageHeight: How tall should the message and debug panels be?
        Note that this is for each panel, so the message and debug panels together
        will occupy ``2 * messageHeight + colSep`` rows.
    """

    _epochWin = None
    _batchWin = None
    _λWin = None
    _mesgWin = None
    _statWin = None
    _batchWidth = 0
    _epochWidth = 0
    _mesgWidth = 0
    _charsToClear = {}
    """Can I write things to the screen now?"""
    exitTime: int
    """How long should the printer wait around after a successful exit?"""

    def __init__(self, stdscr, noDebug: bool, border: int, colSep: int,  # noqa: ANN001
                 statusHeight: int, messageHeight: int | None):
        self.noDebug = noDebug
        self._border = border

        self._stdscr = stdscr
        self._messageBuffer = []
        self._debugBuffer = []

        self._stdscr.clear()
        self._stdscr.refresh()
        if curses.has_colors():
            global _COLOR_GREENBG, _COLOR_REDBG, _COLOR_ALARM
            global _COLOR_HIGHLIGHT, _COLOR_CRITICAL
            curses.init_pair(1, curses.COLOR_BLACK, curses.COLOR_GREEN)
            curses.init_pair(2, curses.COLOR_MAGENTA, curses.COLOR_BLACK)
            curses.init_pair(3, curses.COLOR_GREEN, curses.COLOR_BLACK)
            curses.init_pair(4, curses.COLOR_RED, curses.COLOR_BLACK)
            curses.init_pair(5, curses.COLOR_BLACK, curses.COLOR_RED)
            _COLOR_GREENBG = curses.color_pair(1) | curses.A_BOLD
            _COLOR_ALARM = curses.color_pair(2)
            _COLOR_HIGHLIGHT = curses.color_pair(3)
            _COLOR_CRITICAL = curses.color_pair(4) | curses.A_BOLD
            _COLOR_REDBG = curses.color_pair(5)

        # Layout:
        # ┌───────────────────────────┐
        # │┌─────────────────────────┐│
        # ││ STATUS                  ││
        # │└─────────────────────────┘│
        # │┌────────────────┐┌───────┐│
        # ││ EPOCH          ││BATCH  ││
        # ││                ││       ││
        # ││                ││       ││
        # ││                ││       ││
        # ││                ││       ││
        # ││                │└───────┘│
        # ││                │┌───────┐│
        # ││                ││λ      ││
        # ││                ││       ││
        # ││                ││       ││
        # ││                ││       ││
        # │└────────────────┘└───────┘│
        # │┌─────────────────────────┐│
        # ││ MESSAGES                ││
        # ││                         ││
        # ││                         ││
        # │└─────────────────────────┘│
        # │┌─────────────────────────┐│
        # ││ DEBUG                   ││
        # ││                         ││
        # ││                         ││
        # │└─────────────────────────┘│
        # └───────────────────────────┘

        self._width = width = curses.COLS  # pylint: disable=no-member
        height = curses.LINES  # pylint: disable=no-member
        if messageHeight is None:
            messageHeight = height // 7
        self.joinMessages = noDebug or messageHeight < 8
        if self.joinMessages:
            messageHeight = height // 5

        self._messageBufferSize = messageHeight - 2
        messageAreaHeight = messageHeight if self.joinMessages else \
            2 * messageHeight + colSep
        tableTotalWidth = width - 2 * border - colSep
        tableTotalHeight = height - statusHeight - messageAreaHeight \
            - 2 * border - 2 * colSep
        batchHeight = tableTotalHeight * 3 // 5
        self._epochWidth = epochWidth = (tableTotalWidth) * 3 // 5
        self._batchWidth = batchWidth = tableTotalWidth - epochWidth

        self._statWin = curses.newwin(3, width - 2 * border, border, border)
        self._epochWin = curses.newwin(tableTotalHeight,
                                       epochWidth,
                                       border + colSep + statusHeight,
                                       border)
        self._batchWin = curses.newwin(batchHeight,
                                       batchWidth,
                                       border + colSep + statusHeight,
                                       border + epochWidth + colSep)
        self._λWin = curses.newwin(tableTotalHeight - batchHeight - colSep,
                                   batchWidth,
                                   border + statusHeight + colSep * 2 + batchHeight,
                                   border + epochWidth + colSep)
        if self.joinMessages:
            self._mesgWin = curses.newwin(messageHeight,
                                          width - 2 * border,
                                          height - messageHeight - border,
                                          border)
        else:
            self._mesgWin = curses.newwin(messageHeight,
                                          width - 2 * border,
                                          height - messageHeight * 2 - border - colSep,
                                          border)
            self._debugWin = curses.newwin(messageHeight,
                                           width - 2 * border,
                                           height - messageHeight - border,
                                           border)

        for winName in "BλESMD":
            self.printString(1, 1, winName, " ")
        status = "Waiting for input."
        self.printString(1, 1, "S", status)

    def _getWindowProperties(self, winName: str) -> tuple[Any, int, str]:
        match winName:
            case "B":
                win = self._batchWin
                width = self._batchWidth
                title = "BATCH"
            case "V":
                win = self._batchWin
                width = self._batchWidth
                title = "BATCH"
            case "λ":
                win = self._λWin
                title = "λ VALUES"
                width = self._batchWidth
            case "E":
                title = "EPOCH"
                win = self._epochWin
                width = self._epochWidth
            case "S":
                title = "STATUS"
                win = self._statWin
                width = self._width - 2 * self._border
            case "M":
                title = "MESSAGES"
                win = self._mesgWin
                width = self._width - 2 * self._border
            case "D":
                if self.joinMessages:
                    title = "MESSAGES"
                    win = self._mesgWin
                else:
                    title = "DEBUG LOG"
                    win = self._debugWin
                width = self._width - 2 * self._border
            case _:
                raise ValueError(f"No window for {winName}")
        assert win is not None, "No window initialized."
        return win, width, title

    def printString(self, row: int, col: int, winName: str, text: str,
                    color: int | None = None) -> None:
        """Print a given string at a location in a window.

        :param row: What row within the window should the string start at?
        :param col: What column should the string start at?
        :param winName: Which window? One of ``SEBλMD``.
        :param text: The text to display.
        :param color: If provided, the output from ``curses.color_pair()``.
        """
        win, width, _ = self._getWindowProperties(winName)
        printWidth = width - col - 5
        if len(text) > printWidth:
            text = text[:printWidth]
        clearLen = self._charsToClear.get((row, col, winName), 1)
        try:
            win.addstr(row, col, " " * clearLen)
            self._charsToClear[(row, col, winName)] = len(text)
            if color is not None and curses.has_colors():
                win.addstr(row, col, text, color)
            else:
                win.addstr(row, col, text)
            if winName != "S":
                self._drawBorders(winName)
            win.refresh()
        except curses.error:
            pass

    def _drawBorders(self, highlightWin: str | None = None) -> None:
        """Update all the borders, and highlight the label for the given window."""
        if highlightWin == "V":
            highlightWin = "B"
        for windowName in "BλESMD":
            win, width, title = self._getWindowProperties(windowName)

            if windowName == highlightWin:
                win.border()
                assert _COLOR_HIGHLIGHT is not None, "Colors are not supported by this terminal."
                win.addstr(0, (width - (len(title) + 2)) // 2, " " + title + " ",
                           _COLOR_HIGHLIGHT)
            else:
                win.border()
                win.addstr(0, (width - (len(title) + 2)) // 2, " " + title + " ")
            win.refresh()

    def _writeMessages(self) -> None:
        """Write all of the messages in the message and debug buffers."""
        for i, msg in enumerate(self._messageBuffer):
            if isNormalMessage(msg):
                self.printString(i + 1, 1, "M", msg)
            else:
                self.printString(i + 1, 1, "M", msg, _COLOR_ALARM)
        assert self._mesgWin is not None, "No message window to write to."
        self._mesgWin.refresh()
        if not self.joinMessages:
            for i, msg in enumerate(self._debugBuffer[:-1]):
                self.printString(i + 1, 1, "D", msg)
            self._debugWin.refresh()

    def updateStatus(self, winName: str) -> None:
        """Update the status box to indicate which area is active.

        :param winName: A character, one of ``BλEMDX``
            If winName is ``X`` or ``F``, starts the shutdown process.
            ``X`` indicates successful termination, ``F`` indicates that the
            end message was not seen.
        """
        if winName in "MD":
            self.printString(1, 1, "S", "Log")
        elif winName == "E":
            self.printString(1, 1, "S", "Epoch summary")
        elif winName == "B":
            self.printString(1, 1, "S", "Training")
        elif winName == "V":
            self.printString(1, 1, "S", "Validation")
        elif winName == "λ":
            self.printString(1, 1, "S", "λ update")
        elif winName in "FX":
            self.shutdown(winName)

    def shutdown(self, winName: str) -> None:
        """Keep the display up for a bit, then close it out.

        :param winName: Either ``X`` or ``F``. If ``X``, make the window green
            and show a happy message; if ``F``, make the window red and display
            a sad message.
        """
        msgHeader = "Training complete."
        color = _COLOR_HIGHLIGHT
        winColor = _COLOR_GREENBG
        if winName == "F":
            msgHeader = "Abnormal end of file detected."
            color = _COLOR_CRITICAL
            winColor = _COLOR_REDBG
        self._stdscr.bkgdset(winColor)
        self._stdscr.bkgd(winColor)
        self._stdscr.clear()
        self._stdscr.border()
        self._stdscr.redrawwin()
        self._stdscr.refresh()
        # Drawing the border on stdscr causes all the other windows to disappear.
        for windowName in "BλESMD":
            win, _, _ = self._getWindowProperties(windowName)
            win.redrawwin()
            win.refresh()
        for i in range(self.exitTime * 10):  # Tenths of a second, so 10 s.
            msg = msgHeader + " This window will close " \
                f"in {((self.exitTime * 10) - i) // 10}, or exit with <C-c>."
            self.printString(1, 1, "S", msg, color)
            try:
                curses.napms(100)
            except KeyboardInterrupt:
                # Time to quit.
                return

    def addLine(self, line: str) -> bool:
        r"""Take a line from the log and put the info in the right place.

        :param line: A line from a BPReveal log file.
        :return: ``True`` if more lines are expected, ``False`` if it's time to shut down.

        The return value is important because a program like tail won't know that no more
        input will be added to the logs, but we can look for a line that contains
        ``Training job completed successfully`` and know it's time to quit.
        Lines that have position information will contain a double integral sign
        as a delimiter, like this:

        #ifdef MAN_PAGE

        ``∬INFO : 2024-02-14 13:39:11 :callbacks.py:207 :∬2∬30∬E∬ 47 /  200``

        #else

        :math:`\tt INFO : 2024-02-14 13:39:11 :callbacks.py:207 :
        \iint{}2\iint{}30\iint{}E\iint{} 47 /  200`

        #endif

        The entries are, in order: row, column, window, message.
        A window may be two characters, the second one being either ``H`` or ``A``,
        meaning that the text should be shown in green or red, respectively.

        #ifdef MAN_PAGE
        If a line does not contain ``∬``, then it is displayed in the
        message tab or the debug tab.
        The debug tab is reserved for messages starting with the
        string ``DEBUG``, and messages gets everything else.
        #else
        If a line does not contain ":math:`\tt \iint`", then it is displayed in the
        message tab or the debug tab.
        The debug tab is reserved for messages starting with the
        string ``DEBUG``, and messages gets everything else.
        #endif
        """
        # Is this a line with position information?
        if "∬" in line:
            _, row, col, winMode, text = line.split("∬")
            winName = winMode[0]
            if len(winMode) > 1:
                match winMode[1]:
                    case "H":
                        color = _COLOR_HIGHLIGHT
                    case "A":
                        color = _COLOR_ALARM
                    case _:
                        color = _COLOR_CRITICAL
                self.printString(int(row), int(col), winName, text, color)
            else:
                self.printString(int(row), int(col), winName, text)
            self.updateStatus(winName)
        elif self.noDebug and re.search("DEBUG", line):
            pass  # Just ignore a debug line if we have noDebug set.
        elif (not self.joinMessages) and re.search("DEBUG", line):
            self._debugBuffer.append(line.strip())
            if len(self._debugBuffer) > self._messageBufferSize:
                self._debugBuffer = self._debugBuffer[-self._messageBufferSize:]
            self._writeMessages()
            self.updateStatus("D")
        elif re.search("Training job completed successfully", line.strip()):
            return True
        else:
            self._messageBuffer.append(line.strip())
            if len(self._messageBuffer) > self._messageBufferSize:
                self._messageBuffer = self._messageBuffer[-self._messageBufferSize:]
            self._writeMessages()
            self.updateStatus("M")
        return False


def getParser() -> argparse.ArgumentParser:
    """Command line arguments, all optional.

    :return: An ArgumentParser, ready to parse_args()
    """
    parser = argparse.ArgumentParser(
        description="Takes the logs from training and shows them with a little TUI.")
    parser.add_argument("--no-exit", action="store_true",
                        help="Instead of exiting 10 seconds "
                        "after the training is done, keep this window open so you "
                        "can look at numbers.", dest="noExit")
    parser.add_argument("--exit-delay", type=int, default=10, dest="exitDelay",
                        help="Pause for this many seconds after training is complete "
                             "before closing the window.")
    parser.add_argument("--delay",
                        help="After reading a line, pause for this many milliseconds.",
                        type=int, default=0)
    parser.add_argument("--read-tty", help="Allow the program to read from a terminal. "
                        "There are no good times to use this except for debugging.",
                        dest="readTTY")
    parser.add_argument("--no-debug", help="Don't show debug-level messages.",
                        dest="noDebug", action="store_true")
    parser.add_argument("--message-height", type=int,
                        help="The height (rows) of the message area at "
                        "the bottom of the window.",
                        default=None, dest="messageHeight")
    return parser


def runScreen(stdscr: Any, args: argparse.Namespace) -> None:
    """Called by the wrapper, this constructs the screen and feeds it with stdin.

    :param stdscr: The screen that will be drawn upon.
    :param args: Other values that will be needed to run the program.
    """
    curses.curs_set(0)
    printer = Screen(stdscr, args.noDebug, 1, 0, 3, args.messageHeight)
    if args.noExit:
        exitTime = 10000
    else:
        exitTime = args.exitDelay
    printer.exitTime = exitTime
    for line in sys.stdin:
        if len(line.strip()) > 0:
            ret = printer.addLine(line.strip())
            if args.delay > 0:
                curses.napms(args.delay)
            if ret:
                printer.updateStatus("X")
                break
    else:
        # Abnormal end.
        printer.updateStatus("F")


def main() -> None:
    """A zero-argument wrapper for CLI use."""
    parsedArgs = getParser().parse_args()
    if sys.stdin.isatty() and not parsedArgs.readTTY:
        logUtils.error("Refusing to read from tty. This program expects input"
                       " to be piped in over stdin.")
    else:
        curses.wrapper(runScreen, parsedArgs)


if __name__ == "__main__":
    main()
# Copyright 2022-2025 Charles McAnany. This file is part of BPReveal. BPReveal is free software: You can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 2 of the License, or (at your option) any later version. BPReveal is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with BPReveal. If not, see <https://www.gnu.org/licenses/>.  # noqa
