#!/usr/bin/env python3

# ------------------------------------
# python modules
# ------------------------------------

import os
import sys
import argparse as ap
import pandas as pd
import tempfile
import warnings

from PEALS.io.constant import *
# ------------------------------------
# own python modules
# ------------------------------------

# ------------------------------------
# Main function
# ------------------------------------
def main():
    """The Main function/pipeline for peals.
    """
    # Parse options...
    argparser = prepare_argparser()
    args = argparser.parse_args()
    ## get subcommand
    subcommand  = args.subcommand
    ##
    warnings.simplefilter(action='ignore', category=pd.errors.DtypeWarning)
    if subcommand == "callpeak":
        # General call peak
        from PEALS.subcmd.callpeak_cmd import run
        run( args )
    if subcommand == 'diffpeak':
        from PEALS.subcmd.diffpeak_cmd import run
        run( args )

def prepare_argparser ():
    """Prepare optparser object. New options will be added in this
    function first.
    """
    description = "%(prog)s -- Peak-based Enhancement Analysis PipeLine for MeRIP-Seq"
    epilog = "For command line options of each command, type: %(prog)s COMMAND -h"
    # top-level parser
    argparser = ap.ArgumentParser( description = description, epilog = epilog ) #, usage = usage )
    argparser.add_argument("-v", "--version", dest = "version", action="version", version="%(prog)s v"+ VERSION)
    subparsers = argparser.add_subparsers( dest = 'subcommand' ) #help="sub-command help")
    subparsers.required = True

    # command for 'callpeak'
    add_callpeak_parser( subparsers )
    # command for 'diffpeak'
    add_diffpeak_parser( subparsers )
    return argparser

def add_input_option ( parser ):
    parser.add_argument("-i", "--input", dest = "inputdir", type = str,
                        required = True,
                        help = "Input directory contains all *.bam and *.bai files.")
    parser.add_argument("-m", "--matrix", dest = "matrix", type = str,
                        required = True,
                        help = "The tab-separated matrix contain all information of input samples. Header format: label,library,condition,sample,replicate,bam,...")
    parser.add_argument("--recursive", dest = "recursive", action="store_true",
                        default=False,
                        help = "Recursively walk through --input to find *.bam and *.bai files.")
    parser.add_argument("--type", dest = "rnatype", action="store", 
                        choices = ['mature', 'primary', 'mixture'],
                        default='mature',
                        help = "The type of RNA transcripts in input MeRIP library (mature: only contain mature transcripts; \
                                primary: only contain primary transcripts; mixture: contain both mature and primay transcripts).")

def add_output_option ( parser ):
    parser.add_argument("-o", "--output", dest = "outputdir", type = str,
                        required = True,
                        help = "All output files will be written to this directory.")
    parser.add_argument("-T", "--temp", dest = "tempdir", type = str,
                        help = "If specified, all temporary files will be written to this directory. Default is the same with --output.")
    parser.add_argument("-P", "--prefix", dest = "prefix", type = str,
                        required = True,
                        help = "The prefix of all output files")
    parser.add_argument("--no-bwtrack", dest = "nobwtrack", action='store_true',
                        default = False,
                        help = "If specified, will not output bigwig signal tracks which can be loaded in IGV.")
    parser.add_argument("--keep-temp", dest = "keeptemp", action='store_true',
                        default = False,
                        help = "Keep the temporary files")
    parser.add_argument( "--verbose", dest = "verbose", type = int,
                        choices=[0, 1, 2, 3],
                        default = 2,
                        help = "Set verbose level of runtime message. 0: only show critical message, 1: show additional warning message, \
                                2: show process information, 3: show debug messages." )

def add_library_option ( parser ):
    parser.add_argument("-l", "--library", dest = "library", action="store", type=int,
                        choices = [0, 1, 2],
                        default = 2,
                        help = "Library protocols of input bam files. For details, 0 (unstranded), 1 (stranded or fr-secondstrand) \
                                and 2 (reversely stranded or fr-firstrand).")
    parser.add_argument("--estlib-size", dest = "estlibsize", action="store", type=str,
                        choices = ["primary", "mapped", "primary_mapped"],
                        default = "primary_mapped",
                        help = "How to estimate the library size when calling peaks and scaled bigwig outputs.")
    parser.add_argument("-e", "--extsize", dest = "extsize", type = int,
                        default = EXT_SIZE,
                        help = "Extend reads in 5\'->3\' direction to # bp (not recommend to be set when handling the paired-end data).")
    parser.add_argument("--pairend", dest = "pairend", action='store_true',
                        default = False,
                        help = "If speficied, all input bam files will be treated as PAIRED-END data.")

def add_normalize_option ( parser ):
    parser.add_argument("--exp-method", dest = "expmethod", type = str,
                        choices = ['count', 'TPM', 'FPKM'],
                        default = 'count',
                        help = "The method to estimate gene expression (mainly used for filter non-expressed genes).")
    parser.add_argument("--estpeak-sizefactor", dest = "peaksizefactor", type = str,
                        choices = ['peak', 'gene'],
                        default = 'peak',
                        help = "How to normalize reads count of peak candidates for normalization before applying DESEq2 model.")
    parser.add_argument("--bw-scale", dest = "bwscale", type = str,
                        choices = ['raw', 'rpm', 'paired', 'whole'],
                        default = 'rpm',
                        help = "How to handle the signal track output (in bigwig format).")

def add_bam_option ( parser ):
    parser.add_argument("--scale-sample", dest = "scalesample", type = str,
                        choices = ['to_small', 'to_large', 'raw'],
                        default = "to_small",
                        help = "The larger dataset will be scaled towards the smaller or larger dataset. \"to_small\" is recommended." )
    parser.add_argument("--sortbam", dest = "sortbam", action='store_true',
                        default = False,
                        help = "Sort bams before passing to featureCounts")
    parser.add_argument("--no-fraction", dest = "nofraction", action='store_true',
                        default = False,
                        help = "Do not count the multiple-aligned reads with fraction.")
    parser.add_argument("--frac-overlap", dest = "fracoverlap", type = float,
                        default = 0,
                        help = "Only overlap # fraction of reads larger than this value will count to peak.")
    parser.add_argument("--ignore-dup", dest = "ignoredup", action='store_true',
                        default = False,
                        help = "If specified, will ignore PCR duplicate reads flagged by '1024'.")
    parser.add_argument("--exp-cutoff", dest = "expcutoff", type = float,
                        default = 1,
                        help = "The cutoff for an expressed transcript. Only call peaks on expressed transcripts.")

def add_constant_option ( parser ):
    parser.add_argument("-t", "--thread", dest = "thread", type = int,
                        default = THREAD,
                        help = "How many threads used for parallel computation. Please note that more threads will use more physical and virtual memories.")
    parser.add_argument("-start", "--thread-start", dest = "threadstart", type = str,
                        choices=['forkserver', 'fork', 'spawn'],
                        default = THREAD_START_METHOD,
                        help = "Ways to start a multi-threaded process")
    parser.add_argument("--reltol", dest = "reltol", type = float,
                        default = MATH_RELTOL,
                        help = "The relative tolerance passed to 'rel_tol' parameter of math.isclose().")
    parser.add_argument("--buffer", dest = "buffer", type = int,
                        default = BUFFER_SIZE,
                        help = "The buffer size use to open files.")

def add_anno_option ( parser ):
    parser.add_argument("--gsize", dest = "gsize", type = str,
                        required=True,
                        help = "The file contain the sizes of chromosomes. Can be downloaded from UCSC genome browser (e.g. hg38.chrom.sizes).")
    parser.add_argument("--gff", dest = "gff", type = str,
                        required=True,
                        help = "The gene annotation file. The chromosomes naming should be consistent with --gsize and input bam files.")
    parser.add_argument("--gff-type", dest = "gfftype", type = str,
                        choices = ["GTF", "GFF3"],
                        default = "GTF",
                        help = "The format of input --gff.")
    parser.add_argument("--gff-source", dest = "gffsource", type = str,
                        choices = ["GENCODE", "ENSEMBL", "other"],
                        default = "GENCODE",
                        help = "The source of gene annotations.")
    parser.add_argument("--identifier", dest = "identifier", type = str,
                        default = 'gene_id:gene_name:gene_type:transcript_id:transcript_name:transcript_type',
                        help = "The identifier used to extract transcript information from annotation file.")

def add_peak_option ( parser ):
    parser.add_argument("--split", dest = "txsizemax", type = int,
                        default = TX_SPLIT_MAX_SIZE,
                        help = "Split the transcript into subgroups by #bp long. (must be >= 5000)")
    parser.add_argument("--span-method", dest = "spanmethod", type = str,
                        choices = ["move", "none"],
                        default = MOVE_METHOD,
                        help = "Method used to further smooth the coverage data on transcript after calling CSAPS model.")
    parser.add_argument("--span-loop", dest = "spanloop", type = int,
                        default = SMOOTH_LOOP,
                        help = "How many times that apply the method indicated by '--span-method' after calling CSAPS model.")
    parser.add_argument("--complexity-rate", dest = "comprate", type = float,
                        default = COMPLEXITY_RATE,
                        help = "If the complexity of signals in a detected region is lower than this value (eg., high PCR duplicates), then this region will no be considered as a peak candidates.")
    parser.add_argument("--lookahead", dest = "lookahead", type = int,
                        default = LOOK_AHEAD,
                        help = "The distance to look ahead from a peak candidate to determine if it is the actual peak.")
    parser.add_argument("--csaps-p", dest = "csapsp", type = float,
                        default = CSAPS_SMOOTH,
                        help = "The constant smooth parameter p (smooth) passed to csaps(). Smaller p means more smoothing.")
    parser.add_argument("--peak-size", dest = "peaksize", type = int,
                        help = "The minimun peak size. Default is the maximum value between 'half of the --extsize' and '{}'.".format(PEAK_SIZE))
    parser.add_argument("--ipratio", dest = "ipratio", type = float,
                        default = IP_RATIO,
                        help = "Define the candidate peaks within which regions have the ratio of ip/input >= --ipratio.")
    parser.add_argument("--center", dest = "center", type = float,
                        default = CENTER,
                        help = "Center the candidate peak region by shearing points of which are less than '--center' of highest coverage. Should be the range [0, 0.5].")

def add_model_option ( parser ):
    parser.add_argument("--fit-type", dest = "fittype", type = str,
                        choices = ["parametric", "local", "mean", "glmGamPoi"],
                        default = "parametric",
                        help = "Type of fitting of dispersions to the mean intensity (passed to DESeq2).")
    parser.add_argument("--shrink", dest = "shrink", action="store", type=str,
                        choices = ['none', 'apeglm', 'ashr'],
                        default = 'apeglm',
                        help='The method used for shrinking the fold change (lfcShrink() in DESeq2)')
    parser.add_argument("--test", dest = "test", action="store", type=str,
                        choices = ['Wald', 'LRT'],
                        default = "Wald",
                        help="The test method for calculating p-value (DESeq2).")
    parser.add_argument("--formula", dest = "formula", action="store", type=str,
                        help="The design formula passed to DESeq2(). Will overwrite the default formula.")

def add_filter_option ( parser ):
    parser.add_argument("-p", "--pval", dest = "pvalcutoff", type = float,
                        default = 0.05,
                        help = "If specified, only peaks have p-value <= --pval will be recognized significant peaks.")
    parser.add_argument("-q", "--padj", dest = "padjcutoff", type = float,
                        default = 0.05,
                        help = "If specified, only peaks have adjusted p-value (FDR) <= --padj will be recognized significant peaks.")
    parser.add_argument("--padj-method", dest = "padjmethod", type = str,
                        choices=["holm", "hochberg", "hommel", "bonferroni", "BH", "BY", "fdr", "none"],
                        default = 'BH',
                        help = "The method used to calculate adjusted p-value")

def add_callpeak_parser( subparsers ):
    """Add main function 'peak calling' argument parsers.
    """
    argparser_callpeak = subparsers.add_parser("callpeak",
                                               formatter_class = ap.ArgumentDefaultsHelpFormatter,
                                               epilog = EPILOG_CALL_PEAK,
                                               help="Main PEALS Function: Call peaks from alignment results.")
    ## group for input
    group_input = argparser_callpeak.add_argument_group( "Input options" )
    add_input_option( group_input )
    ## add subset sample option
    group_input.add_argument("-s", "--sample", dest = "sample", type = str,
                        help = "If specified, will select samples in an additional column of sample matrix ('--matrix') used for testing \
                        by the following format: \"column_name:value\", in which the value should be unique for each condition, \
                        e.g., \"group:HeLa_control\".")
    ## group for output
    group_output = argparser_callpeak.add_argument_group( "Output options" )
    add_output_option( group_output )
    ## group for library
    group_library = argparser_callpeak.add_argument_group( "Library options" )
    add_library_option( group_library )
    ## group for bam options
    group_bam = argparser_callpeak.add_argument_group( "Bam options" )
    add_bam_option( group_bam )
    ## group for normalization options
    group_normalize = argparser_callpeak.add_argument_group( "Normalization options" )
    add_normalize_option ( group_normalize )
    ## group for constant options
    group_constant = argparser_callpeak.add_argument_group( "Constants options" )
    add_constant_option( group_constant)
    ## group for constant options
    group_anno = argparser_callpeak.add_argument_group( "Genome options" )
    add_anno_option( group_anno )
    ## group for peak options
    group_peak = argparser_callpeak.add_argument_group( "Peak options" )
    add_peak_option( group_peak )
    ## group for model options
    group_model = argparser_callpeak.add_argument_group( "Model options" )
    add_model_option( group_model )
    ## group for filter options
    group_filter = argparser_callpeak.add_argument_group( "Filter options" )
    group_filter.add_argument("-f", "--fold", dest = "foldcutoff", type = float,
                        default = 2,
                        help = "If specified, only peaks have fold change (IP/input) (>= fold) will be recognized significant peaks.")
    add_filter_option( group_filter )
    ##
    return

def add_diffpeak_parser( subparsers ):
    """Add main function 'peak calling' argument parsers.
    """
    argparser_callpeak = subparsers.add_parser("diffpeak",
                                               formatter_class = ap.ArgumentDefaultsHelpFormatter,
                                               epilog = EPILOG_DIFF_PEAK,
                                               help="Main PEALS Function: Call differential peaks from alignment results.")
    ## group for input
    group_input = argparser_callpeak.add_argument_group( "Input files options" )
    add_input_option( group_input )
    ## add subset sample option
    group_input.add_argument("-s", "--sample", dest = "sample", type = str,
                        help = "If specified, will select samples in an additional column of sample matrix ('--matrix') used for testing \
                        by the following format: \"column_name:control=value1,treated=value2\", \
                        in which the value should be unique for each condition, \
                        e.g., \"group:control=HeLa_control,treated=HeLa_treated\".")
    ## group for output
    group_output = argparser_callpeak.add_argument_group( "Output files options" )
    add_output_option( group_output )
    ## group for library
    group_library = argparser_callpeak.add_argument_group( "Sequencing library options" )
    add_library_option( group_library )
    ## group for bam options
    group_bam = argparser_callpeak.add_argument_group( "Bam options" )
    add_bam_option( group_bam )
    ## group for normalization options
    group_normalize = argparser_callpeak.add_argument_group( "Normalization options" )
    add_normalize_option ( group_normalize )
    group_normalize.add_argument("--estip-eff", dest = "estipeff", type = str,
                        choices = ['across', 'within'],
                        default = 'within',
                        help = "How estimate the IP efficiency factors of control and treated conditions. If 'within' is specified, it will estimate IP efficiency factors within each condition, repspectively. \
                        If 'across' is specified, will estimate IP efficiency factors across 2 conditions (may introduce over-estimation).")
    ## group for constant options
    group_constant = argparser_callpeak.add_argument_group( "Constants options" )
    add_constant_option( group_constant)
    ## group for constant options
    group_anno = argparser_callpeak.add_argument_group( "Genome and gene annotation options" )
    add_anno_option( group_anno )
    ## group for peak options
    group_peak = argparser_callpeak.add_argument_group( "Peak detection options" )
    add_peak_option( group_peak )
    ## group for model options
    group_model = argparser_callpeak.add_argument_group( "Model options" )
    add_model_option( group_model )
    ## group for filter options
    group_filter = argparser_callpeak.add_argument_group( "Filter options" )
    group_filter.add_argument("-f", "--fold", dest = "foldcutoff", type = float,
                        default = 2,
                        help = "If specified, only peaks have fold change (IP/input) in either conditions (>= fold) will be recognized significant peaks.")
    add_filter_option( group_filter )
    group_filter.add_argument("-F", "--diff-fold", dest = "difffoldcutoff", type = float,
                        default = 1,
                        help = "If specified, only peaks have fold change (IP/input) (>= fold or <= 1/fold) will be recognized significantly changed peaks.")
    group_filter.add_argument("--diff-pval", dest = "diffpvalcutoff", type = float,
                        default = 0.05,
                        help = "If specified, only peaks have p-value <= --pval will be recognized significanttly changed peaks.")
    group_filter.add_argument("--diff-padj", dest = "diffpadjcutoff", type = float,
                        default = 1,
                        help = "If specified, only peaks have adjusted p-value (FDR) <= --padj will be recognized significantly changed peaks.")
    ##
    return

if __name__ == '__main__':
    __spec__ = None
    try:
        main()
    except KeyboardInterrupt:
        sys.stderr.write("User interrupted me! ;-) Bye!\n")
        sys.exit(0)
    except MemoryError:
        sys.exit( "MemoryError occurred. Please try to run with less threads." )
        sys.exit(1)
