#!/usr/bin/env python
# -*- coding: utf-8
"""Further analyze one or more bins in a collection.

   This is especially useful when there are one or more highly contaminated
   bins in a merged profile.
"""

import os
import sys
import json
import argparse
import webbrowser
import datetime
import random

from multiprocessing import Process
from bottle import route, static_file, redirect, request, BaseRequest, response
from bottle import run as run_server

import anvio
import anvio.utils as utils
import anvio.refine as refine
import anvio.terminal as terminal

from anvio.errors import ConfigError, FilesNPathsError, DictIOError, SamplesError, RefineError


__author__ = "A. Murat Eren"
__copyright__ = "Copyright 2015, The anvio Project"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"


run = terminal.Run()
progress = terminal.Progress()


# get the absolute path for static directory under anvio
static_dir = os.path.join(os.path.dirname(utils.__file__), 'data/interactive')


parser = argparse.ArgumentParser(description="Start the anvi'o interactive interactive for refining")

groupA = parser.add_argument_group('DEFAULT INPUTS', "The interavtive interface can be started with and without\
                                                      anvi'o databases. The default use assumes you have your\
                                                      profile and contigs database, however, it is also possible\
                                                      to start the interface using ad-hoc input files. See 'MANUAL\
                                                      INPUT' section for other set of parameters that are mutually\
                                                      exclusive with datanases.")
groupB = parser.add_argument_group('REFINE-SPECIFICS', "Parameters that are essential to the refinement process.")
groupC = parser.add_argument_group('ADDITIONAL STUFF', "Parameters to provide additional layers, views, or layer data.")
groupD = parser.add_argument_group('VISUALS RELATED', "Parameters that give access to various adjustements regarding\
                                                       the interface.")
groupE = parser.add_argument_group('SWEET PARAMS OF CONVENIENCE', "Parameters and flags that are not quite essential (but\
                                                                   nice to have).")
groupF = parser.add_argument_group('SERVER CONFIGURATION', "For power users.")

groupA.add_argument(*anvio.A('profile-db'), **anvio.K('profile-db'))
groupA.add_argument(*anvio.A('contigs-db'), **anvio.K('contigs-db'))
groupA.add_argument(*anvio.A('samples-information-db'), **anvio.K('samples-information-db'))
groupB.add_argument(*anvio.A('collection-id'), **anvio.K('collection-id'))
groupB.add_argument(*anvio.A('bin-id'), **anvio.K('bin-id'))
groupB.add_argument(*anvio.A('bin-ids-file'), **anvio.K('bin-ids-file'))
groupC.add_argument(*anvio.A('additional-view'), **anvio.K('additional-view'))
groupC.add_argument(*anvio.A('additional-layers'), **anvio.K('additional-layers'))
groupD.add_argument(*anvio.A('split-hmm-layers'), **anvio.K('split-hmm-layers'))
groupE.add_argument(*anvio.A('dry-run'), **anvio.K('dry-run'))
groupE.add_argument(*anvio.A('debug'), **anvio.K('debug'))
groupF.add_argument(*anvio.A('ip-address'), **anvio.K('ip-address'))
groupF.add_argument(*anvio.A('port-number'), **anvio.K('port-number'))
groupF.add_argument(*anvio.A('read-only'), **anvio.K('read-only'))
groupF.add_argument(*anvio.A('server-only'), **anvio.K('server-only'))

args = parser.parse_args()

port = args.port_number
ip = args.ip_address

port = utils.get_available_port_num(start = port, ip=ip)

if not port:
    run.info_single('anvio failed to find a port number that is available :(', mc='red')
    sys.exit(-1)

unique_session_id = random.randint(0,9999999999)

try:
    r = refine.RefineBins(args)
    d = r.refine()
except ConfigError, e:
    print e
    sys.exit(-1)
except FilesNPathsError, e:
    print e
    sys.exit(-2)
except DictIOError, e:
    print e
    sys.exit(-3)
except SamplesError, e:
    print e
    sys.exit(-4)



#######################################################################################################################
# bottle callbacks start
#######################################################################################################################

def set_default_headers(response):
    response.set_header('Content-Type', 'application/json')
    response.set_header('Pragma', 'no-cache')
    response.set_header('Cache-Control', 'no-cache, no-store, max-age=0, must-revalidate')
    response.set_header('Expires', 'Thu, 01 Dec 1994 16:00:00 GMT')

@route('/')
def redirect_to_app():
    redirect('/app/index.html')

@route('/app/:filename#.*#')
def send_static(filename):
    set_default_headers(response)
    return static_file(filename, root=static_dir)

@route('/data/<name>')
def send_data(name):
    set_default_headers(response)
    if name == "clusterings":
        return json.dumps((d.p_meta['default_clustering'], d.p_meta['clusterings']), )
    elif name == "views":
        available_views = dict(zip(d.views.keys(), d.views.keys()))
        return json.dumps((d.default_view, available_views), )
    elif name == "default_view":
        return json.dumps(d.views[d.default_view])
    elif name == "contig_lengths":
        split_lengths = dict([tuple((c, d.splits_basic_info[c]['length']),) for c in d.splits_basic_info])
        return json.dumps(split_lengths)
    elif name == "title":
        return json.dumps(d.title)
    elif name == "mode":
        return json.dumps("refine")
    elif name == "read_only":
        return json.dumps(args.read_only)
    elif name == "bin_prefix":
        if len(r.bins) == 1:
            return json.dumps(list(r.bins)[0] + "_")
        else:
            return json.dumps("Refined_")
    elif name == "session_id":
        return json.dumps(unique_session_id)
    elif name == "samples_order":
        return json.dumps(d.samples_order_dict)
    elif name == "samples_information":
        return json.dumps(d.samples_information_dict)
    elif name == "samples_information_default_layer_order":
        return json.dumps(d.samples_information_default_layer_order)

@route('/data/view/<view_id>')
def get_view_data(view_id):
    return json.dumps(d.views[view_id])

@route('/tree/<tree_id>')
def send_tree(tree_id):
    set_default_headers(response)

    if tree_id in d.p_meta['clusterings']:
        run.info_single("Clustering of '%s' has been requested" % (tree_id))
        return json.dumps(d.p_meta['clusterings'][tree_id]['newick'])

    return json.dumps("")

@route('/data/charts/<split_name>')
def charts(split_name):
    data = {'layers': [],
             'index': None,
             'total': None,
             'coverage': [],
             'variability': [],
             'competing_nucleotides': [],
             'previous_contig_name': None,
             'next_contig_name': None,
             'genes': []}

    if split_name not in d.split_names:
        return data

    if not d.auxiliary_data_available:
        return data

    index_of_split = d.split_names_ordered.index(split_name)
    if index_of_split:
        data['previous_contig_name'] = d.split_names_ordered[index_of_split - 1]
    if (index_of_split + 1) < len(d.split_names_ordered):
        data['next_contig_name'] = d.split_names_ordered[index_of_split + 1]

    data['index'] = index_of_split + 1
    data['total'] = len(d.split_names_ordered)

    layers = sorted(d.p_meta['samples'])

    coverage_values_dict = d.split_coverage_values.get(split_name)
    data['coverage'] = [coverage_values_dict[layer].tolist() for layer in layers]

    ## get the variability information dict for split:
    split_variability_info_dict = d.get_variability_information_for_split(split_name)

    for layer in layers:
        data['layers'].append(layer)
        data['competing_nucleotides'].append(split_variability_info_dict[layer]['competing_nucleotides'])

        # we get a nice dict back, but here we convert it into a shitty list...
        l = [0] * d.splits_basic_info[split_name]['length']
        vd = split_variability_info_dict[layer]['variability']
        for pos in vd:
            l[pos] = vd[pos]
        data['variability'].append(l)

    levels_occupied = {1: []}
    for entry_id in d.split_to_genes_in_splits_ids[split_name]:
        prot_id = d.genes_in_splits[entry_id]['prot']
        p = d.genes_in_splits[entry_id]
        # p looks like this at this point:
        #
        # {'percentage_in_split': 100,
        #  'start_in_split'     : 16049,
        #  'stop_in_split'      : 16633}
        #  'prot'               : u'prot2_03215',
        #  'split'              : u'D23-1contig18_split_00036'}
        #
        # we will add two more attributes:
        p['direction'] = d.genes_in_contigs_dict[prot_id]['direction']
        p['function'] = d.genes_in_contigs_dict[prot_id]['function'] or None

        for level in levels_occupied:
            level_ok = True
            for gene_tuple in levels_occupied[level]:
                if (p['start_in_split'] >= gene_tuple[0] - 100 and p['start_in_split'] <= gene_tuple[1] + 100) or\
                            (p['stop_in_split'] >= gene_tuple[0] - 100 and p['stop_in_split'] <= gene_tuple[1] + 100):
                    level_ok = False
                    break
            if level_ok:
                levels_occupied[level].append((p['start_in_split'], p['stop_in_split']), )
                p['level'] = level
                break
        if not level_ok:
            levels_occupied[level + 1] = [(p['start_in_split'], p['stop_in_split']), ]
            p['level'] = level + 1

        data['genes'].append(p)

    return json.dumps(data)

state_for_charts = {}

@route('/data/charts/set_state', method='POST')
def set_state():
    global state_for_charts
    state_for_charts = request.forms.get('state')

@route('/data/charts/get_state')
def get_parent_state():
    set_default_headers(response)
    return state_for_charts

@route('/data/contig/<split_name>')
def split_info(split_name):
    set_default_headers(response)
    return json.dumps(d.split_sequences[split_name])

@route('/store_refined_bins', method='POST')
def store_refined_bins():
    data = json.loads(request.forms.get('data'))
    colors = json.loads(request.forms.get('colors'))

    try:
        r.store_refined_bins(data, colors)
    except RefineError, e:
        return json.dumps({'status': -1, 'message': e.clear_text()})

    message = 'Done! Collection %s is updated in the database. You can close your browser window (or continue updating).' % (r.collection_id)
    return json.dumps({'status': 0, 'message': message})

@route('/data/completeness', method='POST')
def completeness():
    completeness_stats = {}
    if not d.completeness:
        return json.dumps(completeness_stats)

    split_names = json.loads(request.forms.get('split_names'))
    bin_name = json.loads(request.forms.get('bin_name'))

    run.info_single('Completeness info has been requested for %d splits in %s' % (len(split_names), bin_name))

    completeness_stats = d.completeness.get_info_for_splits(set(split_names))

    return json.dumps({'stats': completeness_stats, 'refs': d.completeness.http_refs})

@route('/state/autoload')
def state_autoload():
    # see --state parameter.
    set_default_headers(response)

    return json.dumps(d.state)

@route('/state/all')
def state_all():
    set_default_headers(response)

    return json.dumps(d.states_table.states)

@route('/state/get', method='POST')
def get_state():
    set_default_headers(response)

    state = d.states_table.states[request.forms.get('name')]

    if state is not None:
        return json.dumps(state['content'])

    return json.dumps("")

@route('/state/save', method='POST')
def save_state():
    if args.read_only:
        return json.dumps({'status_code': '0'})

    name = request.forms.get('name')
    content = request.forms.get('content')
    last_modified = datetime.datetime.now().strftime("%d.%m.%Y %H:%M:%S")

    d.states_table.store_state(name, content, last_modified)

    return json.dumps({'status_code': '1'})


#######################################################################################################################
# bottle callbacks end
#######################################################################################################################

# increase maximum size of form data to 100 MB
BaseRequest.MEMFILE_MAX = 1024 * 1024 * 100 

if args.dry_run:
    run.info_single('Dry run, eh? Bye!', 'red', nl_before = 1, nl_after=1)
    sys.exit()

try:
    server_process = Process(target=run_server, kwargs={'host': ip, 'port': port, 'quiet': True})
    server_process.start()
    webbrowser.open_new("http://%s:%d" % (ip, port))
    run.info_single('When you are finished, press CTRL+C to terminate the server.', 'green', nl_before = 1, nl_after=1)
    server_process.join()
except KeyboardInterrupt:
    run.warning('The server is being terminated.', header='Please wait...')
    server_process.terminate()
    sys.exit(1)
