cherry-pick: 3157 excessive ptrs

Merge in DNS/adguard-home from 3157-excessive-ptrs to master

Updates #3157.

Squashed commit of the following:

commit 6803988240dca2f147bb80a5b3f78d7749d2fa14
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Apr 19 14:50:01 2022 +0300

    aghnet: and again

commit 1a7f4d1dbc8fd4d3ae620349917526a75fa71b47
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Apr 19 14:49:20 2022 +0300

    aghnet: docs again

commit d88da1fc7135f3cd03aff10b02d9957c8ffdfd30
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Apr 19 14:47:36 2022 +0300

    aghnet: imp docs

commit c45dbc7800e882c6c4110aab640c32b03046f89a
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Apr 19 14:41:19 2022 +0300

    aghnet: keep alphabetical order

commit b61781785d096ef43f60fb4f1905a4ed3cdf7c68
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Apr 19 13:50:56 2022 +0300

    aghnet: imp code quality

commit 578dbd71ed2f2089c69343d7d4bf8bbc29150ace
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Apr 12 17:02:38 2022 +0300

    aghnet: imp arp container
This commit is contained in:
Eugene Burkov 2022-04-19 15:01:49 +03:00 committed by Ainar Garipov
parent 723279121a
commit c4a13b92d2
39 changed files with 2111 additions and 510 deletions

View file

@ -34,10 +34,6 @@ and this project adheres to
### Changed ### Changed
- Reverse DNS now has a greater priority as the source of runtime clients'
information than ARP neighborhood.
- Improved detection of runtime clients through more resilient ARP processing
([#3597]).
- The TTL of responses served from the optimistic cache is now lowered to 10 - The TTL of responses served from the optimistic cache is now lowered to 10
seconds. seconds.
- Domain-specific private reverse DNS upstream servers are now validated to - Domain-specific private reverse DNS upstream servers are now validated to
@ -114,7 +110,6 @@ In this release, the schema version has changed from 12 to 14.
[#3367]: https://github.com/AdguardTeam/AdGuardHome/issues/3367 [#3367]: https://github.com/AdguardTeam/AdGuardHome/issues/3367
[#3381]: https://github.com/AdguardTeam/AdGuardHome/issues/3381 [#3381]: https://github.com/AdguardTeam/AdGuardHome/issues/3381
[#3503]: https://github.com/AdguardTeam/AdGuardHome/issues/3503 [#3503]: https://github.com/AdguardTeam/AdGuardHome/issues/3503
[#3597]: https://github.com/AdguardTeam/AdGuardHome/issues/3597
[#4238]: https://github.com/AdguardTeam/AdGuardHome/issues/4238 [#4238]: https://github.com/AdguardTeam/AdGuardHome/issues/4238
[ddr-draft-06]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html [ddr-draft-06]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html
@ -150,6 +145,10 @@ See also the [v0.107.7 GitHub milestone][ms-v0.107.7].
### Changed ### Changed
- Reverse DNS now has a greater priority as the source of runtime clients'
information than ARP neighborhood.
- Improved detection of runtime clients through more resilient ARP processing
([#3597]).
- On OpenBSD, the daemon script now uses the recommended `/bin/ksh` shell - On OpenBSD, the daemon script now uses the recommended `/bin/ksh` shell
instead of the `/bin/sh` one ([#4533]). To apply this change, backup your instead of the `/bin/sh` one ([#4533]). To apply this change, backup your
data and run `AdGuardHome -s uninstall && AdGuardHome -s install`. data and run `AdGuardHome -s uninstall && AdGuardHome -s install`.
@ -169,6 +168,7 @@ See also the [v0.107.7 GitHub milestone][ms-v0.107.7].
[#1730]: https://github.com/AdguardTeam/AdGuardHome/issues/1730 [#1730]: https://github.com/AdguardTeam/AdGuardHome/issues/1730
[#3157]: https://github.com/AdguardTeam/AdGuardHome/issues/3157 [#3157]: https://github.com/AdguardTeam/AdGuardHome/issues/3157
[#3597]: https://github.com/AdguardTeam/AdGuardHome/issues/3597
[#3978]: https://github.com/AdguardTeam/AdGuardHome/issues/3978 [#3978]: https://github.com/AdguardTeam/AdGuardHome/issues/3978
[#4166]: https://github.com/AdguardTeam/AdGuardHome/issues/4166 [#4166]: https://github.com/AdguardTeam/AdGuardHome/issues/4166
[#4213]: https://github.com/AdguardTeam/AdGuardHome/issues/4213 [#4213]: https://github.com/AdguardTeam/AdGuardHome/issues/4213

211
internal/aghnet/arpdb.go Normal file
View file

@ -0,0 +1,211 @@
package aghnet
import (
"bufio"
"bytes"
"fmt"
"net"
"sync"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
)
// ARPDB: The Network Neighborhood Database
// ARPDB stores and refreshes the network neighborhood reported by ARP (Address
// Resolution Protocol).
type ARPDB interface {
// Refresh updates the stored data. It must be safe for concurrent use.
Refresh() (err error)
// Neighbors returnes the last set of data reported by ARP. Both the method
// and it's result must be safe for concurrent use.
Neighbors() (ns []Neighbor)
}
// NewARPDB returns the ARPDB properly initialized for the OS.
func NewARPDB() (arp ARPDB) {
return newARPDB()
}
// Empty ARPDB implementation
// EmptyARPDB is the ARPDB implementation that does nothing.
type EmptyARPDB struct{}
// type check
var _ ARPDB = EmptyARPDB{}
// Refresh implements the ARPDB interface for EmptyARPContainer. It does
// nothing and always returns nil error.
func (EmptyARPDB) Refresh() (err error) { return nil }
// Neighbors implements the ARPDB interface for EmptyARPContainer. It always
// returns nil.
func (EmptyARPDB) Neighbors() (ns []Neighbor) { return nil }
// ARPDB Helper Types
// Neighbor is the pair of IP address and MAC address reported by ARP.
type Neighbor struct {
// Name is the hostname of the neighbor. Empty name is valid since not each
// implementation of ARP is able to retrieve that.
Name string
// IP contains either IPv4 or IPv6.
IP net.IP
// MAC contains the hardware address.
MAC net.HardwareAddr
}
// Clone returns the deep copy of n.
func (n Neighbor) Clone() (clone Neighbor) {
return Neighbor{
Name: n.Name,
IP: netutil.CloneIP(n.IP),
MAC: netutil.CloneMAC(n.MAC),
}
}
// neighs is the helper type that stores neighbors to avoid copying its methods
// among all the ARPDB implementations.
type neighs struct {
mu *sync.RWMutex
ns []Neighbor
}
// len returns the length of the neighbors slice. It's safe for concurrent use.
func (ns *neighs) len() (l int) {
ns.mu.RLock()
defer ns.mu.RUnlock()
return len(ns.ns)
}
// clone returns a deep copy of the underlying neighbors slice. It's safe for
// concurrent use.
func (ns *neighs) clone() (cloned []Neighbor) {
ns.mu.RLock()
defer ns.mu.RUnlock()
cloned = make([]Neighbor, len(ns.ns))
for i, n := range ns.ns {
cloned[i] = n.Clone()
}
return cloned
}
// reset replaces the underlying slice with the new one. It's safe for
// concurrent use.
func (ns *neighs) reset(with []Neighbor) {
ns.mu.Lock()
defer ns.mu.Unlock()
ns.ns = with
}
// Command ARPDB
// parseNeighsFunc parses the text from sc as if it'd be an output of some
// ARP-related command. lenHint is a hint for the size of the allocated slice
// of Neighbors.
type parseNeighsFunc func(sc *bufio.Scanner, lenHint int) (ns []Neighbor)
// cmdARPDB is the implementation of the ARPDB that uses command line to
// retrieve data.
type cmdARPDB struct {
parse parseNeighsFunc
ns *neighs
cmd string
args []string
}
// type check
var _ ARPDB = (*cmdARPDB)(nil)
// Refresh implements the ARPDB interface for *cmdARPDB.
func (arp *cmdARPDB) Refresh() (err error) {
defer func() { err = errors.Annotate(err, "cmd arpdb: %w") }()
code, out, err := aghosRunCommand(arp.cmd, arp.args...)
if err != nil {
return fmt.Errorf("running command: %w", err)
} else if code != 0 {
return fmt.Errorf("running command: unexpected exit code %d", code)
}
sc := bufio.NewScanner(bytes.NewReader(out))
ns := arp.parse(sc, arp.ns.len())
if err = sc.Err(); err != nil {
// TODO(e.burkov): This error seems unreachable. Investigate.
return fmt.Errorf("scanning the output: %w", err)
}
arp.ns.reset(ns)
return nil
}
// Neighbors implements the ARPDB interface for *cmdARPDB.
func (arp *cmdARPDB) Neighbors() (ns []Neighbor) {
return arp.ns.clone()
}
// Composite ARPDB
// arpdbs is the ARPDB that combines several ARPDB implementations and
// consequently switches between those.
type arpdbs struct {
// arps is the set of ARPDB implementations to range through.
arps []ARPDB
neighs
}
// newARPDBs returns a properly initialized *arpdbs. It begins refreshing from
// the first of arps.
func newARPDBs(arps ...ARPDB) (arp *arpdbs) {
return &arpdbs{
arps: arps,
neighs: neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
}
}
// type check
var _ ARPDB = (*arpdbs)(nil)
// Refresh implements the ARPDB interface for *arpdbs.
func (arp *arpdbs) Refresh() (err error) {
var errs []error
for _, a := range arp.arps {
err = a.Refresh()
if err != nil {
errs = append(errs, err)
continue
}
arp.reset(a.Neighbors())
return nil
}
if len(errs) > 0 {
err = errors.List("each arpdb failed", errs...)
}
return err
}
// Neighbors implements the ARPDB interface for *arpdbs.
//
// TODO(e.burkov): Think of a way to avoid cloning the slice twice.
func (arp *arpdbs) Neighbors() (ns []Neighbor) {
return arp.clone()
}

View file

@ -0,0 +1,76 @@
//go:build darwin || freebsd
// +build darwin freebsd
package aghnet
import (
"bufio"
"net"
"strings"
"sync"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
)
func newARPDB() (arp *cmdARPDB) {
return &cmdARPDB{
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
cmd: "arp",
// Use -n flag to avoid resolving the hostnames of the neighbors. By
// default ARP attempts to resolve the hostnames via DNS. See man 8
// arp.
//
// See also https://github.com/AdguardTeam/AdGuardHome/issues/3157.
args: []string{"-a", "-n"},
}
}
// parseArpA parses the output of the "arp -a -n" command on macOS and FreeBSD.
// The expected input format:
//
// host.name (192.168.0.1) at ff:ff:ff:ff:ff:ff on en0 ifscope [ethernet]
//
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
ns = make([]Neighbor, 0, lenHint)
for sc.Scan() {
ln := sc.Text()
fields := strings.Fields(ln)
if len(fields) < 4 {
continue
}
n := Neighbor{}
if ipStr := fields[1]; len(ipStr) < 2 {
continue
} else if ip := net.ParseIP(ipStr[1 : len(ipStr)-1]); ip == nil {
continue
} else {
n.IP = ip
}
hwStr := fields[3]
if mac, err := net.ParseMAC(hwStr); err != nil {
continue
} else {
n.MAC = mac
}
host := fields[0]
if err := netutil.ValidateDomainName(host); err != nil {
log.Debug("parsing arp output: %s", err)
} else {
n.Name = host
}
ns = append(ns, n)
}
return ns
}

View file

@ -0,0 +1,31 @@
//go:build darwin || freebsd
// +build darwin freebsd
package aghnet
import (
"net"
)
const arpAOutput = `
invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet]
invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet]
invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet]
hostname.one (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet]
hostname.two (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 1198 seconds [ethernet]
? (::1234) at aa:bb:cc:dd:ee:ff on ej0 expires in 1918 seconds [ethernet]
`
var wantNeighs = []Neighbor{{
Name: "hostname.one",
IP: net.IPv4(192, 168, 1, 2),
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
}, {
Name: "hostname.two",
IP: net.ParseIP("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}, {
Name: "",
IP: net.ParseIP("::1234"),
MAC: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
}}

View file

@ -0,0 +1,243 @@
//go:build linux
// +build linux
package aghnet
import (
"bufio"
"fmt"
"io/fs"
"net"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
)
func newARPDB() (arp *arpdbs) {
// Use the common storage among the implementations.
ns := &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
}
var parseF parseNeighsFunc
if aghos.IsOpenWrt() {
parseF = parseArpAWrt
} else {
parseF = parseArpA
}
return newARPDBs(
// Try /proc/net/arp first.
&fsysARPDB{
ns: ns,
fsys: rootDirFS,
filename: "proc/net/arp",
},
// Then, try "arp -a -n".
&cmdARPDB{
parse: parseF,
ns: ns,
cmd: "arp",
// Use -n flag to avoid resolving the hostnames of the neighbors.
// By default ARP attempts to resolve the hostnames via DNS. See
// man 8 arp.
//
// See also https://github.com/AdguardTeam/AdGuardHome/issues/3157.
args: []string{"-a", "-n"},
},
// Finally, try "ip neigh".
&cmdARPDB{
parse: parseIPNeigh,
ns: ns,
cmd: "ip",
args: []string{"neigh"},
},
)
}
// fsysARPDB accesses the ARP cache file to update the database.
type fsysARPDB struct {
ns *neighs
fsys fs.FS
filename string
}
// type check
var _ ARPDB = (*fsysARPDB)(nil)
// Refresh implements the ARPDB interface for *fsysARPDB.
func (arp *fsysARPDB) Refresh() (err error) {
var f fs.File
f, err = arp.fsys.Open(arp.filename)
if err != nil {
return fmt.Errorf("opening %q: %w", arp.filename, err)
}
sc := bufio.NewScanner(f)
// Skip the header.
if !sc.Scan() {
return nil
} else if err = sc.Err(); err != nil {
return err
}
ns := make([]Neighbor, 0, arp.ns.len())
for sc.Scan() {
ln := sc.Text()
fields := stringutil.SplitTrimmed(ln, " ")
if len(fields) != 6 {
continue
}
n := Neighbor{}
if n.IP = net.ParseIP(fields[0]); n.IP == nil || n.IP.IsUnspecified() {
continue
} else if n.MAC, err = net.ParseMAC(fields[3]); err != nil {
continue
}
ns = append(ns, n)
}
arp.ns.reset(ns)
return nil
}
// Neighbors implements the ARPDB interface for *fsysARPDB.
func (arp *fsysARPDB) Neighbors() (ns []Neighbor) {
return arp.ns.clone()
}
// parseArpAWrt parses the output of the "arp -a -n" command on OpenWrt. The
// expected input format:
//
// IP address HW type Flags HW address Mask Device
// 192.168.11.98 0x1 0x2 5a:92:df:a9:7e:28 * wan
//
func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
if !sc.Scan() {
// Skip the header.
return
}
ns = make([]Neighbor, 0, lenHint)
for sc.Scan() {
ln := sc.Text()
fields := strings.Fields(ln)
if len(fields) < 4 {
continue
}
n := Neighbor{}
if ip := net.ParseIP(fields[0]); ip == nil || n.IP.IsUnspecified() {
continue
} else {
n.IP = ip
}
hwStr := fields[3]
if mac, err := net.ParseMAC(hwStr); err != nil {
log.Debug("parsing arp output: %s", err)
continue
} else {
n.MAC = mac
}
ns = append(ns, n)
}
return ns
}
// parseArpA parses the output of the "arp -a -n" command on Linux. The
// expected input format:
//
// hostname (192.168.1.1) at ab:cd:ef:ab:cd:ef [ether] on enp0s3
//
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
ns = make([]Neighbor, 0, lenHint)
for sc.Scan() {
ln := sc.Text()
fields := strings.Fields(ln)
if len(fields) < 4 {
continue
}
n := Neighbor{}
if ipStr := fields[1]; len(ipStr) < 2 {
continue
} else if ip := net.ParseIP(ipStr[1 : len(ipStr)-1]); ip == nil {
continue
} else {
n.IP = ip
}
hwStr := fields[3]
if mac, err := net.ParseMAC(hwStr); err != nil {
log.Debug("parsing arp output: %s", err)
continue
} else {
n.MAC = mac
}
host := fields[0]
if verr := netutil.ValidateDomainName(host); verr != nil {
log.Debug("parsing arp output: %s", verr)
} else {
n.Name = host
}
ns = append(ns, n)
}
return ns
}
// parseIPNeigh parses the output of the "ip neigh" command on Linux. The
// expected input format:
//
// 192.168.1.1 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef REACHABLE
//
func parseIPNeigh(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
ns = make([]Neighbor, 0, lenHint)
for sc.Scan() {
ln := sc.Text()
fields := strings.Fields(ln)
if len(fields) < 5 {
continue
}
n := Neighbor{}
if ip := net.ParseIP(fields[0]); ip == nil {
continue
} else {
n.IP = ip
}
if mac, err := net.ParseMAC(fields[4]); err != nil {
log.Debug("parsing arp output: %s", err)
continue
} else {
n.MAC = mac
}
ns = append(ns, n)
}
return ns
}

View file

@ -0,0 +1,102 @@
//go:build linux
// +build linux
package aghnet
import (
"net"
"sync"
"testing"
"testing/fstest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const arpAOutputWrt = `
IP address HW type Flags HW address Mask Device
1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan
1.2.3.4 0x1 0x2 12:34:56:78:910 * wan
192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan
::ffff:ffff 0x1 0x2 ef:cd:ab:ef:cd:ab * wan`
const arpAOutput = `
invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet]
invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet]
invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet]
? (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet]
? (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 100 seconds [ethernet]`
const ipNeighOutput = `
1.2.3.4.5 dev enp0s3 lladdr aa:bb:cc:dd:ee:ff DELAY
1.2.3.4 dev enp0s3 lladdr 12:34:56:78:910 DELAY
192.168.1.2 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef DELAY
::ffff:ffff dev enp0s3 lladdr ef:cd:ab:ef:cd:ab router STALE`
var wantNeighs = []Neighbor{{
IP: net.IPv4(192, 168, 1, 2),
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
}, {
IP: net.ParseIP("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}}
func TestFSysARPDB(t *testing.T) {
require.NoError(t, fstest.TestFS(testdata, "proc_net_arp"))
a := &fsysARPDB{
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
fsys: testdata,
filename: "proc_net_arp",
}
err := a.Refresh()
require.NoError(t, err)
ns := a.Neighbors()
assert.Equal(t, wantNeighs, ns)
}
func TestCmdARPDB_linux(t *testing.T) {
sh := mapShell{
"arp -a": {err: nil, out: arpAOutputWrt, code: 0},
"ip neigh": {err: nil, out: ipNeighOutput, code: 0},
}
substShell(t, sh.RunCmd)
t.Run("wrt", func(t *testing.T) {
a := &cmdARPDB{
parse: parseArpAWrt,
cmd: "arp",
args: []string{"-a"},
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
}
err := a.Refresh()
require.NoError(t, err)
assert.Equal(t, wantNeighs, a.Neighbors())
})
t.Run("ip_neigh", func(t *testing.T) {
a := &cmdARPDB{
parse: parseIPNeigh,
cmd: "ip",
args: []string{"neigh"},
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
}
err := a.Refresh()
require.NoError(t, err)
assert.Equal(t, wantNeighs, a.Neighbors())
})
}

View file

@ -0,0 +1,73 @@
//go:build openbsd
// +build openbsd
package aghnet
import (
"bufio"
"net"
"strings"
"sync"
"github.com/AdguardTeam/golibs/log"
)
func newARPDB() (arp *cmdARPDB) {
return &cmdARPDB{
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
cmd: "arp",
// Use -n flag to avoid resolving the hostnames of the neighbors. By
// default ARP attempts to resolve the hostnames via DNS. See man 8
// arp.
//
// See also https://github.com/AdguardTeam/AdGuardHome/issues/3157.
args: []string{"-a", "-n"},
}
}
// parseArpA parses the output of the "arp -a -n" command on OpenBSD. The
// expected input format:
//
// Host Ethernet Address Netif Expire Flags
// 192.168.1.1 ab:cd:ef:ab:cd:ef em0 19m59s
//
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
// Skip the header.
if !sc.Scan() {
return nil
}
ns = make([]Neighbor, 0, lenHint)
for sc.Scan() {
ln := sc.Text()
fields := strings.Fields(ln)
if len(fields) < 2 {
continue
}
n := Neighbor{}
if ip := net.ParseIP(fields[0]); ip == nil {
continue
} else {
n.IP = ip
}
if mac, err := net.ParseMAC(fields[1]); err != nil {
log.Debug("parsing arp output: %s", err)
continue
} else {
n.MAC = mac
}
ns = append(ns, n)
}
return ns
}

View file

@ -0,0 +1,24 @@
//go:build openbsd
// +build openbsd
package aghnet
import (
"net"
)
const arpAOutput = `
Host Ethernet Address Netif Expire Flags
1.2.3.4.5 aa:bb:cc:dd:ee:ff em0 permanent
1.2.3.4 12:34:56:78:910 em0 permanent
192.168.1.2 ab:cd:ef:ab:cd:ef em0 19m56s
::ffff:ffff ef:cd:ab:ef:cd:ab em0 permanent l
`
var wantNeighs = []Neighbor{{
IP: net.IPv4(192, 168, 1, 2),
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
}, {
IP: net.ParseIP("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}}

View file

@ -0,0 +1,216 @@
package aghnet
import (
"net"
"sync"
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewARPDB(t *testing.T) {
var a ARPDB
require.NotPanics(t, func() { a = NewARPDB() })
assert.NotNil(t, a)
}
// TestARPDB is the mock implementation of ARPDB to use in tests.
type TestARPDB struct {
OnRefresh func() (err error)
OnNeighbors func() (ns []Neighbor)
}
// Refresh implements the ARPDB interface for *TestARPDB.
func (arp *TestARPDB) Refresh() (err error) {
return arp.OnRefresh()
}
// Neighbors implements the ARPDB interface for *TestARPDB.
func (arp *TestARPDB) Neighbors() (ns []Neighbor) {
return arp.OnNeighbors()
}
func TestARPDBS(t *testing.T) {
knownIP := net.IP{1, 2, 3, 4}
knownMAC := net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF}
succRefrCount, failRefrCount := 0, 0
clnp := func() {
succRefrCount, failRefrCount = 0, 0
}
succDB := &TestARPDB{
OnRefresh: func() (err error) { succRefrCount++; return nil },
OnNeighbors: func() (ns []Neighbor) {
return []Neighbor{{Name: "abc", IP: knownIP, MAC: knownMAC}}
},
}
failDB := &TestARPDB{
OnRefresh: func() (err error) { failRefrCount++; return errors.Error("refresh failed") },
OnNeighbors: func() (ns []Neighbor) { return nil },
}
t.Run("begin_with_success", func(t *testing.T) {
t.Cleanup(clnp)
a := newARPDBs(succDB, failDB)
err := a.Refresh()
require.NoError(t, err)
assert.Equal(t, 1, succRefrCount)
assert.Zero(t, failRefrCount)
assert.NotEmpty(t, a.Neighbors())
})
t.Run("begin_with_fail", func(t *testing.T) {
t.Cleanup(clnp)
a := newARPDBs(failDB, succDB)
err := a.Refresh()
require.NoError(t, err)
assert.Equal(t, 1, succRefrCount)
assert.Equal(t, 1, failRefrCount)
assert.NotEmpty(t, a.Neighbors())
})
t.Run("fail_only", func(t *testing.T) {
t.Cleanup(clnp)
wantMsg := `each arpdb failed: 2 errors: "refresh failed", "refresh failed"`
a := newARPDBs(failDB, failDB)
err := a.Refresh()
require.Error(t, err)
testutil.AssertErrorMsg(t, wantMsg, err)
assert.Equal(t, 2, failRefrCount)
assert.Empty(t, a.Neighbors())
})
t.Run("fail_after_success", func(t *testing.T) {
t.Cleanup(clnp)
shouldFail := false
unstableDB := &TestARPDB{
OnRefresh: func() (err error) {
if shouldFail {
err = errors.Error("unstable failed")
}
shouldFail = !shouldFail
return err
},
OnNeighbors: func() (ns []Neighbor) {
if !shouldFail {
return failDB.OnNeighbors()
}
return succDB.OnNeighbors()
},
}
a := newARPDBs(unstableDB, succDB)
// Unstable ARPDB should refresh successfully.
err := a.Refresh()
require.NoError(t, err)
assert.Zero(t, succRefrCount)
assert.NotEmpty(t, a.Neighbors())
// Unstable ARPDB should fail and the succDB should be used.
err = a.Refresh()
require.NoError(t, err)
assert.Equal(t, 1, succRefrCount)
assert.NotEmpty(t, a.Neighbors())
// Unstable ARPDB should refresh successfully again.
err = a.Refresh()
require.NoError(t, err)
assert.Equal(t, 1, succRefrCount)
assert.NotEmpty(t, a.Neighbors())
})
t.Run("empty", func(t *testing.T) {
a := newARPDBs()
require.NoError(t, a.Refresh())
assert.Empty(t, a.Neighbors())
})
}
func TestCmdARPDB_arpa(t *testing.T) {
a := &cmdARPDB{
cmd: "cmd",
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
}
t.Run("arp_a", func(t *testing.T) {
sh := theOnlyCmd("cmd", 0, arpAOutput, nil)
substShell(t, sh.RunCmd)
err := a.Refresh()
require.NoError(t, err)
assert.Equal(t, wantNeighs, a.Neighbors())
})
t.Run("runcmd_error", func(t *testing.T) {
sh := theOnlyCmd("cmd", 0, "", errors.Error("can't run"))
substShell(t, sh.RunCmd)
err := a.Refresh()
testutil.AssertErrorMsg(t, "cmd arpdb: running command: can't run", err)
})
t.Run("bad_code", func(t *testing.T) {
sh := theOnlyCmd("cmd", 1, "", nil)
substShell(t, sh.RunCmd)
err := a.Refresh()
testutil.AssertErrorMsg(t, "cmd arpdb: running command: unexpected exit code 1", err)
})
t.Run("empty", func(t *testing.T) {
sh := theOnlyCmd("cmd", 0, "", nil)
substShell(t, sh.RunCmd)
err := a.Refresh()
require.NoError(t, err)
assert.Empty(t, a.Neighbors())
})
}
func TestEmptyARPDB(t *testing.T) {
a := EmptyARPDB{}
t.Run("refresh", func(t *testing.T) {
var err error
require.NotPanics(t, func() {
err = a.Refresh()
})
assert.NoError(t, err)
})
t.Run("neighbors", func(t *testing.T) {
var ns []Neighbor
require.NotPanics(t, func() {
ns = a.Neighbors()
})
assert.Empty(t, ns)
})
}

View file

@ -0,0 +1,65 @@
//go:build windows
// +build windows
package aghnet
import (
"bufio"
"net"
"strings"
"sync"
)
func newARPDB() (arp *cmdARPDB) {
return &cmdARPDB{
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
cmd: "arp",
args: []string{"/a"},
}
}
// parseArpA parses the output of the "arp /a" command on Windows. The expected
// input format (the first line is empty):
//
//
// Interface: 192.168.56.16 --- 0x7
// Internet Address Physical Address Type
// 192.168.56.1 0a-00-27-00-00-00 dynamic
// 192.168.56.255 ff-ff-ff-ff-ff-ff static
//
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
ns = make([]Neighbor, 0, lenHint)
for sc.Scan() {
ln := sc.Text()
if ln == "" {
continue
}
fields := strings.Fields(ln)
if len(fields) != 3 {
continue
}
n := Neighbor{}
if ip := net.ParseIP(fields[0]); ip == nil {
continue
} else {
n.IP = ip
}
if mac, err := net.ParseMAC(fields[1]); err != nil {
continue
} else {
n.MAC = mac
}
ns = append(ns, n)
}
return ns
}

View file

@ -0,0 +1,23 @@
//go:build windows
// +build windows
package aghnet
import (
"net"
)
const arpAOutput = `
Interface: 192.168.1.1 --- 0x7
Internet Address Physical Address Type
192.168.1.2 ab-cd-ef-ab-cd-ef dynamic
::ffff:ffff ef-cd-ab-ef-cd-ab static`
var wantNeighs = []Neighbor{{
IP: net.IPv4(192, 168, 1, 2),
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
}, {
IP: net.ParseIP("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}}

View file

@ -11,7 +11,7 @@ const (
ipv6HostnameMaxLen = len("ff80-f076-0000-0000-0000-0000-0000-0010") ipv6HostnameMaxLen = len("ff80-f076-0000-0000-0000-0000-0000-0010")
) )
// generateIPv4Hostname generates the hostname for specific IP version. // generateIPv4Hostname generates the hostname by IP address version 4.
func generateIPv4Hostname(ipv4 net.IP) (hostname string) { func generateIPv4Hostname(ipv4 net.IP) (hostname string) {
hnData := make([]byte, 0, ipv4HostnameMaxLen) hnData := make([]byte, 0, ipv4HostnameMaxLen)
for i, part := range ipv4 { for i, part := range ipv4 {
@ -24,7 +24,7 @@ func generateIPv4Hostname(ipv4 net.IP) (hostname string) {
return string(hnData) return string(hnData)
} }
// generateIPv6Hostname generates the hostname for specific IP version. // generateIPv6Hostname generates the hostname by IP address version 6.
func generateIPv6Hostname(ipv6 net.IP) (hostname string) { func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
hnData := make([]byte, 0, ipv6HostnameMaxLen) hnData := make([]byte, 0, ipv6HostnameMaxLen)
for i, partsNum := 0, net.IPv6len/2; i < partsNum; i++ { for i, partsNum := 0, net.IPv6len/2; i < partsNum; i++ {
@ -51,12 +51,11 @@ func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
// //
// ff80-f076-0000-0000-0000-0000-0000-0010 // ff80-f076-0000-0000-0000-0000-0000-0010
// //
// ip must be either an IPv4 or an IPv6.
func GenerateHostname(ip net.IP) (hostname string) { func GenerateHostname(ip net.IP) (hostname string) {
if ipv4 := ip.To4(); ipv4 != nil { if ipv4 := ip.To4(); ipv4 != nil {
return generateIPv4Hostname(ipv4) return generateIPv4Hostname(ipv4)
} else if ipv6 := ip.To16(); ipv6 != nil {
return generateIPv6Hostname(ipv6)
} }
return "" return generateIPv6Hostname(ip)
} }

View file

@ -8,6 +8,7 @@ import (
) )
func TestGenerateHostName(t *testing.T) { func TestGenerateHostName(t *testing.T) {
t.Run("valid", func(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
want string want string
@ -16,27 +17,14 @@ func TestGenerateHostName(t *testing.T) {
name: "good_ipv4", name: "good_ipv4",
want: "127-0-0-1", want: "127-0-0-1",
ip: net.IP{127, 0, 0, 1}, ip: net.IP{127, 0, 0, 1},
}, {
name: "bad_ipv4",
want: "",
ip: net.IP{127, 0, 0, 1, 0},
}, { }, {
name: "good_ipv6", name: "good_ipv6",
want: "fe00-0000-0000-0000-0000-0000-0000-0001", want: "fe00-0000-0000-0000-0000-0000-0000-0001",
ip: net.ParseIP("fe00::1"), ip: net.ParseIP("fe00::1"),
}, { }, {
name: "bad_ipv6", name: "4to6",
want: "", want: "1-2-3-4",
ip: net.IP{ ip: net.ParseIP("::ffff:1.2.3.4"),
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff,
},
}, {
name: "nil",
want: "",
ip: nil,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
@ -45,4 +33,32 @@ func TestGenerateHostName(t *testing.T) {
assert.Equal(t, tc.want, hostname) assert.Equal(t, tc.want, hostname)
}) })
} }
})
t.Run("invalid", func(t *testing.T) {
testCases := []struct {
name string
ip net.IP
}{{
name: "bad_ipv4",
ip: net.IP{127, 0, 0, 1, 0},
}, {
name: "bad_ipv6",
ip: net.IP{
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff,
},
}, {
name: "nil",
ip: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Panics(t, func() { GenerateHostname(tc.ip) })
})
}
})
} }

View file

@ -368,8 +368,8 @@ func (hp *hostsParser) addPairs(ip net.IP, hosts []string) {
} }
} }
// writeRules writes the actual rule for the qtype and the PTR for the // writeRules writes the actual rule for the qtype and the PTR for the host-ip
// host-ip pair into internal builders. // pair into internal builders.
func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) { func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) {
arpa, err := netutil.IPToReversedAddr(ip) arpa, err := netutil.IPToReversedAddr(ip)
if err != nil { if err != nil {

View file

@ -2,19 +2,31 @@
package aghnet package aghnet
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net" "net"
"os/exec"
"strings"
"syscall" "syscall"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
) )
// Variables and functions to substitute in tests.
var (
// aghosRunCommand is the function to run shell commands.
aghosRunCommand = aghos.RunCommand
// netInterfaces is the function to get the available network interfaces.
netInterfaceAddrs = net.InterfaceAddrs
// rootDirFS is the filesystem pointing to the root directory.
rootDirFS = aghos.RootDirFS()
)
// ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about // ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about
// the IP being static is available. // the IP being static is available.
const ErrNoStaticIPInfo errors.Error = "no information about static ip" const ErrNoStaticIPInfo errors.Error = "no information about static ip"
@ -32,39 +44,29 @@ func IfaceSetStaticIP(ifaceName string) (err error) {
} }
// GatewayIP returns IP address of interface's gateway. // GatewayIP returns IP address of interface's gateway.
func GatewayIP(ifaceName string) net.IP { //
cmd := exec.Command("ip", "route", "show", "dev", ifaceName) // TODO(e.burkov): Investigate if the gateway address may be fetched in another
log.Tracef("executing %s %v", cmd.Path, cmd.Args) // way since not every machine has the software installed.
d, err := cmd.Output() func GatewayIP(ifaceName string) (ip net.IP) {
if err != nil || cmd.ProcessState.ExitCode() != 0 { code, out, err := aghosRunCommand("ip", "route", "show", "dev", ifaceName)
if err != nil {
log.Debug("%s", err)
return nil
} else if code != 0 {
log.Debug("fetching gateway ip: unexpected exit code: %d", code)
return nil return nil
} }
fields := strings.Fields(string(d)) fields := bytes.Fields(out)
// The meaningful "ip route" command output should contain the word // The meaningful "ip route" command output should contain the word
// "default" at first field and default gateway IP address at third field. // "default" at first field and default gateway IP address at third field.
if len(fields) < 3 || fields[0] != "default" { if len(fields) < 3 || string(fields[0]) != "default" {
return nil return nil
} }
return net.ParseIP(fields[2]) return net.ParseIP(string(fields[2]))
}
// CanBindPort checks if we can bind to the given port.
func CanBindPort(port int) (can bool, err error) {
var addr *net.TCPAddr
addr, err = net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return false, err
}
var listener *net.TCPListener
listener, err = net.ListenTCP("tcp", addr)
if err != nil {
return false, err
}
_ = listener.Close()
return true, nil
} }
// CanBindPrivilegedPorts checks if current process can bind to privileged // CanBindPrivilegedPorts checks if current process can bind to privileged
@ -99,19 +101,19 @@ func (iface NetInterface) MarshalJSON() ([]byte, error) {
}) })
} }
// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and WEB only // GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and
// we do not return link-local addresses here // WEB only we do not return link-local addresses here.
func GetValidNetInterfacesForWeb() ([]*NetInterface, error) { //
// TODO(e.burkov): Can't properly test the function since it's nontrivial to
// substitute net.Interface.Addrs and the net.InterfaceAddrs can't be used.
func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) {
ifaces, err := net.Interfaces() ifaces, err := net.Interfaces()
if err != nil { if err != nil {
return nil, fmt.Errorf("couldn't get interfaces: %w", err) return nil, fmt.Errorf("couldn't get interfaces: %w", err)
} } else if len(ifaces) == 0 {
if len(ifaces) == 0 {
return nil, errors.Error("couldn't find any legible interface") return nil, errors.Error("couldn't find any legible interface")
} }
var netInterfaces []*NetInterface
for _, iface := range ifaces { for _, iface := range ifaces {
var addrs []net.Addr var addrs []net.Addr
addrs, err = iface.Addrs() addrs, err = iface.Addrs()
@ -131,26 +133,30 @@ func GetValidNetInterfacesForWeb() ([]*NetInterface, error) {
ipNet, ok := addr.(*net.IPNet) ipNet, ok := addr.(*net.IPNet)
if !ok { if !ok {
// Should be net.IPNet, this is weird. // Should be net.IPNet, this is weird.
return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) return nil, fmt.Errorf("got %s that is not net.IPNet, it is %T", addr, addr)
} }
// Ignore link-local. // Ignore link-local.
if ipNet.IP.IsLinkLocalUnicast() { if ipNet.IP.IsLinkLocalUnicast() {
continue continue
} }
netIface.Addresses = append(netIface.Addresses, ipNet.IP) netIface.Addresses = append(netIface.Addresses, ipNet.IP)
netIface.Subnets = append(netIface.Subnets, ipNet) netIface.Subnets = append(netIface.Subnets, ipNet)
} }
// Discard interfaces with no addresses. // Discard interfaces with no addresses.
if len(netIface.Addresses) != 0 { if len(netIface.Addresses) != 0 {
netInterfaces = append(netInterfaces, netIface) netIfaces = append(netIfaces, netIface)
} }
} }
return netInterfaces, nil return netIfaces, nil
} }
// GetInterfaceByIP returns the name of interface containing provided ip. // GetInterfaceByIP returns the name of interface containing provided ip.
//
// TODO(e.burkov): See TODO on GetValidInterfacesForWeb.
func GetInterfaceByIP(ip net.IP) string { func GetInterfaceByIP(ip net.IP) string {
ifaces, err := GetValidNetInterfacesForWeb() ifaces, err := GetValidNetInterfacesForWeb()
if err != nil { if err != nil {
@ -170,6 +176,8 @@ func GetInterfaceByIP(ip net.IP) string {
// GetSubnet returns pointer to net.IPNet for the specified interface or nil if // GetSubnet returns pointer to net.IPNet for the specified interface or nil if
// the search fails. // the search fails.
//
// TODO(e.burkov): See TODO on GetValidInterfacesForWeb.
func GetSubnet(ifaceName string) *net.IPNet { func GetSubnet(ifaceName string) *net.IPNet {
netIfaces, err := GetValidNetInterfacesForWeb() netIfaces, err := GetValidNetInterfacesForWeb()
if err != nil { if err != nil {
@ -220,17 +228,10 @@ func IsAddrInUse(err error) (ok bool) {
// CollectAllIfacesAddrs returns the slice of all network interfaces IP // CollectAllIfacesAddrs returns the slice of all network interfaces IP
// addresses without port number. // addresses without port number.
func CollectAllIfacesAddrs() (addrs []string, err error) { func CollectAllIfacesAddrs() (addrs []string, err error) {
var ifaces []net.Interface
ifaces, err = net.Interfaces()
if err != nil {
return nil, fmt.Errorf("getting network interfaces: %w", err)
}
for _, iface := range ifaces {
var ifaceAddrs []net.Addr var ifaceAddrs []net.Addr
ifaceAddrs, err = iface.Addrs() ifaceAddrs, err = netInterfaceAddrs()
if err != nil { if err != nil {
return nil, fmt.Errorf("getting addresses for %q: %w", iface.Name, err) return nil, fmt.Errorf("getting interfaces addresses: %w", err)
} }
for _, addr := range ifaceAddrs { for _, addr := range ifaceAddrs {
@ -243,7 +244,6 @@ func CollectAllIfacesAddrs() (addrs []string, err error) {
addrs = append(addrs, ip.String()) addrs = append(addrs, ip.String())
} }
}
return addrs, nil return addrs, nil
} }

View file

@ -4,10 +4,11 @@
package aghnet package aghnet
import ( import (
"bufio"
"bytes"
"fmt" "fmt"
"os" "io"
"regexp" "regexp"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
@ -23,7 +24,7 @@ type hardwarePortInfo struct {
static bool static bool
} }
func ifaceHasStaticIP(ifaceName string) (bool, error) { func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
portInfo, err := getCurrentHardwarePortInfo(ifaceName) portInfo, err := getCurrentHardwarePortInfo(ifaceName)
if err != nil { if err != nil {
return false, err return false, err
@ -32,9 +33,10 @@ func ifaceHasStaticIP(ifaceName string) (bool, error) {
return portInfo.static, nil return portInfo.static, nil
} }
// getCurrentHardwarePortInfo gets information for the specified network interface. // getCurrentHardwarePortInfo gets information for the specified network
// interface.
func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) { func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
// First of all we should find hardware port name // First of all we should find hardware port name.
m := getNetworkSetupHardwareReports() m := getNetworkSetupHardwareReports()
hardwarePort, ok := m[ifaceName] hardwarePort, ok := m[ifaceName]
if !ok { if !ok {
@ -44,6 +46,10 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
return getHardwarePortInfo(hardwarePort) return getHardwarePortInfo(hardwarePort)
} }
// hardwareReportsReg is the regular expression matching the lines of
// networksetup command output lines containing the interface information.
var hardwareReportsReg = regexp.MustCompile("Hardware Port: (.*?)\nDevice: (.*?)\n")
// getNetworkSetupHardwareReports parses the output of the `networksetup // getNetworkSetupHardwareReports parses the output of the `networksetup
// -listallhardwareports` command it returns a map where the key is the // -listallhardwareports` command it returns a map where the key is the
// interface name, and the value is the "hardware port" returns nil if it fails // interface name, and the value is the "hardware port" returns nil if it fails
@ -52,54 +58,44 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
// TODO(e.burkov): There should be more proper approach than parsing the // TODO(e.burkov): There should be more proper approach than parsing the
// command output. For example, see // command output. For example, see
// https://developer.apple.com/documentation/systemconfiguration. // https://developer.apple.com/documentation/systemconfiguration.
func getNetworkSetupHardwareReports() map[string]string { func getNetworkSetupHardwareReports() (reports map[string]string) {
_, out, err := aghos.RunCommand("networksetup", "-listallhardwareports") _, out, err := aghosRunCommand("networksetup", "-listallhardwareports")
if err != nil { if err != nil {
return nil return nil
} }
re, err := regexp.Compile("Hardware Port: (.*?)\nDevice: (.*?)\n") reports = make(map[string]string)
if err != nil {
return nil matches := hardwareReportsReg.FindAllSubmatch(out, -1)
for _, m := range matches {
reports[string(m[2])] = string(m[1])
} }
m := make(map[string]string) return reports
matches := re.FindAllStringSubmatch(out, -1)
for i := range matches {
port := matches[i][1]
device := matches[i][2]
m[device] = port
} }
return m // hardwarePortReg is the regular expression matching the lines of networksetup
} // command output lines containing the port information.
var hardwarePortReg = regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n")
func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) { func getHardwarePortInfo(hardwarePort string) (h hardwarePortInfo, err error) {
h := hardwarePortInfo{} _, out, err := aghosRunCommand("networksetup", "-getinfo", hardwarePort)
_, out, err := aghos.RunCommand("networksetup", "-getinfo", hardwarePort)
if err != nil { if err != nil {
return h, err return h, err
} }
re := regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n") match := hardwarePortReg.FindSubmatch(out)
if len(match) != 4 {
match := re.FindStringSubmatch(out)
if len(match) == 0 {
return h, errors.Error("could not find hardware port info") return h, errors.Error("could not find hardware port info")
} }
h.name = hardwarePort return hardwarePortInfo{
h.ip = match[1] name: hardwarePort,
h.subnet = match[2] ip: string(match[1]),
h.gatewayIP = match[3] subnet: string(match[2]),
gatewayIP: string(match[3]),
if strings.Index(out, "Manual Configuration") == 0 { static: bytes.Index(out, []byte("Manual Configuration")) == 0,
h.static = true }, nil
}
return h, nil
} }
func ifaceSetStaticIP(ifaceName string) (err error) { func ifaceSetStaticIP(ifaceName string) (err error) {
@ -109,7 +105,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
} }
if portInfo.static { if portInfo.static {
return errors.Error("IP address is already static") return errors.Error("ip address is already static")
} }
dnsAddrs, err := getEtcResolvConfServers() dnsAddrs, err := getEtcResolvConfServers()
@ -117,50 +113,62 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
return err return err
} }
args := make([]string, 0) args := append([]string{"-setdnsservers", portInfo.name}, dnsAddrs...)
args = append(args, "-setdnsservers", portInfo.name)
args = append(args, dnsAddrs...)
// Setting DNS servers is necessary when configuring a static IP // Setting DNS servers is necessary when configuring a static IP
code, _, err := aghos.RunCommand("networksetup", args...) code, _, err := aghosRunCommand("networksetup", args...)
if err != nil { if err != nil {
return err return err
} } else if code != 0 {
if code != 0 {
return fmt.Errorf("failed to set DNS servers, code=%d", code) return fmt.Errorf("failed to set DNS servers, code=%d", code)
} }
// Actually configures hardware port to have static IP // Actually configures hardware port to have static IP
code, _, err = aghos.RunCommand("networksetup", "-setmanual", code, _, err = aghosRunCommand(
portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP) "networksetup",
"-setmanual",
portInfo.name,
portInfo.ip,
portInfo.subnet,
portInfo.gatewayIP,
)
if err != nil { if err != nil {
return err return err
} } else if code != 0 {
if code != 0 {
return fmt.Errorf("failed to set DNS servers, code=%d", code) return fmt.Errorf("failed to set DNS servers, code=%d", code)
} }
return nil return nil
} }
// etcResolvConfReg is the regular expression matching the lines of resolv.conf
// file containing a name server information.
var etcResolvConfReg = regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)")
// getEtcResolvConfServers returns a list of nameservers configured in // getEtcResolvConfServers returns a list of nameservers configured in
// /etc/resolv.conf. // /etc/resolv.conf.
func getEtcResolvConfServers() ([]string, error) { func getEtcResolvConfServers() (addrs []string, err error) {
body, err := os.ReadFile("/etc/resolv.conf") const filename = "etc/resolv.conf"
if err != nil {
return nil, err
}
re := regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)") _, err = aghos.FileWalker(func(r io.Reader) (_ []string, _ bool, err error) {
sc := bufio.NewScanner(r)
matches := re.FindAllStringSubmatch(string(body), -1) for sc.Scan() {
matches := etcResolvConfReg.FindAllStringSubmatch(sc.Text(), -1)
if len(matches) == 0 { if len(matches) == 0 {
return nil, errors.Error("found no DNS servers in /etc/resolv.conf") continue
} }
addrs := make([]string, 0) for _, m := range matches {
for i := range matches { addrs = append(addrs, m[1])
addrs = append(addrs, matches[i][1]) }
}
return nil, false, sc.Err()
}).Walk(rootDirFS, filename)
if err != nil {
return nil, fmt.Errorf("parsing etc/resolv.conf file: %w", err)
} else if len(addrs) == 0 {
return nil, fmt.Errorf("found no dns servers in %s", filename)
} }
return addrs, nil return addrs, nil

View file

@ -0,0 +1,261 @@
package aghnet
import (
"io/fs"
"testing"
"testing/fstest"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
)
func TestIfaceHasStaticIP(t *testing.T) {
testCases := []struct {
name string
shell mapShell
ifaceName string
wantHas assert.BoolAssertionFunc
wantErrMsg string
}{{
name: "success",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
},
ifaceName: "en0",
wantHas: assert.False,
wantErrMsg: ``,
}, {
name: "success_static",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "Manual Configuration\nIP address: 1.2.3.4\n" +
"Subnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
},
ifaceName: "en0",
wantHas: assert.True,
wantErrMsg: ``,
}, {
name: "reports_error",
shell: theOnlyCmd(
"networksetup -listallhardwareports",
0,
"",
errors.Error("can't list"),
),
ifaceName: "en0",
wantHas: assert.False,
wantErrMsg: `could not find hardware port for en0`,
}, {
name: "port_error",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: errors.Error("can't get"),
out: ``,
code: 0,
},
},
ifaceName: "en0",
wantHas: assert.False,
wantErrMsg: `can't get`,
}, {
name: "port_bad_output",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "nothing meaningful",
code: 0,
},
},
ifaceName: "en0",
wantHas: assert.False,
wantErrMsg: `could not find hardware port info`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substShell(t, tc.shell.RunCmd)
has, err := IfaceHasStaticIP(tc.ifaceName)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
tc.wantHas(t, has)
})
}
}
func TestIfaceSetStaticIP(t *testing.T) {
succFsys := fstest.MapFS{
"etc/resolv.conf": &fstest.MapFile{
Data: []byte(`nameserver 1.1.1.1`),
},
}
panicFsys := &aghtest.FS{
OnOpen: func(name string) (fs.File, error) { panic("not implemented") },
}
testCases := []struct {
name string
shell mapShell
fsys fs.FS
wantErrMsg string
}{{
name: "success",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
"networksetup -setdnsservers hwport 1.1.1.1": {
err: nil,
out: "",
code: 0,
},
"networksetup -setmanual hwport 1.2.3.4 255.255.255.0 1.2.3.1": {
err: nil,
out: "",
code: 0,
},
},
fsys: succFsys,
wantErrMsg: ``,
}, {
name: "static_already",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "Manual Configuration\nIP address: 1.2.3.4\n" +
"Subnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
},
fsys: panicFsys,
wantErrMsg: `ip address is already static`,
}, {
name: "reports_error",
shell: theOnlyCmd(
"networksetup -listallhardwareports",
0,
"",
errors.Error("can't list"),
),
fsys: panicFsys,
wantErrMsg: `could not find hardware port for en0`,
}, {
name: "resolv_conf_error",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
},
fsys: fstest.MapFS{
"etc/resolv.conf": &fstest.MapFile{
Data: []byte("this resolv.conf is invalid"),
},
},
wantErrMsg: `found no dns servers in etc/resolv.conf`,
}, {
name: "set_dns_error",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
"networksetup -setdnsservers hwport 1.1.1.1": {
err: errors.Error("can't set"),
out: "",
code: 0,
},
},
fsys: succFsys,
wantErrMsg: `can't set`,
}, {
name: "set_manual_error",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
"networksetup -setdnsservers hwport 1.1.1.1": {
err: nil,
out: "",
code: 0,
},
"networksetup -setmanual hwport 1.2.3.4 255.255.255.0 1.2.3.1": {
err: errors.Error("can't set"),
out: "",
code: 0,
},
},
fsys: succFsys,
wantErrMsg: `can't set`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substShell(t, tc.shell.RunCmd)
substRootDirFS(t, tc.fsys)
err := IfaceSetStaticIP("en0")
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}

View file

@ -18,7 +18,7 @@ func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
walker := aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig) walker := aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig)
return walker.Walk(aghos.RootDirFS(), rcConfFilename) return walker.Walk(rootDirFS, rcConfFilename)
} }
// rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to // rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to

View file

@ -4,56 +4,74 @@
package aghnet package aghnet
import ( import (
"strings" "io/fs"
"testing" "testing"
"testing/fstest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestRcConfStaticConfig(t *testing.T) { func TestIfaceHasStaticIP(t *testing.T) {
const iface interfaceName = `em0` const (
const nl = "\n" ifaceName = `em0`
rcConf = "etc/rc.conf"
)
testCases := []struct { testCases := []struct {
name string name string
rcconfData string rootFsys fs.FS
wantCont bool wantHas assert.BoolAssertionFunc
}{{ }{{
name: "simple", name: "simple",
rcconfData: `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
wantCont: false, Data: []byte(`ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl),
}},
wantHas: assert.True,
}, { }, {
name: "case_insensitiveness", name: "case_insensitiveness",
rcconfData: `ifconfig_em0="InEt 127.0.0.253 NeTmAsK 0xffffffff"` + nl, rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
wantCont: false, Data: []byte(`ifconfig_` + ifaceName + `="InEt 127.0.0.253 NeTmAsK 0xffffffff"` + nl),
}},
wantHas: assert.True,
}, { }, {
name: "comments_and_trash", name: "comments_and_trash",
rcconfData: `# comment 1` + nl + rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
Data: []byte(`# comment 1` + nl +
`` + nl + `` + nl +
`# comment 2` + nl + `# comment 2` + nl +
`ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, `ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl,
wantCont: false, ),
}},
wantHas: assert.True,
}, { }, {
name: "aliases", name: "aliases",
rcconfData: `ifconfig_em0_alias="inet 127.0.0.1/24"` + nl + rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
`ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, Data: []byte(`ifconfig_` + ifaceName + `_alias="inet 127.0.0.1/24"` + nl +
wantCont: false, `ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl,
),
}},
wantHas: assert.True,
}, { }, {
name: "incorrect_config", name: "incorrect_config",
rcconfData: `ifconfig_em0="inet6 127.0.0.253 netmask 0xffffffff"` + nl + rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
`ifconfig_em0="inet 256.256.256.256 netmask 0xffffffff"` + nl + Data: []byte(
`ifconfig_em0=""` + nl, `ifconfig_` + ifaceName + `="inet6 127.0.0.253 netmask 0xffffffff"` + nl +
wantCont: true, `ifconfig_` + ifaceName + `="inet 256.256.256.256 netmask 0xffffffff"` + nl +
`ifconfig_` + ifaceName + `=""` + nl,
),
}},
wantHas: assert.False,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
r := strings.NewReader(tc.rcconfData)
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
_, cont, err := iface.rcConfStaticConfig(r) substRootDirFS(t, tc.rootFsys)
has, err := IfaceHasStaticIP(ifaceName)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tc.wantCont, cont) tc.wantHas(t, has)
}) })
} }
} }

View file

@ -13,16 +13,33 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/renameio/maybe" "github.com/google/renameio/maybe"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// dhcpсdConf is the name of /etc/dhcpcd.conf file in the root filesystem.
const dhcpcdConf = "etc/dhcpcd.conf"
func canBindPrivilegedPorts() (can bool, err error) {
cnbs, err := unix.PrctlRetInt(
unix.PR_CAP_AMBIENT,
unix.PR_CAP_AMBIENT_IS_SET,
unix.CAP_NET_BIND_SERVICE,
0,
0,
)
// Don't check the error because it's always nil on Linux.
adm, _ := aghos.HaveAdminRights()
return cnbs == 1 || adm, err
}
// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to // dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to
// have a static IP. // have a static IP.
func (n interfaceName) dhcpcdStaticConfig(r io.Reader) (subsources []string, cont bool, err error) { func (n interfaceName) dhcpcdStaticConfig(r io.Reader) (subsources []string, cont bool, err error) {
s := bufio.NewScanner(r) s := bufio.NewScanner(r)
ifaceFound := findIfaceLine(s, string(n)) if !findIfaceLine(s, string(n)) {
if !ifaceFound {
return nil, true, s.Err() return nil, true, s.Err()
} }
@ -61,9 +78,9 @@ func (n interfaceName) ifacesStaticConfig(r io.Reader) (sub []string, cont bool,
fields := strings.Fields(line) fields := strings.Fields(line)
fieldsNum := len(fields) fieldsNum := len(fields)
// Man page interfaces(5) declares that interface definition // Man page interfaces(5) declares that interface definition should
// should consist of the key word "iface" followed by interface // consist of the key word "iface" followed by interface name, and
// name, and method at fourth field. // method at fourth field.
if fieldsNum >= 4 && if fieldsNum >= 4 &&
fields[0] == "iface" && fields[1] == string(n) && fields[3] == "static" { fields[0] == "iface" && fields[1] == string(n) && fields[3] == "static" {
return nil, false, nil return nil, false, nil
@ -78,10 +95,10 @@ func (n interfaceName) ifacesStaticConfig(r io.Reader) (sub []string, cont bool,
} }
func ifaceHasStaticIP(ifaceName string) (has bool, err error) { func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
// TODO(a.garipov): Currently, this function returns the first // TODO(a.garipov): Currently, this function returns the first definitive
// definitive result. So if /etc/dhcpcd.conf has a static IP while // result. So if /etc/dhcpcd.conf has and /etc/network/interfaces has no
// /etc/network/interfaces doesn't, it will return true. Perhaps this // static IP configuration, it will return true. Perhaps this is not the
// is not the most desirable behavior. // most desirable behavior.
iface := interfaceName(ifaceName) iface := interfaceName(ifaceName)
@ -90,17 +107,15 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
filename string filename string
}{{ }{{
FileWalker: iface.dhcpcdStaticConfig, FileWalker: iface.dhcpcdStaticConfig,
filename: "etc/dhcpcd.conf", filename: dhcpcdConf,
}, { }, {
FileWalker: iface.ifacesStaticConfig, FileWalker: iface.ifacesStaticConfig,
filename: "etc/network/interfaces", filename: "etc/network/interfaces",
}} { }} {
has, err = pair.Walk(aghos.RootDirFS(), pair.filename) has, err = pair.Walk(rootDirFS, pair.filename)
if err != nil { if err != nil {
return false, err return false, err
} } else if has {
if has {
return true, nil return true, nil
} }
} }
@ -108,14 +123,6 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
return false, ErrNoStaticIPInfo return false, ErrNoStaticIPInfo
} }
func canBindPrivilegedPorts() (can bool, err error) {
cnbs, err := unix.PrctlRetInt(unix.PR_CAP_AMBIENT, unix.PR_CAP_AMBIENT_IS_SET, unix.CAP_NET_BIND_SERVICE, 0, 0)
// Don't check the error because it's always nil on Linux.
adm, _ := aghos.HaveAdminRights()
return cnbs == 1 || adm, err
}
// findIfaceLine scans s until it finds the line that declares an interface with // findIfaceLine scans s until it finds the line that declares an interface with
// the given name. If findIfaceLine can't find the line, it returns false. // the given name. If findIfaceLine can't find the line, it returns false.
func findIfaceLine(s *bufio.Scanner, name string) (ok bool) { func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
@ -131,23 +138,23 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
} }
// ifaceSetStaticIP configures the system to retain its current IP on the // ifaceSetStaticIP configures the system to retain its current IP on the
// interface through dhcpdc.conf. // interface through dhcpcd.conf.
func ifaceSetStaticIP(ifaceName string) (err error) { func ifaceSetStaticIP(ifaceName string) (err error) {
ipNet := GetSubnet(ifaceName) ipNet := GetSubnet(ifaceName)
if ipNet.IP == nil { if ipNet.IP == nil {
return errors.Error("can't get IP address") return errors.Error("can't get IP address")
} }
gatewayIP := GatewayIP(ifaceName) body, err := os.ReadFile(dhcpcdConf)
add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP, ipNet.IP)
body, err := os.ReadFile("/etc/dhcpcd.conf")
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return err
} }
gatewayIP := GatewayIP(ifaceName)
add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP)
body = append(body, []byte(add)...) body = append(body, []byte(add)...)
err = maybe.WriteFile("/etc/dhcpcd.conf", body, 0o644) err = maybe.WriteFile(dhcpcdConf, body, 0o644)
if err != nil { if err != nil {
return fmt.Errorf("writing conf: %w", err) return fmt.Errorf("writing conf: %w", err)
} }
@ -157,22 +164,24 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
// dhcpcdConfIface returns configuration lines for the dhcpdc.conf files that // dhcpcdConfIface returns configuration lines for the dhcpdc.conf files that
// configure the interface to have a static IP. // configure the interface to have a static IP.
func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gatewayIP, dnsIP net.IP) (conf string) { func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gwIP net.IP) (conf string) {
var body []byte b := &strings.Builder{}
stringutil.WriteToBuilder(
add := fmt.Sprintf( b,
"\n# %[1]s added by AdGuard Home.\ninterface %[1]s\nstatic ip_address=%s\n", "\n# ",
ifaceName, ifaceName,
ipNet) " added by AdGuard Home.\ninterface ",
body = append(body, []byte(add)...) ifaceName,
"\nstatic ip_address=",
ipNet.String(),
"\n",
)
if gatewayIP != nil { if gwIP != nil {
add = fmt.Sprintf("static routers=%s\n", gatewayIP) stringutil.WriteToBuilder(b, "static routers=", gwIP.String(), "\n")
body = append(body, []byte(add)...)
} }
add = fmt.Sprintf("static domain_name_servers=%s\n\n", dnsIP) stringutil.WriteToBuilder(b, "static domain_name_servers=", ipNet.IP.String(), "\n\n")
body = append(body, []byte(add)...)
return string(body) return b.String()
} }

View file

@ -4,152 +4,124 @@
package aghnet package aghnet
import ( import (
"bytes" "io/fs"
"net"
"testing" "testing"
"testing/fstest"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestDHCPCDStaticConfig(t *testing.T) { func TestHasStaticIP(t *testing.T) {
const iface interfaceName = `wlan0` const ifaceName = "wlan0"
const (
dhcpcd = "etc/dhcpcd.conf"
netifaces = "etc/network/interfaces"
)
testCases := []struct { testCases := []struct {
rootFsys fs.FS
name string name string
data []byte wantHas assert.BoolAssertionFunc
wantCont bool wantErrMsg string
}{{ }{{
name: "has_not", rootFsys: fstest.MapFS{
data: []byte(`#comment` + nl + dhcpcd: &fstest.MapFile{
Data: []byte(`#comment` + nl +
`# comment` + nl + `# comment` + nl +
`interface eth0` + nl + `interface eth0` + nl +
`static ip_address=192.168.0.1/24` + nl + `static ip_address=192.168.0.1/24` + nl +
`# interface ` + iface + nl + `# interface ` + ifaceName + nl +
`static ip_address=192.168.1.1/24` + nl + `static ip_address=192.168.1.1/24` + nl +
`# comment` + nl, `# comment` + nl,
), ),
wantCont: true, },
},
name: "dhcpcd_has_not",
wantHas: assert.False,
wantErrMsg: `no information about static ip`,
}, { }, {
name: "has", rootFsys: fstest.MapFS{
data: []byte(`#comment` + nl + dhcpcd: &fstest.MapFile{
Data: []byte(`#comment` + nl +
`# comment` + nl + `# comment` + nl +
`interface eth0` + nl + `interface ` + ifaceName + nl +
`static ip_address=192.168.0.1/24` + nl + `static ip_address=192.168.0.1/24` + nl +
`# interface ` + iface + nl + `# interface ` + ifaceName + nl +
`static ip_address=192.168.1.1/24` + nl + `static ip_address=192.168.1.1/24` + nl +
`# comment` + nl + `# comment` + nl,
`interface ` + iface + nl +
`# comment` + nl +
`static ip_address=192.168.2.1/24` + nl,
), ),
wantCont: false, },
}} },
name: "dhcpcd_has",
for _, tc := range testCases { wantHas: assert.True,
t.Run(tc.name, func(t *testing.T) { wantErrMsg: ``,
r := bytes.NewReader(tc.data) }, {
_, cont, err := iface.dhcpcdStaticConfig(r) rootFsys: fstest.MapFS{
require.NoError(t, err) netifaces: &fstest.MapFile{
Data: []byte(`allow-hotplug ` + ifaceName + nl +
assert.Equal(t, tc.wantCont, cont)
})
}
}
func TestIfacesStaticConfig(t *testing.T) {
const iface interfaceName = `enp0s3`
testCases := []struct {
name string
data []byte
wantCont bool
wantPatterns []string
}{{
name: "has_not",
data: []byte(`allow-hotplug ` + iface + nl +
`#iface enp0s3 inet static` + nl + `#iface enp0s3 inet static` + nl +
`# address 192.168.0.200` + nl + `# address 192.168.0.200` + nl +
`# netmask 255.255.255.0` + nl + `# netmask 255.255.255.0` + nl +
`# gateway 192.168.0.1` + nl + `# gateway 192.168.0.1` + nl +
`iface ` + iface + ` inet dhcp` + nl, `iface ` + ifaceName + ` inet dhcp` + nl,
), ),
wantCont: true, },
wantPatterns: []string{}, },
name: "netifaces_has_not",
wantHas: assert.False,
wantErrMsg: `no information about static ip`,
}, { }, {
name: "has", rootFsys: fstest.MapFS{
data: []byte(`allow-hotplug ` + iface + nl + netifaces: &fstest.MapFile{
`iface ` + iface + ` inet static` + nl + Data: []byte(`allow-hotplug ` + ifaceName + nl +
`iface ` + ifaceName + ` inet static` + nl +
` address 192.168.0.200` + nl + ` address 192.168.0.200` + nl +
` netmask 255.255.255.0` + nl + ` netmask 255.255.255.0` + nl +
` gateway 192.168.0.1` + nl + ` gateway 192.168.0.1` + nl +
`#iface ` + iface + ` inet dhcp` + nl, `#iface ` + ifaceName + ` inet dhcp` + nl,
), ),
wantCont: false, },
wantPatterns: []string{}, },
name: "netifaces_has",
wantHas: assert.True,
wantErrMsg: ``,
}, { }, {
name: "return_patterns", rootFsys: fstest.MapFS{
data: []byte(`source hello` + nl + netifaces: &fstest.MapFile{
`source world` + nl + Data: []byte(`source hello` + nl +
`#iface ` + iface + ` inet static` + nl, `#iface ` + ifaceName + ` inet static` + nl,
), ),
wantCont: true, },
wantPatterns: []string{"hello", "world"}, "hello": &fstest.MapFile{
Data: []byte(`iface ` + ifaceName + ` inet static` + nl),
},
},
name: "netifaces_another_file",
wantHas: assert.True,
wantErrMsg: ``,
}, { }, {
// This one tests if the first found valid interface prevents rootFsys: fstest.MapFS{
// checking files under the `source` directive. netifaces: &fstest.MapFile{
name: "ignore_patterns", Data: []byte(`source hello` + nl +
data: []byte(`source hello` + nl + `iface ` + ifaceName + ` inet static` + nl,
`source world` + nl +
`iface ` + iface + ` inet static` + nl,
), ),
wantCont: false, },
wantPatterns: []string{}, },
name: "netifaces_ignore_another",
wantHas: assert.True,
wantErrMsg: ``,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
r := bytes.NewReader(tc.data)
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
patterns, has, err := iface.ifacesStaticConfig(r) substRootDirFS(t, tc.rootFsys)
require.NoError(t, err)
assert.Equal(t, tc.wantCont, has) has, err := IfaceHasStaticIP(ifaceName)
assert.ElementsMatch(t, tc.wantPatterns, patterns) testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
} tc.wantHas(t, has)
}
func TestSetStaticIPdhcpcdConf(t *testing.T) {
testCases := []struct {
name string
dhcpcdConf string
routers net.IP
}{{
name: "with_gateway",
dhcpcdConf: nl + `# wlan0 added by AdGuard Home.` + nl +
`interface wlan0` + nl +
`static ip_address=192.168.0.2/24` + nl +
`static routers=192.168.0.1` + nl +
`static domain_name_servers=192.168.0.2` + nl + nl,
routers: net.IP{192, 168, 0, 1},
}, {
name: "without_gateway",
dhcpcdConf: nl + `# wlan0 added by AdGuard Home.` + nl +
`interface wlan0` + nl +
`static ip_address=192.168.0.2/24` + nl +
`static domain_name_servers=192.168.0.2` + nl + nl,
routers: nil,
}}
ipNet := &net.IPNet{
IP: net.IP{192, 168, 0, 2},
Mask: net.IPMask{255, 255, 255, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := dhcpcdConfIface("wlan0", ipNet, tc.routers, net.IP{192, 168, 0, 2})
assert.Equal(t, tc.dhcpcdConf, s)
}) })
} }
} }

View file

@ -16,7 +16,7 @@ import (
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
filename := fmt.Sprintf("etc/hostname.%s", ifaceName) filename := fmt.Sprintf("etc/hostname.%s", ifaceName)
return aghos.FileWalker(hostnameIfStaticConfig).Walk(aghos.RootDirFS(), filename) return aghos.FileWalker(hostnameIfStaticConfig).Walk(rootDirFS, filename)
} }
// hostnameIfStaticConfig checks if the interface is configured by // hostnameIfStaticConfig checks if the interface is configured by

View file

@ -4,49 +4,69 @@
package aghnet package aghnet
import ( import (
"strings" "fmt"
"io/fs"
"testing" "testing"
"testing/fstest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestHostnameIfStaticConfig(t *testing.T) { func TestIfaceHasStaticIP(t *testing.T) {
const nl = "\n" const ifaceName = "em0"
confFile := fmt.Sprintf("etc/hostname.%s", ifaceName)
testCases := []struct { testCases := []struct {
name string name string
rcconfData string rootFsys fs.FS
wantHas bool wantHas assert.BoolAssertionFunc
}{{ }{{
name: "simple", name: "simple",
rcconfData: `inet 127.0.0.253` + nl, rootFsys: fstest.MapFS{
wantHas: true, confFile: &fstest.MapFile{
Data: []byte(`inet 127.0.0.253` + nl),
},
},
wantHas: assert.True,
}, { }, {
name: "case_sensitiveness", name: "case_sensitiveness",
rcconfData: `InEt 127.0.0.253` + nl, rootFsys: fstest.MapFS{
wantHas: false, confFile: &fstest.MapFile{
Data: []byte(`InEt 127.0.0.253` + nl),
},
},
wantHas: assert.False,
}, { }, {
name: "comments_and_trash", name: "comments_and_trash",
rcconfData: `# comment 1` + nl + rootFsys: fstest.MapFS{
`` + nl + confFile: &fstest.MapFile{
Data: []byte(`# comment 1` + nl + nl +
`# inet 127.0.0.253` + nl + `# inet 127.0.0.253` + nl +
`inet` + nl, `inet` + nl,
wantHas: false, ),
},
},
wantHas: assert.False,
}, { }, {
name: "incorrect_config", name: "incorrect_config",
rcconfData: `inet6 127.0.0.253` + nl + rootFsys: fstest.MapFS{
`inet 256.256.256.256` + nl, confFile: &fstest.MapFile{
wantHas: false, Data: []byte(`inet6 127.0.0.253` + nl + `inet 256.256.256.256` + nl),
},
},
wantHas: assert.False,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
r := strings.NewReader(tc.rcconfData)
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
_, has, err := hostnameIfStaticConfig(r) substRootDirFS(t, tc.rootFsys)
has, err := IfaceHasStaticIP(ifaceName)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tc.wantHas, has) tc.wantHas(t, has)
}) })
} }
} }

View file

@ -1,12 +1,17 @@
package aghnet package aghnet
import ( import (
"bytes"
"encoding/json"
"fmt"
"io/fs" "io/fs"
"net" "net"
"os" "os"
"strings"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -20,6 +25,113 @@ func TestMain(m *testing.M) {
// testdata is the filesystem containing data for testing the package. // testdata is the filesystem containing data for testing the package.
var testdata fs.FS = os.DirFS("./testdata") var testdata fs.FS = os.DirFS("./testdata")
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
// package with fsys for tests ran under t.
func substRootDirFS(t testing.TB, fsys fs.FS) {
t.Helper()
prev := rootDirFS
t.Cleanup(func() { rootDirFS = prev })
rootDirFS = fsys
}
// RunCmdFunc is the signature of aghos.RunCommand function.
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
// substShell replaces the the aghos.RunCommand function used throughout the
// package with rc for tests ran under t.
func substShell(t testing.TB, rc RunCmdFunc) {
t.Helper()
prev := aghosRunCommand
t.Cleanup(func() { aghosRunCommand = prev })
aghosRunCommand = rc
}
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
// execution result. It's only needed to simplify testing.
//
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
type mapShell map[string]struct {
err error
out string
code int
}
// theOnlyCmd returns mapShell that only handles a single command and arguments
// combination from cmd.
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
return mapShell{cmd: {code: code, out: out, err: err}}
}
// RunCmd is a RunCmdFunc handled by s.
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
key := strings.Join(append([]string{cmd}, args...), " ")
ret, ok := s[key]
if !ok {
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
}
return ret.code, []byte(ret.out), ret.err
}
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
// throughout the package with f for tests ran under t.
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
t.Helper()
prev := netInterfaceAddrs
t.Cleanup(func() { netInterfaceAddrs = prev })
netInterfaceAddrs = f
}
func TestGatewayIP(t *testing.T) {
const ifaceName = "ifaceName"
const cmd = "ip route show dev " + ifaceName
testCases := []struct {
name string
shell mapShell
want net.IP
}{{
name: "success_v4",
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
want: net.IP{1, 2, 3, 4}.To16(),
}, {
name: "success_v6",
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
want: net.IP{
0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0xFF, 0xFF,
},
}, {
name: "bad_output",
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
want: nil,
}, {
name: "err_runcmd",
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
want: nil,
}, {
name: "bad_code",
shell: theOnlyCmd(cmd, 1, "", nil),
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substShell(t, tc.shell.RunCmd)
assert.Equal(t, tc.want, GatewayIP(ifaceName))
})
}
}
func TestGetInterfaceByIP(t *testing.T) { func TestGetInterfaceByIP(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb() ifaces, err := GetValidNetInterfacesForWeb()
require.NoError(t, err) require.NoError(t, err)
@ -130,3 +242,107 @@ func TestCheckPort(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }
func TestCollectAllIfacesAddrs(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
addrs []net.Addr
wantAddrs []string
}{{
name: "success",
wantErrMsg: ``,
addrs: []net.Addr{&net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
}, &net.IPNet{
IP: net.IP{4, 3, 2, 1},
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
}},
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
}, {
name: "not_cidr",
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
addrs: []net.Addr{&net.IPAddr{
IP: net.IP{1, 2, 3, 4},
}},
wantAddrs: nil,
}, {
name: "empty",
wantErrMsg: ``,
addrs: []net.Addr{},
wantAddrs: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
addrs, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.wantAddrs, addrs)
})
}
t.Run("internal_error", func(t *testing.T) {
const errAddrs errors.Error = "can't get addresses"
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
_, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, wantErrMsg, err)
})
}
func TestIsAddrInUse(t *testing.T) {
t.Run("addr_in_use", func(t *testing.T) {
l, err := net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
assert.True(t, IsAddrInUse(err))
})
t.Run("another", func(t *testing.T) {
const anotherErr errors.Error = "not addr in use"
assert.False(t, IsAddrInUse(anotherErr))
})
}
func TestNetInterface_MarshalJSON(t *testing.T) {
const want = `{` +
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
`"flags":"up|multicast",` +
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
`"name":"iface0",` +
`"mtu":1500` +
`}` + "\n"
ip4, ip6 := net.IP{1, 2, 3, 4}, net.IP{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
mask4, mask6 := net.CIDRMask(24, netutil.IPv4BitLen), net.CIDRMask(8, netutil.IPv6BitLen)
iface := &NetInterface{
Addresses: []net.IP{ip4, ip6},
Subnets: []*net.IPNet{{
IP: ip4.Mask(mask4),
Mask: mask4,
}, {
IP: ip6.Mask(mask6),
Mask: mask6,
}},
Name: "iface0",
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
Flags: net.FlagUp | net.FlagMulticast,
MTU: 1500,
}
b := &bytes.Buffer{}
err := json.NewEncoder(b).Encode(iface)
require.NoError(t, err)
assert.Equal(t, want, b.String())
}

View file

@ -1,11 +1,5 @@
package aghnet package aghnet
import (
"time"
"github.com/AdguardTeam/golibs/log"
)
// DefaultRefreshIvl is the default period of time between refreshing cached // DefaultRefreshIvl is the default period of time between refreshing cached
// addresses. // addresses.
// const DefaultRefreshIvl = 5 * time.Minute // const DefaultRefreshIvl = 5 * time.Minute
@ -16,39 +10,21 @@ type HostGenFunc func() (host string)
// SystemResolvers helps to work with local resolvers' addresses provided by OS. // SystemResolvers helps to work with local resolvers' addresses provided by OS.
type SystemResolvers interface { type SystemResolvers interface {
// Get returns the slice of local resolvers' addresses. It should be // Get returns the slice of local resolvers' addresses. It must be safe for
// safe for concurrent use. // concurrent use.
Get() (rs []string) Get() (rs []string)
// refresh refreshes the local resolvers' addresses cache. It should be // refresh refreshes the local resolvers' addresses cache. It must be safe
// safe for concurrent use. // for concurrent use.
refresh() (err error) refresh() (err error)
} }
// refreshWithTicker refreshes the cache of sr after each tick form tickCh.
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
defer log.OnPanic("systemResolvers")
// TODO(e.burkov): Implement a functionality to stop ticker.
for range tickCh {
err := sr.refresh()
if err != nil {
log.Error("systemResolvers: error in refreshing goroutine: %s", err)
continue
}
log.Debug("systemResolvers: local addresses cache is refreshed")
}
}
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate // NewSystemResolvers returns a SystemResolvers with the cache refresh rate
// defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If // defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If
// nil is passed for hostGenFunc, the default generator will be used. // nil is passed for hostGenFunc, the default generator will be used.
func NewSystemResolvers( func NewSystemResolvers(
refreshIvl time.Duration,
hostGenFunc HostGenFunc, hostGenFunc HostGenFunc,
) (sr SystemResolvers, err error) { ) (sr SystemResolvers, err error) {
sr = newSystemResolvers(refreshIvl, hostGenFunc) sr = newSystemResolvers(hostGenFunc)
// Fill cache. // Fill cache.
err = sr.refresh() err = sr.refresh()
@ -56,11 +32,5 @@ func NewSystemResolvers(
return nil, err return nil, err
} }
if refreshIvl > 0 {
ticker := time.NewTicker(refreshIvl)
go refreshWithTicker(sr, ticker.C)
}
return sr, nil return sr, nil
} }

View file

@ -24,12 +24,15 @@ func defaultHostGen() (host string) {
// systemResolvers is a default implementation of SystemResolvers interface. // systemResolvers is a default implementation of SystemResolvers interface.
type systemResolvers struct { type systemResolvers struct {
resolver *net.Resolver // addrsLock protects addrs.
hostGenFunc HostGenFunc addrsLock sync.RWMutex
// addrs is the set that contains cached local resolvers' addresses. // addrs is the set that contains cached local resolvers' addresses.
addrs *stringutil.Set addrs *stringutil.Set
addrsLock sync.RWMutex
// resolver is used to fetch the resolvers' addresses.
resolver *net.Resolver
// hostGenFunc generates hosts to resolve.
hostGenFunc HostGenFunc
} }
const ( const (
@ -44,6 +47,7 @@ const (
errUnexpectedHostFormat errors.Error = "unexpected host format" errUnexpectedHostFormat errors.Error = "unexpected host format"
) )
// refresh implements the SystemResolvers interface for *systemResolvers.
func (sr *systemResolvers) refresh() (err error) { func (sr *systemResolvers) refresh() (err error) {
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }() defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
@ -56,7 +60,7 @@ func (sr *systemResolvers) refresh() (err error) {
return err return err
} }
func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr SystemResolvers) { func newSystemResolvers(hostGenFunc HostGenFunc) (sr SystemResolvers) {
if hostGenFunc == nil { if hostGenFunc == nil {
hostGenFunc = defaultHostGen hostGenFunc = defaultHostGen
} }
@ -76,19 +80,18 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S
func validateDialedHost(host string) (err error) { func validateDialedHost(host string) (err error) {
defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }() defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }()
var ipStr string
parts := strings.Split(host, "%") parts := strings.Split(host, "%")
switch len(parts) { switch len(parts) {
case 1: case 1:
ipStr = host // host
case 2: case 2:
// Remove the zone and check the IP address part. // Remove the zone and check the IP address part.
ipStr = parts[0] host = parts[0]
default: default:
return errUnexpectedHostFormat return errUnexpectedHostFormat
} }
if net.ParseIP(ipStr) == nil { if _, err = netutil.ParseIP(host); err != nil {
return errBadAddrPassed return errBadAddrPassed
} }

View file

@ -6,37 +6,32 @@ package aghnet
import ( import (
"context" "context"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func createTestSystemResolversImp( func createTestSystemResolversImpl(
t *testing.T, t *testing.T,
refreshDur time.Duration,
hostGenFunc HostGenFunc, hostGenFunc HostGenFunc,
) (imp *systemResolvers) { ) (imp *systemResolvers) {
t.Helper() t.Helper()
sr := createTestSystemResolvers(t, refreshDur, hostGenFunc) sr := createTestSystemResolvers(t, hostGenFunc)
require.IsType(t, (*systemResolvers)(nil), sr)
var ok bool return sr.(*systemResolvers)
imp, ok = sr.(*systemResolvers)
require.True(t, ok)
return imp
} }
func TestSystemResolvers_Refresh(t *testing.T) { func TestSystemResolvers_Refresh(t *testing.T) {
t.Run("expected_error", func(t *testing.T) { t.Run("expected_error", func(t *testing.T) {
sr := createTestSystemResolvers(t, 0, nil) sr := createTestSystemResolvers(t, nil)
assert.NoError(t, sr.refresh()) assert.NoError(t, sr.refresh())
}) })
t.Run("unexpected_error", func(t *testing.T) { t.Run("unexpected_error", func(t *testing.T) {
_, err := NewSystemResolvers(0, func() string { _, err := NewSystemResolvers(func() string {
return "127.0.0.1::123" return "127.0.0.1::123"
}) })
assert.Error(t, err) assert.Error(t, err)
@ -44,7 +39,7 @@ func TestSystemResolvers_Refresh(t *testing.T) {
} }
func TestSystemResolvers_DialFunc(t *testing.T) { func TestSystemResolvers_DialFunc(t *testing.T) {
imp := createTestSystemResolversImp(t, 0, nil) imp := createTestSystemResolversImpl(t, nil)
testCases := []struct { testCases := []struct {
want error want error
@ -52,7 +47,7 @@ func TestSystemResolvers_DialFunc(t *testing.T) {
address string address string
}{{ }{{
want: errFakeDial, want: errFakeDial,
name: "valid", name: "valid_ipv4",
address: "127.0.0.1", address: "127.0.0.1",
}, { }, {
want: errFakeDial, want: errFakeDial,

View file

@ -2,7 +2,6 @@ package aghnet
import ( import (
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -10,13 +9,12 @@ import (
func createTestSystemResolvers( func createTestSystemResolvers(
t *testing.T, t *testing.T,
refreshDur time.Duration,
hostGenFunc HostGenFunc, hostGenFunc HostGenFunc,
) (sr SystemResolvers) { ) (sr SystemResolvers) {
t.Helper() t.Helper()
var err error var err error
sr, err = NewSystemResolvers(refreshDur, hostGenFunc) sr, err = NewSystemResolvers(hostGenFunc)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, sr) require.NotNil(t, sr)
@ -24,8 +22,14 @@ func createTestSystemResolvers(
} }
func TestSystemResolvers_Get(t *testing.T) { func TestSystemResolvers_Get(t *testing.T) {
sr := createTestSystemResolvers(t, 0, nil) sr := createTestSystemResolvers(t, nil)
assert.NotEmpty(t, sr.Get())
var rs []string
require.NotPanics(t, func() {
rs = sr.Get()
})
assert.NotEmpty(t, rs)
} }
// TODO(e.burkov): Write tests for refreshWithTicker. // TODO(e.burkov): Write tests for refreshWithTicker.

View file

@ -11,7 +11,6 @@ import (
"os/exec" "os/exec"
"strings" "strings"
"sync" "sync"
"time"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -27,7 +26,7 @@ type systemResolvers struct {
addrsLock sync.RWMutex addrsLock sync.RWMutex
} }
func newSystemResolvers(refreshIvl time.Duration, _ HostGenFunc) (sr SystemResolvers) { func newSystemResolvers(_ HostGenFunc) (sr SystemResolvers) {
return &systemResolvers{} return &systemResolvers{}
} }

6
internal/aghnet/testdata/proc_net_arp vendored Normal file
View file

@ -0,0 +1,6 @@
IP address HW type Flags HW address Mask Device
192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan
::ffff:ffff 0x1 0x0 ef:cd:ab:ef:cd:ab * br-lan
0.0.0.0 0x0 0x0 00:00:00:00:00:00 * unspec
1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan
1.2.3.4 0x1 0x2 12:34:56:78:910 * wan

View file

@ -52,24 +52,27 @@ func HaveAdminRights() (bool, error) {
return haveAdminRights() return haveAdminRights()
} }
// MaxCmdOutputSize is the maximum length of performed shell command output. // MaxCmdOutputSize is the maximum length of performed shell command output in
const MaxCmdOutputSize = 2 * 1024 // bytes.
const MaxCmdOutputSize = 64 * 1024
// RunCommand runs shell command. // RunCommand runs shell command.
func RunCommand(command string, arguments ...string) (int, string, error) { func RunCommand(command string, arguments ...string) (code int, output []byte, err error) {
cmd := exec.Command(command, arguments...) cmd := exec.Command(command, arguments...)
out, err := cmd.Output() out, err := cmd.Output()
if len(out) > MaxCmdOutputSize { if len(out) > MaxCmdOutputSize {
out = out[:MaxCmdOutputSize] out = out[:MaxCmdOutputSize]
} }
if errors.As(err, new(*exec.ExitError)) { if err != nil {
return cmd.ProcessState.ExitCode(), string(out), nil if eerr := new(exec.ExitError); errors.As(err, &eerr) {
} else if err != nil { return eerr.ExitCode(), eerr.Stderr, nil
return 1, "", fmt.Errorf("exec.Command(%s) failed: %w: %s", command, err, string(out))
} }
return cmd.ProcessState.ExitCode(), string(out), nil return 1, nil, fmt.Errorf("command %q failed: %w: %s", command, err, out)
}
return cmd.ProcessState.ExitCode(), out, nil
} }
// PIDByCommand searches for process named command and returns its PID ignoring // PIDByCommand searches for process named command and returns its PID ignoring
@ -172,3 +175,13 @@ func RootDirFS() (fsys fs.FS) {
// behavior is undocumented but it currently works. // behavior is undocumented but it currently works.
return os.DirFS("") return os.DirFS("")
} }
// NotifyShutdownSignal notifies c on receiving shutdown signals.
func NotifyShutdownSignal(c chan<- os.Signal) {
notifyShutdownSignal(c)
}
// IsShutdownSignal returns true if sig is a shutdown signal.
func IsShutdownSignal(sig os.Signal) (ok bool) {
return isShutdownSignal(sig)
}

27
internal/aghos/os_unix.go Normal file
View file

@ -0,0 +1,27 @@
//go:build darwin || freebsd || linux || openbsd
// +build darwin freebsd linux openbsd
package aghos
import (
"os"
"os/signal"
"golang.org/x/sys/unix"
)
func notifyShutdownSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
}
func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig {
case
unix.SIGINT,
unix.SIGQUIT,
unix.SIGTERM:
return true
default:
return false
}
}

View file

@ -4,6 +4,10 @@
package aghos package aghos
import ( import (
"os"
"os/signal"
"syscall"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
@ -35,3 +39,20 @@ func haveAdminRights() (bool, error) {
func isOpenWrt() (ok bool) { func isOpenWrt() (ok bool) {
return false return false
} }
func notifyShutdownSignal(c chan<- os.Signal) {
// syscall.SIGTERM is processed automatically. See go doc os/signal,
// section Windows.
signal.Notify(c, os.Interrupt)
}
func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig {
case
os.Interrupt,
syscall.SIGTERM:
return true
default:
return false
}
}

View file

@ -135,7 +135,6 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
pctx.Res = s.genNXDomain(pctx.Req) pctx.Res = s.genNXDomain(pctx.Req)
return resultCodeFinish return resultCodeFinish
} }
return resultCodeSuccess return resultCodeSuccess

View file

@ -173,7 +173,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
// TODO(e.burkov): Enable the refresher after the actual implementation // TODO(e.burkov): Enable the refresher after the actual implementation
// passes the public testing. // passes the public testing.
s.sysResolvers, err = aghnet.NewSystemResolvers(0, nil) s.sysResolvers, err = aghnet.NewSystemResolvers(nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("initializing system resolvers: %w", err) return nil, fmt.Errorf("initializing system resolvers: %w", err)
} }

View file

@ -83,7 +83,7 @@ func TestRecursionDetector_Suspect(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
msg dns.Msg msg dns.Msg
want bool want int
}{{ }{{
name: "simple", name: "simple",
msg: dns.Msg{ msg: dns.Msg{
@ -95,24 +95,18 @@ func TestRecursionDetector_Suspect(t *testing.T) {
Qtype: dns.TypeA, Qtype: dns.TypeA,
}}, }},
}, },
want: true, want: 1,
}, { }, {
name: "unencumbered", name: "unencumbered",
msg: dns.Msg{}, msg: dns.Msg{},
want: false, want: 0,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Cleanup(rd.clear) t.Cleanup(rd.clear)
rd.add(tc.msg) rd.add(tc.msg)
assert.Equal(t, tc.want, rd.recentRequests.Stats().Count)
if tc.want {
assert.Equal(t, 1, rd.recentRequests.Stats().Count)
} else {
assert.Zero(t, rd.recentRequests.Stats().Count)
}
}) })
} }
} }

View file

@ -518,44 +518,31 @@ func StartMods() error {
func checkPermissions() { func checkPermissions() {
log.Info("Checking if AdGuard Home has necessary permissions") log.Info("Checking if AdGuard Home has necessary permissions")
if runtime.GOOS == "windows" { if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil {
// On Windows we need to have admin rights to run properly
admin, _ := aghos.HaveAdminRights()
if admin {
return
}
log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.") log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.")
} }
// We should check if AdGuard Home is able to bind to port 53 // We should check if AdGuard Home is able to bind to port 53
ok, err := aghnet.CanBindPort(53) err := aghnet.CheckPort("tcp", net.IP{127, 0, 0, 1}, defaultPortDNS)
if err != nil {
if ok {
log.Info("AdGuard Home can bind to port 53")
return
}
if errors.Is(err, os.ErrPermission) { if errors.Is(err, os.ErrPermission) {
msg := `Permission check failed. log.Fatal(`Permission check failed.
AdGuard Home is not allowed to bind to privileged ports (for instance, port 53). AdGuard Home is not allowed to bind to privileged ports (for instance, port 53).
Please note, that this is crucial for a server to be able to use privileged ports. Please note, that this is crucial for a server to be able to use privileged ports.
You have two options: You have two options:
1. Run AdGuard Home with root privileges 1. Run AdGuard Home with root privileges
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability: 2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser` https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`)
log.Fatal(msg)
} }
msg := fmt.Sprintf(`AdGuard failed to bind to port 53 due to %v log.Info(
"AdGuard failed to bind to port 53: %s\n\n"+
"Please note, that this is crucial for a DNS server to be able to use that port.",
err,
)
}
Please note, that this is crucial for a DNS server to be able to use that port.`, err) log.Info("AdGuard Home can bind to port 53")
log.Info(msg)
} }
// Write PID to a file // Write PID to a file

View file

@ -16,18 +16,17 @@ type RDNS struct {
exchanger dnsforward.RDNSExchanger exchanger dnsforward.RDNSExchanger
clients *clientsContainer clients *clientsContainer
// usePrivate is used to store the state of current private RDNS // usePrivate is used to store the state of current private RDNS resolving
// resolving settings and to react to it's changes. // settings and to react to it's changes.
usePrivate uint32 usePrivate uint32
// ipCh used to pass client's IP to rDNS workerLoop. // ipCh used to pass client's IP to rDNS workerLoop.
ipCh chan net.IP ipCh chan net.IP
// ipCache caches the IP addresses to be resolved by rDNS. The resolved // ipCache caches the IP addresses to be resolved by rDNS. The resolved
// address stays here while it's inside clients. After leaving clients // address stays here while it's inside clients. After leaving clients the
// the address will be resolved once again. If the address couldn't be // address will be resolved once again. If the address couldn't be
// resolved, cache prevents further attempts to resolve it for some // resolved, cache prevents further attempts to resolve it for some time.
// time.
ipCache cache.Cache ipCache cache.Cache
} }

View file

@ -314,12 +314,13 @@ func (s *openbsdRunComService) runCom(cmd string) (out string, err error) {
// TODO(e.burkov): It's possible that os.ErrNotExist is caused by // TODO(e.burkov): It's possible that os.ErrNotExist is caused by
// something different than the service script's non-existence. Keep it // something different than the service script's non-existence. Keep it
// in mind, when replace the aghos.RunCommand. // in mind, when replace the aghos.RunCommand.
_, out, err = aghos.RunCommand(scriptPath, cmd) var outData []byte
_, outData, err = aghos.RunCommand(scriptPath, cmd)
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
return "", service.ErrNotInstalled return "", service.ErrNotInstalled
} }
return out, err return string(outData), err
} }
// Status implements service.Service interface for *openbsdRunComService. // Status implements service.Service interface for *openbsdRunComService.