#!/usr/bin/env python

#===========================================================================#
#                                                                           #
#  File:       v2qe.py                                                      #
#  Dependence: pysupercell.py                                               #
#  Usage:      convert the POSCAR file to part of input file for PWscf(QE)  #      
#  Author:     Shunhong Zhang <szhang2@ustc.edu.cn>                         #
#  Date:       Jun 03, 2023                                                 #
#                                                                           #
#===========================================================================#

import sys
import numpy as np
from pysupercell.QE_ibrav_lib import *
from pysupercell import __version__
from pysupercell.pysupercell import cryst_struct,verbose_pkg_info
from pysupercell.arguments import str2bool
try: from termcolor import cprint,colored
except: pass
import os
import shutil

pyver=sys.version_info[0]

Note='''
This file can be used to generate input file for PWscf (Quantum ESPRESSO) by using VASP-POSCAR as input.
The definition of primitive cell basis vectors follows the Quantum ESPRESSO code, please refer to:
http://www.quantum-espresso.org/wp-content/uploads/Doc/INPUT_PW.html#idm6425376
'''
Usage='''
Usage: Please prepare the POSCAR file in the conventional cell form, use direct (fractional) coordinates to indicate atomic positions.
Then run this script by type the command: python v2qe.py
'''

Alert='''
Caution on the space group and ibrav when dealing with the following systems:
1.  Low dimensional materials: 
    The periodicity in the vacuum direction(s) are inrealistic
    so the 'spacegroup' may be wrong.
2.  Magnetic materials: 
    The spin polarization may adds extra properties to the atoms, 
    so the magnetic unit cell may differs from the chemical primitive cell.
3.  Body/Face/Base-centered structures: 
    The choice of the base plane is alternative, 
    please check the generated structure carefully using xcrysden.
'''

conv_cell_prompt='''
This is a structure with face/body/base-centered symmetry
The POSCAR you provide is a primitive cell
Use phonopy to generate a conventional cell (BPOSCAR)
and then try again
Good luck!
'''


def generate_standard_poscar_by_phonopy(poscar,symprec=5e-5,file_std_poscar='POSCAR_standardized'):
    from phonopy.interface.vasp import read_vasp,write_vasp
    from phonopy.structure.grid_points import get_symmetry_dataset
    from phonopy.structure import atoms
    struct = read_vasp(poscar)
    dataset = get_symmetry_dataset(struct,symprec=symprec)
    phonopy_struct_std = struct.copy()
    symbols = [atoms.atom_data[n][1] for n in dataset['std_types']]

    phonopy_struct_std._set_parameters(
    cell=dataset['std_lattice'],
    symbols=symbols,
    numbers=dataset['std_types'],
    scaled_positions=dataset['std_positions'])

    write_vasp(file_std_poscar,phonopy_struct_std,direct=True)


def gen_pw_nml(args,struct):
    ibrav = struct._get_ibrav()[0]
    pw_nml={}
    control_nml={
    'calculations':args.calculation,
    'restart_mode':'from_scratch',
    'outdir': args.outdir,
    'pseudo_dir': args.pseudo_dir,
    'prefix':args.prefix,
    'tprnfor':False}
    system_nml={
    'ibrav':ibrav}
    #'celldm':struct._find_celldm(ibrav)
    #}
    pw_nml.setdefault('CONTROL',control_nml)
    pw_nml.setdefault('SYSTEM',system_nml)
    return pw_nml



#============================================================#
# input file for quantum ESPRESSO (only for SCF calculation) #
#============================================================#


def write_pwi_nml(pwi_nml,fil='rx.in'):
    import f90nml
    with open(fil,'w') as fw:
       f90nml.write(pwi_nml,fw) 


def write_pwi(setup_dic,ibrav,struct_std,struct,filename="rx.in"):
    # important change: use cell of conventional cell, natom of primitive cell
    struct_std._natom = struct._natom
    
    try:
        import phonopy.structure.atoms as atoms
        atomic_mass=[atoms.atom_data[atoms.symbol_map[sym]][-1] for sym in struct._species]
    except:
        print ('Fail to load atomic mass from phonopy, atomic mass will be displayed as -1')
        atomic_mass = -np.ones(struct._natom)
 
    if filename:  filename=open(filename,"w")
    else: print ("\n{0}\nSample input for PWscf(Quantum ESPRESSO),Start\n{0}\n".format('-'*60))

    print ("&CONTROL", file=filename)
    print ("calculation = ","'vc-relax'", file=filename)
    print ("restart_mode = ","'from_scratch'", file=filename)
    print ("outdir = './tmp/'", file=filename)
    print ("pseudo_dir = ",setup_dic['pseudo_dir'], file=filename)
    print ("prefix = '{0}'".format(setup_dic['prefix']), file=filename)
    print ('tprnfor=.true.',file=filename)
    print ('tstress=.true.',file=filename)
    print ('etot_conv_thr=1.d-12',file=filename)
    print ('forc_conv_thr=1.d-8',file=filename)
    print ("/", file=filename)
    print ("&SYSTEM", file=filename)
    struct_std.write_pw_cell(filename=filename)
    print ("ecutwfc = ",setup_dic['ecutwfc'], file=filename)
    print ("ecutrho = ",setup_dic['ecutrho'], file=filename)
    print ("occupations = 'smearing'", file=filename)
    print ("smearing ='gaussian'", file=filename)
    print ("degauss = 0.001", file=filename)
    print ("/", file=filename)
    print ("&ELECTRONS", file=filename)
    print ("conv_thr=1.d-8", file=filename)
    print ("/", file=filename)
    print ('&IONS',file=filename)
    print ('/',file=filename)
    print ('&CELL',file=filename)
    print ("cell_dofree='2Dxy'",file=filename)
    print ('press_conv_thr=1.d-1',file=filename)
    print ('/',file=filename)
    if ibrav==0:
       print ('CELL_PARAMETERS angstrom', file=filename)
       print ('\n'.join([' '.join(['{0:15.10f}'.format(struct._cell[i,j]) for j in range(3)]) for i in range(3)]), file=filename)

    print ("ATOMIC_SPECIES", file=filename)
    for item,mass in zip(struct._species,atomic_mass):
        print ("{:2s} {:10.5f} {:>20s}".format(item,mass,item+setup_dic['upf']), file=filename)
 
    struct.write_pw_atoms(filename=filename)

    #print ("ATOMIC_POSITIONS crystal", file=filename)
    #fmt="{0:20.14f} {1:20.14f} {2:20.14f}"
    #for sym,atom in zip(struct._symbols,struct._pos):
    #    print ('{0:2s} '.format(sym) + fmt.format(*tuple(atom)),file=filename)

    print ("K_POINTS automatic", file=filename)
    print (setup_dic['kmesh'],setup_dic['kshift'], file=filename)
    if not filename: print ("\n{0}\nSample input for PWscf(Quantum ESPRESSO),End\n{0}\n".format('-'*60))
    try:    filename.close()
    except: pass



def get_args():
    import argparse
    parser = argparse.ArgumentParser(prog='v2qe.py', description = Note)
    parser.add_argument('--poscar',type=str,default='POSCAR',help='name of the POSCAR file')
    parser.add_argument('--symmprec', type=float, default=5e-4, help='deviation tolerance for finding crystal symmetry, in angstrom')
    parser.add_argument('--ecutwfc', type=float, default=100, help='plane wave cutoff')
    parser.add_argument('--ecutrho', type=float, default=500, help='charge density cutoff')
    parser.add_argument('--calculation',type=str,default="'vc-relax'",help="calculation task of QE")
    parser.add_argument('--outdir',type=str,default="'./tmp'",help="directory for temporary files")
    parser.add_argument('--prefix', type=str, default="pw", help='prefix for the pw calculation')
    parser.add_argument('--pseudo_dir',type=str,default="'/home/zsh/pseudo/pbe'",help="directory for pseudopotential files")
    parser.add_argument('--upf',    type=str, default=".pbe-mt_fhi.UPF",help="type of pseudopotentail")
    parser.add_argument('--kmesh', type=str, default='auto', help='k point mesh using the Monkhorst Pack scheme')
    parser.add_argument('--kshift', type=str, default="0 0 0", help='k point mesh shift from the Gamma point')
    parser.add_argument('--filpw',type=str,default='rx.in',help='QE pw.x input file with structures')
    parser.add_argument('--redefine_ibrav',type=str,default="n",help='Redefine ibrav manually or not')
    parser.add_argument('--verbose_notes',type=str2bool,default=False,help='Verbose notes for vasp to QE file conversion')
    args=   parser.parse_args()
    return parser, args

def verbose_head(verbose_notes):
    if verbose_notes:
        try:
            from termcolor import cprint
            cprint(Note,'cyan')
            cprint(Usage,'blue')
            cprint(Alert,"red")
            cprint(def_ibrav,'green')
        except:
            print ('{0}\n{1}\n{2}\n{3}'.format(Note,Usage,Alert,def_ibrav))
    else:
        print ('Use "v2qe --verbose_notes=T" to see notes for usage.\n')

 
def main(args,check_outdir='Sanity_check'):
    print ('\nrunning the script {0}\n'.format(__file__.lstrip('./')))
    verbose_head(args.verbose_notes)

    struct = cryst_struct.load_poscar(args.poscar)
    print ('VASP sutructure from {}'.format(args.poscar))
 
    if args.kmesh=='auto': kgrid = ('{:3d} '*3).format(*tuple(struct.get_kmesh(0.02)))
    else: kgrid=args.kmesh

    pw_setup_dic={ "prefix" : args.prefix,
                   "ecutwfc": args.ecutwfc,
                   "ecutrho": args.ecutrho,
                "pseudo_dir": args.pseudo_dir,
                   "upf"    : args.upf,
                   "kmesh"  : kgrid,
                   "kshift" : args.kshift}

    if os.path.isdir('Sanity_check')==False: os.mkdir('Sanity_check')
    filposcar = '{}/POSCAR_for_symm_analysis'.format(check_outdir)
    ibrav,brav,center = struct.get_ibrav(symmprec=args.symmprec,filposcar=filposcar)

    try: 
        ldef = args.redefine_ibrav
    except:
        redef_note="\nDo you want to define ibrav manually?(n/y, default: n)\n"
        try: redef=colored(redef_note,"red")
        except: redef=redef_note
        if pyver==2: ldef=raw_input(redef)
        else: ldef=input(redef)
    if ldef=="y": ibrav=input(colored("ibrav = ","red"))
    ibrav=int(ibrav)

    connect=get_connect(ibrav,struct.latt_param())
    print ( "connect matrix\n")
    print ((('{:7.4f} '*3+'\n')*3+'\n').format(*tuple(connect.flatten())))

    generate_standard_poscar_by_phonopy(poscar=args.poscar,symprec=args.symmprec,file_std_poscar='{}/POSCAR_standardized'.format(check_outdir))
    struct_std = cryst_struct.load_poscar('{}/POSCAR_standardized'.format(check_outdir))
    if ldef=='y':struct_std=struct
    celldm=struct_std.find_celldm(ibrav=ibrav)
    print ("Lattice Parameters of standardized POSCAR:")
    struct_std.print_latt_param()
    
    struct_pm = struct_std.build_supercell(connect)
    struct_pm._system='primitive cell'
    struct_pm.write_poscar_head(filename="{}/POSCAR_Primitive".format(check_outdir))
    struct_pm.write_poscar_atoms(filename="{}/POSCAR_Primitive".format(check_outdir),mode='a')

    write_pwi(pw_setup_dic,ibrav,struct_std,struct_pm,filename=None)
    write_pwi(pw_setup_dic,ibrav,struct_std,struct_pm,filename=args.filpw)
    #struct._visualize_struct()


desc_str = 'Convert POSCAR into QE pwscf input'

if __name__=='__main__':
    verbose_pkg_info(__version__)
    parser, args=get_args()
    try: cprint (desc_str,'green')
    except: print (desc_str)
    main(args)
