#!/usr/bin/env python

from __future__ import print_function

import threading
import argparse
import subprocess
import adversarial_vision_challenge
import foolbox
import numpy as np
import yaml
import time
import os
import sys
import socket

def _get_free_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))
        return s.getsockname()[1]


def test_attack(directory, no_cache, gpu, mode):
    print('##### START CHECK ######')
    print('Analyzing attack in folder "{}"'.format(directory))

    # remove old container if exists
    subprocess.Popen("docker rm -f avc_test_attack", shell=True).wait()

    # build container
    if no_cache:
        print('Not using cache for docker build. TODO NOT WORKING!')
        subprocess.Popen(
            "crowdai-repo2docker --no-run --image-name avc/test_attack_submission --debug {}".format(directory),
            shell=True).wait()
    else:
        print('Using cache for docker build (if exists)')
        subprocess.Popen("crowdai-repo2docker --no-run --image-name avc/test_attack_submission --debug {}".format(
            directory), shell=True).wait()

    # build local temporary directories to read sample images and write adversarials
    for folder in ['avc_images/', 'avc_results/']:
        if not os.path.exists(folder):
            os.makedirs(folder)
        else:
            import shutil
            shutil.rmtree(folder)
            os.makedirs(folder)

    # write a single source image into image directory
    image, _ = foolbox.utils.imagenet_example()
    image, label = image[:64, :64], 53
    np.save("avc_images/img0.npy".format(directory), image)
    
    with open('avc_images/labels.yml'.format(directory), 'w') as outfile:
        yaml.dump({'img0.npy' : label}, outfile)

    # start mock model
    class MockModel(foolbox.models.Model):

        def __init__(self, original, label, mode):
            super(MockModel, self).__init__(bounds=(0, 255),
                                            channel_axis=3,
                                            preprocessing=(1, 1))
            self.original = original
            self.label = label
            self.untargeted = mode == 'untargeted'

        def predictions(self, image):
            # if distance is small return correct label, else return wrong label
            if np.linalg.norm(self.original - image) < 10:
                return self.label if self.untargeted else (self.label + 1) % 200
            else:
                return (self.label + 1) % 200 if self.untargeted else self.label

        def batch_predictions(self, batch):
            label = self.predictions(batch[0])
            return [label]

        def num_classes(self):
            return 200

    print('Creating a mock model')
    fmodel = MockModel(image, label, mode)

    print('Starting a model server')
    from adversarial_vision_challenge import model_server

    # get IP
    ip = 'localhost'
    port = _get_free_port()
    os.environ["PORT"] = str(port)
    print('Using port: {}'.format(port))

    thread = threading.Thread(
        target=model_server,
        args=(fmodel,))
    thread.daemon = True
    thread.start()


    # wait until start of model
    print('Waiting for server to start.')
    cmd = 'wget -qO - --tries 30 --retry-connrefused http://{}:{}'.format(
        ip, port)
    p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
    response = p.stdout.read()[:-1].decode('UTF-8')
    assert response == "NIPS 2018 Adversarial Vision Challenge Model Server"

    # start attack container
    print('Starting attack container')
    hostpath = os.path.abspath('.')
    imagepath = os.path.join(hostpath, 'avc_images')
    resultpath = os.path.join(hostpath, 'avc_results')
    print(imagepath, resultpath)
    subprocess.Popen(
        "NV_GPU={gpu} nvidia-docker run  -d --net=host "
        "-e GPU={gpu} "
        "-v {imagepath}:/images/ -v {resultpath}:/results/ "
        "-e MODEL_SERVER=localhost "
        "-e MODEL_PORT={port} "
        "-e INPUT_IMG_PATH=/images "
        "-e INPUT_YML_PATH=/images/labels.yml "
        "-e OUTPUT_ADVERSARIAL_PATH=/results "
        "--name=avc_test_attack avc/test_attack_submission bash run.sh".format(gpu=gpu, 
        port=port, imagepath=imagepath, resultpath=resultpath), shell=True).wait()

    # monitor results folder to check if adversarial is written
    start_time = time.time()

    while not os.path.isfile('avc_results/img0.npy') and time.time() - start_time < 20:
        time.sleep(1)
    
    # catch timeout
    if not os.path.isfile('avc_results/img0.npy'):
        raise RuntimeError('Results file not written with time limit (20 seconds)- something went wrong!')
    else:
        print('')
        print('Test successful')
        print('')
        sys.exit()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "directory", help="The directory containing the Dockerfile.")
    parser.add_argument(
        "--no-cache", action='store_true',
        help="Disables the cache when building the attack image.")
    parser.add_argument(
        "--mode", default='untargeted',
        help="Mode can be targeted or untargeted.")
    parser.add_argument(
        "--gpu", type=int, default=0, help="GPU number to run container on")
    args = parser.parse_args()
    test_attack(args.directory, no_cache=args.no_cache, gpu=args.gpu, mode=args.mode)
