#!/usr/bin/env python

# This bisection script helps isolate commits that changed a behavior of ansible
# To use it:
#   1) check the --help output
#   2) make an ansible checkout
#   3) source hacking/env-setup before running the bisect
#   4) ansible-bisect <checkoutdir> <testsript> <args>

import os
import sys
import subprocess
import argparse
import datetime
import pickle
from distutils.version import LooseVersion
from pprint import pprint

DEVELDIR = "/var/cache/ansible/ansible.checkout.clean"
BASEDIR = "/var/cache/ansible/version_checkouts"
LOGFILE = "returncodes.txt"


class Commit(object):
    _logdata = {}
    _hash = None
    _date = None
    _timestamp = None
    _version = None

    def __init__(self, logdata):
        self._logdata = logdata
        self._hash = None
        self._date = None
        self._timestamp =  None
        self._isoformat = None

    def __repr__(self):
        return '{} {}'.format(self.isoformat, self.commithash)

    def _splitdate(self, instring):
        # Thu Jan 4 23:55:38 2018 -0800
        # Thu Jan 4 23:55:38 2018
        parts = instring.split()
        return ' '.join(parts[:-1])

    @property
    def commithash(self):
        return self._logdata.get('hash')

    @property
    def timestamp(self):
        if not self._timestamp:
            self._timestamp = datetime.datetime.strptime(
                self._splitdate(self._logdata.get('commitdate')),
                "%a %b %d %H:%M:%S %Y"
            )
        return self._timestamp

    @property
    def isoformat(self):
        if not self._isoformat:
            self._isoformat = self.timestamp.isoformat()
        return self._isoformat


def run_command(args, capture=False, shell=True, cwd=None):
    if not capture:
        p = subprocess.Popen(args, shell=True, cwd=cwd)
    else:
        p = subprocess.Popen(args, shell=True, cwd=cwd,
                             stdout=subprocess.PIPE,
                             stderr=subprocess.PIPE)
    (so, se) = p.communicate()
    return (p.returncode, so, se)


def clean_checkout(checkoutdir):
    '''Remove all .pyc files in a checkout'''

    cmd = 'find %s -type f -name "*.pyc" | xargs rm -f' % checkoutdir
    (rc, so, se) = run_command(cmd, capture=True)

    #cmd = 'rm -rf %s ; cp -Rp ansible.clean %s' % (checkoutdir, checkoutdir)
    #(rc, so, se) = run_command(cmd, capture=True)


'''
def get_commits(checkoutdir):
    cmd = 'cd %s; git log --pretty=oneline' % checkoutdir
    (rc, so, se) = run_command(cmd, capture=True)
    lines = so.split('\n')
    commits = [x.strip() for x in lines if x.strip()]
    commits = [x.split()[0] for x in commits]
    return commits
'''

def get_commits(checkoutdir):
    cmd = 'cd %s; git log --pretty=fuller' % checkoutdir
    (rc, so, se) = run_command(cmd, capture=True)
    lines = so.split('\n')
    lines = [x for x in lines if x.strip()]

    commits = []
    thiscommit = {}
    linecount = len(lines)
    for idl,line in enumerate(lines):

        parts = line.split()

        if line.startswith('commit '):

            if thiscommit or idl >= linecount:
                commits.append(thiscommit.copy())

            thiscommit = {
                'hash': parts[1],
                'author': None,
                'authordate': None,
                'commiter': None,
                'commitdate': None,
                'message': ''
            }

        elif line.startswith('Author:'):
            thiscommit['author'] = ' '.join(parts[1:])
        elif line.startswith('AuthorDate:'):
            thiscommit['authordate'] = ' '.join(parts[1:])
        elif line.startswith('CommitDate:'):
            thiscommit['commitdate'] = ' '.join(parts[1:])
        elif line.startswith('Commit:'):
            thiscommit['commiter'] = ' '.join(parts[1:])

    for idx,x in enumerate(commits):
        y = Commit(x)
        commits[idx] = y

    return commits


def get_checkout_date(checkoutdir):
    xdate = None
    cmd = 'cd %s; git log -1 --pretty=fuller' % (checkoutdir)
    (rc, so, se) = run_command(cmd, capture=True)
    for x in so.split('\n'):
        if x.strip().lower().startswith('commitdate:'):
            # 'Date:   Fri Nov 11 12:48:45 2016 -0500'
            x = x.replace('CommitDate:', '', 1).strip()
            #xparts = x.split()
            x = x.split('-', 1)[0].strip()
            x = x.split('+', 1)[0].strip()
            xdate = datetime.datetime.strptime(x, "%a %b %d %H:%M:%S %Y")
            break
    return xdate


def get_date_cache(cachedir='.cache'):
    chash = {}
    if not os.path.isdir(cachedir):
        os.makedirs(cachedir)
    cfile = '%s/ansible-commit-dates.pickle' % cachedir
    cfile = os.path.expanduser(cfile)
    if os.path.isfile(cfile):
        try:
            with open(cfile, 'rb') as f:
                chash = pickle.load(f)
        except Exception as e:
            print(e)
            import epdb; epdb.st()
    return chash


def save_date_cache(chash, cachedir='.cache'):
    if not os.path.isdir(cachedir):
        os.makedirs(cachedir)
    cfile = '%s/ansible-commit-dates.pickle' % cachedir
    cfile = os.path.expanduser(cfile)
    with open(cfile, 'wb') as f:
        pickle.dump(chash, f)


def get_commit_date(commitid, checkoutdir):
    xdate = None
    cmd = 'cd %s; git show --pretty=fuller %s' % (checkoutdir, commitid)
    (rc, so, se) = run_command(cmd, capture=True)
    for x in so.split('\n'):
        if x.strip().lower().startswith('commitdate:'):
            # 'Date:   Fri Nov 11 12:48:45 2016 -0500'
            x = x.replace('CommitDate:', '', 1).strip()
            #xparts = x.split()
            x = x.split('-', 1)[0].strip()
            x = x.split('+', 1)[0].strip()
            xdate = datetime.datetime.strptime(x, "%a %b %d %H:%M:%S %Y")
            break
    return xdate


def checkout_commit(checkoutdir, githash):

    print('\tcommit: %s' % githash)
    if isinstance(githash, Commit):
        _githash = Commit
        githash = githash.commithash

    # checkout
    cmd = 'cd %s; git checkout %s' % (checkoutdir, githash)
    (rc, so, se) = run_command(cmd, capture=True)
    print('\tcheckout: %s' % rc)
    assert rc == 0, "unable to checkout %s" % githash

    # update submodules
    cmd = 'cd %s; git submodule update --init --recursive' % checkoutdir
    (rc, so, se) = run_command(cmd, capture=True)
    print('\tsubmod: %s' % rc)
    assert rc == 0, "unable to update submodules"


def _get_hacking_dir(checkoutdir):
    if not os.path.isdir(os.path.join(checkoutdir, 'hacking')):
        # possibly using the submodule dir as the checkout
        hacking_dir = os.path.abspath('%s/../../../../hacking' % checkoutdir)
        if not os.path.isdir(hacking_dir):
            print('NO HACKING DIRECTORY COULD BE FOUND IN CHECKOUT!!!')
            #sys.exit(1)
            return False
    else:
        hacking_dir = os.path.join(checkoutdir, 'hacking')
    return hacking_dir


def get_version(checkoutdir, checkrc=False):
    # get the reported version
    hacking_dir = _get_hacking_dir(checkoutdir)
    if not hacking_dir:
        return (1, '', '', '')


    cmd = 'source %s/env-setup ; ansible --version' % hacking_dir
    (rc, so, se) = run_command(cmd, capture=True)

    version = None
    if rc == 0:
        for x in so.split('\n'):
            if x.startswith('ansible ') or 'detached HEAD' in x:
                version = x.strip()
                version = version.split()[1].strip()
                break

    if not version and rc == 0:
        import epdb; epdb.st()

    #if version == 'egg_info':
    #    import epdb; epdb.st()

    return (rc, so, se, version)


def run_test(checkoutdir, testscript, capture=True):
    if not checkoutdir.startswith('/'):
        checkoutdir = os.path.abspath(checkoutdir)
    if not testscript.startswith('/'):
        testscript = os.path.abspath(testscript)

    hacking_dir = _get_hacking_dir(checkoutdir)

    cmd = 'source %s/env-setup ; %s' % (hacking_dir, testscript)
    (rc, so, se) = run_command(cmd, capture=capture, cwd=os.getcwd())
    return (rc, so, se)


"""
def slice_dates(checkoutdir, commits, start, stop):
    '''Determine date for commits and limit to a given range'''
    #start_idx = None
    #stop_idx = None

    if start:
        start = datetime.datetime.strptime(start, '%Y-%m-%d')
    if stop:
        stop = datetime.datetime.strptime(stop, '%Y-%m-%d')

    # get dates for the commits
    commit_dates = get_date_cache()
    for cid,commitid in enumerate(commits):
        if commitid not in commit_dates:
            # get the date for commit
            ts = get_commit_date(commitid, checkoutdir)
            commit_dates[commitid] = ts
            print('%s %s' % (commitid, commit_dates[commitid]))

    # store dates to save time later
    save_date_cache(commit_dates)

    if start and stop:
        commits = [x for x in commits
                   if commit_dates[x] > start and commit_dates[x] < stop]
    elif start and not stop:
        commits = [x for x in commits if commit_dates[x] > start]
    elif not start and stop:
        commits = [x for x in commits if commit_dates[x] < stop]

    return commits
"""

def slice_dates(checkoutdir, commits, start, stop):
    '''Determine date for commits and limit to a given range'''
    #start_idx = None
    #stop_idx = None

    if start:
        start = datetime.datetime.strptime(start, '%Y-%m-%d')
    if stop:
        stop = datetime.datetime.strptime(stop, '%Y-%m-%d')

    if start and stop:
        commits = [x for x in commits
                   if x.timestamp > start and x.timestamp < stop]
    elif start and not stop:
        commits = [x for x in commits if x.timestamp > start]
    elif not start and stop:
        commits = [x for x in commits if x.timestamp < stop]

    return commits

def get_current_branch(checkoutdir):
    cmd = "git branch | egrep ^* | awk '{print $2}'"
    (rc, so, se) = run_command(cmd, capture=True, cwd=checkoutdir)
    return so.strip()


class AnsibleBisector(object):
    results = {}
    commits = []
    checkrc = 0
    negative_test = False
    linear = False
    linear_increment = 1
    direction = 'backward'
    version_start = None
    version_stop = None
    date_start = None
    date_stop = None
    marker = None

    # JUST FOR TESTING
    tcommits = None
    expected = None

    def __init__(self, commits, **kwargs):
        self.expected = None
        self.commits = commits
        self.results = {}
        self.current_index = None

    def print_status(self):
        if self.marker is not None:
            for idx,x in enumerate(self.commits):
                if idx <= (self.marker - 100) or idx >= (self.marker + 100):
                    continue

                rc = self.results.get(x, {}).get('rc', None)
                if self.marker and idx == self.marker:
                    print('%s %s %s <-- marker' % (idx, x, rc))
                elif idx == self.current_index:
                    print('%s %s %s <-- current' % (idx, x, rc))
                elif self.commits[idx] == self.expected:
                    print('%s %s %s <-- expected' % (idx, x, rc))
                else:
                    print('%s %s %s' % (idx, x, rc))
        else:
            for idx,x in enumerate(self.commits):
                rc = self.results.get(x, {}).get('rc', None)
                print('%s %s %s' % (idx, x, rc))

    def get_untested_count(self):
        if self.marker:
            choices = []
            for idx,x in enumerate(self.commits):
                if idx >= self.marker and x not in self.results:
                    choices.append(x)
        else:
            choices = [x for x in self.commits if x not in self.results]

        return len(choices)

    def get_bisected_commit(self):
        # we should have a marker at this point
        if self.marker is None:
            import epdb; epdb.st()
            return None

        a = self.commits[self.marker]
        b = self.commits[self.marker+1]

        if self.results[a]['rc'] == self.checkrc and \
                self.results[b]['rc'] != self.checkrc:
            #self.print_status()
            if self.expected and b != self.expected:
                self.print_status()
                import epdb; epdb.st()
            return b
        else:
            import epdb; epdb.st()

    def get_commit_to_test(self):
        '''Tell the caller what to test next'''

        if self.linear:
            #print(self.direction)
            #old = self.current_index
            #print(self.current_index)

            if self.current_index is None:
                if self.direction == 'backward':
                    #self.current_index = len(self.commits) - 1
                    self.current_index = 0
                else:
                    #self.current_index = 0
                    self.current_index = len(self.commits) - self.linear_increment
            else:
                if self.direction == 'backward':
                    #print('adding 1')
                    #self.current_index -= 1
                    self.current_index = self.current_index + self.linear_increment
                else:
                    #print('subtract 1')
                    #self.current_index += 1
                    self.current_index = self.current_index - self.linear_increment

            #if old == self.current_index:
            #    import epdb; epdb.st()
            #print(self.current_index)
            #print(len(self.commits))

            return self.commits[self.current_index]

        # start at the middle
        if not self.current_index:
            self.current_index = len(self.commits) / 2

        try:
            chash = self.commits[self.current_index]
        except IndexError as e:
            self.current_index -= 1
            chash = self.commits[self.current_index]
            #print(e)
            #import epdb; epdb.st()

        if chash in self.results:
            # looped back to a tested commit
            if self.marker is not None:
                next_index = self.marker + 1
            else:
                next_index = 0

            try:
                next_commit = self.commits[next_index]
            except IndexError as e:
                print('WARNING: ran out of commits: %s' % e)
                return None

            if next_commit not in self.results:
                self.current_index = next_index
                return self.commits[self.current_index]
            elif self.marker is None:
                import epdb; epdb.st()
            else:
                # No more to test...
                return None
        else:
            return self.commits[self.current_index]

    def remove_commit(self, commit):
        '''Remove a completely busted commit'''
        if commit in self.results:
            self.results.pop(commit, None)

        if commit not in self.commits:
            import epdb; epdb.st()
            return True

        # need to move the marker back one if it's greater than this index
        cindex = self.commits.index(commit)
        if self.marker is not None and self.marker > cindex:
            self.marker -= 1

        # iterate to the next testable commit index
        nindex = cindex + 1
        while True and nindex < len(self.commits):
            if self.commits[nindex] not in self.results:
                break
            nindex += 1

        if nindex > len(self.commits) - 1:
            # went over the size of the list, need to go backwards
            if self.marker is not None:
                nindex = self.marker - 1
            else:
                nindex = 0
            while True and nindex < len(self.commits):
                if self.commits[nindex] not in self.results:
                    break
                nindex += 1

        if nindex > len(self.commits) - 1:
            import epdb; epdb.st()

        self.current_index = nindex

        # drop this commit from the list
        self.commits.remove(commit)

    def set_result(self, commit, returncode, stdout, stderr):
        '''Caller reports the test result with this'''
        if commit not in self.results:
            self.results[commit] = {'rc': None, 'so': None, 'se': None}
        self.results[commit]['rc'] = returncode
        self.results[commit]['so'] = stdout
        self.results[commit]['se'] = stderr

        # move the marker
        self.set_current_index()

    def matches(self, rc):
        '''Need to handle negative testing'''
        if not self.negative_test and rc == self.checkrc:
            return True
        elif not self.negative_test and rc != self.checkrc:
            return False
        elif self.negative_test and rc == self.checkrc:
            return False
        elif self.negative_test and rc != self.checkrc:
            return True

    def set_current_index(self):

        if self.linear:
            return True

        last_index = self.current_index
        last_commit = self.commits[last_index]
        last_rc = self.results[last_commit]['rc']
        #last_so = self.results[last_commit]['so']
        #last_se = self.results[last_commit]['se']

        #if last_rc == self.checkrc:
        if self.matches(last_rc):
            # move forward if marker is greater
            if self.marker is None:
                self.marker = last_index
            elif last_index > self.marker:
                self.marker = last_index
            choices = self.commits[last_index:]
            new_index = len(choices) / 2
            self.current_index = last_index + new_index
        else:
            # move backward
            if self.marker is not None:
                choices = self.commits[self.marker:last_index]
            else:
                choices = self.commits[:last_index]
            new_index = len(choices) / 2
            if self.marker is not None:
                self.current_index = self.marker + new_index
            else:
                self.current_index = new_index


def find_first_version_commit(version, checkoutdir, commits):

    CMAP = {}

    ab = AnsibleBisector(commits)

    tc = ab.get_commit_to_test()
    while tc:

        if tc not in ab.commits:
            tc = ab.get_commit_to_test()

        print('# searching for start of %s' % version)

        CMAP[tc] = {'date': None, 'version': None}

        # get rid of previous remnants
        clean_checkout(checkoutdir)

        # set the checkout version
        aversion = checkout_commit(checkoutdir, tc)

        # get timestamp
        CMAP[tc]['date'] = get_checkout_date(checkoutdir)
        print('\tdate: %s' % CMAP[tc]['date'])

        # get the version and use as a sanity check
        (vrc, aso, ase, aversion) = get_version(checkoutdir)
        print('\tversion: %s' % aversion)

        # skip if failed
        if vrc != 0:
            ab.remove_commit(tc)
            CMAP.pop(tc, None)
            continue

        aversion = LooseVersion(aversion)
        CMAP[tc]['version'] = aversion
        #print('\t%s' % aversion)

        #pprint(CMAP[tc])
        #import epdb; epdb.st()

        if aversion < version:
            ab.set_result(tc, 0, '', '')
        else:
            ab.set_result(tc, 1, '', '')

        tc = ab.get_commit_to_test()

    candidates = [x for x in CMAP.items() if x[1]['version'] < version]
    candidates = sorted(candidates, key=lambda tup: tup[1]['date'])
    lc = candidates[-1][0]
    lc_index = commits.index(lc)
    fc = commits[lc_index+1]
    return fc


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('checkoutdir')
    parser.add_argument('testscript')

    parser.add_argument('--direction',
                        help="which direction to test foward|[backward]",
                        choices=['backward', 'forward'], default='backward')
    parser.add_argument('--checkrc',
                        help="what rc should trigger a stop (can be !<INT>)",
                        type=str,
                        default='1')
    parser.add_argument('--ignorerc',
                        help="what returncodes to ignore",
                        type=int,
                        default=1)
    parser.add_argument('--version_start',
                        help="what ansible version to start testing at")
    parser.add_argument('--version_stop',
                        help="what ansible version to stop testing at")
    parser.add_argument('--date_start',
                        help="what date to start testing at (YYYY-MM-DD)")
    parser.add_argument('--date_stop',
                        help="what date to stop testing at (YYYY-MM-DD)")
    parser.add_argument('--linear', action='store_true')
    parser.add_argument('--linear_increment', type=int, default=1)
    parser.add_argument('--show_errors', action='store_true')
    args = parser.parse_args()

    # fixup args
    if args.version_start:
        args.version_start = LooseVersion(args.version_start)
    if args.version_stop:
        args.version_stop = LooseVersion(args.version_stop)
    if args.checkrc.startswith('!'):
        negative_test = True
        args.checkrc = int(args.checkrc.replace('!', ''))
    else:
        negative_test = False
        args.checkrc = int(args.checkrc)

    # what is the primary branch?
    cbranch = get_current_branch(args.checkoutdir)
    if cbranch.endswith('HEAD'):
        cbranch = 'HEAD'

    # reset the checkout
    print('Resetting branch')
    checkout_commit(args.checkoutdir, cbranch)

    # get all known commits
    commits = get_commits(args.checkoutdir)

    # limit commits to a date range
    if args.date_start or args.date_stop:
        print('Total commits before slicing: %s' % len(commits))
        commits = slice_dates(
            args.checkoutdir,
            commits,
            args.date_start,
            args.date_stop
        )
        print('Total commits after slicing: %s' % len(commits))

    # reverse list if direction should be forward
    if args.direction == 'forward':
        commits = [x for x in reversed(commits)]


    # mapping of commits to their timestamps
    TSMAP = {}

    # slice commits by version
    if args.version_start:
        sfc = find_first_version_commit(
            args.version_start,
            args.checkoutdir,
            commits
        )
        sfc_ix = commits.index(sfc)
        commits = commits[sfc_ix:]

    ab = AnsibleBisector(commits)
    ab.negative_test = negative_test
    ab.direction = args.direction
    ab.linear = args.linear
    ab.linear_increment = args.linear_increment

    tc = ab.get_commit_to_test()
    while tc:

        print('COMMIT: %s (%s)' % (tc, ab.get_untested_count()))

        # get rid of previous remnants
        clean_checkout(args.checkoutdir)

        # set the checkout version
        aversion = checkout_commit(args.checkoutdir, tc)

        # check the commit date
        cdate = get_checkout_date(args.checkoutdir)
        TSMAP[tc] = cdate
        print('\tdate: %s' % cdate)

        # get the version and use as a sanity check
        (vrc, aso, ase, aversion) = get_version(args.checkoutdir)
        if vrc != 0:
            print('\t%s' % cdate)
            print('\tsanity check failed, skipping')
            for line in (str(aso) + str(ase)).split('\n'):
                print('\t\t%s' % line)
            # remove the commit
            ab.remove_commit(tc)
            # get next commit
            tc = ab.get_commit_to_test()
            continue
        else:
            print('\tversion: %s' % aversion)

        # move forward if too old
        aversion = LooseVersion(aversion)
        if args.version_start:
            if aversion < args.version_start:
                print('\t%s < %s, skipping' % (aversion, args.version_start))
                # remove the commit
                ab.remove_commit(tc)
                # get next commit
                tc = ab.get_commit_to_test()
                continue

        # run the test
        (rc, so, se) = run_test(args.checkoutdir, args.testscript)
        print('\trc: %s' % rc)
        if args.show_errors and rc != 0:
            print(str(so) + str(se))
        #if rc != 0:
        #    sys.exit(1)
        ab.set_result(tc, rc, so, se)
        tc = ab.get_commit_to_test()

        #print('## %s' % ab.current_index)
        #print('## %s' % id(ab.current_index))
        #print('## %s' % id(ab))

    #ab.print_status()
    bisected_commit = ab.get_bisected_commit()
    print('########################################')
    print('# BISECTED: %s' % bisected_commit)
    print('########################################')

if __name__ == '__main__':
    main()
