import os

import argparse

parser = argparse.ArgumentParser(
    description="""Merge several epic count-matrixes into one.""",

    prog=os.path.basename(__file__))


parser.add_argument(
    '--matrixes',
    '-m',
    required=True,
    type=str,
    nargs='+',
    help='''epic-count matrixes to merge.''')


parser.add_argument(
    '--regions',
    '-r',
    required=False,
    type=str,
    nargs='+',
    help='''Bed file(s) with regions to use. Does not work with --keep-nonenriched.''')


parser.add_argument(
    '--keep-nonenriched',
    '-k',
    required=False,
    default=False,
    action='store_true',
    help='''Keep non-enriched bins also (takes much more time/mem). Not usable with --regions.''')


parser.add_argument(
    '--enriched-per-file',
    '-e',
    required=False,
    default=False,
    action='store_true',
    help='''Keep a column of enrichment info per matrix used.''')


parser.add_argument(
    '--output',
    '-o',
    required=True,
    type=str,
    help=
    '''Path to write gzipped output matrix.''')


parser.add_argument(
    '--number-cores',
    '-cpu',
    required=False,
    default=1,
    type=int,
    help=
    '''Number of cpus to use. Can use at most one per chromosome. Default: 1.''')


import sys
from subprocess import check_output
import logging

from math import gcd
from functools import reduce

import pandas as pd
from io import StringIO

from epic.config import logging_settings
from epic.merge.merge import merge_matrixes, read_dfs
from epic.merge.compute_bed_bins import merge_bed_bins, compute_bins


def compute_bin_size(dfs):

    bin_sizes = []
    for df in dfs.values():
        bins = df.head(100000).index.get_level_values("Bin")
        bin_size = reduce(gcd, bins)
        bin_sizes.append(bin_size)

    assert len(set(bin_sizes)) == 1, "Matrixes have different bin sizes: " + str(bin_sizes)
    bin_size = bin_sizes.pop()
    logging.info("Bin size: " + str(bin_size))

    return bin_size



if __name__ == '__main__':

    args = parser.parse_args()
    files = args.matrixes
    nb_cpus = args.number_cores
    regions = args.regions
    keep_nonenriched = args.keep_nonenriched
    enriched_per_file = args.enriched_per_file
    output = args.output

    dfs = read_dfs(files)
    bin_size = compute_bin_size(dfs)

    if regions:
        names = ["Enriched_" + os.path.basename(r) for r in regions]
        regions = [pd.read_table(r, sep="\s+", usecols=[0, 1, 2], header=None, names=["Chromosome", "Start", "End"]) for r in regions]
        bins = [compute_bins(df, bin_size) for df in regions]
        r = merge_bed_bins(bins)
        overlapping_indices = [df.index.intersection(r.index) for df in dfs.values()]
        dfs = {k: df.ix[ix].join(r.ix[ix], how="outer") for ix, (k, df) in zip(overlapping_indices, dfs.items())}
        dfs = {k: df[~df.index.duplicated(keep="first")] for k, df in dfs.items()}
        keep_nonenriched = False

    merged_df = merge_matrixes(dfs, keep_nonenriched, enriched_per_file, nb_cpus)

    merged_df.to_csv(output, sep=" ", compression="gzip")
