#!/usr/bin/env python
import argparse
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import sys
from Giraffe_View.function import *

def plot_estimate_acc(input_file, x_min, x_max, x_gap):
    df = process_in_chunks(input_file)
    df = pd.DataFrame(df)
    df["Accuracy"] = df["Accuracy"] * 100

    plt.figure(figsize=(8, 6))
    ax = sns.kdeplot(data=df, x="Accuracy", hue="Group", fill=True, 
                     alpha=0.6, palette="Set2", common_norm=False)
    sns.move_legend(ax, "upper left")
    ax
    acc_scale = [x_min, x_max]
    acc_breaks = [i for i in range(x_min, x_max+1, x_gap)]

    plt.xlabel("Estimated read accuracy (%)")
    plt.ylabel("Probability Density Function")
    plt.xlim(acc_scale)
    plt.xticks(acc_breaks)
    plt.tight_layout()
    plt.savefig("New_read_estimate_accuracy.svg", format="svg", dpi=300)
    plt.close()

def plot_observe_acc(input_file, x_min, x_max, x_gap):
    df = process_in_chunks(input_file)
    df = pd.DataFrame(df)
    df["Acc"] = df["Acc"] * 100

    plt.figure(figsize=(8, 6))
    ax = sns.kdeplot(data=df, x="Acc", hue="Group", fill=True, 
                     common_norm=False, alpha=0.6, palette="Set2")
    sns.move_legend(ax, "upper left")
    ax

    acc_scale = [x_min, x_max]
    acc_breaks = [i for i in range(x_min, x_max+1, x_gap)]

    plt.xlabel("Observed read accuracy (%)")
    plt.ylabel("Probability Density Function")
    plt.xlim(acc_scale)
    plt.xticks(acc_breaks)
    plt.tight_layout()
    plt.savefig("New_read_observe_accuracy.svg", format="svg", dpi=300)
    plt.close()

def plot_observe_mismatch(input_file, y_max, y_gap):
    df = process_in_chunks(input_file)
    df = pd.DataFrame(df)

    df["p_ins"] = 100 * df["Ins"] / (df["Ins"] + df["Del"] + df["Sub"] + df["Mat"])
    df["p_del"] = 100 * df["Del"] / (df["Ins"] + df["Del"] + df["Sub"] + df["Mat"])
    df["p_sub"] = 100 * df["Sub"] / (df["Ins"] + df["Del"] + df["Sub"] + df["Mat"])

    df = df.melt(id_vars=["Group"], value_vars=["p_ins", "p_del", "p_sub"], 
                 var_name="Mismatch Type", value_name="Mismatch Proportion")

    plt.figure(figsize=(8, 6))
    sns.boxplot(data=df, x="Mismatch Type", y="Mismatch Proportion", hue="Group", 
                showfliers=False, width=0.5, saturation=0.6, palette="Set2")

    mis_scale = [0, y_max]
    mis_breaks = [i for i in range(0, y_max+1, y_gap)]

    plt.ylabel("Mismatch proportion (%)")
    plt.ylim(mis_scale)
    plt.yticks(mis_breaks)
    plt.xticks(ticks=[0, 1, 2], labels=["Insertion", "Deletion", "Substitution"])
    plt.xlabel("")

    plt.legend(title='Group')
    plt.tight_layout()
    plt.savefig("New_observed_mismatch_proportion.svg", format="svg", dpi=300)
    plt.close()

def plot_observe_homo(input_file, y_min, y_max, y_gap):
    df = process_in_chunks(input_file)
    df = pd.DataFrame(df)
    df["Accuracy"] = df["Accuracy"] * 100

    plt.figure(figsize=(8, 6))
    sns.lineplot(data=df, x='Base', y='Accuracy', hue='Group', linewidth=1.5,
                 markers=True, dashes=False, palette="Set2", alpha=0.6, legend=False)
    sns.scatterplot(data=df, x='Base', y='Accuracy', hue='Group', 
                    palette="Set2", s=50, edgecolor="black")

    homo_scale = [y_min, y_max]
    homo_breaks = [i for i in range(y_min, y_max+1, y_gap)]

    plt.ylim(homo_scale)
    plt.yticks(homo_breaks)
    plt.ylabel('Accuracy of homopolymer identification (%)')
    plt.xlabel('Base')

    plt.legend(title='Group')
    plt.tight_layout()
    plt.savefig("New_homoploymer_summary.svg", format="svg", dpi=300, bbox_inches='tight')
    plt.close()

def plot_GC_bias(input_file, x_min, x_max, x_gap):
    df = process_in_chunks(input_file)
    df = pd.DataFrame(df)

    plt.figure(figsize=(8, 5))
    sns.lineplot(data=df, x="GC_content", y="Normalized_depth", hue="Group", 
                 palette="Set2", linewidth=1.5, alpha=0.6)
    sns.scatterplot(data=df, x="GC_content", y="Normalized_depth", hue="Group",
                    palette="Set2", edgecolor="black", s=20, legend=False)
    plt.axhline(1, color="grey", linestyle="dotted")
    plt.ylim(0, 2)
    # depth_breaks = [i for i in range(0, 2.1, 0.2)]
    depth_breaks = [i * 0.2 for i in range(11)]

    plt.yticks(depth_breaks)

    plt.xlim([x_min, x_max])
    plt.xticks([i for i in range(x_min, x_max+1, x_gap)])
    plt.xlabel("GC content (%)")
    plt.ylabel("Normalized depth")
    plt.grid(False)
    plt.tight_layout()
    plt.savefig("New_relationship_normalization.svg", format="svg", dpi=300)
    plt.close()

if __name__ == '__main__':
    version = "0.2.2"
    parser = argparse.ArgumentParser(description="",
        usage="\n   # Users can replot the figures by rescaling the regions along the x-axis or y-axis.\n"
              "\n   %(prog)s estimate_acc --input Estimated_information.txt --x_min 50 --x_max 100 --x_gap 10    # For estimated read accuracy!"
              "\n   %(prog)s observe_acc --input Observed_information.txt --x_min 50 --x_max 100 --x_gap 10      # For observed read accuracy!"
              "\n   %(prog)s observe_mismatch --input Observed_information.txt --y_max 5 --y_gap 1               # For mismatch proportion!"
              "\n   %(prog)s observe_homo --input Homoploymer_summary.txt --y_min 90 --y_max 100 --y_gap 2       # For homopolymer accuracy!"
              "\n   %(prog)s gcbias --input Relationship_normalization.txt --x_min 20 --x_max 50 --x_gap 2       # For relationship between normalized depth and GC content!"
              "\n\nversion: " + str(version) + "\n"
              "For more details, please refer to the documentation: https://giraffe-documentation.readthedocs.io/en/latest.")

    subparsers = parser.add_subparsers(dest='function', help=None, description=None, prog="giraffe", metavar="  subcommand and function")

    plot_estimate_acc_parser = subparsers.add_parser('estimate_acc', help='Replot estimated read accuracy')
    plot_estimate_acc_parser.add_argument("--input", type=str, metavar="", required=True, help="the result generated from giraffe (Estimated_information.txt)")
    plot_estimate_acc_parser.add_argument("--x_min", type=int, metavar="", required=True, help="the smallest cutoff for estimated read accuracy")
    plot_estimate_acc_parser.add_argument("--x_max", type=int, metavar="", required=True, help="the largest cutoff for estimated read accuracy")
    plot_estimate_acc_parser.add_argument("--x_gap", type=int, metavar="", required=True, help="the interval between two values on an x-axis")

    plot_observe_acc_parser = subparsers.add_parser('observe_acc', help='Replot observed read accuracy')
    plot_observe_acc_parser.add_argument("--input", type=str, metavar="", required=True, help="the result generated from giraffe (Observed_information.txt)")
    plot_observe_acc_parser.add_argument("--x_min", type=int, metavar="", required=True, help="the smallest cutoff for observed read accuracy")
    plot_observe_acc_parser.add_argument("--x_max", type=int, metavar="", required=True, help="the largest cutoff for observed read accuracy")
    plot_observe_acc_parser.add_argument("--x_gap", type=int, metavar="", required=True, help="the interval between two values on an x-axis")

    plot_observe_mismatch_parser = subparsers.add_parser('observe_mismatch', help='Replot observed mismatch proportion')
    plot_observe_mismatch_parser.add_argument("--input", type=str, metavar="", required=True, help="the result generated from giraffe (Observed_information.txt)")
    plot_observe_mismatch_parser.add_argument("--y_max", type=int, metavar="", required=True, help="the largest cutoff for mismatch proportion")
    plot_observe_mismatch_parser.add_argument("--y_gap", type=int, metavar="", required=True, help="the interval between two values on a y-axis")

    plot_observe_homo_parser = subparsers.add_parser('observe_homo', help='Replot observed read accuracy')
    plot_observe_homo_parser.add_argument("--input", type=str, metavar="", required=True, help="the result generated from giraffe (Homoploymer_summary.txt)")
    plot_observe_homo_parser.add_argument("--y_min", type=int, metavar="", required=True, help="the smallest cutoff for homopolymer accuracy")
    plot_observe_homo_parser.add_argument("--y_max", type=int, metavar="", required=True, help="the largest cutoff for homopolymer accuracy")
    plot_observe_homo_parser.add_argument("--y_gap", type=int, metavar="", required=True, help="the interval between two values on a y-axis")

    plot_GC_bias_parser = subparsers.add_parser('gcbias', help='Replot observed read accuracy')
    plot_GC_bias_parser.add_argument("--input", type=str, metavar="", required=True, help="the result generated from giraffe (Relationship_normalization.txt)")
    plot_GC_bias_parser.add_argument("--x_min", type=int, metavar="", required=True, help="the smallest cutoff for GC content")
    plot_GC_bias_parser.add_argument("--x_max", type=int, metavar="", required=True, help="the largest cutoff for GC content")
    plot_GC_bias_parser.add_argument("--x_gap", type=int, metavar="", required=True, help="the interval between two values on an x-axis")

    args = parser.parse_args()

    if len(sys.argv) == 1:
        parser.print_help(sys.stderr)
        sys.exit(1)

    if args.function == "estimate_acc":
        if len(sys.argv) == 2:
            plot_estimate_acc_parser.print_help(sys.stderr)
            sys.exit(1)
        else:
            plot_estimate_acc(args.input, args.x_min, args.x_max, args.x_gap)

    elif args.function == "observe_acc":
        if len(sys.argv) == 2:
            plot_observe_acc_parser.print_help(sys.stderr)
            sys.exit(1)
        else:
            plot_observe_acc(args.input, args.x_min, args.x_max, args.x_gap)

    elif args.function == "observe_mismatch":
        if len(sys.argv) == 2:
            plot_observe_mismatch_parser.print_help(sys.stderr)
            sys.exit(1)
        else:
            plot_observe_mismatch(args.input, args.y_max, args.y_gap)

    elif args.function == "observe_homo":
        if len(sys.argv) == 2:
            plot_observe_homo_parser.print_help(sys.stderr)
            sys.exit(1)
        else:
            plot_observe_homo(args.input, args.y_min, args.y_max, args.y_gap)

    elif args.function == "gcbias":
        if len(sys.argv) == 2:
            plot_GC_bias_parser.print_help(sys.stderr)
            sys.exit(1)
        else:
            plot_GC_bias(args.input, args.x_min, args.x_max, args.x_gap)


