#!/usr/bin/env python

import argparse
import curses
import json
import os
import re
import shlex
import sys
from itertools import chain
from subprocess import run, Popen, PIPE, CalledProcessError
from textwrap import wrap

__version__ = "0.2"
__scriptname__ = os.path.basename(__file__)

SHELL = os.environ.get("SHELL", "/bin/sh")

DEFAULT_TITLE = " blkmenu "

DEFAULT_COLUMNS = [
    "name",
    ">size",
    "type",
    "fstype",
    "partlabel",
    "label",
    "rm",
    "ro",
    "mountpoint",
]

DEFAULT_ACTIONS = {
    "q": "quit",
    "j": "movedown",
    "KEY_DOWN": "movedown",
    "k": "moveup",
    "KEY_UP": "moveup",
    "m": "mount",
    "u": "unmount",
    "l": "lock",
    "L": "unlock",
    "e": "eject",
    "o": "open",
    "i": "info",
    "r": "refresh",
    "a": "toggle_filter",
    "?": "help",
}

ACTION_DESCRIPTIONS = {
    "quit": "Exit the program",
    "movedown": "Move down a line",
    "moveup": "Move up a line",
    "mount": "Mount the selected device",
    "unmount": "Unmount the selected device",
    "lock": "Lock the selected device",
    "unlock": "Unlock the selected device",
    "eject": "Eject the selected device",
    "open": f"Open the selected device mountpoint with the given --open command (defaults to '{SHELL}')",
    "info": "Display all device properties",
    "refresh": "Refresh the device tree.",
    "toggle_filter": "Toggle active filters (set with -a or -A options)",
    "help": "Display this help screen",
}

VALID_ACTIONS = set(DEFAULT_ACTIONS.values())


class DeviceError(Exception):
    pass


class CmdError(Exception):
    pass


class ActionError(Exception):
    pass


def fatal(message):
    raise SystemExit(message)


def vifm(server_name, *commands):
    """Executes the given commands in the given vifm instance.

    Args:
        server_name (str): The vifm instance name.
        commands (list):  List of commands to execute.
    """
    # --remote must be placed after --server-name
    cmd = ["vifm", "--server-name", server_name, "--remote"]
    cmd.extend(chain(*(("-c", c) for c in commands)))
    try:
        run(cmd, check=True, encoding="utf-8")
    except FileNotFoundError:
        raise CmdError("vifm: command not found")
    except CalledProcessError as e:
        raise CmdError(f"vifm: command failed with code {e.returncode}: {' '.join(cmd)}")


class Device:
    """Device object.

    Attributes:
        parent (Device): The parent device or None for a root node.
        children (list): Device children.
        info (dict): Device properties.
        tree_padding (str): Tree decorations that represent the device hierarchy. Make
            sense only in the context of a list returned by DeviceTree.list().
        root (bool): Whether or not the device is root. The root device holds no info.
    """

    CMD = "udisksctl"
    VIEWER = "less"

    def __init__(self, info, parent=None, root=False):
        self.parent = parent
        self.children = info.pop("children", [])
        self.info = info
        self.tree_padding = ""
        self.root = root

    def mount(self, options=None, fstype=None):
        """Mounts device.

        Args:
            options (str): Mount options.
            fstype (str): Filesystem type to use. As per udisksctl docs, if not
                specified, autodetected filesystem type will be used.
        """
        fstype = [] if fstype is None else ["-f", fstype]
        options = [] if options is None else ["-o", options]
        return self._do("mount", "-b", self.info["path"], *options, *fstype)

    def unmount(self):
        """Unmounts device."""
        return self._do("unmount", "-b", self.info["path"])

    def eject(self):
        """Ejects device."""
        return self._do("power-off", "-b", self.info["path"])

    def lock(self):
        """Locks device."""
        return self._do("lock", "-b", self.info["path"])

    def unlock(self):
        """Unlocks device."""
        return self._do("unlock", "-b", self.info["path"])

    def view_details(self):
        """Views all device properties with `less`.

        Raises:
            CmdError: The VIEWER command was not found or returned an error.
        """
        fmt = "{{:>{}}}  {{}}".format(max(map(len, self.info.keys())) + 2)
        input = "\n".join(fmt.format(k.upper(), v) for k, v in self.info.items())
        try:
            run([self.VIEWER], input=input, check=True, encoding="utf-8")
        except FileNotFoundError:
            raise CmdError(f"{self.VIEWER}: command not found")
        except CalledProcessError as e:
            raise CmdError(f"{self.VIEWER}: exited with code {e.returncode}")

    def open(self, raw_cmd):
        """Opens the device mountpoint with the given command.

        Args:
            raw_cmd (str): Command used to open the device mountpoint. If '&' is found
                at the end of the command, then Popen is used to execute the command in
                background (e.g. when using a gui application).

        Raises:
            CmdError: The given command was not found or returned an error.
            DeviceError: The device was not mounted.
        """
        mountpoint = self.info["mountpoint"]
        if not mountpoint:
            raise DeviceError(f"Device '{self.info['name']}' not mounted.")
        raw_cmd = raw_cmd.format(repr(mountpoint))
        try:
            run_ = Popen if raw_cmd.strip().endswith("&") else run
            raw_cmd = raw_cmd.strip("&")
            run_(shlex.split(raw_cmd), cwd=mountpoint, encoding="utf-8")
        except FileNotFoundError:
            raise CmdError(f"{raw_cmd}: command not found")
        except CalledProcessError as e:
            raise CmdError(f"{raw_cmd}: command failed with code {e.returncode}")

    def _do(self, *args):
        """Runs udisksctl command.

        Args:
            args (list): Udisksctl arguments.

        Returns:
            The `udisksctl` process standard output.

        Raises:
            CmdError: The `udisksctl` command was not found.
            DeviecError: The `udiskctl` command returned an error.
        """
        try:
            cmd = [self.CMD] + list(args)
            proc = run(cmd, check=True, stdout=PIPE, stderr=PIPE, encoding="utf8")
        except FileNotFoundError:
            raise CmdError(f"{self.CMD}: command not found")
        except CalledProcessError as e:
            raise DeviceError(e.stderr.strip())
        return proc.stdout.strip()

    def __str__(self):
        return f"<{self.__class__.__name__} {self.info.get('path')}>"

    def __repr__(self):
        return f"{self.__class__.__name__}(path={self.info.get('path')!r})"


class DeviceTree:
    """Device tree.

    Device tree parsed from the json output of lsblk (lsblk -JO).

    Attributes:
        filters (list): List of python expressions for filteing devices. If any of these
            expressions evaluates to True for a device, the device is removed from the
            tree and its children are become its parent's.
        prunes (list): List of python expressions for pruning devices. If any of these
            expressions evaluates to True for a device, the device and all its
            descendants are removed from the tree.
        _tree (dict): The device tree.
    """

    def __init__(self, filters=None, prunes=None, filter=True):
        self._filters = filters or []
        self._prunes = prunes or []
        self._tree = self._get_device_tree(filter=filter)

    def refresh(self, filter=True):
        """Rebuilds the device tree.

        Args:
            filter (bool): Whether or not to filter devices.
        """
        self._tree = self._get_device_tree(filter=filter)

    def list(self):
        """Returns the device tree as a list.

        Returns:
            A list of devices.

        Note:
            The list must not be sorted afterwards or the tree padding wouldn't make
            sense anymore.
        """

        def _flatten(device, padding, last_child):

            curr_padding = ""
            next_padding = padding

            if not device.parent.root:
                curr_padding = padding + ("└─ " if last_child else "├─ ")
                next_padding = padding + ("   " if last_child else "│  ")

            device.tree_padding = curr_padding
            yield device

            for i, child in enumerate(device.children):
                yield from _flatten(child, next_padding, i == len(device.children) - 1)

        devices = self._tree.children
        return list(chain(*[_flatten(d, "", 0) for d in devices]))

    def _match(self, device, rules):
        """Returns True if any of the expression rules evaluates to True.

        Args:
            device (Device): The target device to match against the rules.
            rules (list): List of python expressions.

        Returns:
            True if any of the given expression rules evaluates to True.
        """
        locals = device.info
        globals = {"__builtins__": None, "match": re.match, "search": re.search}
        for rule in rules:
            try:
                if eval(rule, globals, locals):
                    return True
            except Exception:
                continue
        return False

    def _lsblk(self):
        """Deserializes the `lsblk` json output into a dictionary.

        Returns:
            A raw dict object representing the device tree.
        """
        try:
            proc = run(["lsblk", "-JO"], stdout=PIPE, check=True, encoding="utf8")
        except FileNotFoundError:
            fatal(f"lsblk: command not found")
        except CalledProcessError as e:
            fatal(f"lsblk: command failed with exit code: {e.returncode}")
        return json.loads(proc.stdout)

    def _get_device_tree(self, filter=True):
        """Builds the device tree from the `lsblk` output.

        Args:
            filter (bool): Whether or not to filter devices.

        Returns:
            The device tree.
        """

        def _build_tree(node):
            if not node.root and filter and self._match(node, self._prunes):
                return None

            children = []
            for child in node.children:
                subtree = _build_tree(Device(child, parent=node))
                if subtree:
                    children.extend(subtree if isinstance(subtree, list) else [subtree])

            if not node.root and filter and self._match(node, self._filters):
                for child in children:
                    child.parent = child.parent.parent
                return children

            node.children = children
            return node

        d = self._lsblk()
        root = Device({"children": d["blockdevices"]}, root=True)
        return _build_tree(root)


class Formatter:
    """Tabularize data.

    Attributes:
        data (list): List of entries to format.
        columns (list): Which entry fields to output.
        col_getter (callable): Function used to retrieve entry fields.
        header_transform (callable): Function called to transform the header. Defautls
            to making the each column header uppercase.
        show_header (bool): Whether or not return the header.
        separator (str): Column separator. Used when `stretch == False`.
        width (int): Max table width.
        stretch (bool): Whether or not to make the columsn spane the whole
    """

    def __init__(
        self,
        data=None,
        columns=None,
        col_getter=None,
        header_transform=None,
        show_header=True,
        separator=" ",
        width=0,
        stretch=False,
    ):
        self.data = data or []
        self.columns = columns or []
        self.col_getter = col_getter or (lambda entry, col: entry[col])
        self.show_header = show_header
        self.header_transform = header_transform or (lambda col: col.upper())
        self.width = width
        self.stretch = stretch
        self.separator = separator

    def format(self):
        """Format the data a a table.

        Returns:
            A tuple containing the header string and a list all formatted entries:
            (header, lines)
        """

        header = []
        used_width = 0

        # find the max width of each column
        formats = {}
        for col in self.columns:
            align = "<"
            if col.startswith("<") or col.startswith(">"):
                align, col = col[0], col[1:]
            col_width = len(col) if self.show_header else 0
            maxcol = max(
                [len(self.col_getter(entry, col)) for entry in self.data] + [col_width]
            )
            formats[col] = "{{:{}{}}}".format(align, maxcol)
            if self.show_header:
                header.append(formats[col].format(self.header_transform(col)))
            used_width += maxcol

        if self.stretch:
            # arrange columns to span the whole screen width
            avail_width = max(self.width - used_width, 0)
            sep = " " * max(avail_width // max(len(self.columns) - 1, 1), 1)
        else:
            sep = self.separator

        header = sep.join(header)

        lines = []
        for entry in self.data:
            line = []
            for col in map(lambda s: s.lstrip("<>"), self.columns):
                val = self.col_getter(entry, col)
                line.append(formats[col].format(val))
            lines.append(sep.join(line))

        return header, lines


class Menu:
    """Device menu.

    Attributes:
        stdscr (curses.window): Window object representing the whole screen.
        args (argparse.Namespace): Command line arguments.
        actions (dict): Map of keys/actions.
        filter (bool): Whether or not to filter devices.
        device_tree (DeviceTree): Device tree.
        device_list (list): Device tree as a flat list.
        selected (int): Selected device as an index of `device_list`.
        error (str): Current error message.
        message (str): Feedback message.
        refresh (bool): Whether or not to refresh the device tree after a key press.
    """

    def __init__(self, stdscr, actions, args):

        self.stdscr = stdscr
        self.args = args
        self.actions = actions

        self.filter = bool(args.filters or args.prunes)
        self.device_tree = DeviceTree(
            filters=args.filters, prunes=args.prunes, filter=filter
        )
        self.device_list = []
        self.selected = -1

        self.init_state()

        curses.curs_set(False)
        curses.use_default_colors()
        curses.init_pair(1, curses.COLOR_RED, -1)

    def init_state(self):
        self.error = ""
        self.message = ""
        self.refresh = False

    def do(self, device, action):
        """Performs an action on the given device.

        Args:
            device (Device): The target device.
            action (str): The action to perform.

        Raises:
            ActionError: Action failed.
            SystemExit: User decided to quit or open the device mountpoint with vifm.

        Returns:
            Whether or not the action was performed successfully.
        """

        if action == "quit":
            sys.exit()

        elif action == "help":
            self.view_help()

        elif self.device_list and action == "moveup":
            self.selected = (self.selected - 1) % len(self.device_list)

        elif self.device_list and action == "movedown":
            self.selected = (self.selected + 1) % len(self.device_list)

        elif action == "refresh":
            self.refresh = True

        elif action == "toggle_filter":
            self.filter = not self.filter
            self.refresh = True

        elif device and action == "mount":
            try:
                options = self.args.mount_opts
                self.message = device.mount(options=options)
                self.refresh = True
            except (CmdError, DeviceError) as e:
                raise ActionError(e)

        if device and action == "unmount":
            try:
                self.message = device.unmount()
                self.refresh = True
            except (CmdError, DeviceError) as e:
                raise ActionError(e)

        if device and action == "eject":
            try:
                self.message = device.eject()
                self.refresh = True
            except (CmdError, DeviceError) as e:
                raise ActionError(e)

        if device and action == "lock":
            try:
                self.message = device.lock()
                self.refresh = True
            except (CmdError, DeviceError) as e:
                raise ActionError(e)

        if device and action == "unlock":
            curses.savetty()
            curses.endwin()
            try:
                self.message = device.unlock()
                self.refresh = True
            except (CmdError, DeviceError) as e:
                raise ActionError(e)
            finally:
                curses.resetty()

        if device and action == "info":
            curses.savetty()
            curses.endwin()
            try:
                device.view_details()
            except CmdError as e:
                raise ActionError(e)
            finally:
                curses.resetty()

        if device and action == "open":
            if self.args.vifm:
                try:
                    mountpoint = device.info["mountpoint"]
                    if not mountpoint:
                        raise ActionError(f"Device '{device.info['name']}' not mounted.")
                    vifm(self.args.vifm, f"cd {mountpoint!r}")
                    sys.exit()
                except CmdError as e:
                    raise ActionError(e)
            else:
                curses.savetty()
                curses.endwin()
                try:
                    device.open(self.args.open_cmd)
                except (CmdError, DeviceError) as e:
                    raise ActionError(e)
                finally:
                    curses.resetty()

    def start(self):
        """Starts the main loop."""
        self.device_list = self.device_tree.list()
        self.selected = 0

        while True:

            self.draw()

            key = self.stdscr.getkey()

            self.init_state()

            if key not in self.actions:
                continue

            action = self.actions[key]
            if action not in VALID_ACTIONS:
                self.error = f"Unknown action: {action}"
                continue

            try:
                device = self.device_list[self.selected]
            except IndexError:
                device = None

            try:
                self.do(device, action)
            except ActionError as e:
                self.error = self.format_error(e, device)

            if self.refresh:
                self.device_tree.refresh(filter=self.filter)
                self.device_list = self.device_tree.list()

    def format_error(self, error, device):
        """Transforms an ActionError into a legible error message.

        Args:
            error (ActionError): The ActionError to format.
            device (Device): The device for which the action failed.

        Returns:
            A nicely formatted error message.
        """
        msg = str(error)
        msg = msg.replace("`", "'")
        msg = re.sub("GDBus.*?: ", "", msg, flags=re.IGNORECASE)
        msg = re.sub(r"Object /\S+", device.info["path"], msg, flags=re.IGNORECASE)
        return msg

    def view_help(self):
        """Displays the help and waits for keypress to exit."""
        self.stdscr.erase()
        y, x = self.draw_border(0, 0, title=" help ")

        entries = {}
        for key, action in self.actions.items():
            entries.setdefault(
                action,
                {"action": action, "keys": [], "description": ACTION_DESCRIPTIONS[action]},
            )
            entries[action]["keys"].append(key)

        header, lines = Formatter(
            data=entries.values(),
            col_getter=lambda e, c: ", ".join(e[c]) if c == "keys" else e[c],
            columns=["action", "keys", "description"],
            separator="    ",
            stretch=False,
        ).format()

        width = self.stdscr.getmaxyx()[1] - (x * 2)

        if header:
            self.stdscr.addstr(y, x, header[:width], curses.A_BOLD)
            y += 1

        for line in lines:
            self.stdscr.addstr(y, x, line[:width])
            y += 1

        self.stdscr.refresh()
        self.stdscr.getkey()

    def draw_border(self, y, x, title=""):
        """Draws the window border.

        Args:
            x, y (int): Window coordinates that indicate where to start drawing.
            title (str): The window title to display no the border at the top-left
                corner.

        Returns:
            The updated coordinates.
        """
        if self.args.border:
            self.stdscr.border()
            if title:
                self.stdscr.addstr(y, x + 3, title, curses.A_BOLD)
            y, x = y + 1, x + 2
        return y, x

    def _col_getter(self, dev, prop):
        """Extracts the given property from the device."""
        prop = prop.lstrip("<>")
        try:
            if dev.info[prop] is None:
                return ""
            if dev.info[prop] is False or dev.info[prop] is True:
                return str(int(dev.info[prop]))
            if self.args.tree is not None and prop == self.args.tree:
                return dev.tree_padding + str(dev.info[prop])
            return str(dev.info[prop])
        except KeyError:
            fatal(f"{__scriptname__}: invalid column: {prop}")

    def draw(self):
        """Draws the menu."""

        self.stdscr.erase()
        y, x = self.draw_border(0, 0, title=self.args.title)

        # remove space taken up by the border
        width = self.stdscr.getmaxyx()[1] - (x * 2)

        header, lines = Formatter(
            data=self.device_list,
            columns=self.args.columns,
            col_getter=self._col_getter,
            show_header=self.args.show_header,
            separator=self.args.sep,
            stretch=self.args.stretch,
            width=width,
        ).format()

        if not self.device_list:
            self.stdscr.addstr(y, x, "No device found.")
            y += 1

        if header and self.device_list:
            self.stdscr.addstr(y, x, header[:width], curses.A_BOLD)
            y += 1

        # make sure a device is always selected when toggling filters
        self.selected = max(min(self.selected, len(self.device_list) - 1), 0)

        for i, line in enumerate(lines):
            attr = curses.A_REVERSE if i == self.selected else curses.A_NORMAL
            self.stdscr.addstr(y, x, line[:width], attr)
            y += 1

        y += 1

        color = curses.color_pair(1) if self.error else curses.A_NORMAL
        message = self.error if self.error else self.message
        for line in wrap(message, width=width):
            self.stdscr.addstr(y, x, line, color)
            y += 1

        self.stdscr.refresh()


def parse_args():
    parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
    parser.add_argument("-h", "--help", action="help", help="Print help information.")
    parser.add_argument(
        "-V",
        "--version",
        action="version",
        version=f"%(prog)s {__version__}",
        help="Print version information.",
    )
    parser.add_argument("--border", action="store_true", help="Show menu border.")
    parser.add_argument(
        "--no-border",
        action="store_false",
        dest="border",
        help="Don't show menu border. This is the default behavior.",
    )
    parser.add_argument(
        "--title",
        default=DEFAULT_TITLE,
        help="Set menu title. Shown only when --border is used.",
    )
    parser.add_argument(
        "--no-title",
        action="store_const",
        const="",
        dest="title",
        help="Don't show menu title. Same as --title ''.",
    )
    parser.add_argument(
        "--columns", default=DEFAULT_COLUMNS, help=f"Show only the given columns."
    )
    parser.add_argument(
        "--header",
        action="store_true",
        default=True,
        dest="show_header",
        help="Show header. This is the default behavior.",
    )
    parser.add_argument(
        "--no-header", action="store_false", dest="show_header", help="Hide header."
    )
    parser.add_argument(
        "--tree",
        metavar="COLUMN_NAME",
        nargs="?",
        default="",
        const="",
        help="The column displayed as hierarchy tree. Defaults to the first column.",
    )
    parser.add_argument(
        "--flat",
        action="store_const",
        const=None,
        dest="tree",
        help="Don't display hierarchy relationships.",
    )
    parser.add_argument(
        "--stretch",
        action="store_true",
        default=True,
        help="Stretch columns to fill the whole window. This is the default behavior.",
    )
    parser.add_argument(
        "--no-stretch",
        action="store_false",
        dest="stretch",
        help="Don't stretch columns to fill the whole window.",
    )
    parser.add_argument(
        "--sep",
        metavar="SEPARATOR",
        default="  ",
        help="Columns separator when --no-stretch is given.",
    )
    parser.add_argument(
        "--open",
        metavar="CMD",
        default=SHELL,
        dest="open_cmd",
        help="Command used to open a device mountpoint. Append '&' to open the command in background.",
    )
    parser.add_argument(
        "--mount-opts",
        metavar="OPTS",
        default="nosuid,noexec,noatime",
        help="Default mount options.",
    )
    parser.add_argument(
        "--vifm",
        metavar="SERVER_NAME",
        nargs="?",
        const="vifm",
        help="Vifm server name. Bypass the --open option and open the device mountpoint directly in the vifm instance named SERVER_NAME.",
    )
    parser.add_argument(
        "-f",
        metavar="EXPR",
        dest="filters",
        nargs="+",
        action="append",
        default=[],
        help="Exclude devices from the menu. Treated as python expressions.",
    )
    parser.add_argument(
        "-p",
        metavar="EXPR",
        dest="prunes",
        nargs="+",
        action="append",
        default=[],
        help="Exclude devices and all their descendants from the menu. Treated as python expressions.",
    )
    parser.add_argument(
        "-a",
        metavar="KEY:ACTION",
        dest="actions",
        nargs="+",
        action="append",
        default=[],
        help=f"Bindings as key:action pairs.",
    )
    parser.add_argument(
        "-A",
        metavar="KEY:ACTION",
        dest="actions_override",
        nargs="+",
        action="append",
        default=[],
        help="Same as -a but clears other bindings that map to the same action.",
    )

    args = parser.parse_args()

    args.prunes = list(chain(*args.prunes))
    args.filters = list(chain(*args.filters))

    args.actions = list(chain(*args.actions))
    args.actions_override = list(chain(*args.actions_override))

    if isinstance(args.columns, str):
        if args.columns:
            args.columns = args.columns.split(",")
        else:
            args.columns = DEFAULT_COLUMNS

    if args.tree == "":
        args.tree = args.columns[0].lstrip("<>")

    return args


def start_menu(stdscr, actions, args):
    menu = Menu(stdscr, actions, args)
    menu.start()


def _parse_binding_pair(pair):
    key, _, action = pair.rpartition(":")
    if not key or not action:
        fatal(f"{__scriptname__}: invalid 'key:action' pair: {pair}")
    return key, action


def main():

    args = parse_args()

    actions = DEFAULT_ACTIONS

    # force an action to be mapped to only one key
    for pair in args.actions_override:
        key, action = _parse_binding_pair(pair)
        for k, a in list(actions.items()):
            if a == action:
                del actions[k]
        actions[key] = action

    for pair in args.actions:
        key, action = _parse_binding_pair(pair)
        actions[key] = action

    curses.wrapper(start_menu, actions, args)


if __name__ == "__main__":
    main()
