#!/usr/bin/env python3

import argparse
import time
import logging
import upnpclient

from prometheus_client import start_http_server, Counter


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("ip", help="UPNP device supporting the WANCommonIFC1 service")
    parser.add_argument("port", help="Port that UPNP is running on")

    return parser.parse_args()


class WANCommonIFC1:
    def __init__(self, ip, port):
        self.device = upnpclient.Device(f"http://{ip}:{port}/rootDesc.xml")
        self.service = self.device.WANCommonIFC1

        self.last_total_bytes_sent = int(
            self.service.GetTotalBytesSent()["NewTotalBytesSent"]
        )
        self.last_total_bytes_received = int(
            self.service.GetTotalBytesReceived()["NewTotalBytesReceived"]
        )

    def total_bytes_sent(self):
        """
        Total bytes sent since last call to this function
        """
        total_bytes_sent = int(self.service.GetTotalBytesSent()["NewTotalBytesSent"])
        bytes_sent = total_bytes_sent - self.last_total_bytes_sent
        self.last_total_bytes_sent = total_bytes_sent
        return bytes_sent

    def total_bytes_received(self):
        """
        Total bytes received since last call to this function
        """
        total_bytes_received = int(
            self.service.GetTotalBytesReceived()["NewTotalBytesReceived"]
        )
        bytes_received = total_bytes_received - self.last_total_bytes_received
        self.last_total_bytes_received = total_bytes_received
        return bytes_received


def monitor_network(service, total_bytes_sent_counter, total_bytes_received):
    while True:
        if (bytes_sent := service.total_bytes_sent()) > 0:
            total_bytes_sent_counter.inc(bytes_sent)
        if (bytes_received := service.total_bytes_received()) > 0:
            total_bytes_received.inc(bytes_received)

        time.sleep(1)


def main():
    args = parse_args()
    logging.basicConfig(
        level=logging.INFO,
        datefmt="%Y-%m-%dT%H:%M:%S%z",
        format="%(asctime)s %(levelname)s %(message)s",
    )

    logging.info(f"Starting WANCommonIFC1_exporter using {args.ip}:{args.port}")

    # Start prometheus HTTP server
    start_http_server(8125)
    logging.info(f"Listening on port 8125")
    total_bytes_sent_counter = Counter("bytes_sent_total", "Total bytes sent")
    total_bytes_received = Counter("bytes_received_total", "Total bytes received")

    service = WANCommonIFC1(args.ip, args.port)

    monitor_network(service, total_bytes_sent_counter, total_bytes_received)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        pass