#! /usr/bin/env python
import pandas
import os,argparse,csv
import numpy as np
from Bio import SeqIO
import subprocess
from argparse import RawTextHelpFormatter

def format_for_feature(bam,inputfile):
    out_name = str(inputfile).split(".")[0]
    handle = open(str(out_name)+".bed",'w')
    for record in SeqIO.parse(inputfile,'fasta'):
        seqLen = len(record.seq)
        start, stop = 1, seqLen
        geneId, Chr = str(record.description), str(record.description)
        outLine = [Chr, str(start), str(stop),geneId]
        outLine = '\t'.join(outLine)
        handle.write(outLine + "\n")

def multiBamCov(bam,bed):
    for file_ in os.listdir(bam):
        if file_.endswith(".bam"):
            out_name = file_.split(".")[0]
            outfile = open(str(out_name)+".readcounts","w")
            subprocess.call(["multiBamCov","-bams",os.path.join(bam, file_),"-bed",str(bed)],stdout=outfile)
            outfile.close()
            print "Read Counts calcualted for sample name: " , str(out_name)
    
def convert_counts_to_coverage(ids,outfile,path):
    ids_ = []
    for line in open(str(ids), "r"):
        line = line.rstrip()
        ids_.append(line)
    ids_ = map(str, ids_)
    coverage_dataframe = pandas.DataFrame()
    count = 0
    for cov_file in os.listdir(path):
        if cov_file.endswith(".readcounts"):
            count +=1
            count_file = open(os.path.join(path,cov_file),"r")
            x = list(count_file)
            x.insert(0,'Chr\tstart\tlen\tgeneid\tcounts\n')
            handle = open("temp.out",'w')
            for names in x:
                handle.write(names)
            handle.close()
            handle=open("temp.out","r")
            dataframe = pandas.read_csv(handle,sep="\t",dtype={'Chr':object})
            dataframe["coverage"]=np.where(dataframe["len"]<0,dataframe["len"],dataframe["counts"]/dataframe["len"])
            del dataframe["start"]
            del dataframe["len"]
            del dataframe["geneid"]
            del dataframe["counts"]
            subset= dataframe[dataframe['Chr'].isin(list(ids_))]
            if coverage_dataframe.empty:
                coverage_dataframe = subset
            else:
                coverage_dataframe["coverage%i"%count] = dataframe["coverage"]
    coverage_dataframe.to_csv(path_or_buf=str(outfile)+".cov",header=None,index=False,sep ="\t")
    
def get_scaled_log(file_,outfile):
    cov_file = list(csv.reader(open(str(file_), 'rb'), delimiter='\t'))
    for list_ in cov_file:
       for num_ in range(1,len(list_)):
            list_[num_] = np.log10(((float(list_[num_])*100)+1))
    pd = pandas.DataFrame(cov_file)
    pd.to_csv(outfile,sep="\t",index=False,header=False)
def get_log(file_,outfile):
    cov_file = list(csv.reader(open(str(file_), 'rb'), delimiter='\t'))
    for list_ in cov_file:
       for num_ in range(1,len(list_)):
            list_[num_] = np.log10(((float(list_[num_]))+1))
    pd = pandas.DataFrame(cov_file)
    pd.to_csv(outfile,sep="\t",index=False,header=False)  
def get_multiple(file_,outfile,num):
    cov_file = list(csv.reader(open(str(file_), 'rb'), delimiter='\t'))
    for list_ in cov_file:
        for num_ in range(1,len(list_)):
            list_[num_] = float(list_[num_])*int(num)
    pd = pandas.DataFrame(cov_file)
    pd.to_csv(outfile,sep="\t",index=False,header=False)  

def get_squareroot(file_,outfile):
    cov_file = list(csv.reader(open(str(file_), 'rb'), delimiter='\t'))
    for list_ in cov_file:
       for num_ in range(1,len(list_)):
            list_[num_] = np.sqrt(((float(list_[num_])))+1)
    pd = pandas.DataFrame(cov_file)
    pd.to_csv(outfile,sep="\t",index=False,header=False)
    
class SmartFormatter(argparse.HelpFormatter):
    def _split_lines(self, text, width):
        if text.startswith('R|'):
            return text[2:].splitlines()  
        return argparse.HelpFormatter._split_lines(self, text, width)        
if __name__ == '__main__':    
    parser = argparse.ArgumentParser(prog='Binsanity-profile', usage='%(prog)s -o Output File --contigs ids of contigs of interest -p path to readcoutn files -i fasta file -b path to bam files',description="Pulls read count information from BAM files, creates a coverage profile, and transforms data for input into BinSanity",formatter_class=RawTextHelpFormatter)
    parser.add_argument("-o", dest="outfile", help="""Specify name for output file
    """)
    parser.add_argument("--contigs",dest="ids",help="""specify file containing contig ids for final coverage profile
    """)
    parser.add_argument("-i", dest="inputfile", help="""Specify fasta file being profiled
    """)
    parser.add_argument("-s", dest="inputBAM", help="""indentify path to bam file
    """)
    parser.add_argument("--transform",dest="transform", default = "Scale", help ="""
    Indicate what type of data transformation you want in the final file (default is log):
    Scale --> Scaled by multiplying by 100 and log transforming
    log --> Log transform
    None --> Raw Coverage Values
    X5 --> Multiplication by 5
    X10 --> Multiplication by 10
    SQR --> Square root
    We recommend using a scaled log transformation for initial testing. Other transformations can be useful on a case by case basis""")
    args = parser.parse_args()
    if args.inputfile is None:
	if args.ids is None:
		if args.inputfile is None:
			parser.print_help()
    if args.inputfile is None:
        print "Please provide the fasta file you are analyzing"    
    if args.outfile is None:
        print "Specify name for output file"
    if args.ids is None:
        print "Specify file containing contig ids for final coverage profile"
    if args.inputBAM is None:
        print" Please provide a BAM file" 
    else:
        print """
        ---------------------------------------------------
                         Formating BED file
        ---------------------------------------------------"""
        format_for_feature(args.inputBAM, args.inputfile)
        print """
        ---------------------------------------------------
                        Generating Read Counts
        ---------------------------------------------------"""
        name = str(args.inputfile).split(".")[0]
        multiBamCov(args.inputBAM,str(name)+".bed")
        print """
        -------------------------------------------------
                   Combining all readcount files
        -------------------------------------------------"""
        convert_counts_to_coverage(args.ids,args.outfile,".")
        
        print """
        ------------------------------------------------
                    Transforming Coverage Profile
        ------------------------------------------------"""
        x = str(args.outfile)+".cov"
        if args.transform == "Scale":
            get_scaled_log(x, x+".x100.lognorm")
        
        if args.transform == "log":
            get_log(x,x+".lognorm") 
        elif args.transform =="X5":
            get_multiple(x,x+".x5", int(5))
        elif args.transform =="X10":
            get_multiple(x,x+".x10", int(10))
        elif args.transform =="SQR":
            get_squareroot(x, x+".sqrt")
        print "Transformed Profile Finished!"
