#!/usr/bin/env python

"""
epic-merge
Merge several epic-matrixes into one.

(Visit github.com/endrebak/epic for examples and help.)

Usage:
    epic-merge  [--nb-cpu=CPU] [--enriched-per-file] [--even-nonenriched] --output=OUT FILE...
    epic-merge  --help

Arguments:
    FILE                    input matrixes
    -o=OUT --output=OUT     gzipped result matrix

Options:
    -h --help               show this help message
    -e --even-nonenriched   even include non-enriched bins in output matrix (much slower and more mem used)
    -f --enriched-per-file  include column about bin-enrichment per file
    -n=CPU --nb-cpu=CPU     number of CPUs to use (at most one per chromosome)  [default: 1]
"""

import sys
import logging
from os.path import basename, dirname
from subprocess import check_output

from collections import defaultdict, OrderedDict
import pandas as pd
from io import StringIO
from docopt import docopt
from natsort import natsorted
from joblib import Parallel, delayed

args = docopt(__doc__)

from epic.config import logging_settings


def read_dfs(files):

    dfs = OrderedDict()
    for f in files:
        df = pd.read_table(f, header=0, sep=" ", index_col=[0, 1])

        columns = list(df.columns)
        file_nick = "Enriched_"+ basename(f)
        columns[0] = file_nick
        df.columns = columns

        logging.info("Calling " + f + " " + file_nick + " in matrix file.")
        dfs[f] = df

    return dfs


def enriched_indexes(dfs):

    # find all chromosomes in files
    iter_dfs = iter(dfs.values())
    df = next(iter_dfs)

    current_index = df.filter(like="Enriched_")
    name = current_index.columns[0]

    enriched = current_index.loc[current_index[name] == 1].index

    for df in iter_dfs:

        current_index = df.filter(like="Enriched_")
        name = current_index.columns[0]

        current_index = current_index.loc[current_index[name] == 1].index

        enriched = enriched.union(current_index)

    return enriched


def remove_nonenriched(enriched, dfs):

    new_dfs = OrderedDict()
    for f, df in dfs.items():

        df = df.ix[df.index.intersection(enriched)]

        new_dfs[f] = df

    return new_dfs


def all_chromosomes(dfs):

    # find all chromosomes in files
    chromosomes = set()
    for df in dfs.values():
        chromosomes.update(set(df.index.get_level_values("Chromosome").drop_duplicates()))

    return(chromosomes)


def split_dfs_into_chromosome_dfs(dfs, chromosomes):

    # separate files into chromosomes
    chromosome_dfs = defaultdict(list)
    for f, df in dfs.items():
        for chromosome in chromosomes:
            try:
                chromosome_dfs[chromosome].append(df.xs(chromosome, level="Chromosome", drop_level=False))
            except KeyError: # if file is missing chromosome, insert dummy df to ensure same column ordering
                chromosome_dfs[chromosome].append(pd.DataFrame(columns=dfs[f].columns))

    return chromosome_dfs

def _merge_dfs(dfs):
    return pd.concat(dfs, axis=1).fillna(0)

def merge_dfs(chromosome_dfs, nb_cpu):

    # merge all chromosome dfs horizontally, per df
    merged_chromosome_dfs = Parallel(n_jobs=nb_cpu)(
        delayed(_merge_dfs)(dfs)
        for dfs in chromosome_dfs.values())

    # vertically merge (concat) all chromosome dfs
    merged_df = pd.concat(merged_chromosome_dfs)

    # indexes are shuffled after the intersection, resort
    merged_df = merged_df.reindex(index=natsorted(merged_df.index))

    return merged_df


if __name__ == '__main__':

    files = args["FILE"]

    dfs = read_dfs(files)

    if not args["--even-nonenriched"]:
        enriched = enriched_indexes(dfs)
        dfs = remove_nonenriched(enriched, dfs)

    chromosomes = all_chromosomes(dfs)

    chromosome_dfs = split_dfs_into_chromosome_dfs(dfs, chromosomes)

    merged_df = merge_dfs(chromosome_dfs, int(args["--nb-cpu"]))

    enriched_cols = [c for c in merged_df if "Enriched_" in c]
    total_enriched = merged_df.filter(like="Enriched_").sum(axis=1)

    merged_df.insert(0, "TotalEnriched", total_enriched)

    merged_df = merged_df.set_index("TotalEnriched", append=True)

    if args["--enriched-per-file"]:
        merged_df = merged_df.set_index(enriched_cols, append=True)
    else:
        merged_df = merged_df.drop(enriched_cols, axis=1)

    merged_df.to_csv(args["--output"], sep=" ", compression="gzip")
