#!/usr/bin/env python
import os
import sys
import commands
import time
import numpy as np
import argparse
import warnings
import getpass

from FAST import motility

#Utility functions
def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

#Suppress all the warnings
warnings.filterwarnings("ignore")

#Definition of the program
usage            = ['%(prog)s -d [DIRECTORY]',
                    '------------------------------------------------------------------------',
                    'FAST v1.0: Fast Actin Filament Spud Trekker',
                    '03/26/2017',
                    'Tural Aksel',
                    '',
                    'FAST provides fast and accurate analysis of actin gliding assay movies.',
                    'For bugs and other problems please contact Tural Aksel at turalaksel@gmail.com',
                    '------------------------------------------------------------------------'
                    ]

#Create the parser
parser = argparse.ArgumentParser(description='',usage='\n'.join(usage)) 
parser.add_argument('-d'    ,help='top directory of the movies to be analyzed')
parser.add_argument('-f'    ,action  = 'store_true'  ,default=False, help='force analyze all the movies')
parser.add_argument('-r'    ,action  = 'store_true'  ,default=False, help='recalculate instantaneous velocities from saved filament files')
parser.add_argument('-m'    ,action  = 'store_true'  ,default=False, help='make filament tracking movie')

parser.add_argument('-p'    ,default = 5,       type=int,   help='minimum length for the paths to be analyzed (Default:5)')
parser.add_argument('-n'    ,default = 5,       type=int,   help='number of consecutive frames for averaging (Default:5)')

parser.add_argument('-pt'   ,default = 500,     type=int,   help='percent tolerance (Default:none)')
parser.add_argument('-px'   ,default = 80.65,   type=float, help='pixel size in nm (Default:80.65 nm)')
parser.add_argument('-ymax' ,default = 1500,    type=int,   help='maximum velocity for the plot in nm/s (Default:1500)')
parser.add_argument('-xmax' ,default = 10000,   type=int,   help='maximum length for the plot in nm (Default:10000)')
parser.add_argument('-cl'   ,default = 'b',     type=str,   help='color for maximum velocity points (Default:blue)')
parser.add_argument('-fx'   ,default = 'none',  choices = ['none','exp','uyeda'], type=str, help='function to be fitted to maximum velocity data')
parser.add_argument('-maxd' ,default = 2016.25, type=float, help='maximum allowed distance in nm between adjacent frames for a filament (Default:2016.25 nm)')
parser.add_argument('-minv' ,default = 80,      type=float, help='minimum average path velocity for a filament to be classified as stuck (Default:80 nm/s)')

parser.add_argument('-oscore'   ,default = 0.4, type=float, help='overlap score cutoff value to make the connections between filaments - advanced option (Default:0.4)')
parser.add_argument('-lascore'  ,default = 1.0, type=float, help='log-area score cutoff value to make the connections between filaments - advanced option (Default:1.0)')
parser.add_argument('-dlascore' ,default = 0.5, type=float, help='difference-log-area score cutoff value to make the connections between filaments - advanced option (Default:0.5)')

parser.add_argument('-v'    ,action  = 'store_true'  ,default=False,  help='verbose output for debugging')

args = parser.parse_args()
args_dict = vars(args)
parser.print_help()

#Get username
user_name = getpass.getuser()
print '-'*10+'Welcome %s'%(user_name)+'-'*10

#Get the argument
main_dir             = args_dict['d']
force_analysis       = args_dict['f']
recalculate          = args_dict['r']
num_frames_ave       = args_dict['n']
make_movie           = args_dict['m']
min_path_length      = args_dict['p']
percent_tolerance    = args_dict['pt']
plot_ymax            = args_dict['ymax']
plot_xmax            = args_dict['xmax']
maxvel_color         = args_dict['cl']
fit_function         = args_dict['fx']
pixel_size           = args_dict['px']
max_velocity         = args_dict['maxd']
min_velocity         = args_dict['minv']

#Hard-coded parameters
overlap_score_cutoff        = args_dict['oscore']
log_area_score_cutoff       = args_dict['lascore']
diff_log_area_score_cutoff  = args_dict['dlascore']


#Force analysis string value - for the analysis file
force_analysis_str = 'False'
if force_analysis:
    force_analysis_str = 'True'

#Check if the last character is '/' - if yes, remove it
if main_dir != None and len(main_dir) > 0 and main_dir[-1] == '/':
    main_dir = main_dir[:-1]

#Check if the directory exists
if main_dir == None or not os.path.isdir(main_dir):
    sys.exit("Directory doesn't exist. Program is exiting.")

#Check if the tolerance value is None
if percent_tolerance == 500:
    tolerance_prop = 'none'
else:
    tolerance_prop = str(percent_tolerance)

#Main output directory
main_out_dir   = main_dir+'__pt_'+str(tolerance_prop)+'__n_'+str(num_frames_ave)+'__ymax_'+str(int(plot_ymax))+'__p_'+str(min_path_length)+'__fx_'+fit_function
cwd            = os.getcwd()

'''
    Prepare the files for data analysis
''' 

#Tail for the tif files
tail_tif = ""

for root, subFolders, files in os.walk(main_dir):
    if len(subFolders) == 0:
        #Prepare the python script to read frames
        
        tif_files    = filter(lambda x:x[-4:] == '.tif',files)
        if len(tif_files) == 0:
            continue
        
        #Determine the naming scheme of the tif files
        first_tif    = tif_files[0]
        tail_tif     = first_tif.split('_')[2]
        
        #Remove all the spaces in the directories as it causes problems in analysis  
        head,tail_dir= os.path.split(root)
        tail_dir     = '_'.join(tail_dir.split())
        new_root     = head+'/'+tail_dir
        os.rename(root,new_root)
        root = head+'/'+tail_dir
        script_file  = root+'.py'
        f            = open(script_file,'w') 
        lines      = ["#!/usr/bin/env python",
                      "import os",
                      "from FAST import motility",
                      "import sys",
                      "new_Motility            = motility.Motility()",
                      "new_Motility.directory  = '"+tail_dir+"'",
                      "new_Motility.header     = 'img_000000'",
                      "new_Motility.tail       = '"+tail_tif+"'",
                      "new_Motility.read_frame(float(sys.argv[1]),"+force_analysis_str+")",
                      "new_Motility.save_frame()"]
        
        f.write('\n'.join(lines))
        f.close()

        #Prepare the input file for ppss
        input_file = root+'.in'
        f          = open(input_file,'w')
        files      = filter(lambda x:os.path.splitext(x)[1] == '.tif',os.listdir(root))

        #Write the frame numbers
        frame_nos  = sorted([int(os.path.basename(x).split('_')[1]) for x in files])
        frame_nos  = [str(x)+'\n' for x in frame_nos]
        f.writelines(frame_nos)
        f.close()

'''
The analysis starts here
'''

if not os.path.isdir('outputs'):
    os.mkdir('outputs')

if not os.path.isdir('outputs/'+main_out_dir):
    os.mkdir('outputs/'+main_out_dir)

if not os.path.isdir('outputs/'+main_out_dir+'/combined'):
    os.mkdir('outputs/'+main_out_dir+'/combined')

#Output file for parameter averages
out_MEAN_fname  = os.path.abspath('outputs/'+main_out_dir+'/combined/MEAN_values.txt')
out_SEM_fname   = os.path.abspath('outputs/'+main_out_dir+'/combined/SEM_values.txt')

#If we force the analysis or recalculate the parameters, rewrite the values
if not os.path.isfile(out_MEAN_fname) or force_analysis or recalculate:
    data_header = ("%6s\t%4s\t%80s"+"\t%20s"*13+"\n")%('slide','exp','filename','protein','points-filtered','conc(mg/ml)','utrophin(nM)','top-vel-5','p-stuck','MVEL','MVEL-filtered','plateau','MVIS','mean-length-all','mean-length-filtered','mean-length-mobile')
    
    m_stats    = open(out_MEAN_fname,'w')
    m_stats.write(data_header)
    
    s_stats    = open(out_SEM_fname,'w')
    s_stats.write(data_header)
else:
    m_stats   = open(out_MEAN_fname,'a')
    s_stats   = open(out_SEM_fname,'a')

#Folders to process / experiment numbers
process_folders = {}

for root, subFolders, files in os.walk(main_dir):
    #Split the entries
    split_path_entries = root.split('/')
    
    #Create the top-level directory in outputs folder
    top_directory      = split_path_entries[0]
    
    #Check for either *.tif or *.npy files for folders to be processed 
    if len(subFolders) == 0 and (len(filter(lambda x:x[-4:] == '.tif',files)) > 0 or len(filter(lambda x:x[:6] == 'filXYs',files)) > 0):
        entries = root.split('/')
        top_folder = '/'.join(entries[:-1])
        exp_num    = entries[-1]
        if not process_folders.has_key(top_folder):
            process_folders[top_folder] = []
        process_folders[top_folder].append(exp_num)

#Sorted top folders
sorted_top_roots = sorted(process_folders.keys())
for top_root in sorted_top_roots:
    sorted_exp_nums = sorted(process_folders[top_root])
    process_folders[top_root] = sorted_exp_nums

#Data counters
all_data_counter      = 0
combined_data_counter = 0

#Number of frames in a movie
number_of_frames      = 0

#Go through all the folders
for top_root in sorted_top_roots:
    #Combined statistics
    combined_stats        = []
    
    #Combined full length velocity data
    combined_full_len_vel = []
    
    #Combined max length velocity data
    combined_max_len_vel  = []
    
    #Final - folder : lowest level folder
    data_info        = top_root.split('/')
    top_folder       = data_info[-1]
    root_header      = '_'.join(data_info)  
    
    #Extract slide number
    slide_num = -1
    if len(data_info) > 1:
        slide_info    = data_info[-2]
        slide_entries = slide_info.split('_')
        
        if len(slide_entries) == 2 and slide_entries[0] == 'slide':
            slide_num = int(slide_entries[1])
    
    #Extract protein name and utrophin concentration information
    fname_entries = top_folder.split('_')
    
    #Determine protein name
    if len(fname_entries) > 0:
        protein_name = fname_entries[0]
    else:
        protein_name = 'N/A'
    
    #Check if the word utr exists
    utr_found      = False
    positions_utr  = [ x == 'utr' for x in fname_entries]
    
    #By default make utrophin concentration 0
    utrophin_conc  = 0.0
    if sum(positions_utr) > 0:
        pos = fname_entries.index('utr')
        utr_found = True
        if pos - 1 > -1 and fname_entries[pos-1][-2:] == 'nM' and is_number(fname_entries[pos-1][:-2]):
            utrophin_conc = float(fname_entries[pos-1][:-2])
        else:
            utrophin_conc = 0.0
    
    #Check if the word mg or mgml exists for protein concentration determination
    positions_conc  = [ x == 'ml' for x in fname_entries]
    
    if sum(positions_conc) > 0:
        pos = fname_entries.index('ml')
        if pos - 1 > -1 and fname_entries[pos-1][-2:] == 'mg' and is_number(fname_entries[pos-1][:-2]):
            protein_conc = float(fname_entries[pos-1][:-2])
        else:
            protein_conc = 0.0
    else:
        protein_conc = 0.0
    
    #Final - folder : lowest level folder
    data_info        = top_root.split('/')
    top_folder       = data_info[-1]
    root_header      = '_'.join(data_info)  
    
    #Combined length-velocity
    combined_vl_png_name = cwd+'/outputs/'+main_out_dir+'/combined/'+root_header+'_length_velocity.png'
    
    #If combined analysis is performed skip
    if not recalculate and not force_analysis and os.path.isfile(combined_vl_png_name):
        continue
    
    #Change directory to the location that the analysis will be performed
    os.chdir(top_root)
    
    for final_folder in process_folders[top_root]:
        #Construct root name
        root = top_root+'/'+final_folder
        
        #Check movie quality
        new_Frame            = motility.Frame()
        new_Frame.directory  = final_folder
        new_Frame.header     = 'img_000000'
        new_Frame.tail       = tail_tif
        file_exists          = new_Frame.read_frame(0)
        
        #Frame width and height
        frame_width          = new_Frame.width
        frame_height         = new_Frame.height

        if not file_exists:
            picture_quality = 'good'
        else:
            picture_quality = new_Frame.check_picture_quality()
        
        #If the picture quality is bad, return early
        if picture_quality == 'bad':
            print 'Bad picture quality in %s'%(root)
            continue
        
        #Output filenames
        file_header = '_'.join(root.split('/'))
        
        out_vl_png_fname  = cwd+'/outputs/'+main_out_dir+'/'+'_'.join(root.split('/'))+'_length_velocity.png'
        out_vl_txt_fname  = cwd+'/outputs/'+main_out_dir+'/'+'_'.join(root.split('/'))+'_'
        out_path_fname    = cwd+'/outputs/'+main_out_dir+'/'+'_'.join(root.split('/'))+'_paths'
        
        #Input and script filenames for ppss(parallel processing)
        input_file   = final_folder+'.in'
        script_file  = final_folder+'.py'
        
        print 'Processing tif files in %s'%(root)
        
        #Start the timer
        start_t      = time.time()
        
        #If ppss directory exists remove it recursively before starting
        if os.path.exists('ppss_dir'):
            out      = commands.getstatusoutput("rm -rf ppss_dir")
            
        #Parallel-execution of filament extraction of all the frames
        out          = commands.getstatusoutput("ppss -f "+input_file+" -c 'python "+script_file+" '")
        
        #Check ppss installation
        if len(out) == 2 and 'ppss: not found' in out[1]:
            sys.exit("PPSS is not installed. Please install PPSS to use FAST.")
            
        #Delete ppss output directory
        out          = commands.getstatusoutput("rm -rf ppss_dir")
        
        #Go through all the saved npy files - make the frame links
        f         = open(input_file,'r')
        frame_nos = [int(line.strip()) for line in f.readlines()]
        f.close()
        
        #Determine the number of frames in a movie
        number_of_frames = len(frame_nos)
        
        if len(frame_nos) > 0:
            #Create the Motility object
            new_motility                = motility.Motility()
            new_motility.dx             = 1.0*pixel_size
            new_motility.max_velocity   = 1.0*max_velocity/pixel_size
            new_motility.num_frames     = len(frame_nos)
            new_motility.directory      = final_folder
            new_motility.force_analysis = force_analysis
            new_motility.width          = frame_width
            new_motility.height         = frame_height
            new_motility.min_velocity   = min_velocity

            #Assign the hard-coded parameters
            new_motility.overlap_score_cutoff       = overlap_score_cutoff
            new_motility.log_area_score_cutoff      = log_area_score_cutoff
            new_motility.diff_log_area_score_cutoff = diff_log_area_score_cutoff 
            
            #If links are already constructed read frame links and skip the next parts
            if not new_motility.read_frame_links():
                new_motility.load_frame1(0)
                new_motility.read_metadata()
                
                for no in frame_nos[1:]:
                    print 'Making the links: Frame: %d'%(no)
                    new_motility.load_frame2(no)
                    new_motility.make_frame_links()
                    new_motility.frame1 = new_motility.frame2
                
                #Save the frame-links
                new_motility.save_links()
            
            #Process the frame links to create paths, path velocities
            new_motility.process_frame_links(num_frames_ave)
            
            #Plot created paths
            new_motility.plot_2D_path_data(num_frames_ave, extra_fname=out_path_fname)
            
            if make_movie:
                #Reconstruct skeleton images from frame links
                new_motility.reconstruct_skeleton_images()
                
                #Make the movie
                new_motility.make_movie(extra_fname=out_vl_txt_fname)
            
            #If there is no velocity point in the data move on - we need at least 10 points
            if len(new_motility.full_len_vel) < 10:
                continue
            
            top_5_velocity, percent_stuck, MVEL, MVEL_filtered, max_vel_u, MVIS, mean_len_stuck, mean_len_filtered, mean_len_mobile, mean_len_all, num_points_filtered = new_motility.plot_length_velocity(extra_fname=out_vl_png_fname, max_vel=plot_ymax, max_length = plot_xmax, min_path_length=min_path_length, percent_tolerance=percent_tolerance, min_points=10, print_plot=True, maxvel_color = maxvel_color, fit_f = fit_function)
            
            #If eary terminated - skip to next interation
            if top_5_velocity == -1:
                continue
            
            #Add parameters to array
            combined_stats.append([num_points_filtered,top_5_velocity, percent_stuck, MVEL, MVEL_filtered, max_vel_u, MVIS, mean_len_all, mean_len_filtered, mean_len_mobile])
            
            #Write length-velocity data
            new_motility.write_length_velocity(extra_fname=out_vl_txt_fname)
            
            #Add length-velocity data to combined data set
            combined_full_len_vel.append(new_motility.full_len_vel)
            combined_max_len_vel.append(new_motility.max_len_vel)
            
            #Update the counters
            all_data_counter += 1
            
        end_t = time.time()
        print "Time spent: %.1f"%(end_t-start_t)
    
    
    #Go back to main directory
    os.chdir(cwd)
    
    #If there is no data point move on
    if len(combined_full_len_vel) == 0:
        continue
    
    #Prompt which files are combined
    print "Combining data in %s"%(root_header)
    
    #Process and write down the combined results 
    combined_full_len_vel = np.vstack(combined_full_len_vel)
    combined_max_len_vel  = np.vstack(combined_max_len_vel)
    
    #Motility object for combined lenth-velocity data
    combined_motility              = motility.Motility()
    combined_motility.directory    = 'outputs/'+main_out_dir+'/combined'
    combined_motility.full_len_vel = combined_full_len_vel
    combined_motility.max_len_vel  = combined_max_len_vel
    combined_motility.num_frames   = number_of_frames
    combined_motility.dx           = 1.0*pixel_size
    combined_motility.max_velocity = 1.0*max_velocity/pixel_size
    combined_motility.min_velocity = min_velocity

    #Assign the hard-coded parameters
    combined_motility.overlap_score_cutoff       = overlap_score_cutoff
    combined_motility.log_area_score_cutoff      = log_area_score_cutoff
    combined_motility.diff_log_area_score_cutoff = diff_log_area_score_cutoff 

    #Total number of points
    total_num_points = len(combined_full_len_vel[:,0])
    
    
    #If there is no velocity point in the data move on - we need at least 10 points
    if total_num_points  < 10:
        continue
    
    #Process data and write results
    top_5_velocity, percent_stuck, MVEL, MVEL_filtered, max_vel_u, MVIS, mean_len_stuck, mean_len_filtered, mean_len_mobile, mean_len_all,num_points_filtered = combined_motility.plot_length_velocity(header=root_header+'_', max_vel=plot_ymax, max_length = plot_xmax, min_path_length=min_path_length, percent_tolerance=percent_tolerance, min_points=10, print_plot=True, maxvel_color = maxvel_color, fit_f = fit_function)
    
    #If eary terminated - skip to next iteration
    if top_5_velocity == -1:
        continue
    
    #Write length-velocity data
    combined_motility.write_length_velocity(header=root_header+'_')
    
    #Combined statistics
    combined_stats = np.array(combined_stats)
    mvals = np.mean(combined_stats,axis=0)
    svals = np.std(combined_stats,axis=0)
    
    #Total filtered points
    total_filtered_points =  np.sum(combined_stats[:,0])
    
    #Stats lines to write
    fmt_string = "%6d\t%4d\t%80s\t%20s\t%20d"+"\t%20.3f"*11+"\n"
    mean_line  = fmt_string%(slide_num,combined_data_counter, root_header, protein_name, total_filtered_points, protein_conc, utrophin_conc, mvals[1], mvals[2], mvals[3], mvals[4], mvals[5], mvals[6], mvals[7], mvals[8], mvals[9])
    std_line   = fmt_string%(slide_num,combined_data_counter, root_header, protein_name, total_filtered_points, protein_conc, utrophin_conc, svals[1], svals[2], svals[3], svals[4], svals[5], svals[6], svals[7], svals[8], mvals[9])
    
    m_stats.write(mean_line)
    s_stats.write(std_line)

#Close the files
m_stats.close()
s_stats.close()
