#! /usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2014-2015 Université Catholique de Louvain.
#
# This file is part of INGInious.
#
# INGInious is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# INGInious is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License along with INGInious.  If not, see <http://www.gnu.org/licenses/>.
import argparse
import os
import socket
import tempfile
import select
import sys
from subprocess import Popen

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("host", help="The host to connect to. It can be an agent or a frontend, since both respect the same specifications. Should "
                                     "be in the form host:port")
    args = parser.parse_args()

    try:
        remote_host = args.host.split(":")[0]
        remote_port = int(args.host.split(":")[1])
    except:
        print "Invalid value for host"
        exit(1)

    conn_id = sys.stdin.readline().strip()
    private_key = []
    while len(private_key) == 0 or private_key[-1] != "-----END RSA PRIVATE KEY-----":
        private_key.append(sys.stdin.readline().strip())
    private_key = "\n".join(private_key)
    
    tmpdir = tempfile.mkdtemp()
    with open(os.path.join(tmpdir, "key"), "w") as f:
        f.write(private_key)
    os.chmod(os.path.join(tmpdir, "key"), 0700)
    
    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server.bind(("127.0.0.1", 0))
    ssh_port = server.getsockname()[1]
    server.listen(1)
    
    def run_relay(conn):
        client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            client.connect((remote_host, remote_port))
            print "Connected to frontend/agent"
            client.send(conn_id + "\n")
            print "Conn id sent to frontend/agent"
            retval = ""
            while retval not in ["ok\n", "ko\n"]:
                retval += client.recv(3)
            retval = retval.strip()
            if retval != "ok":
                print "Cannot connect to the remote ssh server. Invalid connection id."
                exit(1)
            print "Received ok"
            while True:
                read, write, exception = select.select([conn, client], [], [conn, client])
                if len(exception) != 0:
                    break
                if conn in read:
                    data = conn.recv(1024)
                    if data is None or len(data) == 0:
                        break
                    client.send(data)
                elif client in read:
                    data = client.recv(1024)
                    if data is None or len(data) == 0:
                        break
                    conn.send(data)
        except:
            pass
    
        client.close()
        conn.close()

    args = ['ssh', '-o', 'UserKnownHostsFile=/dev/null',
                   '-o', 'StrictHostKeyChecking=no',
                   '-p', str(ssh_port),
                   '-i', os.path.join(tmpdir, "key"),
                   'worker@127.0.0.1']

    p = Popen(args, stdin=sys.stdin, stdout=sys.stdout)

    conn, addr = server.accept()
    print "Connection accepted"
    server.close()
    run_relay(conn)

    p.wait()
