#!/usr/bin/env python

"""
epic-overlap
Plot nb. and nucleotides of overlaps between enriched regions in different conditions

(Visit github.com/endrebak/epic for examples and help.)

Usage:
    epic-overlap [--cpu=CPU] [--outfile=OUT] [--matrix=MAT] [--title=TIT] [--ylabs=YLB] [--xlabs=XLB] [--fill-labs=FLB] [--nucleotides-heatmap|--regions-heatmap|--nucleotides-bargraph|--regions-bargraph] FILE...
    epic-overlap --help

Arguments:
    FILE              files of regions

Options:
    --nucelotides-heatmap   heatmap of the number of nucleotides overlapping in each file
    --regions-heatmap       heatmap of the number of regions overlapping (by at least one nucleotide) in each file
    --nucelotides-bargraph  stacked bargraph of the number of other files each nucleotide is overlapping
    --regions-bargraph      stacked bargraph of the number of other files each region is overlapping
    -o OUT --outfile OUT    name of output pdf file
    -m MAT --matrix MAT     write the matrix (dataframe) used for graphing to MAT
    -t TTL --title TTL      title of graph  [default: epic-overlaps]
    -y YLB --ylabs YLB      y-axis text
    -y XLB --xlabs XLB      x-axis text
    -f FLB --fill-labs FLB  color fill name
    -c CPU --cpu CPU        nb. cpus to use, can use at most one per input file  [default: 1]
    -h --help               show this help message
"""

import logging
from os.path import basename, dirname
from subprocess import check_output

import pandas as pd
from io import StringIO
from docopt import docopt
from natsort import natsorted, index_natsorted, order_by_index
from joblib import Parallel, delayed


import warnings
warnings.filterwarnings("ignore",
                        category=UserWarning,
                        module="rpy2")  # rpy2 is noisy

from rpy2.robjects import r
from rpy2.robjects import pandas2ri
from rpy2.robjects.packages import importr
pandas2ri.activate()

args = docopt(__doc__)

from epic.config import logging_settings

importr("ggplot2")

def labs_str(args):

    labs = []

    if args["--ylabs"]:
        labs.append("y='" + args["--ylabs"] + "'")
    if args["--xlabs"]:
        labs.append("x='" + args["--xlabs"] + "'")
    if args["--fill-labs"]:
        labs.append("fill='" + args["--fill-labs"] + "'")

    if labs:
        return "+ labs(" + ", ".join(labs) + ")"
    else:
        return ""

def heatmap(df, outfile, args):
    print(df.head())

    df = df.reindex(index=order_by_index(df.index, index_natsorted(zip(df.Main, df.Other))))
    factor_order = natsorted(list(df.Main.drop_duplicates()))
    factors = '"' + '", "'.join(factor_order) + '"'

    labs = labs_str(args)

    f = """function(df) {{

    df$Main = factor(as.character(df$Main), levels=c({factors}), ordered=TRUE)
    df$Other = factor(df$Other,levels=c({factors}), ordered=TRUE)
    df$Overlaps = as.numeric(df$Overlaps)
    p = ggplot(df, aes(x=Other, y=Main, fill=Overlaps)) + geom_tile() +  scale_fill_gradient(low="white", high="#832424FF") + ggtitle("{title}") {labs}

    p = p + theme(axis.text.x = element_text(angle = 45, hjust = 1))

    print(p)
    }}
    """.format(factors=factors, title=title, labs=labs)
    logging.debug(f)

    plot_df_nb_regions = r(f)

    r("pdf('{}')".format(outfile))
    logging.info("Writing heatmaps to file: " + outfile)

    r_dataframe = pandas2ri.py2ri(df)
    plot_df_nb_regions(r_dataframe)

    r("dev.off()")


def bargraph(df, title, outfile):

    df = df.reindex(index=order_by_index(df.index, index_natsorted(zip(df.Main, df.Other))))
    factor_order = natsorted(list(df.Main.drop_duplicates()))
    factors = '"' + '", "'.join(factor_order) + '"'

    f = """function(df) {{

    df$Main = factor(as.character(df$Main), levels=c({factors}), ordered=TRUE)
    df$Other = factor(df$Other, ordered=TRUE)
    p = ggplot(df, aes(x=Main, y=Overlaps, fill=Other)) + geom_bar(stat="identity") + ggtitle("{title}") + scale_fill_brewer(palette="Set3") + labs(x="Timepoints", y="Number of regions", fill="Overlaps")
    p = p + theme(axis.text.x = element_text(angle = 45, hjust = 1))

    print(p)
    }}
    """.format(factors=factors, title=title)
    logging.debug(f)

    plot_df_nb_regions = r(f)

    r("pdf('{}')".format(outfile))
    logging.info("Writing bargraphs to file: " + outfile)

    r_dataframe = pandas2ri.py2ri(df)
    plot_df_nb_regions(r_dataframe)

    r("dev.off()")


def normalize_heatmaps(overlap_column):
    o = overlap_column
    return (o - o.min()) / (o.max() - o.min())


if __name__ == '__main__':

    all_files = args["FILE"]
    outfile = args["--outfile"]
    title = args["--title"]
    nb_cpu = int(args["--cpu"])

    if args["--regions-bargraph"]:
        from epic.scripts.overlaps.overlaps import overlap_matrix_region_counts
        df = overlap_matrix_region_counts(all_files, nb_cpu)
        bargraph(df, title, outfile)


    elif args["--nucleotides-bargraph"]:
        from epic.scripts.overlaps.nucleotide_bargraph import overlap_matrix_nucleotides
        df = overlap_matrix_nucleotides(all_files, nb_cpu)

        bargraph(df, title, outfile)


    elif args["--regions-heatmap"]:
        from epic.scripts.overlaps.overlaps import overlap_matrix_regions
        df = overlap_matrix_regions(all_files, nb_cpu)

        heatmap(df, outfile, args)


    elif args["--nucleotides-heatmap"]:
        from epic.scripts.overlaps.nucleotides_heatmap import nucleotide_overlaps_per_file

        df = nucleotide_overlaps_per_file(all_files, nb_cpu)

        df = df.set_index(["Chromosome"]).groupby("Main Other".split()).sum().reset_index()

        heatmap(df, outfile, args)


    if args["--matrix"]:
        df.to_csv(args["--matrix"], sep=" ", index_col=False)
