#!/usr/bin/env python3
# Copyright ©2001-2019 Python Software Foundation

# Author: Jasper Boom

# Prequisites:
# - conda install -c bioconda vsearch=2.14.1

# Imports
import argparse
import shutil
import string
import sys
import random
import re
import os
import zipfile
import pandas as pd
import subprocess as sp

# The removeWorkDirs function.
# This function removes all temporary working directories.
def removeWorkDirs(strTempDir):
    shutil.rmtree(strTempDir)

# The setOutputFiles function.
# This function creates the tabular and BLAST output files.
def setOutputFiles(strDir, flBlast, flTabular):
    dfOutput = pd.DataFrame(columns=["UMI ID", "UMI SEQ", "Read Count",
                                     "Centroid Read"])
    intCount = 0
    for strFileName in os.listdir(strDir + "/postClustering"):
        strUmiNumber = strFileName.split("_")[0]
        strUmiString = strFileName.split("_")[1][:-6]
        intLineCount = 0
        with open(strDir + "/postClustering/" + strFileName) as oisClusterFile:
            for strLine in oisClusterFile:
                intLineCount += 1
        with open(strDir + "/postClustering/" + strFileName) as oisUmiFile:
            if intLineCount == 2:
                for strLine in oisUmiFile:
                    if strLine.startswith(">"):
                        strHeader = strLine.split("=")[1].strip("\n")
                        strRead = next(oisUmiFile)
                        dfOutput.loc[intCount] = [strUmiNumber, strUmiString,
                                                  strHeader.strip("\n"),
                                                  strRead.strip("\n").upper()]
                        with open(flBlast, "a") as flOutput:
                            flOutput.write(">" + strUmiNumber + "\n")
                            flOutput.write(strRead.strip("\n").upper() + "\n")
                    else:
                        pass
            elif intLineCount > 2:
                intVersionCount = 1
                for strLine in oisUmiFile:
                    if strLine.startswith(">"):
                        strHeader = strLine.split("=")[1].strip("\n")
                        strRead = next(oisUmiFile)
                        strUmiVersion = strUmiNumber + "." + str(intVersionCount)
                        dfOutput.loc[intCount] = [strUmiVersion, strUmiString,
                                                  strHeader.strip("\n"),
                                                  strRead.strip("\n").upper()]
                        with open(flBlast, "a") as flOutput:
                            flOutput.write(">" + strUmiVersion + "\n")
                            flOutput.write(strRead.strip("\n").upper() + "\n")
                        intVersionCount += 1
                        intCount += 1
                    else:
                        pass
            else:
                pass
        intCount += 1
    dfOutput = dfOutput.set_index("UMI ID")
    dfOutput.to_csv(flTabular, sep="\t", encoding="utf-8")

# The setClusterSize function.
# This function controls the VSEARCH clustering. Every fasta file created by
# setSortBySize is clustered using VSEARCH. The expected result is a single
# centroid sequence. This is checked in the setOutputFiles function.
def setClusterSize(strDir, strIdentity):
    for strFileName in os.listdir(strDir + "/clustering"):
        if strFileName.startswith("sorted"):
            strInputCommand = strDir + "/clustering/" + strFileName
            strOutputCommand = strDir + "/postClustering/" + strFileName[11:]
            rafClustering = sp.Popen(["vsearch", "--cluster_size", strInputCommand,
                                      "--fasta_width", "0", "--id", strIdentity,
                                      "--sizein", "--minseqlength", "1",
                                      "--centroids", strOutputCommand,
                                      "--sizeout"], stdout=sp.PIPE, stderr=sp.PIPE)
            strOut, strError = rafClustering.communicate()
        else:
            pass

# The setSortBySize function.
# This function controls the VSEARCH sorting. Every fasta file created by
# setDereplication is sorted based on abundance. Any reads with a abundance
# lower than strAbundance will be discarded.
def setSortBySize(strDir, strAbundance):
    for strFileName in os.listdir(strDir + "/clustering"):
        if strFileName.startswith("derep"):
            strInputCommand = strDir + "/clustering/" + strFileName
            strOutputCommand = strDir + "/clustering/" + "sorted" + strFileName
            rafSort = sp.Popen(["vsearch", "--sortbysize", strInputCommand,
                                "--output", strOutputCommand, "--minseqlength",
                                "1", "--minsize", strAbundance],
                                stdout=sp.PIPE, stderr=sp.PIPE)
            strOut, strError = rafSort.communicate()
        else:
            pass

# The setDereplication function.
# This function controls the VSEARCH dereplication. Every fasta file created by
# getFastaFile is dereplicated. This step is necessary for the sorting step to
# work.
def setDereplication(strDir):
    for strFileName in os.listdir(strDir + "/preZip"):
        if strFileName.endswith(".fasta"):
            strInputCommand = strDir + "/preZip/" + strFileName
            strOutputCommand = strDir + "/clustering/" + "derep" + strFileName
            rafDerep = sp.Popen(["vsearch", "--derep_fulllength",
                                 strInputCommand, "--output", strOutputCommand,
                                 "--minseqlength", "1", "--sizeout"],
                                 stdout=sp.PIPE, stderr=sp.PIPE)
            strOut, strError = rafDerep.communicate()
        else:
            pass

# The getZipArchive function.
# This function creates a zip archive from all files in the specified directory.
def getZipArchive(flZip, strDir):
    with zipfile.ZipFile(flZip, "w") as objZip:
        for strFileName in os.listdir(strDir + "/preZip"):
            if strFileName.endswith(".fasta"):
                strFullPath = strDir + "/preZip/" + strFileName
                objZip.write(strFullPath, os.path.basename(strFullPath))

# The getFastaFile function.
# This function creates separate fasta files for every unique UMI. The function
# creates a unique name for every UMI file and combines that with the desired
# output path. A file is opened or created based on this combination. The
# read header and the read itself are appended to it.
def getFastaFile(strDir, dicUniqueUmi, strHeader, strRead, strCode):
    strFileIdentifier = "UMI#" + str(dicUniqueUmi[strCode]) + "_" + strCode + ".fasta"
    strFileName = strDir + "/preZip/" + strFileIdentifier
    with open(strFileName, "a") as flOutput:
        flOutput.write(strHeader)
        flOutput.write(strRead)

# The useZeroPosition function.
# This function will isolate either a 5'-end, 3'-end or double UMI based on
# the starting or ending position of a read.
# It will check if both the forward and reverse primer can be found. If this
# check is passed, the 5'-end, 3'-end UMI or double UMI will be isolated by
# adding or subtracting the UMI length from the first or last position of the
# read. The function will return (if possible) the UMI nucleotides.
def useZeroPosition(strSearch, intUmiLength, strRead, strForward, strReverse):
    tplCheckForward = re.search(strForward, strRead)
    if tplCheckForward != None:
        tplCheckReverse = re.search(strReverse, strRead)
        if tplCheckReverse != None:
            if strSearch == "umi5":
                return strRead[0:int(intUmiLength)]
            elif strSearch == "umidouble":
                return (strRead[0:int(intUmiLength)],
                        strRead[-int(intUmiLength):])
            elif strSearch == "umi3":
                return strRead[-int(intUmiLength):]
            else:
                pass
        else:
            pass
    else:
        pass

# The useAdapter function.
# This function searches for a regex string in the provided read. It will
# isolate either a 5'-end, 3'-end or double UMI. The isolation is based on
# this read structure:
#     ADAPTER(F)-UMI(5')-PRIMER(F)-INSERT-PRIMER(R)-UMI(3')-ADAPTER(R).
# When looking for the 5'-end UMI, the last position of ADAPTER(F) is used,
# when looking for the 3'-end UMI, the first position of ADAPTER(R) is used,
# when looking for the double UMI, both mentioned positions are used.
# These positons plus or minus the UMI length result in the UMI nucleotides.
# In the case of umi5 or umi3, a check needs to be passed. This check makes
# sure the opposite adapters are also present, otherwise no UMI is returned.
# The function will return (if possible) the UMI nucleotides.
def useAdapter(strSearch, intUmiLength, strRead, strForward, strReverse):
    if strSearch == "umi5" or strSearch == "umidouble":
        intPositionForward = re.search(strForward, strRead).end()
        intPositionUmiForward = intPositionForward + int(intUmiLength)
        strUmiForward = strRead[intPositionForward:intPositionUmiForward]
        if strSearch == "umi5":
            tplCheckReverse = re.search(strReverse, strRead)
            if tplCheckReverse != None:
                return strUmiForward
            else:
                pass
        elif strSearch == "umidouble":
            intPositionReverse = re.search(strReverse, strRead).start()
            intPositionUmiReverse = intPositionReverse - int(intUmiLength)
            strUmiReverse = strRead[intPositionUmiReverse:intPositionReverse]
            return strUmiForward, strUmiReverse
        else:
            pass
    elif strSearch == "umi3":
        tplCheckForward = re.search(strForward, strRead)
        if tplCheckForward != None:
            intPositionReverse = re.search(strReverse, strRead).start()
            intPositionUmiReverse = intPositionReverse - int(intUmiLength)
            strUmiReverse = strRead[intPositionUmiReverse:intPositionReverse]
            return strUmiReverse
        else:
            pass
    else:
        pass

# The usePrimer function.
# This function searches for a regex string in the provided read. It will
# isolate either a 5'-end, 3'-end or double UMI. The isolation is based on
# this read structure:
#     UMI(5')-PRIMER(F)-INSERT-PRIMER(R)-UMI(3').
# When looking for the 5'-end UMI, the first position of PRIMER(F) is used,
# when looking for the 3'-end UMI, the last position of PRIMER(R) is used,
# when looking for the double UMI, both mentioned positions are used.
# These positons plus or minus the UMI length result in the UMI nucleotides.
# In the case of umi5 or umi3, a check needs to be passed. This check makes
# sure the opposite primer are also present, otherwise no UMI is returned.
# The function will return (if possible) the UMI nucleotides.
def usePrimer(strSearch, intUmiLength, strRead, strForward, strReverse):
    if strSearch == "umi5" or strSearch == "umidouble":
        intPositionForward = re.search(strForward, strRead).start()
        intPositionUmiForward = intPositionForward - int(intUmiLength)
        strUmiForward = strRead[intPositionUmiForward:intPositionForward]
        if strSearch == "umi5":
            tplCheckReverse = re.search(strReverse, strRead)
            if tplCheckReverse != None:
                return strUmiForward
            else:
                pass
        elif strSearch == "umidouble":
            intPositionReverse = re.search(strReverse, strRead).end()
            intPositionUmiReverse = intPositionReverse + int(intUmiLength)
            strUmiReverse = strRead[intPositionReverse:intPositionUmiReverse]
            return strUmiForward, strUmiReverse
        else:
            pass
    elif strSearch == "umi3":
        tplCheckForward = re.search(strForward, strRead)
        if tplCheckForward != None:
            intPositionReverse = re.search(strReverse, strRead).end()
            intPositionUmiReverse = intPositionReverse + int(intUmiLength)
            strUmiReverse = strRead[intPositionReverse:intPositionUmiReverse]
            return strUmiReverse
        else:
            pass
    else:
        pass

# The getReverseComplement function.
# This function creates a complementary string using a nucleotide string as
# input. The function loops through a list version of the nucleotide string
# and checks/changes every character. The function then returns he new string.
def getReverseComplement(strLine):
    dicComplementCodes = {"A": "T", "T": "A", "G": "C", "C": "G", "M": "K",
                          "R": "Y", "W": "W", "S": "S", "Y": "R", "K": "M",
                          "V": "B", "H": "D", "D": "H", "B": "V", "N": "N"}
    lstLine = list(strLine)
    for intPosition in range(len(lstLine)):
        lstLine[intPosition] = dicComplementCodes[lstLine[intPosition]]
    return "".join(lstLine)

# The getRegex function.
# This function creates a regex string using a nucleotide string as input. This
# regex string is based on IUPAC ambiguity codes. The function loops through
# a list version of the nucleotide string and checks per character if it is a
# ambiguous character. If a ambiguous character is found, it is replaced by a
# regex version. The function then returns the new string.
def getRegex(strLine):
    dicAmbiguityCodes = {"M": "[AC]", "R": "[AG]", "W": "[AT]", "S": "[CG]",
                         "Y": "[CT]", "K": "[GT]", "V": "[ACG]", "H": "[ACT]",
                         "D": "[AGT]", "B": "[CGT]", "N": "[GATC]"}
    lstLine = list(strLine)
    for intPosition in range(len(lstLine)):
        if lstLine[intPosition] != "A" and lstLine[intPosition] != "T" and\
           lstLine[intPosition] != "G" and lstLine[intPosition] != "C":
            lstLine[intPosition] = dicAmbiguityCodes[lstLine[intPosition]]
        else:
            pass
    return "".join(lstLine)

# The getUmi function.
# This function controls the UMI searching approach. It first uses the functions
# getRegex and getReverseComplement to create regex strings of both the forward
# and reverse primers/adapters. The regex strings are then directed to the
# associated approach functions [primer/adapter/zero].
def getUmi(strSearch, strApproach, intUmiLength, strForward, strReverse,
           strRead):
    strRead = strRead.strip("\n")
    strRegexForward = getRegex(strForward)
    strRegexComplementReverse = getRegex(getReverseComplement(strReverse[::-1]))
    if strApproach == "primer":
        try:
            return usePrimer(strSearch, intUmiLength, strRead, strRegexForward,
                             strRegexComplementReverse)
        except AttributeError:
            pass
    elif strApproach == "adapter":
        try:
            return useAdapter(strSearch, intUmiLength, strRead, strRegexForward,
                             strRegexComplementReverse)
        except AttributeError:
            pass
    elif strApproach == "zero":
        try:
            return useZeroPosition(strSearch, intUmiLength, strRead,
                                   strRegexForward, strRegexComplementReverse)
        except AttributeError:
            pass
    else:
        pass

# The processInputFile function.
# This function opens the input file and loops through it. It stores the read
# header and read nucleotides. For every read the getUmi function is called,
# this outputs one or two UMI codes. In the case of a double UMI search [umidouble],
# the two UMIs are combined. The length of the UMI is checked before continuing.
# The getFastaFile function is called for every read that contains a UMI.
def processInputFile(flInput, strSearch, strApproach, intUmiLength, strForward,
                     strReverse, strDir, strOperand):
    dicUniqueUmi = {}
    intUniqueUmi = 1
    with open(flInput) as oisInput:
        for strLine in oisInput:
            if strLine[0] == strOperand and bool(re.match("[A-Za-z0-9]",
               strLine[1])) == True:
                strHeader = strLine
                strRead = next(oisInput)
                try:
                    strUmi = getUmi(strSearch, strApproach, intUmiLength,
                                    strForward.upper(), strReverse.upper(),
                                    strRead.upper())
                except UnboundLocalError:
                    pass
                try:
                    if strUmi != None:
                        if strSearch == "umi5" or strSearch == "umi3":
                            intLengthPotentialUmi = len(strUmi)
                            if int(intLengthPotentialUmi) == int(intUmiLength):
                                strCode = strUmi
                            else:
                                strCode = None
                        elif strSearcj == "umidouble":
                            strCombined = strUmi[0] + strUmi[1]
                            intLengthPotentialUmi = len(strCombined)
                            intDoubleUmi = intUmiLengt * 2
                            if int(intLengthPotentialUmi) == int(intDoubleUmi):
                                strCode = strCombined
                            else:
                                strCode = None
                        else:
                            pass
                    else:
                        pass
                except UnboundLocalError:
                    pass
                try:
                    if strCode != None:
                        if strCode not in dicUniqueUmi:
                            dicUniqueUmi[strCode] = intUniqueUmi
                            intUniqueUmi += 1
                        else:
                            pass
                    else:
                        pass
                except UnboundLocalError:
                    pass
                try:
                    if strCode != None:
                        getFastaFile(strDir, dicUniqueUmi, strHeader,
                                     strRead.upper(), strCode)
                    else:
                        pass
                except UnboundLocalError:
                    pass
            strUmi = None
            strCode = None

# The setFormat function.
# This function specifies the first character of the read headers. This
# character is based on the input file format. The function then returns this
# character.
def setFormat(strFormat):
    if strFormat == "fasta":
        return ">"
    elif strFormat == "fastq":
        return "@"
    else:
        print("File format name not recognized.")    

# The setWorkDirs function.
# This function creates the main temporary directory and all subprocess directories.
# It checks if the directories already exist and creates them if they don't.
def setWorkDirs(strDir):
    lstWorkDirs = []
    strRandom = "".join(random.choice(string.ascii_lowercase) for i in range(10))
    lstWorkDirs.append(strDir + "/" + strRandom)
    lstWorkDirs.append(strDir + "/" + strRandom + "/preZip")
    lstWorkDirs.append(strDir + "/" + strRandom + "/clustering")
    lstWorkDirs.append(strDir + "/" + strRandom + "/postClustering")
    for strDirectory in lstWorkDirs:
        if not os.path.exists(strDirectory):
            os.mkdir(strDirectory)
        else:
            print("Directory" + strDirectory + "already exists")
    return lstWorkDirs[0]

# The argvs function.
def parseArgvs():
    strDescription = "A python package to process UMI tagged mixed amplicon\
                      metabarcoding data."
    strEpilog = "This python package requires one extra dependency which can\
                 be easily installed with conda (conda install -c bioconda vsearch)."
    parser = argparse.ArgumentParser(description = strDescription,
                                     epilog = strEpilog)
    parser.add_argument("-v", "-version", action = "version",
                        version = "%(prog)s [0.2]")
    parser.add_argument("-i", "-input", action = "store", dest = "fisInput",
                        help = "The location of the input fasta/fastq file.")
    parser.add_argument("-t", "-tabular", action = "store", dest = "fosTabular",
                        help = "The location of the output tabular file.")
    parser.add_argument("-z", "-zip", action = "store", dest = "fosPreZip",
                        help = "The location of the pre validation zip file.")
    parser.add_argument("-b", "-blast", action = "store", dest = "fosBlast",
                        help = "The location of the output fasta file.")
    parser.add_argument("-f", "-format", action = "store", dest = "disFormat",
                        help = "The format of the input file [fasta/fastq].")
    parser.add_argument("-s", "-search", action = "store", dest = "disSearch",
                        help = "Search UMIs at the 5'-end [umi5], 3'-end [umi3]\
                                or at the 5'-end and 3'-end [umidouble].")
    parser.add_argument("-a", "-approach", action = "store", dest = "disApproach",
                        help = "The UMI search approach [primer/adapter/zero].")
    parser.add_argument("-u", "-length", action = "store", dest = "disUmiLength",
                        help = "The length of the UMI sequence.")
    parser.add_argument("-p", "-identity", action = "store", dest = "disIdentity",
                        help = "The identity percentage with which to perform\
                                the validation.")
    parser.add_argument("-c", "-abundance", action = "store", dest = "disAbundance",
                        help = "The minimum abundance of a read in order to be\
                                included during validation.")
    parser.add_argument("-w", "-forward", action = "store", dest = "disForward",
                        help = "The 5'-end search nucleotides.")
    parser.add_argument("-r", "-reverse", action = "store", dest = "disReverse",
                        help = "The 3'-end search nucleotides.")
    parser.add_argument("-d", "-directory", action = "store", dest = "fisDirectory",
                        help = "The location where the temporary working directory\
                                will be created.")
    argvs = parser.parse_args()
    return argvs

# The main function.
def main():
    argvs = parseArgvs()
    strOperand = setFormat(argvs.disFormat)
    strTempDir = setWorkDirs(argvs.fisDirectory)
    processInputFile(argvs.fisInput, argvs.disSearch, argvs.disApproach,
                     argvs.disUmiLength, argvs.disForward, argvs.disReverse,
                     strTempDir, strOperand)
    getZipArchive(argvs.fosPreZip, strTempDir)
    setDereplication(strTempDir)
    setSortBySize(strTempDir, argvs.disAbundance)
    setClusterSize(strTempDir, argvs.disIdentity)
    setOutputFiles(strTempDir, argvs.fosBlast, argvs.fosTabular)
    removeWorkDirs(strTempDir)

if __name__ == "__main__":
    main()
