#!/usr/bin/env python

# Created on Sun Nov 15 10:05:27 2015

# Author: XiaoTao Wang
# Organization: HuaZhong Agricultural University

import argparse, sys, logging, logging.handlers, random, copy, os, datetime, traceback
import numpy as np
import calTADs
from calTADs import myshelve

currentVersion = calTADs.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(usage = '%(prog)s <-O output> [options]',
                                     description = 'TAD identification.',
                                     formatter_class = argparse.ArgumentDefaultsHelpFormatter)
    
    parser.add_argument('-v', '--version', action = 'version',
                        version = ' '.join(['%(prog)s', currentVersion]),
                        help = 'Print version number and exit')

    # Output
    parser.add_argument('-O', '--output', help = 'Prefix of the generated TAD file.')
    
    ## Related to Hi-C data
    parser.add_argument('-p', '--path', default = '.',
                        help = 'Path to the folder with Hi-C data. Support both absolute'
                               ' and relative path.')
    parser.add_argument('-F', '--Format', default = 'NPZ', choices = ['TXT', 'NPZ'],
                        help = 'Format of source data file')
    parser.add_argument('-R', '--resolution', default = 5000, type = int,
                        help = 'Resolution of the binned data')
    parser.add_argument('-T', '--template', default = 'chr%s_chr%s.int',
                        help = 'Template for matching file names using regular expression.'
                        ' Needed when "--Format" is set to be "TXT". Note only within-chromosome'
                        ' data are allowed, and don\'t place inter-chromosome data under the '
                        'folder.')
    parser.add_argument('-C', '--chroms', nargs = '*', default = ['#', 'X'],
                        help = 'Which chromosomes to read. Specially, "#" stands'
                        ' for chromosomes with numerical labels. "--chroms" with zero argument'
                        ' will generate an empty list, in which case all chromosome data will'
                        ' be loaded.')
    parser.add_argument('-c', '--cols', nargs = 3, type = int,
                        help = 'Which 3 columns to read, with 0 being the first. For example,'
                        ' "--cols 1 3 4" will extract the 2nd, 4th and 5th columns. Only '
                        'required when "--Format=TXT".')
    parser.add_argument('--immortal', action = 'store_true',
                        help = 'When specified, a Numpy .npz file will be generated under the same '
                        'folder. This operation will greatly speed up data loading process next'
                        ' time.')
    parser.add_argument('-P', '--prefix', help = 'Prefix of input .npz file name, path not'
                        ' included. Required when "--Format=NPZ".')
    parser.add_argument('-S', '--saveto', help = 'Prefix of output .npz file name, path '
                        'not included. Required with "--immortal".')
    
    ## Related to TAD calculation
    parser.add_argument('-w', '--window', type = int, nargs = '+', default = [30, 60, 100, 200, 400],
                        help = '''Window size used in DI (Directionality Index) calculation.
                        It tells how far we need to look at the interaction patterns of a
                        given bin. Unit: RESOLUTION.''')
    parser.add_argument('-g', '--maxgaplen', type = int, default = 5,
                        help = '''Maximum gap size we can tolerate in a domain. Unit: RESOLUTION''')
    parser.add_argument('-m', '--minsize', type = int, default = 20,
                        help = '''Minimum size of a domain. Unit: RESOLUTION''')
    
    ## Related to HMM model training process
    parser.add_argument('--maxSteps', type = int, default = 500,
                        help = '''Maximum number of Baum-Welch iteration steps.''')
    parser.add_argument('--likelihoodcutoff', type = float, default = 0.0001,
                        help = '''The least relative improvement in likelihood with respect to
                        the last Baum-Welch iteration required to continue.''')
    parser.add_argument('--logFile', default = 'calTADs.log', help = '''Logging file name.''')
    
    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    args.Format = args.Format.upper()

    ## Conflicts
    if (args.Format == 'TXT') and (not args.cols):
        parser.print_help()
        parser.exit(1, 'You have to specify "--cols" with "--Format TXT"!')
    if (args.Format == 'NPZ') and (not args.prefix):
        parser.print_help()
        parser.exit(1, 'You have to specify "--prefix" with "--Format NPZ"!')
    if args.immortal and (not args.saveto):
        parser.print_help()
        parser.exit(1, 'You have to specify "--saveto" with "--immortal" flag!')
    
    return args, commands

def run():
     # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(20)
        
        filehandler = logging.handlers.RotatingFileHandler(args.logFile,
                                                           maxBytes = 500000,
                                                           backupCount = 5)
        filehandler.setLevel('INFO')
        formatter = logging.Formatter(fmt = '%(name)-14s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        filehandler.setFormatter(formatter)
        logger.addHandler(filehandler)
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# output file name = %s' % args.output,
                   '# data folder = %s' % args.path,
                   '# Hi-C format = %s' % args.Format,
                   '# chromosomes = %s' % args.chroms
                   ]
        if args.Format == 'TXT':
            arglist.extend(['# Hi-C file template = %s' % args.template,
                            '# which columns = %s' % args.cols])
        if args.Format == 'NPZ':
            arglist.extend(['# input NPZ prefix = %s' % args.prefix])
        
        arglist.extend(['# data resolution = %s' % args.resolution,
                        '# window size = %s' % args.window,
                        '# maximum gap size = %s' % args.maxgaplen,
                        '# minimum domain size = %s' % args.minsize,
                        '# maximum Baum-Welch steps = %s' % args.maxSteps,
                        '# likelihood cutoff = %s' % args.likelihoodcutoff])
        if args.immortal:
            arglist.append('# output NPZ prefix = %s' % args.saveto)
        arglist.append('# Log file name = %s' % args.logFile)
        
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)
        
        ### Generate the folder for task queue system if not yet
        des = os.path.abspath(os.path.expanduser('~/.SubDomainCaller'))
        if not os.path.exists(des):
            os.mkdir(des)
        key = str((args.path, args.Format, args.resolution, args.prefix, args.template, args.cols,
                   sorted(args.chroms), sorted(args.window), args.maxgaplen, args.minsize,
                   args.maxSteps, args.likelihoodcutoff))
        dbname = os.path.join(des, 'TaskQueues')
        register(dbname, key)
        
        minregion = max(20, 200000 / args.resolution)
        
        dicts = {'path': args.path, 'Format': args.Format, 'resolution': args.resolution,
                 'template': args.template, 'chroms': args.chroms, 'cols': args.cols,
                 'prefix': args.prefix, 'immortal': args.immortal, 'saveto': args.saveto}
        
        from tadlib import analyze
        
        DI_filename = '.'.join([args.output, 'modifiedDI', 'txt'])
        args.window = sorted(args.window)
        fields = ['-'.join(['DI', str(i)]) for i in args.window]
        title = ['chr','start','end'] + fields
        pooledTP = np.dtype({'names': title,
                             'formats':['S2', np.int32, np.int32] + [np.float] * len(args.window)})
        Stage_1 = getJobState(dbname, key, 'Stage-1')
        if (Stage_1 == 0) or (not os.path.exists(DI_filename)):
            logger.info('Calculating DI ...')
            changeJobState(dbname, key, 'Stage-1', state = 1)
            try:
                workInters = analyze.Inters(**dicts)
                writeout = open(DI_filename, 'wb')
                writeout.write('# Resolution: {:d}\n'.format(args.resolution))
                writeout.write('# ' + '\t'.join(title) + '\n')
                chroms = workInters.labels
                for c in chroms:
                    logger.info('  Chromosome %s ...', c)
                    idata = workInters.data[c]
                    Len = idata['bin2'].max() + 1
                    multiDI = np.zeros(Len, dtype = pooledTP)
                    multiDI['chr'] = c
                    multiDI['start'] = np.arange(Len) * args.resolution
                    multiDI['end'] = np.arange(1, Len + 1) * args.resolution
                    for i in args.window:
                        field_name = '-'.join(['DI', str(i)])
                        mDI = modDI(idata, i)
                        multiDI[field_name] = mDI
                    np.savetxt(writeout, multiDI, delimiter = '\t', fmt = ['%s','%d','%d'] + ['%.7e'] * len(args.window))
                    changeJobState(dbname, key, 'Stage-2', chrom = c, state = 0)
                    changeJobState(dbname, key, 'Stage-3', chrom = c, state = 0)
                writeout.flush()
                writeout.close()
                DIs = np.loadtxt(DI_filename, dtype = pooledTP)
                logger.info('Done!')
            except:
                changeJobState(dbname, key, 'Stage-1', state = 0)
                traceback.print_exc(file = open(args.logFile, 'a'))
                sys.exit(1)
                
            changeJobState(dbname, key, 'Stage-1', state = 2)
        else:
            while Stage_1 != 2:
                Stage_1 = getJobState(dbname, key, 'Stage-1')
            logger.info('DI file already exists, load it ...')
            DIs = np.loadtxt(DI_filename, dtype = pooledTP)
        
        chroms = getChroms(dbname, key)
        
        logger.info('Start domain calling process ...')
        tempdir = '-'.join([args.output, 'domainpieces'])
        if not os.path.exists(tempdir):
            os.mkdir(tempdir)
        pieces = {}
        for c in chroms:
            logger.info('  Chromosome %s ...', c)
            chromDI = DIs[DIs['chr'] == c]
            logger.info('    Split chromosome by large gaps ...')
            regionDIs, forceStates = splitChrom(chromDI, args.maxgaplen, minregion, fields)
            for region in regionDIs:
                tempfile = '.'.join([args.output, c, str(region[0]*args.resolution),
                                     str(region[1]*args.resolution), 'hmmdomain', 'txt'])
                pieces[(c, region)] = os.path.join(tempdir, tempfile)
            logger.info('    Done!')
            Stage_2 = getJobState(dbname, key, 'Stage-2', chrom = c)
            if Stage_2 == 0:
                logger.info('    Training HMM model ...')
                changeJobState(dbname, key, 'Stage-2', chrom = c, state = 1)
                try:
                    A, B, pi = learning(regionDIs, args.maxSteps, args.likelihoodcutoff)
                    for region in regionDIs:
                        changeJobState(dbname, key, 'Stage-3', chrom = c, state = 0, region = region)
                except:
                    changeJobState(dbname, key, 'Stage-2', chrom = c, state = 0)
                    traceback.print_exc(file = open(args.logFile, 'a'))
                    sys.exit(1)
                    
                changeJobState(dbname, key, 'Stage-2', chrom = c, state = 2, params = (A, B, pi))
                logger.info('    Done!')
            else:
                while Stage_2 != 2:
                    Stage_2 = getJobState(dbname, key, 'Stage-2', chrom = c)
                logger.info('    HMM parameters have been learned by others, load them in ...')
                A, B, pi = getParams(dbname, key, 'Stage-2', chrom = c)
                logger.info('    Done!')
            
            logger.info('    Calling hierarchical domains using different scale DI tracks ...')
            for region in regionDIs:
                logger.info('      Current region: chr%s, %d - %d', c,
                            region[0]*args.resolution, region[1]*args.resolution)
                Stage_3 = getJobState(dbname, key, 'Stage-3', chrom = c, region = region)
                if Stage_3 == 0:
                    changeJobState(dbname, key, 'Stage-3', chrom = c, state = 1, region = region)
                    try:
                        hierarchyCaller(A, B, pi, regionDIs[region], 0, regionDIs[region].shape[0],
                                        forceStates[region], regionDIs[region].shape[1] - 1)
                        domains = genfromStates(forceStates[region], fields, c, region[0], args.resolution, args.minsize)
                        tempfile = pieces[(c, region)]
                        outputDomains(tempfile, domains, args.resolution)
                    except:
                        changeJobState(dbname, key, 'Stage-3', chrom = c, state = 0, region = region)
                        traceback.print_exc(file = open(args.logFile, 'a'))
                        sys.exit(1)
                        
                    changeJobState(dbname, key, 'Stage-3', chrom = c, state = 2, region = region)
                else:
                    if Stage_3 == 1:
                        logger.info('      Someone is working on this region, skipping ...')
                    elif Stage_3 == 2:
                        logger.info('      Completed region, skipping ...')
                    continue
            
        check = checkForComplete(dbname, key)
        if check:
            domainfile = '.'.join([args.output, 'hmmdomain', 'txt'])
            logger.info('Output domain positions into %s ...', domainfile)
            Stage_4 = getJobState(dbname, key, 'Stage-4')
            if Stage_4 == 0:
                changeJobState(dbname, key, 'Stage-4', state = 1)
                try:
                    piecefiles = pieces.values()
                    domains = [line.rstrip().split() for f in piecefiles for line in open(f) if not line.startswith('#')]
                    outputDomains(domainfile, domains, args.resolution)
                    command = ['rm', '-rf', tempdir]
                    os.system(' '.join(command))
                    logger.info('Done!')
                except:
                    changeJobState(dbname, key, 'Stage-4', state = 0)
                    traceback.print_exc(file = open(args.logFile, 'a'))
                    sys.exit(1)
                    
                changeJobState(dbname, key, 'Stage-4', state = 2)
            else:
                if Stage_4 == 1:
                    logger.info('Conflicted job, exit')
                elif Stage_4 == 2:
                    logger.info('Completed job, exit')
        else:
            logger.info('Uncompleted job, exit')

def register(dbname, job):
    
    db = myshelve.open(dbname)
    # Clean the database
    if len(db) > 100:
        now = datetime.date.today()
        for key in db:
            deltadate = now - db[key]['timestamp']
            deltaday = deltadate.days
            if deltaday > 10:
                del db[key]
        
    if job not in db:
        temp = {}
        temp['timestamp'] = datetime.date.today()
        # Job state: 0 -- ready, 1 -- running, 2 -- completed
        temp['Stage-1'] = 0
        temp['Stage-2'] = {}
        temp['Stage-3'] = {}
        temp['Stage-4'] = 0
        db[job] = temp
        
    db.close()

def getJobState(dbname, top, stage, chrom = None, region = None):
    
    db = myshelve.open(dbname, 'r')
    if (region is None) and (chrom is None):
        state = db[top][stage]
    else:
        if (not chrom is None) and (region is None):
            state = db[top][stage][chrom]['state']
        elif (not chrom is None) and (not region is None):
            state = db[top][stage][chrom][region]
        
    db.close()
    
    return state

def getChroms(dbname, top):
    
    db = myshelve.open(dbname, 'r')
    chroms = db[top]['Stage-2'].keys()
    db.close()
    
    return chroms

def changeJobState(dbname, top, stage, chrom = None, region = None, state = 0, params = None):
    
    db = myshelve.open(dbname)
    temp = db[top]
    if (region is None) and (chrom is None):
        temp[stage] = state
    else:
        if (not chrom is None) and (region is None):
            if chrom not in temp[stage]:
                temp[stage][chrom] = {}
            temp[stage][chrom]['state'] = state
            if not params is None:
                temp[stage][chrom]['params'] = params
        elif (not chrom is None) and (not region is None):
            temp[stage][chrom][region] = state
            allregcom = all([(temp[stage][chrom][r]==2) for r in temp[stage][chrom] if r != 'state'])
            if allregcom:
                temp[stage][chrom]['state'] = 2
    
    db[top] = temp
    
    db.close()

def getParams(dbname, top, stage, chrom):
    
    db = myshelve.open(dbname, 'r')
    params = db[top][stage][chrom]['params']
    db.close()    
    
    return params

def checkForComplete(dbname, top):
    
    db = myshelve.open(dbname, 'r')
    temp = db[top]
    allchrom = all([(temp['Stage-3'][c]['state']==2) for c in temp['Stage-3']])
    db.close()
    
    return allchrom

def splitChrom(chromDI, maxgaplen, minregion, fields):
    """
    Split a chromosome into gap-free regions.
    """
    
    DI_values = np.r_[[chromDI[i] for i in fields]]
    DI_values = DI_values.T
    valid_pos = np.where(np.all(DI_values != 0, axis = 1))[0]
    gapsizes = valid_pos[1:] - valid_pos[:-1]
    endsIdx = np.where(gapsizes > (maxgaplen + 1))[0]
    startsIdx = endsIdx + 1
    
    chromRegions = {}
    forceStates = {}
    for i in range(startsIdx.size - 1):
        start = valid_pos[startsIdx[i]]
        end = valid_pos[endsIdx[i + 1]] + 1
        if end - start > minregion:
            chromRegions[(start, end)] = DI_values[start:end]
            forceStates[(start, end)] = np.zeros(chromRegions[(start, end)].shape,
                                                 dtype = np.int8) - 1
    if startsIdx.size > 0:
        start = valid_pos[startsIdx[-1]]
        end = valid_pos[-1] + 1
        if end - start > minregion:
            chromRegions[(start, end)] = DI_values[start:end]
            forceStates[(start, end)] = np.zeros(chromRegions[(start, end)].shape,
                                                 dtype = np.int8) - 1
        start = valid_pos[0]
        end = valid_pos[endsIdx[0]] + 1
        if end - start > minregion:
            chromRegions[(start, end)] = DI_values[start:end]
            forceStates[(start, end)] = np.zeros(chromRegions[(start, end)].shape,
                                                 dtype = np.int8) - 1
    
    return chromRegions, forceStates 
            
def modDI(idata, window):
    """
    Calculate modified DI, which considering local standard deviation.
    """
    chromLen = idata['bin2'].max() + 1
    mask = ((idata['bin2'] - idata['bin1']) <= window) & (idata['bin2'] != idata['bin1'])
    idata = idata[mask]
    
    Rbound = idata['bin2'].max()
    Len = Rbound + 1
    
    # Sum of downstream interactions
    sumD = np.bincount(idata['bin1'], weights = idata['IF'])
    # Number of downstream interactions
    numD = np.bincount(idata['bin1'])
    # Average downstream interactions
    meanD = np.zeros_like(sumD)
    numfilter1 = numD > 2
    meanD[numfilter1] = sumD[numfilter1] / numD[numfilter1]
    # Sum of upstream interactions
    sumU = np.bincount(idata['bin2'], weights = idata['IF'])
    # Number of upstream interactions
    numU = np.bincount(idata['bin2'])
    # Average downstream interactions
    meanU = np.zeros_like(sumU)
    numfilter2 = numU > 2
    meanU[numfilter2] = sumU[numfilter2] / numU[numfilter2]
    
    ## Pooled standard deviation
    squareIF = idata['IF'] ** 2
    sumSquareD = np.bincount(idata['bin1'], weights = squareIF)
    sumSquareU = np.bincount(idata['bin2'], weights = squareIF)
    meanSquareD = np.zeros_like(sumSquareD); meanSquareU = np.zeros_like(sumSquareU)
    meanSquareD[numfilter1] = sumSquareD[numfilter1] / numD[numfilter1]
    meanSquareU[numfilter2] = sumSquareU[numfilter2] / numU[numfilter2]
    SD_1 = np.zeros_like(meanSquareD); SD_2 = np.zeros_like(meanSquareU)
    SD_1[numfilter1] = (meanSquareD[numfilter1] - meanD[numfilter1] ** 2) / (numD[numfilter1] - 1)
    SD_2[numfilter2] = (meanSquareU[numfilter2] - meanU[numfilter2] ** 2) / (numU[numfilter2] - 1)
    
    cmeanU = np.zeros(Len)
    cmeanD = np.zeros(Len)
    cSD_1 = np.zeros(Len)
    cSD_2 = np.zeros(Len)
    cmeanU[(Len-meanU.size):] = meanU; cmeanD[:meanD.size] = meanD
    cSD_1[:SD_1.size] = SD_1; cSD_2[(Len-SD_2.size):] = SD_2
    SD_Pool = np.sqrt(cSD_1 + cSD_2)
    
    # Modified DI
    mDI = np.zeros(cmeanU.size)
    badMask = SD_Pool != 0
    mDI[badMask] = (cmeanD[badMask] - cmeanU[badMask]) / SD_Pool[badMask]
    
    corrected = np.zeros(chromLen)
    corrected[:Len] = mDI
    
    return corrected
    
    
def learning(regionDIs, maxIter = 500, likelihood = 0.0001):
    """
    Train HMM model using all DI tracks from all regions of a chromosome.
    """
    import ghmm
    
    F = ghmm.Float()
    
    seqs = []
    for region in regionDIs:
        for track in range(regionDIs[region].shape[1]):
            temp = []
            for i in range(regionDIs[region].shape[0]):
                temp.append(regionDIs[region][i, track])
            seqs.append(temp)
    
    random.shuffle(seqs)
    train_seqs = ghmm.SequenceSet(F, seqs)
    
    # Hidden States: 0 -- start, 1 -- downstream, 2 -- upstream, 3 -- end
    A = [[0., 1., 0., 0.], [0., 0.5, 0.5, 0.],
         [0., 0., 0.5, 0.5], [1., 0., 0., 0.]]
    pi = [0.2, 0.3, 0.3, 0.2]
    numdists = 3 # Three-distribution Gaussian Mixtures
    W = 1. / numdists
    var = 7.5 / (numdists - 1)
    means = [[], [], [], []]
    for i in range(numdists):
        means[0].append(i * 7.5 / ( numdists - 1 ) + 2.5)
        means[1].append(i * 7.5 / ( numdists - 1 ))
        means[2].append(-i * 7.5 / ( numdists - 1 ))
        means[3].append(-i * 7.5 / ( numdists - 1 ) - 2.5)
    B = [[means[i], [var for j in range(numdists)], [W for k in range(numdists)]] for i in range(4)]
    
    # Training ...
    model = ghmm.HMMFromMatrices(F, ghmm.GaussianMixtureDistribution(F), A, B, pi)
    model.baumWelch(train_seqs, maxIter, likelihood)
    
    A = np.zeros((4, 4))
    B = np.zeros((4, 3, numdists))
    pi = np.zeros(4)
    
    for i in range(4):
        for j in range(4):
            A[i, j] = model.getTransition(i, j)
        for j in range(numdists):
            temp = model.getEmission(i, j)
            for k in range(3):
                B[i, k, j] = temp[k]
        
        pi[i] = model.getInitial(i)
    
    return A, B, pi

def hierarchyCaller(A, B, pi, DIs, start, stop, forceStates, level):
    
    ini_path = viterbi(A, B, pi, DIs[start:stop, level], forceStates[start:stop, level])[1]
    current_bounds = genbounds(ini_path)
    
    ## Move close together bounds to adjacent
    for i in range(len(current_bounds) - 1):
        if current_bounds[i + 1][0] - current_bounds[i][1] <= 3:
            midpoint = int(np.floor((current_bounds[i + 1][0] + current_bounds[i][1]) / 2.0))
            current_bounds[i][1] = midpoint
            current_bounds[i + 1][0] = midpoint + 1
    
    ## Optimal positioning ...
    maxiter = 5
    con = 0
    search_space = 10
    for num_iter in range(maxiter):
        if con:
            break
        temp_bounds = copy.deepcopy(current_bounds)
        paired = 0
        for i in range(len(current_bounds)):
            if (i != 0) or (level == DIs.shape[1] - 1):
                if (current_bounds[i][0] != -1) and (paired == 0):
                    if i > 0:
                        test_start = max(int(np.ceil((current_bounds[i-1][1] - current_bounds[i][0])/2)),
                                         current_bounds[i][0] - search_space) + 1
                    else:
                        test_start = max(0, current_bounds[i][0] - search_space) + 1
                    if current_bounds[i][1] != -1:
                        test_stop = min(int(np.floor((current_bounds[i][0] + current_bounds[i][1])/2)),
                                        current_bounds[i][0] + search_space)
                    else:
                        test_stop = min(stop-start, current_bounds[i][0] + search_space)
                    
                    temp_costs = np.zeros(test_stop - test_start)
                    for j in range(temp_costs.shape[0]):
                        temp_bounds[i][0] = test_start + j
                        temp_costs[j] = domain_cost(A, B, pi, DIs[:,0:(level+1)],
                                                    temp_bounds, start, stop)
                    
                    temp_bounds[i][0] = temp_costs.argmin() + test_start
            
            if (i != len(current_bounds)-1) or (level == DIs.shape[1]-1):
                if current_bounds[i][1] != -1:
                    if current_bounds[i][0] != -1:
                        test_start = max(int(np.ceil((current_bounds[i][1] + current_bounds[i][0])/2)),
                                         current_bounds[i][1] - search_space) + 1
                    else:
                        test_start = max(0, current_bounds[i][1] - search_space) + 1
                    if (i < len(current_bounds)-1) and (current_bounds[i+1][0] - current_bounds[i][1] == 1):
                        if current_bounds[i+1][1] != -1:
                            test_stop = min(int(np.floor((current_bounds[i+1][1] + current_bounds[i][1])/2)),
                                            current_bounds[i][1] + search_space)
                        else:
                            test_stop = min(stop - 2 - start, current_bounds[i][1] + search_space)
                        
                        temp_costs = np.zeros(test_stop - test_start)
                        paired = 1
                        for j in range(temp_costs.shape[0]):
                            temp_bounds[i][1] = test_start + j
                            temp_bounds[i + 1][0] = test_start + j + 1
                            temp_costs[j] = domain_cost(A, B, pi, DIs[:,0:(level+1)], temp_bounds, start, stop)
                        temp_bounds[i][1] = temp_costs.argmin() + test_start
                        temp_bounds[i + 1][0] = temp_bounds[i][1] + 1
                    else:
                        if i < len(current_bounds)-1:
                            test_stop = min(int(np.floor((current_bounds[i+1][0] + current_bounds[i][1])/2)),
                                            current_bounds[i][1] + search_space)
                        else:
                            test_stop = min(stop - 1 - start, current_bounds[i][1] + search_space)
                        
                        temp_costs = np.zeros(test_stop - test_start)
                        paired = 0
                        for j in range(temp_costs.shape[0]):
                            temp_bounds[i][1] = test_start + j
                            temp_costs[j] = domain_cost(A, B, pi, DIs[:,0:(level+1)], temp_bounds, start, stop)
                        temp_bounds[i][1] = temp_costs.argmin() + test_start
        
        if temp_bounds == current_bounds:
            con = 1
        current_bounds = copy.deepcopy(temp_bounds)
    
    ## Update *forceStates*
    for i in range(len(current_bounds)):
        if current_bounds[i][0] != -1:
            forceStates[(start + current_bounds[i][0]), 0:(level+1)] = 0
        if current_bounds[i][1] != -1:
            forceStates[(start + current_bounds[i][1]), 0:(level+1)] = 3
    
    if level == 0:
        return
    
    for i in range(len(current_bounds)):
        if current_bounds[i][0] == -1:
            if current_bounds[i][1] == -1:
                hierarchyCaller(A, B, pi, DIs, start, stop, forceStates, level - 1)
            else:
                hierarchyCaller(A, B, pi, DIs, start, start + current_bounds[i][1]+1, forceStates, level - 1)
        else:
            if current_bounds[i][1] == -1:
                hierarchyCaller(A, B, pi, DIs, start + current_bounds[i][0], stop, forceStates, level - 1)
            else:
                hierarchyCaller(A, B, pi, DIs, start + current_bounds[i][0],
                                start + current_bounds[i][1] + 1, forceStates, level -1 )
        
        if (i < len(current_bounds) - 1) and (current_bounds[i+1][0] - current_bounds[i][1] > 1):
            hierarchyCaller(A, B, pi, DIs, start + current_bounds[i][1],
                            start + current_bounds[i+1][0] + 1, forceStates, level -1)
    
    return

def genbounds(path):
    """
    Extract the boundary states from original state series.
    """
    bounds = []
    if path[0] > 0:
        bounds.append([-1])
    elif path[0] == 0:
        bounds.append([0])
    
    for i in range(1, len(path)):
        if path[i] == 3:
            bounds[-1].append(i)
        elif path[i] == 0:
            bounds.append([i])
    
    if len(bounds[-1]) == 1:
        bounds[-1].append(-1)
    
    calibrated = [b for b in bounds if len(b) > 1]
    
    return calibrated

def domain_cost(A, B, pi, seq, bounds, start, stop):
    total_cost = 0
    forceStates = np.zeros(stop - start, np.int8) - 1
    for i in range(len(bounds)):
        if bounds[i][0] != -1:
            forceStates[bounds[i][0]] = 0
        if bounds[i][-1] != -1:
            forceStates[bounds[i][-1]] = 3
    
    for i in range(seq.shape[1]):
        temp = viterbi(A, B, pi, seq[start:stop, i], forceStates)
        total_cost += temp[0]
    
    return total_cost
    

def viterbi(A, B, pi, seq, forceStates):
    """
    Find the most likely hidden state series given the observed *seq*.
    
    *forceStates* provides prior knowledge about the boundary.
    """
    from scipy.stats import norm
    
    np.seterr(divide = "ignore")
    num_states = pi.shape[0]
    numdists = B.shape[2]
    seq_len = seq.shape[0]
    costs = np.zeros((num_states, seq_len))
    paths = np.zeros((num_states, seq_len - 1), np.int8)
    transition_costs = -np.log(A)

    for i in range(num_states):
        for j in range(numdists):
            costs[i,:] += norm.pdf(seq, loc = B[i, 0, j], scale = B[i, 1, j] ** 0.5 ) * B[i, 2, j]
    costs = -np.log(costs)

    costs[:,0] -= np.log(pi)
    if forceStates[0] != -1:
        costs[0:forceStates[0], 0] = np.inf
        costs[(forceStates[0] + 1):, 0] = np.inf
    for i in range(seq_len - 1):
        for j in range(num_states):
            min_cost = costs[0, i] + transition_costs[0, j]
            min_state = 0
            for k in range(1, num_states):
                next_cost = costs[k, i] + transition_costs[k, j]
                if next_cost < min_cost:
                    min_cost = next_cost
                    min_state = k
            costs[j, i + 1] += min_cost
            paths[j, i] = min_state
        if forceStates[i + 1] != -1:
            costs[0:forceStates[i + 1],(i + 1)] += np.inf
            costs[(forceStates[i + 1] + 1):,(i + 1)] += np.inf
    for j in range(num_states):
        min_cost = costs[0, -1] + transition_costs[0, j]
        min_state = 0
        for k in range(1, num_states):
            next_cost = costs[k, -1] + transition_costs[k, j]
            if next_cost < min_cost:
                min_cost = next_cost
                min_state = k
    if forceStates[-1] != -1:
        min_state = forceStates[-1]
    else:
        min_state = 0
        for j in range(1, num_states):
            if costs[j, -1] < costs[min_state, -1]:
                min_state = j
    
    path = [min_state]
    for i in range(paths.shape[1])[::-1]:
        path = [paths[path[0], i]] + path
    
    return costs[min_state, -1], path

def genfromStates(forceStates, fields, chrom, region_start, res, minsize):
    
    temp_domains = {}
    for j in range(forceStates.shape[1])[::-1]:
        temp_domains[fields[j]] = []
        currentStates = forceStates[:,j]
        if currentStates[0] != 0:
            temp_domains[fields[j]].append([-1])
        elif currentStates[0] == 0:
            temp_domains[fields[j]].append([region_start])
        for i in range(1, currentStates.shape[0]):
            if currentStates[i] == 0:
                temp_domains[fields[j]].append([region_start + i])
            elif currentStates[i] == 3:
                temp_domains[fields[j]][-1].append(region_start + i)
    
    domains = []
    for key in fields[::-1]:
        for domain in temp_domains[key]:
            if (len(domain) > 1) and (domain[0] > 0):
                if domain[1] - domain[0] >= minsize:
                    line = [chrom, key, str(domain[0] * res), str(domain[1] * res)]
                    domains.append(line)
    
    return domains

def outputDomains(filename, domains, res):
    
    F = open(filename, 'wb')
    F.write('# Resolution: {:d}\n'.format(res))
    F.write('# ' + '\t'.join(['chr', 'level', 'start', 'end']) + '\n')
    for domain in domains:
        F.write('\t'.join(domain) + '\n')
    
    F.flush()
    F.close()
        
    
if __name__ == '__main__':
    run()
