From d6f560ecafd18ea7a8b969abb801ccdf31e8ec73 Mon Sep 17 00:00:00 2001
From: Andrey Meshkov <am@adguard.com>
Date: Mon, 5 Nov 2018 20:40:10 +0300
Subject: [PATCH] Added persistent connections cache

---
 upstream/dns_upstream.go   |  78 ++++++++++++--
 upstream/https_upstream.go |   7 ++
 upstream/persistent.go     | 208 +++++++++++++++++++++++++++++++++++++
 upstream/tls_upstream.go   |  47 ---------
 upstream/upstream.go       |  31 +++++-
 upstream/upstream_test.go  |  53 ++++++----
 6 files changed, 344 insertions(+), 80 deletions(-)
 create mode 100644 upstream/persistent.go
 delete mode 100644 upstream/tls_upstream.go

diff --git a/upstream/dns_upstream.go b/upstream/dns_upstream.go
index 779c059e..a40aec5a 100644
--- a/upstream/dns_upstream.go
+++ b/upstream/dns_upstream.go
@@ -1,6 +1,7 @@
 package upstream
 
 import (
+	"crypto/tls"
 	"github.com/miekg/dns"
 	"golang.org/x/net/context"
 	"time"
@@ -8,24 +9,40 @@ import (
 
 // DnsUpstream is a very simple upstream implementation for plain DNS
 type DnsUpstream struct {
-	nameServer string        // IP:port
-	timeout    time.Duration // Max read and write timeout
+	endpoint  string        // IP:port
+	timeout   time.Duration // Max read and write timeout
+	proto     string        // Protocol (tcp, tcp-tls, or udp)
+	transport *Transport    // Persistent connections cache
 }
 
-// NewDnsUpstream creates a new plain-DNS upstream
-func NewDnsUpstream(nameServer string) (Upstream, error) {
-	return &DnsUpstream{nameServer: nameServer, timeout: defaultTimeout}, nil
+// NewDnsUpstream creates a new DNS upstream
+func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) {
+
+	u := &DnsUpstream{
+		endpoint: endpoint,
+		timeout:  defaultTimeout,
+		proto:    proto,
+	}
+
+	var tlsConfig *tls.Config
+
+	if tlsServerName != "" {
+		tlsConfig = new(tls.Config)
+		tlsConfig.ServerName = tlsServerName
+	}
+
+	// Initialize the connections cache
+	u.transport = NewTransport(endpoint)
+	u.transport.tlsConfig = tlsConfig
+	u.transport.Start()
+
+	return u, nil
 }
 
 // Exchange provides an implementation for the Upstream interface
 func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
 
-	dnsClient := &dns.Client{
-		ReadTimeout:  u.timeout,
-		WriteTimeout: u.timeout,
-	}
-
-	resp, _, err := dnsClient.Exchange(query, u.nameServer)
+	resp, err := u.exchange(query)
 
 	if err != nil {
 		resp = &dns.Msg{}
@@ -34,3 +51,42 @@ func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, e
 
 	return resp, err
 }
+
+// Clear resources
+func (u *DnsUpstream) Close() error {
+
+	// Close active connections
+	u.transport.Stop()
+	return nil
+}
+
+// Performs a synchronous query. It sends the message m via the conn
+// c and waits for a reply. The conn c is not closed.
+func (u *DnsUpstream) exchange(query *dns.Msg) (r *dns.Msg, err error) {
+
+	// Establish a connection if needed (or reuse cached)
+	conn, err := u.transport.Dial(u.proto)
+	if err != nil {
+		return nil, err
+	}
+
+	// Write the request with a timeout
+	conn.SetWriteDeadline(time.Now().Add(u.timeout))
+	if err = conn.WriteMsg(query); err != nil {
+		conn.Close() // Not giving it back
+		return nil, err
+	}
+
+	// Write response with a timeout
+	conn.SetReadDeadline(time.Now().Add(u.timeout))
+	r, err = conn.ReadMsg()
+	if err != nil {
+		conn.Close() // Not giving it back
+	} else if err == nil && r.Id != query.Id {
+		err = dns.ErrId
+		conn.Close() // Not giving it back
+	}
+
+	u.transport.Yield(conn)
+	return r, err
+}
diff --git a/upstream/https_upstream.go b/upstream/https_upstream.go
index 61c7a397..7daab106 100644
--- a/upstream/https_upstream.go
+++ b/upstream/https_upstream.go
@@ -18,6 +18,8 @@ const (
 	dnsMessageContentType = "application/dns-message"
 )
 
+// TODO: Add bootstrap DNS resolver field
+
 // HttpsUpstream is the upstream implementation for DNS-over-HTTPS
 type HttpsUpstream struct {
 	client   *http.Client
@@ -107,3 +109,8 @@ func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) {
 
 	return buf, nil
 }
+
+// Clear resources
+func (u *HttpsUpstream) Close() error {
+	return nil
+}
diff --git a/upstream/persistent.go b/upstream/persistent.go
new file mode 100644
index 00000000..5c28a10e
--- /dev/null
+++ b/upstream/persistent.go
@@ -0,0 +1,208 @@
+package upstream
+
+import (
+	"crypto/tls"
+	"net"
+	"sort"
+	"sync/atomic"
+	"time"
+
+	"github.com/miekg/dns"
+)
+
+const (
+	defaultExpire       = 10 * time.Second
+	minDialTimeout      = 100 * time.Millisecond
+	maxDialTimeout      = 30 * time.Second
+	defaultDialTimeout  = 30 * time.Second
+	cumulativeAvgWeight = 4
+)
+
+// a persistConn hold the dns.Conn and the last used time.
+type persistConn struct {
+	c    *dns.Conn
+	used time.Time
+}
+
+// Transport hold the persistent cache.
+type Transport struct {
+	avgDialTime int64                     // kind of average time of dial time
+	conns       map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
+	expire      time.Duration             // After this duration a connection is expired.
+	addr        string
+	tlsConfig   *tls.Config
+
+	dial  chan string
+	yield chan *dns.Conn
+	ret   chan *dns.Conn
+	stop  chan bool
+}
+
+// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
+func (t *Transport) Dial(proto string) (*dns.Conn, error) {
+	// If tls has been configured; use it.
+	if t.tlsConfig != nil {
+		proto = "tcp-tls"
+	}
+
+	t.dial <- proto
+	c := <-t.ret
+
+	if c != nil {
+		return c, nil
+	}
+
+	reqTime := time.Now()
+	timeout := t.dialTimeout()
+	if proto == "tcp-tls" {
+		conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout)
+		t.updateDialTimeout(time.Since(reqTime))
+		return conn, err
+	}
+	conn, err := dns.DialTimeout(proto, t.addr, timeout)
+	t.updateDialTimeout(time.Since(reqTime))
+	return conn, err
+}
+
+// Yield return the connection to transport for reuse.
+func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
+
+// Start starts the transport's connection manager.
+func (t *Transport) Start() { go t.connManager() }
+
+// Stop stops the transport's connection manager.
+func (t *Transport) Stop() { close(t.stop) }
+
+// SetExpire sets the connection expire time in transport.
+func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
+
+// SetTLSConfig sets the TLS config in transport.
+func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
+
+func NewTransport(addr string) *Transport {
+	t := &Transport{
+		avgDialTime: int64(defaultDialTimeout / 2),
+		conns:       make(map[string][]*persistConn),
+		expire:      defaultExpire,
+		addr:        addr,
+		dial:        make(chan string),
+		yield:       make(chan *dns.Conn),
+		ret:         make(chan *dns.Conn),
+		stop:        make(chan bool),
+	}
+	return t
+}
+
+func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
+	dt := time.Duration(atomic.LoadInt64(currentAvg))
+	atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
+}
+
+func (t *Transport) dialTimeout() time.Duration {
+	return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
+}
+
+func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
+	averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
+}
+
+// limitTimeout is a utility function to auto-tune timeout values
+// average observed time is moved towards the last observed delay moderated by a weight
+// next timeout to use will be the double of the computed average, limited by min and max frame.
+func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
+	rt := time.Duration(atomic.LoadInt64(currentAvg))
+	if rt < minValue {
+		return minValue
+	}
+	if rt < maxValue/2 {
+		return 2 * rt
+	}
+	return maxValue
+}
+
+// connManagers manages the persistent connection cache for UDP and TCP.
+func (t *Transport) connManager() {
+	ticker := time.NewTicker(t.expire)
+Wait:
+	for {
+		select {
+		case proto := <-t.dial:
+			// take the last used conn - complexity O(1)
+			if stack := t.conns[proto]; len(stack) > 0 {
+				pc := stack[len(stack)-1]
+				if time.Since(pc.used) < t.expire {
+					// Found one, remove from pool and return this conn.
+					t.conns[proto] = stack[:len(stack)-1]
+					t.ret <- pc.c
+					continue Wait
+				}
+				// clear entire cache if the last conn is expired
+				t.conns[proto] = nil
+				// now, the connections being passed to closeConns() are not reachable from
+				// transport methods anymore. So, it's safe to close them in a separate goroutine
+				go closeConns(stack)
+			}
+
+			t.ret <- nil
+
+		case conn := <-t.yield:
+
+			// no proto here, infer from config and conn
+			if _, ok := conn.Conn.(*net.UDPConn); ok {
+				t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()})
+				continue Wait
+			}
+
+			if t.tlsConfig == nil {
+				t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()})
+				continue Wait
+			}
+
+			t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
+
+		case <-ticker.C:
+			t.cleanup(false)
+
+		case <-t.stop:
+			t.cleanup(true)
+			close(t.ret)
+			return
+		}
+	}
+}
+
+// closeConns closes connections.
+func closeConns(conns []*persistConn) {
+	for _, pc := range conns {
+		pc.c.Close()
+	}
+}
+
+// cleanup removes connections from cache.
+func (t *Transport) cleanup(all bool) {
+	staleTime := time.Now().Add(-t.expire)
+	for proto, stack := range t.conns {
+		if len(stack) == 0 {
+			continue
+		}
+		if all {
+			t.conns[proto] = nil
+			// now, the connections being passed to closeConns() are not reachable from
+			// transport methods anymore. So, it's safe to close them in a separate goroutine
+			go closeConns(stack)
+			continue
+		}
+		if stack[0].used.After(staleTime) {
+			continue
+		}
+
+		// connections in stack are sorted by "used"
+		good := sort.Search(len(stack), func(i int) bool {
+			return stack[i].used.After(staleTime)
+		})
+		t.conns[proto] = stack[good:]
+		// now, the connections being passed to closeConns() are not reachable from
+		// transport methods anymore. So, it's safe to close them in a separate goroutine
+		go closeConns(stack[:good])
+	}
+}
diff --git a/upstream/tls_upstream.go b/upstream/tls_upstream.go
deleted file mode 100644
index aed55829..00000000
--- a/upstream/tls_upstream.go
+++ /dev/null
@@ -1,47 +0,0 @@
-package upstream
-
-import (
-	"crypto/tls"
-	"github.com/miekg/dns"
-	"golang.org/x/net/context"
-	"time"
-)
-
-// TODO: Use persistent connection here
-
-// DnsOverTlsUpstream is the upstream implementation for plain DNS-over-TLS
-type DnsOverTlsUpstream struct {
-	endpoint      string
-	tlsServerName string
-	timeout       time.Duration
-}
-
-// NewHttpsUpstream creates a new DNS-over-TLS upstream from the endpoint address and TLS server name
-func NewDnsOverTlsUpstream(endpoint string, tlsServerName string) (Upstream, error) {
-	return &DnsOverTlsUpstream{
-		endpoint:      endpoint,
-		tlsServerName: tlsServerName,
-		timeout:       defaultTimeout,
-	}, nil
-}
-
-// Exchange provides an implementation for the Upstream interface
-func (u *DnsOverTlsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
-
-	dnsClient := &dns.Client{
-		Net:          "tcp-tls",
-		ReadTimeout:  u.timeout,
-		WriteTimeout: u.timeout,
-		TLSConfig:    new(tls.Config),
-	}
-	dnsClient.TLSConfig.ServerName = u.tlsServerName
-
-	resp, _, err := dnsClient.Exchange(query, u.endpoint)
-
-	if err != nil {
-		resp = &dns.Msg{}
-		resp.SetRcode(resp, dns.RcodeServerFailure)
-	}
-
-	return resp, err
-}
diff --git a/upstream/upstream.go b/upstream/upstream.go
index 6d2570c5..44d4e389 100644
--- a/upstream/upstream.go
+++ b/upstream/upstream.go
@@ -5,6 +5,8 @@ import (
 	"github.com/miekg/dns"
 	"github.com/pkg/errors"
 	"golang.org/x/net/context"
+	"log"
+	"runtime"
 	"time"
 )
 
@@ -12,9 +14,12 @@ const (
 	defaultTimeout = 5 * time.Second
 )
 
+// TODO: Add a helper method for health-checking an upstream (see health.go in coredns)
+
 // Upstream is a simplified interface for proxy destination
 type Upstream interface {
 	Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
+	Close() error
 }
 
 // UpstreamPlugin is a simplified DNS proxy using a generic upstream interface
@@ -23,11 +28,21 @@ type UpstreamPlugin struct {
 	Next      plugin.Handler
 }
 
+// Initialize the upstream plugin
+func New() *UpstreamPlugin {
+	p := &UpstreamPlugin{}
+
+	// Make sure all resources are cleaned up
+	runtime.SetFinalizer(p, (*UpstreamPlugin).finalizer)
+	return p
+}
+
 // ServeDNS implements interface for CoreDNS plugin
-func (p UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
 	var reply *dns.Msg
 	var backendErr error
 
+	// TODO: Change the way we call upstreams
 	for _, upstream := range p.Upstreams {
 		reply, backendErr = upstream.Exchange(ctx, r)
 		if backendErr == nil {
@@ -40,4 +55,16 @@ func (p UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *d
 }
 
 // Name implements interface for CoreDNS plugin
-func (p UpstreamPlugin) Name() string { return "upstream" }
+func (p *UpstreamPlugin) Name() string { return "upstream" }
+
+func (p *UpstreamPlugin) finalizer() {
+
+	for i := range p.Upstreams {
+
+		u := p.Upstreams[i]
+		err := u.Close()
+		if err != nil {
+			log.Printf("Error while closing the upstream: %s", err)
+		}
+	}
+}
diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go
index ca0df859..5e60b63d 100644
--- a/upstream/upstream_test.go
+++ b/upstream/upstream_test.go
@@ -2,14 +2,13 @@ package upstream
 
 import (
 	"github.com/miekg/dns"
-	"log"
 	"net"
 	"testing"
 )
 
 func TestDnsUpstream(t *testing.T) {
 
-	u, err := NewDnsUpstream("8.8.8.8:53")
+	u, err := NewDnsUpstream("8.8.8.8:53", "udp", "")
 
 	if err != nil {
 		t.Errorf("cannot create a DNS upstream")
@@ -44,12 +43,12 @@ func TestDnsOverTlsUpstream(t *testing.T) {
 		tlsServerName string
 	}{
 		{"1.1.1.1:853", ""},
-		{"8.8.8.8:853", ""},
+		{"9.9.9.9:853", ""},
 		{"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"},
 	}
 
 	for _, test := range tests {
-		u, err := NewDnsOverTlsUpstream(test.endpoint, test.tlsServerName)
+		u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName)
 
 		if err != nil {
 			t.Errorf("cannot create a DNS-over-TLS upstream")
@@ -60,27 +59,41 @@ func TestDnsOverTlsUpstream(t *testing.T) {
 }
 
 func testUpstream(t *testing.T, u Upstream) {
-	req := dns.Msg{}
-	req.Id = dns.Id()
-	req.RecursionDesired = true
-	req.Question = []dns.Question{
-		{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
+
+	var tests = []struct {
+		name     string
+		expected net.IP
+	}{
+		{"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)},
+		{"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)},
 	}
 
-	resp, err := u.Exchange(nil, &req)
+	for _, test := range tests {
+		req := dns.Msg{}
+		req.Id = dns.Id()
+		req.RecursionDesired = true
+		req.Question = []dns.Question{
+			{Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET},
+		}
 
-	if err != nil {
-		t.Errorf("error while making an upstream request: %s", err)
-	}
+		resp, err := u.Exchange(nil, &req)
 
-	if len(resp.Answer) != 1 {
-		t.Errorf("no answer section in the response")
-	}
-	if answer, ok := resp.Answer[0].(*dns.A); ok {
-		if !net.IPv4(8, 8, 8, 8).Equal(answer.A) {
-			t.Errorf("wrong IP in the response: %v", answer.A)
+		if err != nil {
+			t.Errorf("error while making an upstream request: %s", err)
+		}
+
+		if len(resp.Answer) != 1 {
+			t.Errorf("no answer section in the response")
+		}
+		if answer, ok := resp.Answer[0].(*dns.A); ok {
+			if !test.expected.Equal(answer.A) {
+				t.Errorf("wrong IP in the response: %v", answer.A)
+			}
 		}
 	}
 
-	log.Printf("response: %v", resp)
+	err := u.Close()
+	if err != nil {
+		t.Errorf("Error while closing the upstream: %s", err)
+	}
 }