#!/usr/bin/env python
import sys
import os
import logging
import coloredlogs

import click
import argparse
import yaml
import json

import numpy as np

from grouper import eqnet
from grouper import processLabels
from grouper import iterLabel
from grouper import filtGraph
from subprocess import call
import subprocess

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])

def which(program):
    import os
    def is_exe(fpath):
        return os.path.isfile(fpath) and os.access(fpath, os.X_OK)

    fpath, fname = os.path.split(program)
    if fpath:
        if is_exe(program):
            return program
    else:
        for path in os.environ["PATH"].split(os.pathsep):
            path = path.strip('"')
            exe_file = os.path.join(path, program)
            if is_exe(exe_file):
                return exe_file

    return None

@click.command(context_settings=CONTEXT_SETTINGS)
@click.option('--config', required=True, help='Config file describing the experimental setup')
def processQuant(config):
    cfg = None
    with open(config, 'r') as yamCfg:
        cfg = yaml.load(yamCfg)
    
    outdir = cfg['outdir']

    if os.path.exists(outdir):
        if not os.path.isdir(outdir):
            logging.critical("The output directory already exists, and is not a directory!")
            sys.exit(1)
    else:
        # create it
        os.makedirs(outdir)

    logging.basicConfig(filename=os.path.sep.join([outdir, "log.txt"]),level=logging.INFO)
    logger = logging.getLogger("grouper")
    coloredlogs.install(level='INFO')

    if (which("mcl") == None):
        logging.critical("MCL not added to PATH. Install MCL.")
        sys.exit(1)

    # A list of all sample directories
    sampleDirs =  []
    cutoff = 10.0
    orphan = False 
    mincut = False
    labeling = False 
    threads = 8

    #ensure that fasta file 1 is TS and file 2 is AS
    #ensure label files are in order [TS->ASdb, AS->TSdb] 

    if 'conditions' not in cfg:
        logging.critical("Your configuration file must contain a \"conditions\" entry!")
        sys.exit(1)
    elif 'labels' in cfg and 'fasta' in cfg:
        logging.critical("Provide either FASTA files to run BLAST or labels file in the required formats")
        sys.exit(1)
    else:
        conditions = cfg['conditions']
        conditionDict = {c : {} for c in conditions}
        for c in conditions:
            samples = cfg['samples'][c]
            for i, s in enumerate(samples):
                conditionDict[c][i] = s
                sampleDirs.append(s)

        if 'cutoff' in cfg:
            cutoff = cfg['cutoff']
        if 'orphan' in cfg and cfg['orphan'] != False:
            orphan = True
        if 'mincut' in cfg and cfg['mincut'] != False:
            mincut = True
        if 'fasta' in cfg:
            faFiles = cfg['fasta']
            runBLAST = True
            labeling = True
        if 'labels' in cfg:
            labelFiles = cfg['labels']
            runBLAST = False
            labeling = True
        if 'threads' in cfg:
            threads = cfg['threads']

    netFile = os.path.sep.join([outdir, "mag.net"])
    # Make the network

    with open(os.path.sep.join([sampleDirs[0], 'cmd_info.json'])) as json_file:
        cmd_info = json.load(json_file)
    
    if 'auxDir' in cmd_info:
        auxDir = cmd_info['auxDir']
    else:
        auxDir = 'aux'

    logging.info("Building multiple alignment graph")
    eqnet.buildNetFile(sampleDirs, netFile, cutoff, auxDir)

    # add orphans before filter step
    if orphan:
        logging.info("Adding information from orphans")
        orphanNetFile = os.path.sep.join([outdir, "mag.orphan.net"])
        eqnet.addOrphanLinks(sampleDirs, auxDir, '/orphan_links.txt', cutoff, netFile, orphanNetFile)
    else:
        orphanNetFile = netFile

    # Filter the graph
    logging.info("Filtering multiple alignment graph")
    filtNetFile = os.path.sep.join([outdir, "mag.filt.net"])
    eqnet.filterGraph(conditionDict, orphanNetFile, filtNetFile, auxDir, mincut)

    # If labeling module is run, do this
    if labeling:
        # Check for junto 
        if (which("junto") == None):
            logging.critical("Junto library not added to PATH. Install junto.")
            sys.exit(1)
        #write the junto configuration file beforehand
        juntoConfigFile = os.path.sep.join([outdir, "junto.config"])

        keys={}
        keys["graph_file"] =  os.path.sep.join([outdir, "label.graph.txt"])
        keys["seed_file"] =  os.path.sep.join([outdir, "seed.txt"])
        keys["output_file"] =  os.path.sep.join([outdir, "tempOutput"])

        with open(juntoConfigFile, "w") as configFile:
            for x in keys.iterkeys():
                configFile.write(x + " = " + keys[x] + "\n")
            configFile.write("data_format = edge_factored\n")
            configFile.write("iters = 1\n")
            configFile.write("prune_threshold = 0\n")
            configFile.write("algo = adsorption\n")

        #final labels file
        finalLabelFile = os.path.sep.join([outdir, "seed.labels.txt"])

        if runBLAST:
            logging.info("Running BLAST on the input FASTA files.")
            labelFiles = processLabels.runBLAST(faFiles[0], faFiles[1], outdir, threads)
            logging.info("Processing the BLAST output.")
            processLabels.genFinalLabels(conditionDict, keys, labelFiles, finalLabelFile, outdir)
        else:
            logging.info("Processing the BLAST output.")
            processLabels.genFinalLabels(conditionDict, keys, labelFiles, finalLabelFile, outdir)

        iterLabel.run(keys, finalLabelFile, juntoConfigFile, filtNetFile, outdir)

        orgGraphFile = os.path.sep.join([outdir, "raw.label.graph.txt"])
        call(["mv", keys["graph_file"], orgGraphFile])
        logging.info("Trimming edges after label propagations.")
        filtGraph.filter(conditionDict, orgGraphFile, keys["graph_file"], auxDir)

        # Cluster the graph from Labeling Module
        labelClustFile = os.path.sep.join([outdir, "label.mag.clust"])
        logging.info("Clustering multiple alignment graph")
        call(["mcl", keys["graph_file"], "--abc", "-o", labelClustFile])

        # Flatten and summarize the file
        labelFlatClustFile = os.path.sep.join([outdir, "label.mag.flat.clust"])
        eqnet.flattenClusters(labelClustFile, labelFlatClustFile)

        labelSummaryFile = os.path.sep.join([outdir, "label.stats.json"])
        sizes = []
        with open(labelClustFile) as ifile:
            for l in ifile:
                toks = l.rstrip().split()
                sizes.append(len(toks))
        sizes = np.array(sizes)
        stats = {}
        stats['min clust size'] = sizes.min()
        stats['max clust size'] = sizes.max()
        stats['mean clust size'] = sizes.mean()
        stats['num clusts'] = len(sizes)

        with open(labelSummaryFile, 'w') as ofile:
            json.dump(stats, ofile)

    # Cluster the graph from Clustering Module
    clustFile = os.path.sep.join([outdir, "mag.clust"])
    logging.info("Clustering original mapping ambiguity graph")
    call(["mcl", filtNetFile, "--abc", "-o", clustFile])

    # Flatten both cluster files
    flatClustFile = os.path.sep.join([outdir, "mag.flat.clust"])
    eqnet.flattenClusters(clustFile, flatClustFile)

    # Summarize both cluster files
    summaryFile = os.path.sep.join([outdir, "stats.json"])
    sizes = []
    with open(clustFile) as ifile:
        for l in ifile:
            toks = l.rstrip().split()
            sizes.append(len(toks))
    sizes = np.array(sizes)
    stats = {}
    stats['min clust size'] = sizes.min()
    stats['max clust size'] = sizes.max()
    stats['mean clust size'] = sizes.mean()
    stats['num clusts'] = len(sizes)

    with open(summaryFile, 'w') as ofile:
        json.dump(stats, ofile)

    if labeling:
        call(['mkdir', os.path.sep.join([outdir, "/Annotated"])])
        call(['rm', os.path.sep.join([outdir, "tempOutput"])])
        call(['rm', os.path.sep.join([outdir, "junto.config"])])
        call(['mv', keys["seed_file"], os.path.sep.join([outdir, "final.labels.txt"])])
        temp = 'mv '+outdir+"/*label* "+outdir+"/Annotated"
        subprocess.Popen(temp, shell=True)


if __name__ == "__main__":
    processQuant()
