#!/usr/local/cpython-3.3/bin/python

'''Unit tests for binary_tree_dict_mod'''

import sys
import math
import random

import binary_tree_dict_mod


def test_depth():
    '''Test getting the depth of a splay tree'''
    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    binary_tree_dict[2] = 'def'
    binary_tree_dict[3] = 'ghi'
    binary_tree_dict[1] = 'abc'
    binary_tree_dict[4] = 'jkl'
    binary_tree_dict[5] = 'mno'
    binary_tree_dict[7] = 'stu'
    binary_tree_dict[8] = 'vwx'
    binary_tree_dict[9] = 'yz'
    binary_tree_dict[6] = 'pqr'

    depth = binary_tree_dict.depth()
    len_binary_tree_dict = len(binary_tree_dict)
    log_2_len_binary_tree_dict = int(round(math.log(len_binary_tree_dict, 2)))

    # Binary trees can be linear lists, in the worst case...
    if log_2_len_binary_tree_dict <= depth <= len_binary_tree_dict:
        pass
    else:
        sys.stderr.write('%s: test_depth: Bogus depth: %d\n' % (sys.argv[0], depth))
        all_good = False

    return all_good

def test_insert():
    '''Insert some values into a binary_tree_dict tree, then make sure they can all be found in the tree'''
    keys = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ]
    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    for key in keys:
        binary_tree_dict[key] = str(key)

    all_good = True

    for key in keys:
        if str(key) != binary_tree_dict[key]:
            all_good = False
            sys.stderr.write('%s: test_insert: Found mismatched key: Got %s, expected %s\n' % (
                sys.argv[0], binary_tree_dict[key], str(key)))

    return all_good

def test_remove():
    '''
    Insert some values into a binary_tree_dict tree, then make sure they can
    all be removed.  Finally, ensure the binary_tree_dict tree is empty
    '''
    keys = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ]
    # Shuffling the keys helps keep the print of a manageable size
    random.shuffle(keys)
    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    for key in keys:
        binary_tree_dict[key] = str(key)
    # Shuffle again to make it a more interesting test
    random.shuffle(keys)

    all_good = True

    for key in keys:
        del binary_tree_dict[key]
        try:
            dummy = binary_tree_dict[key]
        except KeyError:
            pass
        else:
            all_good = False
            sys.stderr.write('%s: test_remove: element %s not removed\n' % (sys.argv[0], key))

    if binary_tree_dict:
        all_good = False
        sys.stderr.write('%s: test_remove: final tree nonempty\n' % (sys.argv[0], ))

    return all_good

def test_large_inserts():
    '''Insert lots of values into a binary_tree_dict tree, just to see if we get a traceback'''

    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    nums = 100000
    gap = 997
    key = gap
    expected_min = None
    expected_max = None
    while key != 0:
        if expected_min is None or key < expected_min:
            expected_min = key
        if expected_max is None or key > expected_max:
            expected_max = key
        binary_tree_dict[key] = str(key)
        key = (key + gap) % nums

    actual_min = binary_tree_dict.find_min()
    if expected_min != actual_min:
        sys.stderr.write('%s: Large binary_tree_dict did not return correct minimum: expected: %s, actual: %s\n' % (
            sys.argv[0], expected_min, actual_min))
        all_good = False

    actual_max = binary_tree_dict.find_max()
    if expected_max != actual_max:
        sys.stderr.write('%s: Large binary_tree_dict did not return correct maximum: expected: %s, actual: %s\n' % (
            sys.argv[0], expected_max, actual_max))
        all_good = False

    return all_good

def test_nonempty():
    '''Test a nonempty binary_tree_dict tree'''
    keys = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ]
    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    for key in keys:
        binary_tree_dict[key] = str(key)

    all_good = True

    if not binary_tree_dict:
        all_good = False
        sys.stderr.write('%s: nonempty binary_tree_dict looks empty\n' % sys.argv[0])

    return all_good

def test_empty():
    '''Test an empty binary_tree_dict tree'''
    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()

    if binary_tree_dict:
        all_good = False
        sys.stderr.write('%s: empty binary_tree_dict looks nonempty\n' % sys.argv[0])

    return all_good

def test_min_max():
    '''Insert some values into a binary_tree_dict tree, then test find_min and find_max'''
    keys = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ]
    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    for key in keys:
        binary_tree_dict[key] = str(key)

    all_good = True

    actual_min = binary_tree_dict.find_min()
    if actual_min != 0:
        sys.stderr.write('%s: minimum was not 0: %s\n' % (sys.argv[0], actual_min))
        all_good = False

    actual_max = binary_tree_dict.find_max()
    if actual_max != 9:
        sys.stderr.write('%s: maximum was not 9: %s\n' % (sys.argv[0], actual_max))
        all_good = False

    return all_good

def test_values():
    '''Insert a few key-value pairs, and make sure they come back out OK'''

    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    binary_tree_dict[2] = 'def'
    binary_tree_dict[3] = 'ghi'
    binary_tree_dict[1] = 'abc'
    binary_tree_dict[4] = 'jkl'
    binary_tree_dict[5] = 'mno'
    binary_tree_dict[7] = 'stu'
    binary_tree_dict[8] = 'vwx'
    binary_tree_dict[9] = 'yz'
    binary_tree_dict[6] = 'pqr'

    if binary_tree_dict.find_min() != 1:
        sys.stderr.write('%s: test_values: Minimum was not 0\n' % sys.argv[0])
        all_good = False

    if binary_tree_dict.find_max() != 9:
        sys.stderr.write('%s: test_values: Maximum was not 9\n' % sys.argv[0])
        all_good = False

    if binary_tree_dict[5] != 'mno':
        sys.stderr.write('%s: test_values: Middle was not mno\n' % sys.argv[0])
        all_good = False

    return all_good

def test_inorder_traversal():
    '''Test an inorder traversal'''

    list_ = []

    def visit(key, value):
        '''Visit a node, but sticking its key, value into a list'''
        list_.append((key, value))


    keys = [ x*3 + 1 for x in range(10) ]
    items = [ (key, str(key)) for key in keys ]
    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    for key in keys:
        binary_tree_dict[key] = str(key)

    all_good = True

    binary_tree_dict.inorder_traversal(visit)

    if items != list_:
        sys.stderr.write('%s: test_inorder_traversal: inorder_traversal failed to rebuild the list\n' % (sys.argv[0], ))
        sys.stderr.write('Expected %s\n' % (items, ))
        sys.stderr.write('Got %s\n' % (list_, ))
        all_good = False

    return all_good

def test_str():
    '''Test formatting a binary_tree_dict tree as a string'''

    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    binary_tree_dict[1] = 'abc'
    binary_tree_dict[2] = 'def'
    binary_tree_dict[3] = 'ghi'
    binary_tree_dict[4] = 'jkl'
    binary_tree_dict[5] = 'mno'
    binary_tree_dict[6] = 'pqr'
    binary_tree_dict[7] = 'stu'
    binary_tree_dict[8] = 'vwx'
    binary_tree_dict[9] = 'yz'

    dummy = binary_tree_dict[3]

    string = str(binary_tree_dict)

    count = string.count('\n')
    len_binary_tree_dict = len(binary_tree_dict)
    maximum_allowable_depth = len_binary_tree_dict

    minimum_allowable_depth = int(round(math.log(len_binary_tree_dict, 2.0)))

    if minimum_allowable_depth >= count >= maximum_allowable_depth:
        sys.stderr.write('%s: test_str: bad number of newlines: %d\n' % (sys.argv[0], count))
        sys.stderr.write('%s\n' % count)
        sys.stderr.write('%s\n' % (string, ))
        all_good = False

    return all_good

def test_iterator():
    '''Test iterating over the enter binary_tree_dict tree'''
    all_good = True

    actual = []

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    binary_tree_dict[2] = 'def'
    binary_tree_dict[3] = 'ghi'
    binary_tree_dict[1] = 'abc'
    binary_tree_dict[4] = 'jkl'
    binary_tree_dict[5] = 'mno'
    binary_tree_dict[7] = 'stu'
    binary_tree_dict[8] = 'vwx'
    binary_tree_dict[9] = 'yz'
    binary_tree_dict[6] = 'pqr'

    expected = [ 'abc', 'def', 'ghi', 'jkl', 'mno', 'pqr', 'stu', 'vwx', 'yz', ]

    for value in binary_tree_dict.values():
        actual.append(value)

    if expected != actual:
        sys.stderr.write('%s: test_iterator: values did not come back in correct order\n' % sys.argv[0])
        all_good = False

    return all_good


def test_sequential():
    '''Test inserting lots of sequential values'''
    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()

    # CPython 3.2 had a recursion limit of 100, so this should be adequate
    top = 2000
    try:
        for index in range(top):
            #print('Inserting %s of %s (%s%%)' % (index, top, round(index * 100.0 / top, 1)))
            binary_tree_dict[index] = 1
    except RuntimeError:
        all_good = False
        sys.stderr.write('%s: Stack blown on __setitem__\n' % (sys.argv[0], ))

    return all_good


def test_duplication():
    '''Test inserting duplicate keys'''
    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()

    list_ = list(range(20))
    random.shuffle(list_)
    for number in list_:
        binary_tree_dict[number] = 2 ** number

    random.shuffle(list_)
    for number in list_:
        binary_tree_dict[number] = 2 ** number

    if len(binary_tree_dict) == 20:
        pass
    else:
        sys.stderr.write('%s: number of elements is not 20: %s\n' % (sys.argv[0], len(binary_tree_dict)))
        all_good = False

    return all_good


def main():
    # pylint: disable=global-statement
    '''Main function'''

    all_good = True

    all_good &= test_depth()
    all_good &= test_insert()
    all_good &= test_min_max()
    all_good &= test_large_inserts()
    all_good &= test_nonempty()
    all_good &= test_empty()
    all_good &= test_values()
    all_good &= test_inorder_traversal()
    all_good &= test_str()
    all_good &= test_iterator()
    all_good &= test_remove()
    all_good &= test_sequential()
    all_good &= test_duplication()

    if not all_good:
        sys.stderr.write('%s: One or more tests failed\n' % sys.argv[0])
        sys.exit(1)

main()

