// Copyright (c) 2025 Tigera, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package intdataplane

import (
	"fmt"
	"sort"
	"strings"

	"github.com/sirupsen/logrus"

	dpsets "github.com/projectcalico/calico/felix/dataplane/ipsets"
	"github.com/projectcalico/calico/felix/ipsets"
	"github.com/projectcalico/calico/felix/proto"
	"github.com/projectcalico/calico/felix/rules"
	"github.com/projectcalico/calico/felix/types"
)

type dscpManager struct {
	ipVersion    uint8
	ruleRenderer rules.RuleRenderer
	mangleTable  Table

	// QoS policies.
	wepPolicies map[types.WorkloadEndpointID]*rules.DSCPRule
	hepPolicies map[types.HostEndpointID]*rules.DSCPRule
	dirty       bool

	// IPSet.
	ipsetsDataplane dpsets.IPSetsDataplane
	ipSetMetadata   ipsets.IPSetMetadata

	logCtx *logrus.Entry
}

func newDSCPManager(
	ipsetsDataplane dpsets.IPSetsDataplane,
	mangleTable Table,
	ruleRenderer rules.RuleRenderer,
	ipVersion uint8,
	dpConfig Config,
) *dscpManager {
	return &dscpManager{
		mangleTable:     mangleTable,
		ruleRenderer:    ruleRenderer,
		ipVersion:       ipVersion,
		wepPolicies:     map[types.WorkloadEndpointID]*rules.DSCPRule{},
		hepPolicies:     map[types.HostEndpointID]*rules.DSCPRule{},
		dirty:           true,
		ipsetsDataplane: ipsetsDataplane,
		ipSetMetadata: ipsets.IPSetMetadata{
			MaxSize: dpConfig.MaxIPSetSize,
			SetID:   rules.IPSetIDDSCPEndpoints,
			Type:    ipsets.IPSetTypeHashNet,
		},
		logCtx: logrus.WithField("ipVersion", ipVersion),
	}
}

func (m *dscpManager) OnUpdate(msg interface{}) {
	switch msg := msg.(type) {
	case *proto.HostEndpointUpdate:
		m.handleHEPUpdates(msg.GetId(), msg)
	case *proto.HostEndpointRemove:
		m.handleHEPUpdates(msg.GetId(), nil)
	case *proto.WorkloadEndpointUpdate:
		m.handleWEPUpdates(msg.GetId(), msg)
	case *proto.WorkloadEndpointRemove:
		m.handleWEPUpdates(msg.GetId(), nil)
	}
}

func (m *dscpManager) handleHEPUpdates(hepID *proto.HostEndpointID, msg *proto.HostEndpointUpdate) {
	id := types.ProtoToHostEndpointID(hepID)
	if msg == nil || len(msg.Endpoint.QosPolicies) == 0 {
		_, exists := m.hepPolicies[id]
		if exists {
			delete(m.hepPolicies, id)
			m.dirty = true
		}
		return
	}

	ips := msg.Endpoint.ExpectedIpv4Addrs
	if m.ipVersion == 6 {
		ips = msg.Endpoint.ExpectedIpv6Addrs
	}

	// We only support one policy per endpoint at this point.
	dscp := msg.Endpoint.QosPolicies[0].Dscp
	r, err := convertUpdatesToDSCPRule(ips, dscp)
	if err != nil {
		m.logCtx.WithField("hep", id).WithError(err).Error("Failed to handle DSCP from endpoint update - Skipping.")
		return
	}

	m.hepPolicies[id] = r
	m.dirty = true
}

func (m *dscpManager) handleWEPUpdates(wepID *proto.WorkloadEndpointID, msg *proto.WorkloadEndpointUpdate) {
	id := types.ProtoToWorkloadEndpointID(wepID)
	if msg == nil || len(msg.Endpoint.QosPolicies) == 0 {
		_, exists := m.wepPolicies[id]
		if exists {
			delete(m.wepPolicies, id)
			m.dirty = true
		}
		return
	}

	ips := msg.Endpoint.Ipv4Nets
	if m.ipVersion == 6 {
		ips = msg.Endpoint.Ipv6Nets
	}

	// We only support one policy per endpoint at this point.
	dscp := msg.Endpoint.QosPolicies[0].Dscp
	r, err := convertUpdatesToDSCPRule(ips, dscp)
	if err != nil {
		m.logCtx.WithField("wep", id).WithError(err).Error("Failed to handle DSCP from endpoint update - Skipping.")
		return
	}

	m.wepPolicies[id] = r
	m.dirty = true
}

func convertUpdatesToDSCPRule(ips []string, dscp int32) (*rules.DSCPRule, error) {
	if dscp > 63 || dscp < 0 {
		return nil, fmt.Errorf("invalid DSCP value %v", dscp)
	}
	if len(ips) == 0 {
		return nil, fmt.Errorf("no address provided")
	}
	srcAddrs, err := normaliseSourceAddr(ips)
	if err != nil {
		return nil, fmt.Errorf("invalid address %v", err)
	}
	return &rules.DSCPRule{
		SrcAddrs: srcAddrs,
		Value:    uint8(dscp),
	}, nil
}

func normaliseSourceAddr(addrs []string) (string, error) {
	var trimmedSources []string
	for _, addr := range addrs {
		srcAddr, err := removeSubnetMask(addr)
		if err != nil {
			return "", err
		}
		trimmedSources = append(trimmedSources, srcAddr)
	}
	return strings.Join(trimmedSources, ","), nil
}

func removeSubnetMask(addr string) (string, error) {
	// addr is in format of a.b.c.d/x.
	parts := strings.Split(addr, "/")
	if len(parts) == 0 {
		return "", fmt.Errorf("malformed address %s", addr)
	}
	return parts[0], nil
}

func (m *dscpManager) CompleteDeferredWork() error {
	var dscpRules []*rules.DSCPRule
	if m.dirty {
		for _, r := range m.wepPolicies {
			dscpRules = append(dscpRules, r)
		}
		for _, r := range m.hepPolicies {
			dscpRules = append(dscpRules, r)
		}
		sort.Slice(dscpRules, func(i, j int) bool {
			return dscpRules[i].SrcAddrs < dscpRules[j].SrcAddrs
		})

		m.updateIPSet()

		chain := m.ruleRenderer.EgressDSCPChain(dscpRules)
		m.mangleTable.UpdateChain(chain)
		m.dirty = false
	}

	return nil
}

func (m *dscpManager) updateIPSet() {
	// For simplicity (and on the assumption that endpoints with DSCP annotation add/removes are rare) rewrite
	// the whole IP set whenever we get a change. To replace this with delta handling
	// would require reference counting the IPs because it's possible for two hosts
	// to (at least transiently) share an IP. That would add occupancy and make the
	// code more complex.
	m.logCtx.Debug("DSCP IP set out-of sync, refreshing it.")
	// This is the minimum number of entries. Might need more, if endpoints have multiple addresses.
	members := make([]string, 0, len(m.hepPolicies)+len(m.wepPolicies))
	for _, pol := range m.hepPolicies {
		parts := strings.Split(pol.SrcAddrs, ",")
		members = append(members, parts...)
	}
	for _, pol := range m.wepPolicies {
		parts := strings.Split(pol.SrcAddrs, ",")
		members = append(members, parts...)
	}
	m.ipsetsDataplane.AddOrReplaceIPSet(m.ipSetMetadata, members)
}
