#!/usr/bin/env python
"""Tool for cassandra tracing analysis."""

from __future__ import print_function
from cassandra.cluster import Cluster
from cassandra import OperationTimedOut
import re
import argparse
import textwrap


def parse_args():
    """Parse command line arguments.

    Args:
        None

    Returns:
        namespace object:       parsed arguments
    """
    arg_parser = argparse.ArgumentParser(
        description=(
            textwrap.dedent(
                """\
                This script facilitates analysis of Cassandra tracing data,
                which on its own is difficult to draw conclusions from. The
                output displays the following information, sorted by duration:
                  - Session ID
                  - Duration in microseconds
                  - Max number of tombstones encountered on a single session
                  - Total tombstones encountered
                  - Flags signifying special behavior, according to the legend:
                      - R = read repair
                      - T = timeout
                      - I = index used
                  - Starting timestamp of the session (hidden in slim mode)
                  - CQL query or inferred query fragment (hidden in slim mode)

                  The results can be limited by thresholds on both duration and
                  tombstones, and by and a result cap. The result cap gathers
                  the same number of results but caps the display - when using
                  this, the results displayed are always the longest duration
                  events.

                  Regardless of limiting mechanisms, note that full table scans
                  must be done on the "sessions" and "events" tables in the
                  "system_traces" keyspace.
                  """
            )
        ),
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    arg_parser.add_argument(
        "node_ip",
        help="Specify ip of the node being queried."
    )
    arg_parser.add_argument(
        '-b', '--tombstoneThreshold', dest="tombstone_threshold", default=0,
        type=int, help=(
            'Show only sessions who read tombstones equal to or greater than '
            ' this threshold, default 0.'
        )
    )
    arg_parser.add_argument(
        '-t', '--time', dest="time_threshold", default=10000, type=int,
        help=(
            'Show only sessions that took longer than this threshold in '
            'microseconds, default 10000.'
        )
    )
    arg_parser.add_argument(
        '-r', '--resultCap', dest='result_cap', default=100, type=int,
        help=(
            'Maximum number of results to print out, default 100. To print '
            'all results, use 0.'
        )
    )
    arg_parser.add_argument(
        '-s', '--slim', dest='slim', default=False, action='store_true',
        help=(
            'Enable slim mode, where the query and time started are '
            'suppressed in the output.'
        )
    )
    return arg_parser.parse_args()


def print_update(current_cnt, total_cnt):
    """Print progress bar to stdout.

    Args:
        current_cnt (int):                current number of scanned sessions
        total_cnt (int):                  total number of sessions to scan

    Returns:
        none
    """
    percent = int(current_cnt * 100 / total_cnt)
    print(
        "\r{x}% Complete: {bar}|100".format(
            x=str(percent).zfill(3),
            bar=("X" * percent).ljust(100)
        ),
        end=""
    )


def process_sessions(dbsession, time_threshold, tombstone_threshold):
    """Process sessions collection of traces keyspace.

    Args:
        dbsession (Cassandra session object):       system_traces connection
        time_threshold (int):                       minimum time in microsecs
            for a session to qualify for processing
        tombstone_threshold (int):                  minimum number of total
            tombstones scanned to qualify for processing

    Returns:
        list(
            {
                "duration": int (microseconds),
                "session_id": UUID
                "max_tombstones": int,
                "tot_tombstones": int,
                "time_started": datetime,
                "query": str
            }
        )
    """
    sessions = dbsession.execute('SELECT * FROM sessions')
    output = []
    sessions_cnt = dbsession.execute('SELECT COUNT(*) FROM sessions')
    total_cnt = sessions_cnt[0].count
    current_cnt = 0
    skipped_null = []
    skipped_err = []
    for sess in sessions:
        current_cnt += 1
        print_update(current_cnt, total_cnt)
        if not sess.duration:
            skipped_null.append(str(sess.session_id))
            continue
        elif sess.duration >= time_threshold:
            try:
                (
                    tot_tombstones,
                    max_tombstones,
                    time_started,
                    query_frag,
                    flags
                ) = get_event_info(dbsession, sess.session_id)
            except Exception as e:
                skipped_err.append(
                    (str(sess.session_id), str(e))
                )
                continue
            if tot_tombstones >= tombstone_threshold:
                if sess.parameters and sess.parameters.get('query'):
                    query = sess.parameters['query']
                elif query_frag:
                    query = query_frag
                else:
                    query = sess.request
                output.append(
                    {
                        "duration": sess.duration,
                        "session_id": sess.session_id,
                        "max_tombstones": max_tombstones,
                        "tot_tombstones": tot_tombstones,
                        "query": query,
                        "flags": flags,
                        "time_started": time_started
                    }
                )
    print('\n\n')
    print('Total skipped due to null duration:\t{sn}'
          .format(sn=len(skipped_null)))
    print('\t' + '\n\t'.join(skipped_null))
    print('Total skipped due to error:\t{se}'
          .format(se=len(skipped_err)))
    if skipped_err:
        print('\t' + '\n\t'.join(skipped_err[0] + '\t' + skipped_err[1]))
    print('\n\n')
    return sorted(output, key=lambda k: k['duration'])


def get_event_info(dbsession, session_id):
    """Get additional info about the event corresponding toa given session_id.

    The query fragment in the return tuple is the best-guess that this program
    is making by parsing the events as to what the query was operating on in
    case the parameters field does not itself contain the query, as happens
    with thrift clients.

    The flags in the return tuple are a concatenation of single-letter
    representations of query properties, including:
        - read repair (R)
        - timeout (T)
        - index used (I)

    Args:
        dbsession (Cassandra session object):       system_traces connection
        session_id (UUID):                          session_id to count the
            tombstones of

    Returns:
        tuple(
            total tombstone count(int),
            max tombstone count(int),
            time started (ISO str),
            query fragment (str),
            flags (str)
        )
    """
    events = dbsession.execute(
        'SELECT dateOf(event_id), activity FROM events '
        'WHERE session_id={sid}'.format(
            sid=session_id
        )
    )
    max_tombstone_cnt = 0
    tot_tombstone_cnt = 0
    time_started = ''
    query = ''
    flags = set()
    for event in events:
        m = re.search("([0-9]+) tombstone", event.activity)
        if m and int(m.group(1)) > 0:
            tot_tombstone_cnt += int(m.group(1))
            if max_tombstone_cnt < int(m.group(1)):
                max_tombstone_cnt = int(m.group(1))
        if not query:
            if "Parsing" in event.activity:
                query = event.activity[7:]
            elif "idx" in event.activity:
                m = re.search(".*(Scanning.*\.$)", event.activity)
                query = m.group(1)
                flags.add('I')
            elif "memtable" in event.activity:
                query = event.activity
            elif "query" in event.activity:
                query = event.activity
        if "read-repair" in event.activity.lower():
            flags.add('R')
        elif "idx" in event.activity.lower():
            flags.add('I')
        elif "timeout" in event.activity.lower():
            flags.add('T')

        if not time_started:
            time_started = event.dateOf_event_id.isoformat()
    return (tot_tombstone_cnt, max_tombstone_cnt, time_started, query.strip(),
            ''.join(flags))


def print_output(output, result_cap, show_less=False):
    """Print the results of the analysis to stdout.

    Args:
        output(list<dict>):     List of dictionaries, see process_sessions
                                documentation for schema
        result_cap(int):        Limit on number of results to show (preferring
                                longest duration events)
        show_less(boolean, optional):   Show less information to be friendly
                                        for small screens (i.e. "slim mode")

    Returns:
        None
    """
    print("{x} sessions satisfying criteria.".format(x=len(output)))
    if result_cap < len(output):
        print("Showing {x} longest running results.".format(x=result_cap))
    print(
        "{sessid}{duration}{maxtomb}{tottomb}{flag}{ts}{query}".format(
            sessid="Session Id".center(40),
            duration="Duration (microsecs)".center(22),
            maxtomb="Max Tombstones".center(20),
            tottomb="Total Tombstones".center(20),
            flag="Flags".center(10),
            ts="" if show_less else "Time Started".center(30),
            query="" if show_less else "Query"
        )
    )
    if result_cap:
        output = output[-result_cap:]
    for entry in output:
        print(
            "{sessid}{duration}{maxtomb}{tottomb}{flag}{ts}{query}".format(
                sessid=str(entry["session_id"]).center(40),
                duration=str(entry["duration"]).center(22),
                maxtomb=str(entry["max_tombstones"]).center(20),
                tottomb=str(entry["tot_tombstones"]).center(20),
                flag=entry["flags"].center(10),
                ts="" if show_less else str(entry["time_started"]).center(30),
                query="" if show_less else entry["query"]
            )
        )
    print("EOM")


def main():
    """Run tracing processor, used only if file is executed as a script.

    Args:
        None

    Returns:
        None
    """
    args = parse_args()
    cluster = Cluster([args.node_ip])
    dbsession = cluster.connect('system_traces')
    output = process_sessions(
        dbsession,
        args.time_threshold,
        args.tombstone_threshold
    )
    print_output(output, args.result_cap, args.slim)

if __name__ == '__main__':
    try:
        main()
    except OperationTimedOut:
        print(
            'Timeout in gossip communication, ensure cluster is in steady '
            'state (no nodes leaving or joining) before retrying'
        )
