#!python

import argparse
import subprocess
import adversarial_vision_challenge
import numpy as np
import os
import sys
import socket
from tqdm import tqdm
from adversarial_vision_challenge.common import check_track
from adversarial_vision_challenge.utils import _wait_for_server_start

def checkmark():
    print(u' \u2713')


def _get_free_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))
        return s.getsockname()[1]


def test_model(directory, no_cache, gpu):
    check_track(directory, 'nips-2018-avc-robust-model-track')

    image_name = 'avc/model_submission'
    container_name = 'avc_test_model_submission'

    subprocess.check_call('avc-test-setup', shell=True)

    print()
    print('Building docker image of submission')
    print('-----------------------------------')

    print('Submission folder: "{}"'.format(directory))

    # build image
    if no_cache:
        print('Building image (without cache)...')
        raise AssertionError(
            'Option does not yet exist because repo2docker does'
            ' not allow us to build without cache.')
    else:
        print('Building image from cache (if exists)...', end="")
        cmd = "crowdai-repo2docker --no-run --image-name {} --debug {}"
        subprocess.check_call(cmd.format(image_name, directory), shell=True)
        checkmark()

    # remove old container if exists
    containers = str(subprocess.check_output('docker ps -a', shell=True))
    if container_name in containers:
        print('Removing existing submission container...', end="")
        cmd = "docker rm -f {cn}".format(cn=container_name)
        subprocess.check_call(cmd, shell=True)
        checkmark()

    # start model container
    port = _get_free_port()
    cmd = "NV_GPU={gpu} nvidia-docker run --expose={port} -d -p {port}:{port} "
    cmd += "-e GPU={gpu} -e MODEL_PORT={port} --name={cn} {im} bash run.sh"
    cmd = cmd.format(port=port, gpu=gpu, cn=container_name, im=image_name)
    print('Starting container on port {} using the command \n \n {}'.format(
        port, cmd))
    subprocess.check_call(cmd, shell=True)

    # attach to container to print output in case of failures
    cmd = "docker attach {}".format(container_name)
    dump = subprocess.Popen(cmd, shell=True, stderr=subprocess.PIPE)

    # get IP
    form = '{{ .NetworkSettings.IPAddress }}'
    cmd = "docker inspect --format='{}' {}".format(form, container_name)
    p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
    p.wait()
    ip = p.stdout.read()[:-1].decode('UTF-8')
    print('Received IP of server: ', ip)

    # start client model
    print('Connecting to model sever...', end="")
    url = 'http://{}:{}'.format(ip, port)
    model = adversarial_vision_challenge.TinyImageNetBSONModel(url)
    _wait_for_server_start(model)
    checkmark()

    print('Server version', model.server_version())
    # get predictions and/or gradient for model
    print('Checking model axis and bounds...', end="")
    channel_axis = model.channel_axis()

    assert channel_axis == 3, \
        "model channel axis should be 3, but is: %s" % channel_axis
    assert model.bounds() == (0, 255), \
        "model bounds should be (0,255), but got: %s" % model.bounds
    checkmark()

    # test predictions
    print('Checking return value...', end="")
    image, label = np.random.uniform(size=(64, 64, 3)).astype(
        np.float32) * 255, 1

    assert isinstance(model(image), int), \
        "prediction should be an int-value, but got: %s" % type(model(image))
    checkmark()

    # test prediction performance
    print('Testing model accuracy on test samples...')
    test_samples = adversarial_vision_challenge.utils.get_test_data()
    correct = 0

    for image, label in tqdm(test_samples):
        if model(image) == label:
            correct += 1

    print('The top-1 performance of your model is {}%.'.format(
        100 * correct / float(len(test_samples))))

    if correct / float(len(test_samples)) < 0.5:
        raise AssertionError(
            'The performance of your model is too low (< 50%)! Please check'
            ' whether your preprocessing is correctly implemented.')

    print('')
    print('All tests successful, have fun submitting!')
    print('')
    sys.exit()


if __name__ == '__main__':
    print('running {}'.format(os.path.basename(__file__)))
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "directory", help="The directory containing the Dockerfile.")
    parser.add_argument(
        "--no-cache", action='store_true',
        help="Disables the cache when building the model image.")
    parser.add_argument(
        "--gpu", type=int, default=0, help="GPU number to run container on")
    args = parser.parse_args()
    test_model(args.directory, no_cache=args.no_cache, gpu=args.gpu)
