package dnsforward

import (
	"net"
	"testing"

	"github.com/AdguardTeam/dnsproxy/proxy"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
)

// fakeIpsetMgr is a fake aghnet.IpsetManager for tests.
type fakeIpsetMgr struct {
	ip4s []net.IP
	ip6s []net.IP
}

// Add implements the aghnet.IpsetManager interface for *fakeIpsetMgr.
func (m *fakeIpsetMgr) Add(host string, ip4s, ip6s []net.IP) (n int, err error) {
	m.ip4s = append(m.ip4s, ip4s...)
	m.ip6s = append(m.ip6s, ip6s...)

	return len(ip4s) + len(ip6s), nil
}

// Close implements the aghnet.IpsetManager interface for *fakeIpsetMgr.
func (*fakeIpsetMgr) Close() (err error) {
	return nil
}

func TestIpsetCtx_process(t *testing.T) {
	ip4 := net.IP{1, 2, 3, 4}
	ip6 := net.IP{
		0x12, 0x34, 0x00, 0x00,
		0x00, 0x00, 0x00, 0x00,
		0x00, 0x00, 0x00, 0x00,
		0x00, 0x00, 0x56, 0x78,
	}

	req4 := createTestMessageWithType("example.com", dns.TypeA)
	req6 := createTestMessageWithType("example.com", dns.TypeAAAA)

	resp4 := &dns.Msg{
		Answer: []dns.RR{&dns.A{
			A: ip4,
		}},
	}
	resp6 := &dns.Msg{
		Answer: []dns.RR{&dns.AAAA{
			AAAA: ip6,
		}},
	}

	t.Run("nil", func(t *testing.T) {
		dctx := &dnsContext{
			proxyCtx: &proxy.DNSContext{},

			responseFromUpstream: true,
		}

		ictx := &ipsetCtx{}
		rc := ictx.process(dctx)
		assert.Equal(t, resultCodeSuccess, rc)

		err := ictx.close()
		assert.NoError(t, err)
	})

	t.Run("ipv4", func(t *testing.T) {
		dctx := &dnsContext{
			proxyCtx: &proxy.DNSContext{
				Req: req4,
				Res: resp4,
			},

			responseFromUpstream: true,
		}

		m := &fakeIpsetMgr{}
		ictx := &ipsetCtx{
			ipsetMgr: m,
		}

		rc := ictx.process(dctx)
		assert.Equal(t, resultCodeSuccess, rc)
		assert.Equal(t, []net.IP{ip4}, m.ip4s)
		assert.Empty(t, m.ip6s)

		err := ictx.close()
		assert.NoError(t, err)
	})

	t.Run("ipv6", func(t *testing.T) {
		dctx := &dnsContext{
			proxyCtx: &proxy.DNSContext{
				Req: req6,
				Res: resp6,
			},

			responseFromUpstream: true,
		}

		m := &fakeIpsetMgr{}
		ictx := &ipsetCtx{
			ipsetMgr: m,
		}

		rc := ictx.process(dctx)
		assert.Equal(t, resultCodeSuccess, rc)
		assert.Empty(t, m.ip4s)
		assert.Equal(t, []net.IP{ip6}, m.ip6s)

		err := ictx.close()
		assert.NoError(t, err)
	})
}