#!python
import os, errno, sys
import argparse as ap
import numpy as np
import csv
import subprocess
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path
import pysam
from collections import defaultdict

# standardize the logging output
import logging
this_script=os.path.basename(__file__)
logging.basicConfig(format='['+this_script+'] %(levelname)s:%(message)s', level=logging.INFO)

class SmartFormatter(ap.HelpFormatter):
    def _split_lines(self, text, width):
            if text.startswith('R|'):
                return text[2:].splitlines()
            # this is the RawTextHelpFormatter._split_lines
            return ap.HelpFormatter._split_lines(self, text, width)

def setup_argparse():
    parser = ap.ArgumentParser(prog='amplicov',
        description = '''Script to parse amplicon region coverage and generate barplot in html''',
        formatter_class = SmartFormatter)

    inGrp = parser.add_argument_group('Amplicon Input (required, mutually exclusive)')
    inGrp_me = inGrp.add_mutually_exclusive_group(required=True)
    inGrp_me.add_argument('--bed', metavar='[FILE]', type=str,help='amplicon bed file (bed6 format)')
    inGrp_me.add_argument('--bedpe', metavar='[FILE]', type=str,help='amplicon bedpe file')
   
    covGrp = parser.add_argument_group('Coverage Input (required, mutually exclusive)')
    covGrp_me = covGrp.add_mutually_exclusive_group(required=True)
    covGrp_me.add_argument('--bam', metavar='[FILE]', type=str,help='sorted bam file (ex: samtools sort input.bam -o sorted.bam)')
    covGrp_me.add_argument('--cov', metavar='[FILE]', type=str,help='coverage file [position\tcoverage]')
    
    outGrp = parser.add_argument_group('Output')
    outGrp.add_argument('-o', '--outdir', metavar='[PATH]',type=str, default='.', help='output directory')
    outGrp.add_argument('-p', '--prefix', metavar='[STR]',type=str , help='output prefix')

    #optGrp = parser.add_argument_group('Options')
    parser.add_argument('--pp', action='store_true', help='process proper paired only reads from bam file (illumina)')
    parser.add_argument('--count_primer', action='store_true', help='count overlapped primer region to unqiue coverage')
    parser.add_argument('--mincov', metavar='[INT]', type=int, help='minimum coverage to count as ambiguous N site [default:10]', default=10)
    parser.add_argument('-r', '--refID', metavar='[STR]',type=str , help='reference accession (bed file first field)')
    
    parser.add_argument('--depth_lines', default=[5,10,20,50], type=int, nargs='+', help='Add option to display lines at these depths (provide depths as a list of integers) [default:5 10 20 50]')
    parser.add_argument('--gff', metavar='[FILE]', type=str, help='gff file for data hover info annotation')
    parser.add_argument('--version', action='version', version='%(prog)s 0.3.3')
    args_parsed = parser.parse_args()
    if not args_parsed.outdir:
        args_parsed.outdir = os.getcwd()

    return args_parsed

def mkdir_p(directory_name):
    try:
        os.makedirs(directory_name)
    except OSError as exc: 
        if exc.errno == errno.EEXIST and os.path.isdir(directory_name):
            pass



def is_bed6(input):
    with open(input,'r') as f:
        bedline=[]
        num=0
        for line in f:
            bedline = line.rstrip().split("\t")
            num += 1
            if (len(bedline) != 6 or not bedline[1].isnumeric() or not bedline[2].isnumeric()):
                logging.error(f"The input bed file is not in bed6 format.\nline:{num} {line}")
                sys.exit(1)

def convert_bed_to_amplicon_dict(input,cov_array,RefID="",unique=False, count_primer=False):
    ## convert bed file to amplicon region dictionary
    input_bed = input
    cov_zero_array = np.zeros_like(cov_array)
    amplicon=defaultdict(dict)
    primers_pos=list()
    RefID = '""' if RefID is None else RefID
    cmd = 'grep -v alt %s | grep %s | paste - - | cut -f 2,3,4,8,9 | sed -e "s/_LEFT//g" -e "s/_RIGHT//g" ' % (input_bed, RefID)
    proc = subprocess.Popen(cmd,shell=True,stdout=subprocess.PIPE) 
    previous_id=''
    outs, errs = proc.communicate()
    if proc.returncode == 0:
        outs = outs.splitlines()
    else:
        logging.error("Failed %d %s %s" % (proc.returncode, outs, errs))
   
    for i in range(len(outs)):
        fstart, fend, id, rstart, rend = outs[i].decode().rstrip().split("\t")
        if id in amplicon:
            continue
        amplicon[id]=range(int(fend),int(rstart))   
        primers_pos.extend(range(int(fstart),int(fend)))
        primers_pos.extend(range(int(rstart),int(rend)))
        for i in range(int(fend),int(rstart)):
            cov_zero_array[i] += 1

    if unique:
        unique_region_set = set(np.where(cov_zero_array == 1)[0]) if count_primer else set(np.where(cov_zero_array == 1)[0]) - set(primers_pos)
        for i in range(len(outs)):
            fstart, fend, id, rstart, rend = outs[i].decode().rstrip().split("\t")
            unique_region = sorted(unique_region_set.intersection(set(range(int(fstart),int(rend)))))
            if unique_region:
                amplicon[id] = range(list(unique_region)[0],list(unique_region)[-1]+1)
            else:
                amplicon[id] = range(0)
         
    return amplicon

def convert_bedpe_to_amplicon_dict(input,cov_array,RefID="",unique=False, count_primer=False):
    ## convert bed file to amplicon region dictionary 
    input_bedpe = input
    cov_zero_array = np.zeros_like(cov_array)
    amplicon=defaultdict(dict)
    primers_pos=list()
    RefID = '""' if RefID is None else RefID
    cmd = 'grep %s %s | cut -f 2,3,5,6,7 ' % (RefID, input_bedpe)
    proc = subprocess.Popen(cmd,shell=True,stdout=subprocess.PIPE) 
    previous_id=''
    outs, errs = proc.communicate()
    if proc.returncode == 0:
        outs = outs.splitlines()
    else:
        logging.error("Failed %d %s %s" % (proc.returncode, outs, errs))

    for i in range(len(outs)):
        fstart, fend, rstart, rend, id = outs[i].decode().rstrip().split("\t")
        if id in amplicon:
            continue
        amplicon[id]=range(int(fend),int(rstart))   
        primers_pos.extend(range(int(fstart),int(fend)))
        primers_pos.extend(range(int(rstart),int(rend)))
        for i in range(int(fend),int(rstart)):
            cov_zero_array[i] += 1
    
    if unique:
        unique_region_set = set(np.where(cov_zero_array == 1)[0]) if count_primer else set(np.where(cov_zero_array == 1)[0]) - set(primers_pos)
        for i in range(len(outs)):
            fstart, fend, rstart, rend, id = outs[i].decode().rstrip().split("\t")
            unique_region = sorted(unique_region_set.intersection(set(range(int(fstart),int(rend)))))
            if unique_region:
                amplicon[id] = range(list(unique_region)[0],list(unique_region)[-1]+1)
            else:
                amplicon[id] = range(0)

    return amplicon

def parse_gff_file(input,RefID):
    gff_file = input
    anno_dict = defaultdict(dict)
    f = open(gff_file,'r')
    for line in f:
        if line.startswith("#"):
            continue
        else:
            gffline = line.rstrip().split("\t")
            gff_ref_id = gffline[0].replace(".","_")
            if len(gffline) > 2 and gffline[2] == "CDS":
                if RefID and (RefID != gff_ref_id and RefID != gffline[0]):
                    continue
                annotations = dict(x.split("=") for x in gffline[8].split(";"))
                for i in range(int(gffline[3]),int(gffline[4])+1):
                    anno_dict[i]['name']= annotations['Name'] if 'Name' in annotations else None
                    anno_dict[i]['locus_tag']= annotations['locus_tag'] if 'locus_tag' in annotations else None
                    anno_dict[i]['product']= annotations['product'] if 'product' in annotations else None
                    anno_dict[i]['protein_id']= annotations['protein_id'] if 'protein_id' in annotations else None
    f.close()
    return anno_dict

def parse_cov_file(input):
    ## read the coverage and save in a list
    cov_file = input
    cov_list = []
    with open(cov_file,'r') as cov_file:
        for line in cov_file:
            pos, cov = line.rstrip().split("\t")
            cov_list.append(int(cov))

    cov_array=np.array(cov_list)
    return cov_array

def parse_bam_file(bam,pp,outdir,RefID):

    bam_input = bam
    if pp:
        bam_prefix = Path(bam).stem
        bam_input = outdir + os.sep + bam_prefix + ".pp.bam"
        samfile = pysam.AlignmentFile(bam, "rb")
        ppcount=0;
        allcount=0;
        proper_pairedreads = pysam.AlignmentFile(bam_input, "wb", template=samfile)
        for read in samfile.fetch(until_eof=True):
            allcount+=1
            if read.is_proper_pair:
                ppcount+=1;
                proper_pairedreads.write(read)

        proper_pairedreads.close()
        samfile.close()
        logging.info("Total Reads: %d" % (allcount))
        logging.info("Proper Paired Reads: %d" % (ppcount))
        if ppcount==0:
            logging.error("Failed: There is no proper paired reads in your input file %s." % (bam))
            sys.exit()
        
    cov_list = []
    for line in pysam.samtools.depth("-aa","-d0", bam_input ,split_lines=True):
        id, pos, cov = line.rstrip().split("\t")
        if RefID and RefID == id:
            cov_list.append(int(cov))
        else:
            cov_list.append(int(cov))
    cov_array=np.array(cov_list)
    return cov_array

def calculate_mean_cov_per_amplicon(cov_np_array,amplicon_d, anno_d):
    mean_dict=defaultdict(dict)
    for key in amplicon_d:
        if amplicon_d[key]:
            start = amplicon_d[key][0]
            end = amplicon_d[key][-1]
            #print(key,start,end)
            mean_dict[key]['cov'] = 0 if start > cov_np_array.size else cov_np_array[start:end+1].mean()
            mean_dict[key]['range'] = f"{start + 1} - {end + 1}"
            if anno_d:
                for i in range(start, end + 1):
                    if anno_d[i]:
                        mean_dict[key]['product'] = anno_d[i]['product']
                        mean_dict[key]['name'] = anno_d[i]['name']
                        mean_dict[key]['locus_tag'] = anno_d[i]['locus_tag']
                        mean_dict[key]['protein_id'] = anno_d[i]['protein_id']
        else:
            # not have any range in this key region
            mean_dict[key]['cov'] = 0 
            
    
        
    return mean_dict

def calculate_N_per_amplicon(cov_np_array,amplicon_d, mincov):
    num_low_depth=dict()
    for key in amplicon_d:
        if amplicon_d[key]:
            start = amplicon_d[key][0]
            end = amplicon_d[key][-1]
            if start > cov_np_array.size:
                num_low_depth[key] = 0
            else:
                arr = cov_np_array[start:end+1]
                num_low_depth[key] = np.size(np.where(arr < mincov))
        else:
            # not have any range in this key region
            num_low_depth[key] = 0
    return num_low_depth
    
def write_dict_to_file(mean_d,uniq_mean_d,ambiguity_d,uniq_ambiguity_d,outdir,prefix,mincov):
    output_amplicon_cov_txt = outdir + os.sep + prefix + '_amplicon_coverage.txt'
    try:
        os.remove(output_amplicon_cov_txt)
    except OSError:
        pass
    with open(output_amplicon_cov_txt,"w") as f:
        f.write(f"ID\tWhole_Amplicon\tUnique_Amplicon\tWhole_Amplicon_Ns(cov<{mincov})\tUnique_Amplicon_Ns(cov<{mincov})\n")
        for key in mean_d:
            f.write("%s\t%.2f\t%.2f\t%.2f\t%.2f\n" %
              (
                key,
                mean_d[key]['cov'],uniq_mean_d[key]['cov'],
                ambiguity_d[key],uniq_ambiguity_d[key]
              )
            )
      

def barplot(mean_dict,uniq_mean_d,ambiguity_d,uniq_ambiguity_d,input_bed,overall_mean,outdir,prefix,mincov,depth_lines=[5,10,20,50]):
    #plot bar chart
    x_name=Path(input_bed).stem
    x=list(mean_dict.keys())
    y=[ mean_dict[k]['cov'] for k in x ]
    anno_list=list(mean_dict.keys())
    for i in range(0,len(x)):
        k=x[i]
        if 'product' in mean_dict[k]:
            anno_list[i] = "Product: " +  mean_dict[k]['product'] + "<br>"
        if 'name' in mean_dict[k]:
            anno_list[i] += "name: " +  mean_dict[k]['name'] + "<br>"
        if 'protein_id' in mean_dict[k]:
            anno_list[i] += "protein_id: " +  mean_dict[k]['protein_id'] + "<br>"
        if 'locus_tag' in mean_dict[k]:
            anno_list[i] += "locus_tag: " +  mean_dict[k]['locus_tag'] + "<br>"
        if 'range' in mean_dict[k]:
            anno_list[i] += "range: " +  mean_dict[k]['range'] + "<br>"
    uniq_x=list(uniq_mean_d.keys())
    uniq_y=[ uniq_mean_d[k]['cov'] for k in uniq_x ]
    uniq_anno_list=list(uniq_mean_d.keys())
    for i in range(0,len(uniq_x)):
        k=uniq_x[i]
        if 'product' in uniq_mean_d[k]:
            uniq_anno_list[i] = "Product: " +  uniq_mean_d[k]['product'] + "<br>"
        if 'name' in uniq_mean_d[k]:
            uniq_anno_list[i] += "name: " +  uniq_mean_d[k]['name'] + "<br>"
        if 'protein_id' in uniq_mean_d[k]:
            uniq_anno_list[i] += "protein_id: " +  uniq_mean_d[k]['protein_id'] + "<br>"
        if 'locus_tag' in uniq_mean_d[k]:
            uniq_anno_list[i] += "locus_tag: " +  uniq_mean_d[k]['locus_tag'] + "<br>"
        if 'range' in uniq_mean_d[k]:
            uniq_anno_list[i] += "range: " +  uniq_mean_d[k]['range'] + "<br>"
    barcolor1 = ['lightsalmon' if i >= 20 else 'blue' if i >5 else 'black' for i in y]
    barcolor2 = ['lightsalmon' if i >= 20 else 'blue' if i >5 else 'black' for i in uniq_y]
    fig = make_subplots(specs=[[{"secondary_y": True}]])

    # coverage levels
    if anno_list == x :
        fig.add_trace(
          go.Bar(x=uniq_x,y=uniq_y,marker_color=barcolor2,visible=True,name="Coverage",showlegend=False, 
             hovertemplate="Coverage: %{y:.1f}" + "<extra></extra>"),
          secondary_y=False
        )
        fig.add_trace(
          go.Bar(x=x, y=y, marker_color=barcolor1,visible=False,name="Coverage",showlegend=False,
             hovertemplate="Coverage: %{y:.1f}" + "<extra></extra>"),
          secondary_y=False
        )
    else:
        fig.add_trace(
          go.Bar(x=uniq_x,y=uniq_y,marker_color=barcolor2,visible=True,name="Coverage",showlegend=False,  text=uniq_anno_list,
                 hovertemplate="Coverage: %{y:.1f}<br>" + "%{text}" + "<extra></extra>"),
          secondary_y=False
        )
        fig.add_trace(
          go.Bar(x=x, y=y, marker_color=barcolor1,visible=False,name="Coverage",showlegend=False, text=anno_list,
                 hovertemplate="Coverage: %{y:.1f}<br>" + "%{text}" + "<extra></extra>"),
          secondary_y=False
        )
    #fig = go.Figure(data=[go.Bar(x=x, y=y, marker_color=barcolor1,visible=False)])

    #fig.add_trace(go.Bar(x=uniq_x,y=uniq_y,marker_color=barcolor2,visible=True))

    # Add Ns
    y2=list(ambiguity_d.values())
    uniq_y2=list(uniq_ambiguity_d.values())
    fig.add_trace(
      go.Scatter(x=uniq_x,y=uniq_y2,marker_color='blueviolet',visible=True,name="Num of Ns (cov<" + str(mincov) +")", hovertemplate="Num of Ns: %{y}<extra></extra>" ),
      secondary_y=True
    )
    fig.add_trace(
      go.Scatter(x=x,y=y2,marker_color='blueviolet',visible=False,name="Num of Ns (cov<" + str(mincov) +")", hovertemplate="Num of Ns: %{y}<extra></extra>" ),
      secondary_y=True
    )
    
    depthMean = [dict(type='line',
                    xref='paper',x0=0,x1=0.94,
                    yref='y',y0=overall_mean,y1=overall_mean,
                    line=dict(
                        color="red",
                        width=1,
                        dash='dot',
                    )
                )]

    depth_buttons = list([
        dict(label="Mean(" + str(int(overall_mean)) + 'X)',
             method="relayout",
             args=["shapes", depthMean])])

    depth_selectors = depthMean[:]
    for depth in depth_lines:
        name = '{}X'.format(depth)
        depth_selector = [dict(type='line',
                               xref='paper', x0=0, x1=0.94,
                               yref='y', y0=depth, y1=depth,
                               line=dict(
                                   color="black",
                                   width=1,
                                   dash='dot',
                               ))]
        depth_selectors.extend(depth_selector)
        depth_buttons.append(dict(label=name,
                                  method="relayout",
                                  args=["shapes", depth_selector]))
    depth_buttons.append(
        dict(label="All", method="relayout", args=["shapes", depth_selectors])
    )
    updatemenus = list([
        dict(
             buttons=list([
                dict(label='Linear Scale',
                     method='relayout',
                     args=[{'title': '',
                           'yaxis.type': 'linear'}]),
                dict(label='Log Scale',
                     method='relayout',
                     args=[{'title': '',
                            'yaxis.type': 'log'}])
                ]),
             direction="down",
             x=0,
             xanchor="left",
             y=1.03,
             yanchor="top"
        ),
        dict(
            buttons=list([
                dict(label='Unique',
                     method='update',
                     args=[{"visible": [True, False, True, False]}]),
                dict(label='Whole Amplicon',
                     method='update',
                     args=[{"visible": [False, True, False, True]}])
                ]),
             direction="down",
             x=0.15,
             xanchor="left",
             y=1.03,
             yanchor="top"
        ),
        dict(
            buttons=depth_buttons,
             direction="down",
             x=0.30,
             xanchor="left",
             y=1.03,
             yanchor="top",
             showactive=True
        )
    ])

    fig.update_xaxes(
        tickfont=dict(family='Courier New, monospace', size=8),
        ticks="outside",
        tickwidth=0.5,
        ticklen=3
    )
    fig.update_yaxes(
        tickfont=dict(family='Courier New, monospace', size=8),
        ticks="outside",
        tickwidth=0.5,
        ticklen=3
    )
    fig.update_yaxes(
        tickfont=dict(family='Courier New, monospace', size=8, color="blueviolet"),
        ticks="outside",
        tickwidth=0.5,
        ticklen=3,
        secondary_y=True,
        title_text="Num of Ns (cov<" + str(mincov) +")",
        titlefont=dict(
            color="blueviolet"
        ),
    )
    fig.update_layout(
        updatemenus=updatemenus,
        xaxis_title=x_name,
        yaxis_title="Mean Coverage(X)",
        font=dict(
            family="Courier New, monospace",
            size=10,
        ),
        annotations=[
            dict(text="Y-axis:",x=0, xref="paper",
                                y=1.06, yref="paper", 
                                showarrow=False),
            dict(text="Region:",x=0.15, xref="paper",
                                y=1.06, yref="paper", 
                                showarrow=False),
            dict(text="DepthLine:", x=0.30, xref="paper",
                                y=1.06, yref="paper", 
                                showarrow=False),
        ],
        shapes=[
            dict(type='line',
            xref='paper',x0=0,x1=0.94,
            yref='y',y0=overall_mean,y1=overall_mean,
            line=dict(
                color="red",
                width=1,
                dash='dashdot',
            ),
            ),
        ],
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=0.95
        ),
        hovermode="x unified"
    )

    output_html = outdir + os.sep + prefix + '_amplicon_coverage.html'
    fig.write_html(output_html)

def run(argvs):
    
    if (argvs.cov):
        cov_array = parse_cov_file(argvs.cov)
        prefix = Path(argvs.cov).stem
    if (argvs.bam):
        cov_array = parse_bam_file(argvs.bam,argvs.pp, argvs.outdir, argvs.refID)
        prefix = Path(argvs.bam).stem 
    if (argvs.bed):
        is_bed6(argvs.bed)
        amplicon_dict = convert_bed_to_amplicon_dict(argvs.bed,cov_array,argvs.refID)
        uniq_amplicon_dict = convert_bed_to_amplicon_dict(argvs.bed,cov_array,argvs.refID,True,argvs.count_primer)
        bedfile = argvs.bed
    if (argvs.bedpe):
        amplicon_dict = convert_bedpe_to_amplicon_dict(argvs.bedpe,cov_array,argvs.refID)
        uniq_amplicon_dict = convert_bedpe_to_amplicon_dict(argvs.bedpe,cov_array,argvs.refID,True,argvs.count_primer)
        bedfile = argvs.bedpe
    
    anno_dict = parse_gff_file(argvs.gff,argvs.refID) if (argvs.gff) else None
    
    if not argvs.prefix:
        argvs.prefix = prefix

    # Count the number of Ns < argvs.mincov per amplicon
    ambiguity_d = calculate_N_per_amplicon(cov_array,amplicon_dict, argvs.mincov)
    uniq_ambiguity_d = calculate_N_per_amplicon(cov_array,uniq_amplicon_dict,argvs.mincov)

    # Get the coverage per amplicon
    amplicon_mean_d = calculate_mean_cov_per_amplicon(cov_array,amplicon_dict,anno_dict)
    uniq_amplicon_mean_d = calculate_mean_cov_per_amplicon(cov_array,uniq_amplicon_dict,anno_dict)

    # Write results to a tab-delimited file
    write_dict_to_file(amplicon_mean_d,uniq_amplicon_mean_d,ambiguity_d,uniq_ambiguity_d,argvs.outdir,argvs.prefix,argvs.mincov)

    barplot(amplicon_mean_d,uniq_amplicon_mean_d,ambiguity_d,uniq_ambiguity_d,bedfile,cov_array.mean(),argvs.outdir,argvs.prefix,argvs.mincov,argvs.depth_lines)

if __name__ == '__main__':
    argvs = setup_argparse()
    mkdir_p(argvs.outdir)
    run(argvs)
