#!/usr/bin/env python

# Created on Tue Mar 14 13:44:34 2023
# Author: XiaoTao Wang

## Required modules

import argparse, sys, logging, traceback, raichu, os

currentVersion = raichu.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''A cross-platform method for chromatin
                                     contact normalization.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('--cool-uri', help='''Cool URI.''')
    parser.add_argument('-C', '--chroms', nargs = '*', default = ['#', 'X'],
                        help = 'List of chromosome labels. Only Hi-C data within the specified '
                        'chromosomes will be included. Specially, "#" stands for chromosomes '
                        'with numerical labels. "--chroms" with zero argument will include '
                        'all chromosome data.')
    parser.add_argument('--regions', help='''Path to a BED file where each row specifies a genomic
                        region. If this file is provided, Raichu will restrict its calculation to
                        only the regions defined in the file. By default, the bias vector is calculated
                        for the whole genome.''')
    parser.add_argument('--window-size', default=200, type=int,
                        help='''Size of the sliding window used during calculation, specified in bins.''')
    parser.add_argument('--max-distance', default=200, type=int,
                        help='''Maximum genomic distance in bins. Only interactions with a distance less
                        than or equal to this value will be considered.''')
    parser.add_argument('--ignore-diags', default=0, type=int,
                        help='''Number of diagonals of the contact matrix to ignore.
                        Examples: 0 ignores nothing, 1 ignores the main diagonal, and 2
                        ignores diagonals (-1, 0, 1), etc.''')
    parser.add_argument('--min-nnz', default=10, type=int, 
                        help='''Ignore bins with less than this number of non-zero data points
                        on corresponding intra-chromosomal matrices.''')
    parser.add_argument('--maxiter', default=100, type=int,
                        help='''The maximum number of global search iterations for dual annealing.''')
    parser.add_argument('--lower-bound', default=0.001, type=float,
                        help='''Lower bound value of the search space.''')
    parser.add_argument('--upper-bound', default=1000, type=float,
                        help='''Upper bound value of the search space.''')
    parser.add_argument('-p', '--nproc', default=4, type=int,
                        help='Number of processes to be allocated.')
    parser.add_argument('-t', '--threads', default=2, type=int,
                        help='Number of threads.')
    parser.add_argument('-n', '--name', default='obj_weight',
                        help='''Name of the column to write to.''')
    parser.add_argument('-f', '--force', action = 'store_true',
                        help = '''When specified, overwrite the target bias vector.''')
    parser.add_argument('--logFile', default = '3Dnorm.log', help = '''Logging file name.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.FileHandler(args.logFile)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Input Cool URI = {0}'.format(args.cool_uri),
                   '# Included Chromosomes = {0}'.format(args.chroms),
                   '# Region File Path = {0}'.format(args.regions),
                   '# Window Size = {0}'.format(args.window_size),
                   '# Maximum Genomic Distance = {0}'.format(args.max_distance),
                   '# Number of Diagonals to Ignore = {0}'.format(args.ignore_diags),
                   '# Maximum Iter = {0}'.format(args.maxiter),
                   '# Lower Bound = {0}'.format(args.lower_bound),
                   '# Upper Bound = {0}'.format(args.upper_bound),
                   '# Column Name = {0}'.format(args.name),
                   '# Number of Processes = {0}'.format(args.nproc),
                   '# Number of Threads = {0}'.format(args.threads),
                   '# Log File Name = {0}'.format(args.logFile)
                   ]
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)

        import cooler
        import numpy as np
        from raichu import util
        from raichu.util import pipeline, calculate_expected, write_weights, load_BED
        from joblib import Parallel, delayed

        #util.numba.set_num_threads(args.threads)
        
        try:
            # input params
            cool_uri = args.cool_uri
            ws_bin = args.window_size
            maxiter = args.maxiter
            column_name = args.name
            nproc = args.nproc
            ndiag = args.ignore_diags
            min_nnz = args.min_nnz
            n_threads = args.threads
            lb = np.log10(args.lower_bound)
            ub = np.log10(args.upper_bound)

            # prepare for the calculation
            cool_path, group_path = cooler.util.parse_cooler_uri(cool_uri)
            clr = cooler.Cooler(cool_uri)
            if not args.regions is None:
                included_bins = load_BED(args.regions, clr.binsize)
            else:
                included_bins = None
            max_dis = args.max_distance * clr.binsize # maximum distance in the unit of base pairs
            if (not column_name in clr.bins()) or args.force:
                logger.info('Calculating the average contact frequency for each genomic distance ...')
                chroms = []
                for c in clr.chromnames:
                    chromlabel = c.lstrip('chr')
                    if (not args.chroms) or (chromlabel.isdigit() and '#' in args.chroms) or (chromlabel in args.chroms):
                        chroms.append(c)
                Ed = calculate_expected(clr, chroms, included_bins, max_dis=max_dis, nproc=nproc)
                logger.info('Done')

                # initialize the bias vector
                bias = np.zeros(clr.info['nbins'], dtype=np.float64)
                logger.info('Optimizing the bias vector using dual annealing ...')
                queue = []
                for c in chroms:
                    if included_bins is None:
                        queue.append((clr, c, Ed, ws_bin, included_bins, ndiag, lb, ub, maxiter, min_nnz, n_threads))
                    else:
                        if c in included_bins:
                            queue.append((clr, c, Ed, ws_bin, included_bins[c], ndiag, lb, ub, maxiter, min_nnz, n_threads))
                collect = Parallel(n_jobs=nproc)(delayed(pipeline)(*i) for i in queue)
                for weights, indices in collect:
                    bias[indices] = 1 / (10 ** weights)
                write_weights(bias, cool_path, group_path, column_name)
                logger.info('Done')
            else:
                logger.info('{0} column already exists, exit'.format(column_name))
        except:
            traceback.print_exc(file = open(args.logFile, 'a'))
            sys.exit(1)

if __name__ == '__main__':
    run()