#!python

from __future__ import division, unicode_literals, print_function, \
    absolute_import

"""
This script serves as a management tool for vasp projects, starting 
from encut, kpoint or other parameter optimization of till the slab 
solvation. Just define all types of calculations with their 
corresponding specifications needed for the project in a yaml file 
and run or rerun calculaitons as required.

Note: use your own materials project key to download the required 
structure
"""

from six.moves import range

import shutil
import yaml
from argparse import ArgumentParser
from fnmatch import fnmatch

from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.vasp.inputs import Incar
from pymatgen.io.vasp.inputs import Potcar, Kpoints

from mpinterfaces import get_struct_from_mp
from mpinterfaces.calibrate import Calibrate, CalibrateSlab
from mpinterfaces.interface import Interface
from mpinterfaces.utils import *

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(levelname)s:%(asctime)s:%(message)s')
fh = logging.FileHandler('mpint.log', mode='a')
sh = logging.StreamHandler(stream=sys.stdout)
fh.setFormatter(formatter)
sh.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(sh)

MAPI_KEY = os.environ.get("MAPI_KEY", "")


def process_dir(val):
    poscar_list = []
    if os.path.isdir(val):
        for f in os.listdir(val):
            fname = os.path.join(val, f)
            if os.path.isfile(fname) and fnmatch(fname, "*POSCAR*"):
                poscar_list.append(Poscar.from_file(fname))
    return poscar_list


class mpint(object):
    def __init__(self):
        self.formula = None
        self.qadapter = None
        self.job_cmd = None
        self.cal_job_dir = None
        self.turn_knobs = OrderedDict(
            [
                ('KPOINTS', []),
                ('ENCUT', [])
            ]
        )
        self.incar_dict = {
            'PREC': 'Accurate',
            'ENCUT': 500,
            'ISMEAR': 1,
            'EDIFF': 1e-06,
            'NPAR': 4,
            'SIGMA': 0.1,
            'NSW': 0,
            'ISIF': 3,
            'IBRION': 2
        }
        self.incar = Incar.from_dict(self.incar_dict)
        # some default value
        self.poscar = Poscar.from_string('Li1\n' +
                                         '1.0\n' +
                                         '2.797985 0.000000 -0.989237\n' +
                                         '-1.398992 2.423126 -0.989237\n' +
                                         '0.000000 0.000000 2.967711\n' +
                                         'Li\n' +
                                         '1\n' +
                                         'direct\n' +
                                         '0.000000 0.000000 0.000000 Li\n')
        self.kpoints = Kpoints.monkhorst_automatic(kpts=(8, 8, 8))
        self.potcar = Potcar(self.poscar.site_symbols)
        self.slab = {}

    def get_struct(self):
        if self.formula:
            if MAPI_KEY:
                strt = get_struct_from_mp(self.formula, MAPI_KEY=MAPI_KEY)
                sa = SpacegroupAnalyzer(strt)
                structure_conventional = sa.get_conventional_standard_structure()
                strt = structure_conventional.copy()
                logger.info(
                    "obtaining structure from materialsproject and coverting it to conventional cell..should verify")
                logger.info(
                    "the structure obtained from the materials project correspond to the one with the lowest energy above the hull")

            else:
                print(
                    'get API KEY from materialsproject and set it to the MAPI_KEY environment variable')
                sys.exit()
        else:
            strt = None
        return strt

    def get_slab_struct(self, fin, hkl, min_thick=10, min_vac=12):
        logger.info("ase is used to create an orthogonal slab")
        logger.info("bulk structure file: {0}".format(fin))
        logger.info(
            "hkl = {0}, min_thick = {1}, min_vac = {2}".format(hkl, min_thick,
                                                               min_vac))
        bulk = Structure.from_file(fin)
        iface = Interface(bulk, hkl=hkl,
                          min_thick=min_thick, min_vac=min_vac,
                          primitive=False, from_ase=True)
        iface.create_interface()
        iface.sort()
        return iface

    def update_inputset(self, update, from_mp=True):
        logger.info("updating the inputset")
        logger.info("note: using monkhorst automatic kpoints")
        nprocs = 32
        nnodes = 1
        bin = '/home/km468/Software/VASP/vasp.5.3.5/vasp'
        mem = 1000
        walltime = '04:00:00'
        if update.get('POSCAR'):
            logger.info(
                "setting poscar from file {0}".format(update['POSCAR']))
            self.poscar = Poscar.from_file(update['POSCAR'])
        elif update.get('SLAB'):
            self.slab = update['SLAB']
            if len(self.slab) > 2:
                logger.info("creating slab")
                strt = self.get_slab_struct(*self.slab)
                self.poscar = Poscar(strt)
        elif from_mp:
            logger.info(
                "neither poscar nor slab creation setup provided, will try to get the bulk structure from the materialsproject database(if FORMULA is set)")
            strt = self.get_struct()
            if strt:
                self.poscar = Poscar(strt)
            else:
                logger.warn("No formula or poscar file provided")
        logger.info("setting potcar from poscar symbols")
        self.potcar = Potcar(self.poscar.site_symbols)
        self.incar_dict['SYSTEM'] = self.formula
        if update.get('INCAR'):
            logger.info("updating INCAR")
            self.incar_dict.update(update['INCAR'])
            self.incar = Incar.from_dict(self.incar_dict)
        if update.get('KPOINTS'):
            logger.info("updating KPOINTS")
            self.kpoints = Kpoints.monkhorst_automatic(
                kpts=update['KPOINTS'])  # automatic(int(update['KPOINTS']))
        if update.get('QUE'):
            logger.info("updating que settings")
            if update['QUE'].get('NPROCS'):
                nprocs = update['QUE']['NPROCS']
            if update['QUE'].get('NNODES'):
                nnodes = update['QUE']['NNODES']
            if update['QUE'].get('MEM'):
                mem = update['QUE']['MEM']
            if update['QUE'].get('TIME'):
                walltime = str(update['QUE']['TIME']) + ':00:00'
            if update['QUE'].get('BIN'):
                bin = update['QUE']['BIN']
            self.qadapter, self.job_cmd = get_run_cmmnd(nnodes,
                                                        nprocs,
                                                        walltime,
                                                        bin, mem)
        if update.get('WAVECAR'):
            wavecar_file = update['WAVECAR']
            if os.path.isfile(wavecar_file):
                logger.info(
                    "copying wavecar from {0} to {1}".format(wavecar_file,
                                                             self.cal_job_dir))
                if not os.path.exists(self.cal_job_dir):
                    os.makedirs(self.cal_job_dir)
                shutil.copy(wavecar_file,
                            self.cal_job_dir + os.sep + 'WAVECAR')

    def process_input(self, args):
        update = {}
        kptlist = []
        encutlist = []
        if args.input:
            f = open(args.input)
            update = yaml.load(f)
            self.formula = update.get('FORMULA')
            f.close()
            if self.formula:
                self.cal_job_dir = self.formula + '_single'
            else:
                self.cal_job_dir = 'vasp_job'
            if args.type:
                if self.formula:
                    self.cal_job_dir = self.formula + '_' + args.type
                else:
                    self.cal_job_dir = args.type
            logger.info(
                "{0} type job for {1} system".format(args.type, self.formula))
            logger.info("job folder: {0}".format(self.cal_job_dir))
            # initial input set
            if args.command == 'run':
                self.update_inputset(update, from_mp=False)
            # run specific input set
            # update --> run specific dict
            if update.get(args.type):
                update = update[args.type]
                if args.command == 'run':
                    self.update_inputset(update)
                if update.get('SLAB'):
                    self.slab = update['SLAB']
                if update.get('KNOBS'):
                    logger.info("updating knob settings")
                    knobset = update.get('KNOBS')
                    for key, val in knobset.items():
                        paramlist = []
                        if key == 'KPOINTS':
                            paramlist = [[k, k, k] for k in
                                         range(knobset['KPOINTS'][0],
                                               knobset['KPOINTS'][1],
                                               knobset['KPOINTS'][2])]
                        elif key == 'POSCAR':
                            if type(val) == list:
                                for pfile in val:
                                    if os.path.isdir(pfile):
                                        paramlist += process_dir(pfile)
                                    else:
                                        paramlist.append(
                                            Poscar.from_file(str(pfile)))
                            else:
                                paramlist = process_dir(val)
                        else:
                            paramlist = np.arange(knobset[key][0],
                                                  knobset[key][1],
                                                  knobset[key][2])
                        self.turn_knobs.update({key: paramlist})

    def get_cal(self, args):
        self.process_input(args)
        cal = None
        if self.turn_knobs.get('VACUUM') or self.turn_knobs.get('THICKNESS'):
            logger.info("slab calibration: thickness or vacuum")
            logger.info("bulk structure from file: {0}".format(self.slab[0]))
            self.poscar = Poscar.from_file(self.slab[0])
            logger.info("hkl: {0}".format(self.slab[1]))
            system = {'hkl': self.slab[1]}
            cal = CalibrateSlab(self.incar, self.poscar,
                                self.potcar, self.kpoints,
                                turn_knobs=self.turn_knobs,
                                Grid_type='M',
                                qadapter=self.qadapter,
                                job_dir=self.cal_job_dir,
                                job_cmd=self.job_cmd, wait=False,
                                system=system, from_ase=True)
        else:
            cal = Calibrate(self.incar, self.poscar,
                            self.potcar, self.kpoints,
                            turn_knobs=self.turn_knobs,
                            Grid_type='M',
                            qadapter=self.qadapter,
                            job_dir=self.cal_job_dir,
                            job_cmd=self.job_cmd, wait=False)

        cal.setup()
        return cal

    def run(self, args):
        cal_objs = [self.get_cal(args)]
        for cal in cal_objs:
            cal.run()
            logger.info(
                "the job ids for the job in the directory {0} are {1}".format(
                    cal.job_dir_list, cal.job_ids))

    def get_energies(self, args):
        cal_objs = [self.get_cal(args)]
        for cal in cal_objs:
            cal.set_knob_responses()
            # print(cal.turn_knobs)
            # print(cal.response_to_knobs)
            cal.set_sorted_optimum_params()
            logger.info(
                'optium values for the knob parameters based on the cutoff criterion : {0}'.format(
                    cal.optimum_knob_responses))

    def update(self, args):
        self.incar = None
        self.kpoints = None
        self.qadapter = None
        if args.input:
            f = open(args.input)
            update = yaml.load(f)
            self.update_inputset(update, from_mp=False)
        if args.jids:
            json_files = []
            for i in args.jids:
                if '.json' in i:
                    json_files.append(i)
            if json_files:
                for jf in json_files:
                    update_checkpoint(job_ids=args.jids,
                                      jfile=jf,
                                      incar=self.incar,
                                      kpoints=self.kpoints,
                                      que=self.qadapter)
            else:
                update_checkpoint(job_ids=args.jids,
                                  incar=self.incar,
                                  kpoints=self.kpoints,
                                  que=self.qadapter)
        else:
            update_checkpoint()

    def main(self):
        m_description = """
Management tool for vasp projects, starting from 
encut, kpoint or other parameter optimization of till the slab solvation.

it takes 3 arguments: input yaml file, type of calculation and the run mode
example:
   mpint -i naf.yaml -t bulk_calibrate run
this will read in the specifications for 'bulk_calibrate' job from the 
input yaml file, naf.yaml, and runs the job i.e submits to the que.

run modes supported:
  1. run : submits job to the que

Everytime jobs are submitted or its sttaus queried, information such as job ids, 
job folders etc are written to the log file 'mpint.log'. This makes it easier to
identify job ids and their corresponding to job folders.
 
Note: use your own materials project key to download the required 
structure

Note: this script submits jobs only to the PBS ques such as hipergator 
"""
        parser = ArgumentParser(description=m_description)
        parser.add_argument('-i', '--input', help="yaml input file")
        parser.add_argument('-t', '--type', help="type of calculation")

        subparsers = parser.add_subparsers(help='command', dest='command')

        cal_parser = subparsers.add_parser('run', help='run the specified job')
        cal_parser.set_defaults(func=self.run)

        update_parser = subparsers.add_parser('update',
                                              help='update/rerun the checkpoint file calibrate.json ')
        update_parser.add_argument('jids', type=str, nargs='*',
                                   help='list of job ids')
        update_parser.set_defaults(func=self.update)

        args = parser.parse_args()
        args.func(args)


if __name__ == '__main__':
    mpint().main()
