#!/usr/bin/env python

# Copyright (C) 2015 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import pycbc
import pycbc.version

import os
import copy
import logging
import urlparse
import argparse
import lal
import Pegasus.DAX3 as dax

from glue import lal
from glue import segments

import pycbc.workflow as _workflow

################
# setup workflow
################

# read command line
parser = argparse.ArgumentParser(usage='pycbc_make_sngl_workflow \
[--options]',
                                 description="Workflow generator for a \
                                 single detector analysis.")
parser.add_argument("--name", type=str, required=True,
                    help="Descriptive name of the analysis.")
_workflow.add_workflow_command_line_group(parser)
args = parser.parse_args()

# setup log
logging.basicConfig(format='%(asctime)s:%(levelname)s : %(message)s',
                    level=logging.INFO,datefmt='%I:%M:%S')

# directory names
runDir = os.path.join(os.getcwd(), args.name)
datafindDir = runDir + '/datafind'
fulldataDir = runDir + '/full_data'
outputDir = runDir + '/output'
configDir = outputDir + '/config'
segDir = outputDir + '/segments'

# change into run dir
if not os.path.exists(runDir):
  os.makedirs(runDir)
os.chdir(runDir)

# setup workflow and sub-workflows
container = _workflow.Workflow(args, args.name)
workflow = _workflow.Workflow(args, 'main')
html_workflow = _workflow.Workflow(args, 'html')

###################################
# retrieve segments and frame files
###################################

# get segments
scienceSegs, segFileList = _workflow.setup_segment_generation(workflow, segDir)

# get frames
datafinds, scienceSegs = _workflow.setup_datafind_workflow(workflow, scienceSegs,
                           datafindDir, segFileList)

###########################
# setup template bank nodes
###########################

# setup template bank
banks = _workflow.setup_tmpltbank_workflow(workflow, scienceSegs, datafinds, 
                                       datafindDir)

# setup splitbank
splitBanks = _workflow.setup_splittable_workflow(workflow, banks, datafindDir)

###############################
# setup matched filtering nodes
###############################

# setup matched filtering
insps = _workflow.setup_matchedfltr_workflow(workflow, scienceSegs, datafinds,
                                         splitBanks, fulldataDir)

########################
# setup ligolw_add nodes
########################

# setup ligolw_add executable
llwadds = _workflow.FileList([])
llwadd_exe = _workflow.LigolwAddExecutable(workflow.cp, 'llwadd',
                                   ifo=''.join(scienceSegs.keys()),
                                   out_dir=fulldataDir)

# group inspiral files by IFO and loop over IFOs
insps_ifos, insps_ifos_filelists = insps.categorize_by_attr('ifo')
for insps_ifo, insps_ifo_filelist in zip(insps_ifos, insps_ifos_filelists):

    # group IFO inspiral files by valid segments and loop over valid segments
    insps_segs, insps_seg_filelists = insps_ifo_filelist.categorize_by_attr('segment')
    for insps_seg, insps_seg_filelist in zip(insps_segs, insps_seg_filelists): 
 
        # setup ligolw_add nodes for inspiral filelists of identical IFO and valid segment
        if len(insps_seg_filelist) > 1:
          llwadd_file = _workflow.File(insps_ifo, 'LLWADD', insps_seg,
                                       directory=fulldataDir, extension='xml.gz',
                                       tags=['UNCLUSTERED'])
          llwadd_node = llwadd_exe.create_node(insps_seg, insps_seg_filelist,
                                               output=llwadd_file) 
          llwadd_node.add_opt('--lfn-start-time', insps_seg[0])
          llwadd_node.add_opt('--lfn-end-time',insps_seg[1])
          workflow.add_node(llwadd_node)
          llwadds.append(llwadd_file)

########################
# setup clustering nodes
########################

# setup clustering executable
clusters = _workflow.FileList([])
cluster_exe = _workflow.Executable(workflow.cp, 'cluster', out_dir=fulldataDir)

# get all tags for cluster executable
tags = []
for section in workflow.cp.sections():
    split_section_name = section.split('-')
    if len(split_section_name) > 1 and split_section_name[0] == cluster_exe.name:
        tags.append('-'.join(split_section_name[1:]))

# setup clustering nodes for each ligolw_add and tag combination
for llwadd_file in llwadds:
    for tag in tags:
        cluster_file = _workflow.File(llwadd_file.ifo, 'LLWADD_CLUSTER',
                                   llwadd_file.segment,
                                   directory=fulldataDir,
                                   extension='xml.gz', tags=[tag])
        cluster_node = cluster_exe.create_node()
        cluster_node.add_input_opt('--input-file', llwadd_file)
        cluster_node.add_output_opt('--output-file', cluster_file)
        cluster_window = workflow.cp.get_opt_tags(cluster_exe.name, 'cluster-window', [tag])
        cluster_node.add_opt('--cluster-window', cluster_window)
        workflow.add_node(cluster_node)
        clusters.append(cluster_file)

#######################
# setup histogram nodes
#######################

# setup histogramm executable
trigs = _workflow.FileList(llwadds+clusters)
images = _workflow.FileList([])
histogram_exe = _workflow.Executable(workflow.cp, 'histogram', out_dir=outputDir)

# group trigger files by IFO and loop over IFOs
trigs_ifos, trigs_ifos_filelists = trigs.categorize_by_attr('ifo')
for trigs_ifo, trigs_ifo_filelist in zip(trigs_ifos, trigs_ifos_filelists):

    # group trigger files by tags and loop over tags
    trigs_tags, trigs_tags_filelists = trigs_ifo_filelist.categorize_by_attr('tags')
    for trigs_tag, trigs_tag_filelist in zip(trigs_tags, trigs_tags_filelists):

        # setup histogram nodes for trigger filelists of identical IFO and tags
        image_dir = outputDir+'/'+'_'.join([trigs_ifo]+trigs_tag)
        for stat in ['snr', 'new_snr']:
            image_file = _workflow.File(trigs_ifo, 'HISTOGRAM',
                                        workflow.analysis_time,
                                        directory=image_dir,
                                        extension='png', tags=trigs_tag+[stat, 'ALLTIME'])
            histogram_node = histogram_exe.create_node()
            histogram_node.add_input_list_opt('--trigger-files', trigs_tag_filelist)
            histogram_node.add_output_opt('--output-file', image_file)
            histogram_node.add_opt('--gps-start-time', workflow.analysis_time[0])
            histogram_node.add_opt('--gps-end-time', workflow.analysis_time[1])
            if stat == 'new_snr':
                histogram_node.add_arg('--new-snr')
            workflow.add_node(histogram_node)
            images.append(image_file)

        # group trigger files by valid segments and loop over valid segments
        trigs_segs, trigs_seg_filelists = trigs_tag_filelist.categorize_by_attr('segment')
        for trigs_seg, trigs_seg_filelist in zip(trigs_segs, trigs_seg_filelists):

            # setup histogram nodes for trigger filelists of identical IFO and valid segment
            image_dir = outputDir+'/'+'_'.join([trigs_ifo]+trigs_tag)+'/jobs/'+str(trigs_seg[0])+'-'+str(trigs_seg[1])
            for stat in ['snr', 'new_snr']:
                image_file = _workflow.File(trigs_ifo, 'HISTOGRAM',
                                            trigs_seg,
                                            directory=image_dir,
                                            extension='png', tags=trigs_tag+[stat])
                histogram_node = histogram_exe.create_node()
                histogram_node.add_input_list_opt('--trigger-files', trigs_seg_filelist)
                histogram_node.add_output_opt('--output-file', image_file)
                histogram_node.add_opt('--gps-start-time', trigs_seg[0])
                histogram_node.add_opt('--gps-end-time', trigs_seg[1])
                if stat == 'new_snr':
                    histogram_node.add_arg('--new-snr')
                workflow.add_node(histogram_node)
                images.append(image_file)

########################
# setup glitchgram nodes
########################

# setup glitchgram executable
glitchgram_exe = _workflow.Executable(workflow.cp, 'glitchgram', out_dir=outputDir)

# group trigger files by IFO and loop over IFOs
trigs_ifos, trigs_ifos_filelists = trigs.categorize_by_attr('ifo')
for trigs_ifo, trigs_ifo_filelist in zip(trigs_ifos, trigs_ifos_filelists):

    # group trigger files by tags and loop over tags
    trigs_tags, trigs_tags_filelists = trigs_ifo_filelist.categorize_by_attr('tags')
    for trigs_tag, trigs_tag_filelist in zip(trigs_tags, trigs_tags_filelists):

        # setup glitchgram nodes for trigger filelists of identical IFO and tags
        image_dir = outputDir+'/'+'_'.join([trigs_ifo]+trigs_tag)
        for stat in ['snr', 'new_snr']:
            image_file = _workflow.File(trigs_ifo, 'GLITCHGRAM',
                                        workflow.analysis_time,
                                        directory=image_dir,
                                        extension='png', tags=trigs_tag+[stat, 'ALLTIME'])
            config_file = _workflow.File(trigs_ifo, 'GLITCHGRAM',
                                         workflow.analysis_time,
                                         directory=image_dir,
                                         extension='file.ini', tags=trigs_tag+[stat, 'ALLTIME'])
            glitchgram_node = glitchgram_exe.create_node()
            glitchgram_node.add_input_list_opt('--trigger-files', trigs_tag_filelist)
            glitchgram_node.add_output_opt('--output-file', image_file)
            glitchgram_node.add_output_opt('--config-file', config_file)
            glitchgram_node.add_opt('--gps-start-time', workflow.analysis_time[0])
            glitchgram_node.add_opt('--gps-end-time', workflow.analysis_time[1])
            if stat == 'new_snr':
                glitchgram_node.add_arg('--new-snr')
            workflow.add_node(glitchgram_node)
            images.append(image_file)

        # group trigger files by valid segments and loop over valid segments
        trigs_segs, trigs_seg_filelists = trigs_tag_filelist.categorize_by_attr('segment')
        for trigs_seg, trigs_seg_filelist in zip(trigs_segs, trigs_seg_filelists):

            # setup glitchgram nodes for trigger filelists of identical IFO and valid segment
            image_dir = outputDir+'/'+'_'.join([trigs_ifo]+trigs_tag)+'/jobs/'+str(trigs_seg[0])+'-'+str(trigs_seg[1])
            for stat in ['snr', 'new_snr']:
                image_file = _workflow.File(trigs_ifo, 'GLITCHGRAM',
                                            trigs_seg,
                                            directory=image_dir,
                                            extension='png', tags=trigs_tag+[stat])
                config_file = _workflow.File(trigs_ifo, 'GLITCHGRAM',
                                            trigs_seg,
                                            directory=image_dir,
                                            extension='file.ini', tags=trigs_tag+[stat])
                glitchgram_node = glitchgram_exe.create_node()
                glitchgram_node.add_input_list_opt('--trigger-files', trigs_seg_filelist)
                glitchgram_node.add_output_opt('--output-file', image_file)
                glitchgram_node.add_output_opt('--config-file', config_file)
                glitchgram_node.add_opt('--gps-start-time', trigs_seg[0])
                glitchgram_node.add_opt('--gps-end-time', trigs_seg[1])
                if stat == 'new_snr':
                    glitchgram_node.add_arg('--new-snr')
                workflow.add_node(glitchgram_node)
                images.append(image_file)

########################
# setup timeseries nodes
########################

# setup timeseries executable
timeseries_exe = _workflow.Executable(workflow.cp, 'timeseries', out_dir=outputDir)

# group trigger files by IFO and loop over IFOs
trigs_ifos, trigs_ifos_filelists = trigs.categorize_by_attr('ifo')
for trigs_ifo, trigs_ifo_filelist in zip(trigs_ifos, trigs_ifos_filelists):

    # group trigger files by tags and loop over tags
    trigs_tags, trigs_tags_filelists = trigs_ifo_filelist.categorize_by_attr('tags')
    for trigs_tag, trigs_tag_filelist in zip(trigs_tags, trigs_tags_filelists):

        # setup glitchgram nodes for trigger filelists of identical IFO and tags
        image_dir = outputDir+'/'+'_'.join([trigs_ifo]+trigs_tag)
        for stat in ['snr', 'new_snr']:
            image_file = _workflow.File(trigs_ifo, 'TIMESERIES',
                                        workflow.analysis_time,
                                        directory=image_dir,
                                        extension='png', tags=trigs_tag+[stat, 'ALLTIME'])
            timeseries_node = timeseries_exe.create_node()
            timeseries_node.add_input_list_opt('--trigger-files', trigs_tag_filelist)
            timeseries_node.add_output_opt('--output-file', image_file)
            timeseries_node.add_opt('--gps-start-time', workflow.analysis_time[0])
            timeseries_node.add_opt('--gps-end-time', workflow.analysis_time[1])
            if stat == 'new_snr':
                timeseries_node.add_arg('--new-snr')
            workflow.add_node(timeseries_node)
            images.append(image_file)

        # group trigger files by valid segments and loop over valid segments
        trigs_segs, trigs_seg_filelists = trigs_tag_filelist.categorize_by_attr('segment')
        for trigs_seg, trigs_seg_filelist in zip(trigs_segs, trigs_seg_filelists):

            # setup glitchgram nodes for trigger filelists of identical IFO and valid segment
            image_dir = outputDir+'/'+'_'.join([trigs_ifo]+trigs_tag)+'/jobs/'+str(trigs_seg[0])+'-'+str(trigs_seg[1])
            for stat in ['snr', 'new_snr']:
                image_file = _workflow.File(trigs_ifo, 'TIMESERIES',
                                            trigs_seg,
                                            directory=image_dir,
                                            extension='png', tags=trigs_tag+[stat])
                timeseries_node = timeseries_exe.create_node()
                timeseries_node.add_input_list_opt('--trigger-files', trigs_seg_filelist)
                timeseries_node.add_output_opt('--output-file', image_file)
                timeseries_node.add_opt('--gps-start-time', trigs_seg[0])
                timeseries_node.add_opt('--gps-end-time', trigs_seg[1])
                if stat == 'new_snr':
                    timeseries_node.add_arg('--new-snr')
                workflow.add_node(timeseries_node)
                images.append(image_file)

#######################
# setup plotbank nodes
#######################

# setup plot rates executable
plotbank_exe = _workflow.Executable(workflow.cp, 'plotbank', out_dir=outputDir)

# group trigger files by IFO and loop over IFOs
trigs_ifos, trigs_ifos_filelists = trigs.categorize_by_attr('ifo')
for trigs_ifo, trigs_ifo_filelist in zip(trigs_ifos, trigs_ifos_filelists):

    # group trigger files by tags and loop over tags
    trigs_tags, trigs_tags_filelists = trigs_ifo_filelist.categorize_by_attr('tags')
    for trigs_tag, trigs_tag_filelist in zip(trigs_tags, trigs_tags_filelists):

        # setup plot rates nodes for trigger filelists of identical IFO and tags
        image_dir = outputDir+'/'+'_'.join([trigs_ifo]+trigs_tag)
        image_file = _workflow.File(trigs_ifo, 'PLOTBANK',
                                    workflow.analysis_time,
                                    directory=image_dir,
                                    extension='png', tags=trigs_tag+['ALLTIME'])
        plotbank_node = plotbank_exe.create_node()
        plotbank_node.add_input_list_opt('--trigger-files', trigs_tag_filelist)
        plotbank_node.add_output_opt('--output-file', image_file)
        plotbank_node.add_opt('--gps-start-time', workflow.analysis_time[0])
        plotbank_node.add_opt('--gps-end-time', workflow.analysis_time[1])
        workflow.add_node(plotbank_node)
        images.append(image_file)

#######################
# setup plotchisq nodes
#######################

# setup plot rates executable
plotchisq_exe = _workflow.Executable(workflow.cp, 'plotchisq', out_dir=outputDir)

# group trigger files by IFO and loop over IFOs
trigs_ifos, trigs_ifos_filelists = trigs.categorize_by_attr('ifo')
for trigs_ifo, trigs_ifo_filelist in zip(trigs_ifos, trigs_ifos_filelists):

    # group trigger files by tags and loop over tags
    trigs_tags, trigs_tags_filelists = trigs_ifo_filelist.categorize_by_attr('tags')
    for trigs_tag, trigs_tag_filelist in zip(trigs_tags, trigs_tags_filelists):

        # setup plot rates nodes for trigger filelists of identical IFO and tags
        image_dir = outputDir+'/'+'_'.join([trigs_ifo]+trigs_tag)
        image_file = _workflow.File(trigs_ifo, 'PLOTCHISQ',
                                    workflow.analysis_time,
                                    directory=image_dir,
                                    extension='png', tags=trigs_tag+['ALLTIME'])
        plotchisq_node = plotchisq_exe.create_node()
        plotchisq_node.add_input_list_opt('--trigger-files', trigs_tag_filelist)
        plotchisq_node.add_output_opt('--output-file', image_file)
        plotchisq_node.add_opt('--gps-start-time', workflow.analysis_time[0])
        plotchisq_node.add_opt('--gps-end-time', workflow.analysis_time[1])
        workflow.add_node(plotchisq_node)
        images.append(image_file)


########################
# setup plot rates nodes
########################

# setup plot rates executable
plotrates_exe = _workflow.Executable(workflow.cp, 'plotrates', out_dir=outputDir)

# group trigger files by IFO and loop over IFOs
trigs_ifos, trigs_ifos_filelists = trigs.categorize_by_attr('ifo')
for trigs_ifo, trigs_ifo_filelist in zip(trigs_ifos, trigs_ifos_filelists):

    # group trigger files by tags and loop over tags
    trigs_tags, trigs_tags_filelists = trigs_ifo_filelist.categorize_by_attr('tags')
    for trigs_tag, trigs_tag_filelist in zip(trigs_tags, trigs_tags_filelists):

        # setup plot rates nodes for trigger filelists of identical IFO and tags
        image_dir = outputDir+'/'+'_'.join([trigs_ifo]+trigs_tag)
        image_file = _workflow.File(trigs_ifo, 'PLOTRATES',
                                    workflow.analysis_time,
                                    directory=image_dir,
                                    extension='png', tags=trigs_tag+['ALLTIME'])
        plotrates_node = plotrates_exe.create_node()
        plotrates_node.add_input_list_opt('--trigger-files', trigs_tag_filelist)
        plotrates_node.add_output_opt('--output-file', image_file)
        plotrates_node.add_opt('--gps-start-time', workflow.analysis_time[0])
        plotrates_node.add_opt('--gps-end-time', workflow.analysis_time[1])
        workflow.add_node(plotrates_node)
        images.append(image_file)

######################
# setup html page node
######################

# setup html page executable
html_exe = _workflow.Executable(workflow.cp, 'html', universe='local')

# setup html page node
html_node = html_exe.create_node()
html_node.add_opt('--plots-dir', outputDir)
html_workflow.add_node(html_node)

###############
# save and exit
###############

# setup cache file to be used by summary pages
cache = lal.Cache()

# group trigger files by IFO and loop over IFOs
trigs_ifos, trigs_ifos_filelists = trigs.categorize_by_attr('ifo')
for trigs_ifo, trigs_ifo_filelist in zip(trigs_ifos, trigs_ifos_filelists):

  # group trigger files by tags and loop over tags
  trigs_tags, trigs_tags_filelists = trigs_ifo_filelist.categorize_by_attr('tags')
  for trigs_tag, trigs_tag_filelist in zip(trigs_tags, trigs_tags_filelists):

    # append cache file
    cache = trigs_tag_filelist.convert_to_lal_cache()

    # write cache file (mimics old workflow filename)
    tag = '_'.join(trigs_tag).upper()
    filename = outputDir+'/'+trigs_ifo+'-INSPIRAL_'+tag+'.cache'
    cache.tofile(open(filename, 'w'))

# group tmpltbank files by IFO and loop over IFOs
banks_ifos, banks_ifos_filelists = trigs.categorize_by_attr('ifo')
for banks_ifo, banks_ifo_filelist in zip(banks_ifos, banks_ifos_filelists):

  # group tmpltbank files by tags and loop over tags
  trigs_tags, trigs_tags_filelists = trigs_ifo_filelist.categorize_by_attr('tags')
  for trigs_tag, trigs_tag_filelist in zip(trigs_tags, trigs_tags_filelists):

    # append cache file
    cache = trigs_tag_filelist.convert_to_lal_cache()

    # write cache file (mimics old workflow file)
    tag = '_'.join(trigs_tag).upper()
    filename = outputDir+'/'+trigs_ifo+'-TMPLTBANK_'+tag+'.cache'
    cache.tofile(open(filename, 'w'))

# write configuration file
if not os.path.exists(configDir):
  os.makedirs(configDir)
workflow.cp.write(file(configDir+'/'+'workflow_configuration.ini', 'w'))

# append sub-workflows to workflow
container += workflow
container += html_workflow
dep = dax.Dependency(parent=workflow.as_job, child=html_workflow.as_job)
container._adag.addDependency(dep)

# write dax
container.save()
logging.info('Finished.')
