#!/usr/bin/env python

# Copyright (C) 2013 Ian W. Harry, Alex Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" 
Program for running multi-detector workflow analysis through coincidence and then
generate post-processing and plots.
"""
import pycbc
import pycbc.version
__author__  = "Alex Nitz <alex.nitz@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__    = pycbc.version.date
__program__ = "pycbc_make_coinc_search_workflow"

import sys
import pycbc.events, pycbc.workflow as wf
import os, argparse, ConfigParser, logging, glue.segments, numpy, lal, datetime
from pycbc.events.veto import multi_segments_to_file
from pycbc.results import create_versioning_page, static_table, layout
from pycbc.results.versioning import save_fig_with_metadata
from pycbc.results.metadata import html_escape

parser = argparse.ArgumentParser(description=__doc__[1:])
parser.add_argument('--version', action='version', version=__version__) 
parser.add_argument('--workflow-name', default='my_unamed_run')
parser.add_argument("-d", "--output-dir", default=None,
                    help="Path to output directory.")
wf.add_workflow_command_line_group(parser)
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s:%(levelname)s : %(message)s', 
                    level=logging.INFO)

container = wf.Workflow(args, args.workflow_name)
workflow = wf.Workflow(args, 'main')
finalize_workflow = wf.Workflow(args, 'finalization')

wf.makedir(args.output_dir)
os.chdir(args.output_dir)
             
rdir = layout.SectionNumber('results', ['configuration',
                                 'analysis_time',
                                 'detector_sensitivity',
                                 'single_triggers',
                                 'coincident_triggers',
                                 'injections',
                                 'search_sensitivity',
                                 'result',
                                 ])

# put start / end time at top of summary page
time = workflow.analysis_time
s, e = int(time[0]), int(time[1])
s_utc = str(datetime.datetime(*lal.GPSToUTC(s)[0:6]))
e_utc = str(datetime.datetime(*lal.GPSToUTC(e)[0:6]))
time_str = '<center><p><b>GPS Interval [%s,%s). UTC Interval %s - %s. Interval duration = %.3f days.</b></p></center>' % (s, e, s_utc, e_utc, float(e-s)/86400.0)
time_file = wf.File(workflow.ifos, 'time', workflow.analysis_time,
                                           extension='.html',
                                           directory=rdir.base)
wf.makedir(rdir.base)
kwds = { 'title' : 'Search Workflow Duration (Wall Clock Time)', 
        'caption' : "Wall clock start and end times for this invocation of the workflow. "
                    " The command line button shows the arguments used to invoke the workflow "
                    " creation script.",
        'cmd' :' '.join(sys.argv), }
save_fig_with_metadata(time_str, time_file.storage_path, **kwds)

# Get segments and find where the data is
science_segs, data_segs, science_seg_file = wf.get_analyzable_segments(workflow, "segments")
datafind_files, science_segs = wf.setup_datafind_workflow(workflow, 
                                         science_segs, "datafind", science_seg_file)

cum_veto_files, veto_names, ind_cats = wf.get_cumulative_veto_group_files(workflow, 
                                        'segments-veto-groups', "segments")
final_veto_file, final_veto_name, ind_cats = wf.get_cumulative_veto_group_files(workflow, 
                                        'segments-final-veto-group', "segments")

# Precalculated PSDs
precalc_psd_files = wf.setup_psd_workflow(workflow, science_segs,
                                            datafind_files, "psdfiles")

# Template bank stuff
bank_files = wf.setup_tmpltbank_workflow(workflow, science_segs, 
                                            datafind_files, 
                                            "bank",
                                            psd_files=precalc_psd_files)
hdfbank = wf.convert_bank_to_hdf(workflow, bank_files, "bank")
splitbank_files = wf.setup_splittable_workflow(workflow, bank_files, "bank") 

bank_plot = [(wf.make_template_plot(workflow, hdfbank[0], rdir['coincident_triggers']),)]

# setup the injection files
inj_files, inj_tags = wf.setup_injection_workflow(workflow, 
                                                     output_dir="inj_files")
                                                                                                                                                        
# setup gating files if provided
gate_files = wf.setup_gating_workflow(workflow, science_segs,
                                            datafind_files, "gating")

######################## Setup the FULL DATA run ##############################
tag = output_dir = "full_data"
ctags = [tag, 'full']

# setup the matchedfilter jobs                                                     
ind_insps = insps = wf.setup_matchedfltr_workflow(workflow, science_segs, 
                                   datafind_files, splitbank_files, 
                                   output_dir, gate_files=gate_files,
                                   tags = [tag])

insps = wf.merge_single_detector_hdf_files(workflow, hdfbank[0], 
                                           insps, output_dir, tags=[tag])

anal_log_files = wf.setup_analysislogging(workflow,
                                          science_seg_file, insps, args,
                                          "logging", program_name=__program__) 

# setup coinc for the filtering jobs
full_insps = insps
bg_files = wf.setup_interval_coinc(workflow, hdfbank, insps, 
                               cum_veto_files, veto_names, 
                               output_dir, tags=ctags)  
final_bg_files =  wf.setup_interval_coinc(workflow, hdfbank, insps, 
                               final_veto_file, final_veto_name,
                               output_dir, tags=ctags)
final_bg_file = final_bg_files[0][0]                   
bin_files = final_bg_files[0][1]
     
censored_veto = wf.make_foreground_censored_veto(workflow, 
                       final_bg_file, final_veto_file[0], final_veto_name[0],
                       'closed_box', 'segments')      
              
closed_snrifar = []
for bg_file, bg_bins in (bg_files + final_bg_files):
    for bg_bin in bg_bins:
        snrifar = wf.make_snrifar_plot(workflow, bg_bin,
                        rdir['coincident_triggers'], 
                         closed_box=True, tags=bg_bin.tags + ['closed'])
        if bg_file == final_bg_file:
            closed_snrifar.append(snrifar)
    
results_page = []        
for bin_file in bin_files:
    snrifar = wf.make_snrifar_plot(workflow, bin_file,
                    rdir['result'], tags=bin_file.tags)
    ifar = wf.make_ifar_plot(workflow, bin_file, 
                    rdir['result'], tags=bin_file.tags)
    results_page += [(snrifar, ifar)]

wf.make_ifar_plot(workflow, final_bg_file, rdir['result'])

table = wf.make_foreground_table(workflow, final_bg_file, 
                    hdfbank[0], tag, rdir['result'], singles=insps,
                    extension='.html', tags=["html"])
fore_xmlall = wf.make_foreground_table(workflow, final_bg_file,
                    hdfbank[0], tag, rdir['result'], singles=insps,
                    extension='.xml', tags=["xmlall"])
fore_xmlloudest = wf.make_foreground_table(workflow, final_bg_file,
                    hdfbank[0], tag, rdir['result'], singles=insps,
                    extension='.xml', tags=["xmlloudest"])

snrchi = wf.make_snrchi_plot(workflow, insps, censored_veto, 
                    'closed_box', rdir['single_triggers'], tags=[tag])
layout.group_layout(rdir['single_triggers'], snrchi)

hist_summ = []
for insp in full_insps:
    wf.make_singles_plot(workflow, [insp], hdfbank[0], censored_veto, 
           'closed_box', rdir['single_triggers/%s-hexbin' % insp.ifo], tags=[tag])
    wf.make_single_hist(workflow, insp, censored_veto, 'closed_box', 
           rdir['single_triggers/%s-hist' % insp.ifo], 
           exclude='summ', tags=[tag])
    hists = wf.make_single_hist(workflow, insp, censored_veto, 'closed_box', 
           rdir['single_triggers/%s-hist' % insp.ifo], 
           require='summ', tags=[tag])
    hist_summ += list(layout.grouper(hists, 2))

# Calculate the inspiral psds and make plots
full_segs, psd_files, insp_segs, insp_data_segs = [], [], {}, {}                      
for ifo, files in zip(*ind_insps.categorize_by_attr('ifo')):
    name = 'INSPIRAL_SEGMENTS'
    fname = 'segments/%s-' % ifo + name + '.xml'
    insp_segs[ifo] = glue.segments.segmentlist([f.segment for f in files])
    full_segs.append(pycbc.events.segments_to_file(insp_segs[ifo], fname, name, ifo=ifo))

    insp_data_segs[ifo] = glue.segments.segmentlist([f.metadata['data_seg'] for f in files])

    psd_files += [wf.setup_psd_calculate(workflow, datafind_files.find_output_with_ifo(ifo), 
                  ifo, insp_data_segs[ifo], 'INSPIRAL_DATA', 'psds', gate_files=gate_files)]

    insp_data_segs[ifo] = pycbc.events.segments_to_file(insp_data_segs[ifo],
                   'segments/%s-INSP_DATA.xml' % ifo, 'INSPIRAL_DATA', ifo=ifo)

s = wf.make_spectrum_plot(workflow, psd_files, rdir['detector_sensitivity'],
                          precalc_psd_files=precalc_psd_files)
r = wf.make_range_plot(workflow, psd_files, rdir['detector_sensitivity'], require='summ')
r2 = wf.make_range_plot(workflow, psd_files, rdir['detector_sensitivity'], exclude='summ')

det_summ = [(s, r[0] if len(r) != 0 else None)]
layout.two_column_layout(rdir['detector_sensitivity'], det_summ + list(layout.grouper(r2, 2)))    


# run minifollowups on the output of the loudest events 
wf_layout = wf.setup_minifollowups(workflow, final_bg_file,
                                  full_insps,  hdfbank[0],
                                  insp_data_segs, 'INSPIRAL_DATA',
                                  rdir['result'], tags=final_bg_file.tags)

layout.two_column_layout(rdir['result'], results_page + [(table,)] + wf_layout)
    
for bin_file in bin_files:
    if hasattr(bin_file, 'bin_name'):
        wf_layout = wf.setup_minifollowups(workflow, bin_file, full_insps, hdfbank[0],
                                  insp_data_segs, 'INSPIRAL_DATA', 
                                  rdir['result/%s' % bin_file.bin_name],
                                  tags=bin_file.tags)
        layout.two_column_layout(rdir['result/%s' % bin_file.bin_name], wf_layout)

################## Setup segment / veto related plots #########################    
# do plotting of segments / veto times
for ifo, files in zip(*ind_cats.categorize_by_attr('ifo')):
    wf.make_segments_plot(workflow, files, rdir['analysis_time/segments'], 
                          tags=['%s_VETO_SEGMENTS' % ifo])

wf.make_segments_plot(workflow, full_segs,
                 rdir['analysis_time/segments'], tags=['INSPIRAL_SEGMENTS'])
wf.make_segments_plot(workflow, science_seg_file,
                 rdir['analysis_time/segments'], tags=['SCIENCE_MINUS_CAT1'])

# get data segments to write to segment summary XML file
seg_summ_names    = ['DATA', 'ANALYZABLE_DATA', 'TRIGGERS_PRODUCED']
seg_summ_seglists = [data_segs, science_segs, insp_segs]

# declare comparasion segments for table on summary page
veto_summ_names = ['TRIGGERS_PRODUCED&'+veto_name for veto_name in veto_names+final_veto_name]

# write segment summary XML file
seg_list = []; names = []; ifos = []
for segment_list,segment_name in zip(seg_summ_seglists, seg_summ_names):
    for ifo in workflow.ifos:
        seg_list.append(segment_list[ifo])
        names.append(segment_name)
        ifos.append(ifo)
filename = 'segments/'+''.join(workflow.ifos)+'-WORKFLOW_SEGMENT_SUMMARY.xml'
seg_summ_file = multi_segments_to_file(seg_list, filename, names, ifos)

# make segment table for summary page
seg_summ_table = wf.make_seg_table(workflow, [seg_summ_file], seg_summ_names,
                        rdir['analysis_time/segments'], ['SUMMARY'],
                        title_text='Input and output',
                        description='This shows the total amount of input data, analyzable data, and the time for which triggers are produced.')
veto_summ_table = wf.make_seg_table(workflow,
                        [seg_summ_file] + final_veto_file + cum_veto_files,
                        veto_summ_names, rdir['analysis_time/segments'],
                        ['VETO_SUMMARY'], 
                        title_text='Time removed by vetoes',
                        description='This shows the time removed from the output time by the vetoes applied to the triggers.')

# make segment plot for summary page
seg_summ_plot = wf.make_seg_plot(workflow, [seg_summ_file],
                        rdir['analysis_time/segments'],
                        seg_summ_names, ['SUMMARY'])

# make veto definer table
vetodef_table = wf.make_veto_table(workflow, rdir['analysis_time/veto_definer'])

############################## Setup the injection runs #######################                                                                   
inj_coincs = wf.FileList()
for inj_file, tag in zip(inj_files, inj_tags):
    ctags = [tag, 'inj']
    output_dir = '%s_coinc' % tag

    if workflow.cp.has_option_tags('workflow-injections', 'strip-injections', tags=[tag]): 
        small_inj_file = wf.veto_injections(workflow, inj_file, seg_summ_file, 
                                 'TRIGGERS_PRODUCED', "inj_files", tags=[tag])
    else:
        small_inj_file = inj_file
                                 
    # setup the matchedfilter jobs                                                     
    insps = wf.setup_matchedfltr_workflow(workflow, science_segs, 
                                       datafind_files, splitbank_files, 
                                       output_dir, injection_file=small_inj_file,
                                       gate_files=gate_files,
                                       tags = [tag])
   
    insps = wf.merge_single_detector_hdf_files(workflow, hdfbank[0], 
                                               insps, output_dir, tags=[tag])
                                               
    inj_coinc = wf.setup_interval_coinc_inj(workflow, hdfbank,
                                    full_insps, insps, bin_files, 
                                    final_veto_file[0], final_veto_name[0],
                                    output_dir, tags = ctags)
    found_inj = wf.find_injections_in_hdf_coinc(workflow, wf.FileList([inj_coinc]),
                                    wf.FileList([inj_file]), final_veto_file[0], 
                                    final_veto_name[0],
                                    output_dir, tags=ctags)
    inj_coincs += [inj_coinc]  
    
    #foundmissed/sensitivity plots that go in the subsection wells
    s = wf.make_sensitivity_plot(workflow, found_inj, 
                  rdir['search_sensitivity/%s' % tag], 
                   exclude=['all', 'summ'], require='sub', tags=ctags)
    f = wf.make_foundmissed_plot(workflow, found_inj, 
                  rdir['injections/%s' % tag],
                   exclude=['all', 'summ'], require='sub', tags=[tag])                 
    
    # Extra foundmissed/sensitivity plots                    
    wf.make_sensitivity_plot(workflow, found_inj, 
                  rdir['search_sensitivity/%s' % tag], 
                   exclude=['all', 'summ', 'sub'], tags=ctags)
    wf.make_foundmissed_plot(workflow, found_inj, 
                  rdir['injections/%s' % tag],
                   exclude=['all', 'summ', 'sub'], tags=[tag])
                   
    wf.make_inj_table(workflow, found_inj, 
                  rdir['injections/%s' % tag], tags=[tag])
                         
    for inj_insp, trig_insp in zip(insps, full_insps):
        f += wf.make_coinc_snrchi_plot(workflow, found_inj, inj_insp, 
                                  final_bg_file, trig_insp,
                                  rdir['injections/%s' % tag], tags=[tag])

    if len(s) > 0: layout.group_layout(rdir['search_sensitivity/%s' % tag], s)        
    if len(f) > 0: layout.group_layout(rdir['injections/%s' % tag], f)  
    
# Make combined injection plots
inj_summ = []
if len(inj_files) > 0:
    found_inj = wf.find_injections_in_hdf_coinc(workflow, inj_coincs,
                            inj_files, censored_veto, 'closed_box',
                            'allinj', tags=['ALLINJ'])
    sen = wf.make_sensitivity_plot(workflow, found_inj, rdir['search_sensitivity'],
                            require='all', tags=['ALLINJ'])
    layout.group_layout(rdir['search_sensitivity'], sen)
    inj = wf.make_foundmissed_plot(workflow, found_inj, rdir['injections'],
                            require='all', tags=['ALLINJ'])
    layout.group_layout(rdir['injections'], inj)

    # Make summary page foundmissed and sensitivity plot    
    sen = wf.make_sensitivity_plot(workflow, found_inj, 
                rdir['search_sensitivity'], require='summ', tags=['ALLINJ'])
    inj = wf.make_foundmissed_plot(workflow, found_inj, 
                rdir['injections'], require='summ', tags=['ALLINJ'])
    inj_summ = list(layout.grouper(inj + sen, 2))
                            
# make full summary
summ = ([(time_file,)] + [(seg_summ_plot,)] + [(seg_summ_table, veto_summ_table)] + 
       det_summ + hist_summ + bank_plot + list(layout.grouper(closed_snrifar, 2)) + inj_summ)
for row in summ:
    for f in row:
        if f is None:
            continue
        try:
            os.symlink(f.storage_path, os.path.join(rdir.base, f.name))
        except OSError:
            pass           
layout.two_column_layout(rdir.base, summ)

# save global config file to results directory
base = rdir['configuration']
wf.makedir(base)
ini_file_path = os.path.join(base, 'configuration.ini')
with open(ini_file_path, 'wb') as ini_fh:
    container.cp.write(ini_fh)
ini_file = wf.FileList([wf.File(ifos, '', workflow.analysis_time, file_url='file://'+ini_file_path)])
layout.single_layout(base, ini_file)

# Create versioning information
create_versioning_page(rdir['configuration'], container.cp)

wf.make_results_web_page(finalize_workflow, os.path.join(os.getcwd(), rdir.base))

container += workflow
container += finalize_workflow

import Pegasus.DAX3 as dax
dep = dax.Dependency(parent=workflow.as_job, child=finalize_workflow.as_job)
container._adag.addDependency(dep)

container.save()

# Protect the open box results folder
os.chmod(rdir['result'], 0700)

logging.info("Written dax.")
