#!/usr/bin/env python

#======================================================
# A paralle version of py
# find matched supercell for two 2D crystals
# write all found superstructures 
# in parallel
# Author: Shunhong Zhang
# szhang2@ustc.edu.cn
# Date: Oct 31, 2019
#=======================================================

from __future__ import print_function
import os
import shutil
import mpi_tools as mt
from pysupercell import __version__
import pysupercell.pysupercell as psc
from pysupercell.superlattice import *


def build_superlattice(scale_1,sc_1,scale_2,sc_2,struct_1,struct_2,vacuum=15,interlayer_gap=3):
    sc_struct_1 = struct_1.build_supercell(sc_1)
    sc_struct_2 = struct_2.build_supercell(sc_2)
    sc_cell    = sc_struct_2._cell
    thick_1 = np.max(sc_struct_1._pos_cart[:,2]) - np.min(sc_struct_1._pos_cart[:,2])
    thick_2 = np.max(sc_struct_2._pos_cart[:,2]) - np.min(sc_struct_2._pos_cart[:,2])
    sc_cell[2,2] = thick_1 + thick_2 + interlayer_gap + vacuum
    sc_struct_1._cell[2,2] = sc_cell[2,2]
    sc_struct_2._cell[2,2] = sc_cell[2,2]
    top =  (thick_1+thick_2+interlayer_gap)/2+sc_cell[2,2]/2
    bot = -(thick_1+thick_2+interlayer_gap)/2+sc_cell[2,2]/2
    sc_struct_1._pos_cart[:,2] +=  top-np.max(sc_struct_1._pos_cart[:,2])
    sc_struct_2._pos_cart[:,2] +=  bot-np.min(sc_struct_2._pos_cart[:,2])
    sc_struct_1._pos = np.dot(sc_struct_1._pos_cart, np.linalg.inv(sc_struct_1._cell))
    sc_struct_2._pos = np.dot(sc_struct_2._pos_cart, np.linalg.inv(sc_struct_2._cell))
    sc_species = np.append(sc_struct_1._species,sc_struct_2._species)
    sc_symbols = np.append(sc_struct_1._symbols,sc_struct_2._symbols)
    sc_counts  = np.append(sc_struct_1._counts,sc_struct_2._counts)
    sc_pos     = np.append(sc_struct_1._pos,sc_struct_2._pos,axis=0)
    filpos='POSCAR_'+str(scale_1)+'_'+str(scale_2)+'_'+str(len(sc_pos))+'atoms'
    return psc.cryst_struct(sc_cell,sc_species,sc_counts,sc_pos,system=filpos)



def mpi_find_match_supercell(scale_set,cell_1,cell_2,tolerance,min_sc_angle,max_sc_angle):
    if len(scale_set)==0: exit('Not matched supercell found')
    comm,size,rank,node=mt.get_mpi_handles()
    start,last=mt.assign_task(len(scale_set),size,rank)
    if start<last:
        match_sc_set = find_match_sc(scale_set[start:last],
        cell_1,cell_2,tolerance,min_sc_angle,max_sc_angle)
    else:
        match_sc_set = []
    comm.barrier()
    if not rank:
        match_sc_set = comm.gather(match_sc_set, root=0)
        match_sc_set = [item for item in match_sc_set]
    match_sc_set = comm.bcast(match_sc_set,root=0)
    return match_sc_set
    

def mpi_build_all_superlattice(match_sc_set,struct_1,struct_2,vacuum=15,interlayer_gap=3,pos_type='direct',outdir='matched_structures'):
    if len(match_sc_set)==0: exit()
    comm,size,rank,node=mt.get_mpi_handles()
    if rank==0:
        print ('\nBuilding superlattices, parallel on {0} cores'.format(size))
        if os.path.isdir(outdir): shutil.rmtree(outdir)
        os.mkdir(outdir)
    start,last=mt.assign_task(len(match_sc_set),size,rank)
    for sc_index,(scale_1,sc_1,scale_2,sc_2) in enumerate(match_sc_set[start:last]):
        superlatt=build_superlattice(scale_1,sc_1,scale_2,sc_2,struct_1,struct_2,vacuum=vacuum,interlayer_gap=interlayer_gap)
        superlatt._system+='_sc{0}'.format(start+sc_index+1)
        superlatt.write_poscar_head( filename='{}/{}'.format(outdir,superlatt._system))
        superlatt.write_poscar_atoms(filename='{}/{}'.format(outdir,superlatt._system),postype=pos_type)
    comm.barrier()
    mt.pprint ('done')
        

def main():
    comm,size,rank,node=mt.get_mpi_handles()
    parser, args=get_args()

    if rank==0:
        psc.verbose_pkg_info(__version__)
        print('Runing script {0}'.format(__file__.rstrip('\n')))
        print('Parallel on {0} cores'.format(size))
        verbose_setup(args)
    struct_1 = psc.cryst_struct.load_poscar(args.poscar1)
    struct_2 = psc.cryst_struct.load_poscar(args.poscar2)
    scale_set=None
    match_sc_set=None
    if not rank: 
        scale_set = find_match_area(struct_1._cell,struct_2._cell,args.maxarea,args.tolerance)
        match_sc_set = find_match_supercell(scale_set,struct_1._cell,struct_2._cell,args.tolerance,args.min_sc_angle,args.max_sc_angle)
        write_supercell_log(match_sc_set,struct_1._cell,struct_2._cell)
    scale_set=comm.bcast(scale_set,root=0)
    match_sc_set=comm.bcast(match_sc_set,root=0)
    #match_sc_set=mpi_find_match_supercell(scale_set,struct_1._cell,struct_2._cell,args.tolerance,args.min_sc_angle,args.max_sc_angle)
    mpi_build_all_superlattice(match_sc_set,struct_1,struct_2,vacuum=args.vacuum,interlayer_gap=args.interlayer_gap)
 

if __name__=='__main__':
    main()
