#!/usr/bin/env python2.7
'''
Created on July 24, 2015

@author: Ying Jin
@contact: yjin@cshl.edu
@status: 
@version: 0.5.6
'''
import argparse, subprocess,traceback
import sys, os, time, string, re
import warnings, logging
import collections
import math, copy
import sets
from time import strftime
from datetime import datetime
import ctypes
import multiprocessing,threading,Queue

def locate(name, path):
    for root, dirs, files in os.walk(path):
        if name in files:
             return os.path.join(root, name)

def locBAMqc(loc):
    for p in os.environ[loc].split(os.pathsep):
        potential_file = locate('libBAMqc.so',p)
        if potential_file:
             return potential_file

in_path = locBAMqc('PATH')
in_pythonpath = locBAMqc('PYTHONPATH')
in_local = locate('libBAMqc.so','./')

if in_local:
    so=ctypes.CDLL(in_local)
elif in_path:
    so=ctypes.CDLL(in_path)
elif in_pythonpath:
    so=ctypes.CDLL(in_pythonpath)
else:
    print "can not find libBAMqc.so, you're not setup correctly, exiting\n"
    sys.exit()

if sys.version_info[0] != 2 or sys.version_info[1] != 7:
    print >>sys.stderr, "\nYou are using python" + str(sys.version_info[0]) + '.' + str(sys.version_info[1]) + " BAMQC needs python2.7!\n"
    sys.exit()
    

class pyResults :
    
    def __init__(self):
        self.filename = ""
        self.is_pairEnd = False
        self.clipping_plot_file = ""
        self.mapq_plot_file = ""
        self.mapq_file = ""
        self.read_cov_plot_file = ""
        self.trans_cov_plot_file = ""
        self.insert_plot_file = ""
        self.insert_file = ""
        self.read_dist_plot_file1 = ""
        self.read_dist_plot_file2 = ""
        self.read_dup_plot_file = ""
        self.readLen_plot_file = ""
        self.geneCount_file = ""
        
        self.seqDeDup_percent = 0
        self.posDeDup_percent = 0
        
        self.no_clipping = False
        self.no_rRNA = False
        
        
        self.total_reads = 0
        self.uniq_mapped_reads = 0
        self.multi_mapped_reads = 0
        self.unmapped_reads = 0
        self.low_qual = 0
        self.low_qual_read1 = 0
        self.low_qual_read2 = 0
        self.pcr_dup = 0
        
        self.unmapped_read1 = 0
        self.unmapped_read2 = 0
        self.mapped_read1 = 0
        self.mapped_read2 = 0 
        self.forward_read = 0
        self.reverse_read = 0
        self.paired_reads = 0
        
        self.mapped_plus_minus = 0
        self.mapped_plus_plus = 0
        self.mapped_minus_plus = 0
        self.mapped_minus_minus = 0
        
        self.ins_read = 0
        self.del_read = 0
        
        self.noSplice = 0
        self.splice = 0
        self.paired_diff_chrom = 0
        
        self.rRNA_read = 0
        self.intron_read = 0
        self.cds_exon_read = 0
        self.utr_5_read = 0
        self.utr_3_read = 0
        self.intergenic_up1kb_read = 0
        self.intergenic_down1kb_read = 0
        self.intergenic_read = 0


def read_in_res(cur_data_dir,label):
    res = pyResults()
    fname = cur_data_dir+label+'.res.txt'
    
    try :
        if os.path.exists(fname) :
            f = open(fname,'r')
            for line in f :
                line = line.strip()
                (name,value) = line.split('\t')
                if name == "rRNA_read" :
                    res.rRNA_read = int(value)
                if name == "low_qual_read1" :
                    res.low_qual_read1 = int(value)
                if name == "low_qual_read2" :
                    res.low_qual_read2 = int(value)
                if name == "filename" :
                    res.filename = value
                if name == "is_pairEnd" :
                    res.is_pairEnd = True if int(value) == 1 else False
                if name == "clipping_plot_file" :
                    res.clipping_plot_file = value
                if name == "mapq_plot_file" :
                    res.mapq_plot_file = value
                if name == "mapq_file" :
                    res.mapq_file = value
                if name == "read_cov_plot_file":
                    res.read_cov_plot_file = value
                if name == "trans_cov_plot_file":
                    res.trans_cov_plot_file = value
                if name == "insert_plot_file" :
                    res.insert_plot_file = value
                if name == "insert_file" :
                    res.insert_file = value
                if name == "read_dist_plot_file1" :
                    res.read_dist_plot_file1 = value
                if name == "read_dist_plot_file2" :
                    res.read_dist_plot_file2 = value
                if name == "read_dup_plot_file" :
                    res.read_dup_plot_file = value
                if name == "readLen_plot_file" :
                    res.readLen_plot_file = value
                if name == "geneCount_file" :
                    res.geneCount_file = value
                if name == "seqDeDup_percent" :
                    res.seqDeDup_percent = float(value)
                if name == "posDeDup_percent" :
                    res.posDeDup_percent = float(value)
                if name == "no_clipping" :
                    res.no_clipping = False if int(value) == 0 else True
                if name == "no_rRNA" :
                    res.no_rRNA = False if int(value) == 0 else True
                if name == "total_reads" :
                    res.total_reads = int(value)
                if name == "uniq_mapped_reads" :
                    res.uniq_mapped_reads = int(value)
                if name == "multi_mapped_reads" :
                    res.multi_mapped_reads = int(value)
                if name == "unmapped_reads" :
                    res.unmapped_reads = int(value)
                if name =="low_qual" :
                    res.low_qual = int(value)
                if name == "pcr_dup" :
                    res.pcr_dup = int(value)
                if name == "unmapped_read1" :
                    res.unmapped_read1 = int(value)
                if name == "unmapped_read2" :
                    res.unmapped_read2 = int(value)
                if name =="mapped_read1":
                    res.mapped_read1 = int(value)
                if name =="mapped_read2":
                    res.mapped_read2 = int(value)
                if name =="forward_read":
                    res.forward_read = int(value)
                if name =="reverse_read" :
                    res.reverse_read = int(value)
                if name =="paired_reads":
                    res.paired_reads = int(value)
                if name =="mapped_plus_minus":
                    res.mapped_plus_minus = int(value)
                if name =="mapped_plus_plus":
                    res.mapped_plus_plus = int(value)
                if name == "mapped_minus_plus" :
                    res.mapped_minus_plus = int(value)
                if name =="mapped_minus_minus" :
                    res.mapped_minus_minus = int(value)
                if name =="ins_read" :
                    res.ins_read = int(value)
                if name == "del_read" :
                    res.del_read = int(value)
                if name == "noSplice" :
                    res.noSplice = int(value)
                if name =="splice":
                    res.splice = int(value)
                if name =="paired_diff_chrom" :
                    res.paired_diff_chrom = int(value)
            f.close()
            #os.remove(fname)
        else :
            sys.stderr.write("output file does not exist for sample %s\n" % (label))
    except :
        sys.stderr.write("Error in reading output.\n")
        sys.exit(1)
    return res



def worker(in_queue,out_queue) :
    #for (label,fname,param) in iter(in_queue.get,'STOP'):
    (fname,label,cur_data_dir,cur_fig_dir,rRNA_model,ref_gene_model,attrID,mapq,stranded) = in_queue.get()
    if fname is not None :
        ret = so.run_qc(cur_data_dir,cur_fig_dir,ref_gene_model,attrID,fname,rRNA_model,label,mapq,stranded)
        if ret == 1 :
            res = read_in_res(cur_data_dir,label)
            create_per_sample_plot(fname,label,cur_data_dir,cur_fig_dir,res)
            out_queue.put((label,res))




def distr_jobs2(args,cur_data_dir,cur_fig_dir):
    
    smp_res = dict()
    try:
        #mgr = multiprocessing.Manager()
        #param = [cur_data_dir,cur_fig_dir,args.mapq,args.stranded]
        if args.numProc <= len(args.ifiles) :
            num_process = args.numProc
        else :
            num_process = len(args.ifiles)
        
        processed = 0
        while processed < len(args.ifiles) :
            task_queue = multiprocessing.Queue()
            result_queue = multiprocessing.Queue()
            
            if processed  + num_process > len(args.ifiles) :
                num_process = len(args.ifiles) - processed
            
            for i in range(processed,processed+num_process) :
                label = args.labels[i]
                fname = args.ifiles[i]
                t = ((fname,label,cur_data_dir,cur_fig_dir,args.rRNA_model,args.ref_gene_model,args.attrID,args.mapq,args.stranded))
                task_queue.put(t)
            
            procs = []
            for i in range(num_process) :
                p = multiprocessing.Process(target=worker,args=(task_queue,result_queue))
                #p.daemon=True
                p.start()
                procs.append(p)
            #task_queue.put('STOP')
            finished = 0
            for p in procs :
                p.join()
                #sys.stderr.write(str(p.exitcode)+"\n")
                if p.exitcode != 0 :
                    sys.stderr.write("subprocess error %s " % (p.exitcode))
                    sys.exit(1)
                if p.exitcode == 0 :
                    finished += 1
                processed += 1
                if finished == len(procs) :
                    break
        
            for i in range(num_process) :
                (label,res) = result_queue.get()
                smp_res[label] = res

    except:
        sys.stderr.write("Error: %s\n" % str(sys.exc_info()[1]))
        sys.stderr.write( "[Exception type: %s, raised in %s:%d]\n" %
                         ( sys.exc_info()[1].__class__.__name__,
                          os.path.basename(traceback.extract_tb( sys.exc_info()[2] )[-1][0]),
                          traceback.extract_tb( sys.exc_info()[2] )[-1][1] ) )
        sys.exit(1)
    
    return smp_res



def main():
    
    #read in options
    args = read_opts(prepare_parser())
    
    info = args.info
    warn = args.warn
    debug = args.debug
    error = args.error
    crit = args.critical
    #local_rRNAIdx = None
    #local_geneIdx = None
    
    info("*** Starting BAMqc run. ***\n")
    
    #list of qc results
    smp_res = dict()

    #working directory and output files
    cur_dir = os.path.abspath(args.dir)
    cur_base_dir = os.path.basename(args.dir)
    cur_fig_dir = cur_dir+"/figs/"
    cur_data_dir = cur_dir + "/data/"
    data_file = cur_dir+"summary_data.txt"
    html_file = cur_dir+"/bamqc_output.html"
    
    #check folders and files
    try :
        if os.path.exists(cur_dir) :
            error("Folder already exists!\n")
            sys.exit(1)
        if not os.path.exists(cur_dir) :
            os.makedirs(cur_dir)
        if not os.path.exists(cur_fig_dir):
            os.makedirs(cur_fig_dir)
        if not os.path.exists(cur_data_dir):
            os.makedirs(cur_data_dir)

    except :
        error("Error in create output folder.\n")
        sys.exit(1)

    if args.numProc >=2 :
        smp_res = distr_jobs2(args,cur_data_dir,cur_fig_dir)
    else :
        for i in range(len(args.ifiles)) :
            ifile = args.ifiles[i]
        
            ret = so.run_qc(cur_data_dir,cur_fig_dir,args.ref_gene_model,args.attrID,ifile,args.rRNA_model,args.labels[i],args.mapq,args.stranded,args.numThreads)
            if ret == 1 :
                res = read_in_res(cur_data_dir,args.labels[i])
                create_per_sample_plot(ifile,args.labels[i],cur_data_dir,cur_fig_dir,res)
                smp_res[args.labels[i]] = res

    #sample correlation
    smp_corr_plot_file = cur_fig_dir+"smp_corr.png"
    smp_repro_plot_file = cur_fig_dir+"smp_reproducibility.png"
    smp_var_plot_file = cur_fig_dir+"smp_var.png"
    smp_inner_plot_file = cur_fig_dir+"smp_inner_dist.png"
    smp_cov_plot_file = cur_fig_dir+"smp_cov.png"
    smp_quality_plot_file = cur_fig_dir+"smp_qual.png"
    #smp_dup_plot_file = cur_fig_dir+"smp_dup.png"
    #smp_summary_file = cur_data_dir +"smp_summary.txt"

    if len(smp_res) > 1  :
        
        smp_cnt = 0
        header_corr = 'c('
        header_insert = 'c('
        header_mapq = 'c('
        #header_dup = 'c('
        #sys.stderr.write(','.join(smp_res.values()+"\n"))
        
        filenames_corr = 'c('
        filenames_insert = 'c('
        filenames_mapq = 'c('
        #filenames_dup_seq = 'c('
        #filenames_dup_pos = 'c('
        pe_smp_cnt = 0
        for k in range(len(args.labels)) :
            key = args.labels[k]
            if not smp_res[key].mapq_file == "" :
                filenames_mapq += '"' + cur_base_dir+ "/data/" + os.path.basename(smp_res[key].mapq_file) + '",'
                header_mapq += '"'+key+'",'
            
            if not smp_res[key].insert_file == "" and smp_res[key].is_pairEnd:
                pe_smp_cnt += 1
                filenames_insert += '"' + cur_base_dir + "/data/"+os.path.basename(smp_res[key].insert_file) + '",'
                header_insert += '"'+key+'",'
            
            if not smp_res[key].geneCount_file == "" :
                smp_cnt += 1
                filenames_corr += '"' + cur_base_dir + "/data/"+ os.path.basename(smp_res[key].geneCount_file) + '",'
                header_corr += '"'+key+'",'
        
        header_corr = header_corr[0:len(header_corr)-1] + ')'
        filenames_corr = filenames_corr[0:len(filenames_corr)-1] + ')'
        filenames_insert = filenames_insert[0:len(filenames_insert)-1] + ')'
        header_insert = header_insert[0:len(header_insert)-1] + ')'
        
        header_mapq = header_mapq[0:len(header_mapq)-1] + ')'
        filenames_mapq = filenames_mapq[0:len(filenames_mapq)-1] + ')'
        #header_dup = header_dup[0:len(header_dup)-1] + ')'
        #filenames_dup_seq = filenames_dup_seq[0:len(filenames_dup_seq)-1] + ')'
        #filenames_dup_pos = filenames_dup_pos[0:len(filenames_dup_pos)-1] + ')'
        
        if smp_cnt > 1 :
            info("*** Sample Correlation ***")
            try :
                #subprocess.call(cmd_str+" >"+smp_summary_file, shell=True)
                
                smp_corr_r = cur_data_dir+"smp_correlation.r"
                
                f = open(smp_corr_r,'w')
                f.write("library(corrplot)\n")
                f.write('srcfiles = '+filenames_corr+'\n')
                f.write('destfile = "'+smp_corr_plot_file+'"\n')
                f.write('f1 = read.delim(srcfiles[1],header=T)\n')
                f.write('MM=matrix(nrow=length(f1[,1]),ncol=length(srcfiles))\n')
                f.write('rownames(MM)=f1[,1]\n')
                f.write('MM[,1]=f1[,2]\n')
                f.write('for (i in 2:length(srcfiles)){ \n')
                f.write('    f = read.delim(srcfiles[i],header=T)\n')
                f.write('    MM[,i] = f[,2] }\n')
                f.write('colnames(MM)='+header_corr+'\n')
                f.write('libSize<-colSums(MM)\n')
                f.write('MM<-t(t(MM)*1000000/libSize)\n')
                f.write('ss<-rowSums(MM)\n')
                f.write('M1<-MM[ss>0,]\n')
                f.write('MM_s<-t(scale(t(M1)))\n')
                f.write("M.cor<-cor(MM_s,method='sp')\n")
                f.write("M.cor[is.na(M.cor)]<- 0\n")
                f.write("png(destfile,width=500,height=500,units='px')\n")
                f.write("corrplot(M.cor,is.corr=T,order='FPC',method='color',type='full',add=F,diag=T)\n")
                f.write("dev.state = dev.off()\n")
                f.write("nz_genes = length(M1[,1])\n")
                f.write('destfile = "'+smp_repro_plot_file+'"\n')
                f.write("if(nz_genes >0) { \n")
                f.write("png(destfile,width=500,height=500,units='px')\n")
                f.write("nz_gene_mm = rep(0,length(M1[1,]))\n")
                f.write("for(i in 1:length(M1[1,])) { \n")
                f.write("nz_gene_mm[i] = length(which(M1[,i]>0))/nz_genes * 100 } \n")
                f.write("bplt <- barplot(nz_gene_mm,beside=T,border='NA',space=1.5,ylim=c(0,100),ylab='Genes reproducibly detected (%)',col='blue',names.arg=colnames(MM))\n")
                f.write("text(y= nz_gene_mm+2, x= bplt, labels=paste(as.character(round(nz_gene_mm,digits=1)),'%',sep=''), xpd=TRUE)\n")
                f.write("dev.state = dev.off()}\n")

                f.write('destfile = "'+smp_var_plot_file+'"\n')
                f.write("png(destfile,width=500,height=500,units='px')\n")
                f.write("mad = rep(0,length(M1[,1]))\n")
                f.write("nz_gene_median = rep(0,length(M1[,1]))\n")
                f.write("for(i in 1:length(M1[,1])) { \n")
                f.write("nz_gene_median[i] = median(M1[i,]) \n")
                f.write("mad[i] = median(abs(M1[i,]-nz_gene_median[i])) } \n")
                f.write("mad2 = mad[nz_gene_median >0] \n")
                f.write("nz_gene_median2 = nz_gene_median[nz_gene_median>0] \n")
                f.write("mad_vs_median = mad2/nz_gene_median2 \n")
                f.write("nz_gene_median3 = log(nz_gene_median2, base=2)\n")
                f.write("dd<-data.frame(nz_gene_median3,mad_vs_median) \n")
                f.write("x = densCols(nz_gene_median3,mad_vs_median, colramp=colorRampPalette(c('black', 'white')))\n")
                f.write("dd$dens <- col2rgb(x)[1,] + 1L \n")
                f.write('cols <-  colorRampPalette(c("#000099", "#00FEFF", "#45FE4F", "#FCFF00", "#FF9400", "#FF3100"))(256)\n')
                f.write('dd$col <- cols[dd$dens]\n')
                f.write('plot(mad_vs_median ~ nz_gene_median3,data=dd[order(dd$dens),], col=col, pch=20,xlab="Gene expression (median RPM log2)",ylab="Median absolute deviation/median")\n')
                f.write('dev.state = dev.off()\n')

                #f.write('destfile = "'+smp_corr_plot_file2+'"\n')
                #f.write("M.pc<-prcomp(t(MM_s))\n")
                #f.write("png(destfile,width=500,height=500,units='px')\n")
                #f.write("plot(M.pc)\n")
                #f.write("dev.state = dev.off()\n")
                
                f.write('destfile = "'+smp_cov_plot_file+'"\n')
                f.write("png(destfile,width=500,height=500,units='px')\n")
                f.write('xname=c("<0.5","0.5-10","10-100",">=100")\n')
                f.write('Fn_mm = matrix(0,nrow=length(xname),ncol=length(M1[1,]))\n')
                f.write('rownames(Fn_mm) = xname \n')
                f.write('colnames(Fn_mm) = ' + header_corr + ' \n')

                f.write('for(i in 1:length(M1[1,])) { \n')
                f.write('Fn_mm[1,i] = length(which(M1[,i]<0.5)) \n')
                f.write('Fn_mm[2,i] = length(which(M1[,i]>=0.5 & M1[,i]<10))\n')
                f.write('Fn_mm[3,i] = length(which(M1[,i]>=10 & M1[,i]<100))\n')
                f.write('Fn_mm[4,i] = length(which(M1[,i]>=100)) }\n')

                f.write('barplot(Fn_mm,main="Gene abundance (RPM)",xlab="Sample",ylab="Frequency",col=c("green","blue","red","yellow"),legend=xname)\n')
                f.write("dev.state = dev.off()\n")
                if pe_smp_cnt > 0 :
                    f.write('srcfiles2 = '+filenames_insert+'\n')
                    f.write('destfile2 = "'+smp_inner_plot_file+'"\n')
                    f.write("png(destfile2,width=500,height=500,units='px')\n")
                    f.write('f = read.delim(srcfiles2[1],header=T)\n')
                    f.write('freq=rep(round((f[,1]+f[,2]+1)/2,0),times=f[,3])\n')
                    f.write('smp ='+header_insert+'\n')
                    f.write('boxplot(freq,outline=F,xlim=c(0,length(smp)+1),ylab="Inner distance (bp)",col="blue",border="black") \n')
                    f.write('for (i in 2:length(srcfiles2)){ \n')
                    f.write('    f = read.delim(srcfiles2[i],header=T)\n')
                    f.write('freq=rep(round((f[,1]+f[,2]+1)/2,0),times=f[,3])\n')
                    f.write('boxplot(freq,add=T,outline=F,at=i,col="blue",border="black") }\n')
                    f.write('axis(1,at=seq(1,length(smp),by=1),labels=smp,las=2)\n')
                    f.write("dev.state = dev.off()\n")
                
                f.write('destfile3 = "'+smp_quality_plot_file+'"\n')
                f.write('srcfiles3 = '+filenames_mapq+'\n')
                f.write("png(destfile3,width=500,height=500,units='px')\n")
                f.write('xname=c("<3","3-10","10-20","20-30",">=30")\n')
                f.write('Fn_mm = matrix(0,nrow=length(xname),ncol=length(srcfiles3))\n')
                f.write('rownames(Fn_mm) = xname \n')
                f.write('colnames(Fn_mm) = ' + header_mapq + ' \n')

                f.write('for(i in 1:length(srcfiles3)) { \n')
                f.write('  f = read.delim(srcfiles3[i],header=T)\n')
                f.write(' if(length(which(f[,1]<3)) >0){ Fn_mm[1,i] = sum(f[which(f[,1]<3),3])/f[1,2]} \n')
                f.write('if(length(which(f[,1]>=3 & f[,1]<10)) >0) {Fn_mm[2,i] = sum(f[which(f[,1]<10 & f[,1]>=3),3])/f[1,2]} \n')
                f.write('if(length(which(f[,1]>=10 & f[,1]<20)) >0)  {Fn_mm[3,i] = sum(f[which(f[,1]<20 & f[,1]>=10),3])/f[1,2] }\n')
                f.write('if(length(which(f[,1]>=20 & f[,1]<30)) >0) {Fn_mm[4,i] = sum(f[which(f[,1]<30 & f[,1]>=20),3])/f[1,2]} \n')
                f.write('if(length(which(f[,1]>=30)) >0) {Fn_mm[5,i] = sum(f[which(f[,1]>=30),3])/f[1,2] }} \n')
                
                f.write('barplot(Fn_mm,xlab="Sample",main="Mapping Quality",ylim=c(0,1),ylab="Frequency",col=c("blue","green","yellow","orange","red"),legend=xname)\n')
                f.write("dev.state = dev.off()\n")

                #f.write('destfile3 = "'+smp_dup_plot_file+'"\n')
                #f.write('srcfiles3 = '+filenames_dup_pos+'\n')
                #f.write('srcfiles4 = '+filenames_dup_seq+'\n')
                #f.write('png(destfile3,width=500,height=500,units="px")\n')
                #f.write('M = matrix(0,nrow=2,ncol=length(srcfiles3))\n')
                #f.write('colnames(M) = ' + header_dup + ' \n')
    
                #f.write('for(i in 1:length(srcfiles3)) { \n')
                #f.write('f_pos = read.delim(srcfiles3[i],header=T) \n')
                #f.write('f_seq = read.delim(srcfiles4[i],header=T) \n')
                #f.write('total = sum(f_pos[,1]*f_pos[,2]) \n')
                #f.write('pos_dedup = round(sum(f_pos[,2])/total*100,2) \n')
                #f.write('seq_dedup = round(sum(f_seq[,2])/total*100,2) \n')
                #f.write('M[1,i] = pos_dedup \n')
                #f.write('M[2,i] = seq_dedup }\n')

                #f.write('rownames(M)<-c("Position","Sequence")\n')
                #f.write('barplot(M,ylim=c(0,100),beside=T,col=c("blue","red"),main="Duplication",xlab="Sample",ylab="Percentage after deduplication",legend=c("Mapping position","Sequence"),args.legend=list(bty="n"))\n')
                #f.write('dev.state = dev.off()\n')
    
                

                f.close()
                
                subprocess.call("Rscript " + smp_corr_r , shell=True)
            
            except:
                sys.stderr.write("Error in computing sample correlation.\n")
                smp_corr_plot_file = ""
                smp_corr_plot_file2 = ""
                pass
            
            info("*** Correlation completed ***\n")

    #outputToHTML(smp_res,args.labels,smp_corr_plot_file,smp_inner_plot_file, smp_quality_plot_file,smp_cov_plot_file,html_file)
    outputToHTML(smp_res,args.labels,html_file)
    info("*** BAM QC run completed. ***\n")


def create_per_sample_plot(ifile,label,cur_data_dir,cur_fig_dir,res):
        output_prefix_data = cur_data_dir + label
        output_prefix_fig = cur_fig_dir + label
        try:
            subprocess.call("Rscript "+output_prefix_data+'.read_distr.r',shell=True)
        except :
            print("Error in plotting read distributions.\n")
            res.read_dist_plot_file1 = ""
        pass
        try:
            subprocess.call("Rscript "+output_prefix_data+'.read_distr_pie.r',shell=True)
        except :
            print("Error in plotting read distributions.\n")
            res.read_dist_plot_file2 = ""
            pass

        #res.filename = ifile
        if os.path.isfile(output_prefix_data+'.clipping_profile.r') :
            try:
                subprocess.call("Rscript " + output_prefix_data + '.clipping_profile.r',shell=True)
            #    subprocess.call("rm -rf "+ output_prefix + '.clipping_profile.r',shell=True)
            except:
                print("Cannot generate png file form " + output_prefix_data + '.clipping_profile.r\n')
                res.clipping_plot_file = ""
                pass
        else :
            res.no_clipping = True
        
        if os.path.isfile(output_prefix_data+'.mapq_profile.r') :
            try:
                subprocess.call("Rscript " + output_prefix_data + '.mapq_profile.r',shell=True)
            
            except:
                print("Cannot generate png file form " + output_prefix_data + '.mapq_profile.r\n')
                res.mapq_plot_file = ""
                pass
        else :
            res.mapq_plot_file = ""
        if os.path.isfile(output_prefix_data+'.geneBodyCoverage_plot.r') :
            try:
                subprocess.call("Rscript " + output_prefix_data + '.geneBodyCoverage_plot.r',shell=True)
        
            except:
                print("Cannot generate png file from " + output_prefix_data + '.geneBodyCoverage_plot.r\n')
                res.read_cov_plot_file = ""
                pass
        
            try:
                subprocess.call("Rscript " + output_prefix_data + '.TransCoverage.r',shell=True)
        
            except:
                print("Cannot generate png file from " + output_prefix_data + '.TransCoverage.r\n')
                res.trans_cov_plot_file = ""
                pass
        try:
            subprocess.call("Rscript " + output_prefix_data +  ".ReadLen_plot.r", shell=True)

        except:
            print("Cannot generate png file form " + output_prefix_data + '.ReadLen_plot.rn')
            pass
        if res.is_pairEnd and os.path.isfile(output_prefix_data+'.inner_distance_plot.r') :
            try:
                subprocess.call("Rscript " + output_prefix_data + '.inner_distance_plot.r',shell=True)
            except:
                print("Cannot generate png file form " + output_prefix_data + '.inner_distance_plot.r\n')
                res.insert_plot_file = ""
                pass
        else :
            res.insert_plot_file = ""



#def outputToHTML(res_list,smps,corr_plot_file,smp_inner_plot_file,smp_quality_plot_file,smp_cov_plot_file,html_file):
def outputToHTML(res_list,smps,html_file):
    
    #smps = res_list.keys()
    tohtml = '<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Strict//EN">\n'
    tohtml += '<html>\n'
    tohtml += '<head><title>BAMQC Report</title>\n'
    tohtml += '<style type="text/css">\n'
    #/*ul Styles*/
    tohtml += 'html,body{margin:0;padding:0;height:100%;width:1500px}\n'
    tohtml += 'div#header{background-color:#F3F2ED;}\n'
    tohtml += 'div#header h1{text-align:center; height:80px;line-height:80px;margin:0;padding-left:10px;}\n'
    tohtml += 'div#container{text-align:left;height:100%;width:1500px}\n'
    tohtml += 'div#navigation{background:#F6F0E0;}\n'
    tohtml += 'div#navigation{float:left;width:200px;height:100%}\n'

    tohtml += '.menu-item ul { \n'
    tohtml += 'background: #F6F0E0; \n'
    tohtml += 'font-size: 13px; \n'
    tohtml += 'line-height: 30px; \n'
    tohtml += 'height: 0px; \n'
    tohtml += 'list-style-type: none;'
    tohtml += 'overflow: hidden; \n'
    tohtml += 'padding: 0px; }\n'
    
    tohtml += '.menu-item:hover ul {  height: 220px; }\n '
    #/* table *
    tohtml += 'table{ margin:0;padding:0;width:1300px;table-layout:fixed;text-align:left; }\n'
    tohtml += 'table > thead > tr.tableizer-firstrow > th {  padding: 10px;  background: lavenderblush;} \n'
    #/*border: 4px solid #fff;*/ /*text-overflow: ellipsis;*/ /*overflow: hidden;*/
    tohtml += 'table > tbody > tr > td{ padding: 10px; background: #f8f8f8; word-wrap: break word; } \n'
    tohtml += 'div#footer{background:#BFBD93;}\n'
    tohtml += 'div#footer p{margin:0;padding:5px 10px}\n'
    tohtml += 'div#footer{clear:both;width:100%;text-align:center}\n'
    tohtml += 'div#main{float:right;width:1300px}\n'
    tohtml += 'a{text-decoration:none; color:#000000;}\n'
    tohtml += 'a:hover {text-decoration: underline; }\n'
    

    tohtml += '</style> </head>\n'
    tohtml += '<body>\n'
    tohtml += '<div id ="container">\n'
    tohtml += '<div id="header"><h1>BAMQC Report</h1><p text-align="left">Created On: '+datetime.now().strftime('%m-%d-%Y')+'</p></div>\n'

    i = 0    
    tohtml += '<div id="wrapper">\n'
    tohtml += '<div class="summary" id="navigation">\n'
    #tohtml += '<div id="update_time">\n'
    #tohtml += 
    #tohtml += '</div>\n'
    tohtml += '<h2>Summary</h2>\n'
    tohtml += '<ul>\n'
        
    #i = 0
    for i in range(len(smps)) :
        key = smps[i]
        tohtml += '<li>'
        tohtml += '<div class="menu-item">\n'
        tohtml += '<h4>'+key+'</h4>\n'
        tohtml += '<ul>\n'
        tohtml += '<li><a href="#M'+str(i)+'0">Basic Statistics</a></li>\n'
        tohtml += '<li><a href="#M'+str(i)+'1">Read Distribution</a></li>\n'
        tohtml += '<li><a href="#M'+str(i)+'2">Mappability</a></li>\n'
        tohtml += '<li><a href="#M'+str(i)+'3">Coverage</a></li>\n'
        tohtml += '<li><a href="#M'+str(i)+'4">Read Length and Insertion Size</a></li>\n'
        #tohtml += '<li><a href="#M'+str(i)+'6">rRNA contamination</a></li>\n'
        tohtml += '</ul></div></li>\n'   
        #i += 1 
    
    smp_corr_pos = len(smps)*5 + 100
    
    if len(res_list) > 1:
        tohtml += '<li><div class="menu-item"><a href="#M'+str(smp_corr_pos)+'"><h4>Sample Correlation</h4></a></div></li>\n'
        
    tohtml += '</ul>\n'    
    tohtml += '</div>\n'     
        
    tohtml += '<div id="main" >\n'
    for i in range(len(smps)):
        key = smps[i] 
        res = res_list[key]

        tohtml += '<h2>'+key+'</h2>\n'
        tohtml += '<div class="module"><h2 id="M'+str(i)+'0">Basic Statistics</h2>\n'
        tohtml += '<table>\n'
        tohtml += '<thead><tr class=\"tableizer-firstrow\">\n';
        tohtml += '<th>Measure</th><th>Value</th></tr></thead>\n'
        tohtml += '<tbody><tr><td>Total Reads</td><td>'+str(res.total_reads)+'</td></tr>\n'
        if not res.is_pairEnd :
            tohtml += '<tr><td>Unique Reads</td><td>'+str(res.uniq_mapped_reads)+'</td></tr>\n'
            tohtml += '<tr><td>Multi-reads</td><td>'+str(res.multi_mapped_reads)+'</td></tr>\n'
            tohtml += '<tr><td>Unmapped Reads</td><td>'+str(res.unmapped_reads)+'</td></tr>\n'
            tohtml += '<tr><td>Low Quality Reads</td><td>'+str(res.low_qual)+'</td></tr>\n'
            tohtml += '<tr><td>Forward Reads</td><td>'+str(res.forward_read)+'</td></tr>\n'
            tohtml += '<tr><td>Reverse Reads</td><td>'+str(res.reverse_read)+'</td></tr>\n'
            tohtml += '<tr><td>Splice Reads</td><td>'+str(res.splice)+'</td></tr>\n'
            tohtml += '<tr><td>Non-Splice Reads</td><td>'+str(res.noSplice)+'</td></tr>\n'
            tohtml += '<tr><td>rRNA Reads</td><td>'+str(res.rRNA_read)+'</td></tr></tbody></table></div>\n'
        else : # paired read
            tohtml += '<tr><td>Uniquely Mapped Pairs</td><td>'+str(res.paired_reads)+'</td></tr>\n'
            tohtml += '<tr><td>Uniquely Mapped Read1</td><td>'+str(res.mapped_read1)+'</td></tr>\n'
            tohtml += '<tr><td>Uiquely Mapped Read2</td><td>'+str(res.mapped_read2)+'</td></tr>\n'
            tohtml += '<tr><td>Multi-reads</td><td>'+str(res.multi_mapped_reads)+'</td></tr>\n'
            tohtml += '<tr><td>Unmapped Read1</td><td>'+str(res.unmapped_read1)+'</td></tr>\n'
            tohtml += '<tr><td>Unmapped Read2</td><td>'+str(res.unmapped_read2)+'</td></tr>\n'
            tohtml += '<tr><td>Fraction of read mapped "+,-" </td><td>'+str(res.mapped_plus_minus)+'</td></tr>\n'
            tohtml += '<tr><td>Fraction of read mapped "+,+" </td><td>'+str(res.mapped_plus_plus)+'</td></tr>\n'
            tohtml += '<tr><td>Fraction of read mapped "-,+" </td><td>'+str(res.mapped_minus_plus)+'</td></tr>\n'
            tohtml += '<tr><td>Fraction of read mapped "-,-" </td><td>'+str(res.mapped_minus_minus)+'</td></tr>\n'            
            tohtml += '<tr><td>Low Quality Read1</td><td>'+str(res.low_qual_read1)+'</td></tr>\n'
            tohtml += '<tr><td>Low Quality Read2</td><td>'+str(res.low_qual_read2)+'</td></tr>\n'
            tohtml += '<tr><td>Forward Reads</td><td>'+str(res.forward_read)+'</td></tr>\n'
            tohtml += '<tr><td>Reverse Reads</td><td>'+str(res.reverse_read)+'</td></tr>\n'
            tohtml += '<tr><td>Splice Reads</td><td>'+str(res.splice)+'</td></tr>\n'
            tohtml += '<tr><td>Non-Splice Reads</td><td>'+str(res.noSplice)+'</td></tr>\n'
            tohtml += '<tr><td>Pairs mapped to different chromosomes</td><td>'+str(res.paired_diff_chrom)+'</td></tr>\n'
            tohtml += '<tr><td>rRNA Reads</td><td>'+str(res.rRNA_read)+'</td></tr></tbody></table></div>\n'
            
        tohtml += '<div class="module"><h2 id="M'+str(i)+'1">Read Distribution</h2>\n'
        tohtml += '<p><img class="indented" src="./figs/'+os.path.basename(res.read_dist_plot_file1)+'" alt="Read Distribution"><img class="indented" src="./figs/'+os.path.basename(res.read_dist_plot_file2)+'" alt="Read Distribution"></p></div>\n'
        if res.no_clipping :
            tohtml += '<div class="module"><h2 id="M'+str(i)+'2">Mappability Profile</h2>\n'
            tohtml += '<p>There is no soft clipping. <img class="indented" src="./figs/'+os.path.basename(res.mapq_plot_file)+'" alt="MapQ Profile"> </p></div>\n'
        else :
            tohtml += '<div class="module"><h2 id="M'+str(i)+'2">Mappability</h2>\n'
            tohtml += '<p><img class="indented" src="./figs/'+os.path.basename(res.clipping_plot_file)+'" alt="Mappablity Profile"> <img class="indented" src="./figs/'+os.path.basename(res.mapq_plot_file)+'" alt="MapQ Profile"></p></div>\n'
                        
        tohtml += '<div class="module"><h2 id="M'+str(i)+'3">Coverage</h2>\n'
        tohtml += '<p><img class="indented" src="./figs/'+os.path.basename(res.read_cov_plot_file)+'" alt="Read Coverage"> <img class="indented" src="./figs/'+os.path.basename(res.trans_cov_plot_file)+'" alt="Read Coverage"></p></div>\n'
          
        if res.is_pairEnd :
                tohtml += '<div class="module"><h2 id="M'+str(i)+'4">Read Length and Insertion Size</h2>\n'
                tohtml += '<p><img class="indented" src="./figs/'+os.path.basename(res.readLen_plot_file)+'" alt="Read Length"> <img class="indented" src="./figs/'+os.path.basename(res.insert_plot_file)+'" alt="Insertion Size"></p></div>\n'
        else :
                tohtml += '<div class="module"><h2 id="M'+str(i)+'4">Read Length</h2>\n'
                tohtml += '<p><img class="indented" src="./figs/'+os.path.basename(res.readLen_plot_file)+'" alt="Read Length"></p></div>\n'
                
    if len(res_list) >1 :
        tohtml += '<div class="Smp_corr"><h2 id="M'+str(smp_corr_pos)+'">Sample Correlation and Quality</h2>\n'
        tohtml += '<p><img class="indented" src="./figs/smp_corr.png" alt="Sample Correlation"><img class="indented" src="./figs/smp_qual.png" alt="Sample Correlation"></p>\n'
        
        if res.is_pairEnd :
            tohtml += '<h2 id="M'+str(smp_corr_pos)+'">Sample Coverage and Insert size</h2>\n'
            tohtml += '<p><img class="indented" src="./figs/smp_cov.png" alt="Sample Coverage"><img class="indented" src="./figs/smp_inner.png" alt="Sample insert size"></p></div>\n'
        else :
            tohtml += '<h2 id="M'+str(smp_corr_pos)+'">Sample Coverage</h2>\n'
            tohtml += '<p><img class="indented" src="./figs/smp_cov.png" alt="Sample Coverage"></p></div>\n'
        tohtml += '<h2 id="M'+str(smp_corr_pos)+'">Sample Variation</h2>\n'
        tohtml += '<p><img class="indented" src="./figs/smp_reproducibility.png" alt="Sample Variation"><img class="indented" src="./figs/smp_var.png" alt="Sample Variation"></p></div>\n'

    tohtml += '</div>\n'

    tohtml += '<div id="footer"><p>Produced by Bioinformatics Shared Resource at CSHL (version 0.5)</p></div></div></div></body></html>\n'
    
    try :
        f = open(html_file,'w')
    
        f.write(tohtml+"\n")
        f.close()
    except :
        sys.stderr.write("Cannot generate the final report. \n")
        sys.exit(1)
    
    

def prepare_parser ():
    """ inputs(parameters) required/allowed in this pipeline """
    desc = "Quality control of mapped NGS data (BAM/SAM files) . BAMqc version 0.5."
                                                                                        
                                                                                                   
    exmp = "Example: BAMqc -f treat1.bam treat2.bam treat3.bam -r mm9_refGene.bed -o bamqc_out" 
    parser = argparse.ArgumentParser(description = desc,epilog = exmp) 
    parser.add_argument('-i', '--inputFile', metavar = 'alignment_files', dest = 'ifiles', nargs = '+', required = True,
                   help = 'Alignment files. Could be multiple SAM/BAM files separated by space. Required.')
    parser.add_argument('-r', '--refgene', metavar = 'refgene', dest='ref_gene_model', nargs = '?', type=str, required = True,help = 'refGene GTF file. Required')
    
    parser.add_argument('-f', metavar='attrID', dest='attrID', nargs='?', type=str, default="gene_id",
                   help='The read summation at which feature level in the GTF file. DEFAULT: gene_id.')
                   
    parser.add_argument('--rRNA', metavar = 'rRNA', dest='rRNA_model', nargs = '?', type=str, default="",
                    help = 'rRNA BED file.')
    parser.add_argument('-o', '--outputDir', metavar = 'dir', dest='dir', nargs = '?', type=str, required = True,
                   help = 'output directory. Required.')
#    parser.add_argument('-i', '--index', metavar = 'transript_Index', dest='trIdx', nargs = '?',
#                   help = 'Transcriptome index file.')
    parser.add_argument('--stranded', metavar='stranded', dest='stranded', nargs='?', type=str, default="yes", choices=['yes','no','reverse'],
                    help='Is this a stranded library? (yes, no, or reverse). DEFAULT: yes.')
    
    parser.add_argument('-q', '--mapq', metavar = 'mapq', dest='mapq', nargs = '?', default=30, type=int,
                   help = 'Minimum mapping quality (phred scaled) for an alignment to be called uniquely mapped. DEFAULT:30')
    #parser.add_argument('-l', '--lowBound', metavar = 'lb', dest='lb', nargs = '?', default=-250, type=int,
    #               help = 'Lower bound for plotting insert size distribution. DEFAULT:-250')
    #parser.add_argument('-u', '--upperBound', metavar = 'ub', dest='ub', nargs = '?', default=250, type=int,
    #               help = 'Upper bound for plotting insert size distribution. DEFAULT:250')
    #parser.add_argument('-s', '--stepSize', metavar = 'stepsize', dest='step_size', nargs = '?', default=5, type=int,
    #               help = 'Step size for plotting insert size distribution. DEFAULT:5')
    parser.add_argument('-l','--label',metavar = 'labels', dest = 'labels', nargs = '+', 
                   help = 'Labels of input files. DEFAULT:smp1 smp2 ...')
    #parser.add_argument('-p', '--processes', dest='numProc', default=1, type=int,help='Number of processes to use .DEFAULT:1')
    parser.add_argument('-t', '--threads', dest='numThreads', default=1, type=int,help='Number of threads to use .DEFAULT:1')

    return parser



def read_opts(parser):
    ''' object parser contains parsed options '''
    
    args = parser.parse_args()
    args.numProc = 1
    # logging object
    logging.basicConfig(level=20,
                        format='%(levelname)-5s @ %(asctime)s: %(message)s ',
                        datefmt='%a, %d %b %Y %H:%M:%S',
                        stream=sys.stderr,
                        filemode="w"
                        )
    
    #treatment files
    if args.labels is not None :
        if len(args.labels) >0 and len(args.ifiles) != len(args.labels) :
            logging.error("Number of labels does not match with the number of samples.\n")
            sys.exit(1)
    
    
    if args.labels is None :
        args.labels = []
            
    for i in range(len(args.ifiles)) :
        if not os.path.isfile(args.ifiles[i]) :
            logging.error("No such file: %s !\n" % (args.ifiles[i]))
            sys.exit(1)
        if len(args.labels) < len(args.ifiles) :            
            args.labels.append("smp"+str(i))
        

#    if args.trIdx is None :
#        logging.warning("Trancsriptome index file is not available.\n")
#
#    else :
#        if not os.path.isfile(args.trIdx) :
#            logging.error("No such file : %s !\n" %(args.trIdx))
#            sys.exit(1)

    if args.stranded not in ['yes', 'no', 'reverse'] :
        logging.error("Does not support such stranded value: %s !\n" % (args.stranded))
        sys.exit(1)
    
    if args.mapq is None :
        args.mapq = 30
    
    if args.rRNA_model is not None and args.rRNA_model != "":
        if not os.path.isfile(args.rRNA_model) :
            logging.error("No such file : %s \n" %(args.rRNA_model))
            sys.exit(1)

    if args.rRNA_model is None :
        args.rRNA_model = ""

    if args.attrID is None or args.attrID == "":
        logging.error("please specify the read summation at which feature level in the GTF file\n")
        sys.exit(1)

    if args.ref_gene_model is None :
        logging.error("reference gene model is required.\n")
        sys.exit(1)
    else :
        if not os.path.isfile(args.ref_gene_model) :
            logging.error("No such file : %s !\n" %(args.ref_gene_model))
            sys.exit(1)
                         
    # logging alias
    args.critical = logging.critical
    args.error   = logging.error
    args.warn    = logging.warning
    args.debug   = logging.debug
    args.info    = logging.info        
 
    return args 


if __name__ == '__main__':
    try:
        start_time = time.time()
        main()
        end_time = time.time()
        sys.stderr.write("Elapsed time was " + str(round((end_time - start_time) / 60, 2)) + " minutes.\n")
    except KeyboardInterrupt:
        sys.stderr.write("User interrupt !\n")
        sys.exit(0)
