Added factory method for creating DNS upstreams

This commit is contained in:
Andrey Meshkov 2018-11-05 22:11:13 +03:00
parent a6022fc198
commit 9bc4bf66ed
3 changed files with 112 additions and 41 deletions

View file

@ -1,6 +1,84 @@
package upstream
import "github.com/miekg/dns"
import (
"github.com/miekg/dns"
"golang.org/x/net/context"
"net"
"strings"
)
// Detects the upstream type from the specified url and creates a proper Upstream object
func NewUpstream(url string, bootstrap string) (Upstream, error) {
proto := "udp"
prefix := ""
switch {
case strings.HasPrefix(url, "tcp://"):
proto = "tcp"
prefix = "tcp://"
case strings.HasPrefix(url, "tls://"):
proto = "tcp-tls"
prefix = "tls://"
case strings.HasPrefix(url, "https://"):
return NewHttpsUpstream(url, bootstrap)
}
hostname := strings.TrimPrefix(url, prefix)
host, port, err := net.SplitHostPort(hostname)
if err != nil {
// Set port depending on the protocol
switch proto {
case "udp":
port = "53"
case "tcp":
port = "53"
case "tcp-tls":
port = "853"
}
// Set host = hostname
host = hostname
}
// Try to resolve the host address (or check if it's an IP address)
bootstrapResolver := CreateResolver(bootstrap)
ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host)
if err != nil || len(ips) == 0 {
return nil, err
}
addr := ips[0].String()
endpoint := net.JoinHostPort(addr, port)
tlsServerName := ""
if proto == "tcp-tls" && host != addr {
// Check if we need to specify TLS server name
tlsServerName = host
}
return NewDnsUpstream(endpoint, proto, tlsServerName)
}
func CreateResolver(bootstrap string) *net.Resolver {
bootstrapResolver := net.DefaultResolver
if bootstrap != "" {
bootstrapResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var d net.Dialer
conn, err := d.DialContext(ctx, network, bootstrap)
return conn, err
},
}
}
return bootstrapResolver
}
// Performs a simple health-check of the specified upstream
func IsAlive(u Upstream) (bool, error) {

View file

@ -27,7 +27,7 @@ type HttpsUpstream struct {
endpoint *url.URL
}
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from hostname
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url
func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) {
u, err := url.Parse(endpoint)
if err != nil {
@ -35,18 +35,7 @@ func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) {
}
// Initialize bootstrap resolver
bootstrapResolver := net.DefaultResolver
if bootstrap != "" {
bootstrapResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var d net.Dialer
conn, err := d.DialContext(ctx, network, bootstrap)
return conn, err
},
}
}
bootstrapResolver := CreateResolver(bootstrap)
dialer := &net.Dialer{
Timeout: defaultTimeout,
KeepAlive: defaultKeepAlive,

View file

@ -9,16 +9,17 @@ import (
func TestDnsUpstreamIsAlive(t *testing.T) {
var tests = []struct {
endpoint string
proto string
url string
bootstrap string
}{
{"8.8.8.8:53", "udp"},
{"8.8.8.8:53", "tcp"},
{"1.1.1.1:53", "udp"},
{"8.8.8.8:53", "8.8.8.8:53"},
{"1.1.1.1", ""},
{"tcp://1.1.1.1:53", ""},
{"176.103.130.130:5353", ""},
}
for _, test := range tests {
u, err := NewDnsUpstream(test.endpoint, test.proto, "")
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS upstream")
@ -36,11 +37,11 @@ func TestHttpsUpstreamIsAlive(t *testing.T) {
}{
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
{"https://dns.google.com/experimental", "8.8.8.8:53"},
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, // TODO: status 201??
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
}
for _, test := range tests {
u, err := NewHttpsUpstream(test.url, test.bootstrap)
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS-over-HTTPS upstream")
@ -53,16 +54,17 @@ func TestHttpsUpstreamIsAlive(t *testing.T) {
func TestDnsOverTlsIsAlive(t *testing.T) {
var tests = []struct {
endpoint string
tlsServerName string
url string
bootstrap string
}{
{"1.1.1.1:853", ""},
{"9.9.9.9:853", ""},
{"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"},
{"tls://1.1.1.1", ""},
{"tls://9.9.9.9:853", ""},
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
}
for _, test := range tests {
u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName)
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS-over-TLS upstream")
@ -75,16 +77,17 @@ func TestDnsOverTlsIsAlive(t *testing.T) {
func TestDnsUpstream(t *testing.T) {
var tests = []struct {
endpoint string
proto string
url string
bootstrap string
}{
{"8.8.8.8:53", "udp"},
{"8.8.8.8:53", "tcp"},
{"1.1.1.1:53", "udp"},
{"8.8.8.8:53", "8.8.8.8:53"},
{"1.1.1.1", ""},
{"tcp://1.1.1.1:53", ""},
{"176.103.130.130:5353", ""},
}
for _, test := range tests {
u, err := NewDnsUpstream(test.endpoint, test.proto, "")
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS upstream")
@ -106,7 +109,7 @@ func TestHttpsUpstream(t *testing.T) {
}
for _, test := range tests {
u, err := NewHttpsUpstream(test.url, test.bootstrap)
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS-over-HTTPS upstream")
@ -119,16 +122,17 @@ func TestHttpsUpstream(t *testing.T) {
func TestDnsOverTlsUpstream(t *testing.T) {
var tests = []struct {
endpoint string
tlsServerName string
url string
bootstrap string
}{
{"1.1.1.1:853", ""},
{"9.9.9.9:853", ""},
{"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"},
{"tls://1.1.1.1", ""},
{"tls://9.9.9.9:853", ""},
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
}
for _, test := range tests {
u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName)
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS-over-TLS upstream")