diff --git a/internal/dhcpd/config.go b/internal/dhcpd/config.go index f7919fdf..0721ecd1 100644 --- a/internal/dhcpd/config.go +++ b/internal/dhcpd/config.go @@ -23,9 +23,12 @@ type ServerConfig struct { Enabled bool `yaml:"enabled"` InterfaceName string `yaml:"interface_name"` - // LocalDomainName is the domain name used for DHCP hosts. For example, - // a DHCP client with the hostname "myhost" can be addressed as "myhost.lan" + // LocalDomainName is the domain name used for DHCP hosts. For example, a + // DHCP client with the hostname "myhost" can be addressed as "myhost.lan" // when LocalDomainName is "lan". + // + // TODO(e.burkov): Probably, remove this field. See the TODO on + // [Interface.Enabled]. LocalDomainName string `yaml:"local_domain_name"` Conf4 V4ServerConf `yaml:"dhcpv4"` @@ -58,6 +61,14 @@ type DHCPServer interface { // there is one. FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) + // HostByIP returns a hostname by the IP address of its lease, if there is + // one. + HostByIP(ip netip.Addr) (host string) + + // IPByHost returns an IP address by the hostname of its lease, if there is + // one. + IPByHost(host string) (ip netip.Addr) + // WriteDiskConfig4 - copy disk configuration WriteDiskConfig4(c *V4ServerConf) // WriteDiskConfig6 - copy disk configuration diff --git a/internal/dhcpd/dhcpd.go b/internal/dhcpd/dhcpd.go index 3945c038..e4ee14a9 100644 --- a/internal/dhcpd/dhcpd.go +++ b/internal/dhcpd/dhcpd.go @@ -9,6 +9,7 @@ import ( "path/filepath" "time" + "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/timeutil" "golang.org/x/exp/slices" @@ -31,6 +32,8 @@ const ( // Lease contains the necessary information about a DHCP lease. It's used as is // in the database, so don't change it until it's absolutely necessary, see // [dataVersion]. +// +// TODO(e.burkov): Unexport it and use [dhcpsvc.Lease]. type Lease struct { // Expiry is the expiration time of the lease. Expiry time.Time `json:"expires"` @@ -153,53 +156,37 @@ const ( type Interface interface { Start() (err error) Stop() (err error) + + // Enabled returns true if the DHCP server is running. + // + // TODO(e.burkov): Currently, we need this method to determine whether the + // local domain suffix should be considered while resolving A/AAAA requests. + // This is because other parts of the code aren't aware of the DNS suffixes + // in DHCP clients names and caller is responsible for trimming it. This + // behavior should be changed in the future. Enabled() (ok bool) - Leases(flags GetLeasesFlags) (leases []*Lease) - SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT) - FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) + // Leases returns all the leases in the database. + Leases() (leases []*dhcpsvc.Lease) + + // MacByIP returns the MAC address of a client with ip. It returns nil if + // there is no such client, due to an assumption that a DHCP client must + // always have a HardwareAddr. + MACByIP(ip netip.Addr) (mac net.HardwareAddr) + + // HostByIP returns the hostname of the DHCP client with the given IP + // address. The address will be netip.Addr{} if there is no such client, + // due to an assumption that a DHCP client must always have an IP address. + HostByIP(ip netip.Addr) (host string) + + // IPByHost returns the IP address of the DHCP client with the given + // hostname. The address will be netip.Addr{} if there is no such client, + // due to an assumption that a DHCP client must always have an IP address. + IPByHost(host string) (ip netip.Addr) WriteDiskConfig(c *ServerConfig) } -// MockInterface is a mock Interface implementation. -// -// TODO(e.burkov): Move to aghtest when the API stabilized. -type MockInterface struct { - OnStart func() (err error) - OnStop func() (err error) - OnEnabled func() (ok bool) - OnLeases func(flags GetLeasesFlags) (leases []*Lease) - OnSetOnLeaseChanged func(f OnLeaseChangedT) - OnFindMACbyIP func(ip netip.Addr) (mac net.HardwareAddr) - OnWriteDiskConfig func(c *ServerConfig) -} - -var _ Interface = (*MockInterface)(nil) - -// Start implements the Interface for *MockInterface. -func (s *MockInterface) Start() (err error) { return s.OnStart() } - -// Stop implements the Interface for *MockInterface. -func (s *MockInterface) Stop() (err error) { return s.OnStop() } - -// Enabled implements the Interface for *MockInterface. -func (s *MockInterface) Enabled() (ok bool) { return s.OnEnabled() } - -// Leases implements the Interface for *MockInterface. -func (s *MockInterface) Leases(flags GetLeasesFlags) (ls []*Lease) { return s.OnLeases(flags) } - -// SetOnLeaseChanged implements the Interface for *MockInterface. -func (s *MockInterface) SetOnLeaseChanged(f OnLeaseChangedT) { s.OnSetOnLeaseChanged(f) } - -// FindMACbyIP implements the [Interface] for *MockInterface. -func (s *MockInterface) FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) { - return s.OnFindMACbyIP(ip) -} - -// WriteDiskConfig implements the Interface for *MockInterface. -func (s *MockInterface) WriteDiskConfig(c *ServerConfig) { s.OnWriteDiskConfig(c) } - // server is the DHCP service that handles DHCPv4, DHCPv6, and HTTP API. type server struct { srv4 DHCPServer @@ -269,7 +256,8 @@ func Create(conf *ServerConfig) (s *server, err error) { } // setServers updates DHCPv4 and DHCPv6 servers created from the provided -// configuration conf. +// configuration conf. It returns the status of both the DHCPv4 and the DHCPv6 +// servers, which is always false for corresponding server on any error. func (s *server) setServers(conf *ServerConfig) (v4Enabled, v6Enabled bool, err error) { v4conf := conf.Conf4 v4conf.InterfaceName = s.conf.InterfaceName @@ -279,7 +267,7 @@ func (s *server) setServers(conf *ServerConfig) (v4Enabled, v6Enabled bool, err s.srv4, err = v4Create(&v4conf) if err != nil { if v4conf.Enabled { - return true, false, fmt.Errorf("creating dhcpv4 srv: %w", err) + return false, false, fmt.Errorf("creating dhcpv4 srv: %w", err) } log.Debug("dhcpd: warning: creating dhcpv4 srv: %s", err) @@ -288,14 +276,11 @@ func (s *server) setServers(conf *ServerConfig) (v4Enabled, v6Enabled bool, err v6conf := conf.Conf6 v6conf.InterfaceName = s.conf.InterfaceName v6conf.notify = s.onNotify - v6conf.Enabled = s.conf.Enabled - if len(v6conf.RangeStart) == 0 { - v6conf.Enabled = false - } + v6conf.Enabled = s.conf.Enabled && len(v6conf.RangeStart) != 0 s.srv6, err = v6Create(v6conf) if err != nil { - return v4conf.Enabled, v6conf.Enabled, fmt.Errorf("creating dhcpv6 srv: %w", err) + return v4conf.Enabled, false, fmt.Errorf("creating dhcpv6 srv: %w", err) } return v4conf.Enabled, v6conf.Enabled, nil @@ -337,11 +322,6 @@ func (s *server) onNotify(flags uint32) { s.notify(int(flags)) } -// SetOnLeaseChanged - set callback -func (s *server) SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT) { - s.onLeaseChanged = append(s.onLeaseChanged, onLeaseChanged) -} - func (s *server) notify(flags int) { for _, f := range s.onLeaseChanged { f(flags) @@ -388,15 +368,26 @@ func (s *server) Stop() (err error) { return nil } -// Leases returns the list of active IPv4 and IPv6 DHCP leases. It's safe for -// concurrent use. -func (s *server) Leases(flags GetLeasesFlags) (leases []*Lease) { - return append(s.srv4.GetLeases(flags), s.srv6.GetLeases(flags)...) +// Leases returns the list of active DHCP leases. +func (s *server) Leases() (leases []*dhcpsvc.Lease) { + ls := append(s.srv4.GetLeases(LeasesAll), s.srv6.GetLeases(LeasesAll)...) + leases = make([]*dhcpsvc.Lease, len(ls)) + for i, l := range ls { + leases[i] = &dhcpsvc.Lease{ + Expiry: l.Expiry, + Hostname: l.Hostname, + HWAddr: l.HWAddr, + IP: l.IP, + IsStatic: l.IsStatic, + } + } + + return leases } -// FindMACbyIP returns a MAC address by the IP address of its lease, if there is +// MACByIP returns a MAC address by the IP address of its lease, if there is // one. -func (s *server) FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) { +func (s *server) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { if ip.Is4() { return s.srv4.FindMACbyIP(ip) } @@ -404,6 +395,24 @@ func (s *server) FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) { return s.srv6.FindMACbyIP(ip) } +// HostByIP implements the [Interface] interface for *server. +// +// TODO(e.burkov): Implement this method for DHCPv6. +func (s *server) HostByIP(ip netip.Addr) (host string) { + if ip.Is4() { + return s.srv4.HostByIP(ip) + } + + return "" +} + +// IPByHost implements the [Interface] interface for *server. +// +// TODO(e.burkov): Implement this method for DHCPv6. +func (s *server) IPByHost(host string) (ip netip.Addr) { + return s.srv4.IPByHost(host) +} + // AddStaticLease - add static v4 lease func (s *server) AddStaticLease(l *Lease) error { return s.srv4.AddStaticLease(l) diff --git a/internal/dhcpd/http_unix.go b/internal/dhcpd/http_unix.go index 2ba693cc..ecc0c9e3 100644 --- a/internal/dhcpd/http_unix.go +++ b/internal/dhcpd/http_unix.go @@ -14,9 +14,11 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" + "golang.org/x/exp/slices" ) type v4ServerConfJSON struct { @@ -75,7 +77,7 @@ type leaseStatic struct { } // leasesToStatic converts list of leases to their JSON form. -func leasesToStatic(leases []*Lease) (static []*leaseStatic) { +func leasesToStatic(leases []*dhcpsvc.Lease) (static []*leaseStatic) { static = make([]*leaseStatic, len(leases)) for i, l := range leases { @@ -113,7 +115,7 @@ type leaseDynamic struct { } // leasesToDynamic converts list of leases to their JSON form. -func leasesToDynamic(leases []*Lease) (dynamic []*leaseDynamic) { +func leasesToDynamic(leases []*dhcpsvc.Lease) (dynamic []*leaseDynamic) { dynamic = make([]*leaseDynamic, len(leases)) for i, l := range leases { @@ -143,8 +145,27 @@ func (s *server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) { s.srv4.WriteDiskConfig4(&status.V4) s.srv6.WriteDiskConfig6(&status.V6) - status.Leases = leasesToDynamic(s.Leases(LeasesDynamic)) - status.StaticLeases = leasesToStatic(s.Leases(LeasesStatic)) + leases := s.Leases() + slices.SortFunc(leases, func(a, b *dhcpsvc.Lease) (res int) { + if a.IsStatic == b.IsStatic { + return 0 + } else if a.IsStatic { + return -1 + } else { + return 1 + } + }) + + dynamicIdx := slices.IndexFunc(leases, func(l *dhcpsvc.Lease) (ok bool) { + return !l.IsStatic + }) + + if dynamicIdx == -1 { + dynamicIdx = len(leases) + } + + status.Leases = leasesToDynamic(leases[dynamicIdx:]) + status.StaticLeases = leasesToStatic(leases[:dynamicIdx]) aghhttp.WriteJSONResponseOK(w, r, status) } diff --git a/internal/dhcpd/v46_windows.go b/internal/dhcpd/v46_windows.go index dcdb3caf..dbe22055 100644 --- a/internal/dhcpd/v46_windows.go +++ b/internal/dhcpd/v46_windows.go @@ -24,6 +24,8 @@ func (winServer) WriteDiskConfig4(_ *V4ServerConf) {} func (winServer) WriteDiskConfig6(_ *V6ServerConf) {} func (winServer) Start() (err error) { return nil } func (winServer) Stop() (err error) { return nil } +func (winServer) HostByIP(_ netip.Addr) (host string) { return "" } +func (winServer) IPByHost(_ string) (ip netip.Addr) { return netip.Addr{} } func v4Create(_ *V4ServerConf) (s DHCPServer, err error) { return winServer{}, nil } func v6Create(_ V6ServerConf) (s DHCPServer, err error) { return winServer{}, nil } diff --git a/internal/dhcpd/v4_unix.go b/internal/dhcpd/v4_unix.go index 6a973b24..270a6072 100644 --- a/internal/dhcpd/v4_unix.go +++ b/internal/dhcpd/v4_unix.go @@ -15,7 +15,6 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" - "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/go-ping/ping" "github.com/insomniacslk/dhcp/dhcpv4" @@ -46,11 +45,14 @@ type v4Server struct { // leased. leasedOffsets *bitSet - // leaseHosts is the set of all hostnames of all known DHCP clients. - leaseHosts *stringutil.Set - // leases contains all dynamic and static leases. leases []*Lease + + // hostsIndex is the set of all hostnames of all known DHCP clients. + hostsIndex map[string]*Lease + + // ipIndex is an index of leases by their IP addresses. + ipIndex map[netip.Addr]*Lease } func (s *v4Server) enabled() (ok bool) { @@ -114,6 +116,30 @@ func (s *v4Server) validHostnameForClient(cliHostname string, ip netip.Addr) (ho return hostname } +// HostByIP implements the [Interface] interface for *v4Server. +func (s *v4Server) HostByIP(ip netip.Addr) (host string) { + s.leasesLock.Lock() + defer s.leasesLock.Unlock() + + if l, ok := s.ipIndex[ip]; ok { + return l.Hostname + } + + return "" +} + +// IPByHost implements the [Interface] interface for *v4Server. +func (s *v4Server) IPByHost(host string) (ip netip.Addr) { + s.leasesLock.Lock() + defer s.leasesLock.Unlock() + + if l, ok := s.hostsIndex[host]; ok { + return l.IP + } + + return netip.Addr{} +} + // ResetLeases resets leases. func (s *v4Server) ResetLeases(leases []*Lease) (err error) { defer func() { err = errors.Annotate(err, "dhcpv4: %w") }() @@ -123,7 +149,8 @@ func (s *v4Server) ResetLeases(leases []*Lease) (err error) { } s.leasedOffsets = newBitSet() - s.leaseHosts = stringutil.NewSet() + s.hostsIndex = make(map[string]*Lease, len(leases)) + s.ipIndex = make(map[netip.Addr]*Lease, len(leases)) s.leases = nil for _, l := range leases { @@ -199,20 +226,18 @@ func (s *v4Server) GetLeases(flags GetLeasesFlags) (leases []*Lease) { // FindMACbyIP implements the [Interface] for *v4Server. func (s *v4Server) FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) { + if !ip.Is4() { + return nil + } + now := time.Now() s.leasesLock.Lock() defer s.leasesLock.Unlock() - if !ip.Is4() { - return nil - } - - for _, l := range s.leases { - if l.IP == ip { - if l.IsStatic || l.Expiry.After(now) { - return l.HWAddr - } + if l, ok := s.ipIndex[ip]; ok { + if l.IsStatic || l.Expiry.After(now) { + return l.HWAddr } } @@ -249,7 +274,8 @@ func (s *v4Server) rmLeaseByIndex(i int) { s.leasedOffsets.set(offset, false) } - s.leaseHosts.Del(l.Hostname) + delete(s.hostsIndex, l.Hostname) + delete(s.ipIndex, l.IP) log.Debug("dhcpv4: removed lease %s (%s)", l.IP, l.HWAddr) } @@ -303,13 +329,15 @@ func (s *v4Server) addLease(l *Lease) (err error) { return fmt.Errorf("lease %s (%s) out of range, not adding", l.IP, l.HWAddr) } + // TODO(e.burkov): l must have a valid hostname here, investigate. if l.Hostname != "" { - if s.leaseHosts.Has(l.Hostname) { + if _, ok := s.hostsIndex[l.Hostname]; ok { return ErrDupHostname } - s.leaseHosts.Add(l.Hostname) + s.hostsIndex[l.Hostname] = l } + s.ipIndex[l.IP] = l s.leases = append(s.leases, l) s.leasedOffsets.set(offset, true) @@ -574,7 +602,7 @@ func (s *v4Server) commitLease(l *Lease, hostname string) { prev := l.Hostname hostname = s.validHostnameForClient(hostname, l.IP) - if s.leaseHosts.Has(hostname) { + if _, ok := s.hostsIndex[hostname]; ok { log.Info("dhcpv4: hostname %q already exists", hostname) if prev == "" { @@ -590,11 +618,12 @@ func (s *v4Server) commitLease(l *Lease, hostname string) { l.Expiry = time.Now().Add(s.conf.leaseTime) if prev != "" && prev != l.Hostname { - s.leaseHosts.Del(prev) + delete(s.hostsIndex, prev) } if l.Hostname != "" { - s.leaseHosts.Add(l.Hostname) + s.hostsIndex[l.Hostname] = l } + s.ipIndex[l.IP] = l } // allocateLease allocates a new lease for the MAC address. If there are no IP @@ -1292,7 +1321,8 @@ func (s *v4Server) Stop() (err error) { // Create DHCPv4 server func v4Create(conf *V4ServerConf) (srv *v4Server, err error) { s := &v4Server{ - leaseHosts: stringutil.NewSet(), + hostsIndex: map[string]*Lease{}, + ipIndex: map[netip.Addr]*Lease{}, } err = conf.Validate() diff --git a/internal/dhcpd/v4_unix_test.go b/internal/dhcpd/v4_unix_test.go index 162b5b88..5f6cac1b 100644 --- a/internal/dhcpd/v4_unix_test.go +++ b/internal/dhcpd/v4_unix_test.go @@ -791,6 +791,14 @@ func TestV4Server_FindMACbyIP(t *testing.T) { IP: anotherIP, }}, } + s.ipIndex = map[netip.Addr]*Lease{ + staticIP: s.leases[0], + anotherIP: s.leases[1], + } + s.hostsIndex = map[string]*Lease{ + staticName: s.leases[0], + anotherName: s.leases[1], + } testCases := []struct { want net.HardwareAddr diff --git a/internal/dhcpd/v6_unix.go b/internal/dhcpd/v6_unix.go index fa3640f9..f08ea19e 100644 --- a/internal/dhcpd/v6_unix.go +++ b/internal/dhcpd/v6_unix.go @@ -26,15 +26,14 @@ const valueIAID = "ADGH" // value for IANA.ID // // TODO(a.garipov): Think about unifying this and v4Server. type v6Server struct { - srv *server6.Server - leasesLock sync.Mutex - leases []*Lease - ipAddrs [256]byte - sid dhcpv6.DUID - - ra raCtx // RA module - + ra raCtx conf V6ServerConf + sid dhcpv6.DUID + srv *server6.Server + + leases []*Lease + leasesLock sync.Mutex + ipAddrs [256]byte } // WriteDiskConfig4 - write configuration @@ -59,6 +58,34 @@ func ip6InRange(start, ip net.IP) bool { return start[15] <= ip[15] } +// HostByIP implements the [Interface] interface for *v6Server. +func (s *v6Server) HostByIP(ip netip.Addr) (host string) { + s.leasesLock.Lock() + defer s.leasesLock.Unlock() + + for _, l := range s.leases { + if l.IP == ip { + return l.Hostname + } + } + + return "" +} + +// IPByHost implements the [Interface] interface for *v6Server. +func (s *v6Server) IPByHost(host string) (ip netip.Addr) { + s.leasesLock.Lock() + defer s.leasesLock.Unlock() + + for _, l := range s.leases { + if l.Hostname == host { + return l.IP + } + } + + return netip.Addr{} +} + // ResetLeases resets leases. func (s *v6Server) ResetLeases(leases []*Lease) (err error) { defer func() { err = errors.Annotate(err, "dhcpv6: %w") }() diff --git a/internal/dhcpsvc/dhcpsvc.go b/internal/dhcpsvc/dhcpsvc.go index 0c5d1bad..4b3f5c21 100644 --- a/internal/dhcpsvc/dhcpsvc.go +++ b/internal/dhcpsvc/dhcpsvc.go @@ -53,7 +53,7 @@ type Interface interface { // IPByHost returns the IP address of the DHCP client with the given // hostname. The hostname will be an empty string if there is no such // client, due to an assumption that a DHCP client must always have a - // hostname, either set by the client or assigned automatically. + // hostname, either set or generated. IPByHost(host string) (ip netip.Addr) // Leases returns all the DHCP leases. @@ -104,6 +104,9 @@ func (Empty) MACByIP(_ netip.Addr) (mac net.HardwareAddr) { return nil } // IPByHost implements the [Interface] interface for Empty. func (Empty) IPByHost(_ string) (ip netip.Addr) { return netip.Addr{} } +// type check +var _ Interface = Empty{} + // Leases implements the [Interface] interface for Empty. func (Empty) Leases() (leases []*Lease) { return nil } diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index c1ff6751..f6f1ef1d 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -15,7 +15,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/rdns" @@ -48,19 +47,7 @@ var defaultBlockedHosts = []string{"version.bind", "id.server", "hostname.bind"} var webRegistered bool -// hostToIPTable is a convenient type alias for tables of host names to an IP -// address. -// -// TODO(e.burkov): Use the [DHCP] interface instead. -type hostToIPTable = map[string]netip.Addr - -// ipToHostTable is a convenient type alias for tables of IP addresses to their -// host names. For example, for use with PTR queries. -// -// TODO(e.burkov): Use the [DHCP] interface instead. -type ipToHostTable = map[netip.Addr]string - -// DHCP is an interface for accessing DHCP lease data needed in this package. +// DHCP is an interface for accesing DHCP lease data needed in this package. type DHCP interface { // HostByIP returns the hostname of the DHCP client with the given IP // address. The address will be netip.Addr{} if there is no such client, @@ -89,18 +76,34 @@ type DHCP interface { // // The zero Server is empty and ready for use. type Server struct { - dnsProxy *proxy.Proxy // DNS proxy instance - dnsFilter *filtering.DNSFilter // DNS filter instance - dhcpServer dhcpd.Interface // DHCP server instance (optional) - queryLog querylog.QueryLog // Query log instance - stats stats.Interface - access *accessManager + // dnsProxy is the DNS proxy for forwarding client's DNS requests. + dnsProxy *proxy.Proxy + + // dnsFilter is the DNS filter for filtering client's DNS requests and + // responses. + dnsFilter *filtering.DNSFilter + + // dhcpServer is the DHCP server for accessing lease data. + dhcpServer DHCP + + // queryLog is the query log for client's DNS requests, responses and + // filtering results. + queryLog querylog.QueryLog + + // stats is the statistics collector for client's DNS usage data. + stats stats.Interface + + // access drops unallowed clients. + access *accessManager // localDomainSuffix is the suffix used to detect internal hosts. It // must be a valid domain name plus dots on each side. localDomainSuffix string - ipset ipsetCtx + // ipset processes DNS requests using ipset data. + ipset ipsetCtx + + // privateNets is the configured set of IP networks considered private. privateNets netutil.SubnetSet // addrProc, if not nil, is used to process clients' IP addresses with rDNS, @@ -112,7 +115,10 @@ type Server struct { // // TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy. localResolvers *proxy.Proxy - sysResolvers aghnet.SystemResolvers + + // sysResolvers used to fetch system resolvers to use by default for private + // PTR resolving. + sysResolvers aghnet.SystemResolvers // recDetector is a cache for recursive requests. It is used to detect // and prevent recursive requests only for private upstreams. @@ -128,12 +134,6 @@ type Server struct { // anonymizer masks the client's IP addresses if needed. anonymizer *aghnet.IPMut - tableHostToIP hostToIPTable - tableHostToIPLock sync.Mutex - - tableIPToHost ipToHostTable - tableIPToHostLock sync.Mutex - // clientIDCache is a temporary storage for ClientIDs that were extracted // during the BeforeRequestHandler stage. clientIDCache cache.Cache @@ -142,13 +142,16 @@ type Server struct { // We don't Start() it and so no listen port is required. internalProxy *proxy.Proxy + // isRunning is true if the DNS server is running. isRunning bool // protectionUpdateInProgress is used to make sure that only one goroutine // updating the protection configuration after a pause is running at a time. protectionUpdateInProgress atomic.Bool + // conf is the current configuration of the server. conf ServerConfig + // serverLock protects Server. serverLock sync.RWMutex } @@ -164,7 +167,7 @@ type DNSCreateParams struct { DNSFilter *filtering.DNSFilter Stats stats.Interface QueryLog querylog.QueryLog - DHCPServer dhcpd.Interface + DHCPServer DHCP PrivateNets netutil.SubnetSet Anonymizer *aghnet.IPMut LocalDomain string @@ -200,11 +203,12 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { p.Anonymizer = aghnet.NewIPMut(nil) } s = &Server{ - dnsFilter: p.DNSFilter, - stats: p.Stats, - queryLog: p.QueryLog, - privateNets: p.PrivateNets, - localDomainSuffix: localDomainSuffix, + dnsFilter: p.DNSFilter, + stats: p.Stats, + queryLog: p.QueryLog, + privateNets: p.PrivateNets, + // TODO(e.burkov): Use some case-insensitive string comparison. + localDomainSuffix: strings.ToLower(localDomainSuffix), recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum), clientIDCache: cache.New(cache.Config{ EnableLRU: true, @@ -220,11 +224,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { return nil, fmt.Errorf("initializing system resolvers: %w", err) } - if p.DHCPServer != nil { - s.dhcpServer = p.DHCPServer - s.dhcpServer.SetOnLeaseChanged(s.onDHCPLeaseChanged) - s.onDHCPLeaseChanged(dhcpd.LeaseChangedAdded) - } + s.dhcpServer = p.DHCPServer if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" { // Use plain DNS on MIPS, encryption is too slow diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index fda64807..46d475b2 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -22,7 +22,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" - "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" @@ -94,8 +93,13 @@ func createTestServer( f.SetEnabled(true) + dhcp := &testDHCP{ + OnEnabled: func() (ok bool) { return false }, + OnHostByIP: func(ip netip.Addr) (host string) { return "" }, + OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") }, + } s, err = NewServer(DNSCreateParams{ - DHCPServer: testDHCP, + DHCPServer: dhcp, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) @@ -863,8 +867,13 @@ func TestBlockedCustomIP(t *testing.T) { f, err := filtering.New(&filtering.Config{}, filters) require.NoError(t, err) + dhcp := &testDHCP{ + OnEnabled: func() (ok bool) { return false }, + OnHostByIP: func(_ netip.Addr) (host string) { panic("not implemented") }, + OnIPByHost: func(_ string) (ip netip.Addr) { panic("not implemented") }, + } s, err := NewServer(DNSCreateParams{ - DHCPServer: testDHCP, + DHCPServer: dhcp, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) @@ -1016,8 +1025,13 @@ func TestRewrite(t *testing.T) { f.SetEnabled(true) + dhcp := &testDHCP{ + OnEnabled: func() (ok bool) { return false }, + OnHostByIP: func(ip netip.Addr) (host string) { panic("not implemented") }, + OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") }, + } s, err := NewServer(DNSCreateParams{ - DHCPServer: testDHCP, + DHCPServer: dhcp, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) @@ -1112,22 +1126,25 @@ func publicKey(priv any) any { } } -var testDHCP = &dhcpd.MockInterface{ - OnStart: func() (err error) { panic("not implemented") }, - OnStop: func() (err error) { panic("not implemented") }, - OnEnabled: func() (ok bool) { return true }, - OnLeases: func(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) { - return []*dhcpd.Lease{{ - IP: netip.MustParseAddr("192.168.12.34"), - HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, - Hostname: "myhost", - }} - }, - OnSetOnLeaseChanged: func(olct dhcpd.OnLeaseChangedT) {}, - OnFindMACbyIP: func(ip netip.Addr) (mac net.HardwareAddr) { panic("not implemented") }, - OnWriteDiskConfig: func(c *dhcpd.ServerConfig) { panic("not implemented") }, +// testDHCP is a mock implementation of the [DHCP] interface. +type testDHCP struct { + OnHostByIP func(ip netip.Addr) (host string) + OnIPByHost func(host string) (ip netip.Addr) + OnEnabled func() (ok bool) } +// type check +var _ DHCP = (*testDHCP)(nil) + +// HostByIP implements the [DHCP] interface for *testDHCP. +func (d *testDHCP) HostByIP(ip netip.Addr) (host string) { return d.OnHostByIP(ip) } + +// IPByHost implements the [DHCP] interface for *testDHCP. +func (d *testDHCP) IPByHost(host string) (ip netip.Addr) { return d.OnIPByHost(host) } + +// IsClientHost implements the [DHCP] interface for *testDHCP. +func (d *testDHCP) Enabled() (ok bool) { return d.OnEnabled() } + func TestPTRResponseFromDHCPLeases(t *testing.T) { const localDomain = "lan" @@ -1135,8 +1152,14 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { require.NoError(t, err) s, err := NewServer(DNSCreateParams{ - DNSFilter: flt, - DHCPServer: testDHCP, + DNSFilter: flt, + DHCPServer: &testDHCP{ + OnEnabled: func() (ok bool) { return true }, + OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") }, + OnHostByIP: func(ip netip.Addr) (host string) { + return "myhost" + }, + }, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), LocalDomain: localDomain, }) @@ -1185,6 +1208,12 @@ func TestPTRResponseFromHosts(t *testing.T) { `)}, } + dhcp := &testDHCP{ + OnEnabled: func() (ok bool) { return false }, + OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") }, + OnHostByIP: func(ip netip.Addr) (host string) { return "" }, + } + var eventsCalledCounter uint32 hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{ OnEvents: func() (e <-chan struct{}) { @@ -1213,7 +1242,7 @@ func TestPTRResponseFromHosts(t *testing.T) { var s *Server s, err = NewServer(DNSCreateParams{ - DHCPServer: testDHCP, + DHCPServer: dhcp, DNSFilter: flt, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) diff --git a/internal/dnsforward/filter_test.go b/internal/dnsforward/filter_test.go index 7229cef1..842401f6 100644 --- a/internal/dnsforward/filter_test.go +++ b/internal/dnsforward/filter_test.go @@ -47,7 +47,11 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) { f.SetEnabled(true) s, err := NewServer(DNSCreateParams{ - DHCPServer: testDHCP, + DHCPServer: &testDHCP{ + OnEnabled: func() (ok bool) { return false }, + OnHostByIP: func(ip netip.Addr) (host string) { panic("not implemented") }, + OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") }, + }, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) @@ -219,7 +223,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) { f.SetEnabled(true) s, err := NewServer(DNSCreateParams{ - DHCPServer: testDHCP, + DHCPServer: &testDHCP{}, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) diff --git a/internal/dnsforward/process.go b/internal/dnsforward/process.go index 2abfbac4..3741a635 100644 --- a/internal/dnsforward/process.go +++ b/internal/dnsforward/process.go @@ -8,7 +8,6 @@ import ( "strings" "time" - "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" @@ -70,20 +69,26 @@ type dnsContext struct { // isLocalClient shows if client's IP address is from locally served // network. isLocalClient bool + + // isDHCPHost is true if the request for a local domain name and the DHCP is + // available for this request. + isDHCPHost bool } // resultCode is the result of a request processing function. type resultCode int const ( - // resultCodeSuccess is returned when a handler performed successfully, - // and the next handler must be called. + // resultCodeSuccess is returned when a handler performed successfully, and + // the next handler must be called. resultCodeSuccess resultCode = iota - // resultCodeFinish is returned when a handler performed successfully, - // and the processing of the request must be stopped. + + // resultCodeFinish is returned when a handler performed successfully, and + // the processing of the request must be stopped. resultCodeFinish - // resultCodeError is returned when a handler failed, and the processing - // of the request must be stopped. + + // resultCodeError is returned when a handler failed, and the processing of + // the request must be stopped. resultCodeError ) @@ -239,70 +244,6 @@ func (s *Server) processClientIP(addr net.Addr) { s.addrProc.Process(clientIP) } -func (s *Server) setTableHostToIP(t hostToIPTable) { - s.tableHostToIPLock.Lock() - defer s.tableHostToIPLock.Unlock() - - s.tableHostToIP = t -} - -func (s *Server) setTableIPToHost(t ipToHostTable) { - s.tableIPToHostLock.Lock() - defer s.tableIPToHostLock.Unlock() - - s.tableIPToHost = t -} - -func (s *Server) onDHCPLeaseChanged(flags int) { - switch flags { - case dhcpd.LeaseChangedAdded, - dhcpd.LeaseChangedAddedStatic, - dhcpd.LeaseChangedRemovedStatic: - // Go on. - case dhcpd.LeaseChangedRemovedAll: - s.setTableHostToIP(nil) - s.setTableIPToHost(nil) - - return - default: - return - } - - ll := s.dhcpServer.Leases(dhcpd.LeasesAll) - hostToIP := make(hostToIPTable, len(ll)) - ipToHost := make(ipToHostTable, len(ll)) - - for _, l := range ll { - // TODO(a.garipov): Remove this after we're finished with the client - // hostname validations in the DHCP server code. - err := netutil.ValidateHostname(l.Hostname) - if err != nil { - log.Debug("dnsforward: skipping invalid hostname %q from dhcp: %s", l.Hostname, err) - - continue - } - - lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix) - - // Assume that we only process IPv4 now. - if !l.IP.Is4() { - log.Debug("dnsforward: skipping invalid ip from dhcp: bad ipv4 net.IP %v", l.IP) - - continue - } - - leaseIP := l.IP - - ipToHost[leaseIP] = lowhost - hostToIP[lowhost] = leaseIP - } - - s.setTableHostToIP(hostToIP) - s.setTableIPToHost(ipToHost) - - log.Debug("dnsforward: added %d a and ptr entries from dhcp", len(ipToHost)) -} - // processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB // queries. The response contains different types of encryption supported by // current user configuration. @@ -420,18 +361,6 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) { return rc } -// dhcpHostToIP tries to get an IP leased by DHCP and returns the copy of -// address since the data inside the internal table may be changed while request -// processing. It's safe for concurrent use. -func (s *Server) dhcpHostToIP(host string) (ip netip.Addr, ok bool) { - s.tableHostToIPLock.Lock() - defer s.tableHostToIPLock.Unlock() - - ip, ok = s.tableHostToIP[host] - - return ip, ok -} - // processDHCPHosts respond to A requests if the target hostname is known to // the server. It responds with a mapped IP address if the DNS64 is enabled and // the request is for AAAA. @@ -443,30 +372,31 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) { pctx := dctx.proxyCtx req := pctx.Req - q := req.Question[0] - reqHost, ok := s.isDHCPClientHostQ(q) - if !ok { + + q := &req.Question[0] + dhcpHost := s.dhcpHostFromRequest(q) + if dctx.isDHCPHost = dhcpHost != ""; !dctx.isDHCPHost { return resultCodeSuccess } if !dctx.isLocalClient { - log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, reqHost) + log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost) pctx.Res = s.genNXDomain(req) // Do not even put into query log. return resultCodeFinish } - ip, ok := s.dhcpHostToIP(reqHost) - if !ok { + ip := s.dhcpServer.IPByHost(dhcpHost) + if ip == (netip.Addr{}) { // Go on and process them with filters, including dnsrewrite ones, and // possibly route them to a domain-specific upstream. - log.Debug("dnsforward: no dhcp record for %q", reqHost) + log.Debug("dnsforward: no dhcp record for %q", dhcpHost) return resultCodeSuccess } - log.Debug("dnsforward: dhcp record for %q is %s", reqHost, ip) + log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip) resp := s.makeResponse(req) switch q.Qtype { @@ -638,17 +568,6 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) { return resultCodeSuccess } -// ipToDHCPHost tries to get a hostname leased by DHCP. It's safe for -// concurrent use. -func (s *Server) ipToDHCPHost(ip netip.Addr) (host string, ok bool) { - s.tableIPToHostLock.Lock() - defer s.tableIPToHostLock.Unlock() - - host, ok = s.tableIPToHost[ip] - - return host, ok -} - // processDHCPAddrs responds to PTR requests if the target IP is leased by the // DHCP server. func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { @@ -673,12 +592,12 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { return resultCodeSuccess } - host, ok := s.ipToDHCPHost(ipAddr) - if !ok { + host := s.dhcpServer.HostByIP(ipAddr) + if host == "" { return resultCodeSuccess } - log.Debug("dnsforward: dhcp reverse record for %s is %q", ip, host) + log.Debug("dnsforward: dhcp client %s is %q", ip, host) req := pctx.Req resp := s.makeResponse(req) @@ -686,10 +605,12 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { Hdr: dns.RR_Header{ Name: req.Question[0].Name, Rrtype: dns.TypePTR, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, + // TODO(e.burkov): Use [dhcpsvc.Lease.Expiry]. See + // https://github.com/AdguardTeam/AdGuardHome/issues/3932. + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, }, - Ptr: dns.Fqdn(host), + Ptr: dns.Fqdn(strings.Join([]string{host, s.localDomainSuffix}, ".")), } resp.Answer = append(resp.Answer, ptr) pctx.Res = resp @@ -788,17 +709,18 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) { pctx := dctx.proxyCtx req := pctx.Req - q := req.Question[0] + if pctx.Res != nil { // The response has already been set. return resultCodeSuccess - } else if reqHost, ok := s.isDHCPClientHostQ(q); ok { + } else if dctx.isDHCPHost { // A DHCP client hostname query that hasn't been handled or filtered. // Respond with an NXDOMAIN. // // TODO(a.garipov): Route such queries to a custom upstream for the // local domain name if there is one. - log.Debug("dnsforward: dhcp client hostname %q was not filtered", reqHost) + name := req.Question[0].Name + log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1]) pctx.Res = s.genNXDomain(req) return resultCodeFinish @@ -885,26 +807,26 @@ func (s *Server) setRespAD(pctx *proxy.DNSContext, reqWantsDNSSEC bool) { } } -// isDHCPClientHostQ returns true if q is from a request for a DHCP client -// hostname. If ok is true, reqHost contains the requested hostname. -func (s *Server) isDHCPClientHostQ(q dns.Question) (reqHost string, ok bool) { +// dhcpHostFromRequest returns a hostname from question, if the request is for a +// DHCP client's hostname when DHCP is enabled, and an empty string otherwise. +func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) { if !s.dhcpServer.Enabled() { - return "", false + return "" } // Include AAAA here, because despite the fact that we don't support it yet, // the expected behavior here is to respond with an empty answer and not // NXDOMAIN. if qt := q.Qtype; qt != dns.TypeA && qt != dns.TypeAAAA { - return "", false + return "" } reqHost = strings.ToLower(q.Name[:len(q.Name)-1]) - if strings.HasSuffix(reqHost, s.localDomainSuffix) { - return reqHost, true + if !netutil.IsImmediateSubdomain(reqHost, s.localDomainSuffix) { + return "" } - return "", false + return reqHost[:len(reqHost)-len(s.localDomainSuffix)-1] } // setCustomUpstream sets custom upstream settings in pctx, if necessary. diff --git a/internal/dnsforward/process_internal_test.go b/internal/dnsforward/process_internal_test.go index e5903439..014035c3 100644 --- a/internal/dnsforward/process_internal_test.go +++ b/internal/dnsforward/process_internal_test.go @@ -416,47 +416,58 @@ func TestServer_ProcessDetermineLocal(t *testing.T) { } func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { + const ( + localDomainSuffix = "lan" + dhcpClient = "example" + + knownHost = dhcpClient + "." + localDomainSuffix + unknownHost = "wronghost." + localDomainSuffix + ) + knownIP := netip.MustParseAddr("1.2.3.4") + dhcp := &testDHCP{ + OnEnabled: func() (_ bool) { return true }, + OnIPByHost: func(host string) (ip netip.Addr) { + if host == dhcpClient { + ip = knownIP + } + + return ip + }, + } + testCases := []struct { wantIP netip.Addr name string host string - wantRes resultCode isLocalCli bool }{{ wantIP: knownIP, name: "local_client_success", - host: "example.lan", - wantRes: resultCodeSuccess, + host: knownHost, isLocalCli: true, }, { wantIP: netip.Addr{}, name: "local_client_unknown_host", - host: "wronghost.lan", - wantRes: resultCodeSuccess, + host: unknownHost, isLocalCli: true, }, { wantIP: netip.Addr{}, name: "external_client_known_host", - host: "example.lan", - wantRes: resultCodeFinish, + host: knownHost, isLocalCli: false, }, { wantIP: netip.Addr{}, name: "external_client_unknown_host", - host: "wronghost.lan", - wantRes: resultCodeFinish, + host: unknownHost, isLocalCli: false, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { s := &Server{ - dhcpServer: testDHCP, - localDomainSuffix: defaultLocalDomainSuffix, - tableHostToIP: hostToIPTable{ - "example." + defaultLocalDomainSuffix: knownIP, - }, + dhcpServer: dhcp, + localDomainSuffix: localDomainSuffix, } req := &dns.Msg{ @@ -478,43 +489,52 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { } res := s.processDHCPHosts(dctx) - require.Equal(t, tc.wantRes, res) + pctx := dctx.proxyCtx - if tc.wantRes == resultCodeFinish { + if !tc.isLocalCli { + require.Equal(t, resultCodeFinish, res) require.NotNil(t, pctx.Res) assert.Equal(t, dns.RcodeNameError, pctx.Res.Rcode) - assert.Len(t, pctx.Res.Answer, 0) + assert.Empty(t, pctx.Res.Answer) return } + require.Equal(t, resultCodeSuccess, res) + if tc.wantIP == (netip.Addr{}) { assert.Nil(t, pctx.Res) - } else { - require.NotNil(t, pctx.Res) - ans := pctx.Res.Answer - require.Len(t, ans, 1) - - a := testutil.RequireTypeAssert[*dns.A](t, ans[0]) - - ip, err := netutil.IPToAddr(a.A, netutil.AddrFamilyIPv4) - require.NoError(t, err) - - assert.Equal(t, tc.wantIP, ip) + return } + + require.NotNil(t, pctx.Res) + + ans := pctx.Res.Answer + require.Len(t, ans, 1) + + a := testutil.RequireTypeAssert[*dns.A](t, ans[0]) + + ip, err := netutil.IPToAddr(a.A, netutil.AddrFamilyIPv4) + require.NoError(t, err) + + assert.Equal(t, tc.wantIP, ip) }) } } func TestServer_ProcessDHCPHosts(t *testing.T) { const ( - examplecom = "example.com" - examplelan = "example." + defaultLocalDomainSuffix + localTLD = "lan" + + knownClient = "example" + externalHost = knownClient + ".com" + clientHost = knownClient + "." + localTLD ) knownIP := netip.MustParseAddr("1.2.3.4") + testCases := []struct { wantIP netip.Addr name string @@ -524,55 +544,64 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { qtyp uint16 }{{ wantIP: netip.Addr{}, - name: "success_external", - host: examplecom, - suffix: defaultLocalDomainSuffix, + name: "external", + host: externalHost, + suffix: localTLD, wantRes: resultCodeSuccess, qtyp: dns.TypeA, }, { wantIP: netip.Addr{}, - name: "success_external_non_a", - host: examplecom, - suffix: defaultLocalDomainSuffix, + name: "external_non_a", + host: externalHost, + suffix: localTLD, wantRes: resultCodeSuccess, qtyp: dns.TypeCNAME, }, { wantIP: knownIP, - name: "success_internal", - host: examplelan, - suffix: defaultLocalDomainSuffix, + name: "internal", + host: clientHost, + suffix: localTLD, wantRes: resultCodeSuccess, qtyp: dns.TypeA, }, { wantIP: netip.Addr{}, - name: "success_internal_unknown", + name: "internal_unknown", host: "example-new.lan", - suffix: defaultLocalDomainSuffix, + suffix: localTLD, wantRes: resultCodeSuccess, qtyp: dns.TypeA, }, { wantIP: netip.Addr{}, - name: "success_internal_aaaa", - host: examplelan, - suffix: defaultLocalDomainSuffix, + name: "internal_aaaa", + host: clientHost, + suffix: localTLD, wantRes: resultCodeSuccess, qtyp: dns.TypeAAAA, }, { wantIP: knownIP, - name: "success_custom_suffix", - host: "example.custom", + name: "custom_suffix", + host: knownClient + ".custom", suffix: "custom", wantRes: resultCodeSuccess, qtyp: dns.TypeA, }} for _, tc := range testCases { + testDHCP := &testDHCP{ + OnEnabled: func() (_ bool) { return true }, + OnIPByHost: func(host string) (ip netip.Addr) { + if host == knownClient { + ip = knownIP + } + + return ip + }, + OnHostByIP: func(ip netip.Addr) (host string) { panic("not implemented") }, + } + s := &Server{ dhcpServer: testDHCP, localDomainSuffix: tc.suffix, - tableHostToIP: hostToIPTable{ - "example." + tc.suffix: knownIP, - }, } req := &dns.Msg{ @@ -597,13 +626,6 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { res := s.processDHCPHosts(dctx) pctx := dctx.proxyCtx assert.Equal(t, tc.wantRes, res) - if tc.wantRes == resultCodeFinish { - require.NotNil(t, pctx.Res) - assert.Equal(t, dns.RcodeNameError, pctx.Res.Rcode) - - return - } - require.NoError(t, dctx.err) if tc.qtyp == dns.TypeAAAA { diff --git a/internal/home/clients.go b/internal/home/clients.go index 06aec5c6..6a0ab9fd 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -12,7 +12,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" @@ -35,7 +34,7 @@ type DHCP interface { // HostByIP returns the hostname of the DHCP client with the given IP // address. The address will be netip.Addr{} if there is no such client, - // due to an assumption that a DHCP client must always have an IP address. + // due to an assumption that a DHCP client must always have a hostname. HostByIP(ip netip.Addr) (host string) // MACByIP returns the MAC address for the given IP address leased. It @@ -56,8 +55,8 @@ type clientsContainer struct { allTags *stringutil.Set - // dhcpServer is used for looking up clients IP addresses by MAC addresses - dhcpServer dhcpd.Interface + // dhcp is the DHCP service implementation. + dhcp DHCP // dnsServer is used for checking clients IP status access list status dnsServer *dnsforward.Server @@ -94,7 +93,7 @@ type clientsContainer struct { // Note: this function must be called only once func (clients *clientsContainer) Init( objects []*clientObject, - dhcpServer dhcpd.Interface, + dhcpServer DHCP, etcHosts *aghnet.HostsContainer, arpDB arpdb.Interface, filteringConf *filtering.Config, @@ -109,7 +108,9 @@ func (clients *clientsContainer) Init( clients.allTags = stringutil.NewSet(clientTags...) - clients.dhcpServer = dhcpServer + // TODO(e.burkov): Use [dhcpsvc] implementation when it's ready. + clients.dhcp = dhcpServer + clients.etcHosts = etcHosts clients.arpDB = arpDB err = clients.addFromConfig(objects, filteringConf) @@ -125,11 +126,6 @@ func (clients *clientsContainer) Init( return nil } - if clients.dhcpServer != nil { - clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged) - clients.onDHCPLeaseChanged(dhcpd.LeaseChangedAdded) - } - if clients.etcHosts != nil { go clients.handleHostsUpdates() } @@ -310,40 +306,6 @@ func (clients *clientsContainer) periodicUpdate() { } } -// onDHCPLeaseChanged is a callback for the DHCP server. It updates the list of -// runtime clients using the DHCP server's leases. -// -// TODO(e.burkov): Remove when switched to dhcpsvc. -func (clients *clientsContainer) onDHCPLeaseChanged(flags int) { - if clients.dhcpServer == nil || !config.Clients.Sources.DHCP { - return - } - - clients.lock.Lock() - defer clients.lock.Unlock() - - clients.rmHostsBySrc(ClientSourceDHCP) - - if flags == dhcpd.LeaseChangedRemovedAll { - return - } - - leases := clients.dhcpServer.Leases(dhcpd.LeasesAll) - n := 0 - for _, l := range leases { - if l.Hostname == "" { - continue - } - - ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP) - if ok { - n++ - } - } - - log.Debug("clients: added %d client aliases from dhcp", n) -} - // clientSource checks if client with this IP address already exists and returns // the source which updated it last. It returns [ClientSourceNone] if the // client doesn't exist. @@ -358,10 +320,14 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src clientSource) rc, ok := clients.ipToRC[ip] if ok { - return rc.Source + src = rc.Source } - return ClientSourceNone + if src < ClientSourceDHCP && clients.dhcp.HostByIP(ip) != "" { + src = ClientSourceDHCP + } + + return src } // findMultiple is a wrapper around Find to make it a valid client finder for @@ -522,17 +488,14 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) { } } - if clients.dhcpServer != nil { - return clients.findDHCP(ip) - } - - return nil, false + // TODO(e.burkov): Iterate through clients.list only once. + return clients.findDHCP(ip) } // findDHCP searches for a client by its MAC, if the DHCP server is active and // there is such client. clients.lock is expected to be locked. func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *Client, ok bool) { - foundMAC := clients.dhcpServer.FindMACbyIP(ip) + foundMAC := clients.dhcp.MACByIP(ip) if foundMAC == nil { return nil, false } @@ -553,8 +516,9 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *Client, ok bool) { return nil, false } -// findRuntimeClient finds a runtime client by their IP. -func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) { +// runtimeClient returns a runtime client from internal index. Note that it +// doesn't include DHCP clients. +func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) { if ip == (netip.Addr{}) { return nil, false } @@ -567,6 +531,24 @@ func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeCl return rc, ok } +// findRuntimeClient finds a runtime client by their IP. +func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) { + if rc, ok = clients.runtimeClient(ip); ok && rc.Source > ClientSourceDHCP { + return rc, ok + } + + host := clients.dhcp.HostByIP(ip) + if host == "" { + return rc, ok + } + + return &RuntimeClient{ + Host: host, + Source: ClientSourceDHCP, + WHOIS: &whois.Info{}, + }, true +} + // check validates the client. func (clients *clientsContainer) check(c *Client) (err error) { switch { @@ -824,10 +806,15 @@ func (clients *clientsContainer) addHostLocked( ) (ok bool) { rc, ok := clients.ipToRC[ip] if !ok { + if src < ClientSourceDHCP { + if clients.dhcp.HostByIP(ip) != "" { + return false + } + } + rc = &RuntimeClient{ WHOIS: &whois.Info{}, } - clients.ipToRC[ip] = rc } else if src < rc.Source { return false diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index d3ff2a57..5c92a896 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -8,21 +8,44 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" + "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +type testDHCP struct { + OnLeases func() (leases []*dhcpsvc.Lease) + OnHostBy func(ip netip.Addr) (host string) + OnMACBy func(ip netip.Addr) (mac net.HardwareAddr) +} + +// Lease implements the [DHCP] interface for testDHCP. +func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() } + +// HostByIP implements the [DHCP] interface for testDHCP. +func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) } + +// MACByIP implements the [DHCP] interface for testDHCP. +func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) } + // newClientsContainer is a helper that creates a new clients container for // tests. func newClientsContainer(t *testing.T) (c *clientsContainer) { + t.Helper() + c = &clientsContainer{ testing: true, } - err := c.Init(nil, nil, nil, nil, &filtering.Config{}) - require.NoError(t, err) + dhcp := &testDHCP{ + OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") }, + OnHostBy: func(ip netip.Addr) (host string) { return "" }, + OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil }, + } + + require.NoError(t, c.Init(nil, dhcp, nil, nil, &filtering.Config{})) return c } @@ -288,7 +311,7 @@ func TestClientsAddExisting(t *testing.T) { dhcpServer, err := dhcpd.Create(config) require.NoError(t, err) - clients.dhcpServer = dhcpServer + clients.dhcp = dhcpServer err = dhcpServer.AddStaticLease(&dhcpd.Lease{ HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 240ef4de..6b49881c 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -123,6 +123,17 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http data.RuntimeClients = append(data.RuntimeClients, cj) } + for _, l := range clients.dhcp.Leases() { + cj := runtimeClientJSON{ + Name: l.Hostname, + Source: ClientSourceDHCP, + IP: l.IP, + WHOIS: &whois.Info{}, + } + + data.RuntimeClients = append(data.RuntimeClients, cj) + } + data.Tags = clientTags aghhttp.WriteJSONResponseOK(w, r, data) diff --git a/internal/home/dns.go b/internal/home/dns.go index bc485e36..cf1a1198 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -14,7 +14,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" @@ -123,7 +122,7 @@ func initDNSServer( filters *filtering.DNSFilter, sts stats.Interface, qlog querylog.QueryLog, - dhcpSrv dhcpd.Interface, + dhcpSrv dnsforward.DHCP, anonymizer *aghnet.IPMut, httpReg aghhttp.RegisterFunc, tlsConf *tlsConfigSettings, diff --git a/internal/home/home_test.go b/internal/home/home_test.go index 2ce1d76d..c56f3955 100644 --- a/internal/home/home_test.go +++ b/internal/home/home_test.go @@ -7,6 +7,6 @@ import ( ) func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) initCmdLineOpts() + testutil.DiscardLogOutput(m) }