#!/usr/bin/env python

import os
import sys
from argparse import ArgumentParser
from binascii import hexlify, unhexlify
from configparser import ConfigParser
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.path.pardir))
import bna


class Authenticator(object):
	def __init__(self, args):
		arguments = ArgumentParser(prog="bna")
		arguments.add_argument("serial", nargs="?")
		arguments.add_argument(
			"-n", "--new",
			action="store_true",
			dest="new",
			help="request a new authenticator"
		)
		arguments.add_argument(
			"--config",
			type=str,
			dest="config",
			help="specify config file to use"
		)
		arguments.add_argument(
			"-d", "--delete",
			action="store_true",
			dest="delete",
			help="delete a stored serial and its matching secret"
		)
		arguments.add_argument(
			"-i", "--interactive",
			action="store_true",
			dest="update",
			help="interactive mode: updates the token as soon as it expires"
		)
		arguments.add_argument(
			"-l", "--list",
			action="store_true",
			dest="list",
			help="list all your active serials and exit"
		)
		arguments.add_argument(
			"-r", "--region",
			type=str,
			dest="region",
			default="US",
			help="desired region for new authenticators"
		)
		arguments.add_argument(
			"--remaining", action="store_true",
			dest="remaining",
			help="also print the remaining time until the token expires"
		)
		arguments.add_argument(
			"--restore",
			nargs=2,
			type=str,
			dest="restore",
			metavar=("SERIAL", "CODE"),
			help="restores an existing authenticator"
		)
		arguments.add_argument(
			"--restore-code",
			action="store_true",
			dest="restorecode",
			help="prints a serial's restore code and exit"
		)
		arguments.add_argument(
			"--set-default",
			action="store_true",
			dest="setdefault",
			help="set authenticator as default (also works when requesting a new authenticator)"
		)
		arguments.add_argument(
			"-v", "--version",
			action="version",
			version="bna %s" % (bna.__version__)
		)
		self.args = arguments.parse_args(args)

		self.config = ConfigParser()
		try:
			self.config.read([self.configFile()])
		except Exception as e:
			self.error("Could not parse config file %r: %s" % (self.configFile(), e))

	def _serials(self):
		return [x for x in self.config.sections() if x != "bna"]

	def error(self, msg):
		sys.stderr.write("Error: %s\n" % (msg))
		exit(1)

	def print(self, msg):
		print(msg)

	def add_serial(self, serial, secret):
		self.set_secret(serial, secret)

		# We set the serial as default if we don't have one set already
		# Otherwise, we check for --set-default
		if self.args.setdefault or not self.get_default_serial():
			self.set_default_serial(serial)

	def delete_serial(self, serial):
		serial = bna.normalize_serial(serial)
		if not self.config.has_section(serial):
			self.error("No such serial: %r" % (serial))
		self.config.remove_section(serial)

		# If it's the default serial, remove that
		if serial == self.get_default_serial():
			self.config.remove_option("bna", "default_serial")

		self.writeConfig()
		self.print("Successfully deleted serial %s" % (bna.prettify_serial(serial)))

	def restore_serial(self, serial, code):
		serial = bna.normalize_serial(serial)
		if self.config.has_option(serial, "secret"):
			self.error("A secret already exists for this serial. Try deleting it first with bna --delete %s" % (serial))

		secret = bna.restore(serial, code)
		self.add_serial(serial, hexlify(secret).decode("ascii"))
		self.print("Restored serial %s" % (bna.prettify_serial(serial)))

	def list_serials(self):
		default = self.get_default_serial()
		total = 0
		for serial in self._serials():
			if serial == default:
				self.print("%s (default)" % (serial))
			else:
				self.print(serial)
			total += 1

		self.print("%i items" % (total))

	def query_new_authenticator(self):
		try:
			serial, secret = bna.request_new_serial(self.args.region)
		except bna.HTTPError as e:
			self.error("Could not connect: %s" % (e))

		serial = bna.normalize_serial(serial)
		secret = hexlify(secret).decode("ascii")

		self.add_serial(serial, secret)

		self.print("Success. Your new serial is: %s" % (bna.prettify_serial(serial)))

	def run_live(self):
		from time import sleep
		self.print("Ctrl-C to exit")
		while 1:
			token, time_remaining = bna.get_token(secret=self._secret)
			if self.args.remaining:
				sys.stdout.write("\r%08i (%02is remaining)" % (token, time_remaining))
			else:
				sys.stdout.write("\r%08i" % (token))
			sys.stdout.flush()
			sleep(1)

	def configFile(self):
		"""
		Gets the path to the config file
		"""
		if self.args.config:
			return os.path.expanduser(self.args.config)

		def configDir():
			configdir = "bna"
			home = os.environ.get("HOME")
			if os.name == "posix":
				base = os.environ.get("XDG_CONFIG_HOME", os.path.join(home, ".config"))
				path = os.path.join(base, configdir)
			elif os.name == "nt":
				base = os.environ["APPDATA"]
				path = os.path.join(base, configdir)
			else:
				path = home

			if not os.path.exists(path):
				os.makedirs(path)
			return path

		return os.path.join(configDir(), "bna.conf")

	def get_default_serial(self):
		if not self.config.has_option("bna", "default_serial"):
			return None
		return self.config.get("bna", "default_serial")

	def set_default_serial(self, serial):
		if not self.config.has_section("bna"):
			self.config.add_section("bna")
		self.config.set("bna", "default_serial", serial)
		self.writeConfig()

	def writeConfig(self):
		try:
			with open(self.configFile(), "w") as f:
				self.config.write(f)
		except IOError as e:
			self.error("Could not open %r for writing: %s" % (self.configFile(), e))

	def get_secret(self, serial):
		if not self.config.has_section(serial):
			return None

		secret = self.config.get(serial, "secret")
		return bytearray(secret, "ascii")

	def set_secret(self, serial, secret):
		if not self.config.has_section(serial):
			self.config.add_section(serial)
		self.config.set(serial, "secret", secret)

		self.writeConfig()

	def run(self):
		# Are we requesting a new authenticator?
		if self.args.new:
			self.query_new_authenticator()
			return 0

		if self.args.delete:
			if not self.args.serial:
				self.error("You must provide a serial with --delete")
			self.delete_serial(self.args.serial)
			return 0

		if self.args.list:
			self.list_serials()
			return 0

		if self.args.restore:
			self.restore_serial(*self.args.restore)
			return 0

		if not self.args.serial:
			serial = self.get_default_serial()
			if serial is None:
				if self._serials():
					self.error("You do not have a default serial set. must provide an authenticator serial or set a default one with 'bna --set-default'.")
				else:
					self.error("You do not have any configured authenticators. Create a new one with 'bna --new' or try 'bna --help' for more information")
		else:
			serial = self.args.serial
		serial = bna.normalize_serial(serial)

		# Are we setting a serial as default?
		if self.args.setdefault:
			self.set_default_serial(serial)

		# Get the secret from the keyring
		self._secret = self.get_secret(serial)
		if sys.version_info < (2, 7):
			self._secret = self._secret.decode("utf-8")
		if not self._secret:
			self.error("No such serial: %r" % (serial))
		self._secret = unhexlify(self._secret)

		# Print the restore code if the user asked for it
		if self.args.restorecode:
			code = bna.get_restore_code(serial, self._secret)
			self.print(code)
			return 0

		# otherwise print the token
		if self.args.update:
			self.run_live()
		else:
			token, time_remaining = bna.get_token(secret=self._secret)
			if self.args.remaining:
				self.print("%08i (%02is remaining)" % (token, time_remaining))
			else:
				self.print("%08i" % (token))

		return 0


def main():
	import signal
	signal.signal(signal.SIGINT, signal.SIG_DFL)
	authenticator = Authenticator(sys.argv[1:])
	return authenticator.run()


if __name__ == "__main__":
	exit(main())
