#!/usr/bin/env python

from __future__ import print_function

import threading
import argparse
import subprocess
import adversarial_vision_challenge
import foolbox
import numpy as np
import yaml
import time
import os
import sys
import socket
from tqdm import tqdm

def checkmark():
    print(u' \u2713')

def _get_free_port():
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind(('', 0))
    port = s.getsockname()[1]
    s.close()
    return port

def test_attack(directory, no_cache, gpu, mode, samples):
    image_name = 'avc/attack_submission'
    container_name = 'avc_test_attack_submission'

    subprocess.check_call('avc-test-setup', shell=True)

    print()
    print('Building docker image of submission')
    print('-----------------------------------')

    print('Submission folder: "{}"'.format(directory))

    # build image
    if no_cache:
        print('Building image (without cache)...')
        raise AssertionError('Option does not yet exist because repo2docker does not allow us to build without cache.')
        subprocess.check_call(
            "crowdai-repo2docker --no-run --image-name {} --debug {}".format(image_name, directory),
            shell=True)
    else:
        print('Building image from cache (if exists)...', end="")
        subprocess.check_call("crowdai-repo2docker --no-run --image-name {} --debug {}".format(
            image_name, directory), shell=True)
        checkmark()

    # remove old container if exists
    if container_name in str(subprocess.check_output('docker ps -a', shell=True)):
        print('Removing existing submission container...', end="")
        subprocess.check_call("docker rm -f {cn}".format(cn=container_name), shell=True)
        checkmark()

    # build local temporary directories to read sample images and write adversarials
    print('Creating empty folders for test samples and results...', end="")
    for folder in ['avc_images/', 'avc_results/']:
        if not os.path.exists(folder):
            os.makedirs(folder)
        else:
            import shutil
            shutil.rmtree(folder)
            os.makedirs(folder)

    checkmark()

    # write source images for testing into image directory
    print('Saving test samples into samples_folder (default: avc_images/)...', end="")
    test_samples = adversarial_vision_challenge.utils.get_test_data()[:samples]
    labels = {}

    for k, (image, label) in enumerate(test_samples):
        np.save("avc_images/img{}.npy".format(k), image)
        labels['img{}.npy'.format(k)] = label
    
    with open('avc_images/labels.yml', 'w') as outfile:
        yaml.dump(labels, outfile)

    checkmark()

    # start simple nearest neighbour mock model
    class MockModel(foolbox.models.Model):

        def __init__(self, mode):
            super(MockModel, self).__init__(bounds=(0, 255),
                                            channel_axis=3,
                                            preprocessing=(1, 1))
            self.samples = test_samples
            self.S = np.stack([sample[0] for sample in test_samples]).reshape((len(test_samples), -1))[:, :1000]
            self.untargeted = mode == 'untargeted'
            self.calls = 0

        def predictions(self, image):
            # if distance is small return correct label, else return wrong label
            self.calls += 1

            # get sample with minimum distance to image
            distances = np.linalg.norm(self.S - image.flatten()[:1000][None], axis=1)
            distance = np.amin(distances)

            if self.untargeted:
                if distance < 50:
                    return self.samples[np.argmin(distances)][1]
                else:
                    return (self.samples[np.argmin(distances)][1] + 30) % 200
            else:
                if distance < 50:
                    return (self.samples[np.argmin(distances)][1] + 30) % 200
                else:
                    return self.samples[np.argmin(distances)][1]
                    

        def batch_predictions(self, batch):
            self.calls += batch.shape[0]
            return [self.predictions(image) for image in batch]

        def num_classes(self):
            return 200

    print('Creating a mock model...', end="")
    fmodel = MockModel(mode)
    checkmark()

    from adversarial_vision_challenge import model_server

    # get IP
    ip = 'localhost'
    port = _get_free_port()
    os.environ["MODEL_PORT"] = str(port)
    print('Starting a model server at {}:{}...'.format(ip, port), end="")

    thread = threading.Thread(
        target=model_server,
        args=(fmodel,))
    thread.daemon = True
    thread.start()

    # wait until start of model
    print('Waiting for server to start...', end="")
    cmd = 'wget -qO - --tries 30 --retry-connrefused http://{}:{}'.format(
        ip, port)
    p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
    response = p.stdout.read()[:-1].decode('UTF-8')
    assert response == "NIPS 2018 Adversarial Vision Challenge Model Server"

    checkmark()

    # start attack container
    print('Starting attack container...', end="")
    hostpath = os.path.abspath('.')
    imagepath = os.path.join(hostpath, 'avc_images')
    resultpath = os.path.join(hostpath, 'avc_results')
    subprocess.Popen(
        "NV_GPU={gpu} nvidia-docker run  -d --net=host "
        "-e GPU={gpu} "
        "-v {imagepath}:/images/ -v {resultpath}:/results/ "
        "-e MODEL_SERVER=localhost "
        "-e MODEL_PORT={port} "
        "-e INPUT_IMG_PATH=/images "
        "-e INPUT_YML_PATH=/images/labels.yml "
        "-e OUTPUT_ADVERSARIAL_PATH=/results "
        "--name={cn} {im} bash run.sh".format(gpu=gpu, 
        port=port, imagepath=imagepath, resultpath=resultpath, cn=container_name, im=image_name), shell=True).wait()

    checkmark()

    # monitor results folder to check if adversarial is written
    print('Starting to test attack against mock model and test samples...')
    print("If you'd like to test with less or more samples, append e.g --samples 50 to your avc-test-XXX command.")
    start_time = time.time()

    with tqdm(total=len(test_samples)) as pbar:
        while True:
            time.sleep(1)
            result_files = os.listdir('avc_results/')
            num_results = len(result_files)
            # print('{} result files written after {} seconds.'.format(len(result_files), int(time.time() - start_time)))

            # update progress bar
            if pbar.n < num_results:
                pbar.update(num_results - pbar.n)

            # check that results are written within time limit
            if num_results == 0 and time.time() - start_time > 19:
                raise RuntimeError('Results file not written with time limit (20 seconds)- something went wrong!')
            elif time.time() - start_time > 21 and (time.time() - start_time) / float(len(result_files)) > 10:
                raise RuntimeError('Your attack is too slow (> 10 seconds / sample)!')

            # end condition
            if num_results == len(test_samples):
                break

            if 'avc_test_attack' not in str(subprocess.check_output('docker ps', shell=True)):
                print("""Your container stopped running before all images were processed. 
                    This either means that the attack was not able to produce adversarials 
                    for all samples or that the attack stopped because of runtime errors.
                    """
                    )
                break

    if len(result_files) < len(test_samples) / 2:
        raise RuntimeError('The attack produced results for less then 50\% of the samples ({}/{}).'.format(len(result_files), len(test_samples)))

    # check that the number of calls is below maximum
    if fmodel.calls < len(test_samples) * 1000:
        print('Your attack queried the model {} times (maximum allowed: {})'.format(fmodel.calls, len(test_samples) * 1000))
    else:
        raise RuntimeError('Your attack queried the model {} too many times (maximum allowed: {})!'.format(fmodel.calls, len(test_samples) * 1000))

    # check whether results are truly adversarials and report median distance
    distances = []

    for file in os.listdir('avc_results/'):
        original = np.load('avc_images/{}'.format(file)).astype(np.float64)
        adversarial = np.load('avc_results/{}'.format(file)).astype(np.float64)
        distances.append(np.linalg.norm(original - adversarial))

    distances = np.array(distances)
    print('Number of adversarials {} of {}'.format((distances > 50).sum(), distances.shape[0]))
    print('Median adversarial distances: {} (optimum = 50)'.format(np.median(distances[distances > 50])))

    print('')
    print('All tests successful, have fun submitting!')
    print('')
    sys.exit()


if __name__ == '__main__':
    # set log file
    os.environ["LOG_FILE"] = 'avc.log'

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "directory", help="The directory containing the Dockerfile.")
    parser.add_argument(
        "--no-cache", action='store_true',
        help="Disables the cache when building the attack image.")
    parser.add_argument(
        "--mode", default='untargeted',
        help="Mode can be targeted or untargeted.")
    parser.add_argument(
        "--gpu", type=int, default=0, help="GPU number to run container on")
    parser.add_argument(
        "--samples", type=int, default=100, help="Number of samples for testing.")
    args = parser.parse_args()
    test_attack(args.directory, no_cache=args.no_cache, gpu=args.gpu, mode=args.mode, samples=args.samples)