#!/usr/bin/env python3
# -*- coding: utf-8 -*-


# Copyright (c) 2019 Wolfgang Rohdewald <wolfgang@rohdewald.de>
# See LICENSE for details.

"""implement a server using the mapmytracks protocol.

There is one notable difference:

https://github.com/MapMyTracks/api/blob/master/services/stop_activity.md
says stop_activity has no parameter activity_id. Our server needs it,
Oruxmaps delivers it. Maybe the MMT API definition is wrong.
See https://github.com/MapMyTracks/api/issues/25

"""

# PYTHON_ARGCOMPLETE_OK
# for command line argument completion, put this into your .bashrc:
# eval "$(register-python-argcomplete gpxdo)"
# or see https://argcomplete.readthedocs.io/en/latest/


import os
import sys
import base64
import datetime
import argparse
import logging
import logging.handlers
import traceback
import ssl

from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import parse_qs

from gpxpy import gpx as mod_gpx

GPXTrackSegment = mod_gpx.GPXTrackSegment
GPXTrackPoint = mod_gpx.GPXTrackPoint
GPXXMLSyntaxException = mod_gpx.GPXXMLSyntaxException

# This uses not the installed copy but the development files
_ = os.path.dirname(sys.path[0] or sys.path[1])
if os.path.exists(os.path.join(_, 'gpxity', '__init__.py')):
    sys.path.insert(0, _)
# pylint: disable=wrong-import-position

from gpxity import Gpx, GpxFile, Directory, MMT, Account, DirectoryAccount, Backend, Lifetrack  # noqa pylint: disable=no-name-in-module

try:
    import argcomplete
    # pylint: disable=unused-import
except ImportError:
    pass


class TrackingMessage:

    """Life tracking: represent a received message.

    Args:
        from_ip: the sender
        parsed: the parsed POST data
        result: the POST answer
        response: tuple(HTTP status code, message)

    """

    def __init__(self, command, from_ip, parsed, result, response):
        self.time = datetime.datetime.now()
        self.command = command
        self.from_ip = from_ip
        self.parsed = parsed
        self.result = result
        self.response = response

    def log(self, prefix=''):
        """Log the message"""
        if self.response:
            logging.error('E %s%s', prefix, self)
        else:
            logging.info('I %s%s', prefix, self)

    def __str__(self):
        response = '{} {}'.format(*self.response) if self.response else ''
        ident = self.parsed.get('activity_id')
        request = self.parsed.get('request')
        point_msg = ''
        if 'points' in self.parsed:
            try:
                points = MMTHandler.parse_points(self.parsed['points'])
                if points:
                    point_msg = ' with {} points: {}'.format(
                        len(points), ','.join(str(round(x.elevation)) for x in points))
            except BaseException as exc:
                point_msg = '{}: {}'.format(exc, self.parsed['points'])
        return '{} {} from {} id={} request={} {} --> {}  {}'.format(
            self.time, self.command, self.from_ip, ident, request, point_msg, self.result, response).strip()


class MMTHandler(BaseHTTPRequestHandler):

    """handles all HTTP requests."""

    users = None
    login_user = None

    def log_info(self, format, *args):  # pylint: disable=redefined-builtin
        """Override: Redirect into logger."""
        self.server.logger.info(format % args)

    def log_error(self, format, *args):  # pylint: disable=redefined-builtin
        """Override: redirect into logger."""
        self.server.logger.error(format % args)

    def check_basic_auth_pw(self):
        """basic http authentication."""
        if 'Authorization' not in self.headers:
            self.send_header('WWW-Authenticate', 'Basic realm=\"Test\"')
            self.return_error(401, 'Authorization required')

        if self.users is None:
            self.load_users()
        for pair in self.users.items():
            expect = b'Basic ' + base64.b64encode(':'.join(pair).encode('utf-8'))
            expect = expect.decode('utf-8')
            if expect == self.headers['Authorization']:
                return
        self.return_error(401, 'Authorization failed')

    def load_users(self):
        """load user authentication data from serverdirectory/.users."""
        self.users = dict()
        users_filename = os.path.join(self.server.targets[0].account.url, '.users')
        if not os.path.exists(users_filename):
            _ = users_filename
            prev = _
            while not os.path.exists(os.path.dirname(_)):
                prev = _
                _ = os.path.dirname(_)
            self.return_error(401, 'Cannot find {}: {} missing'.format(users_filename, prev))
            return
        with open(users_filename) as user_file:
            for line in user_file:
                user, password = line.strip().split(':')
                self.users[user] = password

    def return_error(self, code, reason):
        """Answer the clint with an xml formatted error message."""
        if self.error_response:
            reason = '{} ({})'.format(reason, self.error_response[1])
        self.error_response = (code, reason)  # pylint: disable=attribute-defined-outside-init
        try:
            self.send_response(code, reason)
            xml = '<type>error</type><reason>{}</reason>'.format(reason)
            self.send_header('Content-Type', 'text/xml; charset=UTF-8')
            xml = '<?xml version="1.0" encoding="UTF-8"?><message>{}</message>'.format(xml)
            self.send_header('Content-Length', len(xml))
            self.end_headers()
            self.wfile.write(bytes(xml.encode('utf-8')))
        except BaseException as exc:
            logging.error('return_error failed: %s', exc)

    def parseRequest(self):  # noqa pylint: disable=invalid-name
        """Get interesting things.

        Returns:
            A dict with the parsed results or None

        """
        if 'Content-Length' in self.headers:
            data_length = int(self.headers['Content-Length'])
            data = self.rfile.read(data_length).decode('utf-8')
            parsed = parse_qs(data)
            for key, value in parsed.items():
                if len(value) != 1:
                    self.return_error(400, '{} must appear only once'.format(key))
                parsed[key] = parsed[key][0]
            self._fix_headers(parsed)
            return parsed
        return None

    def homepage(self):
        """Return what the client needs."""
        self.load_users()
        names = list(sorted(self.users.keys()))
        return """
            <input type="hidden" value="{}" name="mid" id="mid" />
            """.format(names.index(self.login_user))

    @staticmethod
    def answer_with_categories():
        """Return all categories."""
        all_cat = MMT.supported_categories
        return ''.join('<li><input name="add-activity-x">&nbsp;{}</li>'.format(x) for x in all_cat)

    def do_GET(self):  # noqa pylint: disable=invalid-name
        """Override standard."""
        self.server.logger.info(
            '%s GET %s %s %s', self.server.second(), self.client_ip(), self.server.server_port, self.path)
        self.parseRequest()  # side effect: may output debug info
        self.send_response(200, 'OK')
        self.send_header('WWW-Authenticate', 'Basic realm="MMTracks API"')
        if self.path == '/':
            xml = self.homepage()
        elif self.path.endswith('/explore/wall'):
            # the client wants to find out legal categories
            xml = self.answer_with_categories()
        elif self.path.startswith('//assets/php/gpx.php'):
            # the client wants the entire gpxfile
            parameters = self.path.split('?')[1]
            request = parse_qs(parameters)
            wanted_id = request['tid'][0]
            xml = self.server.targets[0][wanted_id].xml()
        else:
            xml = ''
        self.send_header('Content-Type', 'text/xml; charset=UTF-8')
        self.server.logger.info('%s  returning %s', self.server.second(), xml)
        self.send_header('Content-Length', len(xml))
        self.end_headers()
        self.wfile.write(bytes(xml.encode('utf-8')))

    def client_ip(self):
        """Get the IP of the client.

        Returns:
            The IP

        """
        return self.headers.get('X-Forwarded-For', self.client_address[0])

    def do_POST(self):  # noqa pylint: disable=invalid-name
        """override standard."""
        self.error_response = None  # pylint: disable=attribute-defined-outside-init
        answer = ''
        message = None
        got_post_time = datetime.datetime.now()
        try:
            parsed = self.parseRequest()
            message = TrackingMessage('POST', self.client_ip(), parsed, answer, self.error_response)
            self.check_basic_auth_pw()
            if self.error_response:
                return
            if self.path != '/api/':
                self.return_error(400, 'Url {}: POST api: Path must be /api/'.format(self.path))
                return
            request = parsed.get('request')
            if not request:
                self.return_error(400, 'No request given in {}'.format(parsed))
                return
            try:
                method = getattr(self, 'xml_{}'.format(request))
            except AttributeError:
                self.return_error(400, 'Unknown request {}'.format(request))
                return
            try:
                answer = method(parsed) or ''
            except BaseException as exc:
                self.return_error(400, '{}: {}'.format(request, exc))
                self.server.logger.debug(traceback.format_exc())
                return
            xml = '<?xml version="1.0" encoding="UTF-8"?><message>{}</message>'.format(answer)
            self.send_response(200, 'OK')
            self.send_header('WWW-Authenticate', 'Basic realm="MMTracks API"')
            self.send_header('Content-Type', 'text/xml; charset=UTF-8')
            self.send_header('Content-Length', len(xml))
            self.end_headers()
            self.wfile.write(bytes(xml.encode('utf-8')))
        finally:
            try:
                if message is None:
                    message = TrackingMessage('POST', self.client_ip(), parsed, answer, self.error_response)
                else:
                    message.result = answer
                    message.response = self.error_response
                message.log(prefix='  ')
                self.server.history.append(message)
            except BaseException as exc:
                logging.error(exc)
            logging.debug('POST turnaround time: %s', datetime.datetime.now() - got_post_time)

    @staticmethod
    def xml_get_time(_) ->str:
        """Get server time as defined by the mapmytracks API.

        Returns:
            Our answer

        """
        return '<type>time</type><server_time>{}</server_time>'.format(
            int(datetime.datetime.now().timestamp()))

    def xml_get_activities(self, parsed) ->str:
        """List all gpxfiles as defined by the mapmytracks API.

        # TODO: untested!

        Returns:
            Our answer

        """
        a_list = list()
        if parsed['offset'] == '0':
            for idx, _ in enumerate(self.server.targets[0]):
                a_list.append(
                    '<track{}><id>{}</id>'
                    '<title><![CDATA[ {} ]]></title>'
                    '<activity_type>{}</activity_type>'
                    '<date>{}</date>'
                    '</track{}>'.format(
                        idx + 1, _.id_in_backend, _.title, _.category,
                        int(_.time.timestamp()), idx + 1))
        return '<tracks>{}</tracks>'.format(''.join(a_list))

    @staticmethod
    def parse_points(raw):
        """convert raw data back into list(GPXTrackPoint).

        Returns:
            list(GPXTrackPoint)

        """
        values = raw.split()
        if len(values) % 4:
            raise Exception('Point element count {} is not a multiple of 4'.format(len(values)))
        result = list()
        for idx in range(0, len(values), 4):
            try:
                time = datetime.datetime.utcfromtimestamp(float(values[idx + 3]))
            except ValueError:
                logging.error('Point has illegal time stamp %s', str(values))
                time = 0
            point = GPXTrackPoint(
                latitude=float(values[idx]),
                longitude=float(values[idx + 1]),
                elevation=float(values[idx + 2]),
                time=time)
            result.append(point)
        return result

    def xml_upload_activity(self, parsed) ->str:
        """Upload an activity as defined by the mapmytracks API.

        Returns:
            Our answer

        """
        gpxfile = GpxFile()
        gpxfile.gpx = Gpx.parse(parsed['gpx_file'])
        self.server.targets[0].add(gpxfile)  # TODO this was self.targets and no unittest triggered
        new_ident = gpxfile.id_in_backend
        for _ in self.server.targets[1:]:
            _.add(gpxfile)
        return '<type>success</type><id>{}</id>'.format(new_ident)

    @staticmethod
    def _fix_headers(parsed):
        """Fix some not so nice things in headers."""
        if 'privicity' in parsed:
            parsed['privacy'] = parsed['privicity']
            del parsed['privicity']
        if parsed.get('request') == 'start_activity' and 'title' not in parsed:
            parsed['title'] = ''
            if parsed.get('source') == 'OruxMaps':
                # shorten monster title 2018-10-03 00:0020181003_0018
                parsed['title'] = parsed['title'][:16]

    def _new_tracker(self, parsed) ->Lifetrack:
        """Create a Lifetrack object.

        Returns: The new object

        """
        result = Lifetrack(self.client_ip(), self.server.targets, tracker_id=parsed.get('activity_id'))
        if result is not None:
            sender_ip = self.client_ip()
            trackers = self.server.trackers
            same_ip_done_trackers = [x for x in trackers.values() if x.sender_ip == result.sender_ip and x.done]
            for _ in same_ip_done_trackers:
                if _.sender_ip in trackers:
                    del trackers[_.sender_ip]
                if _.tracker_id() in trackers:
                    del trackers[_.tracker_id()]
            trackers[sender_ip] = result
        return result

    def find_tracker(self, parsed):
        """Find matching tracker.

        Returns: the wanted tracker or None.

        """
        sender_ip = self.client_ip()
        trackers = self.server.trackers
        request = parsed['request']
        if request == 'start_activity':
            tracker = trackers.get(sender_ip)
            if tracker is not None:
                if tracker.done:
                    del trackers[sender_ip]
                    tracker = None
                else:
                    parsed['activity_id'] = tracker.tracker_id()
        elif 'activity_id' in parsed:
            _ = parsed['activity_id']
            if _ == "0":
                raise KeyError("activity_id 0 is illegal")
            tracker = trackers.get(_)
            if tracker:
                # the sender IP may have changed, we need this for standard stop_activity
                # without activity_id
                trackers[sender_ip] = tracker
        elif request == 'stop_activity':
            # The standard mapmytracks protocol does not include activity_id in stop_activity, but GPS Forwarder does.
            # This branch covers the standard.
            tracker = trackers.get(sender_ip)
        else:
            raise KeyError("activity_id is missing")
        if tracker is None and trackers:
            logging.error(
                '%s: No tracker found for IP=%s and activity_id=%s in %s',
                request, sender_ip, parsed.get('activity_id'),
                ', '.join('{}:{}'.format(k, v.tracker_id()) for k, v in trackers.items()))
        return tracker

    def xml_start_activity(self, parsed) ->str:
        """start Lifetrack server.

        Returns:
            Our answer or None if there was an error

        """
        if self.error_response is not None:
            return None
        tracker = self.find_tracker(parsed)
        try:
            points = self.parse_points(parsed['points'])
        except BaseException as exc:
            self.return_error(400, str(exc))
            return None

        if tracker is None:
            tracker = self._new_tracker(parsed)
            title = parsed.get('title', '')
            public = parsed.get('privacy', 'private') == 'public'
            # the MMT API example uses cycling instead of Cycling,
            # and Oruxmaps does so too.
            category = parsed.get('activity', MMT.supported_categories[0]).capitalize()
            tracker.start(points, title, public, category)
            self.server.trackers[tracker.tracker_id()] = tracker
        else:
            tracker.update_trackers(points)
        logging.info(
            'Starting %s, remote software: %s version %s',
            tracker, parsed.get('source'), parsed.get('version'))
        return '<type>activity_started</type><activity_id>{}</activity_id>'.format(tracker.tracker_id())

    def xml_update_activity(self, parsed) ->str:
        """Get new points.

        Returns:
            Our answer

        """
        updated = '<type>activity_updated</type>'
        tracker = self.find_tracker(parsed)
        if tracker is None:
            tracker = self._new_tracker(parsed)
            self.server.trackers[tracker.tracker_id()] = tracker

        tracker.update_trackers(self.parse_points(parsed['points']))
        if tracker.done:
            logging.error('update_activity: %s was already stopped', tracker)
        return updated

    def xml_stop_activity(self, parsed) ->str:  # pylint: disable=unused-argument
        """Client says stop.

        mapmytracks.com says we do not need to get activity_id here. So we just
        stop all trackers which might be meant. See
        https://github.com/MapMyTracks/api/issues/25

        # TODO: This will fail if the sender ip changes just before stop_activity
        # and I see now way how to fix that.

        Returns:
            Our answer

        """
        trackers = [self.find_tracker(parsed)]
        if trackers == [None]:
            trackers = [x for x in self.server.trackers.values() if not x.done]
        for tracker in trackers:
            if not tracker.done:
                logging.info('Stopping %s', tracker)
                tracker.end()
        return '<type>activity_stopped</type>'


class LifeServerMMT(HTTPServer):

    """A simple MMT server for life tracking.

    Attributes:
        trackers: A dict of all Lifetrack instances, key is the id_in_backend (activity_id).
            No tracker is ever removed. At least for now. Change that some time after
            Oruxmaps got a fix for not doing update_activity after stop_activity.

    """

    def __init__(self, options):
        """See class docstring."""
        self.logger = self.define_logger(options)
        offer_https = options.certfile and options.keyfile
        port = options.port or (443 if offer_https else 80)
        super(LifeServerMMT, self).__init__((options.servername, port), MMTHandler)
        if offer_https:
            self.socket = ssl.wrap_socket(
                self.socket, certfile=options.certfile, keyfile=options.keyfile, server_side=True)
        self.targets = [Directory(DirectoryAccount(options.target[0], id_method='counter'))]
        Account.path = self.targets[0].url + '/accounts'
        self.targets.extend(Backend.instantiate(x) for x in options.target[1:])
        self.start_second = datetime.datetime.now().timestamp()
        self.trackers = dict()
        self.history = list()

    @staticmethod
    def define_logger(options):
        """Setup logging."""
        result = logging.getLogger()
        result.setLevel(options.loglevel.upper())
        logfile = logging.FileHandler(os.path.join(options.target[0], 'gpxity_server.log'))
        logfile.setLevel(logging.DEBUG)
        result.addHandler(logfile)
        logging.getLogger('urllib3').level = logging.DEBUG
        return result

    def second(self):
        """The timestamp.

        Returns: A string.

        """
        return '{:10.4f}'.format(datetime.datetime.now().timestamp() - self.start_second)

def create_parser():
    """Create the options parser.add_argument

    Returns: The parser

    """
    epilog = """
    The MMT server uses BASIC AUTH for login. Define user:password in  the file .users in the first target.
    Define authorization for optional other targets in the file accounts the first target."""
    parser = argparse.ArgumentParser('gpxity_server', epilog=epilog)
    parser.add_argument('--servername', help='the name of this server', required=True)
    parser.add_argument('--certfile', help='if certfile and keyfile are given: offer https')
    parser.add_argument('--keyfile', help='if certfile and keyfile are given, offer https')
    parser.add_argument('--port', help='listen on PORT. Default is 80 for http and 443 for https', type=int)
    parser.add_argument(
        '--loglevel', help='set the loglevel, default is error',
        choices=('debug', 'info', 'warning', 'error'), default='error')
    parser.add_argument(
        'target',
        help='backends who should receive the data. The first one must be Directory',
        nargs='+')
    return parser

class Main:  # pylint: disable=too-few-public-methods

    """main."""

    def __init__(self):
        """See class docstring."""
        parser = create_parser()

        try:
            argcomplete.autocomplete(parser)
        except NameError:
            pass

        options = parser.parse_args()

        LifeServerMMT(options).serve_forever()

if __name__ == '__main__':
    Main()
