mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-24 14:05:45 +03:00
Pull request: 2704 local resolvers vol.1
Merge in DNS/adguard-home from 2704-local-addresses-vol.1 to master Updates #2704. Updates #2829. Updates #2846. Squashed commit of the following: commit 9a49b3d27edcb30da7f16a065226907833b1dc81 Author: Eugene Burkov <e.burkov@adguard.com> Date: Mon Mar 22 15:39:17 2021 +0300 aghnet: imp docs and logging commit 74f95a29c55b9e732276601b0ecc63fb7c3a9f9e Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 19 20:56:51 2021 +0300 all: fix friday evening mistakes commit 0e2066bc5c16ed807fa601780b99e154502361a9 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 19 20:51:15 2021 +0300 all: upd testify, imp code quality commit 8237c50b670c58361ccf7adec3ff2452b1196677 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 19 20:19:29 2021 +0300 aghnet: imp test naming commit 14eb1e189339554c0a6d38e2ba7a93917774ebab Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 19 19:41:43 2021 +0300 aghnet: isolate windows-specific functionality commit d461ac8b18c187999da3e3aba116571b7ebe6785 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 19 14:50:05 2021 +0300 aghnet: imp code quality commit d0ee01cb1f8613de2085c0f2f2f396e46beb52a5 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 19 11:59:10 2021 +0300 all: mv funcs to agherr, mk system resolvers getter
This commit is contained in:
parent
eb9526cc92
commit
3b2f5d7842
19 changed files with 568 additions and 48 deletions
2
go.mod
2
go.mod
|
@ -29,7 +29,7 @@ require (
|
|||
github.com/satori/go.uuid v1.2.0
|
||||
github.com/sirupsen/logrus v1.8.1 // indirect
|
||||
github.com/spf13/cobra v1.1.3 // indirect
|
||||
github.com/stretchr/testify v1.6.1
|
||||
github.com/stretchr/testify v1.7.0
|
||||
github.com/ti-mo/netfilter v0.4.0
|
||||
github.com/u-root/u-root v7.0.0+incompatible
|
||||
go.etcd.io/bbolt v1.3.5
|
||||
|
|
2
go.sum
2
go.sum
|
@ -413,6 +413,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
|
|||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
|
||||
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
|
||||
github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU=
|
||||
|
|
|
@ -4,6 +4,8 @@ package agherr
|
|||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Error is the constant error type.
|
||||
|
@ -95,6 +97,8 @@ type wrapper interface {
|
|||
// }
|
||||
//
|
||||
// msg must contain the final ": %w" verb.
|
||||
//
|
||||
// TODO(a.garipov): Clearify the function usage.
|
||||
func Annotate(msg string, errPtr *error, args ...interface{}) {
|
||||
if errPtr == nil {
|
||||
return
|
||||
|
@ -107,3 +111,17 @@ func Annotate(msg string, errPtr *error, args ...interface{}) {
|
|||
*errPtr = fmt.Errorf(msg, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogPanic is a convinient helper function to log a panic in a goroutine. It
|
||||
// should not be used where proper error handling is required.
|
||||
func LogPanic(prefix string) {
|
||||
if v := recover(); v != nil {
|
||||
if prefix != "" {
|
||||
log.Error("%s: recovered from panic: %v", prefix, v)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Error("recovered from panic: %v", v)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
package agherr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -120,3 +122,39 @@ func TestAnnotate(t *testing.T) {
|
|||
assert.Equal(t, wantMsg, err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogPanic(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
aghtest.ReplaceLogWriter(t, buf)
|
||||
|
||||
t.Run("prefix", func(t *testing.T) {
|
||||
const (
|
||||
panicMsg = "spooky!"
|
||||
prefix = "packagename"
|
||||
errWithNoPrefix = "[error] recovered from panic: spooky!"
|
||||
errWithPrefix = "[error] packagename: recovered from panic: spooky!"
|
||||
)
|
||||
|
||||
panicFunc := func(prefix string) {
|
||||
defer LogPanic(prefix)
|
||||
|
||||
panic(panicMsg)
|
||||
}
|
||||
|
||||
panicFunc("")
|
||||
assert.Contains(t, buf.String(), errWithNoPrefix)
|
||||
buf.Reset()
|
||||
|
||||
panicFunc(prefix)
|
||||
assert.Contains(t, buf.String(), errWithPrefix)
|
||||
buf.Reset()
|
||||
})
|
||||
|
||||
t.Run("don't_panic", func(t *testing.T) {
|
||||
require.NotPanics(t, func() {
|
||||
defer LogPanic("")
|
||||
})
|
||||
|
||||
assert.Empty(t, buf.String())
|
||||
})
|
||||
}
|
||||
|
|
|
@ -251,3 +251,25 @@ func ErrorIsAddrInUse(err error) bool {
|
|||
|
||||
return errErrno == syscall.EADDRINUSE
|
||||
}
|
||||
|
||||
// SplitHost is a wrapper for net.SplitHostPort for the cases when the hostport
|
||||
// does not necessarily contain a port.
|
||||
func SplitHost(hostport string) (host string, err error) {
|
||||
host, _, err = net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
// Check for the missing port error. If it is that error, just
|
||||
// use the host as is.
|
||||
//
|
||||
// See the source code for net.SplitHostPort.
|
||||
const missingPort = "missing port in address"
|
||||
|
||||
addrErr := &net.AddrError{}
|
||||
if !errors.As(err, &addrErr) || addrErr.Err != missingPort {
|
||||
return "", err
|
||||
}
|
||||
|
||||
host = hostport
|
||||
}
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
)
|
||||
|
||||
// hardwarePortInfo - information obtained using MacOS networksetup
|
||||
|
@ -47,7 +47,7 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
|
|||
// it returns a map where the key is the interface name, and the value is the "hardware port"
|
||||
// returns nil if it fails to parse the output
|
||||
func getNetworkSetupHardwareReports() map[string]string {
|
||||
_, out, err := util.RunCommand("networksetup", "-listallhardwareports")
|
||||
_, out, err := aghos.RunCommand("networksetup", "-listallhardwareports")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ func getNetworkSetupHardwareReports() map[string]string {
|
|||
func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) {
|
||||
h := hardwarePortInfo{}
|
||||
|
||||
_, out, err := util.RunCommand("networksetup", "-getinfo", hardwarePort)
|
||||
_, out, err := aghos.RunCommand("networksetup", "-getinfo", hardwarePort)
|
||||
if err != nil {
|
||||
return h, err
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
|
|||
args = append(args, dnsAddrs...)
|
||||
|
||||
// Setting DNS servers is necessary when configuring a static IP
|
||||
code, _, err := util.RunCommand("networksetup", args...)
|
||||
code, _, err := aghos.RunCommand("networksetup", args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -125,7 +125,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
|
|||
}
|
||||
|
||||
// Actually configures hardware port to have static IP
|
||||
code, _, err = util.RunCommand("networksetup", "-setmanual",
|
||||
code, _, err = aghos.RunCommand("networksetup", "-setmanual",
|
||||
portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
78
internal/aghnet/systemresolvers.go
Normal file
78
internal/aghnet/systemresolvers.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package aghnet
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// DefaultRefreshIvl is the default period of time between refreshing cached
|
||||
// addresses.
|
||||
// const DefaultRefreshIvl = 5 * time.Minute
|
||||
|
||||
// HostGenFunc is the signature for functions generating fake hostnames. The
|
||||
// implementation must be safe for concurrent use.
|
||||
type HostGenFunc func() (host string)
|
||||
|
||||
// unit is an alias for an existing map value.
|
||||
type unit = struct{}
|
||||
|
||||
// SystemResolvers helps to work with local resolvers' addresses provided by OS.
|
||||
type SystemResolvers interface {
|
||||
// Get returns the slice of local resolvers' addresses.
|
||||
// It should be safe for concurrent use.
|
||||
Get() (rs []string)
|
||||
// Refresh refreshes the local resolvers' addresses cache. It should be
|
||||
// safe for concurrent use.
|
||||
Refresh() (err error)
|
||||
}
|
||||
|
||||
const (
|
||||
// fakeDialErr is an error which dialFunc is expected to return.
|
||||
fakeDialErr agherr.Error = "this error signals the successful dialFunc work"
|
||||
|
||||
// badAddrPassedErr is returned when dialFunc can't parse an IP address.
|
||||
badAddrPassedErr agherr.Error = "the passed string is not a valid IP address"
|
||||
)
|
||||
|
||||
// refreshWithTicker refreshes the cache of sr after each tick form tickCh.
|
||||
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
|
||||
defer agherr.LogPanic("systemResolvers")
|
||||
|
||||
// TODO(e.burkov): Implement a functionality to stop ticker.
|
||||
for range tickCh {
|
||||
err := sr.Refresh()
|
||||
if err != nil {
|
||||
log.Error("systemResolvers: error in refreshing goroutine: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("systemResolvers: local addresses cache is refreshed")
|
||||
}
|
||||
}
|
||||
|
||||
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate
|
||||
// defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If
|
||||
// nil is passed for hostGenFunc, the default generator will be used.
|
||||
func NewSystemResolvers(
|
||||
refreshIvl time.Duration,
|
||||
hostGenFunc HostGenFunc,
|
||||
) (sr SystemResolvers, err error) {
|
||||
sr = newSystemResolvers(refreshIvl, hostGenFunc)
|
||||
|
||||
// Fill cache.
|
||||
err = sr.Refresh()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if refreshIvl > 0 {
|
||||
ticker := time.NewTicker(refreshIvl)
|
||||
|
||||
go refreshWithTicker(sr, ticker.C)
|
||||
}
|
||||
|
||||
return sr, nil
|
||||
}
|
96
internal/aghnet/systemresolvers_others.go
Normal file
96
internal/aghnet/systemresolvers_others.go
Normal file
|
@ -0,0 +1,96 @@
|
|||
// +build !windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
)
|
||||
|
||||
// defaultHostGen is the default method of generating host for Refresh.
|
||||
func defaultHostGen() (host string) {
|
||||
// TODO(e.burkov): Use strings.Builder.
|
||||
return fmt.Sprintf("test%d.org", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// systemResolvers is a default implementation of SystemResolvers interface.
|
||||
type systemResolvers struct {
|
||||
resolver *net.Resolver
|
||||
hostGenFunc HostGenFunc
|
||||
|
||||
// addrs is the map that contains cached local resolvers' addresses.
|
||||
addrs map[string]unit
|
||||
addrsLock sync.RWMutex
|
||||
}
|
||||
|
||||
func (sr *systemResolvers) Refresh() (err error) {
|
||||
defer agherr.Annotate("systemResolvers: %w", &err)
|
||||
|
||||
_, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc())
|
||||
dnserr := &net.DNSError{}
|
||||
if errors.As(err, &dnserr) && dnserr.Err == fakeDialErr.Error() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr SystemResolvers) {
|
||||
if hostGenFunc == nil {
|
||||
hostGenFunc = defaultHostGen
|
||||
}
|
||||
s := &systemResolvers{
|
||||
resolver: &net.Resolver{
|
||||
PreferGo: true,
|
||||
},
|
||||
hostGenFunc: hostGenFunc,
|
||||
addrs: make(map[string]unit),
|
||||
}
|
||||
s.resolver.Dial = s.dialFunc
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// dialFunc gets the resolver's address and puts it into internal cache.
|
||||
func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) {
|
||||
// Just validate the passed address is a valid IP.
|
||||
var host string
|
||||
host, err = SplitHost(address)
|
||||
if err != nil {
|
||||
// TODO(e.burkov): Maybe use a structured badAddrPassedErr to
|
||||
// allow unwrapping of the real error.
|
||||
return nil, fmt.Errorf("%s: %w", err, badAddrPassedErr)
|
||||
}
|
||||
|
||||
if net.ParseIP(host) == nil {
|
||||
return nil, fmt.Errorf("parsing %q: %w", host, badAddrPassedErr)
|
||||
}
|
||||
|
||||
sr.addrsLock.Lock()
|
||||
defer sr.addrsLock.Unlock()
|
||||
|
||||
sr.addrs[address] = unit{}
|
||||
|
||||
return nil, fakeDialErr
|
||||
}
|
||||
|
||||
func (sr *systemResolvers) Get() (rs []string) {
|
||||
sr.addrsLock.RLock()
|
||||
defer sr.addrsLock.RUnlock()
|
||||
|
||||
addrs := sr.addrs
|
||||
rs = make([]string, len(addrs))
|
||||
var i int
|
||||
for addr := range addrs {
|
||||
rs[i] = addr
|
||||
i++
|
||||
}
|
||||
|
||||
return rs
|
||||
}
|
74
internal/aghnet/systemresolvers_others_test.go
Normal file
74
internal/aghnet/systemresolvers_others_test.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
// +build !windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func createTestSystemResolversImp(
|
||||
t *testing.T,
|
||||
refreshDur time.Duration,
|
||||
hostGenFunc HostGenFunc,
|
||||
) (imp *systemResolvers) {
|
||||
t.Helper()
|
||||
|
||||
sr := createTestSystemResolvers(t, refreshDur, hostGenFunc)
|
||||
|
||||
var ok bool
|
||||
imp, ok = sr.(*systemResolvers)
|
||||
require.True(t, ok)
|
||||
|
||||
return imp
|
||||
}
|
||||
|
||||
func TestSystemResolvers_Refresh(t *testing.T) {
|
||||
t.Run("expected_error", func(t *testing.T) {
|
||||
sr := createTestSystemResolvers(t, 0, nil)
|
||||
|
||||
assert.NoError(t, sr.Refresh())
|
||||
})
|
||||
|
||||
t.Run("unexpected_error", func(t *testing.T) {
|
||||
_, err := NewSystemResolvers(0, func() string {
|
||||
return "127.0.0.1::123"
|
||||
})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSystemResolvers_DialFunc(t *testing.T) {
|
||||
imp := createTestSystemResolversImp(t, 0, nil)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
address string
|
||||
want error
|
||||
}{{
|
||||
name: "valid",
|
||||
address: "127.0.0.1",
|
||||
want: fakeDialErr,
|
||||
}, {
|
||||
name: "invalid_split_host",
|
||||
address: "127.0.0.1::123",
|
||||
want: badAddrPassedErr,
|
||||
}, {
|
||||
name: "invalid_parse_ip",
|
||||
address: "not-ip",
|
||||
want: badAddrPassedErr,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
conn, err := imp.dialFunc(context.Background(), "", tc.address)
|
||||
|
||||
require.Nil(t, conn)
|
||||
assert.ErrorIs(t, err, tc.want)
|
||||
})
|
||||
}
|
||||
}
|
33
internal/aghnet/systemresolvers_test.go
Normal file
33
internal/aghnet/systemresolvers_test.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package aghnet
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func createTestSystemResolvers(
|
||||
t *testing.T,
|
||||
refreshDur time.Duration,
|
||||
hostGenFunc HostGenFunc,
|
||||
) (sr SystemResolvers) {
|
||||
t.Helper()
|
||||
|
||||
var err error
|
||||
sr, err = NewSystemResolvers(refreshDur, hostGenFunc)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sr)
|
||||
|
||||
return sr
|
||||
}
|
||||
|
||||
func TestSystemResolvers_Get(t *testing.T) {
|
||||
sr := createTestSystemResolvers(t, 0, nil)
|
||||
assert.NotEmpty(t, sr.Get())
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Write tests for refreshWithTicker.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.
|
158
internal/aghnet/systemresolvers_windows.go
Normal file
158
internal/aghnet/systemresolvers_windows.go
Normal file
|
@ -0,0 +1,158 @@
|
|||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// systemResolvers implementation differs for Windows since Go's resolver
|
||||
// doesn't work there.
|
||||
//
|
||||
// See https://github.com/golang/go/issues/33097.
|
||||
type systemResolvers struct {
|
||||
// addrs is the slice of cached local resolvers' addresses.
|
||||
addrs []string
|
||||
addrsLock sync.RWMutex
|
||||
}
|
||||
|
||||
func newSystemResolvers(refreshIvl time.Duration, _ HostGenFunc) (sr SystemResolvers) {
|
||||
return &systemResolvers{}
|
||||
}
|
||||
|
||||
func (sr *systemResolvers) Get() (rs []string) {
|
||||
sr.addrsLock.RLock()
|
||||
defer sr.addrsLock.RUnlock()
|
||||
|
||||
addrs := sr.addrs
|
||||
rs = make([]string, len(addrs))
|
||||
copy(rs, addrs)
|
||||
|
||||
return rs
|
||||
}
|
||||
|
||||
// getAddrs gets local resolvers' addresses from OS in a special Windows way.
|
||||
//
|
||||
// TODO(e.burkov): This whole function needs more detailed research on getting
|
||||
// local resolvers addresses on Windows. We execute the external command for
|
||||
// now that is not the most accurate way.
|
||||
func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
|
||||
cmd := exec.Command("nslookup")
|
||||
|
||||
var stdin io.WriteCloser
|
||||
stdin, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting the command's stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
var stdout io.ReadCloser
|
||||
stdout, err = cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting the command's stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
var stdoutLimited io.ReadCloser
|
||||
stdoutLimited, err = aghio.LimitReadCloser(stdout, aghos.MaxCmdOutputSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("limiting stdout reader: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer agherr.LogPanic("systemResolvers")
|
||||
defer func() {
|
||||
derr := stdin.Close()
|
||||
if derr != nil {
|
||||
log.Error("systemResolvers: closing stdin pipe: %s", derr)
|
||||
}
|
||||
}()
|
||||
|
||||
_, werr := io.WriteString(stdin, "exit")
|
||||
if werr != nil {
|
||||
log.Error("systemResolvers: writing to command pipe: %s", werr)
|
||||
}
|
||||
}()
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start command executing: %w", err)
|
||||
}
|
||||
|
||||
// The output of nslookup looks like this:
|
||||
//
|
||||
// Default Server: 192-168-1-1.qualified.domain.ru
|
||||
// Address: 192.168.1.1
|
||||
|
||||
var possibleIPs []string
|
||||
s := bufio.NewScanner(stdoutLimited)
|
||||
for s.Scan() {
|
||||
line := s.Text()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) != 2 || fields[0] != "Address:" {
|
||||
continue
|
||||
}
|
||||
|
||||
// If the address contains port then it is separated with '#'.
|
||||
ipStrs := strings.Split(fields[1], "#")
|
||||
if len(ipStrs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
possibleIPs = append(possibleIPs, ipStrs[0])
|
||||
}
|
||||
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("executing the command: %w", err)
|
||||
}
|
||||
|
||||
// Don't close StdoutPipe since Wait do it for us in ¿most? cases.
|
||||
//
|
||||
// See go doc os/exec.Cmd.StdoutPipe.
|
||||
|
||||
for _, addr := range possibleIPs {
|
||||
if net.ParseIP(addr) == nil {
|
||||
log.Debug("systemResolvers: %q is not a valid ip", addr)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func (sr *systemResolvers) Refresh() (err error) {
|
||||
defer agherr.Annotate("systemResolvers: %w", &err)
|
||||
|
||||
got, err := sr.getAddrs()
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't get addresses: %w", err)
|
||||
}
|
||||
if len(got) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sr.addrsLock.Lock()
|
||||
defer sr.addrsLock.Unlock()
|
||||
|
||||
sr.addrs = got
|
||||
|
||||
return nil
|
||||
}
|
7
internal/aghnet/systemresolvers_windows_test.go
Normal file
7
internal/aghnet/systemresolvers_windows_test.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
// TODO(e.burkov): Write tests for Windows implementation.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.
|
|
@ -1,7 +1,11 @@
|
|||
// Package aghos contains utilities for functions requiring system calls.
|
||||
package aghos
|
||||
|
||||
import "syscall"
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// CanBindPrivilegedPorts checks if current process can bind to privileged
|
||||
// ports.
|
||||
|
@ -24,3 +28,20 @@ func HaveAdminRights() (bool, error) {
|
|||
func SendProcessSignal(pid int, sig syscall.Signal) error {
|
||||
return sendProcessSignal(pid, sig)
|
||||
}
|
||||
|
||||
// MaxCmdOutputSize is the maximum length of performed shell command output.
|
||||
const MaxCmdOutputSize = 2 * 1024
|
||||
|
||||
// RunCommand runs shell command.
|
||||
func RunCommand(command string, arguments ...string) (int, string, error) {
|
||||
cmd := exec.Command(command, arguments...)
|
||||
out, err := cmd.Output()
|
||||
if len(out) > MaxCmdOutputSize {
|
||||
out = out[:MaxCmdOutputSize]
|
||||
}
|
||||
if err != nil {
|
||||
return 1, "", fmt.Errorf("exec.Command(%s) failed: %v: %s", command, err, string(out))
|
||||
}
|
||||
|
||||
return cmd.ProcessState.ExitCode(), string(out), nil
|
||||
}
|
||||
|
|
|
@ -3,12 +3,12 @@ package aghtest
|
|||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
@ -166,7 +166,10 @@ type TestErrUpstream struct{}
|
|||
|
||||
// Exchange always returns nil Msg and non-nil error.
|
||||
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
|
||||
return nil, agherr.Error("bad")
|
||||
// We don't use an agherr.Error to avoid the import cycle since aghtests
|
||||
// used to provide the utilities for testing which agherr (and any other
|
||||
// testable package) should be able to use.
|
||||
return nil, errors.New("bad")
|
||||
}
|
||||
|
||||
// Address always returns an empty string.
|
||||
|
|
|
@ -2,7 +2,6 @@ package home
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -11,6 +10,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
|
@ -213,22 +213,11 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
|||
return true
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
host, err := aghnet.SplitHost(r.Host)
|
||||
if err != nil {
|
||||
// Check for the missing port error. If it is that error, just
|
||||
// use the host as is.
|
||||
//
|
||||
// See the source code for net.SplitHostPort.
|
||||
const missingPort = "missing port in address"
|
||||
httpError(w, http.StatusBadRequest, "bad host: %s", err)
|
||||
|
||||
addrErr := &net.AddrError{}
|
||||
if !errors.As(err, &addrErr) || addrErr.Err != missingPort {
|
||||
httpError(w, http.StatusBadRequest, "bad host: %s", err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
host = r.Host
|
||||
return false
|
||||
}
|
||||
|
||||
if r.TLS == nil && web.forceHTTPS {
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
|
@ -270,14 +271,8 @@ func copyInstallSettings(dst, src *configuration) {
|
|||
// shutdownTimeout is the timeout for shutting HTTP server down operation.
|
||||
const shutdownTimeout = 5 * time.Second
|
||||
|
||||
func logPanic() {
|
||||
if v := recover(); v != nil {
|
||||
log.Error("recovered from panic: %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func shutdownSrv(ctx context.Context, cancel context.CancelFunc, srv *http.Server) {
|
||||
defer logPanic()
|
||||
defer agherr.LogPanic("")
|
||||
|
||||
if srv == nil {
|
||||
return
|
||||
|
|
|
@ -98,7 +98,7 @@ func sendSigReload() {
|
|||
if os.IsNotExist(err) {
|
||||
var code int
|
||||
var psdata string
|
||||
code, psdata, err = util.RunCommand("ps", "-C", serviceName, "-o", "pid=")
|
||||
code, psdata, err = aghos.RunCommand("ps", "-C", serviceName, "-o", "pid=")
|
||||
if err != nil || code != 0 {
|
||||
log.Error("Can't find AdGuardHome process: %s code:%d", err, code)
|
||||
return
|
||||
|
@ -301,7 +301,7 @@ func configureService(c *service.Config) {
|
|||
// returns command code or error if any
|
||||
func runInitdCommand(action string) (int, error) {
|
||||
confPath := "/etc/init.d/" + serviceName
|
||||
code, _, err := util.RunCommand("sh", "-c", confPath+" "+action)
|
||||
code, _, err := aghos.RunCommand("sh", "-c", confPath+" "+action)
|
||||
return code, err
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package querylog
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -116,7 +115,7 @@ func TestQLogReader_Seek(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
|
||||
err = r.SeekTS(ts.UnixNano())
|
||||
assert.True(t, errors.Is(err, tc.want), err)
|
||||
assert.ErrorIs(t, err, tc.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,25 +6,12 @@ package util
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RunCommand runs shell command.
|
||||
func RunCommand(command string, arguments ...string) (int, string, error) {
|
||||
cmd := exec.Command(command, arguments...)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return 1, "", fmt.Errorf("exec.Command(%s) failed: %v: %s", command, err, string(out))
|
||||
}
|
||||
|
||||
return cmd.ProcessState.ExitCode(), string(out), nil
|
||||
}
|
||||
|
||||
// SplitNext - split string by a byte and return the first chunk
|
||||
// Skip empty chunks
|
||||
// Whitespace is trimmed
|
||||
|
|
Loading…
Reference in a new issue