#! /usr/bin/env python
from argparse import RawTextHelpFormatter
import os, sys, time, argparse, shutil
import numpy as np
from Bio import SeqIO
from sklearn.cluster import KMeans
from sklearn.cluster import AffinityPropagation

def create_fasta_dict(fastafile):
    """makes dictionary using fasta files to be binned"""
    fasta_dict = {}
    for record in SeqIO.parse(fastafile, "fasta"):
        fasta_dict[record.id] = record.seq
    return fasta_dict

def get_cov_data(h):
    """Makes a dictionary of coverage values for affinity propagation"""
    all_cov_data = {}
    for line in open(str(h), "r"):
        line = line.rstrip()
        cov_data = line.split()
        all_cov_data[cov_data[0]] = cov_data
    print "coverage data retrieved"
    return all_cov_data

def cov_array(a,b,filename,size):
    """Computes a coverage array based on the contigs in files and the contig names associated with the coverage file"""
    names = []
    c = 0
    cov_array = []
    unused = []
    count = 0 
    for record in SeqIO.parse(os.path.join(b, filename), "fasta"):
        if len(record.seq) >= int(size):
            if record.id in a.keys():
                count += 1
                data = a[record.id]
                names.append(data[0])
                data.remove(data[0])
                line = " ".join(data)
                if c == 1:
                    temparray = np.fromstring(line, dtype=float, sep=' ')
                    cov_array = np.vstack((cov_array, temparray))
                if c == 0:
                    cov_array = np.fromstring(line, dtype=float, sep=' ')
                    c += 1
                else:
                    unused.append(str(record.id))   
    print cov_array.shape
    return cov_array, names, unused, count     

def kmean(array,names,filename,unused,path,clusters):
    """Uses kmeans to make putative bins"""
    os.mkdir("KMEAN-BINS")
    average_linkage = KMeans(n_clusters=clusters, n_jobs=-1,n_init=10000)
    apclust = average_linkage.fit_predict(array)
    outfile_data = {}
    i = 0
    while i < len(names):
        if apclust[i] in outfile_data.keys():
            outfile_data[apclust[i]].append(names[i])
        if apclust[i] not in outfile_data.keys():
            outfile_data[apclust[i]] = [names[i]]
        i += 1    
        with open(os.path.join(path,filename),"r") as input2_file: 
            fasta_dict = create_fasta_dict(input2_file)                
            count = 0    
            for k in outfile_data:
                if len(outfile_data[k]) >= 5:
                    output_file = open(os.path.join("KMEAN-BINS","kmean_%s.fna" % (k)), "w" )
                    for x in outfile_data[k]:
                        output_file.write(">"+str(x)+"\n"+str(fasta_dict[x])+"\n")
                    output_file.close()
                    count = count + 1
                elif len(outfile_data[k]) < 5:
                    if any((len(fasta_dict[x])>50000) for x in outfile_data[k]):
                        output_file = open(os.path.join("KMEAN-BINS","kmean_%s.fna" % (k)), "w" )
                        for x in outfile_data[k]:
                            output_file.write(">"+str(x)+"\n"+str(fasta_dict[x])+"\n")
                        output_file.close()
                        count = count +1
                print "Cluster "+str(k)+": "+str(len(outfile_data[k]))
            print ("""
                Total Number of Bins: %i""" % count)

#########################################
    if os.path.isdir("KMEAN-BINS") is True:
        for binname in os.listdir("."):
            if "kmean" and ".fna" in str(binname):
                shutil.move(binname,str("KMEAN-BINS"))
    else:
        os.makedirs("KMEAN-BINS")
        for binname in os.listdir("."):
            if "kmean" and ".fna" in str(binname):
                shutil.move(binname,"KMEAN-BINS")
            
def affinity_propagation(array,names,file_name,damping,iterations,convergence,preference,path,output_directory):
    """Uses affinity propagation to make putative bins"""
    if os.path.isdir(str(output_directory)) is False:
        os.mkdir(output_directory)
    apclust = AffinityPropagation(damping=float(damping), max_iter=int(iterations), convergence_iter=int(convergence), copy=True, preference=int(preference), affinity='euclidean', verbose=False).fit_predict(array)
    print"""-------------------------------------------------------
                            --Creating Bins--
            -------------------------------------------------------"""
    outfile_data = {}
    i = 0
    while i < len(names):
        if apclust[i] in outfile_data.keys():
            outfile_data[apclust[i]].append(names[i])
        if apclust[i] not in outfile_data.keys():
            outfile_data[apclust[i]] = [names[i]]
        i += 1
    out_name = file_name.split(".")[0]
    with open(os.path.join(path,file_name),"r") as input2_file: 
        fasta_dict = create_fasta_dict(input2_file)                
        count = 0                
        for k in outfile_data:
            if len(outfile_data[k]) >= 5:
                output_file = open(os.path.join(output_directory,str(out_name)+"-bin_%s.fna" % (k)), "w")
                for x in outfile_data[k]:
                    output_file.write(">"+str(x)+"\n"+str(fasta_dict[x])+"\n")
                output_file.close()
                count = count + 1
            elif len(outfile_data[k]) < 5:
                if any((len(fasta_dict[x])>50000) for x in outfile_data[k]):
                    output_file = open(os.path.join(output_directory,str(out_name)+"-bin_%s.fna" % (k)), "w")
                    for x in outfile_data[k]:
                        output_file.write(">"+str(x)+"\n"+str(fasta_dict[x])+"\n")
                    output_file.close()
                    count = count +1
            print "Cluster "+str(k)+": "+str(len(outfile_data[k]))
        print ("""Total Number of Bins: %i""" % count)
        

class Logger(object):
    def __init__(self):
        self.terminal = sys.stdout
        self.log = open("binsanity-log.txt", "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)  

    def flush(self):
        pass  
if __name__ == '__main__':
    parser = argparse.ArgumentParser(prog='Binsanity-lc', usage="""%(prog)s -c [Raw Coverage File] -tcov [transformed coverage file] -f [Path To Contig File] -l [Suffix Linking Contig files] 
    {optional [-x Contig Size Cut Off] [-p Preference] [-m Max Iterations] [-v Convergence Iterations] [-d Damping factor] -o [Output directory]}""",description="""
             Binsanity clusters contigs based on coverage. 
    ----------------------------------------------------------------
    Binsanity-lc is for assemblies with greater than 100,000 contigs. 
    It uses K-means to subset your initial contigs making it feasible 
    to use Affinity Propagation. This script should only be used if 
    the memory requirements for the script Binsanity are too high for 
    your current computer.
    
    *****************NOTE THIS IS A BETA VERSION AND HAS NOT BEEN FULLY VETTED YET*****************""",formatter_class=RawTextHelpFormatter)
    parser.add_argument("-tcov", dest="inputTCov", help="""
    Specify the transformed Coverage File
    e.g log normalized""")
    parser.add_argument("-f", dest="inputContigFiles", help="""
    Specify directory containing your contigs""")
    parser.add_argument("-p", type=float, dest="preference", default=-3, help="""
    Specify a preference (default is -3) Note: decreasing the preference leads to 
    more lumping, increasing will lead to more splitting. If your range of coverages 
    are low you will want to decrease the preference, if you have 10 or less 
    replicates increasing the preference could benefit you.""")
    parser.add_argument("-m", type=int, dest="maxiter", default=4000, help="""
    Specify a max number of iterations (default is 4000)""")
    parser.add_argument("-v", type=int, dest="conviter",default=400, help="""
    Specify the convergence iteration number (default is 400)
    e.g Number of iterations with no change in the number of estimated clusters that stops the convergence.""")
    parser.add_argument("-d",default=0.95, type=float, dest="damp", help="""
    Specify a damping factor between 0.5 and 1, default is 0.95""")
    parser.add_argument("-l",dest="fastafile", help="""
    Specify the fasta file containing contigs you want to cluster""")
    parser.add_argument("-x",dest="ContigSize", type=int, default=1000,help="""
    Specify the contig size cut-off (Default 1000 bp)""")
    parser.add_argument("-o",dest="outputdir", default="BINSANITY-RESULTS", help="""
    Give a name to the directory BinSanity 
    results will be output in [Default is 'BINSANITY-RESULTS']""")
    parser.add_argument("-c", dest="inputCovFile", help="""
    Specify a Raw Coverage File""")
    parser.add_argument("-n",dest="ClusterNum", default=None,help="""
    Specify the number of initial bins for hierarchical clustering""")    
    args = parser.parse_args()
    if args.inputCovFile is None:
        if (args.inputContigFiles is None) and (args.fastafile is None):
            parser.print_help()
    if (args.inputTCov is None):
        print "Please indicate -tcov coverage file"
    if args.inputContigFiles is None:
        print "Please indicate -f directory containing your contigs"

    else:
        start_time = time.time()
        sys.stdout = Logger()
        print """
        -------------------------------------------------------
                    ---Computing Coverage Array ---
        -------------------------------------------------------
        """
        val1, val2, val3,val4 = cov_array(get_cov_data(args.inputTCov), args.inputContigFiles, args.fastafile,args.ContigSize)

        if val4 < 50000:
            print """
            -------------------------------------------------------
                        ---Affinity Propagation---
            -------------------------------------------------------        
            """
            
            affinity_propagation(val1, val2, args.fastafile, args.damp, args.maxiter, args.conviter, args.preference,args.inputContigFiles,args.outputdir)
            
            print("""
        
            --------------------------------------------------------
                --- Putative Bins Computed in %s seconds ---
            --------------------------------------------------------""" % (time.time() - start_time))            
        else:

            print """
            -------------------------------------------------------
                        ---Initial K-mean Clustering---
            -------------------------------------------------------
            """ 
            if args.ClusterNum is None:
                cluster_number = int(val4/10000)
                print "cluster number for k means is: ", cluster_number
                kmean(val1,val2,args.fastafiles,val3,args.inputContigFiles,cluster_number)
            
                print """
                -------------------------------------------------------
                        ---Affinity Propagation---
                -------------------------------------------------------        
                """
                for binname in os.listdir("KMEAN-BINS"):
                    if "kmean" in binname:
                        val5, val6, val7,val8 = cov_array((get_cov_data(args.inputCovFile)),"KMEAN-BINS/",binname,args.ContigSize)
                        conduct_AP(binname,val6,args.damp,args.maxiter,args.conviter,args.preference,"KMEAN-BINS/",args.outputdir,val5,args.ContigSize)
                print("""
        
                --------------------------------------------------------
                    --- Putative Bins Computed in %s seconds ---
                --------------------------------------------------------""" % (time.time() - start_time))        
                os.move("KMEAN-BINS",args.outputdir)
            else:
                cluster_number = args.ClusterNum
                print "cluster number for k means is: ", cluster_number
            
                kmean(val1,val2,args.fastafiles,val3,args.inputContigFiles,cluster_number)
            
                print """
                -------------------------------------------------------
                        ---Affinity Propagation---
                -------------------------------------------------------        
                """
                for binname in os.listdir("KMEAN-BINS"):
                    array = (get_cov_data(args.inputCovFile)) 

                    val5, val6, val7,val8 = cov_array(array,"KMEAN-BINS",binname,args.ContigSize)
                    conduct_AP(binname,val6,args.damp,args.maxiter,args.conviter,args.preference,"KMEAN-BINS",args.outputdir,val5,args.ContigSize)
    
                print("""
        
                --------------------------------------------------------
                    --- Putative Bins Computed in %s seconds ---
                --------------------------------------------------------""" % (time.time() - start_time))
                os.move("KMEAN-BINS",args.outputdir)

