#!python

# Copyright (C) 2018 Jasper Boom

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License version 3 as
# published by the Free Software Foundation.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

# Imports
import argparse
import shutil
import string
import random
import re
import os
import zipfile
import pyfastx
import ntpath
import time
import multiprocessing as mp
import pandas as pd
import subprocess as sp
from Bio.Seq import Seq
from Bio.Alphabet import generic_dna


# The setOutputFiles function.
# This function creates the tabular and blast output files.
def setOutputFiles(strDir, strOutputFileName):
    dfOutput = pd.DataFrame(columns=["UMI_ID", "Read_Count"])
    intCount = 0
    flBlast = strOutputFileName + "_BLAST.fasta"
    flTabular = strOutputFileName + "_TABULAR.tbl"
    for strFileName in os.listdir(strDir + "/PostClustering"):
        strUmiNumber = strFileName.split("_")[0]
        intLineCount = 0
        with open(strDir + "/PostClustering/" + strFileName) as oisClusterFile:
            for strLine in oisClusterFile:
                intLineCount += 1
        with open(strDir + "/PostClustering/" + strFileName) as oisUmiFile:
            if intLineCount == 2:
                for strLine in oisUmiFile:
                    if strLine.startswith(">"):
                        strHeader = strLine.split("=")[1].strip("\n")
                        strRead = next(oisUmiFile)
                        dfOutput.loc[intCount] = [
                            strUmiNumber,
                            strHeader.strip("\n"),
                        ]
                        with open(flBlast, "a") as flOutput:
                            flOutput.write(">" + strUmiNumber + "\n")
                            flOutput.write(strRead.strip("\n").upper() + "\n")
                    else:
                        pass
            elif intLineCount > 2:
                intVersionCount = 1
                for strLine in oisUmiFile:
                    if strLine.startswith(">"):
                        strHeader = strLine.split("=")[1].strip("\n")
                        strRead = next(oisUmiFile)
                        strUmiVersion = (
                            strUmiNumber + "." + str(intVersionCount)
                        )
                        dfOutput.loc[intCount] = [
                            strUmiVersion,
                            strHeader.strip("\n"),
                        ]
                        with open(flBlast, "a") as flOutput:
                            flOutput.write(">" + strUmiVersion + "\n")
                            flOutput.write(strRead.strip("\n").upper() + "\n")
                        intVersionCount += 1
                        intCount += 1
                    else:
                        pass
            else:
                pass
        intCount += 1
    dfOutput = dfOutput.set_index("UMI_ID")
    dfOutput.to_csv(flTabular, sep="\t", encoding="utf-8")


# The setClusterSize function.
# This function controls the Vsearch clustering. Every fasta file created by
# getSortBySize is clustered using Vsearch. The expected result is a single
# centroid sequence. This is checked in the setOutputFiles function.
def setClusterSize(strDir, fltIdentity):
    for strFileName in os.listdir(strDir + "/Clustering"):
        if strFileName.startswith("sorted"):
            strInputCommand = strDir + "/Clustering/" + strFileName
            strOutputCommand = strDir + "/PostClustering/" + strFileName[11:]
            rafClustering = sp.Popen(
                [
                    "vsearch",
                    "--cluster_size",
                    strInputCommand,
                    "--fasta_width",
                    "0",
                    "--id",
                    fltIdentity,
                    "--sizein",
                    "--minseqlength",
                    "1",
                    "--centroids",
                    strOutputCommand,
                    "--sizeout",
                ],
                stdout=sp.PIPE,
                stderr=sp.PIPE,
            )
            strOut, strError = rafClustering.communicate()
        else:
            pass


# The getSortBySize function.
# This function controls the Vsearch sorting. Every fasta file created by
# getDereplication is sorted based on abundance. Any reads with a abundance
# lower than intAbundance will be discarded.
def getSortBySize(strDir, intAbundance):
    for strFileName in os.listdir(strDir + "/Clustering"):
        if strFileName.startswith("derep"):
            strInputCommand = strDir + "/Clustering/" + strFileName
            strOutputCommand = strDir + "/Clustering/" + "sorted" + strFileName
            rafSort = sp.Popen(
                [
                    "vsearch",
                    "--sortbysize",
                    strInputCommand,
                    "--output",
                    strOutputCommand,
                    "--minseqlength",
                    "1",
                    "--minsize",
                    intAbundance,
                ],
                stdout=sp.PIPE,
                stderr=sp.PIPE,
            )
            strOut, strError = rafSort.communicate()
        else:
            pass


# The getDereplication function.
# This function controls the Vsearch dereplication. Every fasta file created by
# getFastaFile is dereplicated. This step is necessary for the sorting step to
# work.
def getDereplication(strDir):
    for strFileName in os.listdir(strDir + "/PreValidation"):
        if strFileName.endswith(".fasta"):
            strInputCommand = strDir + "/PreValidation/" + strFileName
            strOutputCommand = strDir + "/Clustering/" + "derep" + strFileName
            rafDerep = sp.Popen(
                [
                    "vsearch",
                    "--derep_fulllength",
                    strInputCommand,
                    "--output",
                    strOutputCommand,
                    "--minseqlength",
                    "1",
                    "--sizeout",
                ],
                stdout=sp.PIPE,
                stderr=sp.PIPE,
            )
            strOut, strError = rafDerep.communicate()
        else:
            pass


# The getFastaFile function.
# This function creates separate fasta files for every unique UMI. The function
# creates a unique name for every UMI file and combines that with the desired
# output path. A file is opened or created based on this combination. The
# read header and the read itself are appended to it.
def getFastaFile(strDir, dicUniqueUmi, strHeader, strRead, strCode):
    strFileIdentifier = (
        "UMI#" + str(dicUniqueUmi[strCode]) + "_" + strCode + ".fasta"
    )
    strFileName = strDir + "/PreValidation/" + strFileIdentifier
    with open(strFileName, "a") as flOutput:
        flOutput.write(">" + strHeader + "\n")
        flOutput.write(strRead + "\n")


# The useZeroPosition function.
# This function will isolate either a 5'-end, 3'-end or double UMI based on
# the start or end position of a sequence.
# It will check if both the forward and reverse primer can be found. If this
# check is passed, the 5'-end, 3'-end UMI or double UMI will be isolated by
# adding or subtracting the UMI length from the first or last position of the
# sequence. The function will return (if possible) the UMI nucleotides.
def useZeroPosition(
    strLocation, intUmiLength, strRead, strForward, strReverse
):
    tplCheckForward = re.search(strForward, strRead)
    if tplCheckForward is not None:
        tplCheckReverse = re.search(strReverse, strRead)
        if tplCheckReverse is not None:
            if strLocation == "umi5":
                return strRead[0 : int(intUmiLength)]
            elif strLocation == "umidouble":
                return (
                    strRead[0 : int(intUmiLength)],
                    strRead[-int(intUmiLength) :],
                )
            elif strLocation == "umi3":
                return strRead[-int(intUmiLength) :]
            else:
                pass
        else:
            pass
    else:
        pass


# The useAdapter function.
# This function searches for a regex string in the provided sequence. It will
# isolate either a 5'-end, 3'-end or double UMI. The isolation is based on
# this sequence structure:
#     ADAPTER(F)-UMI(5')-PRIMER(F)-INSERT-PRIMER(R)-UMI(3')-ADAPTER(R).
# When looking for the 5'-end UMI, the last position of ADAPTER(F) is used,
# when looking for the 3'-end UMI, the first position of ADAPTER(R) is used,
# when looking for the double UMI, both mentioned positions are used.
# These positons plus or minus the UMI length result in the UMI nucleotides.
# In the case of umi5 or umi3, a check needs to be passed. This check makes
# sure the opposite adapters are also present, otherwise no UMI is returned.
# The function will return (if possible) the UMI nucleotides.
def useAdapter(strLocation, intUmiLength, strRead, strForward, strReverse):
    if strLocation == "umi5" or strLocation == "umidouble":
        intPositionForward = re.search(strForward, strRead).end()
        intPositionUmiForward = intPositionForward + int(intUmiLength)
        strUmiForward = strRead[intPositionForward:intPositionUmiForward]
        if strLocation == "umi5":
            tplCheckReverse = re.search(strReverse, strRead)
            if tplCheckReverse is not None:
                return strUmiForward
            else:
                pass
        elif strLocation == "umidouble":
            intPositionReverse = re.search(strReverse, strRead).start()
            intPositionUmiReverse = intPositionReverse - int(intUmiLength)
            strUmiReverse = strRead[intPositionUmiReverse:intPositionReverse]
            return strUmiForward, strUmiReverse
        else:
            pass
    elif strLocation == "umi3":
        tplCheckForward = re.search(strForward, strRead)
        if tplCheckForward is not None:
            intPositionReverse = re.search(strReverse, strRead).start()
            intPositionUmiReverse = intPositionReverse - int(intUmiLength)
            strUmiReverse = strRead[intPositionUmiReverse:intPositionReverse]
            return strUmiReverse
        else:
            pass
    else:
        pass


# The usePrimer function.
# This function searches for a regex string in the provided sequence. It will
# isolate either a 5'-end, 3'-end or double UMI. The isolation is based on
# this sequence structure:
#     UMI(5')-PRIMER(F)-INSERT-PRIMER(R)-UMI(3').
# When looking for the 5'-end UMI, the first position of PRIMER(F) is used,
# when looking for the 3'-end UMI, the last position of PRIMER(R) is used,
# when looking for the double UMI, both mentioned positions are used.
# These positons plus or minus the UMI length result in the UMI nucleotides.
# In the case of umi5 or umi3, a check needs to be passed. This check makes
# sure the opposite primer is also present, otherwise no UMI is returned.
# The function returns (if possible) the UMI nucleotides.
def usePrimer(strLocation, intUmiLength, strRead, strForward, strReverse):
    if strLocation == "umi5" or strLocation == "umidouble":
        intPositionForward = re.search(strForward, strRead).start()
        intPositionUmiForward = intPositionForward - int(intUmiLength)
        strUmiForward = strRead[intPositionUmiForward:intPositionForward]
        if strLocation == "umi5":
            tplCheckReverse = re.search(strReverse, strRead)
            if tplCheckReverse is not None:
                return strUmiForward
            else:
                pass
        elif strLocation == "umidouble":
            intPositionReverse = re.search(strReverse, strRead).end()
            intPositionUmiReverse = intPositionReverse + int(intUmiLength)
            strUmiReverse = strRead[intPositionReverse:intPositionUmiReverse]
            return strUmiForward, strUmiReverse
        else:
            pass
    elif strLocation == "umi3":
        tplCheckForward = re.search(strForward, strRead)
        if tplCheckForward is not None:
            intPositionReverse = re.search(strReverse, strRead).end()
            intPositionUmiReverse = intPositionReverse + int(intUmiLength)
            strUmiReverse = strRead[intPositionReverse:intPositionUmiReverse]
            return strUmiReverse
        else:
            pass
    else:
        pass


# The getRegex function.
# This function creates a regex string using a nucleotide string as input. This
# regex string is based on the IUPAC ambiguity codes. The function loops
# through a list version of the nucleotide string and checks per character if
# it is a ambiguous character. If a ambiguous character is found, it is
# replaced by a regex version. The function returns the new regex string.
def getRegex(strLine):
    dicAmbiguityCodes = {
        "M": "[AC]",
        "R": "[AG]",
        "W": "[AT]",
        "S": "[CG]",
        "Y": "[CT]",
        "K": "[GT]",
        "V": "[ACG]",
        "H": "[ACT]",
        "D": "[AGT]",
        "B": "[CGT]",
        "N": "[GATC]",
    }
    lstLine = list(strLine)
    for intPosition in range(len(lstLine)):
        if (
            lstLine[intPosition] != "A"
            and lstLine[intPosition] != "T"
            and lstLine[intPosition] != "G"
            and lstLine[intPosition] != "C"
        ):
            lstLine[intPosition] = dicAmbiguityCodes[lstLine[intPosition]]
        else:
            pass
    return "".join(lstLine)


# The getUmi function.
# This function controls the UMI searching approach. It first uses the
# functions getRegex to create regex strings of both the forward and reverse
# primers/adapters. The regex strings are then directed to the associated
# anchor functions [primer/adapter/zero].
def getUmi(
    strLocation, strAnchor, intUmiLength, strForward, strReverse, strRead
):
    strRegexForward = getRegex(strForward)
    strReverseComplement = Seq(strReverse, generic_dna)
    strReverseComplement = strReverseComplement.reverse_complement()
    strRegexReverseComplement = getRegex(strReverseComplement)
    if strAnchor == "primer":
        try:
            return usePrimer(
                strLocation,
                intUmiLength,
                strRead,
                strRegexForward,
                strRegexReverseComplement,
            )
        except AttributeError:
            pass
    elif strAnchor == "adapter":
        try:
            return useAdapter(
                strLocation,
                intUmiLength,
                strRead,
                strRegexForward,
                strRegexReverseComplement,
            )
        except AttributeError:
            pass
    elif strAnchor == "zero":
        try:
            return useZeroPosition(
                strLocation,
                intUmiLength,
                strRead,
                strRegexForward,
                strRegexReverseComplement,
            )
        except AttributeError:
            pass
    else:
        pass


# The processInputFile function.
# This function calls the getUmi function for the supplied read, this outputs
# one or two UMI codes. In the case of a double UMI [umidouble], the two UMIs
# are combined. The length of the UMI is checked before continuing. The
# getFastaFile function is called if the read contains a UMI.
def processInputFile(
    strHeader,
    strRead,
    strLocation,
    strAnchor,
    intUmiLength,
    strForward,
    strReverse,
    strTempDir,
    dicUniqueUmi,
    intUniqueUmi,
):
    try:
        strUmi = getUmi(
            strLocation,
            strAnchor,
            intUmiLength,
            strForward.upper(),
            strReverse.upper(),
            strRead,
        )
    except UnboundLocalError:
        pass
    try:
        if strUmi is not None:
            if strLocation == "umi5" or strLocation == "umi3":
                intLengthPotentialUmi = len(strUmi)
                if int(intLengthPotentialUmi) == int(intUmiLength):
                    strCode = strUmi
                else:
                    strCode = None
            elif strLocation == "umidouble":
                strCombined = strUmi[0] + strUmi[1]
                intLengthPotentialUmi = len(strCombined)
                intDoubleUmi = intUmiLength * 2
                if int(intLengthPotentialUmi) == int(intDoubleUmi):
                    strCode = strCombined
                else:
                    strCode = None
            else:
                pass
        else:
            pass
    except UnboundLocalError:
        pass
    try:
        if strCode is not None:
            if strCode not in dicUniqueUmi:
                dicUniqueUmi[strCode] = intUniqueUmi
                intUniqueUmi += 1
            else:
                pass
        else:
            pass
    except UnboundLocalError:
        pass
    try:
        if strCode is not None:
            getFastaFile(
                strTempDir, dicUniqueUmi, strHeader, strRead, strCode,
            )
        else:
            pass
    except UnboundLocalError:
        pass
    strUmi = None
    strCode = None
    return dicUniqueUmi, intUniqueUmi


# The runCaltha function.
# This function controls and calls the main functionality of Caltha. It loops
# through the input file(s) using pyfastx.
def runCaltha(
    flInput,
    strMainDir,
    strLocation,
    strAnchor,
    intUmiLength,
    strForward,
    strReverse,
    intAbundance,
    fltIdentity,
    strFormat,
):
    fltStart = time.time()
    strInputFileName = flInput.split("/")[-1].split(".")[0]
    strOutputFileName = strMainDir + "/" + strInputFileName
    strTempDir = setWorkDirs(strMainDir, True)
    dicUniqueUmi = {}
    intUniqueUmi = 1
    if strFormat == "fasta":
        for oisRecord in pyfastx.Fasta(flInput):
            lstProcessOutput = processInputFile(
                oisRecord.name,
                oisRecord.seq.upper(),
                strLocation,
                strAnchor,
                intUmiLength,
                strForward,
                strReverse,
                strTempDir,
                dicUniqueUmi,
                intUniqueUmi,
            )
            dicUniqueUmi = lstProcessOutput[0]
            intUniqueUmi = lstProcessOutput[1]
    elif strFormat == "fastq":
        for oisRecord in pyfastx.Fastq(flInput):
            lstProcessOutput = processInputFile(
                oisRecord.name,
                oisRecord.seq.upper(),
                strLocation,
                strAnchor,
                intUmiLength,
                strForward,
                strReverse,
                strTempDir,
                dicUniqueUmi,
                intUniqueUmi,
            )
            dicUniqueUmi = lstProcessOutput[0]
            intUniqueUmi = lstProcessOutput[1]
    else:
        pass
    getZipArchive(
        (strTempDir + "/PreValidation"),
        (strOutputFileName + "_PREVALIDATION.zip"),
        ".fasta",
    )
    getDereplication(strTempDir)
    getSortBySize(strTempDir, intAbundance)
    setClusterSize(strTempDir, fltIdentity)
    setOutputFiles(strTempDir, strOutputFileName)
    fltEnd = time.time() - fltStart
    strEnd = ntpath.basename(flInput) + ": " + "%.2f" % fltEnd + "s"
    return strEnd


# The removeWorkDirs function.
# This function removes all temporary working directories.
def removeWorkDirs(strDir):
    shutil.rmtree(strDir)


# The getZipArchive function.
# This function creates a zip archive from all files in the specified
# directory.
def getZipArchive(strDir, flZip, strExtension):
    with zipfile.ZipFile(flZip, "w") as objZip:
        for strFileName in os.listdir(strDir):
            if strFileName.endswith(strExtension):
                strFullPath = strDir + "/" + strFileName
                objZip.write(strFullPath, os.path.basename(strFullPath))


# The createInputList function.
# This function copies the input file(s) to a new temporary folder. It then
# creates a list of all input file names.
def createInputList(flInput, strMainDir):
    strCreate = strMainDir + "/" + "inputFiles"
    os.makedirs(strCreate)
    if (
        os.path.splitext(flInput)[1] == ".zip"
        or os.path.splitext(flInput)[1] == ".ZIP"
    ):
        with zipfile.ZipFile(flInput, "r") as objZip:
            objZip.extractall(strCreate)
            lstInput = [
                strCreate + "/" + strFile
                for strFile in zipfile.ZipFile.namelist(objZip)
            ]
    else:
        shutil.copyfile(flInput, strCreate + "/" + ntpath.basename(flInput))
        lstInput = [strCreate + "/" + ntpath.basename(flInput)]
    return lstInput


# The setWorkDirs function.
# This function creates the main temporary working directory and all subprocess
# directories. It checks if the directories already exist and creates them
# if they don't.
def setWorkDirs(strDirectory, blnExtra):
    strRandom = "".join(
        random.choice(string.ascii_lowercase) for i in range(10)
    )
    lstWorkDirs = []
    lstWorkDirs.append(strDirectory + "/" + strRandom)
    if blnExtra is True:
        lstWorkDirs.append(strDirectory + "/" + strRandom + "/PreValidation")
        lstWorkDirs.append(strDirectory + "/" + strRandom + "/Clustering")
        lstWorkDirs.append(strDirectory + "/" + strRandom + "/PostClustering")
    else:
        pass
    for strDirectory in lstWorkDirs:
        if not os.path.exists(strDirectory):
            os.makedirs(strDirectory)
        else:
            pass
    return lstWorkDirs[0]


# The argvs function.
def parseArgvs():
    strDescription = "A python package for processing UMI tagged mixed amplicon\
                      metabarcoding data."
    strEpilog = "This python package requires one extra dependency which can\
                 be easily installed with conda (conda install -c bioconda\
                 vsearch=2.14.2)."
    parser = argparse.ArgumentParser(
        description=strDescription,
        epilog=strEpilog,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "-v", "--version", action="version", version="%(prog)s [0.6]"
    )
    parser.add_argument(
        "-i",
        "--input",
        action="store",
        dest="flInput",
        default=argparse.SUPPRESS,
        help="The input fasta/fastq file(s). This can either be a zip archive\
              or a single fasta/fastq file.",
    )
    parser.add_argument(
        "-t",
        "--tabular",
        action="store",
        dest="flTabular",
        default=argparse.SUPPRESS,
        help="The output tabular zip file.",
    )
    parser.add_argument(
        "-z",
        "--zip",
        action="store",
        dest="flPreValidation",
        default=argparse.SUPPRESS,
        help="The pre validation zip file.",
    )
    parser.add_argument(
        "-b",
        "--blast",
        action="store",
        dest="flBlast",
        default=argparse.SUPPRESS,
        help="The output blast zip file.",
    )
    parser.add_argument(
        "-f",
        "--format",
        action="store",
        dest="strFormat",
        default="fasta",
        help="The format of the input file [fasta/fastq].",
    )
    parser.add_argument(
        "-l",
        "--location",
        action="store",
        dest="strLocation",
        default="umi5",
        help="Search for UMIs at the 5'-end\
              [umi5], 3'-end [umi3] or at the 5'-end and 3'-end [umidouble].",
    )
    parser.add_argument(
        "-a",
        "--anchor",
        action="store",
        dest="strAnchor",
        default="primer",
        help="Which anchor type to use [primer/adapter/zero].",
    )
    parser.add_argument(
        "-u",
        "--length",
        action="store",
        dest="intUmiLength",
        default="5",
        help="The length of the UMI sequence.",
    )
    parser.add_argument(
        "-y",
        "--identity",
        action="store",
        dest="fltIdentity",
        default="0.97",
        help="The identity percentage with which to perform the validation.",
    )
    parser.add_argument(
        "-c",
        "--abundance",
        action="store",
        dest="intAbundance",
        default="1",
        help="The minimum abundance of a sequence\
              in order for it to be included during validation.",
    )
    parser.add_argument(
        "-w",
        "--forward",
        action="store",
        dest="strForward",
        default=argparse.SUPPRESS,
        help="The 5'-end anchor nucleotides.",
    )
    parser.add_argument(
        "-r",
        "--reverse",
        action="store",
        dest="strReverse",
        default=argparse.SUPPRESS,
        help="The 3'-end anchor nucleotides.",
    )
    parser.add_argument(
        "-d",
        "--directory",
        action="store",
        dest="strDirectory",
        default=".",
        help="The location of the temporary working directory (not created\
              by program).",
    )
    parser.add_argument(
        "-@",
        "--threads",
        action="store",
        dest="intThreads",
        default=mp.cpu_count(),
        help="The number of threads to run Caltha with.",
    )
    argvs = parser.parse_args()
    return argvs


# The main function.
def main():
    disArgvs = parseArgvs()
    strMainDir = setWorkDirs(disArgvs.strDirectory, False)
    lstInput = createInputList(disArgvs.flInput, strMainDir)
    mpPool = mp.Pool(int(disArgvs.intThreads))
    mpTemporary = [
        mpPool.apply_async(
            runCaltha,
            args=(
                flInput,
                strMainDir,
                disArgvs.strLocation,
                disArgvs.strAnchor,
                disArgvs.intUmiLength,
                disArgvs.strForward,
                disArgvs.strReverse,
                disArgvs.intAbundance,
                disArgvs.fltIdentity,
                disArgvs.strFormat,
            ),
        )
        for flInput in lstInput
    ]
    mpPool.close()
    mpPool.join()
    mpOutput = [mpProcess.get() for mpProcess in mpTemporary]
    getZipArchive(strMainDir, disArgvs.flTabular, ".tbl")
    getZipArchive(strMainDir, disArgvs.flPreValidation, ".zip")
    getZipArchive(strMainDir, disArgvs.flBlast, ".fasta")
    removeWorkDirs(strMainDir)
    print(mpOutput)


if __name__ == "__main__":
    main()
