#!/usr/bin/env python
# Copyright (C) 2013  Lukas Rist <glaslos@gmail.com>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import gevent.monkey
gevent.monkey.patch_all()

import logging
import os
import argparse
import sys
import pwd
import grp
import platform
import ast
from ConfigParser import ConfigParser

import gevent
from lxml import etree

import conpot
import conpot.core as conpot_core

from conpot.core.loggers.log_worker import LogWorker
from conpot.protocols.snmp.snmp_server import SNMPServer
from conpot.protocols.modbus.modbus_server import ModbusServer
from conpot.protocols.s7comm.s7_server import S7Server
from conpot.protocols.http.web_server import HTTPServer
from conpot.protocols.kamstrup.meter_protocol.kamstrup_server import KamstrupServer
from conpot.protocols.kamstrup.management_protocol.kamstrup_management_server import KamstrupManagementServer

from conpot.emulators.proxy import Proxy
from conpot.utils import ext_ip


logger = logging.getLogger()
package_directory = os.path.dirname(os.path.abspath(conpot.__file__))


def logo():
    print """
                       _
   ___ ___ ___ ___ ___| |_
  |  _| . |   | . | . |  _|
  |___|___|_|_|  _|___|_|
              |_|

  Version {0}
  Glastopf Project
""".format(conpot.__version__)


def setup_logging(log_file):
    log_format = logging.Formatter('%(asctime)-15s %(message)s')
    root_logger = logging.getLogger()
    console_log = logging.StreamHandler()
    console_log.setLevel(logging.DEBUG)
    console_log.setFormatter(log_format)
    root_logger.addHandler(console_log)
    logger.setLevel(logging.DEBUG)
    file_log = logging.FileHandler(log_file)
    file_log.setFormatter(log_format)
    file_log.setLevel(logging.DEBUG)
    root_logger.addHandler(file_log)


def drop_privileges(uid_name='nobody', gid_name='nogroup'):
    wanted_uid = pwd.getpwnam(uid_name)[2]
    #special handling for os x. (getgrname has trouble with gid below 0)
    if platform.mac_ver()[0] and platform.mac_ver()[0] < float('10.9'):
        wanted_gid = -2
    else:
        wanted_gid = grp.getgrnam(gid_name)[2]

    os.setgid(wanted_gid)
    os.setuid(wanted_uid)

    new_uid_name = pwd.getpwuid(os.getuid())[0]
    new_gid_name = grp.getgrgid(os.getgid())[0]

    logger.info("Privileges dropped, running as {0}/{1}.".format(new_uid_name, new_gid_name))


def validate_template(xml_file):
    xml_schema = etree.parse(os.path.join(package_directory, 'template.xsd'))
    xsd = etree.XMLSchema(xml_schema)
    xml = etree.parse(xml_file)
    xsd.validate(xml)
    if xsd.error_log:
        logger.error('Error parsing XML template: {0}'.format(xsd.error_log))
        sys.exit(1)

def main():
    logo()

    parser = argparse.ArgumentParser()

    parser.add_argument("-w", "--www",
                        help="public www path",
                        default=os.path.join(package_directory, 'www/'),
                        metavar="www/"
    )
    parser.add_argument("-r", "--www-root",
                        help="www root page",
                        default="index.htm",
                        metavar="page.htm"
    )
    parser.add_argument("-t", "--template",
                        help="the template to use",
                        default=os.path.join(package_directory, 'templates/default.xml'),
                        metavar="template.xml"
    )
    parser.add_argument("-c", "--config",
                        help="the configuration file to use",
                        default=os.path.join(package_directory, 'conpot.cfg'),
                        metavar="config.cfg"
    )
    parser.add_argument("-l", "--logfile",
                        help="the template to use",
                        default="conpot.log"
    )
    parser.add_argument("-a", "--raw_mib",
                        help="path to raw MIB files."
                             "(will automatically get compiled by build-pysnmp-mib)",
                        action='append',
                        default=[os.getcwd()])
    parser.add_argument("-m", "--mibpaths",
                        action='append',
                        help="path to compiled PySNMP MIB files. (must be compiled with build-pysnmp-mib)",
                        default=[])
    args = parser.parse_args()

    setup_logging(args.logfile)

    config = ConfigParser()
    if not os.path.isfile(args.config):
        args.config = os.path.join(package_directory, 'conpot.cfg')
        logger.info('No conpot.cfg found in current directory, using default configuration: {0}'.format(args.config))
    config.read(args.config)

    # look for template in CWD
    if not os.path.isfile(args.template):
        if os.path.isfile(os.path.join(package_directory, 'templates', args.template)):
            template = os.path.join(package_directory, 'templates', args.template)
        else:
            logger.error('Template not found: {0}'.format(args.template))
            sys.exit(1)
    else:
        template = args.template
    if not os.path.isfile(args.config):
        logger.error('Config not found: {0}'.format(args.config))
        sys.exit(1)

    logger.info('Starting Conpot using template found in: {0}'.format(template))
    logger.info('Starting Conpot using configuration found in: {0}'.format(args.config))
    logger.info('Starting Conpot using www templates found in: {0}'.format(args.www))

    protocol_greenlets = list()
    servers = list()

    session_manager = conpot_core.get_sessionManager()
    session_manager.initialize_databus(template)

    public_ip = None
    if config.getboolean('fetch_public_ip', 'enabled'):
        public_ip = ext_ip.get_ext_ip(config)

    validate_template(template)

    dom = etree.parse(template)

    log_worker = LogWorker(config, dom, session_manager, public_ip)
    gevent.spawn(log_worker.start)

    if dom.xpath('//conpot_template/protocols/modbus'):
        if ast.literal_eval(dom.xpath('//conpot_template/protocols/modbus/@enabled')[0]):
            modbus_host = dom.xpath('//conpot_template/protocols/modbus/@host')[0]
            modbus_port = ast.literal_eval(dom.xpath('//conpot_template/protocols/modbus/@port')[0])
            modbus_server = ModbusServer(template)
            modbus_server = modbus_server.get_server(modbus_host, modbus_port)
            servers.append(modbus_server)
            protocol_greenlets.append(gevent.spawn(modbus_server.start))

    if dom.xpath('//conpot_template/protocols/snmp'):
        if ast.literal_eval(dom.xpath('//conpot_template/protocols/snmp/@enabled')[0]):
            snmp_host = dom.xpath('//conpot_template/protocols/snmp/@host')[0]
            snmp_port = ast.literal_eval(dom.xpath('//conpot_template/protocols/snmp/@port')[0])
            snmp_server = SNMPServer(snmp_host, snmp_port, template, args.mibpaths, args.raw_mib)
            servers.append(snmp_server)
            protocol_greenlets.append(gevent.spawn(snmp_server.start))

    if dom.xpath('//conpot_template/protocols/s7comm'):
        if ast.literal_eval(dom.xpath('//conpot_template/protocols/s7comm/@enabled')[0]):
            s7_host = dom.xpath('//conpot_template/protocols/s7comm/@host')[0]
            s7_port = ast.literal_eval(dom.xpath('//conpot_template/protocols/s7comm/@port')[0])
            s7_instance = S7Server(template)
            s7_server = s7_instance.get_server(s7_host, s7_port)
            servers.append(s7_server)
            protocol_greenlets.append(gevent.spawn(s7_server.start))

    if dom.xpath('//conpot_template/protocols/http'):
        if ast.literal_eval(dom.xpath('//conpot_template/protocols/http/@enabled')[0]):
            http_host = dom.xpath('//conpot_template/protocols/http/@host')[0]
            http_port = ast.literal_eval(dom.xpath('//conpot_template/protocols/http/@port')[0])
            http_server = HTTPServer(http_host, http_port, template, args.www)
            servers.append(http_server)
            protocol_greenlets.append(gevent.spawn(http_server.start))

    if dom.xpath('//conpot_template/protocols/kamstrup'):
        if ast.literal_eval(dom.xpath('//conpot_template/protocols/kamstrup/@enabled')[0]):
            host = dom.xpath('//conpot_template/protocols/kamstrup/@host')[0]
            port = ast.literal_eval(dom.xpath('//conpot_template/protocols/kamstrup/@port')[0])
            kamstrup_instance = KamstrupServer(template)
            kamstrup_server = kamstrup_instance.get_server(host, port)
            servers.append(kamstrup_server)
            protocol_greenlets.append(gevent.spawn(kamstrup_server.start))

    if dom.xpath('//conpot_template/protocols/kamstrup_management'):
        if ast.literal_eval(dom.xpath('//conpot_template/protocols/kamstrup_management/@enabled')[0]):
            host = dom.xpath('//conpot_template/protocols/kamstrup_management/@host')[0]
            port = ast.literal_eval(dom.xpath('//conpot_template/protocols/kamstrup_management/@port')[0])
            kamstrup_management_instance = KamstrupManagementServer(template)
            kamstrup_management_server = kamstrup_management_instance.get_server(host, port)
            servers.append(kamstrup_management_server)
            protocol_greenlets.append(gevent.spawn(kamstrup_management_server.start))

    proxies = dom.xpath('//conpot_template/proxies/*')
    for p in proxies:
        name = p.attrib['name']
        host = p.attrib['host']
        keyfile = None
        certfile = None
        if 'keyfile' in p.attrib and 'certfile' in p.attrib:
            keyfile = p.attrib['keyfile']
            certfile = p.attrib['certfile']
            # if path is absolute we assert that the cert and key is located in same location as template
            if not os.path.isabs(keyfile):
                keyfile = os.path.join(os.path.dirname(template), keyfile)
                certfile = os.path.join(os.path.dirname(template), certfile)
        port = ast.literal_eval(p.attrib['port'])
        proxy_host = p.xpath('./proxy_host/text()')[0]
        proxy_port = ast.literal_eval(p.xpath('./proxy_port/text()')[0])
        decoder = p.xpath('./decoder/text()')
        if len(decoder) > 0:
            decoder = decoder[0]
        else:
            decoder = None
        proxy_instance = Proxy(name, proxy_host, proxy_port, decoder, keyfile, certfile)
        proxy_server = proxy_instance.get_server(host, port)
        servers.append(proxy_instance)
        protocol_greenlets.append(gevent.spawn(proxy_server.start))

    # Wait for the services to bind ports before dropping privileges
    gevent.sleep(5)
    drop_privileges()
    try:
        if len(protocol_greenlets) > 0:
            gevent.wait()
    except KeyboardInterrupt:
        logging.info('Stopping Conpot')
        for server in servers:
            server.stop()
    finally:
        # Just being nice and cosy!
        gevent.joinall(protocol_greenlets, 2)


if __name__ == "__main__":
    main()
