#!/usr/bin/env python

from __future__ import print_function

import threading
import argparse
import subprocess
import adversarial_vision_challenge
import foolbox
import numpy as np
import yaml
import time
import os
import sys
import socket


def _get_free_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))
        return s.getsockname()[1]


def test_model(directory, no_cache, gpu):
    print('##### START CHECK ######')
    print('Analyzing model in folder "{}"'.format(directory))

    # remove old container if exists
    container_name = 'avc_test_model_submission'
    subprocess.Popen("docker rm -f {cn}".format(cn=container_name), shell=True).wait()

    # build container
    if no_cache:
        print('Not using cache for docker build. TODO NOT WORKING!')
        subprocess.Popen(
            "crowdai-repo2docker --no-run --image-name avc/model_submission --debug {}".format(directory),
            shell=True).wait()
    else:
        print('Using cache for docker build (if exists)')
        subprocess.Popen("crowdai-repo2docker --no-run --image-name avc/model_submission --debug {}".format(
            directory), shell=True).wait()

    # start model container
    print('##### RUNNING DOCKER CONTAINER ######')
    port = _get_free_port()
    print('using port: {}'.format(port))
    cmd = "NV_GPU={gpu} nvidia-docker run --expose={port} -d -p {port}:{port} "
    cmd += "-e GPU={gpu} -e PORT={port} --name={cn} avc/model_submission bash run.sh"
    cmd = cmd.format(port=port, gpu=gpu, cn=container_name)
    subprocess.Popen(cmd, shell=True).wait()

    # attach to container to print output in case of failures
    cmd = "docker attach {}".format(container_name)
    dump = subprocess.Popen(cmd, shell=True, stderr=subprocess.PIPE)

    # get IP
    form = '{{ .NetworkSettings.IPAddress }}'
    cmd = "docker inspect --format='{}' {}".format(form, container_name)
    p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
    p.wait()
    ip = p.stdout.read()[:-1].decode('UTF-8')
    print('Received IP of server: ', ip)

    # wait until start of server
    print('Waiting for server to start.')
    for i in range(30):
        # check if the server is up
        cmd = 'wget -qO - --tries 2 --retry-connrefused http://{}:{}'
        if subprocess.Popen(cmd.format(ip, port), shell=True).wait() == 0:
            break

        # check if the container is running
        form = "{{ .State.Running }}"
        cmd = "docker inspect --format='{}' {}".format(form, container_name)
        p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
        p.wait()
        running = p.stdout.read().decode('UTF-8').strip() == 'true'
        if not running:
            # print any errors in the code and exit
            print(dump.stderr.read().decode('UTF-8'))
            form = "{{ .State.ExitCode }}"
            cmd = "docker inspect --format='{}' {}".format(
                form, container_name)
            p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
            p.wait()
            return exit(int(p.stdout.read().decode('UTF-8').strip()))

    # start client model
    print('##### CONNECTING TO THE RUNNING SERVER ######')
    model = adversarial_vision_challenge.TinyImageNetBSONModel('http://{}:{}'.format(ip, port))

    print('Server version', model.server_version())

    # get predictions and/or gradient for model
    print('Performing basic checks for the code.')
    channel_axis = model.channel_axis()

    assert channel_axis == 3, "model channel axis should be 3, but is: %s" % channel_axis
    assert model.bounds() == (0, 255), "model bounds should be (0,255), but got: %s" % model.bounds

    # test predictions
    image, label = np.random.uniform(size=(64, 64, 3)).astype(
        np.float32) * 255, 1

    assert isinstance(model(image), int), "prediction should be an int-value, but got: %s" % type(model(image))

    # test prediction performance
    print('### TESTING MODEL PERFORMANCE ON TEST DATA ###')
    test_samples = adversarial_vision_challenge.utils.get_test_data()
    correct = 0

    for k, (image, label) in enumerate(test_samples):
        print(model(image), label)
        if model(image) == label:
            correct += 1

        if k % 20 == 1:
            print(correct / float(k))# print('.', end="")

        #if k % 100 == 0:
        #    print(correct / float(k))

    print('The top-1 performance of your model is {}%.'.format(100 * correct / float(len(test_samples))))

    if correct / float(len(test_samples)) < 0.5:
        raise IOError('The performance of your model is too low (< 50%)! Please check whether your preprocessing is correctly implemented.')

    print('')
    print('Test successful')
    print('')
    sys.exit()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "directory", help="The directory containing the Dockerfile.")
    parser.add_argument(
        "--no-cache", action='store_true',
        help="Disables the cache when building the model image.")
    parser.add_argument(
        "--gpu", type=int, default=0, help="GPU number to run container on")
    args = parser.parse_args()
    test_model(args.directory, no_cache=args.no_cache, gpu=args.gpu)
