From a3b8d4d9237fcd726b7e4edfc729bd156efc1323 Mon Sep 17 00:00:00 2001
From: Andrey Meshkov <am@adguard.com>
Date: Tue, 4 Jun 2019 20:38:53 +0300
Subject: [PATCH] Fix #706 -- rDNS for DOH/DOT clients

---
 dns.go                 | 14 +++++++++++---
 dns_test.go            |  4 +++-
 dnsforward/helpers.go  | 14 ++++++++++++++
 dnsforward/querylog.go | 13 +------------
 4 files changed, 29 insertions(+), 16 deletions(-)
 create mode 100644 dnsforward/helpers.go

diff --git a/dns.go b/dns.go
index 28ebbbf5..2f41948b 100644
--- a/dns.go
+++ b/dns.go
@@ -154,10 +154,18 @@ func asyncRDNSLoop() {
 }
 
 func onDNSRequest(d *proxy.DNSContext) {
-	if d.Req.Question[0].Qtype == dns.TypeA {
-		ip, _, _ := net.SplitHostPort(d.Addr.String())
-		beginAsyncRDNS(ip)
+	qType := d.Req.Question[0].Qtype
+	if qType != dns.TypeA && qType != dns.TypeAAAA {
+		return
 	}
+
+	ip := dnsforward.GetIPString(d.Addr)
+	if ip == "" {
+		// This would be quite weird if we get here
+		return
+	}
+
+	beginAsyncRDNS(ip)
 }
 
 func generateServerConfig() dnsforward.ServerConfig {
diff --git a/dns_test.go b/dns_test.go
index 7e3b5788..0ea2cfb8 100644
--- a/dns_test.go
+++ b/dns_test.go
@@ -1,6 +1,8 @@
 package main
 
-import "testing"
+import (
+	"testing"
+)
 
 func TestResolveRDNS(t *testing.T) {
 	config.DNS.BindHost = "1.1.1.1"
diff --git a/dnsforward/helpers.go b/dnsforward/helpers.go
new file mode 100644
index 00000000..e7212355
--- /dev/null
+++ b/dnsforward/helpers.go
@@ -0,0 +1,14 @@
+package dnsforward
+
+import "net"
+
+// GetIPString is a helper function that extracts IP address from net.Addr
+func GetIPString(addr net.Addr) string {
+	switch addr := addr.(type) {
+	case *net.UDPAddr:
+		return addr.IP.String()
+	case *net.TCPAddr:
+		return addr.IP.String()
+	}
+	return ""
+}
diff --git a/dnsforward/querylog.go b/dnsforward/querylog.go
index 7387e43a..5dcdb5aa 100644
--- a/dnsforward/querylog.go
+++ b/dnsforward/querylog.go
@@ -61,7 +61,7 @@ func (l *queryLog) logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfil
 	var q []byte
 	var a []byte
 	var err error
-	ip := getIPString(addr)
+	ip := GetIPString(addr)
 
 	if question != nil {
 		q, err = question.Pack()
@@ -244,14 +244,3 @@ func answerToMap(a *dns.Msg) []map[string]interface{} {
 
 	return answers
 }
-
-// getIPString is a helper function that extracts IP address from net.Addr
-func getIPString(addr net.Addr) string {
-	switch addr := addr.(type) {
-	case *net.UDPAddr:
-		return addr.IP.String()
-	case *net.TCPAddr:
-		return addr.IP.String()
-	}
-	return ""
-}