From bfc7e16d84037fab30116fcc30931fb8c05980fb Mon Sep 17 00:00:00 2001
From: Ainar Garipov <a.garipov@adguard.com>
Date: Wed, 31 Mar 2021 12:36:57 +0300
Subject: [PATCH] Pull request: dhcpd: do not assume mac addrs of 6 bytes

Closes #2828.

Squashed commit of the following:

commit 26c6cf81c32469e1c4955aafb40490c29b4d1a99
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Mar 30 17:43:53 2021 +0300

    dhcpd: do not assume mac addrs of 6 bytes
---
 CHANGELOG.md                     |  2 ++
 internal/aghnet/addr.go          | 23 +++++++++++++
 internal/aghnet/addr_test.go     | 57 +++++++++++++++++++++++++++++++
 internal/dhcpd/routeradv.go      | 58 +++++++++++++++++++++++++++-----
 internal/dhcpd/routeradv_test.go | 31 +++++++++++------
 internal/dhcpd/v4.go             | 26 +++++++++-----
 internal/dhcpd/v6.go             | 55 +++++++++++++++++++-----------
 7 files changed, 204 insertions(+), 48 deletions(-)
 create mode 100644 internal/aghnet/addr.go
 create mode 100644 internal/aghnet/addr_test.go

diff --git a/CHANGELOG.md b/CHANGELOG.md
index ef25039f..5a3c1f64 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -33,6 +33,7 @@ and this project adheres to
 
 ### Fixed
 
+- Assumption that MAC addresses always have the length of 6 octets ([#2828]).
 - Support for more than one `/24` subnet in DHCP ([#2541]).
 - Invalid filenames in the `mobileconfig` API responses ([#2835]).
 
@@ -47,6 +48,7 @@ and this project adheres to
 [#2498]: https://github.com/AdguardTeam/AdGuardHome/issues/2498
 [#2533]: https://github.com/AdguardTeam/AdGuardHome/issues/2533
 [#2541]: https://github.com/AdguardTeam/AdGuardHome/issues/2541
+[#2828]: https://github.com/AdguardTeam/AdGuardHome/issues/2828
 [#2835]: https://github.com/AdguardTeam/AdGuardHome/issues/2835
 [#2838]: https://github.com/AdguardTeam/AdGuardHome/issues/2838
 
diff --git a/internal/aghnet/addr.go b/internal/aghnet/addr.go
new file mode 100644
index 00000000..559c9b46
--- /dev/null
+++ b/internal/aghnet/addr.go
@@ -0,0 +1,23 @@
+package aghnet
+
+import (
+	"fmt"
+	"net"
+
+	"github.com/AdguardTeam/AdGuardHome/internal/agherr"
+)
+
+// ValidateHardwareAddress returns an error if hwa is not a valid EUI-48,
+// EUI-64, or 20-octet InfiniBand link-layer address.
+func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) {
+	defer agherr.Annotate("validating hardware address %q: %w", &err, hwa)
+
+	switch l := len(hwa); l {
+	case 0:
+		return agherr.Error("address is empty")
+	case 6, 8, 20:
+		return nil
+	default:
+		return fmt.Errorf("bad len: %d", l)
+	}
+}
diff --git a/internal/aghnet/addr_test.go b/internal/aghnet/addr_test.go
new file mode 100644
index 00000000..0b3eb48d
--- /dev/null
+++ b/internal/aghnet/addr_test.go
@@ -0,0 +1,57 @@
+package aghnet
+
+import (
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestValidateHardwareAddress(t *testing.T) {
+	testCases := []struct {
+		name       string
+		wantErrMsg string
+		in         net.HardwareAddr
+	}{{
+		name:       "success_eui_48",
+		wantErrMsg: "",
+		in:         net.HardwareAddr{0x00, 0x01, 0x02, 0x03, 0x04, 0x05},
+	}, {
+		name:       "success_eui_64",
+		wantErrMsg: "",
+		in:         net.HardwareAddr{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07},
+	}, {
+		name:       "success_infiniband",
+		wantErrMsg: "",
+		in: net.HardwareAddr{
+			0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
+			0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
+			0x10, 0x11, 0x12, 0x13,
+		},
+	}, {
+		name:       "error_nil",
+		wantErrMsg: `validating hardware address "": address is empty`,
+		in:         nil,
+	}, {
+		name:       "error_empty",
+		wantErrMsg: `validating hardware address "": address is empty`,
+		in:         net.HardwareAddr{},
+	}, {
+		name:       "error_bad",
+		wantErrMsg: `validating hardware address "00:01:02:03": bad len: 4`,
+		in:         net.HardwareAddr{0x00, 0x01, 0x02, 0x03},
+	}}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			err := ValidateHardwareAddress(tc.in)
+			if tc.wantErrMsg == "" {
+				assert.NoError(t, err)
+			} else {
+				require.Error(t, err)
+				assert.Equal(t, tc.wantErrMsg, err.Error())
+			}
+		})
+	}
+}
diff --git a/internal/dhcpd/routeradv.go b/internal/dhcpd/routeradv.go
index 01c119e6..52d64bce 100644
--- a/internal/dhcpd/routeradv.go
+++ b/internal/dhcpd/routeradv.go
@@ -7,6 +7,7 @@ import (
 	"sync/atomic"
 	"time"
 
+	"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
 	"github.com/AdguardTeam/golibs/log"
 	"golang.org/x/net/icmp"
 	"golang.org/x/net/ipv6"
@@ -36,6 +37,33 @@ type icmpv6RA struct {
 	mtu                         uint32
 }
 
+// hwAddrToLinkLayerAddr converts a hardware address into a form required by
+// RFC4861.  That is, a byte slice of length divisible by 8.
+//
+// See https://tools.ietf.org/html/rfc4861#section-4.6.1.
+func hwAddrToLinkLayerAddr(hwa net.HardwareAddr) (lla []byte, err error) {
+	err = aghnet.ValidateHardwareAddress(hwa)
+	if err != nil {
+		// Don't wrap the error, because it already contains enough
+		// context.
+		return nil, err
+	}
+
+	if len(hwa) == 6 || len(hwa) == 8 {
+		lla = make([]byte, 8)
+		copy(lla, hwa)
+
+		return lla, nil
+	}
+
+	// Assume that aghnet.ValidateHardwareAddress prevents lengths other
+	// than 20 by now.
+	lla = make([]byte, 24)
+	copy(lla, hwa)
+
+	return lla, nil
+}
+
 // Create an ICMPv6.RouterAdvertisement packet with all necessary options.
 //
 // ICMPv6:
@@ -63,15 +91,23 @@ type icmpv6RA struct {
 //     Reserved[2]
 //     MTU[4]
 //   Option=Source link-layer address(1):
-//     Link-Layer Address[6]
+//     Link-Layer Address[8/24]
 //   Option=Recursive DNS Server(25):
 //     Type[1]
 //     Length * 8bytes[1]
 //     Reserved[2]
 //     Lifetime[4]
 //     Addresses of IPv6 Recursive DNS Servers[16]
-func createICMPv6RAPacket(params icmpv6RA) []byte {
-	data := make([]byte, 88)
+func createICMPv6RAPacket(params icmpv6RA) (data []byte, err error) {
+	var lla []byte
+	lla, err = hwAddrToLinkLayerAddr(params.sourceLinkLayerAddress)
+	if err != nil {
+		return nil, fmt.Errorf("converting source link layer address: %w", err)
+	}
+
+	// TODO(a.garipov): Don't use a magic constant here.  Refactor the code
+	// and make all constants named instead of all those comments..
+	data = make([]byte, 82+len(lla))
 	i := 0
 
 	// ICMPv6:
@@ -138,8 +174,9 @@ func createICMPv6RAPacket(params icmpv6RA) []byte {
 	data[i] = 1   // Type
 	data[i+1] = 1 // Length
 	i += 2
-	copy(data[i:], params.sourceLinkLayerAddress) // Link-Layer Address[6]
-	i += 6
+
+	copy(data[i:], lla) // Link-Layer Address[8/24]
+	i += len(lla)
 
 	// Option=Recursive DNS Server:
 
@@ -152,11 +189,11 @@ func createICMPv6RAPacket(params icmpv6RA) []byte {
 	i += 4
 	copy(data[i:], params.recursiveDNSServer) // Addresses of IPv6 Recursive DNS Servers[16]
 
-	return data
+	return data, nil
 }
 
 // Init - initialize RA module
-func (ra *raCtx) Init() error {
+func (ra *raCtx) Init() (err error) {
 	ra.stop.Store(0)
 	ra.conn = nil
 	if !(ra.raAllowSLAAC || ra.raSLAACOnly) {
@@ -177,9 +214,12 @@ func (ra *raCtx) Init() error {
 	params.prefix = make([]byte, 16)
 	copy(params.prefix, ra.prefixIPAddr[:8]) // /64
 
-	data := createICMPv6RAPacket(params)
+	var data []byte
+	data, err = createICMPv6RAPacket(params)
+	if err != nil {
+		return fmt.Errorf("creating packet: %w", err)
+	}
 
-	var err error
 	success := false
 	ipAndScope := ra.ipAddr.String() + "%" + ra.ifaceName
 	ra.conn, err = icmp.ListenPacket("ip6:ipv6-icmp", ipAndScope)
diff --git a/internal/dhcpd/routeradv_test.go b/internal/dhcpd/routeradv_test.go
index 4a0f4c5b..94f7afd3 100644
--- a/internal/dhcpd/routeradv_test.go
+++ b/internal/dhcpd/routeradv_test.go
@@ -7,8 +7,23 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
-func TestRA(t *testing.T) {
-	data := createICMPv6RAPacket(icmpv6RA{
+func TestCreateICMPv6RAPacket(t *testing.T) {
+	wantData := []byte{
+		0x86, 0x00, 0x00, 0x00, 0x40, 0x40, 0x07, 0x08,
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+		0x03, 0x04, 0x40, 0xc0, 0x00, 0x00, 0x0e, 0x10,
+		0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x00, 0x00,
+		0x12, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+		0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x05, 0xdc,
+		0x01, 0x01, 0x0a, 0x00, 0x27, 0x00, 0x00, 0x00,
+		0x00, 0x00, 0x19, 0x03, 0x00, 0x00, 0x00, 0x00,
+		0x0e, 0x10, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+		0x00, 0x00, 0x08, 0x00, 0x27, 0xff, 0xfe, 0x00,
+		0x00, 0x00,
+	}
+
+	gotData, err := createICMPv6RAPacket(icmpv6RA{
 		managedAddressConfiguration: false,
 		otherConfiguration:          true,
 		mtu:                         1500,
@@ -17,13 +32,7 @@ func TestRA(t *testing.T) {
 		recursiveDNSServer:          net.ParseIP("fe80::800:27ff:fe00:0"),
 		sourceLinkLayerAddress:      []byte{0x0a, 0x00, 0x27, 0x00, 0x00, 0x00},
 	})
-	dataCorrect := []byte{
-		0x86, 0x00, 0x00, 0x00, 0x40, 0x40, 0x07, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
-		0x03, 0x04, 0x40, 0xc0, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x00, 0x00,
-		0x12, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
-		0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x05, 0xdc, 0x01, 0x01, 0x0a, 0x00, 0x27, 0x00, 0x00, 0x00,
-		0x19, 0x03, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x10, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
-		0x08, 0x00, 0x27, 0xff, 0xfe, 0x00, 0x00, 0x00,
-	}
-	assert.Equal(t, dataCorrect, data)
+
+	assert.NoError(t, err)
+	assert.Equal(t, wantData, gotData)
 }
diff --git a/internal/dhcpd/v4.go b/internal/dhcpd/v4.go
index 2457606d..a4ff66fb 100644
--- a/internal/dhcpd/v4.go
+++ b/internal/dhcpd/v4.go
@@ -10,6 +10,7 @@ import (
 	"time"
 
 	"github.com/AdguardTeam/AdGuardHome/internal/agherr"
+	"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
 	"github.com/AdguardTeam/golibs/log"
 	"github.com/go-ping/ping"
 	"github.com/insomniacslk/dhcp/dhcpv4"
@@ -289,8 +290,9 @@ func (s *v4Server) AddStaticLease(l Lease) (err error) {
 		return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP)
 	}
 
-	if len(l.HWAddr) != 6 {
-		return fmt.Errorf("invalid mac %q, only EUI-48 is supported", l.HWAddr)
+	err = aghnet.ValidateHardwareAddress(l.HWAddr)
+	if err != nil {
+		return fmt.Errorf("validating lease: %w", err)
 	}
 
 	l.Expiry = time.Unix(leaseExpireStatic, 0)
@@ -330,17 +332,21 @@ func (s *v4Server) AddStaticLease(l Lease) (err error) {
 	return nil
 }
 
-// RemoveStaticLease removes a static lease (thread-safe)
-func (s *v4Server) RemoveStaticLease(l Lease) error {
+// RemoveStaticLease removes a static lease.  It is safe for concurrent use.
+func (s *v4Server) RemoveStaticLease(l Lease) (err error) {
+	defer agherr.Annotate("dhcpv4: %w", &err)
+
 	if len(l.IP) != 4 {
 		return fmt.Errorf("invalid IP")
 	}
-	if len(l.HWAddr) != 6 {
-		return fmt.Errorf("invalid MAC")
+
+	err = aghnet.ValidateHardwareAddress(l.HWAddr)
+	if err != nil {
+		return fmt.Errorf("validating lease: %w", err)
 	}
 
 	s.leasesLock.Lock()
-	err := s.rmLease(l)
+	err = s.rmLease(l)
 	if err != nil {
 		s.leasesLock.Unlock()
 
@@ -688,8 +694,10 @@ func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4
 		return
 	}
 
-	if len(req.ClientHWAddr) != 6 {
-		log.Debug("dhcpv4: Invalid ClientHWAddr")
+	err = aghnet.ValidateHardwareAddress(req.ClientHWAddr)
+	if err != nil {
+		log.Error("dhcpv4: invalid ClientHWAddr: %s", err)
+
 		return
 	}
 
diff --git a/internal/dhcpd/v6.go b/internal/dhcpd/v6.go
index aff6d3e3..9b6d113b 100644
--- a/internal/dhcpd/v6.go
+++ b/internal/dhcpd/v6.go
@@ -9,6 +9,8 @@ import (
 	"sync"
 	"time"
 
+	"github.com/AdguardTeam/AdGuardHome/internal/agherr"
+	"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
 	"github.com/AdguardTeam/golibs/log"
 	"github.com/insomniacslk/dhcp/dhcpv6"
 	"github.com/insomniacslk/dhcp/dhcpv6/server6"
@@ -158,19 +160,23 @@ func (s *v6Server) rmDynamicLease(lease Lease) error {
 	return nil
 }
 
-// AddStaticLease - add a static lease
-func (s *v6Server) AddStaticLease(l Lease) error {
+// AddStaticLease adds a static lease.  It is safe for concurrent use.
+func (s *v6Server) AddStaticLease(l Lease) (err error) {
+	defer agherr.Annotate("dhcpv6: %w", &err)
+
 	if len(l.IP) != 16 {
 		return fmt.Errorf("invalid IP")
 	}
-	if len(l.HWAddr) != 6 {
-		return fmt.Errorf("invalid MAC")
+
+	err = aghnet.ValidateHardwareAddress(l.HWAddr)
+	if err != nil {
+		return fmt.Errorf("validating lease: %w", err)
 	}
 
 	l.Expiry = time.Unix(leaseExpireStatic, 0)
 
 	s.leasesLock.Lock()
-	err := s.rmDynamicLease(l)
+	err = s.rmDynamicLease(l)
 	if err != nil {
 		s.leasesLock.Unlock()
 		return err
@@ -183,17 +189,21 @@ func (s *v6Server) AddStaticLease(l Lease) error {
 	return nil
 }
 
-// RemoveStaticLease - remove a static lease
-func (s *v6Server) RemoveStaticLease(l Lease) error {
+// RemoveStaticLease removes a static lease.  It is safe for concurrent use.
+func (s *v6Server) RemoveStaticLease(l Lease) (err error) {
+	defer agherr.Annotate("dhcpv6: %w", &err)
+
 	if len(l.IP) != 16 {
 		return fmt.Errorf("invalid IP")
 	}
-	if len(l.HWAddr) != 6 {
-		return fmt.Errorf("invalid MAC")
+
+	err = aghnet.ValidateHardwareAddress(l.HWAddr)
+	if err != nil {
+		return fmt.Errorf("validating lease: %w", err)
 	}
 
 	s.leasesLock.Lock()
-	err := s.rmLease(l)
+	err = s.rmLease(l)
 	if err != nil {
 		s.leasesLock.Unlock()
 		return err
@@ -271,8 +281,10 @@ func (s *v6Server) findFreeIP() net.IP {
 
 // Reserve lease for MAC
 func (s *v6Server) reserveLease(mac net.HardwareAddr) *Lease {
-	l := Lease{}
-	l.HWAddr = make([]byte, 6)
+	l := Lease{
+		HWAddr: make([]byte, len(mac)),
+	}
+
 	copy(l.HWAddr, mac)
 
 	s.leasesLock.Lock()
@@ -564,7 +576,9 @@ func (s *v6Server) initRA(iface *net.Interface) error {
 }
 
 // Start starts the IPv6 DHCP server.
-func (s *v6Server) Start() error {
+func (s *v6Server) Start() (err error) {
+	defer agherr.Annotate("dhcpv6: %w", &err)
+
 	if !s.conf.Enabled {
 		return nil
 	}
@@ -572,14 +586,14 @@ func (s *v6Server) Start() error {
 	ifaceName := s.conf.InterfaceName
 	iface, err := net.InterfaceByName(ifaceName)
 	if err != nil {
-		return fmt.Errorf("dhcpv6: finding interface %s by name: %w", ifaceName, err)
+		return fmt.Errorf("finding interface %s by name: %w", ifaceName, err)
 	}
 
 	log.Debug("dhcpv6: starting...")
 
 	dnsIPAddrs, err := ifaceDNSIPAddrs(iface, ipVersion6, defaultMaxAttempts, defaultBackoff)
 	if err != nil {
-		return fmt.Errorf("dhcpv6: interface %s: %w", ifaceName, err)
+		return fmt.Errorf("interface %s: %w", ifaceName, err)
 	}
 
 	if len(dnsIPAddrs) == 0 {
@@ -596,15 +610,18 @@ func (s *v6Server) Start() error {
 
 	// don't initialize DHCPv6 server if we must force the clients to use SLAAC
 	if s.conf.RASLAACOnly {
-		log.Debug("DHCPv6: not starting DHCPv6 server due to ra_slaac_only=true")
+		log.Debug("not starting dhcpv6 server due to ra_slaac_only=true")
+
 		return nil
 	}
 
 	log.Debug("dhcpv6: listening...")
 
-	if len(iface.HardwareAddr) != 6 {
-		return fmt.Errorf("dhcpv6: invalid MAC %s", iface.HardwareAddr)
+	err = aghnet.ValidateHardwareAddress(iface.HardwareAddr)
+	if err != nil {
+		return fmt.Errorf("validating interface %s: %w", iface.Name, err)
 	}
+
 	s.sid = dhcpv6.Duid{
 		Type:          dhcpv6.DUID_LLT,
 		HwType:        iana.HWTypeEthernet,
@@ -623,7 +640,7 @@ func (s *v6Server) Start() error {
 
 	go func() {
 		err = s.srv.Serve()
-		log.Debug("DHCPv6: srv.Serve: %s", err)
+		log.Error("dhcpv6: srv.Serve: %s", err)
 	}()
 
 	return nil