#!python

"""Apply black to code cells in jupyter notebooks."""

import sys
import os
import re

import click
import nbformat
import black

from typing import (
    Any,
    Callable,
    Collection,
    Dict,
    Generator,
    Generic,
    Iterable,
    Iterator,
    List,
    Optional,
    Pattern,
    Sequence,
    Set,
    Tuple,
    TypeVar,
    Union,
    cast,
)

from functools import partial
from pathlib import Path


DEFAULT_EXCLUDES = r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist|\.ipynb_checkpoints)/"
DEFAULT_INCLUDES = r"\.ipynb$"


err = partial(click.secho, fg="red", err=True)
out = partial(click.secho, bold=True, err=True)


def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
    """Compile a regular expression string in `regex`.
    If it contains newlines, use verbose mode.
    """
    if "\n" in regex:
        regex = "(?x)" + regex
    return re.compile(regex)


@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
@click.argument(
    "src",
    nargs=-1,
    type=click.Path(
        exists=True,
        file_okay=True,
        dir_okay=True,
        readable=True,
        allow_dash=True,
    ),
    is_eager=True,
)
@click.option(
    "-l",
    "--line-length",
    type=int,
    default=88,
    help="How many characters per line to allow.",
    show_default=True,
)
@click.option(
    "--exclude",
    type=str,
    default=DEFAULT_EXCLUDES,
    help=(
        "A regular expression that matches files and directories that should be "
        "excluded on recursive searches.  An empty value means no paths are excluded. "
        "Use forward slashes for directories on all platforms (Windows, too).  "
        "Exclusions are calculated first, inclusions later."
    ),
    show_default=True,
)
@click.option(
    "--include",
    type=str,
    default=DEFAULT_INCLUDES,
    help=(
        "A regular expression that matches files and directories that should be "
        "included on recursive searches.  An empty value means all files are "
        "included regardless of the name.  Use forward slashes for directories on "
        "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
        "later."
    ),
    show_default=True,
)
@click.option(
    "--check",
    is_flag=True,
    help=(
        "Don't write the files back, just return the status.  Return code 0 "
        "means nothing would change.  Return code 1 means some files would be "
        "reformatted."
    ),
)
@click.option(
    "--clear-output",
    is_flag=True,
    help=("Clean cell output as part of formatting."),
)
def main(
    src: Tuple[str],
    line_length: int,
    include: str,
    exclude: str,
    check: bool,
    clear_output: bool,
) -> None:
    """Apply black to code cells in jupyter notebooks
    underneath the src path."""

    if not src:
        out("No path given.")
        exit(0)

    try:
        include_regex = re_compile_maybe_verbose(include)
    except re.error:
        err(f"Invalid regular expression for include given: {include!r}")
        exit(1)
    try:
        exclude_regex = re_compile_maybe_verbose(exclude)
    except re.error:
        err(f"Invalid regular expression for exclude given: {exclude!r}")
        exit(1)

    root = find_project_root(src)
    sources: Set[Path] = set()
    for s in src:
        p = Path(s)
        if p.is_dir():
            sources.update(
                gen_notebook_files_in_dir(
                    p, root, include_regex, exclude_regex
                )
            )
        elif p.is_file() or s == "-":
            # if a file was explicitly given, we don't care about its extension
            sources.add(p)
        else:
            err(f"invalid path: {s}")

    if not sources:
        out("Can't find any notebooks.")
        exit(0)

    changes = False
    for notebook_path in sources:

        try:
            notebook = nbformat.read(
                open(notebook_path.resolve(), "r"),
                as_version=nbformat.NO_CONVERT,
            )
        except nbformat.reader.NotJSONError:
            continue

        click.secho(f"{notebook_path.resolve()} >", nl=False)

        formatted_cells = []
        changes = False
        for cell in notebook["cells"]:
            cell, changes = handle_cell(
                cell, changes, check, clear_output, line_length
            )
            formatted_cells.append(cell)

        notebook["cells"] = formatted_cells
        nbformat.write(notebook, open(notebook_path.resolve(), "w"))
        click.secho("\n", nl=False)

    if check:
        if changes:
            err("Does not pass.")
            exit(1)
        else:
            out("OK")
            exit(0)
    exit(0)


def handle_cell(
    cell: dict,
    changes: bool,
    check: bool,
    clear_output: bool,
    line_length: int,
) -> dict:

    if cell["cell_type"] == "code":
        try:
            formatted_source = format_cell(cell["source"], line_length)
        except black.InvalidInput:
            click.secho("*", nl=False, fg="red")
        else:
            changes |= formatted_source != cell["source"]

            if changes:
                click.secho("-", nl=False)
            else:
                click.secho("-", nl=False, fg="green")

            if not check:
                cell["source"] = formatted_source

        if clear_output:
            changes |= cell["execution_count"] == None
            changes |= len(cell["outputs"]) == 0

            if not check:
                cell["execution_count"] = None
                cell["outputs"] = 0

    return cell, changes


def format_cell(source: str, line_length: int) -> str:

    source = "\n".join([hide_magic(l) for l in source.splitlines()])

    try:
        colon = source.rstrip()[-1] == ";"
    except IndexError:
        colon = False

    formatted_source = black.format_str(source, line_length).rstrip()

    if colon:
        formatted_source = f"{formatted_source};"

    return reveal_magic(formatted_source)


def contains_magic(line: str) -> bool:
    return line[0] == "%" or line[0] == "!"


def hide_magic(line: str) -> str:
    """
    Black can't deal with cell or line magic, so we
    disguise it as a comment. This keeps it in the same
    place in the reformatted code.
    """
    try:
        return f"###MAGIC###{line}" if contains_magic(line) else line
    except IndexError:
        return line


def reveal_magic(source: str) -> str:
    """
    Reveal any notebook magic hidden by hide_magic().
    """
    return source.replace("###MAGIC###", "")


def gen_notebook_files_in_dir(
    path: Path, root: Path, include: Pattern[str], exclude: Pattern[str]
) -> Iterator[Path]:
    """Generate all files under `path` whose paths are not excluded by the
    `exclude` regex, but are included by the `include` regex.
    Symbolic links pointing outside of the `root` directory are ignored.
    """
    assert (
        root.is_absolute()
    ), f"INTERNAL ERROR: `root` must be absolute but is {root}"
    for child in path.iterdir():
        try:
            normalized_path = (
                "/" + child.resolve().relative_to(root).as_posix()
            )
        except ValueError:
            if child.is_symlink():
                continue

            raise

        if child.is_dir():
            normalized_path += "/"
        exclude_match = exclude.search(normalized_path)
        if exclude_match and exclude_match.group(0):
            continue

        if child.is_dir():
            yield from gen_notebook_files_in_dir(child, root, include, exclude)

        elif child.is_file():
            include_match = include.search(normalized_path)
            if include_match:
                yield child


def find_project_root(srcs: Iterable[str]) -> Path:
    """Return a directory containing .git, .hg, or pyproject.toml.
    That directory can be one of the directories passed in `srcs` or their
    common parent.
    If no directory in the tree contains a marker that would specify it's the
    project root, the root of the file system is returned.
    """
    if not srcs:
        return Path("/").resolve()

    common_base = min(Path(src).resolve() for src in srcs)
    if common_base.is_dir():
        # Append a fake file so `parents` below returns `common_base_dir`, too.
        common_base /= "fake-file"
    for directory in common_base.parents:
        if (directory / ".git").is_dir():
            return directory

        if (directory / ".hg").is_dir():
            return directory

        if (directory / "pyproject.toml").is_file():
            return directory

    return directory


if __name__ == "__main__":
    main()
