package dnsforward

import (
	"encoding/binary"
	"testing"
	"time"

	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
)

func TestRecursionDetector_Check(t *testing.T) {
	rd := newRecursionDetector(0, 2)

	const (
		recID  = 1234
		recTTL = time.Hour * 100
	)

	const nonRecID = recID * 2

	sampleQuestion := dns.Question{
		Name:  "some.domain",
		Qtype: dns.TypeAAAA,
	}
	sampleMsg := dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id: recID,
		},
		Question: []dns.Question{sampleQuestion},
	}

	// Manually add the message with big ttl.
	key := msgToSignature(sampleMsg)
	expire := make([]byte, uint64sz)
	binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano()))
	rd.recentRequests.Set(key, expire)

	// Add an expired message.
	sampleMsg.Id = nonRecID
	rd.add(sampleMsg)

	testCases := []struct {
		name      string
		questions []dns.Question
		id        uint16
		want      bool
	}{{
		name:      "recurrent",
		questions: []dns.Question{sampleQuestion},
		id:        recID,
		want:      true,
	}, {
		name:      "not_suspected",
		questions: []dns.Question{sampleQuestion},
		id:        recID + 1,
		want:      false,
	}, {
		name:      "expired",
		questions: []dns.Question{sampleQuestion},
		id:        nonRecID,
		want:      false,
	}, {
		name:      "empty",
		questions: []dns.Question{},
		id:        nonRecID,
		want:      false,
	}}

	for _, tc := range testCases {
		sampleMsg.Id = tc.id
		sampleMsg.Question = tc.questions
		t.Run(tc.name, func(t *testing.T) {
			detected := rd.check(sampleMsg)
			assert.Equal(t, tc.want, detected)
		})
	}
}

func TestRecursionDetector_Suspect(t *testing.T) {
	rd := newRecursionDetector(0, 1)

	testCases := []struct {
		name string
		msg  dns.Msg
		want bool
	}{{
		name: "simple",
		msg: dns.Msg{
			MsgHdr: dns.MsgHdr{
				Id: 1234,
			},
			Question: []dns.Question{{
				Name:  "some.domain",
				Qtype: dns.TypeA,
			}},
		},
		want: true,
	}, {
		name: "unencumbered",
		msg:  dns.Msg{},
		want: false,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			t.Cleanup(rd.clear)

			rd.add(tc.msg)

			if tc.want {
				assert.Equal(t, 1, rd.recentRequests.Stats().Count)
			} else {
				assert.Zero(t, rd.recentRequests.Stats().Count)
			}
		})
	}
}

var sink []byte

func BenchmarkMsgToSignature(b *testing.B) {
	const name = "some.not.very.long.host.name"

	msg := dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id: 1234,
		},
		Question: []dns.Question{{
			Name:  name,
			Qtype: dns.TypeAAAA,
		}},
	}

	b.Run("efficient", func(b *testing.B) {
		b.ReportAllocs()

		for i := 0; i < b.N; i++ {
			sink = msgToSignature(msg)
		}

		assert.NotEmpty(b, sink)
	})

	b.Run("inefficient", func(b *testing.B) {
		b.ReportAllocs()

		for i := 0; i < b.N; i++ {
			sink = msgToSignatureSlow(msg)
		}

		assert.NotEmpty(b, sink)
	})
}