#!python
import os
import io
from os import listdir
from os.path import isfile, join
import random
import string
import argparse
import re
import sys
from datetime import datetime
from time import gmtime, strftime


if sys.version_info[0] < 3:##to check python version and modify for the code to run in python2 versions
    from io import open
    reload(sys)
    sys.setdefaultencoding('utf8')




def read_in_chunks(file_object, chunk_size=2000): #reads the file in chunks of default 2000 bytes this can be specified by users
    while True:
        data = file_object.read(chunk_size)
        if not data:
            break
        yield data

def timeStampVaalidator(word):
    
    new_word = word.strip()
    obj = re.match("^([0-9]{2,4})\/([0-1][0-9])\/([0-3][0-9])(?:( [0-2][0-9]):([0-5][0-9]):([0-5][0-9]))?$", new_word, re.M | re.I)
    obj1 = re.match("^([0-9]{2,4})-([0-1][0-9])-([0-3][0-9])(?:( [0-2][0-9]):([0-5][0-9]):([0-5][0-9]))?$", new_word, re.M | re.I)
    obj2 = re.match("^([0-3][0-9])-([0-1][0-9])-([0-9]{2,4})(?:( [0-2][0-9]):([0-5][0-9]):([0-5][0-9]))?$", new_word, re.M | re.I)
    obj3 = re.match("^([0-3][0-9])\/([0-1][0-9])\/([0-9]{2,4})(?:( [0-2][0-9]):([0-5][0-9]):([0-5][0-9]))?$", new_word, re.M | re.I)
    obj4 = re.match("^([0-3][0-9])\.([0-1][0-9])\.([0-9]{2,4})(?:( [0-2][0-9]):([0-5][0-9]):([0-5][0-9]))?$", new_word, re.M | re.I)
    obj5 = re.match("^([0-9]{2,4})\.([0-1][0-9])\.([0-3][0-9])(?:([0-2][0-9]):([0-5][0-9]):([0-5][0-9]))?$", new_word, re.M | re.I)
    obj6 = re.match("(([0-9]{2,4})-([0-1][0-9]|[0-9])-[0-3][0-9]T)(?:([0-2][0-9]):([0-5][0-9]):([0-5][0-9]))?$", new_word, re.M | re.I)
    
    
    if obj :
        time = strftime("%Y/%m/%d %H:%M:%S", gmtime()).split(' ')
        return True, time[0]
    elif obj1:
        time = (str(datetime.now()).split(' '))
        return True, time[0]
    elif obj2:
        time = strftime("%d-%m-%Y %H:%M:%S", gmtime()).split(' ')
        return True, time[0]
    elif obj3:
        time = strftime("%d/%m/%Y %H:%M:%S", gmtime()).split(' ')
        return True, time[0]
    
    elif obj4:
        time = strftime("%d.%m.%Y %H:%M:%S", gmtime()).split(' ')
        return True, time[0]
    elif obj5:
        time = strftime("%Y.%m.%d %H:%M:%S", gmtime()).split(' ')
        return True, time[0]
    elif obj6:
        time = strftime("%d-%m-%Y %H:%M:%S", gmtime()).split(' ')
        return True, time[0]

    else:
        return False, word


def maskGenerator(word,flag = False): # iterates through word letter by letter and replaces the same by random character
    c = ""
    r,f = timeStampVaalidator(word)
    if r:#checking if it is a time stamp
        return f
    else :#if not iterating through each element and masking it
      i = 0

      while (i < len(word)):
           if word[i].isdigit():
               c += str(random.randint(1,9))
           elif word[i].isupper():
               c += random.choice(string.ascii_letters).upper()
           elif word[i].isspace():
               c = c + word[i]
           elif word[i].islower():
               c += random.choice(string.ascii_letters).lower()
           else:
               if flag == True:
                  if word[i] == "&":
                      while word[i] != ";":
                          c += word[i]
                          i += 1
                      c = c+';'
                  else:
                      c += word[i]
               else:
                   c += word[i]
           i += 1


      return c




def xmlTagMask (s): #masks the xml attributes

# exception for the xml header tag
    #print (s)
    mask = True
    if "xsi" in s:
        return s
    elif "Namespace" in s:
        return s
    elif "<?" in s:
        return s
    elif "[CDATA" in s:
        s = s.replace("[CDATA", "*")
        s = maskGenerator(s)
        s = s.replace("*","[CDATA" )
        return s




#masking xml attributes

    if ("=" in s) and ("\"" in s):
        y = ""
        i = 0
        while(i < len(s)):
            if s[i] == "=":
                c = 0
                while c < 2:
                    y = y + maskGenerator(s[i])
                    i += 1
                    if s[i] == ("\""):
                           c += 1


            y = y + s[i]
            i += 1
        return y

    else:
        return s


def xmlSpecificTagMask(s,userInput = [],xpath=""):

    flag = False
    if len(userInput) < 1:
        return xmlTagMask(s), True, xpath
    elif "<?" in s:
        return xmlTagMask(s), True, xpath
    elif "!--" in s:
        return xmlTagMask(s), True, xpath
    else:
        xmlAttList = []
        att = "None"
        xmlTag = ""
        r = ""
        r = s.strip()

        if "=" in r:
            xmlAttList = r.split("=")
            #att = xmlAttList[0].split(' ')[1]
            xmlTag = xmlAttList[0].split(' ')[0]
            xmlTag = xmlTag[1:]
            if 'ns:' in  xmlTag:
                 x = xmlTag.split(":")
                 xmlTag = x[1]


        else:
            xmlTag = r
            xmlTag = r[1:]
            if 'ns:' in  xmlTag:
                 x = xmlTag.split(":")
                 xmlTag = x[1]

        if xpath == "" or xpath == None:

            xpath = "/" + xmlTag
        else:
            if xmlTag != None:
                xpath = xpath + "/" + xmlTag

        #if (att is not "None"):
        #           xpath = xpath + "/@" + att
        #           att = "None"



        for i in userInput:
            if i == xpath:

                return xmlTagMask(s), True, xpath
        
        return s, False, xpath




def xpathModification(xpath,endTag):
    if xpath == "":
        return ""

    endTag = endTag.strip()
    xpathList = xpath.split("/")

    if xpathList[-1] == endTag[1:]:
          del xpathList[-1]

    c = "/"
    xpath = c.join(xpathList)
    return xpath












  
def jsonmaskgenerator(s): # function to mask json content

    if ":" not in s:#checking for : sign
        return s

    if ("}" in s) or ("]" in s):
        c = 0
        r = ""
        i = 0

        while ((s[i] != "}") or (s[i] != "]")): #iterating and appending the content of the xml attribure and then performing slicing to mask the content

            #print(s[i] + " " + str(i) + "  "+str(len(s) - 1))
            r = r + s[i]
            c += 1
            i += 1
            if (i >= (len(s)-1)):
                break

        k = r.split(':')
        s = s[c:]
        m = maskGenerator(k[1])

        #print(":" + m + s)

        return ":" + m + s

    else:
        m = s.split(':')
        n = maskGenerator((m[1]))
        return ":" + n








def xmlParse (input,out,bufferbyte,mask = ""):
    listmask = []
    xpath = ""
    mask = str(mask)
    if ',' in mask:
        listmask = mask.split(',')
    elif mask == "" or mask == "None":
        listmask = []
    else:
        listmask.append(mask)



    if len(listmask) > 0:
        shouldMask = False
    else:
        shouldMask = True
    additional_content = "" #to store the contents of the line if it is not ending with > tag
    with open(input, buffering=bufferbyte,encoding="utf-8") as file:
        for file_content in read_in_chunks(file, ((int(bufferbyte / 2)))): #iterating through chunks of file
           flag = True # to check wether we have opening tag in the content that is being read currently
           #file_content = file_content.strip()
           c = 0
           if additional_content != "":#to check wether we have contents of the file that dows not end with >
               file_content = additional_content + file_content
               additional_content = ""
           if file_content.endswith('>'):
               file_content = file_content

           else:
               additional_content = ""
               c = 0 #count to slice the string
               n = True #to identify closing and opening tag
               if '<' in file_content:
                   for k in (reversed(file_content)):
                       if k == '<':
                           n = True
                           break
                       elif k == '>':
                           n = False
                           break
                       c = c + 1
                   if n == True:
                         additional_content = "<" + file_content[(len(file_content) - c):]#adding the contents of the last tag
                         file_content = file_content[0:(len(file_content) - c - 1)]#slicing the content that has been appendide in the additional_content variable
                   else:
                       additional_content = file_content[-c:]
                       file_content = file_content[0:((len(file_content))-c)]

                   flag = True
               else:
                   additional_content = file_content
                   file_content = ""
                   flag = False

           if flag == True:
                lines = file_content.split('>')
                with open(out, "a", encoding="utf-8") as file2:
                    for i in lines:#iterating through the lines
                        if (i.strip().startswith('<') and i.strip()[1] != "/"):
                                i,shouldMask,xpath = xmlSpecificTagMask(i,listmask,xpath)  # masking xml atrribute
                                file2.write(i + ">")  # writing in the output file

                        else:
                            b = i.split("<")#splitting based on opening tag
                            if(sys.version_info[0] < 3):#since below python 3 we do not have unicode characters by default.
                                if shouldMask == True:
                                      file2.write(maskGenerator(b[0],True).decode('utf-8'))#masking the attributes
                                else:
                                    file2.write(b[0].decode('utf-8'))
                                if len(b) > 1:
                                    xpath = xpathModification(xpath, xmlTagMask(b[1]))
                                    file2.write("<" + xmlTagMask(b[1]) + ">")  # appending the closing tags
                            else:
                                if shouldMask == True:
                                     file2.write(maskGenerator(b[0], True))  # masking the attributes
                                else:
                                    file2.write(b[0])
                                if len(b) > 1:
                                    xpath = xpathModification(xpath,xmlTagMask(b[1]))
                                    file2.write("<" + xmlTagMask(b[1]) + ">")  # appending the closing tags









def jsonParse(input,out,bufferbyte):

    additiona_content = "" #to store line not ending with :
    with open(input, buffering= bufferbyte, encoding="utf-8") as file:
        for content in read_in_chunks(file,(int(bufferbyte / 2))) :

            if additiona_content != "":
                content = additiona_content + content #adding remaining content to the new content
            if content[-1] == ':':
                content = content[:-1]
                additiona_content = ':'

            else:
                c = 0
                for k in reversed(content):
                    if k == ':':
                        break
                    c += 1
                additiona_content = content[-c - 1:]#adding the contents till :
                content = content[:-c - 1]#removing the contents by slicing off the data present in additional content
                with open(out, "a",encoding="utf-8") as file2:
                    i = 0
                    s = ""
                    while (i < len(content)):
                        if content[i] == ":":
                            r = ""
                            while (content[i] != ","):#iterating through the data to mask
                                check = i# #to change the value of i later on if the list has data instead of k,v pairs
                                flag = False
                                if content[i] == '[':#to check wehter the list has key value pairs

                                    k = ""
                                    while content[check] != ']':

                                        if content[check] == '{':
                                            flag = True
                                            break
                                        else:
                                            k = k + maskGenerator(content[check])
                                            check += 1
                                    if flag == False:
                                        i = check + 1
                                        s = s+":"+k+"]"
                                        r = ""





                                if flag == True or i < len(content):

                                        if content[i] == '{' or content[i] == '[' :
                                            r = r + content[i]
                                            i = i + 1
                                            break

                                        else:
                                            # s = s + content[i]
                                            r = r + content[i]
                                            i = i + 1

                            if r != "":
                                if ('{' not in r) or ('[' not in r):
                                    s = s + jsonmaskgenerator(r)
                                else:
                                        s = s + r

                        else:
                            s = s + content[i]
                            i += 1


                    file2.write(s)

    with open(out, "a", encoding="utf-8") as file2:
        file2.write(jsonmaskgenerator(additiona_content))


def commandLine(args):

    if args.input != "":

         if (os.path.isdir(args.input)):
             inputDirectory = args.input + "/"
             if not os.path.exists(args.outputDir):
                 try :
                     os.makedirs(args.outputDir)
                 except :
                     sys.exit("Exiting no write access")
             if not os.access(args.outputDir, os.W_OK):
                 sys.exit('NExiting no write access' )
             files = []
             for subdir, dirs, allfiles in os.walk(inputDirectory):
                 for file in allfiles:
                      files.append(os.path.join(subdir, file))
             for names in files:
                 name, ext = os.path.splitext(names)
                 fileName = name.split("/").pop()
                 output = args.outputDir + "/" + fileName + str(random.randint(1,1000))+ext
                 if ext == ".xml":
                     print("Processing " + names)
                     xmlParse((names),output,args.byteSize)
                     print("------------------------------------------------")
                     print("Completed Processing " + inputDirectory + names + " Maksed file location "+output)
                     print("------------------------------------------------")
                     print("------------------------------------------------")

                 elif ext == ".json":
                     print("Processing " + names)
                     jsonParse((names),output,args.byteSize)
                     print("------------------------------------------------")
                     print("Completed Processing " + names + " Maksed file location " + output)
                     print("------------------------------------------------")
                     print("------------------------------------------------")


         elif os.path.isfile(args.input) :
             name, ext = os.path.splitext(args.input)
             fileName = name.split("/").pop()
             if not os.path.exists(args.outputDir ):
                 try :
                     os.makedirs(args.outputDir)
                 except :
                     sys.exit("Exiting no write access")
             if not os.access(args.outputDir, os.W_OK):
                 sys.exit('Exiting no write access ' )
             output = args.outputDir + "/" + fileName + str(random.randint(1, 1000)) + ext

             if ext == ".xml":
                 print("Processing " + fileName)
                 xmlParse(args.input,output,args.byteSize,args.mask)
                 print("------------------------------------------------")
                 print("Completed Processing " + fileName + " Masked file location " + output)
                 print("------------------------------------------------")
                 print("------------------------------------------------")

             elif ext == ".json":
                 print("Processing " + fileName)
                 jsonParse(args.input,output,args.byteSize)
                 print("------------------------------------------------")
                 print("Completed Processing " + fileName + " Maksed file location " + output)
                 print("------------------------------------------------")
                 print("------------------------------------------------")

             else:
                 sys.exit("Please provide valid xml/json file")
         else:
             sys.exit("Invalid directory/ File ")

    else:
       sys.exit("Please provide input")


def main():
    parser = argparse.ArgumentParser(description="Data Masking")
    parser.add_argument("-i", help="Input Directory Name / File Name", dest="input", type=str, default = "",required=True)
    parser.add_argument("-b", help = "Provide byte size to buffer", dest = "byteSize",type = int , default= "2000000")
    parser.add_argument("-o", help = "Output Directory Name",dest = "outputDir", type=str,required=True)
    parser.add_argument("-l", help = "Input xpath or xpaths seperated by ,",dest = "mask", type=str)
    parser.set_defaults(func=commandLine)
    args = parser.parse_args()
    args.func(args)


if __name__ == "__main__":
    main()













































