#!/usr/bin/env python


#===========================================================================#
#                                                                           #
#  File:       qe2v.py                                                      #
#  Dependence: crystal_structure.py                                         #
#  Usage:      parse data from QE and plot                                  #      
#  Author:     Shunhong Zhang <szhang2@ustc.edu.cn>                         #
#  Date:       Feb 24, 2020                                                 #
#                                                                           #
#===========================================================================#

from __future__ import print_function
import os
import numpy as np
from pysupercell import __version__
from pysupercell.pysupercell import *
import collections
import re


def make_cell_qe(ibrav,celldm,verbosity=True):
    if verbosity:
        print ('\nbuild cell from the following parameters:')
        print ('ibrav={0}'.format(ibrav))
        for i in range(1,7):
            print ('celldm({0}) ={1:12.6f}'.format(i,celldm[i]))
        print ('make sure they are correct\n')
    def dv(): exit ('make_cell_qe: under development!')
    a=celldm[1]
    b=a*celldm[2]
    c=a*celldm[3]
    cell=np.zeros((3,3),float)
    if ibrav in [1,2,3]:
        cell=np.diag([celldm[1]]*3)
    if ibrav==4:
        cell[0,0]=a
        cell[1,0]=a*np.cos(2*np.pi/3)
        cell[1,1]=a*np.sin(2*np.pi/3)
        cell[2,2]=c
    elif ibrav==5:
        dv()
    elif ibrav==6:
        cell=np.diag([a,a,c])
    elif ibrav==7:
        cell[0]=np.array([ a,-a,c])/2.
        cell[1]=np.array([ a, a,c])/2.
        cell[2]=np.array([-a,-a,c])/2.
    elif ibrav in [8,9]:
        cell=np.diag([a,b,c])
    elif ibrav==10:
        dv()
    elif ibrav==11:
        dv()
    elif abs(ibrav)==12:
        cell[0,0]=a
        if ibrav==12:
            cell[2,2]=c
            cell[1,0]=b*celldm[4]
            cell[1,1]=b*np.sqrt(1-celldm[4]**2)
        elif ibrav==-12:
            cell[1,1]=b
            cell[2,0]=c*celldm[4]
            cell[2,2]=c*np.sqrt(1-celldm[4]**2)
    elif abs(ibrav)==13:
        if ibrav==13:
            cell[0]=np.array([a/2,0,-c/2])
            cell[1]=np.array([b*celldm[4],b*np.sqrt(1-celldm[4]**2),0])
            cell[2]=np.array([a/2,0, c/2])
        elif ibrav==-13:
            cell[0]=np.array([ a/2,b/2,0])
            cell[1]=np.array([-a/2,b/2,0])
            cell[2]=np.array([c*celldm[4],0,c*np.sqrt(1-celldm[4]**2)])
    elif ibrav==14:
       cell[0,0]=a
       cell[1,0]=b*celldm[6]
       cell[1,1]=b*np.sqrt(1-celldm[6]**2)
    elif ibrav==0:
       print ('free lattice, no celldm implemented!')
    return cell*Bohr_to_Angstrom


def parse_pwi_nml(filpw='scf.in'):
    import f90nml
    nml=f90nml.read(filpw)
    return nml

def parse_pwin_struct(filpw='scf.in'):
    if not os.path.isfile(filpw): exit('cannot find {0}'.format(filpw))
    lines=open(filpw).readlines()
    nml=parse_pwi_nml(filpw=filpw)
    ibrav=nml['SYSTEM']['ibrav']
    celldm_=np.array(nml['SYSTEM']['celldm'])
    idx=np.where(celldm_)[0]
    celldm=np.zeros(7,float)
    celldm[idx+1]=celldm_[idx]
    nat=nml['SYSTEM']['nat']
    get_species=[item.rstrip('\n').split() for item in os.popen('grep -i upf {0}'.format(filpw)).readlines()]
    species=[item[0] for item in get_species]
    idx=np.where([re.search('ATOMIC_POSITIONS',line) for line in lines])[0][-1]
    get_pos=lines[idx+1:idx+nat+1]
    symbols=[get_pos[i].split()[0] for i in range(nat)]
    cc=collections.Counter(symbols)
    counts=np.array([cc[ispec] for ispec in species])
    pos=np.array([np.array(item.split()[1:],float) for item in get_pos])
    cell=make_cell_qe(ibrav,celldm)
    return cell,species,counts,pos


def parse_pwout_struct(filpw='rx.out'):
    if not os.path.isfile(filpw): exit('cannot find {0}'.format(fil))
    lines=open(filpw).readlines()
    idx=np.where([re.search('bravais-lattice',line) for line in lines])[0][0]
    ibrav=int(lines[idx].split()[3])
    idx=np.where([re.search('CELL_PARAMETERS',line) for line in lines])[0][-1]
    get_cell=lines[idx:idx+4]
    alat=float(get_cell[0].split()[2].rstrip('\)'))*Bohr_to_Angstrom
    cell=alat*np.array([item.split() for item in get_cell[1:]],float)
    idx=np.where([re.search('number of atoms',line) for line in lines])[0][-1]
    nat=int(lines[idx].split()[-1])
    idx=np.where([re.search('ATOMIC_POSITIONS',line) for line in lines])[0][-1]
    get_at=lines[idx:idx+nat+1]
    symbols=[get_at[i].split()[0] for i in range(1,nat+1)]
    pos=np.array([[float(coord) for coord in get_at[i].split()[1:]] for i in range(1,nat+1)])
    from collections import Counter
    counts_dic = Counter(symbols)
    counts = np.array([counts_dic[symbol] for symbol in counts_dic.keys()])
    idx=np.where([re.search('atomic species',line) for line in lines])[0][1]+1
    species = [item.split()[0] for item in lines[idx:idx+len(counts)]]
    return cell,species,counts,pos


def get_args(desc_str='qe2v'):
    import argparse
    import astk.utility.arguments as arguments
    parser = argparse.ArgumentParser(prog='qe2v.py', description = desc_str)
    arguments.add_control_arguments(parser)
    arguments.add_io_arguments(parser)
    args = parser.parse_args()
    if args.source!='pwi' and args.source!='pwo': args.source='pwi'
    return parser,args



def main(args):
    print ('Running qe2v')
    if args.source=='pwi':      
        print ('Convert QE PWscf input file into POSCAR')
        struct_data=parse_pwin_struct(filpw=args.filpw)
    elif args.source=='pwo':    
        print ('Convert QE PWscf output file into POSCAR')
        struct_data=parse_pwout_struct(filpw=args.filpw)
    else: 
        exit('--source can only be pwi or pwo for input or output files respectively!')

    print ('Structure from QE file {}'.format(args.filpw))
    print ('Converted structure written in POSCAR\n')

    cell,species,counts,pos = struct_data
    struct = cryst_struct(cell,species,counts,pos)
    struct.write_poscar_head()
    struct.write_poscar_atoms(postype='direct')

    print ('\n{0}\nQE structure for double-checking\n{0}'.format('-'*80))
    struct.write_pw_cell()
    struct.write_pw_atoms()
    print ('{0}\nQE structure for double-checking\n{0}\n'.format('-'*80))



if __name__=='__main__':
    verbose_pkg_info(__version__)
    parser,args=get_args()
    main(args)


