mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-25 22:45:46 +03:00
Added factory method for creating DNS upstreams
This commit is contained in:
parent
a6022fc198
commit
9bc4bf66ed
3 changed files with 112 additions and 41 deletions
|
@ -1,6 +1,84 @@
|
|||
package upstream
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Detects the upstream type from the specified url and creates a proper Upstream object
|
||||
func NewUpstream(url string, bootstrap string) (Upstream, error) {
|
||||
|
||||
proto := "udp"
|
||||
prefix := ""
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(url, "tcp://"):
|
||||
proto = "tcp"
|
||||
prefix = "tcp://"
|
||||
case strings.HasPrefix(url, "tls://"):
|
||||
proto = "tcp-tls"
|
||||
prefix = "tls://"
|
||||
case strings.HasPrefix(url, "https://"):
|
||||
return NewHttpsUpstream(url, bootstrap)
|
||||
}
|
||||
|
||||
hostname := strings.TrimPrefix(url, prefix)
|
||||
|
||||
host, port, err := net.SplitHostPort(hostname)
|
||||
if err != nil {
|
||||
// Set port depending on the protocol
|
||||
switch proto {
|
||||
case "udp":
|
||||
port = "53"
|
||||
case "tcp":
|
||||
port = "53"
|
||||
case "tcp-tls":
|
||||
port = "853"
|
||||
}
|
||||
|
||||
// Set host = hostname
|
||||
host = hostname
|
||||
}
|
||||
|
||||
// Try to resolve the host address (or check if it's an IP address)
|
||||
bootstrapResolver := CreateResolver(bootstrap)
|
||||
ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host)
|
||||
|
||||
if err != nil || len(ips) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addr := ips[0].String()
|
||||
endpoint := net.JoinHostPort(addr, port)
|
||||
tlsServerName := ""
|
||||
|
||||
if proto == "tcp-tls" && host != addr {
|
||||
// Check if we need to specify TLS server name
|
||||
tlsServerName = host
|
||||
}
|
||||
|
||||
return NewDnsUpstream(endpoint, proto, tlsServerName)
|
||||
}
|
||||
|
||||
func CreateResolver(bootstrap string) *net.Resolver {
|
||||
|
||||
bootstrapResolver := net.DefaultResolver
|
||||
|
||||
if bootstrap != "" {
|
||||
bootstrapResolver = &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
conn, err := d.DialContext(ctx, network, bootstrap)
|
||||
return conn, err
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return bootstrapResolver
|
||||
}
|
||||
|
||||
// Performs a simple health-check of the specified upstream
|
||||
func IsAlive(u Upstream) (bool, error) {
|
||||
|
|
|
@ -27,7 +27,7 @@ type HttpsUpstream struct {
|
|||
endpoint *url.URL
|
||||
}
|
||||
|
||||
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from hostname
|
||||
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url
|
||||
func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) {
|
||||
u, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
|
@ -35,18 +35,7 @@ func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) {
|
|||
}
|
||||
|
||||
// Initialize bootstrap resolver
|
||||
bootstrapResolver := net.DefaultResolver
|
||||
if bootstrap != "" {
|
||||
bootstrapResolver = &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
conn, err := d.DialContext(ctx, network, bootstrap)
|
||||
return conn, err
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
bootstrapResolver := CreateResolver(bootstrap)
|
||||
dialer := &net.Dialer{
|
||||
Timeout: defaultTimeout,
|
||||
KeepAlive: defaultKeepAlive,
|
||||
|
|
|
@ -9,16 +9,17 @@ import (
|
|||
func TestDnsUpstreamIsAlive(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
endpoint string
|
||||
proto string
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"8.8.8.8:53", "udp"},
|
||||
{"8.8.8.8:53", "tcp"},
|
||||
{"1.1.1.1:53", "udp"},
|
||||
{"8.8.8.8:53", "8.8.8.8:53"},
|
||||
{"1.1.1.1", ""},
|
||||
{"tcp://1.1.1.1:53", ""},
|
||||
{"176.103.130.130:5353", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewDnsUpstream(test.endpoint, test.proto, "")
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS upstream")
|
||||
|
@ -36,11 +37,11 @@ func TestHttpsUpstreamIsAlive(t *testing.T) {
|
|||
}{
|
||||
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
|
||||
{"https://dns.google.com/experimental", "8.8.8.8:53"},
|
||||
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, // TODO: status 201??
|
||||
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewHttpsUpstream(test.url, test.bootstrap)
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||
|
@ -53,16 +54,17 @@ func TestHttpsUpstreamIsAlive(t *testing.T) {
|
|||
func TestDnsOverTlsIsAlive(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
endpoint string
|
||||
tlsServerName string
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"1.1.1.1:853", ""},
|
||||
{"9.9.9.9:853", ""},
|
||||
{"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"},
|
||||
{"tls://1.1.1.1", ""},
|
||||
{"tls://9.9.9.9:853", ""},
|
||||
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
|
||||
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName)
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||
|
@ -75,16 +77,17 @@ func TestDnsOverTlsIsAlive(t *testing.T) {
|
|||
func TestDnsUpstream(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
endpoint string
|
||||
proto string
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"8.8.8.8:53", "udp"},
|
||||
{"8.8.8.8:53", "tcp"},
|
||||
{"1.1.1.1:53", "udp"},
|
||||
{"8.8.8.8:53", "8.8.8.8:53"},
|
||||
{"1.1.1.1", ""},
|
||||
{"tcp://1.1.1.1:53", ""},
|
||||
{"176.103.130.130:5353", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewDnsUpstream(test.endpoint, test.proto, "")
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS upstream")
|
||||
|
@ -106,7 +109,7 @@ func TestHttpsUpstream(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewHttpsUpstream(test.url, test.bootstrap)
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||
|
@ -119,16 +122,17 @@ func TestHttpsUpstream(t *testing.T) {
|
|||
func TestDnsOverTlsUpstream(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
endpoint string
|
||||
tlsServerName string
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"1.1.1.1:853", ""},
|
||||
{"9.9.9.9:853", ""},
|
||||
{"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"},
|
||||
{"tls://1.1.1.1", ""},
|
||||
{"tls://9.9.9.9:853", ""},
|
||||
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
|
||||
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName)
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||
|
|
Loading…
Reference in a new issue