#!python
# -*- coding: utf-8 -*-
"""
Created on Sun Aug 20 19:30:07 2023
Last modified on Wedn Dec 6 14:45:00 2023

@author: Sheng Zhang, Institute of Physics, Chinese Academy of Sciences

the python script of VASP2KP
VASP->vasp2mat->mat2kp
"""

import sympy
import sympy as sp
from sympy import Matrix,sqrt,sin,cos,eye,Rational,exp
from sympy.core.numbers import I
import re
#import sympy.physics.matrices as sm
#from sympy.physics.quantum import TensorProduct
import sys
from VASP2KP import get_std_kp_Zeeman, get_O3_matrix, get_std_kp_Zeeman_no_coe

class SymmetryNotSetError(Exception):
    pass

class PrintFlagPatchError(Exception):
    pass

mathematica_flag = False

F = False
f = False
T = True
t = True
true = True
false = False
TRUE = True
FALSE = False
matrix = Matrix
MATRIX = Matrix
i = I
j = I
Exp = exp
EXP = exp
rational = Rational
RATIONAL = rational
Sqrt = sqrt
SQRT = sqrt
Eye = eye
EYE = eye
Sin = sin
SIN = sin
Cos = cos
COS = cos
SP = sp
sP = sp
Sp = sp
sp.matrix = matrix
sympy.matrix = matrix
sp.rational = rational
sympy.rational = rational

str_list = ['exp','sin','cos','sqrt','rational','matrix','true','false','eye','repr_matrix','rotation_matrix','repr_has_cc']

# open input file
with open('mat2kp.in', 'r') as file:
    input_list = []
    
    for line in file.readlines():
        line = line.strip()
        
        if line == '':
            continue
        
        line_low = line.lower()
        
        for string in str_list:
            matches = re.finditer(string, line_low)
            
            for match in matches:
                start = match.start()
                end = match.end()
                line = line.replace(line[start:end],string)
        
        input_list.append(line)
    
    input_inf = '\n'.join(input_list)



# execute
exec(input_inf)

locals_key = list(locals().keys())
locals_key_low = []

for i in range(len(locals_key)):
    locals_key_low.append(locals_key[i].lower())


if 'Symmetry' not in locals_key:
    
    if 'symmetry' in locals_key_low:
        Symmetry_id = locals_key_low.index('symmetry')
        Symmetry = eval(locals_key[Symmetry_id])
        
    else:
        raise SymmetryNotSetError("The parameter Symmetry not set!")
        sys.exit()


if 'log' not in locals_key:
    
    if 'log' in locals_key_low:
        log_id = locals_key_low.index('log')
        log = eval(locals_key[log_id])
        
    else:
        log = 0
        
        
if 'no_vasp_kp' not in locals_key:
    
    if 'no_vasp_kp' in locals_key_low:
        no_vasp_kp_id = locals_key_low.index('no_vasp_kp')
        no_vasp_kp = eval(locals_key[no_vasp_kp_id])
        
    else:
        no_vasp_kp = False
        
        
        
    
if 'acc' not in locals_key:
    
    if 'acc' in locals_key_low:
        acc_id = locals_key_low.index('acc')
        acc = eval(locals_key[acc_id])
        
    else:
        acc = 0
    
    
if 'print_flag' not in locals_key:
    
    if 'print_flag' in locals_key_low:
        print_flag_id = locals_key_low.index('print_flag')
        print_flag = eval(locals_key[print_flag_id])
        
    else:
        print_flag = 2
    
    
if 'gfactor' not in locals_key:
    
    if 'gfactor' in locals_key_low:
        gfactor_id = locals_key_low.index('gfactor')
        gfactor = eval(locals_key[gfactor_id])
        
    else:
        gfactor = 1
    
    
if 'order' not in locals_key:
    
    if 'order' in locals_key_low:
        order_id = locals_key_low.index('order')
        order = eval(locals_key[order_id])
        
    else:
        order = 2
        
if 'kpmodel' not in locals_key:
    
    if 'kpmodel' in locals_key_low:
        kpmodel_id = locals_key_low.index('kpmodel')
        kpmodel = eval(locals_key[kpmodel_id])
        
    else:
        kpmodel = 1
        
        
    
'''    
if 'folder_path' not in locals_key:
    
    if 'folder_path' in locals_key_low:
        folder_path_id = locals_key_low.index('folder_path')
        folder_path = eval(locals_key[folder_path_id])
        
    else:
        folder_path = 'case_mat_eig'
'''


if 'vaspMAT' not in locals_key:
    
    if 'vaspmat' in locals_key_low:
        vaspMAT_id = locals_key_low.index('vaspmat')
        vaspMAT = eval(locals_key[vaspMAT_id])
        
    else:
        vaspMAT = 'mat'

if print_flag == 0:

    raise PrintFlagPatchError("Parameter print_flag is set to 0, there will be no output file! Please reset this parameter.")
    sys.exit()


elif print_flag == 1:
    print(r"Warning!!! Parameter 'print_flag' is set to 1, all the result will be output on console, not into files!")
    
    
# calculate the result
if no_vasp_kp:
    result_kp, result_Zeeman = \
    get_std_kp_Zeeman_no_coe(Symmetry, order=order, gfactor=gfactor, print_flag=print_flag, log=log)

else:
    result_kp, result_Zeeman = \
    get_std_kp_Zeeman(Symmetry, vaspMAT, kpmodel=kpmodel, order=order, gfactor=gfactor, print_flag=print_flag, log=log, acc=acc)
    


if gfactor == 1 and (not no_vasp_kp) and mathematica_flag:
    mathematica_file = open("kp-Zeeman.nb",'w')
    mathematica_file.write(r'Clear["Global`*"];')
    mathematica_file.write('\n')
    kp_string = str(sp.simplify(result_kp.evalf()))
    Zeeman_string = str(sp.simplify(result_Zeeman.evalf()))
    Zeeman_string = Zeeman_string.replace('[','{')
    kp_string = kp_string.replace('[','{')
    Zeeman_string = Zeeman_string.replace(']','}')
    kp_string = kp_string.replace(']','}')
    Zeeman_string = Zeeman_string.replace('**','^')
    kp_string = kp_string.replace('**','^')
    Zeeman_string = Zeeman_string.replace('Matrix','HZ=')
    kp_string = kp_string.replace('Matrix','Hkp=')
    
    mathematica_file.write(kp_string)
    mathematica_file.write(";\n")
    mathematica_file.write(Zeeman_string)
    mathematica_file.write(";")
    mathematica_file.close()

elif gfactor == 0 and (not no_vasp_kp) and mathematica_flag:
    mathematica_file = open("kp.nb",'w')
    mathematica_file.write(r'Clear["Global`*"];')
    mathematica_file.write('\n')
    kp_string = str(sp.simplify(result_kp.evalf()))
    kp_string = kp_string.replace('[','{')
    kp_string = kp_string.replace(']','}')
    kp_string = kp_string.replace('**','^')
    kp_string = kp_string.replace('Matrix','Hkp=')
    
    mathematica_file.write(kp_string)
    mathematica_file.write(";")
    mathematica_file.close()
    
elif gfactor == 1 and no_vasp_kp and mathematica_flag:
    mathematica_file = open("kp-Zeeman.nb",'w')
    mathematica_file.write(r'Clear["Global`*"];')
    mathematica_file.write('\n')
    kp_string = str(sp.simplify(result_kp))
    Zeeman_string = str(sp.simplify(result_Zeeman))
    Zeeman_string = Zeeman_string.replace('[','{')
    kp_string = kp_string.replace('[','{')
    Zeeman_string = Zeeman_string.replace(']','}')
    kp_string = kp_string.replace(']','}')
    Zeeman_string = Zeeman_string.replace('**','^')
    kp_string = kp_string.replace('**','^')
    Zeeman_string = Zeeman_string.replace('sqrt(3)','Sqrt[3]')
    kp_string = kp_string.replace('sqrt(3)','Sqrt[3]')
    Zeeman_string = Zeeman_string.replace('Matrix','HZ=')
    kp_string = kp_string.replace('Matrix','Hkp=')
    
    mathematica_file.write(kp_string)
    mathematica_file.write(";\n")
    mathematica_file.write(Zeeman_string)
    mathematica_file.write(";")
    mathematica_file.close() 
    

elif gfactor == 0 and no_vasp_kp and mathematica_flag:
    mathematica_file = open("kp.nb",'w')
    mathematica_file.write(r'Clear["Global`*"];')
    mathematica_file.write('\n')
    kp_string = str(sp.simplify(result_kp))
    kp_string = kp_string.replace('[','{')
    kp_string = kp_string.replace(']','}')
    kp_string = kp_string.replace('**','^')
    kp_string = kp_string.replace('sqrt(3)','Sqrt[3]')
    kp_string = kp_string.replace('Matrix','Hkp=')
    
    mathematica_file.write(kp_string)
    mathematica_file.write(";")
    mathematica_file.close()
