#!/usr/bin/env python3
import os
import re
import site
import subprocess as sp
import time
import shlex
import argparse
import uuid

import requests
from glob import glob
from dataclasses import dataclass
from multiprocessing import Lock
from multiprocessing import Process
from multiprocessing import current_process
from subprocess import Popen, PIPE
from threading import Thread
from typing import List, Dict, Union, Callable, Any

import toml
from bs4 import BeautifulSoup


class Colors:
    CYAN = '\033[96m'
    GREEN = '\033[92m'
    MAGENTA = '\033[95m'
    RED = '\033[91m'
    END = '\033[0m'
    BOLD = '\033[1m'


class SuccessStatus:
    EXPLOIT_UNSUCCESSFUL = 'EXPLOIT_UNSUCCESSFUL'
    EXPLOIT_SUCCESSFUL = 'EXPLOIT_SUCCESSFUL'
    PRIVILEGE_ESCALATE_SUCCESSFUL = 'PRIVILEGE_ESCALATE_SUCCESSFUL'


class Utils:
    class AutoHackException(Exception):
        pass

    @staticmethod
    def make_directories(*args: str) -> None:
        """
        Makes a directory for each directory
        :param args:
        :return:
        """
        for directory in args:
            if os.path.isdir(directory):
                continue
            os.makedirs(directory)

    @staticmethod
    def clear_file(_file: str) -> None:
        """
        Clear file by writing an empty string into it
        :param _file:
        :return:
        """
        with open(_file, 'w') as f:
            f.write('')

    @staticmethod
    def print_time() -> None:
        current_time_in_seconds = int(time.time())
        print(f"{Colors.MAGENTA}{Colors.BOLD}CURRENT TIME{Colors.END}: {current_time_in_seconds}")

    @staticmethod
    def determine_output_file(
            command: str,
            options_associated_with_output_file: List[str]
    ) -> Union[str, None]:
        """
        Parses command for output file and returns...
            (1) the output file if determined
                        or
            (2) None, if output file can not be determined
        :param command:
        :param options_associated_with_output_file:
        :return:
        """
        fragments = shlex.split(command)
        filepath = None

        for index, fragment in enumerate(fragments):
            if fragment.lower() in options_associated_with_output_file:
                if index + 1 >= len(fragments):
                    continue

                filepath = fragments[index + 1]
                return filepath
        return filepath

    @staticmethod
    def should_run_command(command_output_file: str) -> bool:
        """
        Should run command if:
            (1) the output file is None
                            or
            (2) the command's output file doesn't exist
        :param command_output_file:
        :return: a boolean value
        """
        if command_output_file is None:
            return True

        return not os.path.isfile(command_output_file)

    @staticmethod
    def extract_matching_string(regex: str, text: str) -> Union[str, None]:
        """
        Extracts the string which matches regex within a text
        :param regex:
        :param text:
        :return:
        """
        match = re.search(regex, text.lower())
        if match:
            return match.string
        else:
            return None

    @staticmethod
    def alpha_to_int(text) -> Union[int, str]:
        """
        Casts alpha-numeric to an integer
        :param text:
        :return:
        """
        if text.isdigit():
            return int(text)
        return text

    @staticmethod
    def natural_keys(text: str) -> List[str]:
        """
        Split text into fragments so it can be naturally sorted.
        :param text:
        :return:
        """
        return [Utils.alpha_to_int(c) for c in re.split('(\\d+)', text)]

    @staticmethod
    def get_site_packages() -> List[str]:
        return site.getsitepackages()

    @staticmethod
    def get_lock():
        """
        Returns the processing lock
        :return:
        """
        return Lock()

    @staticmethod
    def process_safe_print(_str: str) -> None:
        """
        Locks process, prints, and releases process
        so that printing output appears in the correct order
        :param _str:
        :return:
        """
        process_lock = Utils.get_lock()
        process_lock.acquire()
        print(_str)
        process_lock.release()

    @staticmethod
    def get_chart(value, max_value, size=30, unit='seconds') -> str:
        """
        Renders chart based on input values

        :param value:
        :param max_value:
        :param size:
        :param unit:
        :return:
        """
        if value > max_value:
            value = max_value
        elif value < 0:
            value = 0

        scaled_value = int(value / max_value * size)
        tick_marks = '▯' * scaled_value
        non_tick_marks = ' ' * (size - scaled_value)
        chart = (f'|{Colors.GREEN}{Colors.BOLD}{tick_marks}{Colors.END}'
                 f'{Colors.CYAN}({value} {unit}){Colors.END}'
                 f'{non_tick_marks}|{Colors.CYAN}({max_value} {unit}){Colors.END}')
        return chart

    @staticmethod
    def get_attacker_address(interface='tun0') -> str | None:
        """
        Gets ip address of attacker. This is used for
        obtaining a reverse shell.

        By default, it uses the tun0 network interface since the attacker
        is likely connected to the target through a vpn.

        :param interface:
        :return:
        """
        try:
            result = sp.check_output(
                ['ip', '-4', 'addr', 'show', interface],
                stderr=sp.STDOUT
            )
            result = result.decode().splitlines()
            ipaddr = [line.split()[1].split('/')[0] for line in result if "inet" in line]
            return ipaddr[0]

        except (Exception,):
            return None


@dataclass
class ProcessReference:
    """keeps track of useful process data."""
    target: Callable
    args: List[Any]
    kill_after_time: int
    process_object: Process
    retry_max: int
    command_pids: List[str]

    def has_timed_out(self) -> bool:
        current_time_in_seconds = int(time.time())
        return current_time_in_seconds >= self.kill_after_time

    @property
    def should_retry(self):
        return self.retry_max > 0

    @property
    def command(self):
        try:
            return self.args[0]
        except IndexError:
            raise Utils.AutoHackException(
                f'Process {self.process_object.pid} Does Not Have Any Arguments Recorded'
            )

    def decrement_retry_max(self):
        self.retry_max -= 1

    def append_pid(self, pid: str):
        """
        Keeping track of this command's pids is useful when debugging
        :param pid:
        :return:
        """
        self.command_pids.append(pid)


class ProcessManager:
    def __init__(
        self,
        verbose: bool,
        time_between_heart_beats: int,
    ):
        self.processes: List[ProcessReference] = []
        self.verbose = verbose
        self.time_between_heart_beats = time_between_heart_beats
        self.heart_beat_process = Thread(
            target=self.heart_beat,
        )

    def append_process(self, process: ProcessReference) -> None:
        self.processes.append(process)

    def print_active_processes_data(self, active_process: List[ProcessReference]) -> None:
        current_time = int(time.time())

        message = f'\n{Colors.MAGENTA}{Colors.BOLD}COMMANDS STILL RUNNING [at {current_time}]:{Colors.END}\n'
        for process in active_process:
            seconds_until_timeout = int(process.kill_after_time - current_time)

            message += f'\t{Colors.CYAN}{Colors.BOLD}PID:{Colors.END} {process.process_object.pid}\n'
            chart = Utils.get_chart(
                value=seconds_until_timeout,
                max_value=AutoHack.command_timeout
            )
            message += f'\t{Colors.CYAN}{Colors.BOLD}SECONDS UNTIL TIMEOUT:{Colors.END} {chart}\n'

            if self.verbose:
                message += f'\t{Colors.CYAN}{Colors.BOLD}COMMAND CALLED:{Colors.END} {process.command}\n'

            # add spaces between each process message
            message += '\n'

        Utils.process_safe_print(message)

    @staticmethod
    def retry_process(process: ProcessReference):
        """
        * Creates a new process using the preexisting target and args.
        * Updates attributes of process reference

        :param process:
        :return:
        """

        # create new process using preexisting
        pr = Process(target=process.target, args=process.args)
        pr.start()

        # update process reference attributes
        process.process_object = pr
        process.append_pid(str(pr.pid))
        process.kill_after_time = int(time.time()) + AutoHack.command_timeout
        process.decrement_retry_max()

    def terminate_all_active_processes(self, show_message:bool=True) -> None:
        for process in self.processes:

            if not process.process_object.is_alive():
                continue

            if process.has_timed_out():
                continue

            process.process_object.terminate()

            if not show_message:
                continue

            message = (f'\n{Colors.RED}{Colors.BOLD}KILLING PROCESS, PROGRAM ENDING: '
                       f'PID:{process.process_object.pid}{Colors.END}')
            Utils.process_safe_print(message)

    def heart_beat(self) -> None:
        while True:
            time.sleep(self.time_between_heart_beats)
            active_processes = []

            for process in self.processes:

                # skip non-active processes
                if not process.process_object.is_alive():
                    continue

                if process.has_timed_out():
                    process.process_object.terminate()
                    message = (f'\n{Colors.RED}{Colors.BOLD}KILLING PROCESS DUE TO TIMEOUT: '
                               f'PID:{process.process_object.pid}{Colors.END}')
                    Utils.process_safe_print(message)

                    if not process.should_retry:
                        continue

                    self.retry_process(process)
                    command_pids = '->'.join(process.command_pids)
                    message = (
                        f'{Colors.MAGENTA}{Colors.BOLD}COMMAND RESTARTED:{Colors.END} '
                        f'{Colors.CYAN}{Colors.BOLD}COMMAND PIDS:{Colors.END}'
                        f'{Colors.GREEN}{Colors.BOLD}{command_pids}{Colors.END}\n'
                    )
                    Utils.process_safe_print(message)

                active_processes.append(process)

            if AutoHack.should_kill_heart_beat:
                self.terminate_all_active_processes()
                break

            if active_processes:
                self.print_active_processes_data(active_processes)

    def run(self):
        self.heart_beat_process.start()


class AutoHack:
    should_kill_heart_beat = False
    command_timeout = 0
    api_source = os.getenv('AUTOHACK_API_SOURCE') or 'https://api.autohack.com'
    api_key = os.getenv('AUTOHACK_API_KEY') or 'FREE'

    def __init__(
        self, ip_address, attacker_ip_address, output_directory=None, config_directory=None,
        port_scan_type=None, nmap_extra=None, ports=None,
        use_processes=None, verbose=None, command_timeout=None,
        heart_beat_time=None, retry_max=None, professional=None
    ) -> None:
        AutoHack.command_timeout = command_timeout
        self.process_lock = Utils.get_lock()
        self.target_address = ip_address
        self.attacker_address = attacker_ip_address
        self.port_scan_type = port_scan_type
        self.nmap_extra = nmap_extra
        self.ports = ports
        self.current_working_directory = os.getcwd()
        self.use_processes = use_processes
        self.verbose = verbose
        self.retry_max = retry_max
        self.professional = professional
        self.process_manager = ProcessManager(
            verbose=verbose,
            time_between_heart_beats=heart_beat_time,
        )
        if use_processes:
            self.process_manager.run()

        output_directory = output_directory or os.path.join(
            self.current_working_directory,
            'results', self.target_address
        )
        config_directory = self.get_config(config_directory)
        port_scan_config_file = os.path.join(config_directory, 'port-scans.toml')
        services_scan_config_file = os.path.join(config_directory, 'service-scans.toml')
        universal_pattern_config_file = os.path.join(config_directory, 'universal-patterns.toml')

        if not os.path.isfile(port_scan_config_file):
            raise Utils.AutoHackException(f'Can Not Find Port Scan Config File inside: {config_directory} Directory')
        if not os.path.isfile(services_scan_config_file):
            raise Utils.AutoHackException(f'Can Not Find Services Scan Config File: {config_directory} Directory')
        if not os.path.isfile(universal_pattern_config_file):
            raise Utils.AutoHackException(f'Can Not Find Universal Pattern Config File: {config_directory} Directory')

        self.port_scan_config = toml.load(port_scan_config_file)
        self.services_scan_config = toml.load(services_scan_config_file)
        self.universal_patterns = [value for value in toml.load(universal_pattern_config_file).values()]

        self.scan_directory = os.path.join(output_directory, 'scans')
        xml_directory = os.path.join(self.scan_directory, 'xml')
        log_directory = os.path.join(self.scan_directory, 'logs')
        self.exploit_directory = os.path.join(output_directory, 'exploit')
        self.exploit_commands_attempted_directory = os.path.join(self.exploit_directory, 'commands_attempted')
        self.priv_directory = os.path.join(output_directory, 'priv')
        loot_directory = os.path.join(output_directory, 'loot')
        Utils.make_directories(
            self.scan_directory, xml_directory, log_directory,
            self.exploit_directory, self.priv_directory, loot_directory,
            self.exploit_commands_attempted_directory
        )

        self.detected_services = set()
        self.pattern_matches = dict()
        self.exploit_scans = set()
        self.commands_run_log = os.path.join(log_directory, 'commands.log')
        self.detected_services_log = os.path.join(log_directory, 'detected_services.log')
        self.exploit_steps_log = os.path.join(log_directory, 'exploit_steps.log')
        self.patterns_detected_log = os.path.join(log_directory, 'patterns.log')
        self.session_id: str | None = None
        self.spearhead_index_unique_id: str | None = None
        self.expected_command_answers_to_verify_root_access: dict = {}
        self.success_status = SuccessStatus.EXPLOIT_UNSUCCESSFUL
        Utils.clear_file(self.exploit_steps_log)

    def get_config(self, config_directory: str):
        """
        The config directory is needed for enumeration.
        This method checks for the config directory in all the
        expected places. If it cannot be found, it raises an exception

        :param config_directory:
        :return:
        """
        # use user-specified AutoHack config directory
        if config_directory:
            return config_directory

        # define <current-working-directory>/autohack_config
        local_config_directory = os.path.join(
            self.current_working_directory, 'autohack_config'
        )

        # define ~/.config/autohack_config
        home_directory = os.path.expanduser("~")
        global_config_directory = os.path.join(
            home_directory, '.config', 'autohack_config'
        )

        # define <possible site-packages>/autohack_config
        site_packages = Utils.get_site_packages()
        site_packages_autohack_directories = [
            f'{site_package_directory}/autohack_config' for site_package_directory in site_packages
        ]

        possible_config_directories_locations = [
            local_config_directory,
            global_config_directory
        ] + site_packages_autohack_directories

        for prospective_directory in possible_config_directories_locations:
            if self.verbose:
                print(f'{Colors.GREEN}Looking For AutoHack Config In: {prospective_directory}{Colors.END}')

            if os.path.exists(prospective_directory):
                if self.verbose:
                    print(f'{Colors.GREEN}Found AutoHack Config In: {prospective_directory}{Colors.END}\n\n')
                return prospective_directory

        # raise exception when config directory couldn't be found
        raise Utils.AutoHackException('Can Not Find Config Directory')

    def print(self, _str: str) -> None:
        """
        Locks process, prints, and releases process
        so that printing output appears in the correct order
        :param _str:
        :return:
        """
        self.process_lock.acquire()
        print(_str)
        self.process_lock.release()

    def process_print(self, _str: str) -> None:
        """
        Prints while prepending the process id
        so it's more obvious where message came from
        :param _str:
        :return:
        """
        if self.use_processes:
            process_id = current_process().pid
            _str = f"{Colors.CYAN}{Colors.BOLD}Process {process_id}:{Colors.END} {_str}"
        self.print(_str)

    def command_has_been_run(self, command) -> bool:
        try:
            f = open(self.commands_run_log, "r")
            all_commands_run = f.read()
            return command in all_commands_run
        except FileNotFoundError:
            return False

    def log_command(self, command: str) -> None:
        """
        Writes command to a log file
        :param command:
        :return:
        """
        process_id = current_process().pid
        command = f"[*] (Process {process_id}) {command}\n\n"
        self.process_lock.acquire()
        with open(self.commands_run_log, 'a') as f:
            f.write(command)
        self.process_lock.release()

    def log_detected_services(self) -> None:
        """
        Write the detected services to a log file
        :return:
        """
        detected_services_list = list(self.detected_services)
        detected_services_list.sort(key=Utils.natural_keys)

        with open(self.detected_services_log, 'w') as f:
            self.process_lock.acquire()
            for detected_service in detected_services_list:
                f.write(f"{detected_service}\n")
            self.process_lock.release()

    def log_exploit_steps(self, exploit_scans: List[Dict], service_scan_data: Dict) -> None:
        """
        Writes exploit scans to a log file
        :param exploit_scans:
        :param service_scan_data:
        :return:
        """
        with open(self.exploit_steps_log, 'a') as f:
            self.process_lock.acquire()
            for sub_scan in exploit_scans:
                message = f"[*] {sub_scan['description']}\n"
                wrote_command = False

                for command in sub_scan['commands']:
                    command = command.format(**service_scan_data)
                    if command in self.exploit_scans:
                        continue
                    wrote_command = True
                    self.exploit_scans.add(command)
                    message += f"\t{command}\n"
                if wrote_command:
                    f.write(f"{message}\n\n")
            self.process_lock.release()

    def record_pattern_matches(self, output_file: str, patterns: List[dict], service_data: dict) -> None:
        """
        Checks output file for certain regex patterns
        and then saves pattern description for future analysis.

        :param output_file:
        :param patterns:
        :param service_data:
        :return:
        """
        if 'port' in service_data:
            # since port scans don't specify a port, don't look at universal
            # patterns for port scans
            patterns_to_search = patterns + self.universal_patterns
        else:
            patterns_to_search = patterns

        if len(patterns_to_search) == 0:
            return

        with open(output_file, 'r') as f:
            for line in f:
                for pattern in patterns_to_search:
                    description = pattern['description']
                    pattern_regex = pattern['pattern'].lower()
                    match = Utils.extract_matching_string(pattern_regex, line)
                    if match:
                        description = description.format(**service_data, match=match)
                        if output_file in self.pattern_matches:
                            self.pattern_matches[output_file].add(description)
                        else:
                            self.pattern_matches[output_file] = {description}
        self.log_recorded_patterns()

    def log_recorded_patterns(self) -> None:
        """
        Records patterns which were saved
        :return:
        """
        if len(self.pattern_matches) == 0:
            return

        pattern_description_shown = set()

        self.process_lock.acquire()
        with open(self.patterns_detected_log, 'w') as f:
            for output_file in self.pattern_matches:
                unrepeated_pattern_present = False
                message = f"[*] Pattern/s detected in: '{output_file}'\n"
                for description in self.pattern_matches[output_file]:
                    if description in pattern_description_shown:
                        continue
                    pattern_description_shown.add(description)
                    message += f"\t[-] {description}\n"
                    unrepeated_pattern_present = True
                message += "\n"

                if unrepeated_pattern_present:
                    f.write(message)
        self.process_lock.release()

    def run_command(self, command: str, should_log=True) -> None:
        """
        Executes a command
        :param command:
        :param should_log:
        :return:
        """
        if should_log:
            self.log_command(command)

        process = Popen(command, shell=True, stdout=PIPE, stderr=PIPE)
        _, __ = process.communicate()

    def record_process_reference(self, target, arguments, pr):
        """
        Stores a reference to a process so we can keep
        track of its progress

        :param target: the function being called
        :param arguments: the arguments function is being called with
        :param pr: process object
        :return:
        """
        kill_after_time = int(time.time()) + self.command_timeout

        process_reference = ProcessReference(
            target=target,
            args=arguments,
            kill_after_time=kill_after_time,
            process_object=pr,
            retry_max=self.retry_max,
            command_pids=[str(pr.pid)]
        )
        self.process_manager.append_process(process_reference)

    def wait_for_all_processes(self):
        """
        This blocks execution until all the spawned processes and threads have completed
        :return:
        """
        if not self.use_processes:
            # if processes are not being used,
            # this can be skipped
            return

        for pr in self.process_manager.processes:
            pr.process_object.join()

        # let heart beat thread complete after process
        self.process_manager.heart_beat_process.join()
        print('\n\n')
        Utils.print_time()

    def execute_secondary_enumerate_command(
            self, command: str,
            service_data: dict,
            patterns: List[dict]
    ) -> None:
        """
        Executes command for secondary enumeration
        on text-based file outputs

        :param command:
        :param service_data:
        :param patterns:
        :return:
        """
        command = command.format(**service_data)
        if self.command_has_been_run(command):
            return

        text_file_related_parameters = ['-on', 'tee', '-o', '--simple-report']
        output_file = Utils.determine_output_file(command, text_file_related_parameters)
        if Utils.should_run_command(output_file):
            time_before_command = int(time.time())
            command_printed = (
                f"SECONDARY ENUMERATE: [at {time_before_command}]\n"
                f"\t{Colors.GREEN}{command}{Colors.END}"
            )
            self.process_print(command_printed)
            self.run_command(command)
            time_after_command = int(time.time())
            self.process_print(
                f"{Colors.MAGENTA}{Colors.BOLD}Command Completed{Colors.END} "
                f"[at {time_after_command}]\n"
            )

        if output_file is None:
            return

        self.record_pattern_matches(output_file, patterns, service_data)

    def execute_primary_enumerate_command(
            self, command: str,
            service_data: dict,
    ) -> None:
        """
        Executes command for primary enumeration
        on xml-based file outputs

        :param command:
        :param service_data:
        :return:
        """
        command = command.format(**service_data)
        if self.command_has_been_run(command):
            return

        xml_file_related_parameters = ['-ox']
        output_file = Utils.determine_output_file(command, xml_file_related_parameters)
        if Utils.should_run_command(output_file):
            time_before_command = int(time.time())
            command_printed = f"PRIMARY ENUMERATE: [at {time_before_command}]\n\t{Colors.GREEN}{command}{Colors.END}"
            self.process_print(command_printed)
            self.run_command(command)
            time_after_command = int(time.time())
            self.process_print(
                f"{Colors.MAGENTA}{Colors.BOLD}Primary Command Completed{Colors.END} "
                f"[at {time_after_command}]\n"
            )

        if output_file is None:
            return

        self.secondary_enumerate(output_file)

    def secondary_enumerate(self, xml_output_file: str) -> None:
        """
        Does secondary enumeration based on sub-scans in
        the service-scans.toml in config folder and xml-output files
        :param xml_output_file::
        :return:
        """
        if xml_output_file is None:
            return

        infile = open(xml_output_file, "r")
        contents = infile.read()
        soup = BeautifulSoup(contents, 'xml')
        ports = soup.find_all('port')

        for port in ports:
            state = port.find('state')
            if state['state'] == 'closed':
                continue

            service = port.find('service')
            if service is None:
                continue

            service_scan_data = {
                'nmap_extra': self.nmap_extra,
                'scandir': self.scan_directory,
                'target_address': self.target_address,
                'port': port['portid'],
                'protocol': port['protocol'],
                'name': service['name'],
                'username_wordlist': self.services_scan_config['username_wordlist'],
                'password_wordlist': self.services_scan_config['password_wordlist'],
                'secure': True if 'ssl' in service or 'tls' in service else False,
                'scheme': 'https' if 'https' in service or 'ssl' in service or 'tls' in service else 'http'
            }

            self.detected_services.add(
                f"[*] {port['portid']}/{port['protocol']}: {service['name']}    ({state['state']})"
            )
            for service_name, scan in self.services_scan_config.items():
                if type(scan) is not dict:
                    continue
                scan_regex_combined = "(" + ")|(".join(scan['service-names']) + ")"
                match = Utils.extract_matching_string(scan_regex_combined, service['name'])
                if match is None:
                    continue

                if 'exploit' in scan:
                    self.log_exploit_steps(scan['exploit'], service_scan_data)

                if 'scan' not in scan:
                    continue

                for sub_scan in scan['scan']:
                    command = sub_scan['command']
                    patterns = sub_scan.get('pattern', [])
                    if self.use_processes:
                        target = self.execute_secondary_enumerate_command
                        arguments = (command, service_scan_data, patterns)

                        pr = Process(target=target, args=arguments)
                        pr.start()

                        self.record_process_reference(
                            target=target, arguments=arguments, pr=pr)
                    else:
                        self.execute_secondary_enumerate_command(
                            command, service_scan_data, patterns
                        )

            self.log_detected_services()

    def enumerate(self) -> None:
        """
        Does primary enumeration based on scans in the
        service-scans.toml in config folder and user-selected scan type
        :return:
        """
        Utils.print_time()
        print('\n')

        port_scan_data = {
            'nmap_extra': self.nmap_extra,
            'scandir': self.scan_directory,
            'target_address': self.target_address,
            'ports': self.ports,
        }

        for scan_name, scan in self.port_scan_config[self.port_scan_type].items():
            for sub_scan in scan.values():
                command = sub_scan['command']
                self.execute_primary_enumerate_command(
                    command, port_scan_data
                )

    def upload_enumeration_files(self):
        """
        Uploads enumeration file and receives the session_id from the api
        """
        xml_file_pattern = glob(f'{self.scan_directory}/xml/*.xml')
        text_file_pattern = glob(f'{self.scan_directory}/*.txt')
        files_to_upload = {}
        files_uploaded_without_extensions = set()

        for file_path in xml_file_pattern + text_file_pattern:
            file_name_with_extension = os.path.basename(file_path)
            file_path_without_extension = os.path.splitext(file_name_with_extension)[0]

            if file_path_without_extension in files_uploaded_without_extensions:
                continue

            files_uploaded_without_extensions.add(file_path_without_extension)
            files_to_upload[file_name_with_extension] = open(file_path, 'rb').read()

        url = f'{self.api_source}/analyze-enumerate-results'
        attacker_address = self.attacker_address or Utils.get_attacker_address()
        if attacker_address is None:
            raise Utils.AutoHackException('Attacker Address Must Be Defined')

        headers = {
            'API_KEY': self.api_key,
            'target_address': self.target_address,
            'attacker_address': self.attacker_address or Utils.get_attacker_address(),
            'scandir': self.scan_directory,
        }
        response = requests.post(url, files=files_to_upload, headers=headers)
        response_data = response.json()
        self.session_id = response_data['session_id']

    def get_best_attack(self) -> dict:
        """
        Requests the current best attack from the Api
        """
        headers = {
            'API_KEY': self.api_key,
            'session_id': self.session_id,
        }
        url = f'{self.api_source}/get-best-attack'
        response = requests.post(url, headers=headers)
        response_data = response.json()
        return response_data

    def execute_attack(self, attack: dict):
        """
            Runs commands in attack via screen. Some commands
            may block subsequent commands.
        """
        commands = attack['attack_commands']['command_group']
        is_blocking = attack['attack_commands']['meta']['command_should_block']
        vul_name = attack['attack_commands']['vulnerability_name'].strip()
        spearhead_index = attack['attack_commands']['meta']['spearhead_index']

        print(f'{Colors.CYAN}RUNNING ATTACK COMMANDS FOR:{Colors.END} {vul_name}:\n')

        for index, command in enumerate(commands):
            unique_screen_name = str(uuid.uuid4())
            if spearhead_index == index:
                self.spearhead_index_unique_id = unique_screen_name

            command_file_path = os.path.join(
                self.exploit_commands_attempted_directory,
                f'{unique_screen_name}.txt'
            )
            with open(command_file_path, "w") as file:
                file.write(command)

            screen_command = f'screen -dmS {unique_screen_name} bash -c "$(cat {command_file_path})"'

            target = self.run_command
            arguments = (screen_command, False)

            time_before_command = int(time.time())
            command_printed = (
                f"\tEXPLOIT COMMAND-{index} [at {time_before_command}]\n"
                f"\t\t{Colors.GREEN}{command}{Colors.END}\n"
            )
            print(command_printed)

            p = Process(target=target, args=arguments)
            p.start()
            p.join()

            if is_blocking[index]:
                # wait for screen command to finish
                blocking_command = f"tail --pid=$(pgrep -f {unique_screen_name}) -f /dev/null"
                self.run_command(blocking_command, False)

            # remove command file to clean up directory
            os.remove(command_file_path)

    def attack_is_successful(self, attack) -> bool:
        """
            Examines main attack screen output to see if attack was successful.
            * Creates file with screen contents
            * Grabs file contents
            * Sends file to server
        """
        message = f"\t{Colors.MAGENTA}VERIFYING ATTACK SUCCESS{Colors.END} [at {int(time.time())}]"
        print(message)

        # slow execution to ensure attack ran
        time.sleep(2)

        output_file_path = os.path.join(self.exploit_directory, f'{self.spearhead_index_unique_id}.txt')
        grab_screen_command = f"screen -r {self.spearhead_index_unique_id} -X hardcopy {output_file_path}"

        # create file with contents
        self.run_command(grab_screen_command, False)

        main_screen_output = open(output_file_path, 'r').read()

        url = f'{self.api_source}/verify-best-attack-success'
        headers = {
            'API_KEY': self.api_key,
            'session_id': self.session_id,
        }
        response = requests.post(url, json=main_screen_output, headers=headers)
        response_data = response.json()

        return response_data['is_successful']

    def attack_repeatedly(self):
        """
        Repeatedly
            Gets the best attack based on the previous enumeration output
            and then implements attack until all attacks have been exhausted
        """
        self.upload_enumeration_files()

        while True:
            best_attack = self.get_best_attack()
            if not best_attack['should_continue']:
                self.should_kill_heart_beat = True
                print('\n\nNo Exploit Discovered: Ending Program')
                break

            self.execute_attack(best_attack)
            if self.attack_is_successful(best_attack):
                message = f"\t{Colors.GREEN}{Colors.BOLD}ATTACK SUCCESSFUL{Colors.END} [at {int(time.time())}]\n"
                print(message)
                self.success_status = SuccessStatus.EXPLOIT_SUCCESSFUL
                break

    def get_is_user_root_commands(self) -> dict:
        """
        Requests commands and expected answers
        that verify root access
        """
        headers = {
            'API_KEY': self.api_key,
            'session_id': self.session_id,
        }
        url = f'{self.api_source}/get-is-user-root-commands'
        response = requests.post(url, headers=headers)
        response_data = response.json()
        return response_data

    def user_is_root(self) -> bool:
        for key_command in self.expected_command_answers_to_verify_root_access:
            """
                * Runs verification command on spearhead terminal session
                * Runs grab screen session command
                * Examines the output for expected command answers
            """
            data = self.expected_command_answers_to_verify_root_access[key_command]
            expected_output_when_root = data['answer']
            command_wait_time_in_seconds = data['wait']

            verify_command = f'screen -S {self.spearhead_index_unique_id} -X stuff "{key_command}"'
            self.run_command(verify_command, False)

            # delay so command will run
            time.sleep(command_wait_time_in_seconds)

            output_file_path = os.path.join(
                self.priv_directory,
                f'{self.spearhead_index_unique_id}.txt'
            )

            # use the current output file size so we can get the newly written content
            output_filesize_before_running_command = 0
            if os.path.exists(output_file_path):
                output_filesize_before_running_command = os.path.getsize(output_file_path)

            grab_screen_command = f"screen -r {self.spearhead_index_unique_id} -X hardcopy {output_file_path}"
            self.run_command(grab_screen_command, False)

            output_file_handle = open(output_file_path, 'r')
            output_file_handle.seek(output_filesize_before_running_command)
            output_produced_by_command = output_file_handle.read()

            if expected_output_when_root in output_produced_by_command:
                self.success_status = SuccessStatus.PRIVILEGE_ESCALATE_SUCCESSFUL
                return True

        return False

    @staticmethod
    def print_user_is_root_successful():
        message = (
            f'\t{Colors.GREEN}{Colors.BOLD}ROOT ACCESS SUCCESSFUL{Colors.END}'
        )
        print(message)

    @staticmethod
    def print_user_is_root_unsuccessful():
        message = (
            f'\t{Colors.RED}{Colors.BOLD}ROOT ACCESS UNSUCCESSFUL{Colors.END}'
        )
        print(message)

    def privilege_escalate_repeatedly(self):
        """
            when user is not root
                -enumerate
                -repeatedly attack and check for root access
        """
        # load the expected root command-answers once
        self.expected_command_answers_to_verify_root_access = self.get_is_user_root_commands()

        message = (
            f'\n\n'
            f'{Colors.CYAN}VERIFYING ROOT ACCESS:{Colors.END}'
        )
        print(message)

        if self.user_is_root():
            self.print_user_is_root_successful()
            return
        else:
            self.print_user_is_root_unsuccessful()

        # todo: enumerate target os
        # todo: upload target os enumeration data

        # repeatedly exploit and check for root access
        tries = 0
        while True:
            # todo: attack
            if self.user_is_root():
                break

            if tries > 5:
                break

            tries += 1

    def handle_success(self):
        if self.success_status == SuccessStatus.PRIVILEGE_ESCALATE_SUCCESSFUL:
            message = (
                f'\n\n\n'
                f'{Colors.CYAN}{Colors.BOLD}TO ACCESS TARGET ROOT TERMINAL, RUN COMMAND:{Colors.END}\n'
                f'{Colors.GREEN}{Colors.BOLD}\tscreen -r {self.spearhead_index_unique_id}{Colors.END}'
            )
            print(message)
        elif self.success_status == SuccessStatus.EXPLOIT_SUCCESSFUL:
            message = (
                f'\n\n\n'
                f'{Colors.CYAN}{Colors.BOLD}TO ACCESS TARGET NON-ROOT TERMINAL, RUN COMMAND:{Colors.END}\n'
                f'{Colors.GREEN}{Colors.BOLD}\tscreen -r {self.spearhead_index_unique_id}{Colors.END}'
            )
            print(message)

    def exploit(self) -> None:
        """
        When run in professional mode, code exploits based on
        commands generated by AutoHack's API
        :return:
        """
        if self.professional:
            self.attack_repeatedly()

            if self.success_status == SuccessStatus.EXPLOIT_SUCCESSFUL:
                self.privilege_escalate_repeatedly()
                self.handle_success()

        AutoHack.should_kill_heart_beat = True


if __name__ == "__main__":
    """
        * Primary Enumeration is based on the:
            -user-selected port scan type
            -service scans in autohack_config.
            
            This creates xml files which are:
                -(secondary) enumerated
        
        
        * Secondary Enumeration is based on the:
            -xml output from Primary Enumeration
            -service sub-scans in autohack_config.
            
            This creates text files which are:
                -analyzed for regex patterns
    """
    parser = argparse.ArgumentParser(description='Does Enumeration')
    parser.add_argument(
        '-i', '--ip_address', action='store', type=str,
        dest='ip_address', help='IP Address', required=True
    )
    parser.add_argument(
        '-ai', '--attacker_ip_address', action='store', type=str,
        dest='attacker_ip_address', required=False,
        help='Attackers IP Address, Used When Crafting Reverse Shell Payload'
    )
    parser.add_argument(
        '-c', '--config', action='store', type=str,
        dest='config_folder', help='Config Folder'
    )
    parser.add_argument(
        '-o', '--output', action='store', type=str,
        dest='output_folder', help='Output Folder',
    )
    parser.add_argument(
        '--port_scan_type', action='store', type=str,
        default="full", choices=["quick", "udp", "full"],
        dest='port_scan_type', help='Select The Type of Port Scan'
    )
    parser.add_argument(
        '-ne', '--nmap_extra', action='store', type=str,
        dest='nmap_extra', help='Additional Nmap Args', default='-Pn'
    )
    parser.add_argument(
        '-p', '--ports', action='store', type=str,
        dest='ports', help='Ports to Scan', default='80'
    )
    parser.add_argument(
        '--no_processes', action=argparse.BooleanOptionalAction, type=bool,
        dest='no_processes', help='Should Not Use Processes For Secondary Enumeration',
        default=False
    )
    parser.add_argument(
        '-v', '--verbose', action=argparse.BooleanOptionalAction, type=bool,
        dest='verbose', help='Run In Verbose Mode',
        default=False
    )
    parser.add_argument(
        '-ct', '--command_timeout', action='store', type=int,
        dest='command_timeout', help='Max Seconds Command Runs Before Being Killed',
        default=1800
    )
    parser.add_argument(
        '-r', '--retry_max', action='store', type=int,
        dest='retry_max', help='Max Number of Times To Retry A Command',
        default=1
    )
    parser.add_argument(
        '-hbt', '--heart_beat_time', action='store', type=int,
        dest='heart_beat_time', help='Seconds Between Heart Beats',
        default=60
    )
    parser.add_argument(
        '-pro', '--professional',
        action=argparse.BooleanOptionalAction, type=bool, dest='professional',
        help="Uses AutoHack's API to Iteratively Generate and Run Commands "
             "that Attack Target Machine Based on Previous Command-Output Files. "
             "Only Use This If You Have Permission to Penetrate the Target Machine.",
    )
    args = parser.parse_args()
    autohack = AutoHack(
        ip_address=args.ip_address,
        attacker_ip_address=args.attacker_ip_address,
        output_directory=args.output_folder,
        config_directory=args.config_folder,
        port_scan_type=args.port_scan_type,
        nmap_extra=args.nmap_extra,
        ports=args.ports,
        use_processes=not args.no_processes,
        verbose=args.verbose,
        command_timeout=args.command_timeout,
        heart_beat_time=args.heart_beat_time,
        retry_max=args.retry_max,
        professional=args.professional
    )

    autohack.enumerate()
    autohack.exploit()
    autohack.wait_for_all_processes()

    print('AutoHack Complete')
