diff --git a/internal/aghnet/dhcp_unix.go b/internal/aghnet/dhcp_unix.go index 479d9926..554d68c6 100644 --- a/internal/aghnet/dhcp_unix.go +++ b/internal/aghnet/dhcp_unix.go @@ -19,7 +19,8 @@ import ( "github.com/insomniacslk/dhcp/iana" ) -// defaultDiscoverTime is the +// defaultDiscoverTime is the default timeout of checking another DHCP server +// response. const defaultDiscoverTime = 3 * time.Second func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) { diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index 70f2d00f..150e8c19 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -343,113 +343,93 @@ func TestHostsContainer(t *testing.T) { testdata := os.DirFS("./testdata") - nRewrites := func(t *testing.T, res *urlfilter.DNSResult, n int) (rws []*rules.DNSRewrite) { - rewrites := res.DNSRewrites() - require.Len(t, rewrites, n) - - for _, rewrite := range rewrites { - require.Equal(t, listID, rewrite.FilterListID) - - rw := rewrite.DNSRewrite - require.NotNil(t, rw) - - rws = append(rws, rw) - } - - return rws - } - testCases := []struct { - testTail func(t *testing.T, res *urlfilter.DNSResult) - name string - req urlfilter.DNSRequest + want []*rules.DNSRewrite + name string + req urlfilter.DNSRequest }{{ + want: []*rules.DNSRewrite{{ + RCode: dns.RcodeSuccess, + Value: net.IPv4(1, 0, 0, 1), + RRType: dns.TypeA, + }, { + RCode: dns.RcodeSuccess, + Value: net.IP(append((&[15]byte{})[:], byte(1))), + RRType: dns.TypeAAAA, + }}, name: "simple", req: urlfilter.DNSRequest{ Hostname: "simplehost", DNSType: dns.TypeA, }, - testTail: func(t *testing.T, res *urlfilter.DNSResult) { - rws := nRewrites(t, res, 2) - - v, ok := rws[0].Value.(net.IP) - require.True(t, ok) - - assert.True(t, net.IP{1, 0, 0, 1}.Equal(v)) - - v, ok = rws[1].Value.(net.IP) - require.True(t, ok) - - // It's ::1. - assert.True(t, net.IP(append((&[15]byte{})[:], byte(1))).Equal(v)) - }, }, { + want: []*rules.DNSRewrite{{ + RCode: dns.RcodeSuccess, + NewCNAME: "hello", + }}, name: "hello_alias", req: urlfilter.DNSRequest{ Hostname: "hello.world", DNSType: dns.TypeA, }, - testTail: func(t *testing.T, res *urlfilter.DNSResult) { - assert.Equal(t, "hello", nRewrites(t, res, 1)[0].NewCNAME) - }, }, { + want: []*rules.DNSRewrite{{ + RCode: dns.RcodeSuccess, + NewCNAME: "hello", + }}, name: "other_line_alias", req: urlfilter.DNSRequest{ Hostname: "hello.world.again", DNSType: dns.TypeA, }, - testTail: func(t *testing.T, res *urlfilter.DNSResult) { - assert.Equal(t, "hello", nRewrites(t, res, 1)[0].NewCNAME) - }, }, { + want: []*rules.DNSRewrite{}, name: "hello_subdomain", req: urlfilter.DNSRequest{ Hostname: "say.hello", DNSType: dns.TypeA, }, - testTail: func(t *testing.T, res *urlfilter.DNSResult) { - assert.Empty(t, res.DNSRewrites()) - }, }, { + want: []*rules.DNSRewrite{}, name: "hello_alias_subdomain", req: urlfilter.DNSRequest{ Hostname: "say.hello.world", DNSType: dns.TypeA, }, - testTail: func(t *testing.T, res *urlfilter.DNSResult) { - assert.Empty(t, res.DNSRewrites()) - }, }, { + want: []*rules.DNSRewrite{{ + RCode: dns.RcodeSuccess, + NewCNAME: "a.whole", + }}, name: "lots_of_aliases", req: urlfilter.DNSRequest{ Hostname: "for.testing", DNSType: dns.TypeA, }, - testTail: func(t *testing.T, res *urlfilter.DNSResult) { - assert.Equal(t, "a.whole", nRewrites(t, res, 1)[0].NewCNAME) - }, }, { + want: []*rules.DNSRewrite{{ + RCode: dns.RcodeSuccess, + RRType: dns.TypePTR, + Value: "simplehost.", + }}, name: "reverse", req: urlfilter.DNSRequest{ Hostname: "1.0.0.1.in-addr.arpa", DNSType: dns.TypePTR, }, - testTail: func(t *testing.T, res *urlfilter.DNSResult) { - rws := nRewrites(t, res, 1) - - assert.Equal(t, dns.TypePTR, rws[0].RRType) - assert.Equal(t, "simplehost.", rws[0].Value) - }, }, { + want: []*rules.DNSRewrite{}, name: "non-existing", req: urlfilter.DNSRequest{ Hostname: "nonexisting", DNSType: dns.TypeA, }, - testTail: func(t *testing.T, res *urlfilter.DNSResult) { - require.NotNil(t, res) - - assert.Nil(t, res.DNSRewrites()) + }, { + want: nil, + name: "bad_type", + req: urlfilter.DNSRequest{ + Hostname: "1.0.0.1.in-addr.arpa", + DNSType: dns.TypeSRV, }, }} @@ -466,9 +446,26 @@ func TestHostsContainer(t *testing.T) { t.Run(tc.name, func(t *testing.T) { res, ok := hc.MatchRequest(tc.req) require.False(t, ok) + + if tc.want == nil { + assert.Nil(t, res) + + return + } + require.NotNil(t, res) - tc.testTail(t, res) + rewrites := res.DNSRewrites() + require.Len(t, rewrites, len(tc.want)) + + for i, rewrite := range rewrites { + require.Equal(t, listID, rewrite.FilterListID) + + rw := rewrite.DNSRewrite + require.NotNil(t, rw) + + assert.Equal(t, tc.want[i], rw) + } }) } } diff --git a/internal/aghnet/interfaces.go b/internal/aghnet/interfaces.go index a5095919..a667a1f3 100644 --- a/internal/aghnet/interfaces.go +++ b/internal/aghnet/interfaces.go @@ -25,6 +25,13 @@ type NetIface interface { // IfaceIPAddrs returns the interface's IP addresses. func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) { + switch ipv { + case IPVersion4, IPVersion6: + // Go on. + default: + return nil, fmt.Errorf("invalid ip version %d", ipv) + } + addrs, err := iface.Addrs() if err != nil { return nil, err @@ -41,20 +48,16 @@ func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) { continue } - // Assume that net.(*Interface).Addrs can only return valid IPv4 - // and IPv6 addresses. Thus, if it isn't an IPv4 address, it - // must be an IPv6 one. - switch ipv { - case IPVersion4: - if ip4 := ip.To4(); ip4 != nil { + // Assume that net.(*Interface).Addrs can only return valid IPv4 and + // IPv6 addresses. Thus, if it isn't an IPv4 address, it must be an + // IPv6 one. + ip4 := ip.To4() + if ipv == IPVersion4 { + if ip4 != nil { ips = append(ips, ip4) } - case IPVersion6: - if ip6 := ip.To4(); ip6 == nil { - ips = append(ips, ip) - } - default: - return nil, fmt.Errorf("invalid ip version %d", ipv) + } else if ip4 == nil { + ips = append(ips, ip) } } @@ -96,16 +99,16 @@ func IfaceDNSIPAddrs( switch len(addrs) { case 0: - // Don't return errors in case the users want to try and enable - // the DHCP server later. + // Don't return errors in case the users want to try and enable the DHCP + // server later. t := time.Duration(n) * backoff log.Error("dhcpv%d: no ip for iface after %d attempts and %s", ipv, n, t) return nil, nil case 1: - // Some Android devices use 8.8.8.8 if there is not a secondary - // DNS server. Fix that by setting the secondary DNS address to - // the same address. + // Some Android devices use 8.8.8.8 if there is not a secondary DNS + // server. Fix that by setting the secondary DNS address to the same + // address. // // See https://github.com/AdguardTeam/AdGuardHome/issues/1708. log.Debug("dhcpv%d: setting secondary dns ip to itself", ipv) diff --git a/internal/aghnet/interfaces_test.go b/internal/aghnet/interfaces_test.go index 2b70429c..ca829fb1 100644 --- a/internal/aghnet/interfaces_test.go +++ b/internal/aghnet/interfaces_test.go @@ -5,13 +5,15 @@ import ( "testing" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// fakeIface is a stub implementation of aghnet.NetIface to simplify testing. type fakeIface struct { - addrs []net.Addr err error + addrs []net.Addr } // Addrs implements the NetIface interface for *fakeIface. @@ -33,61 +35,86 @@ func TestIfaceIPAddrs(t *testing.T) { addr6 := &net.IPNet{IP: ip6} testCases := []struct { - name string - iface NetIface - ipv IPVersion - want []net.IP - wantErr error + iface NetIface + name string + wantErrMsg string + want []net.IP + ipv IPVersion }{{ - name: "ipv4_success", - iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil}, - ipv: IPVersion4, - want: []net.IP{ip4}, - wantErr: nil, + iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil}, + name: "ipv4_success", + wantErrMsg: "", + want: []net.IP{ip4}, + ipv: IPVersion4, }, { - name: "ipv4_success_with_ipv6", - iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil}, - ipv: IPVersion4, - want: []net.IP{ip4}, - wantErr: nil, + iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil}, + name: "ipv4_success_with_ipv6", + wantErrMsg: "", + want: []net.IP{ip4}, + ipv: IPVersion4, }, { - name: "ipv4_error", - iface: &fakeIface{addrs: []net.Addr{addr4}, err: errTest}, - ipv: IPVersion4, - want: nil, - wantErr: errTest, + iface: &fakeIface{addrs: []net.Addr{addr4}, err: errTest}, + name: "ipv4_error", + wantErrMsg: errTest.Error(), + want: nil, + ipv: IPVersion4, }, { - name: "ipv6_success", - iface: &fakeIface{addrs: []net.Addr{addr6}, err: nil}, - ipv: IPVersion6, - want: []net.IP{ip6}, - wantErr: nil, + iface: &fakeIface{addrs: []net.Addr{addr6}, err: nil}, + name: "ipv6_success", + wantErrMsg: "", + want: []net.IP{ip6}, + ipv: IPVersion6, }, { - name: "ipv6_success_with_ipv4", - iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil}, - ipv: IPVersion6, - want: []net.IP{ip6}, - wantErr: nil, + iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil}, + name: "ipv6_success_with_ipv4", + wantErrMsg: "", + want: []net.IP{ip6}, + ipv: IPVersion6, }, { - name: "ipv6_error", - iface: &fakeIface{addrs: []net.Addr{addr6}, err: errTest}, - ipv: IPVersion6, - want: nil, - wantErr: errTest, + iface: &fakeIface{addrs: []net.Addr{addr6}, err: errTest}, + name: "ipv6_error", + wantErrMsg: errTest.Error(), + want: nil, + ipv: IPVersion6, + }, { + iface: &fakeIface{addrs: nil, err: nil}, + name: "bad_proto", + wantErrMsg: "invalid ip version 10", + want: nil, + ipv: IPVersion6 + IPVersion4, + }, { + iface: &fakeIface{addrs: []net.Addr{&net.IPAddr{IP: ip4}}, err: nil}, + name: "ipaddr_v4", + wantErrMsg: "", + want: []net.IP{ip4}, + ipv: IPVersion4, + }, { + iface: &fakeIface{addrs: []net.Addr{&net.IPAddr{IP: ip6, Zone: ""}}, err: nil}, + name: "ipaddr_v6", + wantErrMsg: "", + want: []net.IP{ip6}, + ipv: IPVersion6, + }, { + iface: &fakeIface{addrs: []net.Addr{&net.UnixAddr{}}, err: nil}, + name: "non-ipv4", + wantErrMsg: "", + want: nil, + ipv: IPVersion4, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got, gotErr := IfaceIPAddrs(tc.iface, tc.ipv) - require.True(t, errors.Is(gotErr, tc.wantErr)) + got, err := IfaceIPAddrs(tc.iface, tc.ipv) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + assert.Equal(t, tc.want, got) }) } } type waitingFakeIface struct { - addrs []net.Addr err error + addrs []net.Addr n int } @@ -116,11 +143,11 @@ func TestIfaceDNSIPAddrs(t *testing.T) { addr6 := &net.IPNet{IP: ip6} testCases := []struct { - name string iface NetIface - ipv IPVersion - want []net.IP wantErr error + name string + want []net.IP + ipv IPVersion }{{ name: "ipv4_success", iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil}, @@ -169,12 +196,25 @@ func TestIfaceDNSIPAddrs(t *testing.T) { ipv: IPVersion6, want: []net.IP{ip6, ip6}, wantErr: nil, + }, { + name: "empty", + iface: &fakeIface{addrs: nil, err: nil}, + ipv: IPVersion4, + want: nil, + wantErr: nil, + }, { + name: "many", + iface: &fakeIface{addrs: []net.Addr{addr4, addr4}}, + ipv: IPVersion4, + want: []net.IP{ip4, ip4}, + wantErr: nil, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got, gotErr := IfaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0) - require.True(t, errors.Is(gotErr, tc.wantErr)) + got, err := IfaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0) + require.ErrorIs(t, err, tc.wantErr) + assert.Equal(t, tc.want, got) }) } diff --git a/internal/aghnet/ipmut_test.go b/internal/aghnet/ipmut_test.go new file mode 100644 index 00000000..51fc16ba --- /dev/null +++ b/internal/aghnet/ipmut_test.go @@ -0,0 +1,44 @@ +package aghnet + +import ( + "net" + "testing" + + "github.com/AdguardTeam/golibs/netutil" + "github.com/stretchr/testify/assert" +) + +func TestIPMut(t *testing.T) { + testIPs := []net.IP{{ + 127, 0, 0, 1, + }, { + 192, 168, 0, 1, + }, { + 8, 8, 8, 8, + }} + + t.Run("nil_no_mut", func(t *testing.T) { + ipmut := NewIPMut(nil) + + ips := netutil.CloneIPs(testIPs) + for i := range ips { + ipmut.Load()(ips[i]) + assert.True(t, ips[i].Equal(testIPs[i])) + } + }) + + t.Run("not_nil_mut", func(t *testing.T) { + ipmut := NewIPMut(func(ip net.IP) { + for i := range ip { + ip[i] = 0 + } + }) + want := netutil.IPv4Zero() + + ips := netutil.CloneIPs(testIPs) + for i := range ips { + ipmut.Load()(ips[i]) + assert.True(t, ips[i].Equal(want)) + } + }) +} diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index 77bdcc63..ecb70fa8 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -42,8 +42,7 @@ func GatewayIP(ifaceName string) net.IP { fields := strings.Fields(string(d)) // The meaningful "ip route" command output should contain the word - // "default" at first field and default gateway IP address at third - // field. + // "default" at first field and default gateway IP address at third field. if len(fields) < 3 || fields[0] != "default" { return nil } @@ -218,28 +217,6 @@ func IsAddrInUse(err error) (ok bool) { return isAddrInUse(sysErr) } -// SplitHost is a wrapper for net.SplitHostPort for the cases when the hostport -// does not necessarily contain a port. -func SplitHost(hostport string) (host string, err error) { - host, _, err = net.SplitHostPort(hostport) - if err != nil { - // Check for the missing port error. If it is that error, just - // use the host as is. - // - // See the source code for net.SplitHostPort. - const missingPort = "missing port in address" - - addrErr := &net.AddrError{} - if !errors.As(err, &addrErr) || addrErr.Err != missingPort { - return "", err - } - - host = hostport - } - - return host, nil -} - // CollectAllIfacesAddrs returns the slice of all network interfaces IP // addresses without port number. func CollectAllIfacesAddrs() (addrs []string, err error) { diff --git a/internal/aghnet/net_test.go b/internal/aghnet/net_test.go index 2e3f54be..b5bf2297 100644 --- a/internal/aghnet/net_test.go +++ b/internal/aghnet/net_test.go @@ -15,12 +15,20 @@ func TestMain(m *testing.M) { aghtest.DiscardLogOutput(m) } -func TestGetValidNetInterfacesForWeb(t *testing.T) { +func TestGetInterfaceByIP(t *testing.T) { ifaces, err := GetValidNetInterfacesForWeb() - require.NoErrorf(t, err, "cannot get net interfaces: %s", err) - require.NotEmpty(t, ifaces, "no net interfaces found") + require.NoError(t, err) + require.NotEmpty(t, ifaces) + for _, iface := range ifaces { - require.NotEmptyf(t, iface.Addresses, "no addresses found for %s", iface.Name) + t.Run(iface.Name, func(t *testing.T) { + require.NotEmpty(t, iface.Addresses) + + for _, ip := range iface.Addresses { + ifaceName := GetInterfaceByIP(ip) + require.Equal(t, iface.Name, ifaceName) + } + }) } } @@ -73,18 +81,47 @@ func TestBroadcastFromIPNet(t *testing.T) { } func TestCheckPort(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:") - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, l.Close) + t.Run("tcp_bound", func(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, l.Close) - ipp := netutil.IPPortFromAddr(l.Addr()) - require.NotNil(t, ipp) - require.NotNil(t, ipp.IP) - require.NotZero(t, ipp.Port) + ipp := netutil.IPPortFromAddr(l.Addr()) + require.NotNil(t, ipp) + require.NotNil(t, ipp.IP) + require.NotZero(t, ipp.Port) - err = CheckPort("tcp", ipp.IP, ipp.Port) - target := &net.OpError{} - require.ErrorAs(t, err, &target) + err = CheckPort("tcp", ipp.IP, ipp.Port) + target := &net.OpError{} + require.ErrorAs(t, err, &target) - assert.Equal(t, "listen", target.Op) + assert.Equal(t, "listen", target.Op) + }) + + t.Run("udp_bound", func(t *testing.T) { + conn, err := net.ListenPacket("udp", "127.0.0.1:") + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, conn.Close) + + ipp := netutil.IPPortFromAddr(conn.LocalAddr()) + require.NotNil(t, ipp) + require.NotNil(t, ipp.IP) + require.NotZero(t, ipp.Port) + + err = CheckPort("udp", ipp.IP, ipp.Port) + target := &net.OpError{} + require.ErrorAs(t, err, &target) + + assert.Equal(t, "listen", target.Op) + }) + + t.Run("bad_network", func(t *testing.T) { + err := CheckPort("bad_network", nil, 0) + assert.NoError(t, err) + }) + + t.Run("can_bind", func(t *testing.T) { + err := CheckPort("udp", net.IP{0, 0, 0, 0}, 0) + assert.NoError(t, err) + }) }