package dhcpsvc

import (
	"encoding/binary"
	"fmt"
	"math"
	"math/big"
	"net/netip"

	"github.com/AdguardTeam/golibs/errors"
)

// ipRange is an inclusive range of IP addresses.  A zero range doesn't contain
// any IP addresses.
//
// It is safe for concurrent use.
type ipRange struct {
	start netip.Addr
	end   netip.Addr
}

// maxRangeLen is the maximum IP range length.  The bitsets used in servers only
// accept uints, which can have the size of 32 bit.
//
// TODO(a.garipov, e.burkov):  Reconsider the value for IPv6.
const maxRangeLen = math.MaxUint32

// newIPRange creates a new IP address range.  start must be less than end.  The
// resulting range must not be greater than maxRangeLen.
func newIPRange(start, end netip.Addr) (r ipRange, err error) {
	defer func() { err = errors.Annotate(err, "invalid ip range: %w") }()

	switch false {
	case start.Is4() == end.Is4():
		return ipRange{}, fmt.Errorf("%s and %s must be within the same address family", start, end)
	case start.Less(end):
		return ipRange{}, fmt.Errorf("start %s is greater than or equal to end %s", start, end)
	default:
		diff := (&big.Int{}).Sub(
			(&big.Int{}).SetBytes(end.AsSlice()),
			(&big.Int{}).SetBytes(start.AsSlice()),
		)

		if !diff.IsUint64() || diff.Uint64() > maxRangeLen {
			return ipRange{}, fmt.Errorf("range length must be within %d", uint32(maxRangeLen))
		}
	}

	return ipRange{
		start: start,
		end:   end,
	}, nil
}

// contains returns true if r contains ip.
func (r ipRange) contains(ip netip.Addr) (ok bool) {
	// Assume that the end was checked to be within the same address family as
	// the start during construction.
	return r.start.Is4() == ip.Is4() && !ip.Less(r.start) && !r.end.Less(ip)
}

// ipPredicate is a function that is called on every IP address in
// [ipRange.find].
type ipPredicate func(ip netip.Addr) (ok bool)

// find finds the first IP address in r for which p returns true.  It returns an
// empty [netip.Addr] if there are no addresses that satisfy p.
//
// TODO(e.burkov):  Use.
func (r ipRange) find(p ipPredicate) (ip netip.Addr) {
	for ip = r.start; !r.end.Less(ip); ip = ip.Next() {
		if p(ip) {
			return ip
		}
	}

	return netip.Addr{}
}

// offset returns the offset of ip from the beginning of r.  It returns 0 and
// false if ip is not in r.
func (r ipRange) offset(ip netip.Addr) (offset uint64, ok bool) {
	if !r.contains(ip) {
		return 0, false
	}

	startData, ipData := r.start.As16(), ip.As16()
	be := binary.BigEndian

	// Assume that the range length was checked against maxRangeLen during
	// construction.
	return be.Uint64(ipData[8:]) - be.Uint64(startData[8:]), true
}

// String implements the fmt.Stringer interface for *ipRange.
func (r ipRange) String() (s string) {
	return fmt.Sprintf("%s-%s", r.start, r.end)
}