#!/usr/bin/env python
# -*- coding: utf-8
"""
Converts anvi'o variability output to VCF
"""


import re
import sys
import argparse
import datetime

import numpy as np
import pandas as pd

from collections import defaultdict
from collections import OrderedDict

import anvio

import anvio.utils as utils
import anvio.terminal as terminal
import anvio.filesnpaths as filesnpaths

from anvio.errors import ConfigError, FilesNPathsError


__author__ = "Srinidhi Varadharajan (Center for Ecological and Evolutionary Synthesis (CEES), University of Oslo, Norway)"
__copyright__ = "Copyleft 2015-2019, the Meren Lab (http://merenlab.org/)"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"
__description__ = ("A script to convert SNV output obtained from anvi-gen-variability-profile "
                   "to the standard VCF format")


progress = terminal.Progress()
run = terminal.Run()


def main(args):
    A = lambda x: args.__dict__[x] if x in args.__dict__ else None
    input_file_path = A('input')
    output_file_path = A('output_file') or '%s.vcf' % (str(input_file_path))

    filesnpaths.is_file_tab_delimited(input_file_path)
    filesnpaths.is_output_file_writable(output_file_path)

    run.info("Input file path", input_file_path)
    run.info("Output file path", output_file_path)

    # some basic checks
    fields_in_input_file = utils.get_columns_of_TAB_delim_file(input_file_path)
    fields_expected = ['A', 'C', 'G', 'N', 'unique_pos_identifier', 'corresponding_gene_call', 'competing_nts', 'pos']
    if len([f for f in fields_expected if f not in fields_in_input_file]):
        raise ConfigError("The input file %s does not look like a file generated by the program "
                          "`anvi-gen-variability-profile` with the nucleotide engine as it is missing "
                          "some of the key columns :/")

    if 'split_name' not in fields_in_input_file:
        raise ConfigError("While your input file looks largely OK, it is missing the split name information that is "
                          "necessary for a proper VCF output. Regenerating the same anvi-gen-variability output with "
                          "the flag `--include-split-names` will solve this issue.")


    progress.new("Processing")
    progress.update('...')

    input_df = pd.read_table(input_file_path)

    #######################   SELECTING COLUMNS OF INTEREST ##################################################
    input_df = input_df.filter(items=['unique_pos_identifier','split_name', 'pos','sample_id','coverage','reference','competing_nts'])

   ##########################################################################################################
    ## FOR KEEPING TRACK OF THE IDS WHILE PRINTING DIRECTLY, not used when dictionary printed
    finalVCF = defaultdict(dict)
    sampleNames = sorted(list(input_df['sample_id'].unique()))
    header = ["#CHROM" ,"POS", "ID", "REF", "ALT" ,"QUAL" ,"FILTER", "INFO","FORMAT"]
 
   ##################### ADDING ALLELE COLUMNS FOR FINDING GENOTYPE #########################################
    input_df['competing_nts'] = input_df.competing_nts.astype(str)
    input_df['allele1'] = input_df.competing_nts.str[0]
    input_df['allele2'] = input_df.competing_nts.str[1]
    input_df["allele2"], input_df["allele1"] = np.where(input_df['allele2']==input_df['reference'],
                                                        [input_df["allele1"], input_df["allele2"]],
                                                        [input_df["allele2"], input_df["allele1"]])

    genotype = defaultdict(dict)
    alt_alleleDict = defaultdict(list)
    sampleInfoDict = defaultdict(dict)
    for index, row in input_df.iterrows():
        key = row['unique_pos_identifier']
        sample_name = row['sample_id']
        Ref_allele = row['reference']
        a1 = row['allele1']
        a2 = row['allele2']

        if a1 not in alt_alleleDict[key] and a1 != row['reference']:
            alt_alleleDict[key].append(a1)
        if a2 not in alt_alleleDict[key] and a2 != row['reference']:
            alt_alleleDict[key].append(a2)

        if row['reference'] in alt_alleleDict[key]:
            alt_alleleDict[key].remove( row['reference'])

        ## SET UP A GENOTYPE dict.
        if a2 in row['reference']:
            genotype[key][sample_name] = '0/0'
        elif a1 in row['reference']:
            genotype[key][sample_name] = '0/'+str(alt_alleleDict[key].index(a2) + 1)

        else:
            genotype[key][sample_name] = str(alt_alleleDict[key].index(a1) + 1)+'/'+str(alt_alleleDict[key].index(a2) + 1)
        sampleCol = genotype[key][sample_name] + ':' + str(row['coverage'])
        sampleInfoDict[key][sample_name]=sampleCol

        #if key not in finalVCF:
        finalVCF[key]=[row['split_name'],row['pos'],key,Ref_allele,','.join(alt_alleleDict[key]),99,'PASS','.','GT:DP']
    ##########################################################################################################
    ##########################################################################################################
    for key in sampleInfoDict.keys():

       for sample in sampleNames:
           if sample not in sampleInfoDict.get(key, {}):
               sampleInfoDict[key][sample]='./.'
           finalVCF[key].append(sampleInfoDict[key][sample])
    
    ######################################### PRINT THE VCF ################################################
    ###### PRINT HEADERS AND INFO

    with open(output_file_path, 'w') as output_file:
        output_file.write('##fileformat=VCFv4.0' + '\n')
        output_file.write('##fileDate=' + datetime.datetime.now().strftime("%Y%m%d") + '\n')
        output_file.write('##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">' + '\n')
        output_file.write('##FORMAT=<ID=DP,Number=1,Type=Integer,Description="Read Depth">' +  '\n')
        output_file.write('\t'.join(header + sampleNames)  + '\n')

        ###### sort?
        sortedfinalVCF = sorted(finalVCF.values())
        for i in range(0, len(sortedfinalVCF)):
            output_file.write('\t'.join([str(x) for x in sortedfinalVCF[i]]) + '\n')

    progress.end()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__description__)

    parser.add_argument('-i', '--input', metavar='FILE_PATH', help='Filepath to the SNV table. This is the output from the \
                                anvi-gen-variability-profile program with the nucleotide engine (which is the default engine).')
    parser.add_argument(*anvio.A('output-file'), **anvio.K('output-file'))

    args = anvio.get_args(parser)

    try:
        main(args)
    except ConfigError as e:
        print(e)
        sys.exit(1)
    except FilesNPathsError as e:
        print(e)
        sys.exit(1)
