From e10a3fa4b3c416e655af88d472c3184fce6c7b4b Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Wed, 24 Mar 2021 14:52:37 +0300 Subject: [PATCH] Pull request: home: fix dns address fallback Closes #2868. Squashed commit of the following: commit 7497b0d80233fa0f0fbdc94a85007d39566eea73 Author: Ainar Garipov Date: Wed Mar 24 14:23:41 2021 +0300 home: fix specified ip collecting commit 7b1dfa69f4edeb3e07cd1067f77ff8519bcdbe1c Author: Ainar Garipov Date: Wed Mar 24 14:01:25 2021 +0300 home: fix dns address fallback --- internal/home/control.go | 94 +++++++++++++++++++++++++++++++++++----- internal/home/dns.go | 43 +----------------- 2 files changed, 86 insertions(+), 51 deletions(-) diff --git a/internal/home/control.go b/internal/home/control.go index 57e8e731..3a9efa1b 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -34,15 +34,79 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface http.Error(w, text, code) } -// --------------- -// dns run control -// --------------- -func addDNSAddress(dnsAddresses *[]string, addr net.IP) { - hostport := addr.String() - if config.DNS.Port != 53 { - hostport = net.JoinHostPort(hostport, strconv.Itoa(config.DNS.Port)) +// appendDNSAddrs is a convenient helper for appending a formatted form of DNS +// addresses to a slice of strings. +func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) { + for _, addr := range addrs { + hostport := addr.String() + if config.DNS.Port != 53 { + hostport = net.JoinHostPort(hostport, strconv.Itoa(config.DNS.Port)) + } + + dst = append(dst, hostport) } - *dnsAddresses = append(*dnsAddresses, hostport) + + return dst +} + +// appendDNSAddrsWithIfaces formats and appends all DNS addresses from src to +// dst. It also adds the IP addresses of all network interfaces if src contains +// an unspecified IP addresss. +func appendDNSAddrsWithIfaces(dst []string, src []net.IP) (res []string, err error) { + ifacesAdded := false + for _, h := range src { + if !h.IsUnspecified() { + dst = appendDNSAddrs(dst, h) + + continue + } else if ifacesAdded { + continue + } + + // Add addresses of all network interfaces for addresses like + // "0.0.0.0" and "::". + var ifaces []*aghnet.NetInterface + ifaces, err = aghnet.GetValidNetInterfacesForWeb() + if err != nil { + return nil, fmt.Errorf("cannot get network interfaces: %w", err) + } + + for _, iface := range ifaces { + dst = appendDNSAddrs(dst, iface.Addresses...) + } + + ifacesAdded = true + } + + return dst, nil +} + +// collectDNSAddresses returns the list of DNS addresses the server is listening +// on, including the addresses on all interfaces in cases of unspecified IPs. +func collectDNSAddresses() (addrs []string, err error) { + if hosts := config.DNS.BindHosts; len(hosts) == 0 { + addrs = appendDNSAddrs(addrs, net.IP{127, 0, 0, 1}) + } else { + addrs, err = appendDNSAddrsWithIfaces(addrs, hosts) + if err != nil { + return nil, fmt.Errorf("collecting dns addresses: %w", err) + } + } + + de := getDNSEncryption() + if de.https != "" { + addrs = append(addrs, de.https) + } + + if de.tls != "" { + addrs = append(addrs, de.tls) + } + + if de.quic != "" { + addrs = append(addrs, de.quic) + } + + return addrs, nil } // statusResponse is a response for /control/status endpoint. @@ -60,8 +124,17 @@ type statusResponse struct { } func handleStatus(w http.ResponseWriter, _ *http.Request) { + dnsAddrs, err := collectDNSAddresses() + if err != nil { + // Don't add a lot of formatting, since the error is already + // wrapped by collectDNSAddresses. + httpError(w, http.StatusInternalServerError, "%s", err) + + return + } + resp := statusResponse{ - DNSAddrs: getDNSAddresses(), + DNSAddrs: dnsAddrs, DNSPort: config.DNS.Port, HTTPPort: config.BindPort, IsRunning: isRunning(), @@ -82,9 +155,10 @@ func handleStatus(w http.ResponseWriter, _ *http.Request) { } w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(resp) + err = json.NewEncoder(w).Encode(resp) if err != nil { httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err) + return } } diff --git a/internal/home/dns.go b/internal/home/dns.go index 9692dbc4..c3ef8c09 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -9,7 +9,6 @@ import ( "strconv" "github.com/AdguardTeam/AdGuardHome/internal/agherr" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/querylog" @@ -144,10 +143,8 @@ func ipsToUDPAddrs(ips []net.IP, port int) (udpAddrs []*net.UDPAddr) { func generateServerConfig() (newConf dnsforward.ServerConfig, err error) { dnsConf := config.DNS hosts := dnsConf.BindHosts - for i, h := range hosts { - if h.IsUnspecified() { - hosts[i] = net.IP{127, 0, 0, 1} - } + if len(hosts) == 0 { + hosts = []net.IP{{127, 0, 0, 1}} } newConf = dnsforward.ServerConfig{ @@ -268,42 +265,6 @@ func getDNSEncryption() (de dnsEncryption) { return de } -// Get the list of DNS addresses the server is listening on -func getDNSAddresses() (dnsAddrs []string) { - if hosts := config.DNS.BindHosts; len(hosts) == 0 || hosts[0].IsUnspecified() { - ifaces, e := aghnet.GetValidNetInterfacesForWeb() - if e != nil { - log.Error("Couldn't get network interfaces: %v", e) - return []string{} - } - - for _, iface := range ifaces { - for _, addr := range iface.Addresses { - addDNSAddress(&dnsAddrs, addr) - } - } - } else { - for _, h := range hosts { - addDNSAddress(&dnsAddrs, h) - } - } - - de := getDNSEncryption() - if de.https != "" { - dnsAddrs = append(dnsAddrs, de.https) - } - - if de.tls != "" { - dnsAddrs = append(dnsAddrs, de.tls) - } - - if de.quic != "" { - dnsAddrs = append(dnsAddrs, de.quic) - } - - return dnsAddrs -} - // applyAdditionalFiltering adds additional client information and settings if // the client has them. func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *dnsfilter.RequestFilteringSettings) {