// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ext

import (
	"errors"
	"fmt"
	"math"
	"strconv"
	"strings"
	"sync"

	"github.com/google/cel-go/cel"
	"github.com/google/cel-go/common/ast"
	"github.com/google/cel-go/common/types"
	"github.com/google/cel-go/common/types/ref"
	"github.com/google/cel-go/common/types/traits"
	"github.com/google/cel-go/interpreter"
)

// Bindings returns a cel.EnvOption to configure support for local variable
// bindings in expressions.
//
// # Cel.Bind
//
// Binds a simple identifier to an initialization expression which may be used
// in a subsequenct result expression. Bindings may also be nested within each
// other.
//
//	cel.bind(<varName>, <initExpr>, <resultExpr>)
//
// Examples:
//
//	cel.bind(a, 'hello',
//	cel.bind(b, 'world', a + b + b + a)) // "helloworldworldhello"
//
//	// Avoid a list allocation within the exists comprehension.
//	cel.bind(valid_values, [a, b, c],
//	[d, e, f].exists(elem, elem in valid_values))
//
// Local bindings are not guaranteed to be evaluated before use.
func Bindings(options ...BindingsOption) cel.EnvOption {
	b := &celBindings{version: math.MaxUint32}
	for _, o := range options {
		b = o(b)
	}
	return cel.Lib(b)
}

const (
	celNamespace  = "cel"
	bindMacro     = "bind"
	blockFunc     = "@block"
	unusedIterVar = "#unused"
)

// BindingsOption declares a functional operator for configuring the Bindings library behavior.
type BindingsOption func(*celBindings) *celBindings

// BindingsVersion sets the version of the bindings library to an explicit version.
func BindingsVersion(version uint32) BindingsOption {
	return func(lib *celBindings) *celBindings {
		lib.version = version
		return lib
	}
}

type celBindings struct {
	version uint32
}

func (*celBindings) LibraryName() string {
	return "cel.lib.ext.cel.bindings"
}

func (lib *celBindings) CompileOptions() []cel.EnvOption {
	opts := []cel.EnvOption{
		cel.Macros(
			// cel.bind(var, <init>, <expr>)
			cel.ReceiverMacro(bindMacro, 3, celBind),
		),
	}
	if lib.version >= 1 {
		// The cel.@block signature takes a list of subexpressions and a typed expression which is
		// used as the output type.
		paramType := cel.TypeParamType("T")
		opts = append(opts,
			cel.Function("cel.@block",
				cel.Overload("cel_block_list",
					[]*cel.Type{cel.ListType(cel.DynType), paramType}, paramType)),
		)
		opts = append(opts, cel.ASTValidators(blockValidationExemption{}))
	}
	return opts
}

func (lib *celBindings) ProgramOptions() []cel.ProgramOption {
	if lib.version >= 1 {
		celBlockPlan := func(i interpreter.Interpretable) (interpreter.Interpretable, error) {
			call, ok := i.(interpreter.InterpretableCall)
			if !ok {
				return i, nil
			}
			switch call.Function() {
			case "cel.@block":
				args := call.Args()
				if len(args) != 2 {
					return nil, fmt.Errorf("cel.@block expects two arguments, but got %d", len(args))
				}
				expr := args[1]
				// Non-empty block
				if block, ok := args[0].(interpreter.InterpretableConstructor); ok {
					slotExprs := block.InitVals()
					return newDynamicBlock(slotExprs, expr), nil
				}
				// Constant valued block which can happen during runtime optimization.
				if cons, ok := args[0].(interpreter.InterpretableConst); ok {
					if cons.Value().Type() == types.ListType {
						l := cons.Value().(traits.Lister)
						if l.Size().Equal(types.IntZero) == types.True {
							return args[1], nil
						}
						return newConstantBlock(l, expr), nil
					}
				}
				return nil, errors.New("cel.@block expects a list constructor as the first argument")
			default:
				return i, nil
			}
		}
		return []cel.ProgramOption{cel.CustomDecorator(celBlockPlan)}
	}
	return []cel.ProgramOption{}
}

type blockValidationExemption struct{}

// Name returns the name of the validator.
func (blockValidationExemption) Name() string {
	return "cel.validator.cel_block"
}

// Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip
// during homogeneous aggregate literal type-checks.
func (blockValidationExemption) Configure(config cel.MutableValidatorConfig) error {
	functions := config.GetOrDefault(cel.HomogeneousAggregateLiteralExemptFunctions, []string{}).([]string)
	functions = append(functions, "cel.@block")
	return config.Set(cel.HomogeneousAggregateLiteralExemptFunctions, functions)
}

// Validate is a no-op as the intent is to simply disable strong type-checks for list literals during
// when they occur within cel.@block calls as the arg types have already been validated.
func (blockValidationExemption) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) {
}

func celBind(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
	if !macroTargetMatchesNamespace(celNamespace, target) {
		return nil, nil
	}
	varIdent := args[0]
	varName := ""
	switch varIdent.Kind() {
	case ast.IdentKind:
		varName = varIdent.AsIdent()
	default:
		return nil, mef.NewError(varIdent.ID(), "cel.bind() variable names must be simple identifiers")
	}
	varInit := args[1]
	resultExpr := args[2]
	return mef.NewComprehension(
		mef.NewList(),
		unusedIterVar,
		varName,
		varInit,
		mef.NewLiteral(types.False),
		mef.NewIdent(varName),
		resultExpr,
	), nil
}

func newDynamicBlock(slotExprs []interpreter.Interpretable, expr interpreter.Interpretable) interpreter.Interpretable {
	bs := &dynamicBlock{
		slotExprs: slotExprs,
		expr:      expr,
	}
	bs.slotActivationPool = &sync.Pool{
		New: func() any {
			slotCount := len(slotExprs)
			sa := &dynamicSlotActivation{
				slotExprs: slotExprs,
				slotCount: slotCount,
				slotVals:  make([]*slotVal, slotCount),
			}
			for i := 0; i < slotCount; i++ {
				sa.slotVals[i] = &slotVal{}
			}
			return sa
		},
	}
	return bs
}

type dynamicBlock struct {
	slotExprs          []interpreter.Interpretable
	expr               interpreter.Interpretable
	slotActivationPool *sync.Pool
}

// ID implements the Interpretable interface method.
func (b *dynamicBlock) ID() int64 {
	return b.expr.ID()
}

// Eval implements the Interpretable interface method.
func (b *dynamicBlock) Eval(activation cel.Activation) ref.Val {
	sa := b.slotActivationPool.Get().(*dynamicSlotActivation)
	sa.Activation = activation
	defer b.clearSlots(sa)
	return b.expr.Eval(sa)
}

func (b *dynamicBlock) clearSlots(sa *dynamicSlotActivation) {
	sa.reset()
	b.slotActivationPool.Put(sa)
}

type slotVal struct {
	value   *ref.Val
	visited bool
}

type dynamicSlotActivation struct {
	cel.Activation
	slotExprs []interpreter.Interpretable
	slotCount int
	slotVals  []*slotVal
}

// Unwrap returns the underlying activation.
func (sa *dynamicSlotActivation) Unwrap() cel.Activation {
	return sa.Activation
}

// ResolveName implements the Activation interface method but handles variables prefixed with `@index`
// as special variables which exist within the slot-based memory of the cel.@block() where each slot
// refers to an expression which must be computed only once.
func (sa *dynamicSlotActivation) ResolveName(name string) (any, bool) {
	if idx, found := matchSlot(name, sa.slotCount); found {
		v := sa.slotVals[idx]
		if v.visited {
			// Return not found if the index expression refers to itself
			if v.value == nil {
				return nil, false
			}
			return *v.value, true
		}
		v.visited = true
		val := sa.slotExprs[idx].Eval(sa)
		v.value = &val
		return val, true
	}
	return sa.Activation.ResolveName(name)
}

func (sa *dynamicSlotActivation) reset() {
	sa.Activation = nil
	for _, sv := range sa.slotVals {
		sv.visited = false
		sv.value = nil
	}
}

func newConstantBlock(slots traits.Lister, expr interpreter.Interpretable) interpreter.Interpretable {
	count := slots.Size().(types.Int)
	return &constantBlock{slots: slots, slotCount: int(count), expr: expr}
}

type constantBlock struct {
	slots     traits.Lister
	slotCount int
	expr      interpreter.Interpretable
}

// ID implements the interpreter.Interpretable interface method.
func (b *constantBlock) ID() int64 {
	return b.expr.ID()
}

// Eval implements the interpreter.Interpretable interface method, and will proxy @index prefixed variable
// lookups into a set of constant slots determined from the plan step.
func (b *constantBlock) Eval(activation cel.Activation) ref.Val {
	vars := constantSlotActivation{Activation: activation, slots: b.slots, slotCount: b.slotCount}
	return b.expr.Eval(vars)
}

type constantSlotActivation struct {
	cel.Activation
	slots     traits.Lister
	slotCount int
}

// Unwrap returns the underlying activation.
func (sa *constantSlotActivation) Unwrap() cel.Activation {
	return sa.Activation
}

// ResolveName implements Activation interface method and proxies @index prefixed lookups into the slot
// activation associated with the block scope.
func (sa constantSlotActivation) ResolveName(name string) (any, bool) {
	if idx, found := matchSlot(name, sa.slotCount); found {
		return sa.slots.Get(types.Int(idx)), true
	}
	return sa.Activation.ResolveName(name)
}

func matchSlot(name string, slotCount int) (int, bool) {
	if idx, found := strings.CutPrefix(name, indexPrefix); found {
		idx, err := strconv.Atoi(idx)
		// Return not found if the index is not numeric
		if err != nil {
			return -1, false
		}
		// Return not found if the index is not a valid slot
		if idx < 0 || idx >= slotCount {
			return -1, false
		}
		return idx, true
	}
	return -1, false
}

var (
	indexPrefix = "@index"
)
