#!/usr/bin/env python
import matplotlib
matplotlib.use('TkAgg')

import numpy as np
import matplotlib.pyplot as py
import os
import argparse
import warnings
import sys

from FAST import plotparams
from scipy import optimize

def sms(x,params):
    Vmax = params[0]
    Ks   = params[1]
    
    return Vmax/(1.0+Vmax*x/Ks)

def sms_utr_half(params):
    Vmax = params[0]
    Ks   = params[1]
    
    utr_half = Ks/Vmax
    
    return utr_half,Vmax

def err_sms(params,x,y):
    return y - sms(x,params)

def read_stats_files(fname):
    #output dictionary
    stats_out = []
    g     = open(fname,'r')
    lines = g.readlines()
    g.close()
    
    #Valid data list
    valid_lines = []
    
    if len(lines) > 1:
        for i in range(1,len(lines)):
            line      = lines[i]
            entries   = line[1:].strip().split('\t')
            slide_num = int(entries[0])
            exp_num   = int(entries[1])
            fname     = entries[2].strip()
            protein   = entries[3].strip()
            
            #Rest of the entries are data
            data      = [slide_num,exp_num,fname,protein] + [float(x) for x in entries[4:]]
            stats_out.append(data)
            
            #The data line is valid if the line does not start with #
            if not line[0] == '#':
                valid_lines.append(i-1)

    return np.array(valid_lines),stats_out

#Suppress all the warnings
warnings.filterwarnings("ignore")

#Definition of the program
usage            = ['%(prog)s -d [DIRECTORY]',
                    '------------------------------------------------------------------------',
                    'LIMA v1.0: Loaded In vitro Motility Assay',
                    '05/15/2015',
                    'Tural Aksel',
                    '',
                    'LIMA extracts the loaded motility result from FAST outputs',
                    'For bugs and other problems please contact Tural Aksel at turalaksel@gmail.com',
                    '------------------------------------------------------------------------'
                    ]

#Create the parser
parser = argparse.ArgumentParser(description='',usage='\n'.join(usage)) 
parser.add_argument('-d' ,default = None , help='top directory of the output files to be analyzed')
parser.add_argument('-amin' ,default = 0.0  , type=float              , help='minimum load concentration for analysis')
parser.add_argument('-amax' ,default = 0.0  , type=float              , help='maximum load concentration for analysis')
parser.add_argument('-pmin' ,default = 0.0  , type=float              , help='minimum load concentration for plotting')
parser.add_argument('-pmax' ,default = 0.0  , type=float              , help='maximum load concentration for plotting')
parser.add_argument('-p'  ,nargs   = '*'  , default = None, type=str  , help='protein names to be analyzed')
parser.add_argument('-r'  ,default = True, action = 'store_false'     , help='print fit parameters')
parser.add_argument('-g'  ,default = None , nargs = '*', type=float   , help='initial guess for the parameters to be fitted')
parser.add_argument('-cl' ,nargs   = '*'  , default = None, type=str  , help='plotting colors')

args = parser.parse_args()
args_dict = vars(args)
parser.print_help()

#Get main directory
main_dir           = args_dict['d']
max_load_analysis  = args_dict['amax']
min_load_analysis  = args_dict['amin']
max_load_plot      = args_dict['pmax']
min_load_plot      = args_dict['pmin']
protein_names      = args_dict['p']
print_params       = args_dict['r']
init_params        = args_dict['g']
colors             = args_dict['cl']

#Plotting colors
if colors == None:
    colors        = ['b','r','g','c','m','y']

#Fit function parameters
mobile_init_params = [100, 0.2]
fit_function      = sms
fit_err_function  = err_sms
fit_utr_half      = sms_utr_half
fit_function_name = '_stop_model'

#By default the load molecule is utrophin
load_molecule  = r'utrophin'

#Check if the last character is '/' - if yes, remove it
if main_dir != None and len(main_dir) > 0 and main_dir[-1] == '/':
    main_dir = main_dir[:-1]

#Check if the directory exists
if main_dir == None or not os.path.isdir(main_dir):
    sys.exit("Directory or file doesn't exist. Program is exiting.")

##Combined-averaged data
all_mean_file      = main_dir+'/combined/MEAN_values.txt'
all_std_file       = main_dir+'/combined/SEM_values.txt'

if not os.path.isfile(all_mean_file):
    sys.exit("Combined analysis file does not exist. Program is exiting.")

#Read the mean data
valid,std_data  = read_stats_files(all_std_file)
valid,mean_data = read_stats_files(all_mean_file)

#Header name for the files generated
header = main_dir+'/combined/lima'
if not os.path.exists(header):
    os.mkdir(header)

#Get only valid data
header_data    = [mean_data[i][:9] for i in valid]
slide_list     = np.array([int(x[0]) for x in header_data])
expnum_list    = np.array([int(x[1]) for x in header_data])
fname_list     = np.array([str(x[2]) for x in header_data])
protein_list   = np.array([x[3] for x in header_data])
utr_list       = np.array([float(x[6]) for x in header_data])

#Averaged data
mean_data     = np.array([mean_data[i][7:] for i in valid])
std_data      = np.array([std_data[i][7:]  for i in valid])

#Get protein names
proteins      = sorted(set(protein_list))

#Utrophin set for 
utr_set       = sorted(set(utr_list))

#If maximum utrophin for plot/analysis is not entered, retrieve from the data
if max_load_analysis == 0.0:
    max_load_analysis = max(utr_set)

if max_load_plot == 0.0:
    max_load_plot = max_load_analysis

#Pick only proteins you want to analyze
if not protein_names == None:
    proteins = protein_names
else:
    protein_names = proteins

#Cap-size
cap_size = 25

#Parameters tail
params_tail = ''
if print_params:
    params_tail = '_with_parameters'

#Tail for the plots generated
tail = params_tail

#Plot dimensions
x_plot,y_plot = plotparams.get_figsize(1200)

#All plots in one canvas
py.figure(0,figsize=(2*x_plot,1*y_plot))

#Percent mobile - avg
lines        = []
ref_velocity = 0

for i in range(len(proteins)):
    py.subplot(121)
    valid_pro          = np.nonzero(protein_list == proteins[i])[0]
    
    mean_data_filtered = mean_data[valid_pro,:]
    std_data_filtered  = std_data[valid_pro,:]
    
    #Utrophin concentration
    utr_conc           =  utr_list[valid_pro]
    
    maxvelocity        =  mean_data_filtered[:,0]
    std_maxvelocity    =  std_data_filtered[:,0]
    
    mean_percent_stuck =  mean_data_filtered[:,1]
    std_percent_stuck  =  std_data_filtered[:,1]
    
    mean_MVEL          =  mean_data_filtered[:,2]
    std_MVEL           =  std_data_filtered[:,2]
    
    mean_MVIS          =  mean_data_filtered[:,5]
    std_MVIS           =  std_data_filtered[:,5]
    
    #Slide and experiment numbers
    slide_nums         = slide_list[valid_pro]
    exp_nums           = expnum_list[valid_pro]
    
    #Combine data for in-place sorting
    full_data = np.array([[utr_conc[j],mean_percent_stuck[j],std_percent_stuck[j],mean_MVEL[j],std_MVEL[j],mean_MVIS[j],std_MVIS[j],maxvelocity[j],std_maxvelocity[j],slide_nums[j],exp_nums[j]] for j in range(len(utr_conc))],dtype='f8')
    full_data = full_data.view('f8,f8,f8,f8,f8,f8,f8,f8,f8,f8,f8')
    full_data.sort(order='f0',axis=0)
    
    utr                 = full_data['f0'].flatten()
    mean_frac_mobile    = 1.0 - full_data['f1'].flatten()/100.0
    std_frac_mobile     = full_data['f2'].flatten()/100.0
    
    mean_MVEL           = full_data['f3'].flatten()
    std_MVEL            = full_data['f4'].flatten()
    
    mean_MVIS           = full_data['f5'].flatten()
    std_MVIS            = full_data['f6'].flatten()
    
    maxvelocity         = full_data['f7'].flatten()
    std_maxvelocity     = full_data['f8'].flatten()
    
    slide_nums          = full_data['f9'].flatten()
    exp_nums            = full_data['f10'].flatten()
    
    #Reference velocity is the maximum velocity at 0 nM utrophin
    ref_velocity          = maxvelocity[0]
    
    mobile_correction     = mean_MVEL/ref_velocity
    mean_frac_time_mobile = mean_frac_mobile*mobile_correction
    std_frac_time_mobile  = std_frac_mobile*mobile_correction
    
    #Analyze only the data less than a utrophin concentration
    valid_utr             = np.nonzero((utr <= max_load_analysis)*(utr >= min_load_analysis))[0]
    utr                   = utr[valid_utr]
    
    maxvelocity           = maxvelocity[valid_utr]
    std_maxvelocity       = std_maxvelocity[valid_utr]
    
    mean_frac_mobile      = mean_frac_mobile[valid_utr]
    std_frac_mobile       = std_frac_mobile[valid_utr]
    
    mean_MVEL             = mean_MVEL[valid_utr]
    std_MVEL              = std_MVEL[valid_utr]
    
    mean_MVIS             = mean_MVIS[valid_utr]
    std_MVIS              = std_MVIS[valid_utr]
    
    mean_frac_time_mobile = mean_frac_time_mobile[valid_utr]
    std_frac_time_mobile  = std_frac_time_mobile[valid_utr]
    
    slide_nums            = slide_nums[valid_utr]
    exp_nums              = exp_nums[valid_utr]
    
    num_points            = len(utr)
    
    if num_points == 0:
        sys.exit("No data points for plotting!.Exiting.")
    
    #Plot percent time mobile
    py.subplot(121)
    py.errorbar(utr,mean_frac_time_mobile*100,yerr=std_frac_time_mobile*100,marker='o',color=colors[i],linestyle='None',capsize=cap_size)
    
    #Plot MVIS
    py.subplot(122)
    py.errorbar(utr,mean_MVIS,yerr=std_MVIS,marker='o',color=colors[i],linestyle='None',capsize=cap_size)
    
    data_combined = np.hstack((np.vstack(utr),np.vstack(mean_frac_mobile)*100.0,np.vstack(std_frac_mobile)*100.0))
    np.savetxt(header+'/'+proteins[i]+'_percent_mobile.txt',data_combined)
    
    data_combined = np.hstack((np.vstack(utr),np.vstack(mean_frac_time_mobile)*100.0,np.vstack(std_frac_time_mobile)*100.0))
    np.savetxt(header+'/'+proteins[i]+'_percent_time_mobile.txt',data_combined)
    
    data_combined = np.hstack((np.vstack(utr),np.vstack(mean_MVIS),np.vstack(std_MVIS)))
    np.savetxt(header+'/'+proteins[i]+'_MVIS.txt',data_combined)
    
    data_combined = np.hstack((np.vstack(utr),np.vstack(mean_MVIS),np.vstack(std_MVIS)))
    np.savetxt(header+'/'+proteins[i]+'_MVEL.txt',data_combined)
    
    #There should be at least 3 data points for fitting
    if num_points < 3:
        continue
    
    #Back to percent time mobile panel
    py.subplot(121)
    
    #Fit model to data
    sim_utr= np.linspace(min_load_analysis,max_load_analysis,1000)
    
    params = mobile_init_params
    best_params,success = optimize.leastsq(fit_err_function,params,args=(utr,mean_frac_time_mobile),maxfev=1000)
    
    #The parameters relevant to force production
    V0  = best_params[0]
    Ks  = best_params[1]
    
    #value at Ksinv 
    X_val = fit_function(Ks,best_params)
    
    #Plot line
    line       = py.plot(sim_utr,fit_function(sim_utr,best_params)*100,color=colors[i],linestyle='-')
    
    if print_params:
        py.text(Ks,50-i*10,r'$%.2f nM^{K_S}$'%(Ks),color=colors[i],fontsize=50)
    lines.append(line[0])
    
    #Residuals
    residuals  = fit_err_function(best_params,utr,mean_frac_time_mobile)
    
    #Plot half-load/half-value
    py.plot([Ks,Ks],[X_val*100,0.0],color=colors[i],linestyle='--',linewidth=10)
    
    data_combined = np.hstack((np.vstack(utr),np.vstack(residuals)*100.0))
    np.savetxt(header+'/'+proteins[i]+'_percent_time_mobile_fit_residuals.txt',data_combined)
    
    data_combined = np.hstack((np.vstack(sim_utr),np.vstack(fit_function(sim_utr,best_params))*100))
    np.savetxt(header+'/'+proteins[i]+'_percent_time_mobile_simulated.txt',data_combined)
    
    #Plot the velocity panel
    py.subplot(122)
    
    #Plot line
    py.plot(sim_utr,fit_function(sim_utr,best_params)*ref_velocity,color=colors[i],linestyle='-')
    
    #Plot half-load/half-value
    py.plot([Ks,Ks],[X_val*ref_velocity,0.0],color=colors[i],linestyle='--',linewidth=10)
    
    data_combined = np.hstack((np.vstack(sim_utr),np.vstack(fit_function(sim_utr,best_params))*ref_velocity))
    np.savetxt(header+'/'+proteins[i]+'_MVIS_simulated.txt',data_combined)

py.subplot(121)
py.legend(lines,proteins,loc=1)
py.xlim([min_load_plot,max_load_plot])
py.ylim([0,100])
py.ylabel('% Time mobile')
py.xlabel('Utrophin (nM)')

py.subplot(122)
py.legend(lines,proteins,loc=1)
py.xlim([min_load_plot,max_load_plot])
py.ylabel('MVIS (nm/s)')
py.xlabel('Utrophin (nM)')

py.savefig(header+'/'+'_vs_'.join(proteins)+fit_function_name+tail+'.png',dpi=100)
py.close()
    
#Filament length - avg
py.figure(1,figsize=(x_plot,y_plot))
lines = []
valid      = np.nonzero(~np.isnan(mean_data[:,-3]))[0]
max_length = np.max(mean_data[valid,-3])
for i in range(len(proteins)):
    valid_pro          = np.nonzero(protein_list == proteins[i])[0]
    
    mean_data_filtered = mean_data[valid_pro,:]
    std_data_filtered  = std_data[valid_pro,:]
    
    #utrophin concentration
    utr_conc           = utr_list[valid_pro]
    
    #Slide and experiment numbers
    slide_nums         = slide_list[valid_pro]
    exp_nums           = expnum_list[valid_pro]
    
    mean_length =  mean_data_filtered[:,-3]
    std_length  =  std_data_filtered[:,-3]
    
    #Combine data for in-place sorting
    full_data = np.array([[utr_conc[j],mean_length[j],std_length[j],slide_nums[j],exp_nums[j]] for j in range(len(utr_conc))],dtype='f8')
    full_data = full_data.view('f8,f8,f8,f8,f8')
    full_data.sort(order='f0',axis=0)
    
    utr             = full_data['f0'].flatten()
    mean_length     = full_data['f1'].flatten()
    std_length      = full_data['f2'].flatten()
    slide_nums      = full_data['f3'].flatten()
    exp_nums        = full_data['f4'].flatten()
    
    #Analyze only the data less than a utrophin concentration
    valid_utr       = np.nonzero((utr <= max_load_analysis)*(utr >= min_load_analysis))[0]
    utr             = utr[valid_utr]
    
    mean_length     = mean_length[valid_utr]
    std_length      = std_length[valid_utr]
    slide_nums      = slide_nums[valid_utr]
    exp_nums        = exp_nums[valid_utr]
    
    line = py.errorbar(utr, mean_length,yerr=std_length,marker='o',color=colors[i],linestyle='-',capsize=cap_size)
    lines.append(line[0])
    
    data_combined = np.hstack((np.vstack(utr),np.vstack(mean_length),np.vstack(std_length)))
    np.savetxt(header+'/'+proteins[i]+'_length.txt',data_combined)

#Plotting parameters
py.xlim([min_load_plot,max_load_analysis])
py.xlabel(load_molecule+'(nM)')
py.ylabel('Filament length(nm)')
py.legend(lines,proteins,loc=1)

py.savefig(header+'/'+'_vs_'.join(proteins)+'_length'+tail+'.png',dpi=200)
py.close()