#!/usr/bin/env python

# ansible_debug_logparser:
#   A script to combine and enumerate ansible's debug stdout and syslog outputs.
#
# Purpose:
#   Performance troubleshooting ansible can be difficult. If an issue can't
#   be narrowed down to a single task and a single host, this script will
#   aggregate all the debug data into human readable insights.
#
# Usage:
#   1) Run ansible with debug logging and with the logfile set ...
#       ANSIBLE_DEBUG=1 ANSIBLE_LOG_PATH=ansible_syslog.log ansible-playbook -vvvv ... | tee -a ansible_stdout.log
#   2) Pass both log files to this script ...
#       ./ansible_debug_logparser ansible_syslog.log ansible_stdout.log
#

import argparse
import datetime
import csv
import os
import re
import sys
import time

from collections import OrderedDict
from pprint import pprint


def split_executor_line(line):
    '''Chop all of the info from a taskexecutor log entry'''

    # 2018-10-12 01:29:39,173 p=5489 u=vagrant |    7705 1539307779.17295:
    #   running TaskExecutor() for sshd_145/TASK: Check for /usr/local/sync (Target Directory)
    # 2018-10-12 01:29:39,654 p=5489 u=vagrant |    7591 1539307779.65405:
    #   done running TaskExecutor() for sshd_60/TASK: Check for /usr/local/sync (Target Directory) [525400a6-0421-65e9-9a84-000000000032]
    # 5502 1539307714.25537: done running TaskExecutor() for sshd_250/TASK: wipe out the rules [525400a6-0421-65e9-9a84-00000000002e]

    parts = line.split()
    if parts[4] != '|' and not parts[0].isdigit():
        orig_parts = parts[:]
        teidx = parts.index('TaskExecutor()')
        if 'done running TaskExecutor' in line:
            parts = parts[teidx-4:]
        else:
            parts = parts[teidx-3:]
        if not parts[0].isdigit():
            badchars = [x for x in parts[0] if not x.isdigit()]
            #safechars = parts[0].split(badchars[-1])[-1]
            parts[0] = parts[0].split(badchars[-1])[-1]
            #import epdb; epdb.st()

    if parts[4] == '|':
        # pylogging
        date = parts[0]
        time = parts[1]
        ppid = int(parts[2].replace('p=', ''))
        uid = parts[3].replace('u=', '')
        pid = int(parts[5])
        ts = float(parts[6].replace(':', ''))
    else:
        # stdout+stderr
        date = None
        time = None
        ppid = None
        uid = None
        try:
            pid = int(parts[0])
        except Exception as e:
            print(e)
            import epdb; epdb.st()
        ts = float(parts[1].replace(':', ''))

    uuid = None
    if parts[-1].startswith('[') and parts[-1].endswith(']'):
        uuid = parts[-1].replace('[', '').replace(']', '')

    for_index = parts.index('for')
    host = parts[for_index+1].split('/', 1)[0]

    if uuid:
        task = ' '.join(parts[for_index+2:-1])
    else:
        task = ' '.join(parts[for_index+2:])

    return {
        'date': date,
        'time': time,
        'ts': ts,
        'ppid': ppid,
        'pid': pid,
        'uid': uid,
        'uuid': uuid,
        'host': host,
        'task': task
    }


def split_ssh_exec(line):
    '''Chop all of the info out of an ssh connection string'''

    # <dockerhost> SSH: EXEC sshpass -d90 ssh -vvv -C -o ControlMaster=auto
    #   -o ControlPersist=60s -o StrictHostKeyChecking=no -o Port=33017
    #   -o User=root -o ConnectTimeout=10 -o ControlPath=/home/vagrant/.ansible/cp/da9b210846
    #   dockerhost '/bin/sh -c '"'"'echo ~root && sleep 0'"'"''

    parts = line.split()
    hostname = parts[0].replace('<', '').replace('>', '')
    try:
        port = re.search(r"(?<=Port=).*?(?=\ )", line).group(0)
    except:
        port = None

    if 'stricthoskeychecking=no' in line.lower():
        hostkey_checking = False
    else:
        hostkey_checking = True

    if 'sshpass' in line:
        sshpass = True
    else:
        sshpass = False

    user = re.search(r"(?<=User=).*?(?=\ )", line).group(0)

    if '-o ControlMaster=' in line:
        cp = True
    else:
        cp = False
    try:
        timeout = re.search(r"(?<=ConnectTimeout=).*?(?=\ )", line).group(0)
    except:
        timeout = None
    try:
        cp_path = re.search(r"(?<=ConnectTimeout=).*?(?=\ )", line).group(0)
    except:
        cp_path = None

    return {
        'hostname': hostname,
        'hostkey_checking': hostkey_checking,
        'port': port,
        'user': user,
        'cp': cp,
        'cp_path': cp_path,
        'sshpass': sshpass,
        'timeout': timeout,
    }


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--dest', default=None)
    parser.add_argument('--task')
    parser.add_argument('--host')
    parser.add_argument('--timegaps', action='store_true',
                        help="find gaps in time")
    parser.add_argument('--host-durations', action='store_true',
                        help="show total duration for each host")
    parser.add_argument('filename', nargs='+')
    args = parser.parse_args()

    filenames = args.filename[:]

    total_forks = None
    hostsmeta = {}
    tasks = OrderedDict()
    pids = OrderedDict()
    pidsmeta = {}

    # iterate files and lines to classify and string chop them
    for fn in filenames:
        with open(fn, 'r') as f:
            current_task_name = None
            for line in f.readlines():
                if not line.strip():
                    continue

                if 'p=' in line and 'u=' in line:
                    # pylogging entries
                    if ': running TaskExecutor() for ' in line:
                        data = split_executor_line(line)
                        if data['host'] not in hostsmeta:
                            hostsmeta[data['host']] = {}
                        if data['task'] not in tasks:
                            tasks[data['task']] = OrderedDict()
                        tasks[data['task']][data['host']] = {
                            'start': data.copy()
                        }
                    elif ': done running TaskExecutor() for ' in line:
                        data = split_executor_line(line)
                        tasks[data['task']][data['host']]['stop'] = data.copy()

                else:
                    # stdout+stderr logs
                    #print(line)

                    if line.startswith('TASK'):
                        current_task_name = re.search(r"(?<=TASK \[).*?(?=\])", line).group(0)
                        #import epdb; epdb.st()
                    elif 'SSH: EXEC' in line:
                        data = split_ssh_exec(line)
                    elif re.search(r"\ [0-9]+\.[0-9]+\:", line):
                        numbers = re.findall(r"[0-9]+", line)

                        if 'worker is' in line and 'out of' in line:
                            total_forks = int(numbers[-1])

                        #ts = float(numbers[1] + '.' + numbers[2])
                        pid = int(numbers[0])
                        if pid not in pids:
                            pids[pid] = {'log': []}
                        pids[pid]['log'].append(line.lstrip())
                    else:
                        #  45575 1539795809.22047: done sending task result for
                        #       task 005056a7-cdb4-2ab2-7a6e-00000000007b
                        #print(line)
                        #import epdb; epdb.st()
                        pass

    # further eval each pid's data
    for pid,pid_data in pids.items():
        isparent = False
        timestamps = []
        hosts = []
        task_name = None
        task_uuid = None

        for line in pid_data['log']:
            line = line.rstrip()

            numbers = re.findall(r"[0-9]+", line)
            ts = float(numbers[1] + '.' + numbers[2])
            timestamps.append(ts)

            parts = line.split()
            if 'starting run' in line:
                isparent = True
            if 'running TaskExecutor()' in line:
                data = split_executor_line(line)
                if data['host'] not in hosts:
                    hosts.append(data['host'])
                task_name = data['task']
                if data['uuid']:
                    task_uuid = data['uuid']

        duration = timestamps[-1] - timestamps[0]

        if pid not in pidsmeta:
            pidsmeta[pid] = {}
        pidsmeta[pid]['pid'] = pid
        pidsmeta[pid]['task_name'] = task_name
        pidsmeta[pid]['task_uuid'] = task_uuid
        pidsmeta[pid]['isparent'] = isparent
        pidsmeta[pid]['start'] = timestamps[0]
        pidsmeta[pid]['stop'] = timestamps[-1]
        pidsmeta[pid]['duration'] = duration
        pidsmeta[pid]['hosts'] = hosts[:]

    # find the duration for each task for each host
    host_durations = {}
    for task,hosts in tasks.items():
        for host,results in hosts.items():
            duration = results['stop']['ts'] - results['start']['ts']
            tasks[task][host]['duration'] = duration
            if host not in host_durations:
                host_durations[host] = 0.0
            host_durations[host] += duration

    # find the slowest host among the group
    sorted_durations = sorted(host_durations.items(), key=lambda x: x[1])
    durations = [x[1] for x in sorted_durations]
    avg = sum(durations) / float(len(durations))

    print('# total forks: %s' % total_forks)
    print('# average total duration for each host: %ss' % avg)

    sh = sorted_durations[-1][0]
    print('# slowest host')
    print(' name: %s' % sh)
    print(' total duration: %ss' % sorted_durations[-1][1])
    _pids = [x for x in pidsmeta.items() if sh in x[1]['hosts']]
    bychrono = sorted(_pids, key=lambda x: x[1]['start'])
    byduration = sorted(_pids, key=lambda x: x[1]['duration'])
    print(' slowest task: [p=%s] %s (t=%ss)' % (
        byduration[-1][1]['pid'],
        byduration[-1][1]['task_name'],
        byduration[-1][1]['duration']
    ))

    if args.host_durations:
        hds = host_durations.items()
        hds = [[x[1],x[0]] for x in hds]
        hds = sorted(hds, key=lambda x: x[0], reverse=True)
        print('# total duration for each host ...')
        for hd in hds:
            #print('%s - %ss' % (hd[1], hd[0]))
            print("{0:20} - {1}".format(hd[1],hd[0]))

    elif args.timegaps:
        # find gaps in time for tasks+hosts

        # all time gaps
        gaps = []

        # iterate pids and calculate gaps between each host's log entries
        for pid,pinfo in pids.items():
            if pid == pids.keys()[0]:
                continue

            if args.task or args.host:
                pmeta = pidsmeta[pid]
                if args.task and pmeta['task_name'] is None:
                    continue
                if args.task is not None and args.task not in pmeta['task_name']:
                    continue
                if args.host and not pmeta['hosts']:
                    continue
                if args.host is not None and args.host not in pmeta['hosts']:
                    continue

                #import epdb; epdb.st()

            t0 = None
            for idx,x in enumerate(pinfo['log']):
                tn = re.search(r' [0-9]+\.[0-9]+\: ', x).group()
                tn = tn.replace(':', '')
                tn = float(tn.strip())
                if t0 is None:
                    t0 = tn
                    continue

                delta = tn - t0
                t0 = tn
                if delta > 0.0:
                    gaps.append([delta, pid, idx])
                #print(delta)

        # sort the gaps and print info about the top X ...
        gaps = sorted(gaps, key=lambda x: x[0])
        for idg,gap in enumerate(gaps[-10:][::-1]):
            pid = gap[1]
            pinfo = pidsmeta[pid]
            if pinfo['hosts']:
                host = pinfo['hosts'][0]
            else:
                #import epdb; epdb.st()
                host = None
            task_name = pinfo['task_name']
            print('# %s. %s second gap for host: %s in task: %s' % (idg, gap[0], host, task_name))

            linenos = range(gap[2]-5, gap[2]+1)
            for lineno in linenos:
                if lineno < 0:
                    continue
                try:
                    line = pids[pid]['log'][lineno]
                except IndexError:
                    continue
                line = re.search(r' [0-9]+\.[0-9]+\: .*', line).group()
                print('\t' + ' ' + str(lineno) + '.' + ' ' + line)
            #import epdb; epdb.st()
        #import epdb; epdb.st()

    elif args.task and args.host:
        # pick out the log entries for a specific task+host
        task = tasks.get(args.task)
        host = task.get(args.host)
        _pid = host['start']['pid']
        log = pids[_pid]['log']
        for idx,x in enumerate(log):
            log[idx] = x.replace(str(_pid) + ' ', '', 1)

        t0 = log[0].split(':', 1)[0]
        t0 = float(t0)
        for idx,x in enumerate(log):
            ts = x.split(':', 1)[0]
            delta = float(ts) - t0

            #ets = time.gmtime(float(ts))
            ets = datetime.datetime.fromtimestamp(float(ts))

            insert = ' %s ' % ets.isoformat()
            insert += ' %s ' % str(delta)
            log[idx] = x.replace(':', insert, 1)

        if args.dest:
            with open(args.dest, 'w') as f:
                f.writelines(log)
        else:
            for line in log:
                print(line.rstrip())

    else:
        # dump data to file for comparison
        #print('# task breakdowns')
        rows = [['task', 'host', 'duration', 'lag', 'start', 'stop']]
        for task,td in tasks.items():
            _hosts = sorted(td.keys())
            _hosts_total = len(_hosts)
            #print('%s' % task)

            # lag is the delta between a host starting vs task starting
            starts = [td[x]['start']['ts'] for x in _hosts]
            time0 = min(starts)

            for _host in _hosts:
                #print('\t%s - %s' % (_host, td[_host]['duration']))
                lag = td[_host]['start']['ts'] - time0
                row = [task, _host, td[_host]['duration'], lag, td[_host]['start']['ts'], td[_host]['stop']['ts']]
                rows.append(row)
                #import epdb; epdb.st()

        with open(args.dest, 'w') as f:
            cw = csv.writer(f)
            for row in rows:
                cw.writerow(row)


if __name__ == "__main__":
    main()
