From 298f74ba814770306501bc1a594f18d5ba59d1ff Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Fri, 6 Nov 2020 17:34:40 +0300 Subject: [PATCH] Pull request: * all: allow multiple hosts in reverse lookups Merge in DNS/adguard-home from 2269-multiple-hosts to master For #2269. Squashed commit of the following: commit f8ae452540b106f2d5b130b8edb08c4e76b003f4 Merge: 8dd06f7cc 3e1f92225 Author: Ainar Garipov Date: Fri Nov 6 17:28:12 2020 +0300 Merge branch 'master' into 2269-multiple-hosts commit 8dd06f7cca27ec32a4690e2673603b166f82af0a Author: Ainar Garipov Date: Thu Nov 5 20:28:33 2020 +0300 * all: allow multiple hosts in reverse lookups --- internal/dnsfilter/dnsfilter.go | 43 +++++++++++++------ internal/dnsforward/filter.go | 28 ++++++------ internal/home/clients.go | 19 ++++---- internal/util/auto_hosts.go | 74 ++++++++++++++++++++------------ internal/util/auto_hosts_test.go | 17 +++++--- 5 files changed, 116 insertions(+), 65 deletions(-) diff --git a/internal/dnsfilter/dnsfilter.go b/internal/dnsfilter/dnsfilter.go index 6e733364..a4ac31b8 100644 --- a/internal/dnsfilter/dnsfilter.go +++ b/internal/dnsfilter/dnsfilter.go @@ -1,3 +1,4 @@ +// Package dnsfilter implements a DNS filter. package dnsfilter import ( @@ -281,7 +282,7 @@ type Result struct { CanonName string `json:",omitempty"` // CNAME value // for RewriteEtcHosts: - ReverseHost string `json:",omitempty"` + ReverseHosts []string `json:",omitempty"` // for ReasonRewrite & RewriteEtcHosts: IPList []net.IP `json:",omitempty"` // list of IP addresses @@ -325,18 +326,9 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering // Now check the hosts file -- do we have any rules for it? // just like DNS rewrites, it has higher priority than filtering rules. if d.Config.AutoHosts != nil { - ips := d.Config.AutoHosts.Process(host, qtype) - if ips != nil { - result.Reason = RewriteEtcHosts - result.IPList = ips - return result, nil - } - - revHost := d.Config.AutoHosts.ProcessReverse(host, qtype) - if len(revHost) != 0 { - result.Reason = RewriteEtcHosts - result.ReverseHost = revHost + "." - return result, nil + matched, err := d.checkAutoHosts(host, qtype, &result) + if matched { + return result, err } } @@ -401,6 +393,31 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering return Result{}, nil } +func (d *Dnsfilter) checkAutoHosts(host string, qtype uint16, result *Result) (matched bool, err error) { + ips := d.Config.AutoHosts.Process(host, qtype) + if ips != nil { + result.Reason = RewriteEtcHosts + result.IPList = ips + + return true, nil + } + + revHosts := d.Config.AutoHosts.ProcessReverse(host, qtype) + if len(revHosts) != 0 { + result.Reason = RewriteEtcHosts + + // TODO(a.garipov): Optimize this with a buffer. + result.ReverseHosts = make([]string, len(revHosts)) + for i := range revHosts { + result.ReverseHosts[i] = revHosts[i] + "." + } + + return true, nil + } + + return false, nil +} + // Process rewrites table // . Find CNAME for a domain name (exact match or by wildcard) // . if found and CNAME equals to domain name - this is an exception; exit diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index d2990360..11267adf 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -54,24 +54,28 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { } else if res.IsFiltered { log.Tracef("Host %s is filtered, reason - %q, matched rule: %q", host, res.Reason, res.Rule) d.Res = s.genDNSFilterMessage(d, &res) - } else if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 && len(res.IPList) == 0 { ctx.origQuestion = d.Req.Question[0] // resolve canonical name, not the original host name d.Req.Question[0].Name = dns.Fqdn(res.CanonName) - - } else if res.Reason == dnsfilter.RewriteEtcHosts && len(res.ReverseHost) != 0 { - + } else if res.Reason == dnsfilter.RewriteEtcHosts && len(res.ReverseHosts) != 0 { resp := s.makeResponse(req) - ptr := &dns.PTR{} - ptr.Hdr = dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypePTR, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, + for _, h := range res.ReverseHosts { + hdr := dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypePTR, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } + + ptr := &dns.PTR{ + Hdr: hdr, + Ptr: h, + } + + resp.Answer = append(resp.Answer, ptr) } - ptr.Ptr = res.ReverseHost - resp.Answer = append(resp.Answer, ptr) + d.Res = resp } else if res.Reason == dnsfilter.ReasonRewrite || res.Reason == dnsfilter.RewriteEtcHosts { resp := s.makeResponse(req) diff --git a/internal/home/clients.go b/internal/home/clients.go index 03a69db3..538bfe7e 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -606,22 +606,25 @@ func (clients *clientsContainer) rmHosts(source clientSource) int { return n } -// Fill clients array from system hosts-file +// addFromHostsFile fills the clients hosts list from the system's hosts files. func (clients *clientsContainer) addFromHostsFile() { hosts := clients.autoHosts.List() clients.lock.Lock() defer clients.lock.Unlock() + _ = clients.rmHosts(ClientSourceHostsFile) n := 0 - for ip, name := range hosts { - ok, err := clients.addHost(ip, name, ClientSourceHostsFile) - if err != nil { - log.Debug("Clients: %s", err) - } - if ok { - n++ + for ip, names := range hosts { + for _, name := range names { + ok, err := clients.addHost(ip, name, ClientSourceHostsFile) + if err != nil { + log.Debug("Clients: %s", err) + } + if ok { + n++ + } } } diff --git a/internal/util/auto_hosts.go b/internal/util/auto_hosts.go index 4230eae0..ec2f3579 100644 --- a/internal/util/auto_hosts.go +++ b/internal/util/auto_hosts.go @@ -20,9 +20,14 @@ type onChangedT func() // AutoHosts - automatic DNS records type AutoHosts struct { - lock sync.Mutex // serialize access to table - table map[string][]net.IP // 'hostname -> IP' table - tableReverse map[string]string // "IP -> hostname" table for reverse lookup + // lock protects table and tableReverse. + lock sync.Mutex + // table is the host-to-IPs map. + table map[string][]net.IP + // tableReverse is the IP-to-hosts map. + // + // TODO(a.garipov): Make better use of newtypes. Perhaps a custom map. + tableReverse map[string][]string hostsFn string // path to the main hosts-file hostsDirs []string // paths to OS-specific directories with hosts-files @@ -127,40 +132,44 @@ func (a *AutoHosts) Process(host string, qtype uint16) []net.IP { return ipsCopy } -// ProcessReverse - process PTR request -// Return "" if not found or an error occurred -func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) string { +// ProcessReverse processes a PTR request. It returns nil if nothing is found. +func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) (hosts []string) { if qtype != dns.TypePTR { - return "" + return nil } ipReal := DNSUnreverseAddr(addr) if ipReal == nil { - return "" // invalid IP in question + return nil } + ipStr := ipReal.String() a.lock.Lock() - host := a.tableReverse[ipStr] - a.lock.Unlock() + defer a.lock.Unlock() - if len(host) == 0 { - return "" // not found + hosts = a.tableReverse[ipStr] + + if len(hosts) == 0 { + return nil // not found } - log.Debug("AutoHosts: reverse-lookup: %s -> %s", addr, host) - return host + log.Debug("AutoHosts: reverse-lookup: %s -> %s", addr, hosts) + + return hosts } -// List - get "IP -> hostname" table. Thread-safe. -func (a *AutoHosts) List() map[string]string { - table := make(map[string]string) +// List returns an IP-to-hostnames table. It is safe for concurrent use. +func (a *AutoHosts) List() (ipToHosts map[string][]string) { a.lock.Lock() + defer a.lock.Unlock() + + ipToHosts = make(map[string][]string, len(a.tableReverse)) for k, v := range a.tableReverse { - table[k] = v + ipToHosts[k] = v } - a.lock.Unlock() - return table + + return ipToHosts } // update table @@ -187,19 +196,30 @@ func (a *AutoHosts) updateTable(table map[string][]net.IP, host string, ipAddr n } } -// update "reverse" table -func (a *AutoHosts) updateTableRev(tableRev map[string]string, host string, ipAddr net.IP) { +// updateTableRev updates the reverse address table. +func (a *AutoHosts) updateTableRev(tableRev map[string][]string, newHost string, ipAddr net.IP) { ipStr := ipAddr.String() - _, ok := tableRev[ipStr] + hosts, ok := tableRev[ipStr] if !ok { - tableRev[ipStr] = host - log.Debug("AutoHosts: added reverse-address %s -> %s", ipStr, host) + tableRev[ipStr] = []string{newHost} + log.Debug("AutoHosts: added reverse-address %s -> %s", ipStr, newHost) + + return } + + for _, host := range hosts { + if host == newHost { + return + } + } + + tableRev[ipStr] = append(tableRev[ipStr], newHost) + log.Debug("AutoHosts: added reverse-address %s -> %s", ipStr, newHost) } // Read IP-hostname pairs from file // Multiple hostnames per line (per one IP) is supported. -func (a *AutoHosts) load(table map[string][]net.IP, tableRev map[string]string, fn string) { +func (a *AutoHosts) load(table map[string][]net.IP, tableRev map[string][]string, fn string) { f, err := os.Open(fn) if err != nil { log.Error("AutoHosts: %s", err) @@ -306,7 +326,7 @@ func (a *AutoHosts) updateLoop() { // updateHosts - loads system hosts func (a *AutoHosts) updateHosts() { table := make(map[string][]net.IP) - tableRev := make(map[string]string) + tableRev := make(map[string][]string) a.load(table, tableRev, a.hostsFn) diff --git a/internal/util/auto_hosts_test.go b/internal/util/auto_hosts_test.go index efd94b99..9084daba 100644 --- a/internal/util/auto_hosts_test.go +++ b/internal/util/auto_hosts_test.go @@ -15,7 +15,7 @@ import ( func prepareTestDir() string { const dir = "./agh-test" _ = os.RemoveAll(dir) - _ = os.MkdirAll(dir, 0755) + _ = os.MkdirAll(dir, 0o755) return dir } @@ -50,17 +50,24 @@ func TestAutoHostsResolution(t *testing.T) { // Test hosts file table := ah.List() - name, ok := table["127.0.0.1"] + names, ok := table["127.0.0.1"] assert.True(t, ok) - assert.Equal(t, "host", name) + assert.Equal(t, []string{"host", "localhost"}, names) // Test PTR a, _ := dns.ReverseAddr("127.0.0.1") a = strings.TrimSuffix(a, ".") - assert.True(t, ah.ProcessReverse(a, dns.TypePTR) == "host") + hosts := ah.ProcessReverse(a, dns.TypePTR) + if assert.Len(t, hosts, 2) { + assert.Equal(t, hosts[0], "host") + } + a, _ = dns.ReverseAddr("::1") a = strings.TrimSuffix(a, ".") - assert.True(t, ah.ProcessReverse(a, dns.TypePTR) == "localhost") + hosts = ah.ProcessReverse(a, dns.TypePTR) + if assert.Len(t, hosts, 1) { + assert.Equal(t, hosts[0], "localhost") + } } func TestAutoHostsFSNotify(t *testing.T) {