#!/usr/bin/env python
#===========================================================================#
#                                                                           #
#  File:       boltz_tools.py                                               #
#  Dependence: none                                                         #
#  Usage:      pre- and post-processing of boltztrap                        #      
#  Author:     Shunhong Zhang <szhang2@ustc.edu.cn>                         #
#  Date:       Sep 23, 2017                                                 #
#                                                                           #
#===========================================================================#

from __future__ import print_function
import numpy as np
import os
from pysupercell import __version__
from pysupercell.pysupercell import *

import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}' #for \text command

from scipy.constants import physical_constants as pc
Ryd_to_eV = pc['Rydberg constant times hc in eV'][0]
Bohr_to_Angstrom = pc['Bohr radius'][0]*1e10


dir_dic={0:'x',1:'y'}
spin_dic={0:'up',1:'dn'}


def write_intrans(case='case',nelect=1,dos_method='HISTO',efermi=0.):
    print ('You set nelect={:3d}, efermi = {:6.3f} eV'.format(nelect,efermi))
    print ('Please make sure they are correct, or modify them in the .intrans file manually')
    with open(case+'.intrans','w') as fw:
        fw.write('GENE                           # use generic interface\n')
        fw.write('0  0  0  0.0                   # iskip (not presently used) idebug setgap shiftgap\n')
        fw.write('{0:<7.4f}  0.0005  0.5  {1:5d}'.format(efermi,nelect))
        fw.write('   # Fermilevel (Ry,energygrid,energy span around Fermilevel,number of electrons\n')
        fw.write('CALC                           # CALC (calculate expansion coeff,NOCALC read from file\n')
        fw.write('5                              # lpfac,number of latt-points per k-point\n')
        fw.write('BOLTZ                          # run mode (only BOLTZ is supported\n')
        fw.write('0.15                           # (efcut energy range of chemical potential\n')
        fw.write('800.0  50.0                    # Tmax,temperature grid\n')
        fw.write('-1.0                           # energyrange of bands given individual DOS output sig_xxx and dos_xxx (xxx is band number\n')
        fw.write('{0}\n'.format(dos_method))


def get_data(outdir='./',return_h5=False):
    import itertools
    case=os.popen('ls {0}/*.struct'.format(outdir)).read().split('/')[-1].split('.')[-2]
    filename='{0}/{1}.condtens'.format(outdir,case)
    data=np.loadtxt(open(filename),skiprows=1)
    temperatures=np.array(sorted(set(data[:,1])))
    ntemp=len(temperatures)
    nedos=data.shape[0]/ntemp
    energy=data[:,0].reshape(nedos,ntemp)[:,0]
    e_cond=data[:,3:12].reshape(nedos,ntemp,9)
    seebeck=data[:,12:21].reshape(nedos,ntemp,9)*1e6
    e_kappa=data[:,21:].reshape(nedos,ntemp,9)/1e15
    concentration=data[:,2].reshape(nedos,ntemp)
    powerfactor=np.zeros((nedos,ntemp,9),float)
    #get_cell=os.popen('head -4 {0}.struct|tail -3'.format(case)).readlines()
    #cell=np.array([list(map(float,vec.rstrip('\n').split())) for vec in get_cell])*Bohr_to_Angstrom
    #volume=np.linalg.det(cell)
    for ie,itemp in itertools.product(range(nedos),range(ntemp)):
        powerfactor[ie,itemp]=np.dot((seebeck[ie,itemp].reshape(3,3))**2,e_cond[ie,itemp].reshape(3,3)).reshape(-1,9)
    if return_h5:
        import h5py
        with h5py.File("boltz.hdf5","w") as h5_data:
            d1=h5_data.create_dataset('energy', energy.shape, 'f8')
            d2=h5_data.create_dataset('temperatures',temperatures.shape,'f8')
            d3=h5_data.create_dataset('concentration',concentration.shape,'f8')
            d4=h5_data.create_dataset('e_cond',e_cond.shape,'f8')
            d5=h5_data.create_dataset('seebeck',seebeck.shape,'f8')
            d6=h5_data.create_dataset('e_kappa',e_kappa.shape,'f8')
            d7=h5_data.create_dataset('powerfactor',powerfactor.shape,'f8')
            d1[:]=energy
            d2[:]=temperatures
            d3[:]=concentration
            d4[:]=e_cond
            d5[:]=seebeck
            d6[:]=e_kappa
            d7[:]=powerfactor
        return h5_data
    return energy,temperatures,concentration,e_cond,seebeck,e_kappa,powerfactor


def plot_data(args,quantity='seebeck',plot_temp=None):
    energy,temperatures,concentration,e_cond,seebeck,e_kappa,powerfactor=get_data(outdir=args.outdir)
    data=seebeck; symbol='S'; unit='\mu V/K'
    if quantity=='e_cond':
       data=e_cond; symbol='\sigma/\\tau'; unit='\\Omega//m/s'
    elif quantity=='e_kappa':
       data=e_kappa; symbol='\kappa_e'; unit='10^{15}\\ W/mKs'
    colors=['r','g','b','c','m','purple','k']
    if not plot_temp: plot_temp = np.linspace(300,800,6)   # by default
    print ( 'plotting {0:10s}, temperature to plot:{1}'.format(quantity,plot_temp))
    ntemp=len(temperatures)
    fig,ax = plot.subplots(len(args.dirs),1,sharex=True,sharey=True)
    if len(args.dirs)==1: ax=[ax]
    color_patch=[]
    energy *= Rydberg_to_eV
    for ii,idir in enumerate(args.dirs):
        for index,temp in enumerate(plot_temp):
            if temp in plot_temp:
               itemp=list(temperatures).index(temp)
               ax[ii].plot(energy[:,itemp], data[:,itemp,idir*4], color=colors[index], ls='-',lw=2, label='{0:4.0f}K'.format(temp))
        if args.elim[0]<args.elim[1]: ax[ii].set_xlim(tuple(args.elim))
        else: ax[ii].set_xlim(np.min(energy),np.max(energy))
        xlim=ax[ii].get_xlim()
        ylim=ax[ii].get_ylim()
        left=(xlim[1]-xlim[0])*0.1+xlim[0]
        #up=ylim[1]*0.8
        #if quantity!='e_cond': ax[ii].text(left,up,'${0}^'.format(symbol)+'{'+'{0}{0}'.format(dir_dic[idir])+'}$',va='center')
        #else: ax[ii].text(left,up,'$({0})^'.format(symbol)+'{'+'{0}{0}'.format(dir_dic[idir])+'}$',va='center')
        if quantity!='seebeck': ax[ii].set_ylim(0,ylim[1])
        

    ax[-1].set_xlabel('$E-E_f$ ($eV$)')
    plot.setp(ax,ylabel='${0}$ (${1}$)'.format(symbol,unit))
    #ax[-1].legend(bbox_to_anchor=(0.7,0.8),handles=[item for item in color_patch],loc=0,prop={'size':8})
    ax[-1].legend()
    axes=plot.gca()
    ymin, ymax = axes.get_ylim()
    for axx in ax: 
        axx.set_xlim(-2,2)
        xmin,xmax=axx.get_xlim()
        x_idx=np.where(np.logical_and(energy[:,0]<xmax, energy[:,0]>xmin))[0]
        ymin=np.min(data[x_idx])
        ymax=np.max(data[x_idx])
        axx.set_ylim(ymin,ymax)
    fig.savefig(quantity, dpi=600)
    return fig


def plot_v2dos(case='case'):
    v2dos=np.loadtxt(open('{0}.v2dos'.format(case)),skiprows=1)
    fig=plt.figure(figsize=(6,4))
    ax=fig.add_subplot(111)
    ax.plot(v2dos[:,0],v2dos[:,1],'r-',label='xx')
    ax.legend()
    ax.set_xlabel('$E\ \mathrm{(Ry)}$')
    ax.set_ylabel('$v^2\ \mathrm{DOS}$')
    fig.tight_layout()
    fig.savefig('v2dos',dpi=400)
    return fig


def plot_transdos(args,case='case'):
    f=open('{0}.transdos'.format(case))
    line=list(map(float,f.readline().split()[1:]))
    emin,emax,estep,ee,nedos=line
    nedos=int(nedos)
    print ( emin,emax,estep)
    transdos=np.loadtxt(f)
    fig=plt.figure(figsize=(6,4))
    ax=fig.add_subplot(111)
    if args.en_unit=='Ry': 
        ax.plot(transdos[:,0],transdos[:,1],'r-',label='xx')
        ax.set_xlabel('$E\ \mathrm{(Ry)}$')
    elif args.en_unit=='eV': 
        ax.plot(transdos[:,0]*Rydberg_to_eV,transdos[:,1],'r-',label='xx')
        ax.set_xlabel('$E\ \mathrm{(eV)}$')
    ax.legend()
    if args.elim[1]>args.elim[0]: ax.set_xlim(*tuple(args.elim))
    ax.set_ylabel('$\mathrm{Transport\ DOS}$')
    fig.tight_layout()
    fig.savefig('transdos',dpi=400)
    return fig


def plot_wan_tdf(seedname='wannier90'):
    print ( 'reading tdf from {0}_tdf.dat'.format(seedname))
    data=np.loadtxt(open('{0}_tdf.dat'.format(seedname)),skiprows=6)
    fig=plt.figure(figsize=(6,4))
    ax=fig.add_subplot(111)
    ax.plot(data[:,0],data[:,1],'r-',label='xx')
    ax.plot(data[:,0],data[:,3],'g-',label='yy')
    ax.plot(data[:,0],data[:,6],'b-',label='zz')
    ax.legend(loc='upper right')
    ax.set_ylabel('$\Xi\ \mathrm{(eV \cdot fs/\hbar^2/\AA)}$')
    ax.set_xlabel('$\mu\ \mathrm{(eV)}$')
    fig.tight_layout()
    fig.savefig('wan_tdf',dpi=400)
    return fig

def plot_wan_elcond(seedname='wannier90'):
    print ( 'reading elcond from {0}_elcond.dat'.format(seedname))
    data=np.loadtxt(open('{0}_elcond.dat'.format(seedname)),skiprows=3)
    ntemp=len(set(data[:,1]))
    print ( 'number of temperatrue points: {0}'.format(ntemp))
    if ntemp>1:
        data=data.reshape(data.shape[0]/ntemp,ntemp,data.shape[1])
        data=np.transpose(data,[1,0,2])
    fig=plt.figure(figsize=(4,6))
    if ntemp==1:
        ax=fig.add_subplot(111)
        ax.plot(data[:,0],data[:,2]*1e-6,'r-',label='xx')
        ax.plot(data[:,0],data[:,4]*1e-6,'g-',label='yy')
        ax.plot(data[:,0],data[:,7]*1e-6,'b-',label='zz')
        ax.legend(loc='upper right')
        ax.set_ylabel('$\mathrm{\sigma_{xx}\ (10^6\ S/m)}$')
    else:
        for i in range(3):
            ax=fig.add_subplot(3,1,i+1)
            col={0:2,1:4,2:7}[i]
            for itemp in range(ntemp):
                ax.plot(data[itemp,:,0],data[itemp,:,col]*1e-6,label='{0:<5.1f} K'.format(data[itemp,0,1]),lw=1)
            ax.set_ylabel('$\sigma_{'+'{0}'.format({0:'xx',1:'yy',2:'zz'}[i])+'}'+'\ \mathrm{(10^6\ S/m)}$')
            if i!=2: ax.set_xticklabels([])
            if i==0: ax.legend(loc='upper right',ncol=1)
    ax.set_xlabel('$\mathrm{\mu\ (eV)}$')
    fig.tight_layout()
    fig.savefig('wan_elcond',dpi=400)
    return fig


def plot_wan_seebeck(seedname='wannier90'):
    print ( 'reading Seebeck from {0}_seebeck.dat'.format(seedname))
    data=np.loadtxt(open('{0}_seebeck.dat'.format(seedname)),skiprows=6)
    ntemp=len(set(data[:,1]))
    print ( 'number of temperatrue points: {0}'.format(ntemp))
    if ntemp>1:
        data=data.reshape(data.shape[0]/ntemp,ntemp,data.shape[1])
        data=np.transpose(data,[1,0,2])
    fig=plt.figure(figsize=(4,6))
    if ntemp==1:
        ax=fig.add_subplot(111)
        ax.plot(data[:,0],data[:,2]*1e6,'r-',label='xx')
        ax.plot(data[:,0],data[:,6]*1e6,'g-',label='yy')
        ax.plot(data[:,0],data[:,10]*1e6,'b-',label='zz')
        ax.legend(loc='upper right')
        ax.set_ylabel('$\mathrm{S_{xx}\ (\mu V/K)}$')
    else:
        for i in range(3):
            ax=fig.add_subplot(3,1,i+1)
            for itemp in range(ntemp):
                ax.plot(data[itemp,:,0],data[itemp,:,2+i*4]*1e6,label='{0:<5.1f} K'.format(data[itemp,0,1]),lw=1)
            ax.set_ylabel('$S_{'+'{0}'.format({0:'xx',1:'yy',2:'zz'}[i])+'}'+'\ \mathrm{(\mu V/K)}$')
            if i!=2: ax.set_xticklabels([])
            if i==0: ax.legend(loc='upper right',ncol=1)
    ax.set_xlabel('$\mathrm{\mu\ (eV)}$')
    fig.tight_layout()
    fig.savefig('wan_seebeck',dpi=400)
    return fig


def plot_wan_boltzdos(args,seedname='wannier90'):
    f=open('{0}_boltzdos.dat'.format(seedname))
    data=np.loadtxt(f,skiprows=3)
    fig=plt.figure(figsize=(6,4))
    ax=fig.add_subplot(111)
    ax.plot(data[:,0],data[:,1],'g-')
    ax.set_xlabel('$\mu\ (eV)$')
    ax.set_ylabel('$\mathrm{DOS}$')
    if args.elim[0]<args.elim[1]: ax.set_xlim(args.elim[0],args.elim[1])
    fig.tight_layout()
    fig.savefig('wan_boltzdos',dpi=400)
    return fig


def boltz_pre_process(args):
    import astk.core.bands as bands
    import glob
    import shutil
    write_intrans(case=args.case,nelect=args.nelect,dos_method=args.dos_method,efermi=args.efermi)
    struct=cryst_struct.load_poscar()
    struct.write_wien2k_struct(case=args.case)
    parse,args1=bands.get_args('bands')
    ebands=bands.ebands(args1)
    ebands._write_wien2k_energy(case=args.case) 
    if ebands._nspin>1: 
        os.mkdir('{0}_spin_{1}'.format(args.case,spin_dic[ispin]))
    else: 
        if os.path.isdir(args.case): shutil.rmtree(args.case)
        os.mkdir('{0}'.format(args.case))
    if ebands._nspin>1:
        for ispin in range(ebands._nspin):
            shutil.copyfile('{0}.intrans'.format(args.case),'{0}_spin_{1}/{0}_spin_{1}.intrans'.format(args.case,spin_dic[ispin]))
            shutil.copyfile('{0}.struct'.format(args.case),'{0}_spin_{1}/{0}_spin_{1}.struct'.format(args.case,spin_dic[ispin]))
    else:
        for item in glob.glob(('{0}.*'.format(args.case))):
            shutil.move(item,args.case)


def get_args():
    import argparse
    from pysupercell.arguments import add_control_arguments
    parser = argparse.ArgumentParser(prog='boltz_tools.py', description = desc)
    add_control_arguments(parser)
    parser.add_argument('--case',type=str,default='case',help='case name')
    parser.add_argument('--plot_temp',type=eval,default=None,help='plot temperatures')
    parser.add_argument('--seedname',type=str,default='wannier90',help='seedname of wannier90 output')
    parser.add_argument('--dirs',type=eval,default=(0,),help='components of quantities to plot')
    parser.add_argument('--nelect',type=int, default=0, help='No. of electrons, only integer allowed for BoltzTraP, non-integer will cause problems')
    parser.add_argument('--efermi',type=float,default=0,help='Fermi level for intrans')
    parser.add_argument('--dos_method',type=str,default='HISTO',help='method to calculate DOS, can be HISTO or TETRA')
    parser.add_argument('--en_unit',type=str,default='eV',help='unit for energy, can be eV or Ry')
    parser.add_argument('--elim',type=eval,default=(-1,1),help='Energy range for dos plots')
    args=   parser.parse_args()
    print ('task = {0}'.format(args.task))
    return args

if __name__=='__main__':
    verbose_pkg_info(__version__)

    desc='pre- and post-processing for boltztrap'
    print ('\nrunning the script {0}\n'.format(__file__.lstrip('./')))
    args=get_args()
    if args.task=='plot':
        for quan in ['e_cond','seebeck','e_kappa']: plot_data(args,quantity=quan,plot_temp=args.plot_temp)
    elif args.task=='pre':          boltz_pre_process(args)
    elif args.task=='v2dos':        plot_v2dos(case=args.case)
    elif args.task=='wan_tdf':      plot_wan_tdf(seedname=args.seedname)
    elif args.task=='wan_seebeck':  plot_wan_seebeck(seedname=args.seedname)
    elif args.task=='wan_boltzdos': plot_wan_boltzdos(args,seedname=args.seedname)
    elif args.task=='wan_elcond':   plot_wan_elcond()
    elif args.task=='transdos':     plot_transdos(args)
