#!/usr/bin/env python


#==============================================================#
#                                                              #
#  File:       pysc.py                                         #
#  Usage:      define a class for crystal structures           #      
#  Author:     Shunhong Zhang <szhang2@ustc.edu.cn>            #
#  Date:       Jun 03, 2023                                    #
#                                                              #
#==============================================================#


import os
import numpy as np
import warnings
import itertools
import copy 
from pysupercell import __version__
from pysupercell.pysupercell import *
from pysupercell.arguments import str2bool, add_control_arguments, add_io_arguments, add_structure_arguments
 

def get_args(desc_str):
    import argparse
    parser = argparse.ArgumentParser(prog='pysupercell', description = desc_str)
    add_control_arguments(parser)
    add_io_arguments(parser)
    add_structure_arguments(parser)
    parser.add_argument('--strain',type=eval,default=0,help='magnitude of strain')
    parser.add_argument('--strain_dirs',type=eval,default=None,help='directions for strain, along crystal axes')
    parser.add_argument('--angle',type=float,default=0,help='angle to rotate the crystal axes')
    parser.add_argument('--symmprec',type=float,default=1E-4,help='accuracy to find crystal symmetry in Angstrom')
    parser.add_argument('--case',type=str,default='case',help='case name for wien2k input')
    parser.add_argument('--kgrid',type=float,default=0.02,help='density of kgrid, in unit of 2pi/Angs')
    parser.add_argument('--idir_shift',type=int,default=2,help='latt vector index, for shifting')
    parser.add_argument('--shift',type=float,default=1,help='atom shift distance along certain direction')
    parser.add_argument('--to_home',type=str2bool,default=True,help='shift atom coordinates to home cell or not')
    parser.add_argument('--inv_center',type=eval,default=(0,0,0),help='Inversion center to transform the structure')
    parser.add_argument('--nimages',type=int,default=5,help='Number of images to interpolate two structures')
    args = parser.parse_args()
    return parser, args



def get_twin_structures(poscar1,poscar2):
    import operator
    print ('Please make sure that all atoms are ordered in the same sequence!\n')
    print ('Structure 1 from file {}'.format(poscar1))
    print ('Structure 2 from file {}'.format(poscar2))
    assert os.path.isfile(poscar1),'cannot find {}'.format(poscar1)
    assert os.path.isfile(poscar2),'cannot find {}'.format(poscar2)
    struct_1=cryst_struct.load_poscar(poscar1)
    struct_2=cryst_struct.load_poscar(poscar2)
    assert np.prod(operator.eq(struct_1._species,struct_2._species)), 'Twin-structure error: species inconsistent!'
    assert np.prod(operator.eq(struct_1._counts,struct_2._counts)), 'Twin-structure error: No. of atoms inconsistent!'
    assert struct_1._natom==struct_2._natom,'Twin-structure error numbers of atoms inconsistent'
    return struct_1, struct_2


def sort_structs_atoms(poscar1,poscar2,ncell=1,boundary_condition=[1,1,0]):
    def bond_length(pos1,pos2,cell,cell_idx=True,ncell=1,boundary_condition=[1,1,1]):
        Rvecs = gen_Rvecs(ncell,boundary_condition)
        dists = np.linalg.norm( np.dot(pos1 - pos2 + Rvecs, cell), axis=-1)
        bond = np.min(dists)
        if cell_idx: return bond,np.where(dists==bond)[0][0]
        return bond

    struct_1, struct_2 = get_twin_structures(poscar1,poscar2)
    pos_to_sort = copy.deepcopy(struct_2._pos) 
    sort_idx = np.zeros(struct_2._natom,int)
    sort_iR = np.zeros(struct_2._natom,int)
    for iat in range(struct_2._natom):
        current_bond=100
        for jat in range(struct_1._natom):
            bond,iR = bond_length(struct_1._pos[jat], struct_2._pos[iat] ,struct_2._cell, 
            boundary_condition = boundary_condition)
            if bond < current_bond: 
                current_bond = bond
                sort_idx[iat]=jat
                sort_iR[iat] = iR
    Rvecs = gen_Rvecs(ncell,boundary_condition=[1,1,0])
    struct_2._pos = pos_to_sort[sort_idx] 
    struct_2._pos -= Rvecs[sort_iR]
    struct_2._pos_cart = np.dot(struct_2._pos, struct_2._cell)
    struct_2.write_poscar_head(filename='POSCAR_sorted')
    struct_2.write_poscar_atoms(filename='POSCAR_sorted',mode='a')
    return struct_1,struct_2
 


def compare_structs(poscar1,poscar2):
    print ('\nComparing two crystal structures\n')
    struct_1, struct_2 = get_twin_structures(poscar1,poscar2)

    nat = struct_1._natom
    diff=np.zeros(nat)
    print ('\n{}'.format('='*60))
    print (('{:4s} '*3+'{:>8s} '*4).format('idx','st1','st2','dist (Ang)','dx','dy','dz'))
    print ('-'*60)
    Rvecs = gen_Rvecs()
    for iat in range(nat):
        images = np.dot(struct_2._pos[iat]+Rvecs,struct_2._cell)
        norms = np.linalg.norm(struct_1._pos_cart[iat]-images,axis=1)
        diff[iat]=np.min(norms)
        idx = np.where(norms==np.min(norms))[0][0]
        print (' {:<4d} {:4s} {:4s} '.format(iat+1,struct_1._symbols[iat],struct_2._symbols[iat]),end=' ')
        print (('{:8.4f} '*4).format(diff[iat],*tuple(struct_1._pos_cart[iat] - images[idx])))
    print ('{0}'.format('='*60))
    print ('{0:14s}'.format('Total dist : ')+'   {:8.4f}'.format(np.sum(diff)))
    print ('{0:14s}'.format('Max   dist : ')+'   {:8.4f}\n'.format(np.max(diff)))
    return diff


def interpolate_structs(poscar1,poscar2,nimages=5,outdir='interpolated_images',f_archive='XDATCAR',sort_atom=True):
    print ('\nInterpolating images between two crystal structures\n')
    if sort_atom:
        struct_1, struct_2 = sort_structs_atoms(poscar1,poscar2)
        poscar2 = 'POSCAR_sorted' 
    struct_1, struct_2 = get_twin_structures(poscar1,poscar2)
    if os.path.isdir(outdir): os.system('rm -r {}'.format(outdir))
    assert nimages>=1, 'Number of images should not be less than 1!'
    diff_cell = (struct_2._cell - struct_1._cell)/(nimages+1)
    diff_pos_cart  = (struct_2._pos_cart - struct_1._pos_cart)/(nimages+1)
    os.mkdir(outdir)
    st_images = []

    st_arch = copy.deepcopy(struct_1)
    st_arch._system = 'All images'
    st_arch.write_poscar_head(filename=f_archive)
    with open(f_archive,'a') as fw: 
        fw.write(' '.join(['{:4s}'.format(item) for item in st_arch._species])+'\n')
        fw.write(' '.join(['{:4d}'.format(item) for item in st_arch._counts])+'\n')
 
    for im in range(nimages):
        st_im = copy.deepcopy(struct_1)
        st_im._system = 'Image {}'.format(im+1)
        st_im._cell += diff_cell*(im+1)
        st_im._pos_cart += diff_pos_cart*(im+1)
        st_im._pos = np.dot(st_im._pos_cart, np.linalg.inv(st_im._cell))
        filpos = '{}/POSCAR_im_{}'.format(outdir,im+1)
        st_im.write_poscar_head(filename=filpos)
        st_im.write_poscar_atoms(filename=filpos,mode='a')
        st_images.append(st_im)
        with open(f_archive,'a') as fw: 
            fw.write('Direct configuration = {}\n'.format(im+1))
            for pos in st_im._pos: fw.write(('{:22.15f} '*3+'\n').format(*tuple(pos)))
    return st_images


task_list=[
'crystal_info',
'redefine',
'slab',
'tube',
'strain',
'bond',
'rotate_z',
'shift',
'inversion',
'kmesh',
'reset_vacuum',
'bending',
'screw_dislocation',
'wien',
'write_pw',
'compare',
'cmp',
'interpolate',
'None',
]



desc_str='''
input exmaple:
    {0}
    {0} --task=redefine --sc1=1,-1,0 --sc2=1,1,0 sc3=0,0,1
    {0} --task=slab --hkl=121
    {0} --task=tube --chiral_num=2,4
    {0} --task=bond --atom_index=0,1
    {0} --task=wien
    {0} --task=crystal_info
    {0} --task=cmp --poscar1=POSCAR --poscar2=CONTCAR
    {0} --task=strain --dirs=0,1 --strain=0.01
    {0} --task=kmesh
    {0} --task=bending --nn=8 --idir_per=1 --idir_bend=2
    {0} --task=shift --idir_shift=2 --shift=10
    {0} --task=screw_dislocation --burgers_vector=[0,0,1] --screw_idir=2 --display_structure=True
'''.format(__file__.split('/')[-1])



def main(args,task_list,desc_str):

    print ('task = {}'.format(args.task))
    if args.task==None or args.task=='None': 
        try: cprint (desc_str,'green')
        except: print (desc_str)
        return
    if args.task not in task_list:
        print ('\n{}\navailable tasks:'.format('-'*30))
        print ('\n'.join([task for task in task_list]))
        print ('\n{0}'.format('-'*30))
        return 


    if args.task!='cmp' and args.task!='compare':
        if args.source=='VASP':
            struct = cryst_struct.load_poscar(args.poscar)
        elif args.source=='QE':
            struct = cryst_struct.load_pwscf_in(args.filpw)
        elif args.source=='cif':
            struct = cryst_struct.load_cif(args.filcif)

    if args.task=='crystal_info': 
        struct.verbose_crystal_info()

    elif args.task=='redefine':
        filpos = 'POSCAR_redefine'
        redef_struct = struct.redefine_lattice(args.sc1,args.sc2,args.sc3,args.cell_orientation)
        redef_struct.write_poscar_head(filename=filpos)
        redef_struct.write_poscar_atoms(filename=filpos,mode='a')

    elif args.task=='reset_vacuum':
        redef_struct=copy.copy(struct)
        redef_struct._cell[2,2]=args.vacuum
        central_z = np.average(redef_struct._pos_cart[:,2])
        redef_struct._pos_cart[:,2] -= central_z + redef_struct._cell[2,2]/2
        redef_struct._pos=np.dot(redef_struct._pos_cart,np.linalg.inv(redef_struct._cell))
        redef_struct.shift_atoms_to_home()
        flpos='POSCAR_reset_vacuum'
        redef_struct.write_poscar_head(filename=flpos)
        redef_struct.write_poscar_atoms(filename=flpos,mode='a')
        
    elif args.task=='rotate_z':
        a=struct.latt_param()['a']
        b=struct.latt_param()['b']
        gamma=struct.latt_param()['gamma']
        cell_1=[a*cos((args.angle)/180*np.pi),a*np.sin((args.angle)/180*np.pi),0]
        cell_2=[b*cos((gamma+args.angle)/180*np.pi),b*np.sin((gamma+args.angle)/180*np.pi),0]
        struct._cell=np.array([cell_1,cell_2,struct._cell[2]])
        struct._system='POSCAR_rotate_z'
        flpos='POSCAR_rotate_z'
        struct.write_poscar_head(filename=flpos)
        struct.write_poscar_atoms(filename=flpos,mode='a')

    elif args.task=='inversion':
        inversion_center_cart = np.dot(args.inv_center,struct._cell)
        struct_inv = copy.deepcopy(struct)
        struct_inv._pos_cart = 2 * inversion_center_cart - struct._pos_cart
        struct_inv._pos = np.dot(struct_inv._pos_cart, np.linalg.inv(struct_inv._cell) )
        flpos='POSCAR_inv'
        struct_inv.write_poscar_head(filename=flpos)
        struct_inv.write_poscar_atoms(filename=flpos,mode='a')

    elif args.task=='slab':
        hkl=args.hkl
        if len(hkl)>3:
            exit('we do not allow negative index!')
        h,k,l=list(map(int,hkl))
        struct_slab = struct.build_slab(h,k,l,args.thickness,args.vacuum,args.atom_shift)
        struct_slab._system='slab'
        flpos='POSCAR_slab_{}'.format(hkl)
        struct_slab.write_poscar_head(filename=flpos)
        struct_slab.write_poscar_atoms(filename=flpos,mode='a')

    elif args.task=='tube':
        n,m=args.chiral_num
        struct_tube = struct.build_tube(n,m,negative_R=args.negative_R)
        fltube = "POSCAR_tube_{0}_{1}".format(n,m)
        struct_tube._system=system='{0}_{1}_nanotube'.format(n,m)
        struct_tube.write_poscar_head(filename=fltube)
        struct_tube.write_poscar_atoms(filename=fltube,mode='a')

    elif args.task=='bending':
        struct_bend = struct.build_bending_supercell(args.nn,args.amp,args.idir_per,args.idir_bend,args.central_z)
        struct_bend._system = 'bending struct'
        flpos='POSCAR_bending'
        struct_bend.write_poscar_head(filename=flpos)
        struct_bend.write_poscar_atoms(filename=flpos,mode='a')

    elif args.task=='strain':
        for idir in args.strain_dirs:
            struct._cell[idir] *= 1+args.strain
        struct._pos_cart=np.dot(struct._pos,struct._cell)
        filpos='POSCAR_{0:5.3f}'.format(args.strain)
        struct.write_poscar_head(filename=filpos)
        struct.write_poscar_atoms(filename=filpos)

    elif args.task=='shift':        
        struct.shift_pos(args.idir_shift,args.shift,to_home=args.to_home)
        filpos='POSCAR_shifted'
        struct.write_poscar_head(filename=filpos)
        struct.write_poscar_atoms(filename=filpos)

    elif args.task=='bond':
        i,j=args.atom_index
        cc=np.append(0,np.cumsum(struct._counts))
        ic=i-cc[np.where(i-cc>=0)[0][-1]]+1
        jc=j-cc[np.where(j-cc>=0)[0][-1]]+1
        print ('\n{0}\natomic positions in the home cell\n{0}'.format('-'*80))
        print ('symbol  atom '+('{:>10s} '*6).format('x','y','z','x_cart','y_cart','z_cart'))
        for ii,jj in zip((i,j),(ic,jc)):
            print ('{:>6s} {:5d} '.format(struct._symbols[ii],jj),end='')
            print (' '.join(['{:10.5f}'.format(item) for item in np.append(struct._pos[ii],struct._pos_cart[ii])]))
        print ( '-'*80)
        output='min distance between {}{} and {}{} is: {:8.5f} Angstrom'
        print (output.format(struct._symbols[i],ic,struct._symbols[j],jc,struct.bond_length(i,j)))

    elif args.task=='screw_dislocation':
        burgers_vector = np.dot(args.burgers_vector,struct._cell)
        struct_screw = struct.make_screw_dislocation(burgers_vector,args.screw_center,args.screw_normal,args.screw_idir)
        struct_screw._system = 'screw dislocation structure'
        struct_screw.write_poscar_head(filename='POSCAR_screw')
        struct_screw.write_poscar_atoms(filename='POSCAR_screw',mode='a')
        if args.display_structure: map_data(struct_screw._cell,struct_screw._pos)

    elif args.task=='Select_Dynamics' or args.task=='SD':
        struct.write_poscar_head(filename='POSCAR_sd')
        struct.write_poscar_atoms(selective_dynamics=True,fix_dirs=args.dirs,filename='POSCAR_sd')

    elif args.task=='wien':         
        struct.write_wien2k_struct(args.case,symmprec=args.symmprec)

    elif args.task=='write_pw':
        struct.write_pw_cell()
        struct.write_pw_atoms()

    elif args.task=='cmp' or args.task=='compare':
        compare_structs(args.poscar1,args.poscar2)

    elif args.task=='interpolate':
        st_images = interpolate_structs(args.poscar1,args.poscar2,args.nimages)

    elif args.task=='kmesh':
        struct.writekp(args.kgrid)



if __name__=='__main__':
    verbose_pkg_info(__version__)
    parser, args = get_args(desc_str)
    main(args,task_list,desc_str)
