#!python

"""AttaCut: Fast and Reasonably Accurate Word Tokenizer for Thai

Usage:
  attacut-cli <src> [--dest=<dest>] [--model=<model>] [--num-cores=<num-cores>] [--batch-size=<batch-size>] [--gpu]
  attacut-cli [-v | --version]
  attacut-cli [-h | --help]

Arguments:
  <src>             Path to input text file to be tokenized

Options:
  -h --help         Show this screen.
  --model=<model>   Model to be used [default: attacut-sc].
  --dest=<dest>     If not specified, it'll be <src>-tokenized-by-<model>.txt
  -v --version      Show version
  --num-cores=<num-cores>  Use multiple-core processing [default: 0]
  --batch-size=<batch-size>  Batch size [default: 20]
"""

import sys
import os

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


from attacut import Tokenizer, __version__, preprocessing, utils, models
from docopt import docopt

# from https://github.com/pytorch/pytorch/issues/1494#issuecomment-305993854
from multiprocessing import set_start_method
try:
    set_start_method('spawn')
except RuntimeError:
    pass

SEP = "|"

def get_argument(dict, name, default):
    v = dict.get(name)
    return v if v is not None else default

class AttaCutCLIDataset(Dataset):
      def __init__(self, src, tokenizer, device):
          self.src = src
          self.total_lines = utils.wc_l(self.src)

          self.tokenizer = tokenizer

          self.device = device

          self.data = []
          with open(self.src, "r") as fin:
              for txt in fin:
                txt = preprocessing.TRAILING_SPACE_RX.sub("", txt)

                tokens, features = self.tokenizer.dataset.make_feature(txt)

                inputs = (
                    features,
                    torch.zeros(features[1])  # dummy label when won't need it here
                )

                (x, seq), labels, _= self.tokenizer.dataset.prepare_model_inputs(
                  inputs,
                )

                x = torch.squeeze(x)

                self.data.append((((x, seq), labels), tokens))

      def __len__(self):
          return self.total_lines

      def __getitem__(self, idx):
          return self.data[idx]

def collate_fn(tokenizer, batch, device):
    inputs, tokens = [], []

    for x, t in batch:
      inputs.append(x)
      tokens.append(t)

    (x, seq), labels, perm_idx = tokenizer.dataset.collate_fn(inputs)

    return ((x.to(device), seq.to(device)), labels, perm_idx), tokens


if __name__ == "__main__":
    arguments = docopt(__doc__, version=f"AttaCut: version {__version__}")

    src = arguments["<src>"]
    model = arguments["--model"]
    num_cores = int(arguments["--num-cores"])
    batch_size = int(arguments["--batch-size"])

    assert num_cores >= 0, "Input given to <num-thread> should greather than or equal one"

    if not src:
      print(__doc__)
      sys.exit(0)

    # for a custom model, use the last dir's name.
    model_name = model.split("/")[-1]

    dest = get_argument(
        arguments,
        "--dest",
        utils.add_suffix_to_file_path(
            src,
            f"tokenized-by-{model_name}"
        )
    )

    print(f"Tokenizing {src}")
    print(f"Using {src}")
    print(f"Output: {dest}")

    tokenizer = Tokenizer(model)

    total_lines = utils.wc_l(src)

    device = "cuda" if arguments["--gpu"] else "cpu"

    if num_cores == 0:
        print(f"Use main process processing for {total_lines} lines")
    else:
        print(f"Use {num_cores} cores for processing for {total_lines} lines")

    print(f"device={device}")

    ds = AttaCutCLIDataset(src, tokenizer, device)
    dataloader = DataLoader(
      ds,
      batch_size=batch_size,
      shuffle=False,
      num_workers=0, # only use main process
      collate_fn=lambda batch: collate_fn(tokenizer, batch, device)
    )

    tokenizer.model.to(device)

    results = []

    with torch.no_grad(), \
        tqdm(total=total_lines) as tq, \
        open(dest, "w") as fout:
          for batch in dataloader:
              (x, labels, perm_idx), tokens = batch
              probs = tokenizer.model(x)

              preds = probs.cpu().detach().numpy() > 0.5
              max_seq = x[1].max().cpu().detach().numpy()

              perm_idx = perm_idx.cpu().detach()
              preds = preds.reshape((-1, max_seq))


              for ori_ix, after_sorting_ix in enumerate(np.argsort(perm_idx)):
                  pred = preds[after_sorting_ix, :]
                  token = tokens[ori_ix]


                  words = preprocessing.find_words_from_preds(token, pred)
                  fout.write("%s\n" % SEP.join(words))

              tq.update(n=preds.shape[0])