package certmonitor

import (
	"context"
	"errors"
	"fmt"
	"os"
	"path/filepath"
	"strings"
	"sync"
	"time"

	daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config"
	"github.com/k3s-io/k3s/pkg/daemons/control/deps"
	"github.com/k3s-io/k3s/pkg/daemons/executor"
	"github.com/k3s-io/k3s/pkg/metrics"
	"github.com/k3s-io/k3s/pkg/util"
	"github.com/k3s-io/k3s/pkg/util/services"
	"github.com/k3s-io/k3s/pkg/version"
	"github.com/prometheus/client_golang/prometheus"
	certutil "github.com/rancher/dynamiclistener/cert"
	"github.com/sirupsen/logrus"
	corev1 "k8s.io/api/core/v1"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/apimachinery/pkg/types"
	"k8s.io/apimachinery/pkg/util/wait"
)

var (
	// Check certificates twice an hour. Kubernetes events have a TTL of 1 hour by default,
	// so similar events should be aggregated and refreshed by the event recorder as long
	// as they are created within the TTL period.
	certCheckInterval = time.Minute * 30

	controllerName = version.Program + "-cert-monitor"

	certificateExpirationSeconds = prometheus.NewGaugeVec(prometheus.GaugeOpts{
		Name: version.Program + "_certificate_expiration_seconds",
		Help: "Remaining lifetime in seconds of the certificate, labeled by certificate subject and usages.",
	}, []string{"subject", "usages"})
)

// Setup starts the certificate expiration monitor
func Setup(ctx context.Context, nodeConfig *daemonconfig.Node, dataDir string) error {
	logrus.Debugf("Starting %s with monitoring period %s", controllerName, certCheckInterval)
	metrics.DefaultRegisterer.MustRegister(certificateExpirationSeconds)

	client, err := util.GetClientSet(nodeConfig.AgentConfig.KubeConfigKubelet)
	if err != nil {
		return err
	}

	recorder := util.BuildControllerEventRecorder(client, controllerName, metav1.NamespaceDefault)

	// This is consistent with events attached to the node generated by the kubelet
	// https://github.com/kubernetes/kubernetes/blob/612130dd2f4188db839ea5c2dea07a96b0ad8d1c/pkg/kubelet/kubelet.go#L479-L485
	nodeRef := &corev1.ObjectReference{
		Kind:      "Node",
		Name:      nodeConfig.AgentConfig.NodeName,
		UID:       types.UID(nodeConfig.AgentConfig.NodeName),
		Namespace: "",
	}

	// Create a dummy controlConfig just to hold the paths for the server certs
	controlConfig := daemonconfig.Control{
		DataDir: filepath.Join(dataDir, "server"),
		Runtime: &daemonconfig.ControlRuntime{},
	}
	deps.CreateRuntimeCertFiles(&controlConfig)

	startupOnce := &sync.Once{}
	caMap := map[string][]string{}
	nodeList := services.Agent
	if _, err := os.Stat(controlConfig.DataDir); err == nil {
		nodeList = services.All
		caMap, err = services.FilesForServices(controlConfig, services.CA)
		if err != nil {
			return err
		}
	}

	nodeMap, err := services.FilesForServices(controlConfig, nodeList)
	if err != nil {
		return err
	}

	go wait.Until(func() {
		// don't check and create events until after the apiserver is up, otherwise the events may be lost.
		<-executor.APIServerReadyChan()

		logrus.Debugf("Running %s certificate expiration check", controllerName)
		var hasErr bool
		if err := checkCerts(nodeMap, time.Hour*24*daemonconfig.CertificateRenewDays); err != nil {
			message := fmt.Sprintf("Node certificates require attention - restart %s on this node to trigger automatic rotation: %v", version.Program, err)
			recorder.Event(nodeRef, corev1.EventTypeWarning, "CertificateExpirationWarning", message)
			hasErr = true
		}
		if err := checkCerts(caMap, time.Hour*24*365); err != nil {
			message := fmt.Sprintf("Certificate Authority certificates require attention - check %s documentation and begin planning rotation: %v", version.Program, err)
			recorder.Event(nodeRef, corev1.EventTypeWarning, "CACertificateExpirationWarning", message)
			hasErr = true
		}
		// Only check for no errors and emit an OK event once, on the initial check after startup.
		startupOnce.Do(func() {
			if !hasErr {
				message := fmt.Sprintf("Node and Certificate Authority certificates managed by %s are OK", version.Program)
				recorder.Event(nodeRef, corev1.EventTypeNormal, "CertificateExpirationOK", message)
			}
		})
	}, certCheckInterval, ctx.Done())

	return nil
}

func checkCerts(fileMap map[string][]string, warningPeriod time.Duration) error {
	errs := []error{}
	now := time.Now()
	warn := now.Add(warningPeriod)

	for service, files := range fileMap {
		for _, file := range files {
			basename := filepath.Base(file)
			certs, _ := certutil.CertsFromFile(file)
			for _, cert := range certs {
				usages := util.GetCertUsages(cert)
				certificateExpirationSeconds.WithLabelValues(cert.Subject.String(), strings.Join(usages, ",")).Set(cert.NotAfter.Sub(now).Seconds())
				status := util.GetCertStatus(cert, now, warn)
				if status != util.CertStatusOK {
					switch status {
					case util.CertStatusNotYetValid:
						errs = append(errs, fmt.Errorf("%s/%s: certificate %s is not valid before %s", service, basename, cert.Subject, cert.NotBefore.Format(time.RFC3339)))
					case util.CertStatusExpired:
						errs = append(errs, fmt.Errorf("%s/%s: certificate %s expired at %s", service, basename, cert.Subject, cert.NotAfter.Format(time.RFC3339)))
					case util.CertStatusWarning:
						errs = append(errs, fmt.Errorf("%s/%s: certificate %s will expire within %d days at %s", service, basename, cert.Subject, int(warningPeriod.Hours()/24), cert.NotAfter.Format(time.RFC3339)))
					}
				}
			}
		}
	}

	return errors.Join(errs...)
}
