mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-22 04:55:33 +03:00
Pull request: all: imp cyclo in new code
Updates #2646, Squashed commit of the following: commit af6a6fa2b7229bc0f1c7c9083b0391a6bec7ae70 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Mon May 31 20:00:36 2021 +0300 all: imp code, docs commit 1cd4781b13e635a9e1bccb758104c1b76c78d34e Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Mon May 31 18:51:23 2021 +0300 all: imp cyclo in new code
This commit is contained in:
parent
c95acf73ab
commit
e17e1f20fb
10 changed files with 211 additions and 190 deletions
|
@ -2,7 +2,6 @@ package aghnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -231,14 +230,41 @@ func (ehc *EtcHostsContainer) updateTableRev(tableRev map[string][]string, newHo
|
||||||
log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost)
|
log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read IP-hostname pairs from file
|
// parseHostsLine parses hosts from the fields.
|
||||||
// Multiple hostnames per line (per one IP) is supported.
|
func parseHostsLine(fields []string) (hosts []string) {
|
||||||
func (ehc *EtcHostsContainer) load(table map[string][]net.IP, tableRev map[string][]string, fn string) {
|
for _, f := range fields {
|
||||||
|
hashIdx := strings.IndexByte(f, '#')
|
||||||
|
if hashIdx == 0 {
|
||||||
|
// The rest of the fields are a part of the comment.
|
||||||
|
// Skip immediately.
|
||||||
|
return
|
||||||
|
} else if hashIdx > 0 {
|
||||||
|
// Only a part of the field is a comment.
|
||||||
|
hosts = append(hosts, f[:hashIdx])
|
||||||
|
|
||||||
|
return hosts
|
||||||
|
}
|
||||||
|
|
||||||
|
hosts = append(hosts, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hosts
|
||||||
|
}
|
||||||
|
|
||||||
|
// load reads IP-hostname pairs from the hosts file. Multiple hostnames per
|
||||||
|
// line for one IP are supported.
|
||||||
|
func (ehc *EtcHostsContainer) load(
|
||||||
|
table map[string][]net.IP,
|
||||||
|
tableRev map[string][]string,
|
||||||
|
fn string,
|
||||||
|
) {
|
||||||
f, err := os.Open(fn)
|
f, err := os.Open(fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("etchostscontainer: %s", err)
|
log.Error("etchostscontainer: %s", err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
derr := f.Close()
|
derr := f.Close()
|
||||||
if derr != nil {
|
if derr != nil {
|
||||||
|
@ -246,25 +272,11 @@ func (ehc *EtcHostsContainer) load(table map[string][]net.IP, tableRev map[strin
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
r := bufio.NewReader(f)
|
|
||||||
log.Debug("etchostscontainer: loading hosts from file %s", fn)
|
log.Debug("etchostscontainer: loading hosts from file %s", fn)
|
||||||
|
|
||||||
for done := false; !done; {
|
s := bufio.NewScanner(f)
|
||||||
var line string
|
for s.Scan() {
|
||||||
line, err = r.ReadString('\n')
|
line := strings.TrimSpace(s.Text())
|
||||||
if err == io.EOF {
|
|
||||||
done = true
|
|
||||||
} else if err != nil {
|
|
||||||
log.Error("etchostscontainer: %s", err)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
if len(line) == 0 || line[0] == '#' {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
fields := strings.Fields(line)
|
fields := strings.Fields(line)
|
||||||
if len(fields) < 2 {
|
if len(fields) < 2 {
|
||||||
continue
|
continue
|
||||||
|
@ -275,28 +287,17 @@ func (ehc *EtcHostsContainer) load(table map[string][]net.IP, tableRev map[strin
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 1; i != len(fields); i++ {
|
hosts := parseHostsLine(fields[1:])
|
||||||
host := fields[i]
|
for _, host := range hosts {
|
||||||
if len(host) == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
sharp := strings.IndexByte(host, '#')
|
|
||||||
if sharp == 0 {
|
|
||||||
// Skip the comments.
|
|
||||||
break
|
|
||||||
} else if sharp > 0 {
|
|
||||||
host = host[:sharp]
|
|
||||||
}
|
|
||||||
|
|
||||||
ehc.updateTable(table, host, ip)
|
ehc.updateTable(table, host, ip)
|
||||||
ehc.updateTableRev(tableRev, host, ip)
|
ehc.updateTableRev(tableRev, host, ip)
|
||||||
if sharp >= 0 {
|
|
||||||
// Skip the comments again.
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = s.Err()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("etchostscontainer: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// onlyWrites is a filter for (*fsnotify.Watcher).Events.
|
// onlyWrites is a filter for (*fsnotify.Watcher).Events.
|
||||||
|
|
|
@ -23,10 +23,11 @@ func prepareTestFile(t *testing.T) (f *os.File) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
|
|
||||||
f, err := os.CreateTemp(dir, "")
|
f, err := os.CreateTemp(dir, "")
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, f)
|
require.NotNil(t, f)
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.Nil(t, f.Close())
|
assert.NoError(t, f.Close())
|
||||||
})
|
})
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
@ -37,7 +38,7 @@ func assertWriting(t *testing.T, f *os.File, strs ...string) {
|
||||||
|
|
||||||
for _, str := range strs {
|
for _, str := range strs {
|
||||||
n, err := f.WriteString(str)
|
n, err := f.WriteString(str)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, n, len(str))
|
assert.Equal(t, n, len(str))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -77,16 +78,16 @@ func TestEtcHostsContainerResolution(t *testing.T) {
|
||||||
t.Run("ptr", func(t *testing.T) {
|
t.Run("ptr", func(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
wantIP string
|
wantIP string
|
||||||
wantLen int
|
|
||||||
wantHost string
|
wantHost string
|
||||||
|
wantLen int
|
||||||
}{
|
}{
|
||||||
{wantIP: "127.0.0.1", wantLen: 2, wantHost: "host"},
|
{wantIP: "127.0.0.1", wantHost: "host", wantLen: 2},
|
||||||
{wantIP: "::1", wantLen: 1, wantHost: "localhost"},
|
{wantIP: "::1", wantHost: "localhost", wantLen: 1},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
a, err := dns.ReverseAddr(tc.wantIP)
|
a, err := dns.ReverseAddr(tc.wantIP)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
a = strings.TrimSuffix(a, ".")
|
a = strings.TrimSuffix(a, ".")
|
||||||
hosts := ehc.ProcessReverse(a, dns.TypePTR)
|
hosts := ehc.ProcessReverse(a, dns.TypePTR)
|
||||||
|
@ -114,7 +115,7 @@ func TestEtcHostsContainerFSNotify(t *testing.T) {
|
||||||
t.Cleanup(ehc.Close)
|
t.Cleanup(ehc.Close)
|
||||||
|
|
||||||
assertWriting(t, f, "127.0.0.2 newhost\n")
|
assertWriting(t, f, "127.0.0.2 newhost\n")
|
||||||
require.Nil(t, f.Sync())
|
require.NoError(t, f.Sync())
|
||||||
|
|
||||||
// Wait until fsnotify has triggerred and processed the
|
// Wait until fsnotify has triggerred and processed the
|
||||||
// file-modification event.
|
// file-modification event.
|
|
@ -68,40 +68,41 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
|
||||||
return false, ErrNoStaticIPInfo
|
return false, ErrNoStaticIPInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// findIfaceLine scans s until it finds the line that declares an interface with
|
||||||
|
// the given name. If findIfaceLine can't find the line, it returns false.
|
||||||
|
func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
|
||||||
|
for s.Scan() {
|
||||||
|
line := strings.TrimSpace(s.Text())
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) == 2 && fields[0] == "interface" && fields[1] == name {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to
|
// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to
|
||||||
// have a static IP.
|
// have a static IP.
|
||||||
func dhcpcdStaticConfig(r io.Reader, ifaceName string) (has bool, err error) {
|
func dhcpcdStaticConfig(r io.Reader, ifaceName string) (has bool, err error) {
|
||||||
s := bufio.NewScanner(r)
|
s := bufio.NewScanner(r)
|
||||||
var withinInterfaceCtx bool
|
ifaceFound := findIfaceLine(s, ifaceName)
|
||||||
|
if !ifaceFound {
|
||||||
|
return false, s.Err()
|
||||||
|
}
|
||||||
|
|
||||||
for s.Scan() {
|
for s.Scan() {
|
||||||
line := strings.TrimSpace(s.Text())
|
line := strings.TrimSpace(s.Text())
|
||||||
|
|
||||||
if withinInterfaceCtx && len(line) == 0 {
|
|
||||||
// An empty line resets our state.
|
|
||||||
withinInterfaceCtx = false
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(line) == 0 || line[0] == '#' {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
fields := strings.Fields(line)
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) >= 2 &&
|
||||||
if withinInterfaceCtx {
|
fields[0] == "static" &&
|
||||||
if len(fields) >= 2 && fields[0] == "static" && strings.HasPrefix(fields[1], "ip_address=") {
|
strings.HasPrefix(fields[1], "ip_address=") {
|
||||||
return true, nil
|
return true, s.Err()
|
||||||
}
|
|
||||||
if len(fields) > 0 && fields[0] == "interface" {
|
|
||||||
// Another interface found.
|
|
||||||
withinInterfaceCtx = false
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(fields) == 2 && fields[0] == "interface" && fields[1] == ifaceName {
|
if len(fields) > 0 && fields[0] == "interface" {
|
||||||
// The interface found.
|
// Another interface found.
|
||||||
withinInterfaceCtx = true
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,6 @@ package aghnet
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -25,18 +24,6 @@ type SystemResolvers interface {
|
||||||
refresh() (err error)
|
refresh() (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
// errBadAddrPassed is returned when dialFunc can't parse an IP address.
|
|
||||||
errBadAddrPassed errors.Error = "the passed string is not a valid IP address"
|
|
||||||
|
|
||||||
// errFakeDial is an error which dialFunc is expected to return.
|
|
||||||
errFakeDial errors.Error = "this error signals the successful dialFunc work"
|
|
||||||
|
|
||||||
// errUnexpectedHostFormat is returned by validateDialedHost when the host has
|
|
||||||
// more than one percent sign.
|
|
||||||
errUnexpectedHostFormat errors.Error = "unexpected host format"
|
|
||||||
)
|
|
||||||
|
|
||||||
// refreshWithTicker refreshes the cache of sr after each tick form tickCh.
|
// refreshWithTicker refreshes the cache of sr after each tick form tickCh.
|
||||||
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
|
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
|
||||||
defer log.OnPanic("systemResolvers")
|
defer log.OnPanic("systemResolvers")
|
||||||
|
|
|
@ -32,6 +32,18 @@ type systemResolvers struct {
|
||||||
addrsLock sync.RWMutex
|
addrsLock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// errBadAddrPassed is returned when dialFunc can't parse an IP address.
|
||||||
|
errBadAddrPassed errors.Error = "the passed string is not a valid IP address"
|
||||||
|
|
||||||
|
// errFakeDial is an error which dialFunc is expected to return.
|
||||||
|
errFakeDial errors.Error = "this error signals the successful dialFunc work"
|
||||||
|
|
||||||
|
// errUnexpectedHostFormat is returned by validateDialedHost when the host has
|
||||||
|
// more than one percent sign.
|
||||||
|
errUnexpectedHostFormat errors.Error = "unexpected host format"
|
||||||
|
)
|
||||||
|
|
||||||
func (sr *systemResolvers) refresh() (err error) {
|
func (sr *systemResolvers) refresh() (err error) {
|
||||||
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
|
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,57 @@ func (sr *systemResolvers) Get() (rs []string) {
|
||||||
return rs
|
return rs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeExit writes "exit" to w and closes it. It is supposed to be run in
|
||||||
|
// a goroutine.
|
||||||
|
func writeExit(w io.WriteCloser) {
|
||||||
|
defer log.OnPanic("systemResolvers: writeExit")
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
derr := w.Close()
|
||||||
|
if derr != nil {
|
||||||
|
log.Error("systemResolvers: writeExit: closing: %s", derr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := io.WriteString(w, "exit")
|
||||||
|
if err != nil {
|
||||||
|
log.Error("systemResolvers: writeExit: writing: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanAddrs scans the DNS addresses from nslookup's output. The expected
|
||||||
|
// output of nslookup looks like this:
|
||||||
|
//
|
||||||
|
// Default Server: 192-168-1-1.qualified.domain.ru
|
||||||
|
// Address: 192.168.1.1
|
||||||
|
//
|
||||||
|
func scanAddrs(s *bufio.Scanner) (addrs []string) {
|
||||||
|
for s.Scan() {
|
||||||
|
line := strings.TrimSpace(s.Text())
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) != 2 || fields[0] != "Address:" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the address contains port then it is separated with '#'.
|
||||||
|
ipPort := strings.Split(fields[1], "#")
|
||||||
|
if len(ipPort) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := ipPort[0]
|
||||||
|
if net.ParseIP(addr) == nil {
|
||||||
|
log.Debug("systemResolvers: %q is not a valid ip", addr)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs = append(addrs, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return addrs
|
||||||
|
}
|
||||||
|
|
||||||
// getAddrs gets local resolvers' addresses from OS in a special Windows way.
|
// getAddrs gets local resolvers' addresses from OS in a special Windows way.
|
||||||
//
|
//
|
||||||
// TODO(e.burkov): This whole function needs more detailed research on getting
|
// TODO(e.burkov): This whole function needs more detailed research on getting
|
||||||
|
@ -71,73 +122,30 @@ func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
|
||||||
return nil, fmt.Errorf("limiting stdout reader: %w", err)
|
return nil, fmt.Errorf("limiting stdout reader: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go writeExit(stdin)
|
||||||
defer log.OnPanic("systemResolvers")
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
derr := stdin.Close()
|
|
||||||
if derr != nil {
|
|
||||||
log.Error("systemResolvers: closing stdin pipe: %s", derr)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, werr := io.WriteString(stdin, "exit")
|
|
||||||
if werr != nil {
|
|
||||||
log.Error("systemResolvers: writing to command pipe: %s", werr)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = cmd.Start()
|
err = cmd.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("start command executing: %w", err)
|
return nil, fmt.Errorf("start command executing: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The output of nslookup looks like this:
|
|
||||||
//
|
|
||||||
// Default Server: 192-168-1-1.qualified.domain.ru
|
|
||||||
// Address: 192.168.1.1
|
|
||||||
|
|
||||||
var possibleIPs []string
|
|
||||||
s := bufio.NewScanner(stdoutLimited)
|
s := bufio.NewScanner(stdoutLimited)
|
||||||
for s.Scan() {
|
addrs = scanAddrs(s)
|
||||||
line := s.Text()
|
|
||||||
if len(line) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
fields := strings.Fields(line)
|
|
||||||
if len(fields) != 2 || fields[0] != "Address:" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the address contains port then it is separated with '#'.
|
|
||||||
ipStrs := strings.Split(fields[1], "#")
|
|
||||||
if len(ipStrs) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
possibleIPs = append(possibleIPs, ipStrs[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cmd.Wait()
|
err = cmd.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("executing the command: %w", err)
|
return nil, fmt.Errorf("executing the command: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = s.Err()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning output: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Don't close StdoutPipe since Wait do it for us in ¿most? cases.
|
// Don't close StdoutPipe since Wait do it for us in ¿most? cases.
|
||||||
//
|
//
|
||||||
// See go doc os/exec.Cmd.StdoutPipe.
|
// See go doc os/exec.Cmd.StdoutPipe.
|
||||||
|
|
||||||
for _, addr := range possibleIPs {
|
|
||||||
if net.ParseIP(addr) == nil {
|
|
||||||
log.Debug("systemResolvers: %q is not a valid ip", addr)
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
addrs = append(addrs, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return addrs, nil
|
return addrs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,8 +13,6 @@ import (
|
||||||
|
|
||||||
// TestUpstream is a mock of real upstream.
|
// TestUpstream is a mock of real upstream.
|
||||||
type TestUpstream struct {
|
type TestUpstream struct {
|
||||||
// Addr is the address for Address method.
|
|
||||||
Addr string
|
|
||||||
// CName is a map of hostname to canonical name.
|
// CName is a map of hostname to canonical name.
|
||||||
CName map[string]string
|
CName map[string]string
|
||||||
// IPv4 is a map of hostname to IPv4.
|
// IPv4 is a map of hostname to IPv4.
|
||||||
|
@ -23,9 +21,13 @@ type TestUpstream struct {
|
||||||
IPv6 map[string][]net.IP
|
IPv6 map[string][]net.IP
|
||||||
// Reverse is a map of address to domain name.
|
// Reverse is a map of address to domain name.
|
||||||
Reverse map[string][]string
|
Reverse map[string][]string
|
||||||
|
// Addr is the address for Address method.
|
||||||
|
Addr string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exchange implements upstream.Upstream interface for *TestUpstream.
|
// Exchange implements upstream.Upstream interface for *TestUpstream.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Split further into handlers.
|
||||||
func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
|
func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
|
||||||
resp = &dns.Msg{}
|
resp = &dns.Msg{}
|
||||||
resp.SetReply(m)
|
resp.SetReply(m)
|
||||||
|
@ -33,70 +35,69 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
|
||||||
if len(m.Question) == 0 {
|
if len(m.Question) == 0 {
|
||||||
return nil, fmt.Errorf("question should not be empty")
|
return nil, fmt.Errorf("question should not be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
name := m.Question[0].Name
|
name := m.Question[0].Name
|
||||||
|
|
||||||
if cname, ok := u.CName[name]; ok {
|
if cname, ok := u.CName[name]; ok {
|
||||||
resp.Answer = append(resp.Answer, &dns.CNAME{
|
ans := &dns.CNAME{
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
Name: name,
|
Name: name,
|
||||||
Rrtype: dns.TypeCNAME,
|
Rrtype: dns.TypeCNAME,
|
||||||
},
|
},
|
||||||
Target: cname,
|
Target: cname,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
resp.Answer = append(resp.Answer, ans)
|
||||||
}
|
}
|
||||||
|
|
||||||
var hasRec bool
|
rrType := m.Question[0].Qtype
|
||||||
var rrType uint16
|
hdr := dns.RR_Header{
|
||||||
|
Name: name,
|
||||||
|
Rrtype: rrType,
|
||||||
|
}
|
||||||
|
|
||||||
|
var names []string
|
||||||
var ips []net.IP
|
var ips []net.IP
|
||||||
switch m.Question[0].Qtype {
|
switch m.Question[0].Qtype {
|
||||||
case dns.TypeA:
|
case dns.TypeA:
|
||||||
rrType = dns.TypeA
|
ips = u.IPv4[name]
|
||||||
if ipv4addr, ok := u.IPv4[name]; ok {
|
|
||||||
hasRec = true
|
|
||||||
ips = ipv4addr
|
|
||||||
}
|
|
||||||
case dns.TypeAAAA:
|
case dns.TypeAAAA:
|
||||||
rrType = dns.TypeAAAA
|
ips = u.IPv6[name]
|
||||||
if ipv6addr, ok := u.IPv6[name]; ok {
|
|
||||||
hasRec = true
|
|
||||||
ips = ipv6addr
|
|
||||||
}
|
|
||||||
case dns.TypePTR:
|
case dns.TypePTR:
|
||||||
names, ok := u.Reverse[name]
|
names = u.Reverse[name]
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, n := range names {
|
|
||||||
resp.Answer = append(resp.Answer, &dns.PTR{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: n,
|
|
||||||
Rrtype: rrType,
|
|
||||||
},
|
|
||||||
Ptr: n,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
resp.Answer = append(resp.Answer, &dns.A{
|
var ans dns.RR
|
||||||
Hdr: dns.RR_Header{
|
if rrType == dns.TypeA {
|
||||||
Name: name,
|
ans = &dns.A{
|
||||||
Rrtype: rrType,
|
Hdr: hdr,
|
||||||
},
|
A: ip,
|
||||||
A: ip,
|
}
|
||||||
})
|
|
||||||
|
resp.Answer = append(resp.Answer, ans)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ans = &dns.AAAA{
|
||||||
|
Hdr: hdr,
|
||||||
|
AAAA: ip,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Answer = append(resp.Answer, ans)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, n := range names {
|
||||||
|
ans := &dns.PTR{
|
||||||
|
Hdr: hdr,
|
||||||
|
Ptr: n,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Answer = append(resp.Answer, ans)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(resp.Answer) == 0 {
|
if len(resp.Answer) == 0 {
|
||||||
if hasRec {
|
|
||||||
// Set no error RCode if there are some records for
|
|
||||||
// given Qname but we didn't apply them.
|
|
||||||
resp.SetRcode(m, dns.RcodeSuccess)
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
// Set NXDomain RCode otherwise.
|
|
||||||
resp.SetRcode(m, dns.RcodeNameError)
|
resp.SetRcode(m, dns.RcodeNameError)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,10 +112,13 @@ func (u *TestUpstream) Address() string {
|
||||||
// TestBlockUpstream implements upstream.Upstream interface for replacing real
|
// TestBlockUpstream implements upstream.Upstream interface for replacing real
|
||||||
// upstream in tests.
|
// upstream in tests.
|
||||||
type TestBlockUpstream struct {
|
type TestBlockUpstream struct {
|
||||||
Hostname string
|
Hostname string
|
||||||
Block bool
|
|
||||||
requestsCount int
|
// lock protects reqNum.
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
|
reqNum int
|
||||||
|
|
||||||
|
Block bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
|
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
|
||||||
|
@ -122,7 +126,7 @@ type TestBlockUpstream struct {
|
||||||
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
|
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
|
||||||
u.lock.Lock()
|
u.lock.Lock()
|
||||||
defer u.lock.Unlock()
|
defer u.lock.Unlock()
|
||||||
u.requestsCount++
|
u.reqNum++
|
||||||
|
|
||||||
hash := sha256.Sum256([]byte(u.Hostname))
|
hash := sha256.Sum256([]byte(u.Hostname))
|
||||||
hashToReturn := hex.EncodeToString(hash[:])
|
hashToReturn := hex.EncodeToString(hash[:])
|
||||||
|
@ -156,7 +160,7 @@ func (u *TestBlockUpstream) RequestsCount() int {
|
||||||
u.lock.Lock()
|
u.lock.Lock()
|
||||||
defer u.lock.Unlock()
|
defer u.lock.Unlock()
|
||||||
|
|
||||||
return u.requestsCount
|
return u.reqNum
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestErrUpstream implements upstream.Upstream interface for replacing real
|
// TestErrUpstream implements upstream.Upstream interface for replacing real
|
||||||
|
|
|
@ -326,7 +326,7 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||||
require.Len(t, pctx.Res.Answer, tc.wantLen)
|
require.Len(t, pctx.Res.Answer, tc.wantLen)
|
||||||
|
|
||||||
if tc.wantLen > 0 {
|
if tc.wantLen > 0 {
|
||||||
assert.Equal(t, tc.want, pctx.Res.Answer[0].Header().Name)
|
assert.Equal(t, tc.want, pctx.Res.Answer[0].(*dns.PTR).Ptr)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -368,7 +368,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||||
require.Equal(t, resultCodeSuccess, rc)
|
require.Equal(t, resultCodeSuccess, rc)
|
||||||
require.NotEmpty(t, proxyCtx.Res.Answer)
|
require.NotEmpty(t, proxyCtx.Res.Answer)
|
||||||
|
|
||||||
assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].Header().Name)
|
assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].(*dns.PTR).Ptr)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("disabled", func(t *testing.T) {
|
t.Run("disabled", func(t *testing.T) {
|
||||||
|
|
|
@ -284,7 +284,7 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
|
||||||
StartTime: time.Now(),
|
StartTime: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
var resolver *proxy.Proxy = s.internalProxy
|
resolver := s.internalProxy
|
||||||
if s.subnetDetector.IsLocallyServedNetwork(ip) {
|
if s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||||
if !s.conf.UsePrivateRDNS {
|
if !s.conf.UsePrivateRDNS {
|
||||||
return "", nil
|
return "", nil
|
||||||
|
|
|
@ -175,8 +175,15 @@ golint --set_exit_status ./...
|
||||||
|
|
||||||
"$GO" vet ./...
|
"$GO" vet ./...
|
||||||
|
|
||||||
# Here and below, don't use quotes to get word splitting.
|
# Apply more lax standards to the code we haven't properly refactored yet.
|
||||||
gocyclo --over 17 $go_files
|
gocyclo --over 17 ./internal/dhcpd/ ./internal/dnsforward/\
|
||||||
|
./internal/filtering/ ./internal/home/ ./internal/querylog/\
|
||||||
|
./internal/stats/ ./internal/updater/
|
||||||
|
|
||||||
|
# Apply stricter standards to new or vetted code
|
||||||
|
gocyclo --over 10 ./internal/aghio/ ./internal/aghnet/ ./internal/aghos/\
|
||||||
|
./internal/aghstrings/ ./internal/aghtest/ ./internal/tools/\
|
||||||
|
./internal/version/ ./main.go
|
||||||
|
|
||||||
gosec --quiet $go_files
|
gosec --quiet $go_files
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue