#!/usr/bin/env python
#
# Sends LLDP packets out of a specified network interface,
# either once or periodically at a fixed interval.
# The process can be optionally daemonized.
#
# Previous version:
#
# GitHub Repository: bigswitch/bigswitchcontroller
# Branch: bigdb
# Path: bare-metal/hypervisor/pcap_start_lldp.py
#

import argparse
import glob
import os
import socket
import time

LLDP_DST_MAC = "01:80:c2:00:00:0e"
LLDP_ETHERTYPE = 0x88cc
TTL = 120
CHASSIS_ID_LOCALLY_ASSIGNED = 7
PORT_ID_INTERFACE_ALIAS = 1
SYS_CLASS_NET = "/sys/class/net"

def parse_args():
    parser = argparse.ArgumentParser()

    # LLDP packet arguments
    parser.add_argument("--network_interface")
    parser.add_argument("--system-name")
    parser.add_argument("--system-desc")

    # Other arguments
    parser.add_argument("-i", "--interval", type=int, default=0)
    parser.add_argument("-d", "--daemonize", action="store_true", default=False)

    return parser.parse_args()

def validate_num_bits_of_int(int_value, num_bits, name=None):
    mask = pow(2, num_bits) - 1
    if (int_value & mask) != int_value:
        name = name if name else "The integer value"
        raise ValueError("%s must be %d-bit long. Given: %d (%s)"
                % (name, num_bits, int_value, hex(int_value)))

def raw_bytes_of_hex_str(hex_str):
    return hex_str.decode("hex")

def raw_bytes_of_mac_str(mac_str):
    return raw_bytes_of_hex_str(mac_str.replace(":", ""))

def raw_bytes_of_int(int_value, num_bytes, name=None):
    validate_num_bits_of_int(int_value, num_bytes * 8, name)
    template = "%0" + "%d" % (num_bytes * 2) + "x"
    return raw_bytes_of_hex_str(template % int_value)

def get_mac_str(network_interface):
    with open("/sys/class/net/%s/address" % network_interface) as f:
        return f.read().strip()

def lldp_ethertype():
    return raw_bytes_of_int(LLDP_ETHERTYPE, 2, "LLDP ethertype")

def validate_tlv_type(type_):
    validate_num_bits_of_int(type_, 7, "TLV type")

def validate_tlv_length(length):
    validate_num_bits_of_int(length, 9, "TLV length")

def tlv_1st_2nd_bytes_of(type_, length):
    validate_tlv_type(type_)
    validate_tlv_length(length)
    int_value = (type_ << (8 + 1)) | length
    return raw_bytes_of_int(int_value, 2, "First 2 bytes of TLV")

def tlv_of(type_, str_value):
    return tlv_1st_2nd_bytes_of(type_, len(str_value)) + str_value

def chassis_id_tlv_of(chassis_id, subtype=CHASSIS_ID_LOCALLY_ASSIGNED):
    return tlv_of(1,
            raw_bytes_of_int(subtype, 1, "Chassis ID subtype") + chassis_id)

def port_id_tlv_of(port_id, subtype=PORT_ID_INTERFACE_ALIAS):
    return tlv_of(2, raw_bytes_of_int(subtype, 1, "Port ID subtype") + port_id)

def ttl_tlv_of(ttl_seconds):
    return tlv_of(3, raw_bytes_of_int(ttl_seconds, 2, "TTL (seconds)"))

def port_desc_tlv_of(port_desc):
    return tlv_of(4, port_desc)

def system_name_tlv_of(system_name):
    return tlv_of(5, system_name)

def system_desc_tlv_of(system_desc):
    return tlv_of(6, system_desc)

def end_tlv():
    return tlv_of(0, "")

def lldp_frame_of(chassis_id,
                  network_interface,
                  ttl,
                  system_name=None,
                  system_desc=None):
    port_mac_str = get_mac_str(network_interface)
    contents = [
        # Ethernet header
        raw_bytes_of_mac_str(LLDP_DST_MAC),
        raw_bytes_of_mac_str(port_mac_str),
        lldp_ethertype(),

        # Required LLDP TLVs
        chassis_id_tlv_of(chassis_id),
        port_id_tlv_of(network_interface),
        ttl_tlv_of(ttl),
        port_desc_tlv_of(port_mac_str)
        ]

    # Optional LLDP TLVs
    if system_name is not None:
        contents.append(system_name_tlv_of(system_name))
    if system_desc is not None:
        contents.append(system_desc_tlv_of(system_desc))

    # End TLV
    contents.append(end_tlv())

    return "".join(contents)

def daemonize():
    # Do not use this code for daemonizing elsewhere as this is
    # a very simple version that is just good enough for here.
    pid = os.fork()
    if pid != 0:
        # Exit from the parent process
        os._exit(os.EX_OK)

    os.setsid()

    pid = os.fork()
    if pid != 0:
        # Exit from the 2nd parent process
        os._exit(os.EX_OK)

def is_active_nic(interface_name):
    try:
        if interface_name == 'lo':
            return False

        addr_assign_type = None
        with open(SYS_CLASS_NET + '/%s/addr_assign_type' % interface_name,
                  'r') as f:
            addr_assign_type = int(f.read().rstrip())

        carrier = None
        with open(SYS_CLASS_NET + '/%s/carrier' % interface_name, 'r') as f:
            carrier = int(f.read().rstrip())

        address = None
        with open(SYS_CLASS_NET + '/%s/address' % interface_name, 'r') as f:
            address = f.read().rstrip()

        if addr_assign_type == 0 and carrier == 1 and address:
            return True
        else:
            return False
    except IOError:
        return False

def ordered_active_nics():
    embedded_nics = []
    nics = []
    for name in glob.iglob(SYS_CLASS_NET + '/*'):
        nic = name[(len(SYS_CLASS_NET) + 1):]
        if is_active_nic(nic):
            if nic.startswith('em') or nic.startswith('eth') or \
                    nic.startswith('eno'):
                embedded_nics.append(nic)
            else:
                nics.append(nic)
    return sorted(embedded_nics) + sorted(nics)

def main():
    args = parse_args()

    if args.daemonize:
        daemonize()

    senders = []
    frames = []
    chassis = "00:00:00:00:00:00"
    nics = ordered_active_nics()
    if len(nics) != 0:
        chassis = get_mac_str(nics[0])
    intfs = args.network_interface.split(',')
    for intf in intfs:
      interface = intf.strip()
      frame = lldp_frame_of(chassis_id=chassis,
                            network_interface=interface,
                            ttl=TTL,
                            system_name=args.system_name,
                            system_desc=args.system_desc)
      frames.append(frame)

      # Send the frame
      s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW)
      s.bind((interface, 0))
      senders.append(s)

    while True:
        for idx, s in enumerate(senders):
            s.send(frames[idx])
        if not args.interval:
            break
        time.sleep(args.interval)

if __name__ == "__main__":
    main()

