#! /usr/bin/env python3
"""
ad2openldap3 -- a script that performs one-way syncing of certain Active
                Directory records along with a limited set of attributes, into
                an OpenLDAP directory information tree.

Python3 port using ldap3

"""

import argparse
import base64
import datetime
import hashlib
import logging
import os
import pwd
import random
import re
import shutil
import smtplib
import socket
import subprocess
import string
import sys
import time
import traceback
import yaml

import ldap3

from os import stat
from pwd import getpwnam, getpwuid
from grp import getgrnam, getgrgid
from string import Template

__author__ = "Jeff Katcher"
__credits__ = ["Dirk Petersen","Brian Hodges"]
__license__ = "GPLv3"
__version__ = "0.20"

def config_logging(arguments):
    # Set up logging.  Show error messages by default, show debugging 
    # info if specified.
    log_format = '%(levelname)s:%(module)s:line %(lineno)s:%(message)s'
    if arguments.debug:
        log_level = logging.DEBUG
    elif arguments.verbose:
        log_level = logging.INFO
    else:
        log_level = logging.ERROR
    logging.basicConfig(stream=sys.stdout, format=log_format, level=log_level)

    logging.info('ad2openldap - starting execution at %s' % datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    logging.debug('Parsed arguments: %s' % arguments)

def main():
    if len(sys.argv)<2:
        logging.error('No action specified, exiting.')
        #run_command("/usr/games/fortune",echo=True)
        return 1

    action = sys.argv[1]

    arguments = parse_arguments()
    config_logging(arguments)

    # Load up yaml configuration file.
    config = load_config(arguments.config_file)
    logging.debug('yaml config file, loaded settings: %s' % config)

    # Exit if it looks as though the config file is untouched.
    if action != 'setup' and config['smtp_host'] == 'mx.example.com':
        logging.error('The ad2openldap.conf file must first be configured, unable to proceed, exiting.')
        return 1


    # The script is presently monolithic... for a future release, it is 
    # planned to have the argument parsing part of the main script, 
    # and to separate most of the ldap and ad code into one 
    # or more classes, of which command-line args and config file
    # settings will be instance properties, cleaning up the 
    # passing of settings throughout the code.

    tmp_dir = config['tmp_dir']
    pid_file = os.path.join(tmp_dir, config['pid_file'])
    ad_export_file = os.path.join(tmp_dir, config['ad_export_file'])
    ad_export_previous = os.path.join(tmp_dir, config['ad_export_previous'])
    delta_file = os.path.join(tmp_dir, config['delta_file'])
    ldap_tree_dir = config['ldap_tree_dir']
    ldap_tree_dir_perms = config['ldap_tree_dir_perms']
    ldap_rtc_parent_dir = config['ldap_rtc_parent_dir']
    ldap_rtc_dir = config['ldap_rtc_dir']
    ldap_rtc_dir_perms = config['ldap_tree_dir_perms']
    openldap_user = config['openldap_user']
    openldap_group = config['openldap_group']
    slapd_start = config['slapd_start']
    slapd_stop = config['slapd_stop']
    slapadd_template = Template(config['slapadd'])
    slapadd_rtc_template = Template(config['slapadd_rtc'])
    slapadd_dit_skeleton = slapadd_template.substitute(ldif_file=config['dit_skeleton'])
    slapadd_rtc_ldif = slapadd_rtc_template.substitute(ldif_file=os.path.join(tmp_dir, config['rtc_substituted_file']), ldap_rtc_parent_dir=config['ldap_rtc_parent_dir'])
    slapadd_ad_export_file = slapadd_template.substitute(ldif_file=ad_export_file)
    ldapmodify_template = Template(config['ldapmodify'])
    ldapmodify_delta_file = ldapmodify_template.substitute(ldif_file=os.path.join(tmp_dir, config['delta_file']), bind_dn=config['bind_dn'], bind_dn_password=config['bind_dn_password'])
    iptables_add_template = Template(config['iptables_add'])
    iptables_delete_template = Template(config['iptables_delete'])
    iptables_add = iptables_add_template.substitute(interface=config['interface'])
    iptables_delete = iptables_delete_template.substitute(interface=config['interface'])

    # Setup notification addresses
    if arguments.notify_address:
        notify_addresses = arguments.notify_address
    elif 'notify_addresses' in config:
        notify_addresses = config['notify_addresses']
    else:
        notify_addresses = None

    # Use a pid file to make sure two copies don't run in 
    # parallel or to notify and log that the last run exited prematurely.
    pid = str(os.getpid())
    if os.path.isfile(pid_file):
        # Notify and log here
        notify(notify_addresses, 'Unable to start', 
            '%s already exists, exiting' % pid_file, config)
        logging.critical('%s already exists, exiting' % pid_file)
        return 1
    else:
        open(pid_file, 'w').write(pid)

    try:
        # Make certain ownership and permissions are correct 
        # before executing any commands.

        owner_and_perms_check(arguments.config_file, config)

        if action == 'fullsync' or action == 'deltasync':
            # If ad_export_file exists, move it to ad_export_previous.
            if os.path.isfile(ad_export_file):
                shutil.move(ad_export_file, ad_export_previous)

            # Export AD users and groups to a local ldif file
            # Uncomment after testing
            ad_export_objects(config)

            if arguments.regex_replace:
                ldif_replace(ad_export_file, arguments.regex_replace)
            
        # If both files exist we have data that is diffable.
        diffable = True if os.path.isfile(ad_export_file) and os.path.isfile(ad_export_previous) else False

        if action == 'fullsync' and arguments.full_sync:
            # Only root can perform a full sync due to OS-level permissions.
            if pwd.getpwuid(os.getuid()).pw_name != 'root':
                logging.error('Only the root user can perform a full sync')
                return 1
            # Stop slapd, clean-up ldap dir, quickload DIT skeleton and 
            # all AD records to import.
            run_command(iptables_add, verbose=arguments.verbose)
            run_command(slapd_stop, verbose=arguments.verbose)
            ldap_tree_dir_clean(ldap_tree_dir)
            run_command(slapadd_dit_skeleton, verbose=arguments.verbose)
            run_command(slapadd_ad_export_file, verbose=arguments.verbose)
            ldap_tree_dir_fix_perms(ldap_tree_dir, ldap_tree_dir_perms, openldap_user, openldap_group)
            run_command(slapd_start, verbose=arguments.verbose)
            run_command(iptables_delete, verbose=arguments.verbose)
        elif action == 'deltasync' and arguments.delta_sync and diffable:
            ldif_diff(config)
            if os.path.exists(delta_file) and os.path.getsize(delta_file) > 0: 
                run_command(ldapmodify_delta_file, verbose=arguments.verbose,
                    fatal=False)
        elif action == 'deltasync' and arguments.delta_sync and not diffable:
            logging.error('Both %s (ad_export_file) and %s (ad_export_previous) must be present to determine delta.' % (ad_export_file, ad_export_previous))
            return 1
        elif action == 'healthcheck' and arguments.common_name:
            health_check(arguments.common_name,config,verbose=arguments.verbose)
        elif action == 'setup':
            setup_information()
        elif action == 'rtcclean' and arguments.rtc_clean:
            # Only root can perform a rtc clean due to OS-level permissions.
            if pwd.getpwuid(os.getuid()).pw_name != 'root':
                logging.error('Only the root user can perform a rtc clean')
                return 1
            rtc_template = Template(open(config['rtc_ldif']).read())
            rtc_substituted = rtc_template.safe_substitute(ssha=generate_ssha(config['bind_dn_password']).decode())
            with open(os.path.join(tmp_dir, config['rtc_substituted_file']), 'w') as rtc_substituted_file:
                rtc_substituted_file.write(rtc_substituted)

            run_command(iptables_add, verbose=arguments.verbose)
            run_command(slapd_stop, verbose=arguments.verbose, fatal=False)
            ldap_rtc_parent_dir_clean(ldap_rtc_parent_dir)
            run_command(slapadd_rtc_ldif, verbose=arguments.verbose)
            os.remove(os.path.join(tmp_dir, config['rtc_substituted_file']))
            ldap_rtc_fix_perms(ldap_rtc_parent_dir, ldap_rtc_dir, 
                ldap_rtc_dir_perms, openldap_user, openldap_group)
            run_command(slapd_start, verbose=arguments.verbose)
            run_command(iptables_delete, verbose=arguments.verbose)
    except:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        # Notify with limited information. Full traceback to stderr for now
        notify(notify_addresses, 'A %s exception has been raised' % exc_type, 
            '%s' % '\n'.join(traceback.format_exception(exc_type, exc_value, 
            exc_traceback)), config)
        raise
    finally:
        os.unlink(pid_file)
        logging.info('ad2openldap - finished execution at %s' % datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))



def notify(notify_addresses, subject, text, config):
    """
    Send e-mail notifications.
    """
    if notify_addresses:
        smtp_host = config['smtp_host']
        from_address = config['from_address']
        # Remove yaml escape characters - \
        subject=str.replace(config['subject_prepend'],'\\','')+' '+subject
        host_name = socket.gethostname()

        message = Template('From: $notify_from\nTo: $notify_to\n' + \
            'Subject: $notify_subject\n\n' + \
            'This is a notification message from ad2openldap, running on \n' + \
            'host $host. Please review the following message:\n\n\n' + \
            '$notify_text\n\nIf output is being captured, you may find\n' + \
            'additional information in your log file.\n'
            )

        full_message = message.substitute(notify_from=from_address,
            notify_to=','.join(notify_addresses),
            notify_subject=subject, host=host_name.upper(), notify_text=text)


        smtp = smtplib.SMTP(smtp_host)
        smtp.sendmail(from_address, notify_addresses, full_message)
        smtp.quit()


def parse_arguments():
    """
    Parse command-line arguments.
    """
    parsers = []
    parser = argparse.ArgumentParser(description='ad2openldap - sync a ' + \
       'destination OpenLDAP DIT with objects originating from ' + \
       'Active Directory.', prog='ad2openldap')

    parsers.append(parser)

    subparsers = parser.add_subparsers(title='subcommands',
        description='valid subcommands', help='additional help')


    # delta sync subcommand
    delta_sync = subparsers.add_parser('deltasync', help='Perform a one-way ' + \
        'sync from Active Directory to the ad2openldap DIT, only working ' + \
        'with the delta')       
    delta_sync.add_argument('--dont-blame-ad2openldap', '-y', action='store_true',
        required=True, dest='delta_sync', default=False,
        help='Sync from Active Directory to OpenLDAP incrementally')
    delta_sync.add_argument('--regex-replace', '-r', help='Specify a ' + \
        'regular expression substitution that will be applied to each line of the AD exported ' + \
	'LDIF file.  Match expression and replacement must be ^^ separated. ' + \
	'Can be called multiple times and order will be preserved. This feature is experimental',
        dest='regex_replace', action='append', metavar="'EXPR^^REPLACEMENT'")
    parsers.append(delta_sync)


    # full sync subcommand
    full_sync = subparsers.add_parser('fullsync', help='Perform a one-way, ' + \
        'full sync from Active Directory to the ad2openldap DIT.  Use ' + \
        'with caution, this command is destructive and currently will ' + \
        'permanently delete all DITs on this system')       
    full_sync.add_argument('--dont-blame-ad2openldap', '-y', action='store_true',
        required=True, dest='full_sync', default=False,
        help='All directory information trees, not only the ad2openldap tree ' + \
        'will be permanently deleted. The base DN synced to by ad2openldap ' + \
        'will be freshly populated. Use with caution') 
    full_sync.add_argument('--regex-replace', '-r', help='Specify a ' + \
        'regular expression substitution that will be applied to each line of the AD exported ' + \
	'LDIF file.  Match expression and replacement must be ^^ separated. ' + \
	'Can be called multiple times and order will be preserved. This feature is experimental',
        dest='regex_replace', action='append', metavar="'EXPR^^REPLACEMENT'")
    parsers.append(full_sync)


    # healthcheck subcommand 
    health_check = subparsers.add_parser('healthcheck', help='Check health of ' + \
        'OpenLDAP instance by performing searches')       
    health_check.add_argument('--common-name', '-N', help='Specify a CN to ' + \
        'search for.  A single matching entry being returned is considered a ' + \
        'success', required=True, dest='common_name')
    parsers.append(health_check)


    # healthcheck subcommand 
    setup = subparsers.add_parser('setup', help='Print information ' + \
        'pertaining to the setup of ad2openldap on this system')       
    parsers.append(setup)


    # rtc subcommand
    rtc_clean = subparsers.add_parser('rtcclean', help='Install a clean copy of ' + \
        'cn=config, which adheres to rfc2307bis.  This is a destructive command ' + \
        'which should be used with caution.  All files under ldap_rtc_parent_dir will ' + \
        'be permanently lost!')       
    rtc_clean.add_argument('--dont-blame-ad2openldap', '-y', action='store_true',
        required=True, dest='rtc_clean',
        help='cn=config, the OpenLDAP real-time configuration in its current ' + \
        'state will be permanently lost. It will be freshly populated ' + \
        'to adhere to rfc2307bis, which may not be compatible with directory ' + \
        'information trees other than what is used by ad2openldap. Use with ' + \
        'caution as the real-time configuration will be completely repopulated')
    parsers.append(rtc_clean)


    # Add shared arguments
    for command in parsers:
        command.add_argument('--config-file', '-C', dest='config_file',
            action='store', default='/etc/ad2openldap/ad2openldap.conf', 
            help='Configuration file to use, default to: %s' % \
            '/etc/ad2openldap/ad2openldap.conf')

        command.add_argument('--notify-address', '-a', dest='notify_address',
            action='append', help='E-mail address to send notifications to, ' + \
            'overrides settings in ad2openldap.conf, use multiple times to ' + \
            'send to more than one recipient')

        """
        # Not implement, output goes to stdout presently.
        command.add_argument('--log-file', '-l', dest='log_file',
            action='store', help='Full path to log file with write permission.')
        """
        command.add_argument('--version', '-V', action='version',
            version='ad2openldap v' + __version__,
            help="Print the version number")

        command.add_argument('--verbose', '-v', action='store_true', 
            dest='verbose', help='Turn verbose output on.  If redirecting ' + \
            'output to an application log file, this will option is important. ' + \
            'A future release may provide the option of specifying a log file, but ' + \
            'for the initial release, output goes to STDOUT and STDERR.')

        command.add_argument('--debug', '-d', action='store_true', 
            dest='debug', help='Turn debugging output on')

        """
        # Not implemented.
        parser.add_argument('--dry-run', '-n', action='store_true', 
            default=False, help='Run in dry-run mode, which will not write ' + \
            'to any files, nor to OpenLDAP.  Useful used in ' + \
            'combination with -v and -d')
        """

    arguments = parser.parse_args()

    return arguments


def ldif_replace(ad_export_file, regex_replace):
    """
    Just after exporting from AD, replace text according to 
    regular expressions passed as command-line arguments.
    """
    # Perform substitutions in-memory
    new_ldif = ""
    with open(ad_export_file, 'r') as exported:
        for line in exported:
            for regex in regex_replace:
                pattern, replace = string.split(regex, '^^') 
                line = re.sub(pattern, replace, line)
            new_ldif += line
    # Overwrite ad_export_file with transformed version
    with open(ad_export_file, "w") as out_file:
        out_file.write(new_ldif)

def load_config(config_file):
    """
    Load up the YAML configuration file.
    """
    stream = open(config_file, 'r')
    config = yaml.load(stream)
    stream.close()

    # Perform some configuration file validation here.
    validate_config(config)

    return config


def validate_config(config):
    """
    Basic settings validation. Warn or raise errors as appropriate.
    """

    """
    try:
        config['']
        if not os.path.isdir(config['']):
            raise os.error('%s does not exist or is not readable.' %
                config[''])
    except KeyError:
        logging.error('must be specified in configuration file.')
        raise
    try:
        config['']
        int(config[''])
    except KeyError:
        logging.error('must be specified in configuration file.')
        raise
    except ValueError:
        logging.error('setting must be an integer: "%s" is not valid.' %
           '')
        raise

    """
    pass


def health_check(common_name, config, verbose=False):
    ldap_conn=open_ldap(config['ldap_url'])

    ldap_conn.search(search_base=config['base_dn'],search_scope=ldap3.SUBTREE,
        attributes=['cn'],
        search_filter="(cn=%s)" % common_name)

    if verbose:
        print("Results of search for cn: %s -- %s" % 
            (common_name,ldap_conn.entries)) 
    logging.debug("Results of search for cn: %s -- %s" % 
        (common_name,ldap_conn.entries)) 
    results_count=len(ldap_conn.entries)

    if results_count == 1:
        if verbose:
            print('Successful search performed for CN: %s' % common_name)
        logging.info('Successful search performed for CN: %s' % common_name)
    else:
        if verbose:
            print('Health check returned %i results for cn: %s' %
                (results_count, common_name) ) 
        logging.error('Health check returned %i results for cn: %s' %
            (results_count, common_name) ) 
        return 1
        
    return 0


def setup_information():
    print("""
        Setting up ad2openldap requires a few simple steps. Before
        getting started however, note that ad2openldap will 
        permanently remove your OpenLDAP real-time configuration
        (RTC / cn=config) and replace it with a configuration 
        that is part of the ad2openldap package. Also, be aware 
        that doing a 'fullsync' will completely destroy anything 
        under the directory used to store your directory information
        trees, typically /var/lib/ldap/.

        Steps to follow:

        1) Read the man page:
            
           man ad2openldap        

        2) Configure ad2openldap by editing:

           /etc/ad2openldap/ad2openldap.conf
           
           It is self-documented and most of the defaults will work.

        3) Destroy the default real-time configuration, create a 
           new one (must run this as root):

           ad2openldap rtcclean --dont-blame-ad2openldap -v

        4) Remove any previous directory information trees used 
           by OpenLDAP, create an empty DIT for syncing from 
           Active Directory and perform a full sync (must be run
           as root):

           ad2openldap fullsync --dont-blame-ad2openldap -v

        5) Perform one or more health checks to see if objects
           that should have been imported, were imported.  Use
           common names which are user and group names:

           /usr/sbin/ad2openldap healthcheck -N bhodges -v

        6) LDIF files are now owned by root, prior to running ad2openldap
           under the openldap account, change ownership:

           chown openldap /tmp/ad_export*

        7) Setup cron to run ad2openldap periodically as the openldap
           user, performing delta syncs where only changes are 
           brought over from AD.  Example:

           # Run ad2openldap deltasync once per hour
           SHELL=/bin/bash
           MAILTO=jsmith@example.com
           25 * * * *  openldap  /usr/sbin/ad2openldap deltasync \\
             --dont-blame-ad2openldap -v \\
             >>/var/log/ad2openldap/ad2openldap.log 2>&1; \\
             /usr/sbin/ad2openldap healthcheck -N bhodges

           It is recommended to have notify_addresses set in the 
           configuration file.  /var/log/ad2openldap should be 
           writable by the openldap user.  Performing a health 
           check after each sync is also advised.



        If things go awry at any point, look at the verbose output 
        and consider trying the --debug flag.

        Try running ad2openldap manually similar to this:

            # su -s /bin/bash -c "/usr/sbin/ad2openldap deltasync --dont-blame-ad2openldap -v" openldap

    """)

    return 0


def owner_and_perms_check(config_file, config):
    """
    Very basic checks to make sure ownership and permissions are
    properly set.  This is meant to warn admins of only a few 
    common security considerations, but it is not extensive 
    nor to be relied on.
    """
    owner_check(config_file, 'root', config['openldap_group'])
    perms_check(config_file, '0640')


def owner_check(path, user_name, group_name):
    """
    Test if user.group is correct for a path.
    """
    assert getpwuid(os.stat(path).st_uid).pw_name == user_name, '%s is not owned by user %s' % (path, user_name)
    assert getgrgid(os.stat(path).st_gid).gr_name == group_name, '%s is not owned by group %s' % (path, group_name)


def perms_check(path, perms):
    """
    Test if string representation of octal permissions is correct for a path.
    """
    mode = str(oct(stat(path).st_mode)[-4:])
    assert mode == perms, '%s is mode %s, %s is required' % (path, mode, perms)


def run_command(command, verbose=False, fatal=True, echo=False):
    child = subprocess.Popen(command.split(),
        stdout=None if echo else open('/dev/null', 'w'),
        stderr=subprocess.STDOUT)
    return_code = child.wait()

    # Filter out password (-w xxxxx) that is common with ldap-utils commands, 
    # prior to logging or raising an exception.
    command = re.sub(r'-w\s*.*?(\s|$)', '-w xxxxx ', command)
    if return_code != 0:
        logging.info('A problem (%d) has occurred calling %s' % 
            (return_code,command))
        if fatal:
            raise OSError('A problem (%d) has occurred calling %s' %
                (return_code,command))
    elif verbose:
        logging.info('Successfully executed command: %s' % command)


def generate_ssha(password):
    """
    Generate a salta sha1 hash.
    """
    salt = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for x in range(4))

    # encode added to make python3 happy    
    return base64.b64encode(hashlib.sha1((password+salt).encode('utf-8')).digest() + salt.encode('utf-8'))


def ldap_tree_dir_clean(ldap_tree_dir):
    """
    Perform recursive delete on delete ldap_tree_dir. Recreate 
    directory.
    """
    if os.path.isdir(ldap_tree_dir):
        shutil.rmtree(ldap_tree_dir)
    os.mkdir(ldap_tree_dir)


def ldap_tree_dir_fix_perms(ldap_tree_dir, ldap_tree_dir_perms, openldap_user, openldap_group):
    """
    Fix ownership and perms for ldap_tree_dir and files.
    """
    os.chmod(ldap_tree_dir, int(str(ldap_tree_dir_perms), 8))
    os.chown(ldap_tree_dir, getpwnam(openldap_user).pw_uid, 
        getgrnam(openldap_group).gr_gid)
    for dirpath, dirnames, filenames in os.walk(ldap_tree_dir):
        for file_name in filenames:
            #os.chmod(os.path.join(ldap_tree_dir, file_name), mode)
            os.chown(os.path.join(ldap_tree_dir, file_name), 
                getpwnam(openldap_user).pw_uid, getgrnam(openldap_group).gr_gid)


def ldap_rtc_parent_dir_clean(ldap_rtc_parent_dir):
    """
    Perform recursive delete on delete ldap_rtc_dir. Recreate 
    directory.
    """
    if os.path.isdir(ldap_rtc_parent_dir):
        shutil.rmtree(ldap_rtc_parent_dir)
    os.mkdir(ldap_rtc_parent_dir)

            
def ldap_rtc_fix_perms(ldap_rtc_parent_dir, ldap_rtc_dir, ldap_rtc_dir_perms, openldap_user, openldap_group):
    """
    Fix ownership and perms for ldap_rtc_dir and files.
    """
    os.chmod(ldap_rtc_parent_dir, int(str(ldap_rtc_dir_perms), 8))
    os.chown(ldap_rtc_parent_dir, getpwnam(openldap_user).pw_uid, 
        getgrnam(openldap_group).gr_gid)
    for dirpath, dirnames, filenames in os.walk(ldap_rtc_parent_dir):
        for file_name in filenames:
            os.chown(dirpath, getpwnam(openldap_user).pw_uid, 
                getgrnam(openldap_group).gr_gid)
            os.chown(os.path.join(dirpath, file_name), 
                getpwnam(openldap_user).pw_uid, getgrnam(openldap_group).gr_gid)


def ad_export_objects(config):
    """
    Handle flow of AD object exporting.
    """
    ad_url = config['ad_url']
    ad_base_dn = config['ad_base_dn']
    base_dn = config['base_dn'] 
    ad_excluded_group = config['ad_excluded_group']
    default_shell = config['default_shell']
    ad_account = config['ad_account']
    ad_account_password = config['ad_account_password']
    misc_attributes = config['misc_attributes']
    default_gid = config['default_gid']
    tmp_dir = config['tmp_dir']
    ad_export_file = os.path.join(tmp_dir, config['ad_export_file'])

    # Filters to use for AD queries, should move to config file.
    nis_info_filter = config['nis_info_filter']
    user_filter = config['user_filter']
    group_filter = config['group_filter']

    # shared connection to AD
    ldap_conn=open_ldap(ad_url,ad_account,ad_account_password)

    nisinfo = retrieve_ldap_nisinfo(ldap_conn,ad_base_dn,base_dn,
        nis_info_filter)

    users, users_by_group, dn_uid_map = retrieve_ldap_userinfo(ldap_conn,
        ad_base_dn, user_filter, misc_attributes)

    group_gids, groups, excluded_groups = retrieve_ldap_groupinfo(ldap_conn,
        ad_base_dn, group_filter, ad_excluded_group) 

    # It is possible that a group may contain groups but no users.  
    # Make sure to include any groups in 'groups' dictionary that 
    # might have been neglected, as no individual users are members
    for new in groups:
        if new not in users_by_group:
            users_by_group[new] = []


    flatten_groups(groups, users_by_group, users, dn_uid_map)

    # Write out users, groups and nisinfo to ldif file
    with open(ad_export_file, 'w') as ad_export_fh: 
        print_nisinfo(ad_export_fh, nisinfo)
        print_users(ad_export_fh, users, base_dn, misc_attributes, default_gid,
            default_shell)
        print_groups(ad_export_fh, users_by_group, group_gids, base_dn, 
            excluded_groups)


# split LDAP-returned strings into components while editing out '\'s
def skip_split(s,split_char=',',skip_char='\\'):
    split=[]
    partial=""
    skip=0
    escape = False

    for char in s:
        if char == skip_char:
            escape = True          
            continue
        elif char == split_char and escape == False:
            split.append(partial)
            partial = ""
        else:
            partial += char
            # It's possible that non-split characters will get escaped
            escape = False

    # append residual partial if it exists
    if partial!="":
        split.append(partial)

    return split


# utility function to print dictionary attribute if possible
def print_attr(ad_export_fh,user,attr):
    if attr in user:
        ad_export_fh.write(attr+": "+user[attr]+'\n')


# prints out the nisinfo string that was previously built
def print_nisinfo(ad_export_fh, nisinfo):
    ad_export_fh.write(nisinfo)


# iterates through users dictionary uses base as base dn
# if no gidNumber is set in AD, then 'nobody' value is set for export
def print_users(ad_export_fh,users,base,misc_attrs,nobody,def_shell=""):
    for cn,user in users.items():
        if 'uid' not in user:
            logging.error('no uid in %s' % user)
            continue
        if 'uidNumber' not in user:
            logging.error('no uidNumber in %s' % user)
            continue
        ad_export_fh.write("dn: uid="+user['uid']+",ou=people,"+base+'\n')
        ad_export_fh.write("cn: "+user['uid']+'\n')
        ad_export_fh.write("uid: "+user['uid']+'\n')
        ad_export_fh.write("objectclass: account"+'\n')
        ad_export_fh.write("objectclass: posixAccount"+'\n')
        ad_export_fh.write("uidnumber: "+str(user['uidNumber'])+'\n')
        ad_export_fh.write("gidnumber: ")
        if 'gidNumber' in user:
            ad_export_fh.write(str(user['gidNumber'])+'\n')
        else:
            ad_export_fh.write(str(nobody)+'\n')
        
        if 'unixHomeDirectory' in user:    
            ad_export_fh.write("homedirectory: "+user['unixHomeDirectory']+'\n')
        else:
            ad_export_fh.write("homedirectory: /home/"+user['uid']+'\n')

        for attr in misc_attrs:
            print_attr(ad_export_fh,user,attr)

        if def_shell and 'loginShell' not in user:
            ad_export_fh.write("loginShell: "+def_shell+'\n')

        ad_export_fh.write('\n')


# convert member list to ldif syntax and print
# sort members to compensate for any server specific order inconsistency
def print_members(ad_export_fh,mlist,base):
    for user in sorted(mlist):
        ad_export_fh.write("member: uid="+user+",ou=people,"+base+'\n')
    for user in sorted(mlist):
        ad_export_fh.write("memberUid: " + user + '\n')


# output groups dictionary in LDIF format 
# skip groups ending in '-LS'
def print_groups(ad_export_fh,groups,group_gids,base,xgroups=[],debug=0):
    if debug:
        logging.debug("groups "+str(len(groups)))
        logging.debug("gids " + str(len(group_gids)))
    gintersect=0

    for group in groups.keys():
        # groupOfNames requires at least one member.  If we've gotten
        # this far and still don't have any members, skip this group.
        if not groups[group] or len(groups[group]) < 1:
            continue
        if group in group_gids and group[-3:]!="-LS" and group not in xgroups:
            ad_export_fh.write("dn: cn="+group+",ou=group,"+base+'\n')
            ad_export_fh.write("cn: "+group+'\n')
            ad_export_fh.write("objectclass: groupOfNames"+'\n')
            ad_export_fh.write("objectclass: posixGroup"+'\n')
            ad_export_fh.write("gidnumber: "+str(group_gids[group])+'\n')

            print_members(ad_export_fh,groups[group],base)

            ad_export_fh.write('\n')

            gintersect+=1

    if debug:
        logging.debug("gintersect " + str(gintersect))


# open connection to LDAP server
def open_ldap(url,account='',password=''):
    ldap_server=ldap3.Server(url,get_info=ldap3.ALL)

    ldap_conn=ldap3.Connection(ldap_server,user=account,password=password)

    if not ldap_conn.bind():
       logging.error("open_ldap: error in bind %d" % ldap_conn.result)

    return ldap_conn


# generator encapulating paged ldap retrieval
def generate_ldap(ldap_conn,base,search_flt):
    ldap_entries=ldap_conn.extend.standard.paged_search(search_base=base,
        search_scope=ldap3.SUBTREE,
        attributes=ldap3.ALL_ATTRIBUTES,
        search_filter=search_flt)

    for crud in ldap_entries:
        if 'attributes' in crud:
            yield crud['attributes'] 

# if field is present in crud, add to uid
def add_user_field(uid,field,crud):
    if field in crud:
        uid[field]=crud[field]

def generate_members(member_list):
    for name in member_list:
        for gname in skip_split(name):
            if gname[:3].lower()=="cn=":
                yield gname.split('=')[1]
                break

# returns list of users, dictionary of groups by users
def retrieve_ldap_userinfo(user_conn,base,search_flt,misc_attrs):
    users={}
    groups={}
    dn_uid_map={}

    for crud in generate_ldap(user_conn,base,search_flt):
        if 'uid' in crud:
            current_user=crud['uid'][0]
            if current_user in users:
                logging.warning('duplicate uid %s!' % current_user)
            else:
                if 'memberOf' in crud:
                    for gmember in generate_members(crud['memberOf']):
                        if gmember in groups:
                            groups[gmember].append(current_user)
                        else:
                            groups[gmember]=[current_user]

                uid={}

                uid['uid']=current_user

                if 'uidNumber' in crud:
                    uid['uidNumber']=crud['uidNumber']
                elif 'employeeID' in crud:
                    uid['uidNumber']=crud['employeeID']

                for attr in ['gidNumber','unixHomeDirectory']+misc_attrs:
                    add_user_field(uid,attr,crud)
                
                if 'distinguishedName' in crud:
                    dn = crud['distinguishedName']
                    dn_uid_map[dn] = current_user
                users[current_user]=uid

    return users,groups,dn_uid_map

# returns dictionary of groups with members and gids 
# returns dictionary of all groups with members
def retrieve_ldap_groupinfo(ldap_conn,base,search_flt,ad_excluded_group):
    # dictionary of gids by group name
    ggroups={}

    # dictionary of distinguishedNames and members by group name
    dgroups={}

    # list of exclusions from group 'ExcludedFromLDAPSync'
    xgroups=[]

    for crud in generate_ldap(ldap_conn,base,search_flt):
        if 'name' in crud and 'member' in crud:
            if 'gidNumber' in crud:
                ggroups[crud['name']]=crud['gidNumber']
                if crud['name'] not in dgroups:
                    dgroups[crud['name']] = []

            if 'distinguishedName' in crud:
                dgroups[crud['name']]=\
                    [crud['distinguishedName'],crud['member']]

            # hardcode retrieval of this particular hack
            if crud['name']==ad_excluded_group:
                xgroups=[x for x in generate_members(crud['member'])]

    return ggroups,dgroups,xgroups


# confirms if all members of dn_list are substrings of group
def match_dn(dn_list,group):
    for dn in dn_list:
        if dn not in group:
            return False

    return True

# add user to parent_group if user is not already present
def add_user(cn,ugroups,parent_group,users,dn,dn_uid_map):
    if dn in dn_uid_map:
        current_user = dn_uid_map[dn]
        if current_user in users and 'uid' in users[current_user]:
            uid=users[current_user]['uid']
            if uid not in ugroups[parent_group]:
                ugroups[parent_group].append(uid)


# iterate through members of group adding users to parent group
# if any members are themselves groups, recursively call on group
def flatten_group(group,groups,ugroups,parent_group,users,dn_uid_map):
    for m in group[1]:
        # split dn into components removing '\ '
        m_dn=skip_split(m)

        # extract cn from dn as group key
        cn=m_dn[0].split('=')[1]

        # if there's a group with this cn
        if cn in groups:
            current_group=groups[cn]

            # compare member dn components with group dn components
            # if child group is root parent group, abort due to infinite loop
            if match_dn(m_dn,current_group[0]) and cn not in parent_group:
                flatten_group(current_group,groups,ugroups,parent_group+[cn],users,dn_uid_map)
        else:
            #bhodges
            # ['CN=Gow, Edward L', 'OU=Users', 'OU=Accounts', 'OU=CRD', 'DC=fhcrc', 'DC=org']
            # problem seems to be that if cn is not the username, it won't match 
            # in users keys.
            # cn is not necessarily distinct, so this may not be good anyway,
            # should use login id.  users really appears to use uid, so cn 
            # will therefor not match much of the time.
            add_user(cn,ugroups,parent_group[0],users,m,dn_uid_map)

# for each group with users
# check to see if any members are groups
# if so, get their members and add to parent
def flatten_groups(groups,ugroups,users,dn_uid_map):
    # for each group with users
    for g in ugroups.keys():
        if g not in groups:
            logging.warning('%s in users but not in groups' % g)
            continue
        
        # for each member in group 
        for m in groups[g][1]:
            # split dn into components removing '\ '
            m_dn=skip_split(m)

            # extract cn from dn as group key
            cn=m_dn[0].split('=')[1]

            # if there's a group with this cn
            if cn in groups:
                current_group=groups[cn]
                
                # compare member dn components with group dn components
                if match_dn(m_dn,current_group[0]):
                    flatten_group(current_group,groups,ugroups,[g],users,dn_uid_map)


def print_ldap_list(crud,attr):
    ldap_list = ''
    if attr in crud:
        for item in sorted(crud[attr]):
            ldap_list += attr+": "+item+'\n'
    return ldap_list


def retrieve_ldap_nisinfo(ldap_conn,base,dst_base,search_flt):
    lastmap=""
    auto_master=0

    nisinfo = ''

    for crud in generate_ldap(ldap_conn,base,search_flt):
        if 'nisNetgroup' in crud['objectClass']:
            if 'cn' in crud and 'nisNetgroupTriple' in crud:
                nisinfo += "dn: cn="+crud['cn']+",ou=netgroup,"+dst_base+'\n'
                nisinfo += print_ldap_list(crud,'objectClass')
                nisinfo += "cn: "+crud['cn']+'\n'
                nisinfo += print_ldap_list(crud,'nisNetgroupTriple')
        else:
            cn=crud['cn']
            if 'nisMap' in crud['objectClass']:
                nisinfo += "dn: nisMapName="+cn+",ou=autofs,"+dst_base+'\n'
                lastmap=",nisMapName="+cn
                if cn=="auto.master":
                    auto_master+=1
                    if auto_master>1:
                        logging.warning('Error: extra auto.master detected!')
            else:
                nisinfo += "dn: cn="+cn+lastmap+",ou=autofs,"+dst_base+'\n'
                  
            for entry in ['objectClass','nisMapName','nisMapEntry']:
                nisinfo += print_ldap_list(crud,entry)

        nisinfo += "\n"
    return nisinfo

def ldif_diff(config):
    """
    Compare current AD download with previous version and 
    create and ldif file that will perform actions to 
    bring the two version into sync.
    """
    tmp_dir = config['tmp_dir']

    ad_export_file = os.path.join(tmp_dir, config['ad_export_file'])
    ad_export_previous = os.path.join(tmp_dir, config['ad_export_previous'])
    delta_file = os.path.join(tmp_dir, config['delta_file'])

    old_ldif=parse_ldif(ad_export_previous)
    new_ldif=parse_ldif(ad_export_file)

    # Make sure number of OUs is equal between the comparison files.
    # If not, raise an exception.
    if len(old_ldif)!=len(new_ldif):
        raise ValueError('Number of OUs is different.  old: %s, new: %s' % (old_ldif.keys(), new_ldif.keys() ))
    else:
        with open(delta_file, 'w') as diff_file:

            # diff old against new for deletions      
            for ou_key,ou_value in old_ldif.items():
                logging.info("ad_export_previous -- ou: %s, count: %i" % (ou_key, len(old_ldif[ou_key])))
                for key,value in ou_value.items():
                    if key not in new_ldif[ou_key]:
                        write_entry(value[1],"delete",diff_file)

            # diff new against old for adds and changes
            for ou_key,ou_value in new_ldif.items():
                diffs=0
                logging.info("ad_export_file -- ou: %s, count: %i" % (ou_key, len(new_ldif[ou_key])))
                for key,value in ou_value.items():
                    if key not in old_ldif[ou_key]:
                        write_entry(value[1],"add",diff_file)
                    else:
                        if new_ldif[ou_key][key][0]!=old_ldif[ou_key][key][0]:
                            # easier to do modify as delete/add
                            write_entry(value[1],"delete",diff_file)
                            write_entry(value[1],"add",diff_file)
                            diffs+=1

                if diffs!=0:
                    logging.info("Diffs %s" % diffs)


def compute_hash(entry):
    m=hashlib.md5()
    for part in sorted(entry):
        m.update(part.encode('utf-8'))
    return(m.digest())


def parse_ldif(ldif_filename):
    """
    Build up a dictionary of entries
    """
    num_entries=0
    in_entry=0
    ldict={}

    p=re.compile('ou=(\S+)[,\n]')

    try:
        with open(ldif_filename,'r') as f:
            for line in f:
                line_s=line.strip()
                if in_entry==0:
                    if len(line_s)>0:
                        in_entry=1
                        num_entries+=1
                        
                        entry=[line_s]
 
                        if line_s[0:3]=='dn:': 
                            dn=line_s.split(' ',1)[1]
                            ou=p.findall(dn)[0]
                            if ou not in ldict:
                                ldict[ou]={}

                else:
                    if len(line_s)==0:
                        in_entry=0

                        if dn not in ldict[ou]:
                            ldict[ou][dn]=[compute_hash(entry),entry]
                        else:
                            logging.error("Duplicate %s in OU %s" % (dn, ou))

                    else:
                        entry.append(line_s)

        # if still in entry at end
        if in_entry:
            if dn not in ldict[ou]:
                ldict[ou][dn]=[compute_hash(entry),entry]
            else:
                logging.error("Duplicate DN %s in OU %s" % (dn, ou))

    except IOError:
        logging.error("Failed to open '%s'" % ldif_filename)

    logging.info("%s has %i entries" % (ldif_filename, num_entries))
    return ldict


def write_entry(entry,changetype,stream):
    """
    Format a ldif entry
    """
    for num,item in enumerate(entry):
        print(item,file=stream)
        #print >>stream,item
        if num==0:
            print("changetype:",changetype,file=stream)
            #print >>stream,"changetype:",changetype
            if changetype=="delete":
                break

    print(file=stream)
    #print >>stream

if __name__ == '__main__':
   sys.exit(main())
