#!python

from __future__ import division, unicode_literals, print_function

"""
This script serves as a management tool for vasp projects, starting 
from encut, kpoint or other parameter optimization of till the slab 
solvation. Just define all types of calculations with their 
corresponding specifications needed for the project in a yaml file 
and run or rerun calculaitons as required.

it takes 3 arguments: input yaml file, type of calculation and the 
run mode

example:
   mpint -i naf.yaml -t bulk_calibrate run

this will read in the specifications for 'bulk_calibrate' job from the 
input yaml file, naf.yaml(in examples folder), and runs the job i.e 
submits to the PBS que.

run modes supported:
  1. run : submits job to the que
  2. check: check whether the job is finished or not
  3. energies: get the energies and the optimum parameter values

Everytime jobs are submitted or its sttaus queried, information such 
as job ids, job folders etc are written to the log file 'mpint.log'. 
This makes it easier to identify job ids and their corresponding job 
folders.
 
Note: use your own materials project key to download the required 
structure

Note: this script submits jobs only to the PBS ques such as hipergator
"""

import os
import sys
import shutil
import datetime
import yaml
import logging
from collections import OrderedDict
from argparse import ArgumentParser

from pymatgen.matproj.rest import MPRester
from pymatgen.core.structure import Structure
from pymatgen.core.lattice import Lattice
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.vasp.inputs import Incar, Poscar
from pymatgen.io.vasp.inputs import Potcar, Kpoints
from pymatgen.io.vasp.outputs import Outcar
from pymatgen.apps.borg.queen import BorgQueen

from fireworks.user_objects.queue_adapters.common_adapter import CommonAdapter

from mpinterfaces import get_struct_from_mp, Interface
from mpinterfaces.calibrate import Calibrate, CalibrateSlab
from mpinterfaces.data_processor import MPINTVaspDrone

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(levelname)s:%(asctime)s:%(message)s')
fh = logging.FileHandler('mpint.log', mode='a')
sh = logging.StreamHandler(stream=sys.stdout)
fh.setFormatter(formatter)
sh.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(sh)

MAPI_KEY = os.environ.get("MAPI_KEY", "")

class mpint(object):
    
        def __init__(self):
            self.formula = None
            self.qadapter = None
            self.cal_job_dir = None
            self.turn_knobs = OrderedDict(
                [
                    ('KPOINTS', []), 
                    ('ENCUT', [])
                    ]
                )
            self.incar_dict = {
            'PREC': 'Accurate',
            'ENCUT': 500, 
            'ISMEAR': 1, 
            'EDIFF': 1e-06, 
            'NPAR': 4, 
            'SIGMA': 0.1, 
            'NSW' : 0,
            'ISIF': 3, 
            'IBRION': 2
            }
            self.incar = Incar.from_dict(self.incar_dict)
            # some default value
            self.poscar = Poscar.from_string('Li1\n'+
                            '1.0\n'+
                            '2.797985 0.000000 -0.989237\n'+
                            '-1.398992 2.423126 -0.989237\n'+
                            '0.000000 0.000000 2.967711\n'+
                            'Li\n'+
                            '1\n'+
                            'direct\n'+
                            '0.000000 0.000000 0.000000 Li\n')
            self.kpoints = Kpoints.monkhorst_automatic(kpts=(8, 8, 8))
            self.potcar = Potcar(self.poscar.site_symbols)
            self.slab = {}


        def get_d(self, nnodes=1, nprocs=32, walltime=4,
                  bin='/home/km468/Software/VASP/vaspsol_nln_kappa.5.3.5/vasp',
                  mem=1000):
            return {'type':'PBS',
                    'params':
                        {
                    'nnodes': str(nnodes),
                    'ppnode': str(int(nprocs/nnodes)),
                    'walltime': walltime,
                    'job_name': 'vasp_job',
                    'pre_rocket': '#PBS -l pmem='+str(mem)+'mb',
                    'rocket_launch': 'mpirun '+bin
                    }
                    }


        def get_struct(self):
            logger.info("obtaining structure from materialsproject and coverting it to conventional cell..should verify")
            logger.info("the structure obtained from the materials project correspond to the one with the lowest energy above the hull")
            if not MAPI_KEY:
                print('get API KEY from materialsproject and set it to the MAPI_KEY environment variable')
                sys.exit()
            strt = get_struct_from_mp(self.formula, MAPI_KEY=MAPI_KEY)
            sa = SpacegroupAnalyzer(strt) 
            structure_conventional = sa.get_conventional_standard_structure()
            strt = structure_conventional.copy()
            return strt


        def get_slab_struct(self, fin, hkl, min_thick=10, min_vac=12):
            logger.info("ase is used to create an orthogonal slab")
            logger.info("bulk structure file: {0}".format(fin))    
            logger.info("hkl = {0}, min_thick = {1}, min_vac = {2}".format(hkl, min_thick, min_vac))
            bulk = Structure.from_file(fin) 
            iface = Interface(bulk, hkl=hkl, 
                              min_thick=min_thick, min_vac=min_vac,
                              primitive= False, from_ase = True)
            iface.create_interface()
            iface.sort()
            return iface


        def update_inputset(self, update, from_mp=True):
            logger.info("updating the inputset")
            logger.info("note: using monkhorst automatic kpoints")
            nprocs = 32
            nnodes = 1
            bin = '/home/km468/Software/VASP/vaspsol_nln_kappa.5.3.5/vasp'
            mem = 1000
            walltime = '04:00:00'
            if update.get('POSCAR'):
                logger.info("setting poscar from file {0}".format(update['POSCAR']))
                self.poscar = Poscar.from_file(update['POSCAR'])
            elif update.get('SLAB'):
                self.slab = update['SLAB']
                if len(self.slab)>2:
                    logger.info("creating slab")
                    strt = self.get_slab_struct(*self.slab)
                    self.poscar = Poscar(strt)
            elif from_mp:
                logger.info("neither poscar nor slab creation setup provided, getting the bulk structure for the given formula from the materialsproject database")
                strt = self.get_struct()
                self.poscar = Poscar(strt)
            logger.info("setting potcar from poscar symbols")
            self.potcar = Potcar(self.poscar.site_symbols)
            logger.info("updating INCAR")
            self.incar_dict['SYSTEM'] = self.formula
            if update.get('INCAR'):
                self.incar_dict.update(update['INCAR'])
                self.incar = Incar.from_dict(self.incar_dict)
            if update.get('KPOINTS'):
                logger.info("updating KPOINTS")
                self.kpoints = Kpoints.monkhorst_automatic(kpts=update['KPOINTS']) #automatic(int(update['KPOINTS']))
            if update.get('QUE'):
                logger.info("updating que settings")
                if update['QUE'].get('NPROCS'):
                    nprocs = update['QUE']['NPROCS']
                if update['QUE'].get('NNODES'):
                    nnodes = update['QUE']['NNODES']
                if update['QUE'].get('MEM'):
                    mem = update['QUE']['MEM']
                if update['QUE'].get('TIME'):
                    walltime = str(update['QUE']['TIME'])+':00:00'
                if update['QUE'].get('BIN'):
                    bin = update['QUE']['BIN']
                d = self.get_d(nnodes, nprocs, walltime, bin, mem)
                self.qadapter = CommonAdapter(d['type'], **d['params'])
            if update.get('WAVECAR') :
                wavecar_file = update['WAVECAR']
                if os.path.isfile(wavecar_file):
                    logger.info("copying wavecar from {0} to {1}".format(wavecar_file, self.cal_job_dir))
                    if not os.path.exists(self.cal_job_dir):
                        os.makedirs(self.cal_job_dir)  
                    shutil.copy(wavecar_file, self.cal_job_dir+os.sep+'WAVECAR')        


        def process_input(self, args):
            update = {}
            kptlist = []
            encutlist= []
            if args.input:
                f = open(args.input)
                update = yaml.load(f)
                self.formula = update['FORMULA']
                f.close()
                self.cal_job_dir = self.formula+'_single'
                if args.type:
                    self.cal_job_dir = self.formula+'_'+args.type
                    logger.info("{0} type job for {1} system".format(args.type, self.formula))
                logger.info("job folder: {0}".format(self.cal_job_dir))
		# initial input set
                if args.command == 'run':
                    self.update_inputset(update, from_mp=False)
		#run specific input set
                #update --> run specific dict
                if update.get(args.type):
                    update = update[args.type]
                    if args.command == 'run':
                        self.update_inputset(update)
                    if update.get('SLAB'):
                        self.slab = update['SLAB']
                    if update.get('KNOBS'):
                        logger.info("updating knob settings")
                        knobset = update.get('KNOBS')
                        for key, val in knobset.items():
                            paramlist = []
                            if key == 'KPOINTS':
                                paramlist = [[k,k,k] for k in range(knobset['KPOINTS'][0], knobset['KPOINTS'][1], knobset['KPOINTS'][2])]
                            else:
                                paramlist = range(knobset[key][0], knobset[key][1], knobset[key][2])
                            self.turn_knobs.update({key:paramlist})


        def get_cal(self, args):
            self.process_input(args)
            cal = None
            if self.turn_knobs.get('VACUUM') or self.turn_knobs.get('THICKNESS'):
                logger.info("slab calibration: thickness or vacuum")
                logger.info("bulk structure from file: {0}".format(self.slab[0]))
                self.poscar = Poscar.from_file(self.slab[0])
                logger.info("hkl: {0}".format(self.slab[1]))
                system = {'hkl': self.slab[1]}
                cal = CalibrateSlab(self.incar, self.poscar, 
                                    self.potcar, self.kpoints,
                                    turn_knobs = self.turn_knobs,
                                    Grid_type='M', 
                                    qadapter = self.qadapter,
                                    job_dir = self.cal_job_dir, 
                                    job_cmd = None, wait = False, 
                                    system = system, from_ase = True)
            else:
                cal = Calibrate(self.incar, self.poscar, 
                                self.potcar, self.kpoints,
                                turn_knobs = self.turn_knobs,
                                Grid_type='M', 
                                qadapter = self.qadapter,
                                job_dir = self.cal_job_dir, 
                                job_cmd = None, wait=False)
                
            cal.setup()
            return cal

        def run(self, args):
            cal_objs = [self.get_cal(args)]
            for cal in cal_objs:
                cal.run()
                logger.info("the job ids for the job in the directory {0} are {1}".format(cal.job_dir_list, cal.job_ids))


        def check(self, args):
            cal_objs = [self.get_cal(args)]
            done  = Calibrate.check_calcs(cal_objs)
            if done:
                print('\n all done ...')                
            else:
                print('\n not yet  ...')


        def get_energies(self, args):
            cal_objs = [self.get_cal(args)]
            for cal in cal_objs:
                cal.set_knob_responses()
                #print(cal.turn_knobs)
                #print(cal.response_to_knobs)
                cal.set_sorted_optimum_params()        
                logger.info('optium values for the knob parameters based on the cutoff criterion : {0}'.format(cal.optimum_knob_responses))


        def main(self):
            m_description = """
Management tool for vasp projects, starting from 
encut, kpoint or other parameter optimization of till the slab solvation.

it takes 3 arguments: input yaml file, type of calculation and the run mode
example:
   mpint -i naf.yaml -t bulk_calibrate run
this will read in the specifications for 'bulk_calibrate' job from the 
input yaml file, naf.yaml, and runs the job i.e submits to the que.

run modes supported:
  1. run : submits job to the que
  2. check: check whether the job is finished or not
  3. energies: get the energies and the optimum parameter values

Everytime jobs are submitted or its sttaus queried, information such as job ids, 
job folders etc are written to the log file 'mpint.log'. This makes it easier to
identify job ids and their corresponding to job folders.
 
Note: use your own materials project key to download the required 
structure

Note: this script submits jobs only to the PBS ques such as hipergator 
"""
            parser = ArgumentParser(description=m_description)
            parser.add_argument('-i', '--input', help="yaml input file")
            parser.add_argument('-t', '--type', help="type of calculation")

            subparsers = parser.add_subparsers(help='command', dest='command')

            cal_parser = subparsers.add_parser('run', help='run the specified job') 
            cal_parser.set_defaults(func=self.run)

            check_parser = subparsers.add_parser('check', help='check whether the calculation of the specified type for the given system is done or not')
            check_parser.set_defaults(func=self.check)

            check_parser = subparsers.add_parser('energies', help='get energies for the given system and calculation type')
            check_parser.set_defaults(func=self.get_energies)

            args = parser.parse_args()
            args.func(args) 


if __name__=='__main__':
    mpint().main()
