package home

import (
	"bytes"
	"encoding/binary"
	"errors"
	"net"
	"sync"
	"testing"
	"time"

	"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
	"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
	"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
	"github.com/AdguardTeam/dnsproxy/proxy"
	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/cache"
	"github.com/AdguardTeam/golibs/log"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestRDNS_Begin(t *testing.T) {
	aghtest.ReplaceLogLevel(t, log.DEBUG)
	w := &bytes.Buffer{}
	aghtest.ReplaceLogWriter(t, w)

	ip1234, ip1235 := net.IP{1, 2, 3, 4}, net.IP{1, 2, 3, 5}

	testCases := []struct {
		cliIDIndex    map[string]*Client
		customChan    chan net.IP
		name          string
		wantLog       string
		req           net.IP
		wantCacheHit  int
		wantCacheMiss int
	}{{
		cliIDIndex:    map[string]*Client{},
		customChan:    nil,
		name:          "cached",
		wantLog:       "",
		req:           ip1234,
		wantCacheHit:  1,
		wantCacheMiss: 0,
	}, {
		cliIDIndex:    map[string]*Client{},
		customChan:    nil,
		name:          "not_cached",
		wantLog:       "rdns: queue is full",
		req:           ip1235,
		wantCacheHit:  0,
		wantCacheMiss: 1,
	}, {
		cliIDIndex:    map[string]*Client{"1.2.3.5": {}},
		customChan:    nil,
		name:          "already_in_clients",
		wantLog:       "",
		req:           ip1235,
		wantCacheHit:  0,
		wantCacheMiss: 1,
	}, {
		cliIDIndex:    map[string]*Client{},
		customChan:    make(chan net.IP, 1),
		name:          "add_to_queue",
		wantLog:       `rdns: "1.2.3.5" added to queue`,
		req:           ip1235,
		wantCacheHit:  0,
		wantCacheMiss: 1,
	}}

	for _, tc := range testCases {
		w.Reset()

		ipCache := cache.New(cache.Config{
			EnableLRU: true,
			MaxCount:  defaultRDNSCacheSize,
		})
		ttl := make([]byte, binary.Size(uint64(0)))
		binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix()))

		rdns := &RDNS{
			ipCache: ipCache,
			clients: &clientsContainer{
				list:    map[string]*Client{},
				idIndex: tc.cliIDIndex,
				ipToRC:  map[string]*RuntimeClient{},
				allTags: map[string]bool{},
			},
		}
		ipCache.Clear()
		ipCache.Set(net.IP{1, 2, 3, 4}, ttl)

		if tc.customChan != nil {
			rdns.ipCh = tc.customChan
			defer close(tc.customChan)
		}

		t.Run(tc.name, func(t *testing.T) {
			rdns.Begin(tc.req)
			assert.Equal(t, tc.wantCacheHit, ipCache.Stats().Hit)
			assert.Equal(t, tc.wantCacheMiss, ipCache.Stats().Miss)
			assert.Contains(t, w.String(), tc.wantLog)
		})
	}
}

func TestRDNS_Resolve(t *testing.T) {
	extUpstream := &aghtest.TestUpstream{
		Reverse: map[string][]string{
			"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
		},
	}
	locUpstream := &aghtest.TestUpstream{
		Reverse: map[string][]string{
			"1.1.168.192.in-addr.arpa.": {"local.domain"},
			"2.1.168.192.in-addr.arpa.": {},
		},
	}
	upstreamErr := errors.New("upstream error")
	errUpstream := &aghtest.TestErrUpstream{
		Err: upstreamErr,
	}
	nonPtrUpstream := &aghtest.TestBlockUpstream{
		Hostname: "some-host",
		Block:    true,
	}

	dns := dnsforward.NewCustomServer(&proxy.Proxy{
		Config: proxy.Config{
			UpstreamConfig: &proxy.UpstreamConfig{
				Upstreams: []upstream.Upstream{extUpstream},
			},
		},
	})

	cc := &clientsContainer{}

	snd, err := aghnet.NewSubnetDetector()
	require.NoError(t, err)

	localIP := net.IP{192, 168, 1, 1}
	testCases := []struct {
		name        string
		want        string
		wantErr     error
		locUpstream upstream.Upstream
		req         net.IP
	}{{
		name:        "external_good",
		want:        "one.one.one.one",
		wantErr:     nil,
		locUpstream: nil,
		req:         net.IP{1, 1, 1, 1},
	}, {
		name:        "local_good",
		want:        "local.domain",
		wantErr:     nil,
		locUpstream: locUpstream,
		req:         localIP,
	}, {
		name:        "upstream_error",
		want:        "",
		wantErr:     upstreamErr,
		locUpstream: errUpstream,
		req:         localIP,
	}, {
		name:        "empty_answer_error",
		want:        "",
		wantErr:     rDNSEmptyAnswerErr,
		locUpstream: locUpstream,
		req:         net.IP{192, 168, 1, 2},
	}, {
		name:        "not_ptr_error",
		want:        "",
		wantErr:     rDNSNotPTRErr,
		locUpstream: nonPtrUpstream,
		req:         localIP,
	}}

	for _, tc := range testCases {
		rdns := NewRDNS(dns, cc, snd, &aghtest.Exchanger{
			Ups: tc.locUpstream,
		})

		t.Run(tc.name, func(t *testing.T) {
			r, rerr := rdns.resolve(tc.req)
			require.ErrorIs(t, rerr, tc.wantErr)
			assert.Equal(t, tc.want, r)
		})
	}
}

func TestRDNS_WorkerLoop(t *testing.T) {
	aghtest.ReplaceLogLevel(t, log.DEBUG)
	w := &bytes.Buffer{}
	aghtest.ReplaceLogWriter(t, w)

	locUpstream := &aghtest.TestUpstream{
		Reverse: map[string][]string{
			"1.1.168.192.in-addr.arpa.": {"local.domain"},
		},
	}

	snd, err := aghnet.NewSubnetDetector()
	require.NoError(t, err)

	testCases := []struct {
		wantLog string
		name    string
		cliIP   net.IP
	}{{
		wantLog: "",
		name:    "all_good",
		cliIP:   net.IP{192, 168, 1, 1},
	}, {
		wantLog: `rdns: resolving "192.168.1.2": lookup for "2.1.168.192.in-addr.arpa.": ` +
			string(rDNSEmptyAnswerErr),
		name:  "resolve_error",
		cliIP: net.IP{192, 168, 1, 2},
	}}

	for _, tc := range testCases {
		w.Reset()

		lr := &aghtest.Exchanger{
			Ups: locUpstream,
		}
		cc := &clientsContainer{
			list:    map[string]*Client{},
			idIndex: map[string]*Client{},
			ipToRC:  map[string]*RuntimeClient{},
			allTags: map[string]bool{},
		}
		ch := make(chan net.IP)
		rdns := &RDNS{
			dnsServer:      nil,
			clients:        cc,
			subnetDetector: snd,
			localResolvers: lr,
			ipCh:           ch,
		}

		t.Run(tc.name, func(t *testing.T) {
			var wg sync.WaitGroup
			wg.Add(1)
			go func() {
				rdns.workerLoop()
				wg.Done()
			}()

			ch <- tc.cliIP
			close(ch)
			wg.Wait()

			if tc.wantLog != "" {
				assert.Contains(t, w.String(), tc.wantLog)

				return
			}

			assert.True(t, cc.Exists(tc.cliIP.String(), ClientSourceRDNS))
		})
	}
}