#!/usr/bin/env python


#=======================================================================
#                                                              
#  File:       pysc.py                                         
#  Usage:      some functionalities to modulate and analyze structures
#  Author:     Shunhong Zhang <szhang2@ustc.edu.cn>            
#  Date:       Mar 30, 2024                                    
#                                                              
#========================================================================


import copy 
from pysupercell.utility.twin_structure import *
from pysupercell.arguments import *
 

def get_args(desc_str):
    import argparse
    parser = argparse.ArgumentParser(prog='pysupercell', description = desc_str)
    add_control_arguments(parser)
    add_io_arguments(parser)
    add_structure_arguments(parser)
    add_plotting_arguments(parser)
    args = parser.parse_args()
    return parser, args




task_list=[
'crystal_info',
'redefine',
'slab',
'tube',
'strain',
'bond',
'rotate_z',
'shift',
'inversion',
'kmesh',
'reset_vacuum',
'bending',
'screw_dislocation',
'wien',
'write_pw',
'compare',
'cmp',
'interpolate',
'None',
]



desc_str='''
input exmaple:
    {0}
    {0} --task=redefine --sc1=1,-1,0 --sc2=1,1,0 sc3=0,0,1
    {0} --task=slab --hkl=121
    {0} --task=tube --chiral_num=2,4
    {0} --task=bond --atom_index=0,1
    {0} --task=wien
    {0} --task=crystal_info
    {0} --task=cmp --poscar1=POSCAR --poscar2=CONTCAR
    {0} --task=strain --dirs=0,1 --strain=0.01
    {0} --task=kmesh
    {0} --task=bending --nn=8 --idir_per=1 --idir_bend=2
    {0} --task=shift --idir_shift=2 --shift=10
    {0} --task=screw_dislocation --burgers_vector=[0,0,1] --screw_idir=2 --display_structure=True
'''.format(__file__.split('/')[-1])



def check_task_validity(task):
    print ('task = {}'.format(task))
    if task==None or task=='None': 
        try: cprint (desc_str,'green')
        except: print (desc_str)
        return

    if task not in task_list:
        print ('\n{}\navailable tasks:'.format('-'*30))
        print ('\n'.join([task for task in task_list]))
        print ('\n{0}'.format('-'*30))
        return 


def main(args,task_list,desc_str):
    check_task_validity(args.task)

    if args.task not in ['interpolate', 'cmp', 'compare']:
        if args.source=='VASP':
            struct = cryst_struct.load_poscar(args.poscar)
        elif args.source=='QE':
            struct = cryst_struct.load_pwscf_in(args.filpw)
        elif args.source=='cif':
            struct = cryst_struct.load_cif(args.filcif)

    if args.task=='crystal_info': 
        struct.verbose_crystal_info()

    elif args.task=='redefine':
        filpos = 'POSCAR_redefine'
        redef_struct = struct.redefine_lattice(args.sc1,args.sc2,args.sc3,args.cell_orientation)
        redef_struct.write_poscar_head(filename=filpos)
        redef_struct.write_poscar_atoms(filename=filpos,mode='a')

    elif args.task=='reset_vacuum':
        redef_struct=copy.copy(struct)
        redef_struct._cell[2,2]=args.vacuum
        central_z = np.average(redef_struct._pos_cart[:,2])
        redef_struct._pos_cart[:,2] -= central_z + redef_struct._cell[2,2]/2
        redef_struct._pos=np.dot(redef_struct._pos_cart,np.linalg.inv(redef_struct._cell))
        redef_struct.shift_atoms_to_home()
        flpos='POSCAR_reset_vacuum'
        redef_struct.write_poscar_head(filename=flpos)
        redef_struct.write_poscar_atoms(filename=flpos,mode='a')
        
    elif args.task=='rotate_z':
        a=struct.latt_param()['a']
        b=struct.latt_param()['b']
        gamma=struct.latt_param()['gamma']
        cell_1=[a*cos((args.angle)/180*np.pi),a*np.sin((args.angle)/180*np.pi),0]
        cell_2=[b*cos((gamma+args.angle)/180*np.pi),b*np.sin((gamma+args.angle)/180*np.pi),0]
        struct._cell=np.array([cell_1,cell_2,struct._cell[2]])
        struct._system='POSCAR_rotate_z'
        flpos='POSCAR_rotate_z'
        struct.write_poscar_head(filename=flpos)
        struct.write_poscar_atoms(filename=flpos,mode='a')

    elif args.task=='inversion':
        inversion_center_cart = np.dot(args.inv_center,struct._cell)
        struct_inv = copy.deepcopy(struct)
        struct_inv._pos_cart = 2 * inversion_center_cart - struct._pos_cart
        struct_inv._pos = np.dot(struct_inv._pos_cart, np.linalg.inv(struct_inv._cell) )
        flpos='POSCAR_inv'
        struct_inv.write_poscar_head(filename=flpos)
        struct_inv.write_poscar_atoms(filename=flpos,mode='a')

    elif args.task=='slab':
        hkl=args.hkl
        if len(hkl)>3:
            exit('we do not allow negative index!')
        h,k,l=list(map(int,hkl))
        struct_slab = struct.build_slab(h,k,l,args.thickness,args.vacuum,args.atom_shift)
        struct_slab._system='slab'
        flpos='POSCAR_slab_{}'.format(hkl)
        struct_slab.write_poscar_head(filename=flpos)
        struct_slab.write_poscar_atoms(filename=flpos,mode='a')

    elif args.task=='tube':
        n,m=args.chiral_num
        struct_tube = struct.build_tube(n,m,negative_R=args.negative_R)
        fltube = "POSCAR_tube_{0}_{1}".format(n,m)
        struct_tube._system=system='{0}_{1}_nanotube'.format(n,m)
        struct_tube.write_poscar_head(filename=fltube)
        struct_tube.write_poscar_atoms(filename=fltube,mode='a')

    elif args.task=='bending':
        struct_bend = struct.build_bending_supercell(args.nn,args.amp,args.idir_per,args.idir_bend,args.central_z)
        struct_bend._system = 'bending struct'
        flpos='POSCAR_bending'
        struct_bend.write_poscar_head(filename=flpos)
        struct_bend.write_poscar_atoms(filename=flpos,mode='a')

    elif args.task=='strain':
        for idir in args.strain_dirs:
            struct._cell[idir] *= 1+args.strain
        struct._pos_cart=np.dot(struct._pos,struct._cell)
        filpos='POSCAR_{0:5.3f}'.format(args.strain)
        struct.write_poscar_head(filename=filpos)
        struct.write_poscar_atoms(filename=filpos)

    elif args.task=='shift':        
        struct.shift_pos(args.idir_shift,args.shift,to_home=args.to_home)
        filpos='POSCAR_shifted'
        struct.write_poscar_head(filename=filpos)
        struct.write_poscar_atoms(filename=filpos)

    elif args.task=='bond':
        i,j=args.atom_index
        cc=np.append(0,np.cumsum(struct._counts))
        ic=i-cc[np.where(i-cc>=0)[0][-1]]+1
        jc=j-cc[np.where(j-cc>=0)[0][-1]]+1
        print ('\n{0}\natomic positions in the home cell\n{0}'.format('-'*80))
        print ('symbol  atom '+('{:>10s} '*6).format('x','y','z','x_cart','y_cart','z_cart'))
        for ii,jj in zip((i,j),(ic,jc)):
            print ('{:>6s} {:5d} '.format(struct._symbols[ii],jj),end='')
            print (' '.join(['{:10.5f}'.format(item) for item in np.append(struct._pos[ii],struct._pos_cart[ii])]))
        print ( '-'*80)
        output='min distance between {}{} and {}{} is: {:8.5f} Angstrom'
        print (output.format(struct._symbols[i],ic,struct._symbols[j],jc,struct.bond_length(i,j)))

    elif args.task=='screw_dislocation':
        burgers_vector = np.dot(args.burgers_vector,struct._cell)
        struct_screw = struct.make_screw_dislocation(burgers_vector,args.screw_center,args.screw_normal,args.screw_idir)
        struct_screw._system = 'screw dislocation structure'
        struct_screw.write_poscar_head(filename='POSCAR_screw')
        struct_screw.write_poscar_atoms(filename='POSCAR_screw',mode='a')

        map_kws = dict(
        marked_positions=None,
        scatter_size=args.scatter_size,
        cmap=args.cmap,
        colorbar_orientation='vertical',
        vmin=0,vmax=0,
        show=args.display_structure,
        grid_x=0,grid_y=0)
     
        map_data(struct_screw._cell,struct_screw._pos,**map_kws)

    elif args.task=='Select_Dynamics' or args.task=='SD':
        struct.write_poscar_head(filename='POSCAR_sd')
        struct.write_poscar_atoms(selective_dynamics=True,fix_dirs=args.dirs,filename='POSCAR_sd')

    elif args.task=='wien':         
        struct.write_wien2k_struct(args.case,symmprec=args.symmprec)

    elif args.task=='write_pw':
        struct.write_pw_cell()
        struct.write_pw_atoms()

    elif args.task=='cmp' or args.task=='compare':
        compare_structs(args.poscar1,args.poscar2)

    elif args.task=='interpolate':
        st_images = interpolate_structs(args.poscar1,args.poscar2,args.nimages,sort_atoms=args.sort_atoms, n_extrapolate=args.n_extrapolate)

    elif args.task=='kmesh':
        struct.writekp(args.kgrid)



if __name__=='__main__':
    from pysupercell import __version__
    verbose_pkg_info(__version__)
    parser, args = get_args(desc_str)
    main(args,task_list,desc_str)
