#!/usr/bin/env python
# -*- coding: utf-8 -*-
import flask
import requests
import redis

import argparse
import collections
import grp
import hashlib
import json
import logging
import logging.handlers
import os
import pkg_resources
import pwd
import random
import signal
import sys
import threading
import time

# Stop requests from spamming our logs
logging.getLogger("requests").setLevel(logging.WARNING)

from bundlemaker import BundleMaker

SESSION_COOKIE_NAME = "X-DDeflect-Session"
SESSION_COOKIE_LIFETIME = 60

try:
    from bundlemaker import settings
except IOError:
    # IO Error means that we can't load the default settings file.
    if __name__ == "__main__":
        # We'll reload the settings file in main
        pass
    else:
        # We're being imported as a module, there is no hope of
        # recovery.
        raise

def mash_dict(input_dict):
    # Mash together keys and values of a dict
    output_string = "".join([ i for i in input_dict.keys() ])
    output_string += "".join([ str(i) for i in input_dict.values() ])
    return output_string


class DebundlerMaker(object):

    def __init__(self, refresh_period):

        self.refresh_period = refresh_period
        self.key = None
        self.iv = None
        self.hmac_key = None
        self.gen_keys()
        self.refresher = threading.Timer(self.refresh_period,
                                         self.gen_keys())

    def gen_keys(self):

        key_bytes = os.urandom(16)
        iv_bytes = os.urandom(16)
        hmac_key_bytes = os.urandom(16)
        if self.key and self.iv and self.hmac_key:
            logging.info(("Rotating keys. Old key was %s, "
                          "old hmac key was %s and old IV was %s"),
                         self.key, self.iv.encode("hex"), self.hmac_key)

        self.hmac_key = hmac_key_bytes.encode('hex')
        self.key = key_bytes.encode("hex")
        self.iv = iv_bytes.encode("hex")

class VedgeManager(object):
    def __init__(self, vedge_data):
        self.vedge_data = vedge_data
        #self.redis = redis.Redis()
        self.vedge_threshold = 100
    '''
    def populateRedisVEdges(self):
        """
        Set initial values for v-edges in redis
        storing time windows and bandwidth constraints
        """
        for edge in self.vedge_data:
            value = json.dumps({
                "start": edge['availability']['start'],
                "end": edge['availability']['end'],
                "total_bandiwdth": edge['total_bandwidth'],
                "used_bandwidth": 0
            })
            self.redis.sadd("vedges", edge.key())
            self.redis.set(edge.key(), value)

            now = datetime.utcnow()
            now_time = now.time()

            start = time.strptime(edge['availability']['start'], "%H:%M")
            end = time.strptime(edge['availability']['end'], "%H:%M")

            if now_time >= start and now_time <= end:
                expiration = now.replace(
                                    hour = end.hour,
                                    minute = end.minute
                            )
                active_key = edge.key() + '_active'
                self.redis.sadd("active_vedges", active_key)
                self.redis.set(active_key, value)
                self.redis.expire(active_key, expiration)

            #Build special set for quick lookup of time windows, represent as hash set

        # Create active set and copy store to it
        # with expiration for keys
        # key should be delted for bandwidth when that quantity is reached
        # in terms of bandwidth, this should be handled by the badnwidth
        # recorder
        '''
    def refresh_vedges(self):
        """
        Rebuild v-edge list is number of available v-edges
        has slipped below predefined threshold
        """
        pass

    def get_vedge(self):
        """
        This function selects v-edges based on
        their current availability, the time last accessed,
        and the total bandwidth available
        """
        # pop first element in sorted list, reset timestamp
        """
        if self.redis.llen("active_vedges") < self.vedge_threshold:
            self.refresh_vedges()
        vedge = self.redis.srandmember("active_vedges")
        return vedge
        """

        #For now, just return a random Vedge
        chosen_vedge = self.vedge_data.keys()[
            random.randint(0, len(self.vedge_data.keys())-1)]
        prefix = self.vedge_data[chosen_vedge].get("prefix", "")
        if prefix and not prefix.startswith("/"):
            chosen_vedge = chosen_vedge + "/" + prefix
        return chosen_vedge


class DebundlerServer(flask.Flask):


    def __init__(self):
        if "template_directory" in settings.general:
            self.template_folder = settings.general["template_directory"]
        else:
            self.template_folder = pkg_resources.resource_filename('bundlemaker', "templates")

        super(DebundlerServer, self).__init__("DebundlerServer",
                                              template_folder=self.template_folder)

        self.debundler_maker = DebundlerMaker(settings.general["refresh_period"])
        self.vedge_manager = VedgeManager(settings.v_edges)

        self.bundles = collections.defaultdict(dict)
        self.remap_rules = settings.remap
        self.bundleMaker = BundleMaker(
            self.remap_rules,
            settings.general["bundler_location"]
        )

        self.salt = settings.general["url_salt"]
        self.refresh_period = settings.general["refresh_period"]

        redis_host = "localhost"
        if "redis_host" in settings.general:
            redis_host = settings.general["redis_host"]
        self.redis = redis.Redis(host=redis_host)

        # wildcard routing
        self.route('/', defaults={'path': ''})(self.root_route)
        self.route('/', methods=['POST'])(self.post_route)
        self.route("/ping")(self.ping)
        self.route("/_bundle/")(self.serve_bundle)
        # more wildcard routing
        self.route('/<path:path>')(self.root_route)
        self.route('/<path:path>', methods=['POST'])(self.post_route)

    def ping(self):
        return "Hi"

    def reload_vedges(self, vedge_manager):
        self.vedge_manager = vedge_manager

    def _bundle_sig_host_path(self, request):
        return hashlib.sha512(
            self.salt + request.path + request.headers["host"]
        ).hexdigest()

    def _bundle_sig_useragent_ip(self, request):
        return hashlib.sha512(
            self.salt + request.path + request.user_agent.string \
                + request.environ['REMOTE_ADDR'] + request.headers["host"]
        ).hexdigest()

    def _bundle_sig_user_agent_cookies(self, request):
        # user agent + cookies

        # Mash together all headers
        cookie_string = mash_dict(request.cookies)

        return hashlib.sha512(
            self.salt + request.user_agent.string + \
            cookie_string + request.headers["host"] + request.path
        ).hexdigest()

    def _bundle_sig_useragent_ip_cookies(self, request):
        # user agent + IP + cookies

        # Mash together all cookies
        cookie_string = mash_dict(request.cookies)

        return hashlib.sha512(
            self.salt + request.user_agent.string + \
                cookie_string + request.environ['REMOTE_ADDR'] + \
                request.headers["host"] + request.path
        ).hexdigest()

    def _bundle_sig_useragent_ip_headers(self, request):
        """ The most "unique" mechanism """

        # Mash together all headers
        header_string = mash_dict(request.headers)

        return hashlib.sha512(
            self.salt + request.user_agent.string + \
            request.environ['REMOTE_ADDR'] + header_string + \
            request.headers["host"] + request.path
        ).hexdigest()

    def gen_bundle_hash(self, request):
        return self._bundle_sig_useragent_ip(request)

    def gen_bundle(self, frequest, path, key, iv, hmac_key):
        request_host = frequest.headers.get('Host')
        logging.debug("Bundle request url is %s", frequest.url)

        try:
            bundler_result = self.bundleMaker.createBundle(
                frequest,
                key,
                iv,
                hmac_key
            )
        except requests.ConnectionError as e:
            logging.error("Failed to request resources from Bundler! %s", str(e))
            flask.abort(503)

        if not bundler_result:
            logging.error("Failed to get bundle for %s", frequest.url)
            flask.abort(503)

        #Not 1 thousand percent sure this is the same as what you
        # are currently saving so needs to be rechecked
        logging.info("hmac_sig: %s", bundler_result['hmac_sig'])
        bundle_json_path = pkg_resources.resource_filename('bundlemaker', "templates/bundle.json")
        rendered_bundle = flask.render_template(
            "bundle.json",
            encrypted=bundler_result['bundle'],
            hmac=bundler_result['hmac_sig']
        )
        bundle_content = rendered_bundle
        bundle_signature = self.gen_bundle_hash(frequest)
        self.redis.sadd("bundles", bundle_signature)
        self.redis.set(bundle_signature, json.dumps({
            "host": request_host,
            "path": path,
            "headers": dict(frequest.headers),
            "requestor": frequest.environ["REMOTE_ADDR"],
            "bundle": bundle_content,
            "fetched": time.time()
        }))
        self.redis.expire(bundle_signature, self.refresh_period)

        return bundle_signature

    def post_route(self):
        """
        Passes POST request directly to the remapped origin
        returns the server's response
        """


        # THIS FUNCTION IS BROKEN, IT WAS NOT TESTED WHEN IT WAS
        # DEVELOPED. FIXME FIXME FIXME.

        request_host = flask.request.headers.get('Host')
        if request_host not in self.remap_rules:
            flask.abort(403)

        remap_host = self.remap_rules[request_host]

        # Pass all headers
        request_headers = dict(flask.request.headers)
        request_headers['Host'] = request_host

        remapped_origin = request.url.replace(request_host, origin, 1)

        proxied_response = requests.post(
            remapped_origin,
            headers=request_headers,
            files=flask.request.files,
            data=flask.request.form,
            cookies=flask.request.cookies,
        )

        # TODO this doesn't return any content, just returns the
        # requests Response object. not sure whether bill tested this.
        return flask.Response(
            response=proxied_response
        )

    def serve_bundle(self, bundlehash):
        logging.info("Got a request for bundle with hash of %s", bundlehash)
        if not self.redis.sismember("bundles", bundlehash):
            logging.error("Got request for bundle %s but it is not in the bundles list",
                          bundlehash)
            flask.abort(404)

        bundle_get = json.loads(self.redis.get(bundlehash))
        if "bundle" not in bundle_get:
            # TODO this indicates that the bundle might have expired -
            # we need to rebundle the page.
            logging.error("Failed to get a valid bundle from bundle key %s", bundlehash)
            flask.abort(503)
        else:
            bundle = bundle_get["bundle"]
            #TODO fix this to not be a wildcard
            resp = flask.Response(bundle, status=200)
            resp.headers['Access-Control-Allow-Origin'] = "*"
            return resp

    def reverse_proxy(self, request, path):
        """
         Reverse proxy requests to origin. Strip off any
        DDeflect-related stuff and headers that make the requests
        misbehave.
        """

        host = request.headers.get("Host")
        if not host:
            # We don't proxy for domains with no host header
            flask.abort(403)

        if host not in settings.remap:
            # We don't proxy for hosts we're not configured for
            logging.warn("Received an invalid request for %s/%s from %s.",
                         host, path, request.headers.get("X-Forwarded-For"))
            flask.abort(403)

        origin = settings.remap[host]["origin"]
        request_headers = dict(request.headers)
        remapped_url = request.url.replace(host, origin, 1)
        request_cookies = dict(request.cookies)

        # If you don't del content-length the origin will probably throw a 411.
        del(request_headers["Content-Length"])
        # Scrub the ddeflect session key
        del(request_cookies[SESSION_COOKIE_NAME])

        # and scrub it from the headers - TODO not sure if this is
        # necessary.
        scrubbed_cookies = [ key+"="+val for key, val in request_cookies.iteritems() \
                               if key != SESSION_COOKIE_NAME ]
        request_headers_cookie_str = ";".join(scrubbed_cookies)
        request_headers["Cookie"] = request_headers_cookie_str

        proxied_response_obj = requests.get(
            remapped_url,
            headers=request_headers,
            data=dict(request.form),
            cookies=dict(request.cookies)
        )

        return proxied_response_obj

    def root_route(self, path):
        # We don't currently handle favicons because that would slow
        # down bundling. Returning a 404 here because returning other
        # codes can actually break loading in some less enlightened
        # browsers.
        if path == "favicon.ico":
            flask.abort(404)

        request_host = flask.request.headers.get('Host')
        logging.debug("New request received - request is for %s", flask.request.url)

        # HACK - Flask doesn't behave properly as an endpoint. In
        # a perfect world we'd use header_filter in ATS to pass
        # this properly but this isn't always available and we'll
        # need to repackage. For now we hack around it because
        # fuck computers.
        via_header = flask.request.headers.get("Via")
        if via_header and via_header.startswith("https"):
            flask.request.url = flask.request.url.replace("http", "https", 1)
        else:
            logging.info("Got a request for %s request_host with no Via header", request_host)

        ddeflect_session_id = flask.request.cookies.get(SESSION_COOKIE_NAME)

        if path.startswith("_bundle"):
            logging.debug("Got a _bundle request at %s", path)
            if "/" not in path:
                logging.error("got request that started with _bundle but had no slash!")
                flask.abort(503)
            return self.serve_bundle(path.split("/")[1])
        elif self.redis.exists("session_%s" % ddeflect_session_id):

            # We're *not* serving a bundle but we *do* have a valid
            # DDeflect session ID - reverse proxy our request directly
            # with no bundling.

            logging.info("Serving resource %s unbundled because of valid session ID",
                         flask.request.url)

            revproxy_result = self.reverse_proxy(flask.request, path)
            # TODO check this value for excessive use/abuse
            self.redis.incr("session_%s" % ddeflect_session_id)

            if revproxy_result.ok:
                return flask.Response(revproxy_result.text, status=200)
            else:
                logging.error("Failed to fetch resource %s to reverse proxy: %s (error %d)",
                              flask.request.url,
                              revproxy_result.text,
                              revproxy_result.status_code)
                flask.abort(revproxy_result.status_code)

        else:
            v_edge = self.vedge_manager.get_vedge()
            key = self.debundler_maker.key
            iv = self.debundler_maker.iv
            hmac_key = self.debundler_maker.hmac_key

            if not path:
                path = "/"

            bundlehash = None
            requested_hash = self.gen_bundle_hash(flask.request)

            if self.redis.exists(requested_hash):
                # TODO Should we update the last fetched time? Update the expiry?
                bundlehash = requested_hash

            if not bundlehash:
                logging.debug("No bundle hash found. Requesting new bundle")
                bundlehash = self.gen_bundle(flask.request, path,
                                             key, iv, hmac_key)
                if not bundlehash:
                    logging.error("Failing because bundlehash was null")
                    flask.abort(404)

            if not self.redis.sismember("bundles", bundlehash) or not self.redis.exists(bundlehash):
                logging.error("Site not in bundles after bundling was requested!!")
                flask.abort(503)

            render_result = flask.render_template(
                "debundler_template.html.j2",
                hmac_key=unicode(hmac_key),
                key=unicode(key),
                iv=unicode(iv),
                v_edge=unicode(v_edge),
                bundle_signature=bundlehash
            )

            logging.debug("Returning template to user for signature %s",
                          bundlehash)
            resp = flask.Response(render_result, status=200)

            ddeflect_session_id = os.urandom(128).encode("hex")
            # TODO allow for updating of the max_age
            resp.set_cookie(SESSION_COOKIE_NAME, ddeflect_session_id,
                            max_age=SESSION_COOKIE_LIFETIME)
            self.redis.set("session_%s" % ddeflect_session_id, 0)
            self.redis.expire("session_%s" % ddeflect_session_id,
                              SESSION_COOKIE_LIFETIME + 5)

            return resp

class BundleManagerDaemon():
    def __init__(self, stdin='/dev/stdin',
                 stdout='/dev/stdout',
                 stderr='/dev/stderr'):
        self.stdin = stdin
        self.stdout = stdout
        self.stderr = stderr
        self.debundleServer = None

    def run(self):

        port = settings.general["port"]
        host = settings.general["host"]

        self.debundleServer = DebundlerServer()
        logging.info("Starting to serve on port %d", port)
        self.debundleServer.run(debug=False, threaded=True,
                                host=host, port=port, use_reloader=False)

    def delpid(self):
        if os.path.exists(settings.general["pidfile"]):
            os.remove(settings.general["pidfile"])

    def daemonise(self):
        try:
            pid = os.fork()
            if pid > 0:
                sys.exit(0)
        except OSError, e:
            logging.error("fork #1 failed: %d (%s)\n",
                          e.errno,
                          e.strerror)
            sys.exit(1)
        try:
            pid = os.fork()
            if pid > 0:
                sys.exit(0)
        except OSError, e:
            sys.stderr.write("fork #2 failed: %d (%s)\n",
                             e.errno,
                             e.strerror)
            sys.exit(1)
        pid = str(os.getpid())

        with open(settings.general["pidfile"], 'w+') as f:
            f.write("%s\n" % pid)

    def getpid(self):
        try:
            pf = file(settings.general["pidfile"], 'r')
            pid = int(pf.read().strip())
            pf.close()
        except IOError:
            pid = None
        return pid

    def start(self):

        if self.getpid():
            logging.error("Stale pidfile exists.\n")

        self.daemonise()
        self.run()

    def stop(self):
        pid = self.getpid()
        #TODO wat? How can we not be running if we're called from
        #inside ourself.
        if not pid:
            logging.error("Bundlemanager not running\n")
            sys.exit(1)
        #TODO simplify this, it's more than a little overcomplicated.
        try:
            while 1:
                self.delpid()
                os.kill(pid, signal.SIGKILL)
                time.sleep(0.1)
        except OSError, e:
            e = str(e)
            if e.find("No such process") > 0:
                self.delpid()
            else:
                logging.error(e)
                sys.exit(1)

    def restart(self):
        self.stop()
        self.start()

def drop_privileges(uid_name='nobody', gid_name='no_group'):
    if os.getuid() != 0:
        return

    running_uid = pwd.getpwnam(uid_name).pw_uid
    running_gid = grp.getgrnam(gid_name).gr_gid

    os.setgroups([])
    try:
        os.setgid(running_gid)
    except OSError, e:
        logging.error('Could not set effective group id: %s', e)
    try:
        os.setuid(running_uid)
    except OSError, e:
        logging.error('Could not set effective group id: %s', e)
    #TODO what is old_umask supposed to be for?
    old_umask = os.umask(077)

def create_handler(daemon):
    def _handle_signal(signum, frame):
        if signum == signal.SIGTERM:
            daemon.stop()
            logging.warn("Closing on SIGTERM")
        elif signum == signal.SIGHUP:
            if daemon.debundleServer:
                logging.warn("Reload V-Edge list")
                from bundlemaker import settings
                settings = reload(settings)
                daemon.debundleServer.reload_vedges(
                    VedgeManager(settings.v_edges)
                )
    return _handle_signal

def main(args):
    if "BUNDLEMANAGER_CONFIG" not in os.environ:
        #Backwards compatability
        os.environ["BUNDLEMANAGER_CONFIG"] = args.config_path

    #Make settings import/reimport available everywhere.
    global settings
    from bundlemaker import settings

    logger = logging.getLogger()
    if args.verbose:
        logger.setLevel(logging.DEBUG)
        handler = logging.StreamHandler()
    else:
        logger.setLevel(logging.INFO)
        handler = logging.handlers.SysLogHandler(address="/dev/log")
    handler.setFormatter(logging.Formatter("bundlemanager [%(process)d] %(levelname)s %(message)s"))
    logger.addHandler(handler)

    #TODO Dropping here breaks binding to ports less than 1025
    if settings.general["port"] < 1025:
        logging.error(("Dropping privileges means that the use of ports"
                       " less than 1025 will fail"))

    drop_privileges(settings.general["uid_name"],
                    settings.general["gid_name"])

    daemon = BundleManagerDaemon()
    signal.signal(signal.SIGTERM, create_handler(daemon))
    signal.signal(signal.SIGHUP, create_handler(daemon))

    if args.verbose:
        daemon.run()
    else:
        daemon.start()
    sys.exit(0)

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Manage DDeflect bundle serving and retreival.')
    parser.add_argument('-c', dest='config_path', action='store',
                        default='/etc/bundlemanager.yaml',
                        help='Path to config file.')
    parser.add_argument('-v', '--verbose', dest='verbose', action='store_true',
                        help='Verbose mode, not daemonized')

    args = parser.parse_args()
    main(args)
