#!python

from __future__ import print_function

import threading
import argparse
import subprocess
import adversarial_vision_challenge
import foolbox
import numpy as np
import yaml
import time
import os
import sys
import socket

def _get_free_port():
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind(('', 0))
    port = s.getsockname()[1]
    s.close()
    return port

def test_attack(directory, no_cache, gpu, mode):
    print('##### START CHECK ######')
    print('performing basic setup tests')
    subprocess.Popen('avc-test-setup', shell=True).wait()

    print('Analyzing attack in folder "{}"'.format(directory))

    # remove old container if exists
    subprocess.Popen("docker rm -f avc_test_attack", shell=True).wait()

    # build container
    if no_cache:
        print('Not using cache for docker build. TODO NOT WORKING!')
        subprocess.Popen(
            "crowdai-repo2docker --no-run --image-name avc/test_attack_submission --debug {}".format(directory),
            shell=True).wait()
    else:
        print('Using cache for docker build (if exists)')
        subprocess.Popen("crowdai-repo2docker --no-run --image-name avc/test_attack_submission --debug {}".format(
            directory), shell=True).wait()

    # build local temporary directories to read sample images and write adversarials
    for folder in ['avc_images/', 'avc_results/']:
        if not os.path.exists(folder):
            os.makedirs(folder)
        else:
            import shutil
            shutil.rmtree(folder)
            os.makedirs(folder)

    # write source images for testing into image directory
    test_samples = adversarial_vision_challenge.utils.get_test_data()[:100]
    labels = {}

    for k, (image, label) in enumerate(test_samples):
        np.save("avc_images/img{}.npy".format(k), image)
        labels['img{}.npy'.format(k)] = label
    
    with open('avc_images/labels.yml', 'w') as outfile:
        yaml.dump(labels, outfile)

    # start simple nearest neighbour mock model
    class MockModel(foolbox.models.Model):

        def __init__(self, mode):
            super(MockModel, self).__init__(bounds=(0, 255),
                                            channel_axis=3,
                                            preprocessing=(1, 1))
            self.samples = test_samples
            self.S = np.stack([sample[0] for sample in test_samples]).reshape((len(test_samples), -1))[:, :1000]
            self.untargeted = mode == 'untargeted'
            self.calls = 0

        def predictions(self, image):
            # if distance is small return correct label, else return wrong label
            self.calls += 1

            # get sample with minimum distance to image
            distances = np.linalg.norm(self.S - image.flatten()[:1000][None], axis=1)
            distance = np.amin(distances)

            if self.untargeted:
                if distance < 50:
                    return self.samples[np.argmin(distances)][1]
                else:
                    return (self.samples[np.argmin(distances)][1] + 30) % 200
            else:
                if distance < 50:
                    return (self.samples[np.argmin(distances)][1] + 30) % 200
                else:
                    return self.samples[np.argmin(distances)][1]
                    

        def batch_predictions(self, batch):
            self.calls += batch.shape[0]
            return [self.predictions(image) for image in batch]

        def num_classes(self):
            return 200

    print('Creating a mock model')
    fmodel = MockModel(mode)

    print('Starting a model server')
    from adversarial_vision_challenge import model_server

    # get IP
    ip = 'localhost'
    port = _get_free_port()
    os.environ["MODEL_PORT"] = str(port)
    print('Using port: {}'.format(port))

    thread = threading.Thread(
        target=model_server,
        args=(fmodel,))
    thread.daemon = True
    thread.start()


    # wait until start of model
    print('Waiting for server to start.')
    cmd = 'wget -qO - --tries 30 --retry-connrefused http://{}:{}'.format(
        ip, port)
    p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
    response = p.stdout.read()[:-1].decode('UTF-8')
    assert response == "NIPS 2018 Adversarial Vision Challenge Model Server"

    # start attack container
    print('Starting attack container')
    hostpath = os.path.abspath('.')
    imagepath = os.path.join(hostpath, 'avc_images')
    resultpath = os.path.join(hostpath, 'avc_results')
    print(imagepath, resultpath)
    subprocess.Popen(
        "NV_GPU={gpu} nvidia-docker run  -d --net=host "
        "-e GPU={gpu} "
        "-v {imagepath}:/images/ -v {resultpath}:/results/ "
        "-e MODEL_SERVER=localhost "
        "-e MODEL_PORT={port} "
        "-e INPUT_IMG_PATH=/images "
        "-e INPUT_YML_PATH=/images/labels.yml "
        "-e OUTPUT_ADVERSARIAL_PATH=/results "
        "--name=avc_test_attack avc/test_attack_submission bash run.sh".format(gpu=gpu, 
        port=port, imagepath=imagepath, resultpath=resultpath), shell=True).wait()

    # monitor results folder to check if adversarial is written
    start_time = time.time()

    while True:
        time.sleep(5)
        result_files = os.listdir('avc_results/')
        print('{} result files written after {} seconds.'.format(len(result_files), int(time.time() - start_time)))

        if len(result_files) == 0 and time.time() - start_time > 19:
            raise RuntimeError('Results file not written with time limit (20 seconds)- something went wrong!')
        elif time.time() - start_time > 21 and (time.time() - start_time) / float(len(result_files)) > 10:
            raise RuntimeError('Your attack is too slow (> 10 seconds / sample)!')

        if len(result_files) == len(test_samples):
            break

        if 'avc_test_attack' not in str(subprocess.check_output('docker ps', shell=True)):
            print("""Your container stopped running before all images were processed. 
                This either means that the attack was not able to produce adversarials 
                for all samples or that the attack stopped because of runtime errors.
                """
                )
            break

    if len(result_files) < len(test_samples) / 2:
        raise RuntimeError('The attack produced results for less then 50\% of the samples ({}/{}).'.format(len(result_files), len(test_samples)))

    # check that the number of calls is below maximum
    if fmodel.calls < len(test_samples) * 1000:
        print('Your attack queried the model {} times (maximum allowed: {})'.format(fmodel.calls, len(test_samples) * 1000))
    else:
        raise RuntimeError('Your attack queried the model {} too many times (maximum allowed: {})!'.format(fmodel.calls, len(test_samples) * 1000))

    # check whether results are truly adversarials and report median distance
    distances = []

    for file in os.listdir('avc_results/'):
        original = np.load('avc_images/{}'.format(file)).astype(np.float64)
        adversarial = np.load('avc_results/{}'.format(file)).astype(np.float64)
        distances.append(np.linalg.norm(original - adversarial))

    distances = np.array(distances)
    print('Number of adversarials {} of {}'.format((distances > 50).sum(), distances.shape[0]))
    print('Median adversarial distances: {} (optimum = 50)'.format(np.median(distances[distances > 50])))

    print('')
    print('Test successful')
    print('')
    sys.exit()


if __name__ == '__main__':
    # set log file
    os.environ["LOG_FILE"] = 'avc.log'

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "directory", help="The directory containing the Dockerfile.")
    parser.add_argument(
        "--no-cache", action='store_true',
        help="Disables the cache when building the attack image.")
    parser.add_argument(
        "--mode", default='untargeted',
        help="Mode can be targeted or untargeted.")
    parser.add_argument(
        "--gpu", type=int, default=0, help="GPU number to run container on")
    args = parser.parse_args()
    test_attack(args.directory, no_cache=args.no_cache, gpu=args.gpu, mode=args.mode)