#!/usr/bin/env python


#===========================================================================#
#                                                                           #
#  File:       match_latt.py                                                #
#  Dependence: parse.py,crystal_structure.py                                #
#  Usage:      find matched superlattices for two given structures          #      
#  Author:     Shunhong Zhang <szhang2@ustc.edu.cn>                         #
#  Date:       Sep 25, 2019                                                 #
#                                                                           #
#===========================================================================#


from __future__ import print_function
import os
import shutil
import numpy as np
from pysupercell import __version__
import pysupercell.pysupercell as psc
from pysupercell.superlattice import *


def build_superlattice(scale_1,sc_1,scale_2,sc_2,struct_1,struct_2,vacuum=15,interlayer_gap=3):
    sc_struct_1 = struct_1.build_supercell(sc_1)
    sc_struct_2 = struct_2.build_supercell(sc_2)
    sc_cell    = sc_struct_2._cell
    thick_1 = np.max(sc_struct_1._pos_cart[:,2]) - np.min(sc_struct_1._pos_cart[:,2])
    thick_2 = np.max(sc_struct_2._pos_cart[:,2]) - np.min(sc_struct_2._pos_cart[:,2])
    sc_cell[2,2] = thick_1 + thick_2 + interlayer_gap + vacuum
    sc_struct_1._cell[2,2] = sc_cell[2,2]
    sc_struct_2._cell[2,2] = sc_cell[2,2]
    top =  (thick_1+thick_2+interlayer_gap)/2+sc_cell[2,2]/2
    bot = -(thick_1+thick_2+interlayer_gap)/2+sc_cell[2,2]/2
    sc_struct_1._pos_cart[:,2] +=  top-np.max(sc_struct_1._pos_cart[:,2])
    sc_struct_2._pos_cart[:,2] +=  bot-np.min(sc_struct_2._pos_cart[:,2])
    sc_struct_1._pos = np.dot(sc_struct_1._pos_cart, np.linalg.inv(sc_struct_1._cell))
    sc_struct_2._pos = np.dot(sc_struct_2._pos_cart, np.linalg.inv(sc_struct_2._cell))
    sc_species = np.append(sc_struct_1._species,sc_struct_2._species)
    sc_symbols = np.append(sc_struct_1._symbols,sc_struct_2._symbols)
    sc_counts  = np.append(sc_struct_1._counts,sc_struct_2._counts)
    sc_pos     = np.append(sc_struct_1._pos,sc_struct_2._pos,axis=0)
    filpos='POSCAR_'+str(scale_1)+'_'+str(scale_2)+'_'+str(len(sc_pos))+'atoms'
    return psc.cryst_struct(sc_cell,sc_species,sc_counts,sc_pos,system=filpos)

def build_all_superlattice(match_sc_set,struct_1,struct_2,vacuum=15,interlayer_gap=3):
    print ("\nBuilding superlattices")
    superlatt_list=[]
    for sc_index,(scale_1,sc_1,scale_2,sc_2) in enumerate(match_sc_set):
        superlatt=build_superlattice(scale_1,sc_1,scale_2,sc_2,struct_1,struct_2,vacuum=vacuum,interlayer_gap=interlayer_gap)
        superlatt._system+='_sc{0}'.format(sc_index+1)
        superlatt_list.append(superlatt)
    print ('done')
    return superlatt_list

def write_superlattice(superlatt_list,pos_type,outdir='matched_structures'):
    if os.path.isdir(outdir): shutil.rmtree(outdir)
    os.mkdir(outdir)
    print ("\nWriting structures")
    for sc_index,superlatt in enumerate(superlatt_list):
        superlatt.write_poscar_head( filename='{}/{}'.format(outdir,superlatt._system))
        superlatt.write_poscar_atoms(filename='{}/{}'.format(outdir,superlatt._system),postype=pos_type)
    print ("done\n")


def main():
    parser, args=get_args()
    verbose_setup(args)
    struct_1 = psc.cryst_struct.load_poscar(args.poscar1)
    struct_2 = psc.cryst_struct.load_poscar(args.poscar2)
    scale_set = find_match_area(struct_1._cell,struct_2._cell,args.maxarea,args.tolerance)
    match_sc_set=find_match_supercell(scale_set,struct_1._cell,struct_2._cell,args.tolerance,args.min_sc_angle,args.max_sc_angle)
    write_supercell_log(match_sc_set,struct_1._cell,struct_2._cell)
    if len(match_sc_set)==0: exit(1)
    superlatt_list = build_all_superlattice(match_sc_set,struct_1,struct_2,vacuum=args.vacuum,interlayer_gap=args.interlayer_gap)
    write_superlattice(superlatt_list,args.pos_type)


if __name__=='__main__':
    psc.verbose_pkg_info(__version__)

    print ('\nRunning the script: {0}\n'.format(__file__.lstrip('./')))
    try:
        from termcolor import cprint
        cprint (desc_str,'blue')
        cprint (notes,'green')
        cprint (alert,'red')
    except:
        print ('{0}\n{1}\n{2}'.format(desc_str,notes,alert))
    main()
