#!/usr/bin/env python3

__doc__ = f"""
Datalog (py)
============

An interactive datalog interpreter with commands and persistence

Commands
~~~~~~~~
  .help      (this message)
  .all       display all tuples
  .quit      to exit the REPL

To exit, use control-c or control-d

The interpreter
~~~~~~~~~~~~~~~

The interpreter reads one line at a time from stdin.
Lines are either
 - definitions (ending in .),
 - queries (ending in ?)
 - retractions (ending in !)

A definition may contain arbitrarily many datalog tuples and rules.

   edge(a, b). edge(b, c).  % A pair of definitions
   ⇒ edge(a, b). % The REPL's response that it has been committed
   ⇒ edge(b, c).

A query may contain definitions, but they exist only for the duration of the query.

   edge(X, Y)? % A query which will enumerate all 2-edges
   ⇒ edge(a, b).
   ⇒ edge(b, c).

   edge(c, d). edge(X, Y)? % A query with a local tuple
   ⇒ edge(a, b).
   ⇒ edge(b, c).
   ⇒ edge(c, d).

A retraction may contain only one tuple or clause, which will be expunged.

   edge(a, b)!   % This tuple is in our dataset
   ⇒ edge(a, b)  % So deletion succeeds

   edge(a, b)!   % This tuple is no longer in our dataset
   ⇒ Ø           % So deletion fails

"""

import argparse
import logging
import sys

from datalog.debris import Timing
from datalog.evaluator import select
from datalog.reader import pr_str, read_command, read_dataset
from datalog.types import CachedDataset, Constant, Dataset, IndexedDataset, LVar, Rule

from prompt_toolkit import print_formatted_text, prompt, PromptSession
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.history import FileHistory
from prompt_toolkit.styles import Style
from yaspin import yaspin


STYLE = Style.from_dict({
    # User input (default text).
    "": "",
    "prompt": "ansigreen",
    "time": "ansiyellow"
})


class InterpreterInterrupt(Exception):
  """An exception used to break the prompt or evaluation."""


def print_(fmt, **kwargs):
  print_formatted_text(FormattedText(fmt), **kwargs)


def print_db(db):
  """Render a database for debugging."""

  for e in db.tuples():
    print(f"⇒ {pr_str(e)}")

  for r in db.rules():
    print(f"⇒ {pr_str(r)}")


def main(args):
  """REPL entry point."""

  if args.db_cls == "simple":
    db_cls = Dataset
  elif args.db_cls == "cached":
    db_cls = CachedDataset
  elif args.db_cls == "indexed":
    db_cls = IndexedDataset

  print(f"Using dataset type {db_cls}")

  session = PromptSession(history=FileHistory(".datalog.history"))
  db = db_cls([], [])

  if args.dbs:
    for db_file in args.dbs:
      try:
        with open(db_file, "r") as f:
          db = db.merge(read_dataset(f.read()))
          print(f"Loaded {db_file} ...")
      except Exception as e:
        print("Internal error - {e}")
        print(f"Unable to load db {db_file}, skipping")

  while True:
    try:
      line = session.prompt([("class:prompt", ">>> ")], style=STYLE)
    except (InterpreterInterrupt, KeyboardInterrupt):
      continue
    except EOFError:
      break

    if line == ".all":
      op = ".all"
    elif line == ".dbg":
      op = ".dbg"
    elif line == ".quit":
      break

    elif line in {".help", "help", "?", "??", "???"}:
      print(__doc__)
      continue

    elif line.split(" ")[0] == ".log":
      op = ".log"

    else:
      try:
        op, val = read_command(line)
      except Exception as e:
        print(f"Got an unknown command or syntax error, can't tell which")
        continue

    # Definition merges on the DB
    if op == ".all":
      print_db(db)

    # .dbg drops to a debugger shell so you can poke at the instance objects (database)
    elif op == ".dbg":
      import pdb
      pdb.set_trace()

    # .log sets the log level - badly
    elif op == ".log":
      level = line.split(" ")[1].upper()
      try:
        ch.setLevel(getattr(logging, level))
      except BaseException:
        print(f"Unknown log level {level}")

    elif op == ".":
      # FIXME (arrdem 2019-06-15):
      #   Syntax rules the parser doesn't impose...
      try:
        for rule in val.rules():
          assert not rule.free_vars, f"Rule contains free variables {rule.free_vars!r}"

        for tuple in val.tuples():
          assert not any(isinstance(e, LVar) for e in tuple), f"Tuples cannot contain lvars - {tuple!r}"

      except BaseException as e:
        print(f"Error: {e}")
        continue

      db = db.merge(val)
      print_db(val)

    # Queries execute - note that rules as queries have to be temporarily merged.
    elif op == "?":
      # In order to support ad-hoc rules (joins), we have to generate a transient "query" database
      # by bolting the rule on as an overlay to the existing database. If of course we have a join.
      #
      # `val` was previously assumed to be the query pattern. Introduce `qdb`, now used as the
      # database to query and "fix" `val` to be the temporary rule's pattern.
      #
      # We use a new db and db local so that the ephemeral rule doesn't persist unless the user
      # later `.` defines it.
      #
      # Unfortunately doing this merge does nuke caches.
      qdb = db
      if isinstance(val, Rule):
        qdb = db.merge(db_cls([], [val]))
        val = val.pattern

      with yaspin() as spinner:
        with Timing() as t:
          try:
            results = list(select(qdb, val))
          except KeyboardInterrupt:
            print(f"Evaluation aborted after {t}")
            continue

      # It's kinda bogus to move sorting out but oh well
      sorted(results)

      for _results, _bindings in results:
        _result = _results[0] # select only selects one tuple at a time
        print(f"⇒ {pr_str(_result)}")

      # So we can report empty sets explicitly.
      if not results:
        print("⇒ Ø")

      print_([("class:time", f"Elapsed time - {t}")], style=STYLE)

    # Retractions try to delete, but may fail.
    elif op == "!":
      if val in db.tuples() or val in [r.pattern for r in db.rules()]:
        db = db_cls([u for u in db.tuples() if u != val],
                    [r for r in db.rules() if r.pattern != val])
        print(f"⇒ {pr_str(val)}")
      else:
        print("⇒ Ø")


parser = argparse.ArgumentParser()

# Select which dataset type to use
parser.add_argument("--db-type",
                    choices=["simple", "cached", "indexed"],
                    help="Choose which DB to use (default indexed)",
                    dest="db_cls",
                    default="indexed")

parser.add_argument("--load-db", dest="dbs", action="append",
                    help="Datalog files to load first.")

if __name__ == "__main__":
  args = parser.parse_args(sys.argv[1:])
  logger = logging.getLogger("arrdem.datalog")
  ch = logging.StreamHandler()
  ch.setLevel(logging.INFO)
  formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
  ch.setFormatter(formatter)
  logger.addHandler(ch)
  main(args)
