#!/usr/bin/env python

# Copyright (C) 2013 Ian W. Harry, Alex Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Program for running multi-detector workflow analysis through coincidence and then
generate post-processing and plots.
"""
import pycbc, pycbc.version, pycbc.events, pycbc.workflow as wf
import os, argparse, ConfigParser, logging, glue.segments


parser = argparse.ArgumentParser(description=__doc__[1:])
parser.add_argument('--version', action='version', 
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--workflow-name', default='my_unamed_run')
parser.add_argument("-d", "--output-dir", default=None,
                    help="Path to output directory.")
wf.add_workflow_command_line_group(parser)
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s:%(levelname)s : %(message)s', 
                    level=logging.INFO)

container = wf.Workflow(args, args.workflow_name)
workflow = wf.Workflow(args, 'main')
finalize_workflow = wf.Workflow(args, 'finalization')

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
os.chdir(args.output_dir)

# Get segments and find where the data is
science_segs, science_seg_file = wf.get_analyzable_segments(workflow, "segments")
datafind_files, science_segs = wf.setup_datafind_workflow(workflow, 
                                         science_segs, "datafind", science_seg_file)

cum_veto_files, ind_cats = wf.get_cumulative_veto_group_files(workflow, 
                                        'segments-veto-groups', "segments")
final_veto_file, ind_cats = wf.get_cumulative_veto_group_files(workflow, 
                                        'segments-final-veto-group', "segments")

# Template bank stuff
bank_files = wf.setup_tmpltbank_workflow(workflow, science_segs, 
                                            datafind_files, "bank")                                            
hdfbank = wf.convert_bank_to_hdf(workflow, bank_files, "bank")
splitbank_files = wf.setup_splittable_workflow(workflow, bank_files, "bank") 


# setup the injection files
inj_files, inj_tags = wf.setup_injection_workflow(workflow, 
                                                     output_dir="inj_files")
                                                                                                                                                              
bg_file = None                                                                                     
tags = ["full_data"] + inj_tags
output_dirs = ["full_data"]
inj_coincs = wf.FileList()
output_dirs += inj_tags
for inj_file, tag, output_dir in zip([None]+inj_files, tags, output_dirs):

    if tag == 'full_data':
        ctags = [tag, 'full']
    else:
        ctags = [tag, 'inj']
        output_dir += '_coinc'

    # setup the matchedfilter jobs                                                     
    insps = wf.setup_matchedfltr_workflow(workflow, science_segs, 
                                       datafind_files, splitbank_files, 
                                       output_dir, injection_file=inj_file,
                                       tags = [tag])
   
    ind_insps = insps if tag == 'full_data' else ind_insps
   
    insps = wf.merge_single_detector_hdf_files(workflow, hdfbank[0], insps, output_dir, tags=[tag])

    # setup coinc for the filtering jobs
    if tag == 'full_data':
        full_insps = insps
        bg_files = wf.setup_interval_coinc(workflow, hdfbank, insps, cum_veto_files,
                                       output_dir, tags=ctags)  
        final_bg_file =  wf.setup_interval_coinc(workflow, hdfbank, insps, final_veto_file,
                                       output_dir, tags=ctags)   
        for bg_file in (final_bg_file + bg_files):
            wf.make_snrifar_plot(workflow, bg_file, 'plots/background', tags=bg_file.tags)

        wf.make_foreground_table(workflow, final_bg_file[0], hdfbank[0], tag, 'plots/foreground')
        wf.make_snrchi_plot(workflow, insps, final_veto_file[0], 'plots/background', tags=[tag])
        wf.make_singles_plot(workflow, insps, hdfbank[0], final_veto_file[0], 'plots/background', tags=[tag])

    else:
        inj_coinc = wf.setup_interval_coinc_inj(workflow, hdfbank,
                                                full_insps, insps, final_bg_file, final_veto_file,
                                                output_dir, tags = ctags)
        found_inj = wf.find_injections_in_hdf_coinc(workflow, wf.FileList([inj_coinc]),
                                                wf.FileList([inj_file]), final_veto_file[0], 
                                                output_dir, tags=ctags)
        inj_coincs += [inj_coinc]                      
        wf.make_sensitivity_plot(workflow, found_inj, 'plots/sensitivity/%s' % tag,
                                 tags=ctags)
        wf.make_foundmissed_plot(workflow, found_inj, 'plots/foundmissed/%s' % tag, 
                                 tags=[tag])
        wf.make_inj_table(workflow, found_inj, 'plots/foundmissed/%s' % tag, tags=[tag])
                                 
        for inj_insp, trig_insp in zip(insps, full_insps):
            wf.make_coinc_snrchi_plot(workflow, found_inj, inj_insp, 
                                      final_bg_file[0], trig_insp,
                                      'plots/background/%s' % tag, tags=[tag])
   
full_segs = []                                 
for ifo, files in zip(*ind_insps.categorize_by_attr('ifo')):
    name = 'INSPIRAL_SEGMENTS'
    fname = 'segments/%s-' % ifo + name + '.xml'
    fsegs = glue.segments.segmentlist([f.segment for f in files])
    full_segs.append(pycbc.events.segments_to_file(fsegs, fname, name, ifo=ifo))

for ifo, files in zip(*ind_cats.categorize_by_attr('ifo')):
    wf.make_segments_plot(workflow, files, 'plots/segments', 
                          tags=['%s_VETO_SEGMENTS' % ifo])

wf.make_segments_plot(workflow, full_segs, 'plots/segments', tags=['INSPIRAL_SEGMENTS'])
wf.make_segments_plot(workflow, science_seg_file, 'plots/segments', tags=['SCIENCE_MINUS_CAT1'])

if len(inj_files) > 0:
    found_inj = wf.find_injections_in_hdf_coinc(workflow, inj_coincs,
                                            inj_files, final_veto_file[0], 
                                            output_dir, tags=['ALLINJ'])                                                
    wf.make_sensitivity_plot(workflow, found_inj, 'plots/sensitivity', 
                                            tags=['ALLINJ'])

# save global config file to results directory
ini_file_path = 'plots/configuration.ini'
with open(ini_file_path, 'wb') as ini_fh:
    container.cp.write(ini_fh)

wf.make_results_web_page(finalize_workflow, os.path.join(os.getcwd(), 'plots'))

container += workflow
container += finalize_workflow

import Pegasus.DAX3 as dax
dep = dax.Dependency(parent=workflow.as_job, child=finalize_workflow.as_job)
container._adag.addDependency(dep)

container.save()
logging.info("Written dax.")
