#!python

# Created on Tue Mar 14 13:44:34 2023
# Author: XiaoTao Wang

## Required modules

import argparse, sys, logging, traceback, raichu, os

currentVersion = raichu.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''A cross-platform method for chromatin
                                     contact normalization.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('--cool-uri', help='''Cool URI.''')
    parser.add_argument('-C', '--chroms', nargs = '*', default = ['#', 'X'],
                        help = 'List of chromosome labels. Only Hi-C data within the specified '
                        'chromosomes will be included. Specially, "#" stands for chromosomes '
                        'with numerical labels. "--chroms" with zero argument will include '
                        'all chromosome data.')
    parser.add_argument('--window-size', default=200, type=int,
                        help='''The window size in the unit of pixels.''')
    parser.add_argument('--ignore-diags', default=0, type=int,
                        help='''Number of diagonals of the contact matrix to ignore.
                        Examples: 0 ignores nothing, 1 ignores the main diagonal, and 2
                        ignores diagonals (-1, 0, 1), etc.''')
    parser.add_argument('--min-nnz', default=10, type=int, 
                        help='''Ignore bins with less than this number of non-zero data points
                        on corresponding intra-chromosomal matrices.''')
    parser.add_argument('--maxiter', default=100, type=int,
                        help='''The maximum number of global search iterations for dual annealing.''')
    parser.add_argument('--lower-bound', default=0.001, type=float,
                        help='''Lower bound value of the search space.''')
    parser.add_argument('--upper-bound', default=1000, type=float,
                        help='''Upper bound value of the search space.''')
    parser.add_argument('-p', '--nproc', default=4, type=int,
                        help='Number of processes to be allocated.')
    parser.add_argument('-t', '--threads', default=2, type=int,
                        help='Number of threads.')
    parser.add_argument('-n', '--name', default='obj_weight',
                        help='''Name of the column to write to.''')
    parser.add_argument('-f', '--force', action = 'store_true',
                        help = '''When specified, overwrite the target bias vector.''')
    parser.add_argument('--logFile', default = '3Dnorm.log', help = '''Logging file name.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.FileHandler(args.logFile)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Input Cool URI = {0}'.format(args.cool_uri),
                   '# Included Chromosomes = {0}'.format(args.chroms),
                   '# Window Size = {0}'.format(args.window_size),
                   '# Number of Diagonals to Ignore = {0}'.format(args.ignore_diags),
                   '# Maximum Iter = {0}'.format(args.maxiter),
                   '# Lower Bound = {0}'.format(args.lower_bound),
                   '# Upper Bound = {0}'.format(args.upper_bound),
                   '# Column Name = {0}'.format(args.name),
                   '# Number of Processes = {0}'.format(args.nproc),
                   '# Number of Threads = {0}'.format(args.threads),
                   '# Log File Name = {0}'.format(args.logFile)
                   ]
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)

        import cooler
        import numpy as np
        from raichu import util
        from raichu.util import pipeline, calculate_expected, write_weights
        from joblib import Parallel, delayed

        #util.numba.set_num_threads(args.threads)
        
        try:
            # input params
            cool_uri = args.cool_uri
            ws_bin = args.window_size
            maxiter = args.maxiter
            column_name = args.name
            nproc = args.nproc
            ndiag = args.ignore_diags
            min_nnz = args.min_nnz
            n_threads = args.threads
            lb = np.log10(args.lower_bound)
            ub = np.log10(args.upper_bound)

            # prepare for the calculation
            cool_path, group_path = cooler.util.parse_cooler_uri(cool_uri)
            clr = cooler.Cooler(cool_uri)
            ws_bp = ws_bin * clr.binsize # window size in the unit of base pairs
            if (not column_name in clr.bins()) or args.force:
                logger.info('Calculating the average contact frequency for each genomic distance ...')
                chroms = []
                for c in clr.chromnames:
                    chromlabel = c.lstrip('chr')
                    if (not args.chroms) or (chromlabel.isdigit() and '#' in args.chroms) or (chromlabel in args.chroms):
                        chroms.append(c)
                Ed = calculate_expected(clr, chroms, max_dis=ws_bp, nproc=nproc)
                logger.info('Done')

                # initialize the bias vector
                bias = np.zeros(clr.info['nbins'], dtype=np.float64)
                logger.info('Optimizing the bias vector using dual annealing ...')
                queue = []
                for c in chroms:
                    queue.append((clr, c, Ed, ws_bin, ndiag, lb, ub, maxiter, min_nnz, n_threads))
                collect = Parallel(n_jobs=nproc)(delayed(pipeline)(*i) for i in queue)
                for weights, indices in collect:
                    bias[indices] = 1 / (10 ** weights)
                write_weights(bias, cool_path, group_path, column_name)
                logger.info('Done')
            else:
                logger.info('{0} column already exists, exit'.format(column_name))
        except:
            traceback.print_exc(file = open(args.logFile, 'a'))
            sys.exit(1)

if __name__ == '__main__':
    run()