#!/usr/bin/env python

# Copyright (c) 2015, Amit Zeisel, Gioele La Manno and Sten Linnarsson
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# This .py file can be used as a library or a command-line version of BackSPIN, 
# This version of BackSPIN was implemented by Gioele La Manno.
# The BackSPIN biclustering algorithm was developed by Amit Zeisel and is described
# in Zeisel et al. Cell types in the mouse cortex and hippocampus revealed by 
# single-cell RNA-seq Science 2015 (PMID: 25700174, doi: 10.1126/science.aaa1934). 
#
# Building using pyinstaller:
# pyinstaller -F backSPIN.py -n backspin-mac-64-bit
#

from __future__ import division
from numpy import *
import getopt
import sys
import os
from backspinpy import SPIN, backSPIN, fit_CV, feature_selection
from backspinpy import CEF_obj



def usage_quick():

    message ='''usage: backSPIN [-hbv] [-i inputfile] [-o outputfolder] [-d int] [-f int] [-t int] [-s float] [-T int] [-S float] [-g int] [-c int] [-k float] [-r float]
    manual: backSPIN -h
    '''
    print (message)

def usage():

    message='''
       backSPIN commandline tool
       -------------------------

       The options are as follows:

       -i [inputfile]
       --input=[inputfile]
              Path of the cef formatted tab delimited file.
              Rows should be genes and columns single cells/samples.
              For further information on the cef format visit:
              https://github.com/linnarsson-lab/ceftools

       -o [outputfile]
       --output=[outputfile]
              The name of the file to which the output will be written

       -d [int]
              Depth/Number of levels: The number of nested splits that will be tried by the algorithm
       -t [int]
              Number of the iterations used in the preparatory SPIN.
              Defaults to 10
       -f [int]   
              Feature selection is performed before BackSPIN. Argument controls how many genes are seleceted.
              Selection is based on expected noise (a curve fit to the CV-vs-mean plot).
       -s [float]
              Controls the decrease rate of the width parameter used in the preparatory SPIN.
              Smaller values will increase the number of SPIN iterations and result in higher 
              precision in the first step but longer execution time.
              Defaults to 0.1
       -T [int]
              Number of the iterations used for every width parameter.
              Does not apply on the first run (use -t instead)
              Defaults to 8
       -S [float]
              Controls the decrease rate of the width parameter.
              Smaller values will increase the number of SPIN iterations and result in higher 
              precision but longer execution time.
              Does not apply on the first run (use -s instead)
              Defaults to 0.3
       -g [int]
              Minimal number of genes that a group must contain for splitting to be allowed.
              Defaults to 2
       -c [int]
              Minimal number of cells that a group must contain for splitting to be allowed.
              Defaults to 2
       -k [float]
              Minimum score that a breaking point has to reach to be suitable for splitting.
              Defaults to 1.15
       -r [float]
              If the difference between the average expression of two groups is lower than threshold the algorythm 
              uses higly correlated genes to assign the gene to one of the two groups
              Defaults to 0.2
       -b [axisvalue]
              Run normal SPIN instead of backSPIN.
              Normal spin accepts the parameters -T -S
              An axis value 0 to only sort genes (rows), 1 to only sort cells (columns) or 'both' for both
              must be passed
       -v  
              Verbose. Print  to the stdoutput extra details of what is happening

    '''

    print(message)



if __name__ == '__main__':
    print("")
    #defaults arguments
    input_path = None
    outfiles_path = None
    numLevels=2 # -d
    feature_fit = False # -f
    feature_genes = 2000
    first_run_iters=10 # -t
    first_run_step=0.1 # -s
    runs_iters=8 # -T
    runs_step=0.3 # -S
    split_limit_g=2 # -g
    split_limit_c=2 # -c
    stop_const = 1.15 # -k
    low_thrs=0.2 # -r
    normal_spin = False #-b
    normal_spin_axis = 'both'
    verbose=False # -v

    optlist, args = getopt.gnu_getopt(sys.argv[1:], "hvi:o:f:d:t:s:T:S:g:c:k:r:b:", ["help", "input=","output="])

    if optlist== [] and args == []:
        usage_quick()
        sys.exit()
    for opt, a in optlist:
        if opt in ("-h", "--help"):
            usage()
            sys.exit()
        elif opt in ('-i', '--input'):
            input_path = a
        elif opt in ("-o", "--output"):
            outfiles_path = a
        elif opt == '-d':
            numLevels = int(a)
        elif opt == '-f':
            feature_fit = True
            if a != '':
                feature_genes = int(a)
        elif opt == '-t':
            first_run_iters = int(a)
        elif opt == '-s':
            first_run_step = float(a)
        elif opt == '-T':
            runs_iters = int(a)
        elif opt == '-S':
            runs_step = float(a)
        elif opt == '-g':
            split_limit_g = int(a)
        elif opt == '-c':
            split_limit_c = int(a)
        elif opt == '-k':
            stop_const = float(a)
        elif opt == '-r':
            low_thrs = float(a)
        elif opt == '-v':
            verbose = True
        elif opt == '-b':
            normal_spin = True
            if a != '':
                if a == 'both':
                    normal_spin_axis = a
                else:
                    normal_spin_axis = int(a)
        else:
            assert False, "%s option is not supported" % opt

    if input_path == None:
        print ('No input file was provided.\nYou need to specify an input file\n(e.g. backSPIN -i path/to/your/file/foo.cef)\n')
        sys.exit()
    if outfiles_path == None:
        print ('No output file was provided.\nYou need to specify an output file\n(e.g. backSPIN -o path/to/your/file/bar.cef)\n')
        sys.exit()

    try:
        if verbose:
            print ('Loading file.')
        input_cef = CEF_obj()
        input_cef.readCEF(input_path)

        data = array(input_cef.matrix)

        if feature_fit:
            if verbose:
                print ("Performing feature selection")
            ix_features = feature_selection(data, feature_genes, verbose=verbose)
            if verbose:
                print ("Selected %i genes" % len(ix_features))
            data = data[ix_features, :]
            input_cef.matrix = data.tolist()
            input_cef.row_attr_values = atleast_2d( array( input_cef.row_attr_values ))[:,ix_features].tolist()
            input_cef.update()

        data = log2(data+1)
        data = data - data.mean(1)[:,newaxis]
        if data.shape[0] <= 3 and data.shape[1] <= 3:
            print ('Input file is not correctly formatted.\n')
            sys.exit()
    except Exception as err:
        import traceback
        print ('There was an error')
        print (traceback.format_exc())
        print ('Error occurred in parsing the input file.')
        print ('Please check that your input file is a correctly formatted cef file.\n')
        sys.exit()

    if normal_spin == False:

        print ('backSPIN started\n----------------\n')
        print ('Input file:\n%s\n' % input_path)
        print ('Output file:\n%s\n' % outfiles_path)
        print ('numLevels: %i\nfirst_run_iters: %i\nfirst_run_step: %.3f\nruns_iters: %i\nruns_step: %.3f\nsplit_limit_g: %i\nsplit_limit_c: %i\nstop_const: %.3f\nlow_thrs: %.3f\n' % (numLevels, first_run_iters, first_run_step, runs_iters,\
            runs_step, split_limit_g, split_limit_c, stop_const, low_thrs))


        results = backSPIN(data, numLevels, first_run_iters, first_run_step, runs_iters, runs_step,\
            split_limit_g, split_limit_c, stop_const, low_thrs, verbose)

        sys.stdout.flush()
        print ('\nWriting output.\n')

        output_cef = CEF_obj()

        for h_name, h_val in zip( input_cef.header_names, input_cef.header_values):
            output_cef.add_header(h_name, h_val )
        for c_name, c_val in zip( input_cef.col_attr_names, input_cef.col_attr_values):
            output_cef.add_col_attr(c_name, array(c_val)[results.cells_order])
        for r_name, r_val in zip( input_cef.row_attr_names, input_cef.row_attr_values):
            output_cef.add_row_attr(r_name, array(r_val)[results.genes_order])

        for level, groups in enumerate( results.genes_gr_level.T ):
            output_cef.add_row_attr('Level_%i_group' % level, [int(el) for el in groups])
        for level, groups in enumerate( results.cells_gr_level.T ):
            output_cef.add_col_attr('Level_%i_group' % level, [int(el) for el in groups])

        output_cef.set_matrix(array(input_cef.matrix)[results.genes_order,:][:,results.cells_order])
        if sum(type(i)==float for i in input_cef.matrix[0]) + sum(type(i)==float for i in input_cef.matrix[-1]) == 0:
            fmt = '%i'
        else:
            fmt ='%.6g'
        output_cef.writeCEF( outfiles_path, matrix_str_fmt=fmt )
    else:

        print ('normal SPIN started\n----------------\n')
        print ('Input file:\n%s\n' % input_path)
        print ('Output file:\n%s\n' % outfiles_path)

        results = SPIN(data, widlist=runs_step, iters=runs_iters, axis=normal_spin_axis, verbose=verbose)

        print ('\nWriting output.\n')

        output_cef = CEF_obj()

        for h_name, h_val in zip( input_cef.header_names, input_cef.header_values):
            output_cef.add_header(h_name, h_val )

        if normal_spin_axis == 'both':
            for c_name, c_val in zip( input_cef.col_attr_names, input_cef.col_attr_values):
                output_cef.add_col_attr(c_name, array(c_val)[results[1]])
            for r_name, r_val in zip( input_cef.row_attr_names, input_cef.row_attr_values):
                output_cef.add_row_attr(r_name, array(r_val)[results[0]])
            output_cef.set_matrix(array(input_cef.matrix)[results[0],:][:,results[1]])

        if normal_spin_axis == 0:
            for r_name, r_val in zip( input_cef.row_attr_names, input_cef.row_attr_values):
                output_cef.add_row_attr(r_name, array(r_val)[results])
            output_cef.set_matrix(array(input_cef.matrix)[results,:])

        if normal_spin_axis == 1:
            for c_name, c_val in zip( input_cef.col_attr_names, input_cef.col_attr_values):
                output_cef.add_col_attr(c_name, array(c_val)[results])
            output_cef.set_matrix(array(input_cef.matrix)[:,results])

        output_cef.writeCEF( outfiles_path )
