From d6f560ecafd18ea7a8b969abb801ccdf31e8ec73 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov <am@adguard.com> Date: Mon, 5 Nov 2018 20:40:10 +0300 Subject: [PATCH] Added persistent connections cache --- upstream/dns_upstream.go | 78 ++++++++++++-- upstream/https_upstream.go | 7 ++ upstream/persistent.go | 208 +++++++++++++++++++++++++++++++++++++ upstream/tls_upstream.go | 47 --------- upstream/upstream.go | 31 +++++- upstream/upstream_test.go | 53 ++++++---- 6 files changed, 344 insertions(+), 80 deletions(-) create mode 100644 upstream/persistent.go delete mode 100644 upstream/tls_upstream.go diff --git a/upstream/dns_upstream.go b/upstream/dns_upstream.go index 779c059e..a40aec5a 100644 --- a/upstream/dns_upstream.go +++ b/upstream/dns_upstream.go @@ -1,6 +1,7 @@ package upstream import ( + "crypto/tls" "github.com/miekg/dns" "golang.org/x/net/context" "time" @@ -8,24 +9,40 @@ import ( // DnsUpstream is a very simple upstream implementation for plain DNS type DnsUpstream struct { - nameServer string // IP:port - timeout time.Duration // Max read and write timeout + endpoint string // IP:port + timeout time.Duration // Max read and write timeout + proto string // Protocol (tcp, tcp-tls, or udp) + transport *Transport // Persistent connections cache } -// NewDnsUpstream creates a new plain-DNS upstream -func NewDnsUpstream(nameServer string) (Upstream, error) { - return &DnsUpstream{nameServer: nameServer, timeout: defaultTimeout}, nil +// NewDnsUpstream creates a new DNS upstream +func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) { + + u := &DnsUpstream{ + endpoint: endpoint, + timeout: defaultTimeout, + proto: proto, + } + + var tlsConfig *tls.Config + + if tlsServerName != "" { + tlsConfig = new(tls.Config) + tlsConfig.ServerName = tlsServerName + } + + // Initialize the connections cache + u.transport = NewTransport(endpoint) + u.transport.tlsConfig = tlsConfig + u.transport.Start() + + return u, nil } // Exchange provides an implementation for the Upstream interface func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { - dnsClient := &dns.Client{ - ReadTimeout: u.timeout, - WriteTimeout: u.timeout, - } - - resp, _, err := dnsClient.Exchange(query, u.nameServer) + resp, err := u.exchange(query) if err != nil { resp = &dns.Msg{} @@ -34,3 +51,42 @@ func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, e return resp, err } + +// Clear resources +func (u *DnsUpstream) Close() error { + + // Close active connections + u.transport.Stop() + return nil +} + +// Performs a synchronous query. It sends the message m via the conn +// c and waits for a reply. The conn c is not closed. +func (u *DnsUpstream) exchange(query *dns.Msg) (r *dns.Msg, err error) { + + // Establish a connection if needed (or reuse cached) + conn, err := u.transport.Dial(u.proto) + if err != nil { + return nil, err + } + + // Write the request with a timeout + conn.SetWriteDeadline(time.Now().Add(u.timeout)) + if err = conn.WriteMsg(query); err != nil { + conn.Close() // Not giving it back + return nil, err + } + + // Write response with a timeout + conn.SetReadDeadline(time.Now().Add(u.timeout)) + r, err = conn.ReadMsg() + if err != nil { + conn.Close() // Not giving it back + } else if err == nil && r.Id != query.Id { + err = dns.ErrId + conn.Close() // Not giving it back + } + + u.transport.Yield(conn) + return r, err +} diff --git a/upstream/https_upstream.go b/upstream/https_upstream.go index 61c7a397..7daab106 100644 --- a/upstream/https_upstream.go +++ b/upstream/https_upstream.go @@ -18,6 +18,8 @@ const ( dnsMessageContentType = "application/dns-message" ) +// TODO: Add bootstrap DNS resolver field + // HttpsUpstream is the upstream implementation for DNS-over-HTTPS type HttpsUpstream struct { client *http.Client @@ -107,3 +109,8 @@ func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) { return buf, nil } + +// Clear resources +func (u *HttpsUpstream) Close() error { + return nil +} diff --git a/upstream/persistent.go b/upstream/persistent.go new file mode 100644 index 00000000..5c28a10e --- /dev/null +++ b/upstream/persistent.go @@ -0,0 +1,208 @@ +package upstream + +import ( + "crypto/tls" + "net" + "sort" + "sync/atomic" + "time" + + "github.com/miekg/dns" +) + +const ( + defaultExpire = 10 * time.Second + minDialTimeout = 100 * time.Millisecond + maxDialTimeout = 30 * time.Second + defaultDialTimeout = 30 * time.Second + cumulativeAvgWeight = 4 +) + +// a persistConn hold the dns.Conn and the last used time. +type persistConn struct { + c *dns.Conn + used time.Time +} + +// Transport hold the persistent cache. +type Transport struct { + avgDialTime int64 // kind of average time of dial time + conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. + expire time.Duration // After this duration a connection is expired. + addr string + tlsConfig *tls.Config + + dial chan string + yield chan *dns.Conn + ret chan *dns.Conn + stop chan bool +} + +// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. +func (t *Transport) Dial(proto string) (*dns.Conn, error) { + // If tls has been configured; use it. + if t.tlsConfig != nil { + proto = "tcp-tls" + } + + t.dial <- proto + c := <-t.ret + + if c != nil { + return c, nil + } + + reqTime := time.Now() + timeout := t.dialTimeout() + if proto == "tcp-tls" { + conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return conn, err + } + conn, err := dns.DialTimeout(proto, t.addr, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return conn, err +} + +// Yield return the connection to transport for reuse. +func (t *Transport) Yield(c *dns.Conn) { t.yield <- c } + +// Start starts the transport's connection manager. +func (t *Transport) Start() { go t.connManager() } + +// Stop stops the transport's connection manager. +func (t *Transport) Stop() { close(t.stop) } + +// SetExpire sets the connection expire time in transport. +func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } + +// SetTLSConfig sets the TLS config in transport. +func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } + +func NewTransport(addr string) *Transport { + t := &Transport{ + avgDialTime: int64(defaultDialTimeout / 2), + conns: make(map[string][]*persistConn), + expire: defaultExpire, + addr: addr, + dial: make(chan string), + yield: make(chan *dns.Conn), + ret: make(chan *dns.Conn), + stop: make(chan bool), + } + return t +} + +func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) { + dt := time.Duration(atomic.LoadInt64(currentAvg)) + atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) +} + +func (t *Transport) dialTimeout() time.Duration { + return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) +} + +func (t *Transport) updateDialTimeout(newDialTime time.Duration) { + averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) +} + +// limitTimeout is a utility function to auto-tune timeout values +// average observed time is moved towards the last observed delay moderated by a weight +// next timeout to use will be the double of the computed average, limited by min and max frame. +func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration { + rt := time.Duration(atomic.LoadInt64(currentAvg)) + if rt < minValue { + return minValue + } + if rt < maxValue/2 { + return 2 * rt + } + return maxValue +} + +// connManagers manages the persistent connection cache for UDP and TCP. +func (t *Transport) connManager() { + ticker := time.NewTicker(t.expire) +Wait: + for { + select { + case proto := <-t.dial: + // take the last used conn - complexity O(1) + if stack := t.conns[proto]; len(stack) > 0 { + pc := stack[len(stack)-1] + if time.Since(pc.used) < t.expire { + // Found one, remove from pool and return this conn. + t.conns[proto] = stack[:len(stack)-1] + t.ret <- pc.c + continue Wait + } + // clear entire cache if the last conn is expired + t.conns[proto] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + } + + t.ret <- nil + + case conn := <-t.yield: + + // no proto here, infer from config and conn + if _, ok := conn.Conn.(*net.UDPConn); ok { + t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()}) + continue Wait + } + + if t.tlsConfig == nil { + t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()}) + continue Wait + } + + t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()}) + + case <-ticker.C: + t.cleanup(false) + + case <-t.stop: + t.cleanup(true) + close(t.ret) + return + } + } +} + +// closeConns closes connections. +func closeConns(conns []*persistConn) { + for _, pc := range conns { + pc.c.Close() + } +} + +// cleanup removes connections from cache. +func (t *Transport) cleanup(all bool) { + staleTime := time.Now().Add(-t.expire) + for proto, stack := range t.conns { + if len(stack) == 0 { + continue + } + if all { + t.conns[proto] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + continue + } + if stack[0].used.After(staleTime) { + continue + } + + // connections in stack are sorted by "used" + good := sort.Search(len(stack), func(i int) bool { + return stack[i].used.After(staleTime) + }) + t.conns[proto] = stack[good:] + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack[:good]) + } +} diff --git a/upstream/tls_upstream.go b/upstream/tls_upstream.go deleted file mode 100644 index aed55829..00000000 --- a/upstream/tls_upstream.go +++ /dev/null @@ -1,47 +0,0 @@ -package upstream - -import ( - "crypto/tls" - "github.com/miekg/dns" - "golang.org/x/net/context" - "time" -) - -// TODO: Use persistent connection here - -// DnsOverTlsUpstream is the upstream implementation for plain DNS-over-TLS -type DnsOverTlsUpstream struct { - endpoint string - tlsServerName string - timeout time.Duration -} - -// NewHttpsUpstream creates a new DNS-over-TLS upstream from the endpoint address and TLS server name -func NewDnsOverTlsUpstream(endpoint string, tlsServerName string) (Upstream, error) { - return &DnsOverTlsUpstream{ - endpoint: endpoint, - tlsServerName: tlsServerName, - timeout: defaultTimeout, - }, nil -} - -// Exchange provides an implementation for the Upstream interface -func (u *DnsOverTlsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { - - dnsClient := &dns.Client{ - Net: "tcp-tls", - ReadTimeout: u.timeout, - WriteTimeout: u.timeout, - TLSConfig: new(tls.Config), - } - dnsClient.TLSConfig.ServerName = u.tlsServerName - - resp, _, err := dnsClient.Exchange(query, u.endpoint) - - if err != nil { - resp = &dns.Msg{} - resp.SetRcode(resp, dns.RcodeServerFailure) - } - - return resp, err -} diff --git a/upstream/upstream.go b/upstream/upstream.go index 6d2570c5..44d4e389 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -5,6 +5,8 @@ import ( "github.com/miekg/dns" "github.com/pkg/errors" "golang.org/x/net/context" + "log" + "runtime" "time" ) @@ -12,9 +14,12 @@ const ( defaultTimeout = 5 * time.Second ) +// TODO: Add a helper method for health-checking an upstream (see health.go in coredns) + // Upstream is a simplified interface for proxy destination type Upstream interface { Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) + Close() error } // UpstreamPlugin is a simplified DNS proxy using a generic upstream interface @@ -23,11 +28,21 @@ type UpstreamPlugin struct { Next plugin.Handler } +// Initialize the upstream plugin +func New() *UpstreamPlugin { + p := &UpstreamPlugin{} + + // Make sure all resources are cleaned up + runtime.SetFinalizer(p, (*UpstreamPlugin).finalizer) + return p +} + // ServeDNS implements interface for CoreDNS plugin -func (p UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { +func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { var reply *dns.Msg var backendErr error + // TODO: Change the way we call upstreams for _, upstream := range p.Upstreams { reply, backendErr = upstream.Exchange(ctx, r) if backendErr == nil { @@ -40,4 +55,16 @@ func (p UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *d } // Name implements interface for CoreDNS plugin -func (p UpstreamPlugin) Name() string { return "upstream" } +func (p *UpstreamPlugin) Name() string { return "upstream" } + +func (p *UpstreamPlugin) finalizer() { + + for i := range p.Upstreams { + + u := p.Upstreams[i] + err := u.Close() + if err != nil { + log.Printf("Error while closing the upstream: %s", err) + } + } +} diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go index ca0df859..5e60b63d 100644 --- a/upstream/upstream_test.go +++ b/upstream/upstream_test.go @@ -2,14 +2,13 @@ package upstream import ( "github.com/miekg/dns" - "log" "net" "testing" ) func TestDnsUpstream(t *testing.T) { - u, err := NewDnsUpstream("8.8.8.8:53") + u, err := NewDnsUpstream("8.8.8.8:53", "udp", "") if err != nil { t.Errorf("cannot create a DNS upstream") @@ -44,12 +43,12 @@ func TestDnsOverTlsUpstream(t *testing.T) { tlsServerName string }{ {"1.1.1.1:853", ""}, - {"8.8.8.8:853", ""}, + {"9.9.9.9:853", ""}, {"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"}, } for _, test := range tests { - u, err := NewDnsOverTlsUpstream(test.endpoint, test.tlsServerName) + u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName) if err != nil { t.Errorf("cannot create a DNS-over-TLS upstream") @@ -60,27 +59,41 @@ func TestDnsOverTlsUpstream(t *testing.T) { } func testUpstream(t *testing.T, u Upstream) { - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + + var tests = []struct { + name string + expected net.IP + }{ + {"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)}, + {"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)}, } - resp, err := u.Exchange(nil, &req) + for _, test := range tests { + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } - if err != nil { - t.Errorf("error while making an upstream request: %s", err) - } + resp, err := u.Exchange(nil, &req) - if len(resp.Answer) != 1 { - t.Errorf("no answer section in the response") - } - if answer, ok := resp.Answer[0].(*dns.A); ok { - if !net.IPv4(8, 8, 8, 8).Equal(answer.A) { - t.Errorf("wrong IP in the response: %v", answer.A) + if err != nil { + t.Errorf("error while making an upstream request: %s", err) + } + + if len(resp.Answer) != 1 { + t.Errorf("no answer section in the response") + } + if answer, ok := resp.Answer[0].(*dns.A); ok { + if !test.expected.Equal(answer.A) { + t.Errorf("wrong IP in the response: %v", answer.A) + } } } - log.Printf("response: %v", resp) + err := u.Close() + if err != nil { + t.Errorf("Error while closing the upstream: %s", err) + } }