From 192b58b9d94db2f98704f46fa3fd0b20d95fb8dd Mon Sep 17 00:00:00 2001
From: Simon Zolin <s.zolin@adguard.com>
Date: Thu, 19 Sep 2019 18:27:13 +0300
Subject: [PATCH] * rDNS: refactor

---
 home/dns.go  | 14 +++---------
 home/rdns.go | 62 +++++++++++++++++++++++++++++++++-------------------
 2 files changed, 42 insertions(+), 34 deletions(-)

diff --git a/home/dns.go b/home/dns.go
index 3b90ddfa..1712df9c 100644
--- a/home/dns.go
+++ b/home/dns.go
@@ -5,27 +5,19 @@ import (
 	"net"
 	"os"
 	"path/filepath"
-	"sync"
 
 	"github.com/AdguardTeam/AdGuardHome/dnsfilter"
 	"github.com/AdguardTeam/AdGuardHome/dnsforward"
 	"github.com/AdguardTeam/AdGuardHome/querylog"
 	"github.com/AdguardTeam/AdGuardHome/stats"
 	"github.com/AdguardTeam/dnsproxy/proxy"
-	"github.com/AdguardTeam/dnsproxy/upstream"
 	"github.com/AdguardTeam/golibs/log"
 	"github.com/joomcode/errorx"
 	"github.com/miekg/dns"
 )
 
 type dnsContext struct {
-	rdnsChannel chan string // pass data from DNS request handling thread to rDNS thread
-	// contains IP addresses of clients to be resolved by rDNS
-	// if IP address couldn't be resolved, it stays here forever to prevent further attempts to resolve the same IP
-	rdnsIP   map[string]bool
-	rdnsLock sync.Mutex        // synchronize access to rdnsIP
-	upstream upstream.Upstream // Upstream object for our own DNS server
-
+	rdns  *RDNS
 	whois *Whois
 }
 
@@ -57,7 +49,7 @@ func initDNSServer(baseDir string) {
 	config.auth = InitAuth(sessFilename, config.Users)
 	config.Users = nil
 
-	initRDNS()
+	config.dnsctx.rdns = InitRDNS(&config.clients)
 	config.dnsctx.whois = initWhois(&config.clients)
 	initFiltering()
 }
@@ -133,7 +125,7 @@ func onDNSRequest(d *proxy.DNSContext) {
 
 	ipAddr := net.ParseIP(ip)
 	if !ipAddr.IsLoopback() {
-		beginAsyncRDNS(ip)
+		config.dnsctx.rdns.Begin(ip)
 	}
 	if isPublicIP(ipAddr) {
 		config.dnsctx.whois.Begin(ip)
diff --git a/home/rdns.go b/home/rdns.go
index 048dcde1..c8a39974 100644
--- a/home/rdns.go
+++ b/home/rdns.go
@@ -3,6 +3,7 @@ package home
 import (
 	"fmt"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/AdguardTeam/dnsproxy/upstream"
@@ -14,7 +15,21 @@ const (
 	rdnsTimeout = 3 * time.Second // max time to wait for rDNS response
 )
 
-func initRDNS() {
+// RDNS - module context
+type RDNS struct {
+	clients   *clientsContainer
+	ipChannel chan string // pass data from DNS request handling thread to rDNS thread
+	// contains IP addresses of clients to be resolved by rDNS
+	// if IP address couldn't be resolved, it stays here forever to prevent further attempts to resolve the same IP
+	ips      map[string]bool
+	lock     sync.Mutex        // synchronize access to 'ips'
+	upstream upstream.Upstream // Upstream object for our own DNS server
+}
+
+// InitRDNS - create module context
+func InitRDNS(clients *clientsContainer) *RDNS {
+	r := RDNS{}
+	r.clients = clients
 	var err error
 
 	bindhost := config.DNS.BindHost
@@ -26,35 +41,36 @@ func initRDNS() {
 	opts := upstream.Options{
 		Timeout: rdnsTimeout,
 	}
-	config.dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
+	r.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
 	if err != nil {
 		log.Error("upstream.AddressToUpstream: %s", err)
-		return
+		return nil
 	}
 
-	config.dnsctx.rdnsIP = make(map[string]bool)
-	config.dnsctx.rdnsChannel = make(chan string, 256)
-	go asyncRDNSLoop()
+	r.ips = make(map[string]bool)
+	r.ipChannel = make(chan string, 256)
+	go r.workerLoop()
+	return &r
 }
 
-// Add IP address to the rDNS queue
-func beginAsyncRDNS(ip string) {
-	if config.clients.Exists(ip, ClientSourceRDNS) {
+// Begin - add IP address to rDNS queue
+func (r *RDNS) Begin(ip string) {
+	if r.clients.Exists(ip, ClientSourceRDNS) {
 		return
 	}
 
-	// add IP to rdnsIP, if not exists
-	config.dnsctx.rdnsLock.Lock()
-	defer config.dnsctx.rdnsLock.Unlock()
-	_, ok := config.dnsctx.rdnsIP[ip]
+	// add IP to ips, if not exists
+	r.lock.Lock()
+	defer r.lock.Unlock()
+	_, ok := r.ips[ip]
 	if ok {
 		return
 	}
-	config.dnsctx.rdnsIP[ip] = true
+	r.ips[ip] = true
 
 	log.Tracef("Adding %s for rDNS resolve", ip)
 	select {
-	case config.dnsctx.rdnsChannel <- ip:
+	case r.ipChannel <- ip:
 		//
 	default:
 		log.Tracef("rDNS queue is full")
@@ -62,7 +78,7 @@ func beginAsyncRDNS(ip string) {
 }
 
 // Use rDNS to get hostname by IP address
-func resolveRDNS(ip string) string {
+func (r *RDNS) resolve(ip string) string {
 	log.Tracef("Resolving host for %s", ip)
 
 	req := dns.Msg{}
@@ -81,7 +97,7 @@ func resolveRDNS(ip string) string {
 		return ""
 	}
 
-	resp, err := config.dnsctx.upstream.Exchange(&req)
+	resp, err := r.upstream.Exchange(&req)
 	if err != nil {
 		log.Debug("Error while making an rDNS lookup for %s: %s", ip, err)
 		return ""
@@ -106,19 +122,19 @@ func resolveRDNS(ip string) string {
 
 // Wait for a signal and then synchronously resolve hostname by IP address
 // Add the hostname:IP pair to "Clients" array
-func asyncRDNSLoop() {
+func (r *RDNS) workerLoop() {
 	for {
 		var ip string
-		ip = <-config.dnsctx.rdnsChannel
+		ip = <-r.ipChannel
 
-		host := resolveRDNS(ip)
+		host := r.resolve(ip)
 		if len(host) == 0 {
 			continue
 		}
 
-		config.dnsctx.rdnsLock.Lock()
-		delete(config.dnsctx.rdnsIP, ip)
-		config.dnsctx.rdnsLock.Unlock()
+		r.lock.Lock()
+		delete(r.ips, ip)
+		r.lock.Unlock()
 
 		_, _ = config.clients.AddHost(ip, host, ClientSourceRDNS)
 	}