#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Executable that launches one of the theanolm subcommands.
"""

from traceback import format_tb
import logging
import sys
import argparse

from theanolm.backend import NumberError, TheanoConfigurationError
from theanolm.backend import IncompatibleStateError
from theanolm.commands import train, score, decode, sample, version

def _get_message(e):
    if hasattr(e, 'args') and len(e.args) > 0 and isinstance(e.args[0], bytes):
        return e.args[0].decode('utf-8')
    return str(e)

# These exception classes are not needed, if gpuarray is not used.
try:
    from pygpu.gpuarray import GpuArrayException
    from theano.gpuarray.type import ContextNotDefined
except ImportError:
    class GpuArrayException(Exception):
        """Dummy GpuArrayException Exception
        """
        pass
    class ContextNotDefined(Exception):
        """Dummy ContextNotDefined Exception
        """
        pass

def exception_handler(exception_type, exception, traceback):
    """Exception handler. Writes stack trace in case of uncaught exceptions to
    debug log.
    """

    print("An unexpected {} exception occurred. Traceback will be written to "
          "debug log. The error message was: {}"
          .format(exception_type.__name__, _get_message(exception)))
    logging.debug("Traceback:")
    for item in format_tb(traceback):
        logging.debug(item.strip())
    sys.exit(2)

def main():
    """The main function of theanolm command.
    """

    parser = argparse.ArgumentParser(prog='theanolm')
    subparsers = parser.add_subparsers(
        title='commands',
        help='selects the command to perform ("theanolm command --help" '
             'displays help for the specific command)')

    train_parser = subparsers.add_parser(
        'train', help='train a model')
    train.add_arguments(train_parser)
    train_parser.set_defaults(command_function=train.train)

    score_parser = subparsers.add_parser(
        'score', help='score text or n-best lists using a model')
    score.add_arguments(score_parser)
    score_parser.set_defaults(command_function=score.score)

    decode_parser = subparsers.add_parser(
        'decode', help='decode a word lattice using a model')
    decode.add_arguments(decode_parser)
    decode_parser.set_defaults(command_function=decode.decode)

    sample_parser = subparsers.add_parser(
        'sample', help='generate text using a model')
    sample.add_arguments(sample_parser)
    sample_parser.set_defaults(command_function=sample.sample)

    version_parser = subparsers.add_parser(
        'version', help='display the version number')
    version_parser.set_defaults(command_function=version.version)

    args = parser.parse_args()

    if hasattr(args, 'command_function'):
        sys.excepthook = exception_handler
        try:
            # Increasing recursion limit is necessary with deep networks.
            sys.setrecursionlimit(100000)

            args.command_function(args)
        except FileNotFoundError as e:
            print('Could not open one of the required files. The error message '
                  'was:',
                  _get_message(e))
            sys.exit(2)
        except NumberError as e:
            print('A numerical error has occurred. This may happen e.g. when '
                  'network parameters go to infinity. If this happens during '
                  'training, using a smaller learning rate usually helps. '
                  'Another possibility is to use gradient normalization. If '
                  'this happens during inference, there is probably something '
                  'wrong in the model. The error message was:',
                  _get_message(e))
            sys.exit(2)
        except ContextNotDefined as e:
            print('Theano returned error "{}". You need to map the device '
                  'contexts that are used in the architecture file to CUDA '
                  'devices '
                  '(e.g. THEANO_FLAGS=contexts=dev0->cuda0;dev1->cuda1).'
                  .format(_get_message(e)))
            sys.exit(2)
        except TheanoConfigurationError as e:
            print('There is a problem with Theano configuration. Please check '
                  'THEANO_FLAGS environment variable and .theanorc file. The '
                  'error message was:',
                  _get_message(e))
            sys.exit(2)
        except GpuArrayException as e:
            if _get_message(e) == 'out of memory':
                print('Theano could not fit all the required variables in the '
                      'GPU memory. Memory consumption can be reduced by using '
                      'a smaller vocabulary, smaller layers, smaller '
                      'mini-batch size, shorter sequence length, or dividing '
                      'the model over multiple GPUs. Traceback will be written '
                      'to debug log.')
            else:
                print('Theano returned a gpuarray error "{}". Traceback will '
                      'be written to debug log.'
                      .format(_get_message(e)))
            logging.debug(e, exc_info=True)
            sys.exit(2)
        except IncompatibleStateError as e:
            print('The neural network state that you tried to load is invalid. '
                  'It may have been created with an old version of TheanoLM, '
                  'or the file may be corrupt. The error message was:',
                  _get_message(e))
            sys.exit(2)
        except KeyboardInterrupt:
            print('Interrupted from keyboard.')
    else:
        parser.print_help()

if __name__ == "__main__":
    main()
