import os

import argparse

parser = argparse.ArgumentParser(
    description="""Merge several epic count-matrixes into one.""",

    prog=os.path.basename(__file__))


parser.add_argument(
    '--matrixes',
    '-m',
    required=True,
    type=str,
    nargs='+',
    help='''epic-count matrixes to merge.''')


parser.add_argument(
    '--regions',
    '-r',
    required=False,
    type=str,
    nargs='+',
    help='''Bed file(s) with regions to use. Does not work with --keep-nonenriched.''')


parser.add_argument(
    '--keep-nonenriched',
    '-k',
    required=False,
    default=False,
    action='store_true',
    help='''Keep non-enriched bins also (takes much more time/mem). Not usable with --regions.''')


parser.add_argument(
    '--enriched-per-file',
    '-e',
    required=False,
    default=False,
    action='store_true',
    help='''Keep a column of enrichment info per matrix used.''')


parser.add_argument(
    '--output',
    '-o',
    required=True,
    type=str,
    help=
    '''Path to write gzipped output matrix.''')


parser.add_argument(
    '--number-cores',
    '-cpu',
    required=False,
    default=1,
    type=int,
    help=
    '''Number of cpus to use. Can use at most one per chromosome. Default: 1.''')


import sys
from subprocess import check_output
import logging

from math import gcd
from functools import reduce

import pandas as pd
from io import StringIO

from epic.config import logging_settings
from epic.merge.merge import merge_matrixes, read_dfs
from epic.merge.compute_bed_bins import merge_bed_bins, compute_bins


def compute_bin_size(dfs):

    bin_sizes = []
    for df in dfs.values():
        bins = df.head(100000).index.get_level_values("Bin")
        bin_size = reduce(gcd, bins)
        bin_sizes.append(bin_size)

    assert len(set(bin_sizes)) == 1, "Matrixes have different bin sizes: " + str(bin_sizes)
    bin_size = bin_sizes.pop()
    logging.info("Bin size: " + str(bin_size))

    return bin_size


def _remove_epic_enriched(dfs):

    new_dfs = {}
    for n, df in dfs.items():
        bad_cols = [c for c in df.columns if "Enriched_" in c]
        df = df.drop(bad_cols, axis=1)
        new_dfs[n] = df

    return new_dfs

if __name__ == '__main__':

    args = parser.parse_args()
    files = args.matrixes
    nb_cpus = args.number_cores
    regions = args.regions
    keep_nonenriched = args.keep_nonenriched
    enriched_per_file = args.enriched_per_file
    output = args.output

    dfs = read_dfs(files)
    bin_size = compute_bin_size(dfs)

    if regions:
        dfs = _remove_epic_enriched(dfs)
        names = ["Enriched_" + os.path.basename(r) for r in regions]
        regions = [(name, pd.read_table(r, sep="\s+", usecols=[0, 1, 2], header=None, names=["Chromosome", "Start", "End"])) for name, r in zip(names, regions)]
        bins = [compute_bins(df, bin_size, name) for name, df in regions]
        r = merge_bed_bins(bins)
        overlapping_indices = [df.index.intersection(r.index) for df in dfs.values()]

        key, first_df = list(dfs.items())[0]
        new_dfs = {key: first_df.ix[overlapping_indices[0]].join(r.ix[overlapping_indices[0]], how="inner")}
        for ix, (k, df) in zip(overlapping_indices[1:], list(dfs.items())[1:]):
            new_dfs[k] = df.ix[ix]
        dfs = {k: df[~df.index.duplicated(keep="first")] for k, df in new_dfs.items()}
        keep_nonenriched = False

    merged_df = merge_matrixes(dfs, keep_nonenriched, regions, enriched_per_file, nb_cpus)

    merged_df.to_csv(output, sep=" ", compression="gzip")
