#!/usr/bin/env python

# Copyright 2020 Will Gerrard
#This file is part of autoENRICH.

#autoENRICH is free software: you can redistribute it and/or modify
#it under the terms of the GNU Affero General Public License as published by
#the Free Software Foundation, either version 3 of the License, or
#(at your option) any later version.

#autoENRICH is distributed in the hope that it will be useful,
#but WITHOUT ANY WARRANTY; without even the implied warranty of
#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#GNU Affero General Public License for more details.

#You should have received a copy of the GNU Affero General Public License
#along with autoENRICH.  If not, see <https://www.gnu.org/licenses/>.



# To quit things
import sys
# Preferences files are all in json
import json

# Functions for checking input flags
from autoENRICH.util.flag_handler import hdl_targetflag, flag_combos
# Preferences wizard function
from autoENRICH.util.argparse_wizard import run_wizard, get_minimal_args
# Import main command functions
from autoENRICH.top_level import CMD_trainmodel, CMD_predict
# Import pretty banner printing function (for ego purposes only)
from autoENRICH.util.header import print_header_IMP
# Import argument parser
from autoENRICH.util.arguments.argparser import IMP_parser, combine_args
# Used for memory and code tracing
import ast
import tracemalloc
import cProfile
import pstats
from pstats import SortKey

# Define tracer, used to trace code execution line by line
def trace(frame, event, arg):
	print("%s, %s:%d" % (event, frame.f_code.co_filename, frame.f_lineno))
	return trace

def impression(inpargs):
	# Parse arguments into args object
	args = vars(IMP_parser(inpargs))

	# Preserve command argument whilst messing about with preferences / args
	COMMAND = args['Command']
	# Run preferences wizard to ask user for preference choices
	user_args = args
	default_args = get_minimal_args()
	file_args = {}
	wiz_args = {}
	if args['prefs'] in ['wizard', 'default']:
		if args['prefs'] == 'default':
			wiz_args = run_wizard(args, default=True)
		else:
			wiz_args = run_wizard(args)
		pref_file = 'settings_' + args['modelflag'] + '_' + args['featureflag'] + '_' + args['targetflag'] + '.json'
		args['prefs'] = pref_file
		json.dump(args, open(pref_file, 'w'), indent=4)
		with open(pref_file, 'r') as fp:
			args = json.load(fp)

	# Use default preferences for every argument and generate a settings file to edit
	elif args['prefs'] == 'generate':
		pref_file = 'IMPRESSION_settings.json'
		args['prefs'] = pref_file
		json.dump(args, open(pref_file, 'w'), indent=4)
		print('Template preferences file generated')
		sys.exit()
	# Else read preferences file or shout at the user for not specifiying one
	elif args['prefs'] == '':
		print('Only user arguments')
		wiz_args = run_wizard(args, default=True)
	else:
		print('Reading settings from file ', args['prefs'])
		try:
			with open(args['prefs'], 'r') as json_file:
				file_args = json.load(json_file)
		except Exception as E:
			print('Error reading preferences file ', args['prefs'])
			print('You must specify a preferences file, or generate one using --prefs default')
			print(E)
			print('Exiting. . .')
			sys.exit(0)

	# sort arguments provided from multiple sources
	args = combine_args(user_args, file_args, wiz_args, default_args)

	# Restore command argument
	args['Command'] = COMMAND
	# Optional, trace code execution
	if args['tracecode']:
		sys.settrace(trace)
	# Optional, trace memory
	if args['tracemem']:
		tracemalloc.start()

	if args['tracetime']:
		TRACETIME = True
		pr = cProfile.Profile()
		pr.enable()
	else:
		TRACETIME = False
	# Unless making predictions check combination of feature / model
	if args['Command'] not in ['predict', 'setup_predict']:
		# check target flag is valid (nJxy or XCS):
		print(args['targetflag'])
		target = hdl_targetflag.flag_to_target(args['targetflag'])
		# 0 is the bad number
		if target == 0:
			print('Invalid target flag, ', args['targetflag'])
			sys.exit(0)

		# check flag combination for feature / model
		if not flag_combos.check_combination(args['modelflag'], args['featureflag']):
			print('Invalid model and feature combination: ', args['modelflag'], args['featureflag'])
			sys.exit(0)

	# Print pretty banner
	print_header_IMP()

	# set up submission file for model training
	if args['Command'] == "setup_train":
		# If multiple targets, setup submission script for each target
		if len(args['target_list']) > 0:
			# Loop through targets
			for target in args['target_list']:
				# Assign targetflag
				args['targetflag'] = target
				# Define preferences file
				pref_file = 'settings_HPS_' + args['modelflag'] + '_' + args['featureflag'] + '_' + args['targetflag'] + '_' + args['searchflag'] + '.json'
				# Assign to args dict
				args['prefs'] = pref_file
				# Dump prefs in pref file
				json.dump(args, open(pref_file, 'w'), indent=4)
				# Run train_model setup function to create submission file
				CMD_trainmodel.setup_trainmodel(args)
		else:
			# Run train_model setup function to create submission file
			CMD_trainmodel.setup_trainmodel(args)
		# yay success. . .
		print('Training submission file created . . .')

	# Train a model
	if args['Command'] == "train":
		# Check for defined logfile
		if args['logfile'] == '':
			# If no logfile defined, make a sensible one
			args['logfile'] = args['modelflag'] + '_' + args['featureflag'] + '_' + args['targetflag'] + '_' + args['searchflag'] + '.log'

		# Tell the user what model they are training
		print('Training model: ', args['modelflag'] , args['featureflag'], args['targetflag'], args['searchflag'])
		# Train the model
		CMD_trainmodel.trainmodel(args)

	# Set up a prediction
	# This seems like a dumb thing to have, but it facilitates predictions on HPC clusters,
	# and its easier to edit a preferences file than it is to define all the flags
	elif args['Command'] == "setup_predict":
		# Set these flags to empty because these are taken from the specified model objects
		args['targetflag'] = ''
		args['modelflag'] = ''
		args['featureflag'] = ''
		# Define basic preferences file if none specified
		if args['prefs'] == '':
			pref_file = 'settings_predict.json'
			args['prefs'] = pref_file
		json.dump(args, open(args['prefs'], 'w'), indent=4)
		CMD_predict.setup_predict(args)

	# Make predictions from a molecule
	elif args['Command'] == "predict":
		CMD_predict.predict(args)

	# Do code testing, or in this case print sarcastic message
	elif args['Command'] == 'test':
		print('Not done yet, if will wasnt so lazy we would have some nice test code ')

	# Output for memory trace
	try:
		if args['tracemem']:
			snapshot = tracemalloc.take_snapshot()
			top_stats = snapshot.statistics('lineno')

			print("[ Top 10 ]")
			for stat in top_stats[:10]:
			    print(stat)
	except:
		print('No memory trace option')

	#if args['tracetime']:
	if TRACETIME:
		pr.disable()
		ps = pstats.Stats(pr).sort_stats('time')
		ps.print_stats(10)


if __name__ == "__main__":
	impression(sys.argv[1:])














##
