package dnsforward

import (
	"crypto/tls"
	"net"
	"testing"
	"time"

	"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
	"github.com/AdguardTeam/AdGuardHome/internal/filtering"
	"github.com/AdguardTeam/dnsproxy/proxy"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

const (
	blockedHost      = "blockedhost.org"
	testFQDN         = "example.org."
	dnsClientTimeout = 200 * time.Millisecond
)

func TestServer_HandleBefore_tls(t *testing.T) {
	t.Parallel()

	const clientID = "client-1"

	testCases := []struct {
		clientSrvName     string
		name              string
		host              string
		allowedClients    []string
		disallowedClients []string
		blockedHosts      []string
		wantRCode         int
	}{{
		clientSrvName:     tlsServerName,
		name:              "allow_all",
		host:              testFQDN,
		allowedClients:    []string{},
		disallowedClients: []string{},
		blockedHosts:      []string{},
		wantRCode:         dns.RcodeSuccess,
	}, {
		clientSrvName:     "%" + "." + tlsServerName,
		name:              "invalid_client_id",
		host:              testFQDN,
		allowedClients:    []string{},
		disallowedClients: []string{},
		blockedHosts:      []string{},
		wantRCode:         dns.RcodeServerFailure,
	}, {
		clientSrvName:     clientID + "." + tlsServerName,
		name:              "allowed_client_allowed",
		host:              testFQDN,
		allowedClients:    []string{clientID},
		disallowedClients: []string{},
		blockedHosts:      []string{},
		wantRCode:         dns.RcodeSuccess,
	}, {
		clientSrvName:     "client-2." + tlsServerName,
		name:              "allowed_client_rejected",
		host:              testFQDN,
		allowedClients:    []string{clientID},
		disallowedClients: []string{},
		blockedHosts:      []string{},
		wantRCode:         dns.RcodeRefused,
	}, {
		clientSrvName:     tlsServerName,
		name:              "disallowed_client_allowed",
		host:              testFQDN,
		allowedClients:    []string{},
		disallowedClients: []string{clientID},
		blockedHosts:      []string{},
		wantRCode:         dns.RcodeSuccess,
	}, {
		clientSrvName:     clientID + "." + tlsServerName,
		name:              "disallowed_client_rejected",
		host:              testFQDN,
		allowedClients:    []string{},
		disallowedClients: []string{clientID},
		blockedHosts:      []string{},
		wantRCode:         dns.RcodeRefused,
	}, {
		clientSrvName:     tlsServerName,
		name:              "blocked_hosts_allowed",
		host:              testFQDN,
		allowedClients:    []string{},
		disallowedClients: []string{},
		blockedHosts:      []string{blockedHost},
		wantRCode:         dns.RcodeSuccess,
	}, {
		clientSrvName:     tlsServerName,
		name:              "blocked_hosts_rejected",
		host:              dns.Fqdn(blockedHost),
		allowedClients:    []string{},
		disallowedClients: []string{},
		blockedHosts:      []string{blockedHost},
		wantRCode:         dns.RcodeRefused,
	}}

	localAns := []dns.RR{&dns.A{
		Hdr: dns.RR_Header{
			Name:     testFQDN,
			Rrtype:   dns.TypeA,
			Class:    dns.ClassINET,
			Ttl:      3600,
			Rdlength: 4,
		},
		A: net.IP{1, 2, 3, 4},
	}}
	localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
		resp := (&dns.Msg{}).SetReply(req)
		resp.Answer = localAns

		require.NoError(t, w.WriteMsg(resp))
	})
	localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			s, _ := createTestTLS(t, TLSConfig{
				TLSListenAddrs: []*net.TCPAddr{{}},
				ServerName:     tlsServerName,
			})

			s.conf.UpstreamDNS = []string{localUpsAddr}

			s.conf.AllowedClients = tc.allowedClients
			s.conf.DisallowedClients = tc.disallowedClients
			s.conf.BlockedHosts = tc.blockedHosts

			err := s.Prepare(&s.conf)
			require.NoError(t, err)

			startDeferStop(t, s)

			tlsConfig := &tls.Config{
				InsecureSkipVerify: true,
				ServerName:         tc.clientSrvName,
			}

			client := &dns.Client{
				Net:       "tcp-tls",
				TLSConfig: tlsConfig,
				Timeout:   dnsClientTimeout,
			}

			req := createTestMessage(tc.host)
			addr := s.dnsProxy.Addr(proxy.ProtoTLS).String()

			reply, _, err := client.Exchange(req, addr)
			require.NoError(t, err)

			assert.Equal(t, tc.wantRCode, reply.Rcode)
			if tc.wantRCode == dns.RcodeSuccess {
				assert.Equal(t, localAns, reply.Answer)
			} else {
				assert.Empty(t, reply.Answer)
			}
		})
	}
}

func TestServer_HandleBefore_udp(t *testing.T) {
	t.Parallel()

	const (
		clientIPv4 = "127.0.0.1"
		clientIPv6 = "::1"
	)

	clientIPs := []string{clientIPv4, clientIPv6}

	testCases := []struct {
		name              string
		host              string
		allowedClients    []string
		disallowedClients []string
		blockedHosts      []string
		wantTimeout       bool
	}{{
		name:              "allow_all",
		host:              testFQDN,
		allowedClients:    []string{},
		disallowedClients: []string{},
		blockedHosts:      []string{},
		wantTimeout:       false,
	}, {
		name:              "allowed_client_allowed",
		host:              testFQDN,
		allowedClients:    clientIPs,
		disallowedClients: []string{},
		blockedHosts:      []string{},
		wantTimeout:       false,
	}, {
		name:              "allowed_client_rejected",
		host:              testFQDN,
		allowedClients:    []string{"1:2:3::4"},
		disallowedClients: []string{},
		blockedHosts:      []string{},
		wantTimeout:       true,
	}, {
		name:              "disallowed_client_allowed",
		host:              testFQDN,
		allowedClients:    []string{},
		disallowedClients: []string{"1:2:3::4"},
		blockedHosts:      []string{},
		wantTimeout:       false,
	}, {
		name:              "disallowed_client_rejected",
		host:              testFQDN,
		allowedClients:    []string{},
		disallowedClients: clientIPs,
		blockedHosts:      []string{},
		wantTimeout:       true,
	}, {
		name:              "blocked_hosts_allowed",
		host:              testFQDN,
		allowedClients:    []string{},
		disallowedClients: []string{},
		blockedHosts:      []string{blockedHost},
		wantTimeout:       false,
	}, {
		name:              "blocked_hosts_rejected",
		host:              dns.Fqdn(blockedHost),
		allowedClients:    []string{},
		disallowedClients: []string{},
		blockedHosts:      []string{blockedHost},
		wantTimeout:       true,
	}}

	localAns := []dns.RR{&dns.A{
		Hdr: dns.RR_Header{
			Name:     testFQDN,
			Rrtype:   dns.TypeA,
			Class:    dns.ClassINET,
			Ttl:      3600,
			Rdlength: 4,
		},
		A: net.IP{1, 2, 3, 4},
	}}
	localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
		resp := (&dns.Msg{}).SetReply(req)
		resp.Answer = localAns

		require.NoError(t, w.WriteMsg(resp))
	})
	localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			s := createTestServer(t, &filtering.Config{
				BlockingMode: filtering.BlockingModeDefault,
			}, ServerConfig{
				UDPListenAddrs: []*net.UDPAddr{{}},
				TCPListenAddrs: []*net.TCPAddr{{}},
				Config: Config{
					AllowedClients:    tc.allowedClients,
					DisallowedClients: tc.disallowedClients,
					BlockedHosts:      tc.blockedHosts,
					UpstreamDNS:       []string{localUpsAddr},
					UpstreamMode:      UpstreamModeLoadBalance,
					EDNSClientSubnet:  &EDNSClientSubnet{Enabled: false},
					ClientsContainer:  EmptyClientsContainer{},
				},
				ServePlainDNS: true,
			})

			startDeferStop(t, s)

			client := &dns.Client{
				Net:     "udp",
				Timeout: dnsClientTimeout,
			}

			req := createTestMessage(tc.host)
			addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()

			reply, _, err := client.Exchange(req, addr)
			if tc.wantTimeout {
				wantErr := &net.OpError{}
				require.ErrorAs(t, err, &wantErr)
				assert.True(t, wantErr.Timeout())

				assert.Nil(t, reply)
			} else {
				require.NoError(t, err)
				require.NotNil(t, reply)

				assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
				assert.Equal(t, localAns, reply.Answer)
			}
		})
	}
}