Merge pull request #117 in DNS/adguard-dns from no_coredns to master

* commit '253d8a4016d66863ecee426b8f7d74841c4ed4de': (58 commits)
  Pointer for dnsfilter.Result in querylog didn't make things simpler, revert that change and all related changes.
  Fixup of previous commit -- remove unused import.
  Remove unused code.
  Use filter deduplication function.
  Small code review update -- use CamelCase
  readme -- Update config field descriptions and clarify about coredns.
  dnsforward -- fix panic on ANY request
  dnsfilter -- fix broken tests
  config -- Avoid deleting existing dns section if someone removes schema_version from yaml file.
  Rename coredns.go to dns.go
  Add support for bootstrapping upstream DNS servers by hostname.
  dnsforward -- support tcp:// schema
  dnsforward -- add upstream tests.
  Don't omit empty user rules in configfile -- otherwise users might not be able to find that it's customizable in configfile.
  Get rid of mentions of CoreDNS in code except for upgrading and in readme. Add config upgrade.
  dnsforward -- add a simple test that launches a server and queries well-known value through it
  Remove old entries from .gitignore
  Remove unused code. Goodbye CoreDNS.
  Use dnsforward for checking if upstream DNS server is working.
  dnsforward -- implement ratelimit and refuseany
  ...
This commit is contained in:
Andrey Meshkov 2018-12-06 17:29:36 +03:00
commit b5121c5754
46 changed files with 2453 additions and 2870 deletions

8
.gitignore vendored
View file

@ -1,15 +1,11 @@
.DS_Store
.vscode
.idea
debug
/.vscode
/.idea
/AdGuardHome
/AdGuardHome.yaml
/data/
/build/
/client/node_modules/
/coredns
/Corefile
/dnsfilter.txt
/querylog.json
/querylog.json.1
/scripts/translations/node_modules

View file

@ -19,9 +19,12 @@ client/node_modules: client/package.json client/package-lock.json
$(STATIC): $(JSFILES) client/node_modules
npm --prefix client run build-prod
$(TARGET): $(STATIC) *.go coredns_plugin/*.go dnsfilter/*.go
GOPATH=$(GOPATH) GOOS=$(NATIVE_GOOS) GOARCH=$(NATIVE_GOARCH) GO111MODULE=off go get -v github.com/gobuffalo/packr/...
GOPATH=$(GOPATH) PATH=$(GOPATH)/bin:$(PATH) packr build -ldflags="-X main.VersionString=$(GIT_VERSION)" -asmflags="-trimpath=$(PWD)" -gcflags="-trimpath=$(PWD)" -o $(TARGET)
$(TARGET): $(STATIC) *.go dnsfilter/*.go dnsforward/*.go
go get -d .
GOOS=$(NATIVE_GOOS) GOARCH=$(NATIVE_GOARCH) GO111MODULE=off go get -v github.com/gobuffalo/packr/...
PATH=$(GOPATH)/bin:$(PATH) packr -z
CGO_ENABLED=0 go build -ldflags="-s -w -X main.VersionString=$(GIT_VERSION)" -asmflags="-trimpath=$(PWD)" -gcflags="-trimpath=$(PWD)"
PATH=$(GOPATH)/bin:$(PATH) packr clean
clean:
$(MAKE) cleanfast

View file

@ -90,7 +90,7 @@ Now open the browser and navigate to http://localhost:3000/ to control your AdGu
You can run AdGuard Home without superuser privileges, but you need to instruct it to use a different port rather than 53. You can do that by editing `AdGuardHome.yaml` and finding these two lines:
```yaml
coredns:
dns:
port: 53
```
@ -104,25 +104,32 @@ Upon the first execution, a file named `AdGuardHome.yaml` will be created, with
Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possible parameters that you can configure are listed below:
* `bind_host` — Web interface IP address to listen on
* `bind_port` — Web interface IP port to listen on
* `auth_name` — Web interface optional authorization username
* `auth_pass` — Web interface optional authorization password
* `coredns` — CoreDNS configuration section
* `port` — DNS server port to listen on
* `filtering_enabled` — Filtering of DNS requests based on filter lists
* `safebrowsing_enabled` — Filtering of DNS requests based on safebrowsing
* `safesearch_enabled` — Enforcing "Safe search" option for search engines, when possible
* `parental_enabled` — Parental control-based DNS requests filtering
* `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17
* `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes)
* `bootstrap_dns` — DNS server used for initial hostnames resolution in case if upstream is DoH or DoT with a hostname
* `upstream_dns` — List of upstream DNS servers
* `bind_host` — Web interface IP address to listen on.
* `bind_port` — Web interface IP port to listen on.
* `auth_name` — Web interface optional authorization username.
* `auth_pass` — Web interface optional authorization password.
* `dns` — DNS configuration section.
* `port` — DNS server port to listen on.
* `protection_enabled` — Whether any kind of filtering and protection should be done, when off it works as a plain dns forwarder.
* `filtering_enabled` — Filtering of DNS requests based on filter lists.
* `blocked_response_ttl` — For how many seconds the clients should cache a filtered response. Low values are useful on LAN if you change filters very often, high values are useful to increase performance and save traffic.
* `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistical purposes).
* `ratelimit` — DDoS protection, specifies in how many packets per second a client should receive. Anything above that is silently dropped. To disable set 0, default is 20. Safe to disable if DNS server is not available from internet.
* `ratelimit_whitelist` — If you want exclude some IP addresses from ratelimiting but keep ratelimiting on for others, put them here.
* `refuse_any` — Another DDoS protection mechanism. Requests of type ANY are rarely needed, so refusing to serve them mitigates against attackers trying to use your DNS as a reflection. Safe to disable if DNS server is not available from internet.
* `bootstrap_dns` — DNS server used for initial hostname resolution in case if upstream server name is a hostname.
* `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17 if enabled.
* `parental_enabled` — Parental control-based DNS requests filtering.
* `safesearch_enabled` — Enforcing "Safe search" option for search engines, when possible.
* `safebrowsing_enabled` — Filtering of DNS requests based on safebrowsing.
* `upstream_dns` — List of upstream DNS servers.
* `filters` — List of filters, each filter has the following values:
* `ID` - filter ID (must be unique)
* `url` — URL pointing to the filter contents (filtering rules)
* `enabled` — Current filter's status (enabled/disabled)
* `user_rules` — User-specified filtering rules
* `enabled` — Current filter's status (enabled/disabled).
* `url` — URL pointing to the filter contents (filtering rules).
* `name` — Name of the filter. If it's an adguard syntax filter it will get updated automatically, otherwise it stays unchanged.
* `last_updated` — Time when the filter was last updated from server.
* `ID` - filter ID (must be unique).
* `user_rules` — User-specified filtering rules.
Removing an entry from settings file will reset it to the default value. Deleting the file will reset all settings to the default values.
@ -151,7 +158,15 @@ cd AdGuardHome
make
```
## How to update translations
## Contributing
You are welcome to fork this repository, make your changes and submit a pull request — https://github.com/AdguardTeam/AdGuardHome/pulls
### How to update translations
If you want to help with AdGuard Home translations, please learn more about translating AdGuard products here: https://kb.adguard.com/en/general/adguard-translations
Here is a direct link to AdGuard Home project: http://translate.adguard.com/collaboration/project?id=153384
Before updating translations you need to install dependencies:
```
@ -181,14 +196,6 @@ node upload.js
node download.js
```
## Contributing
You are welcome to fork this repository, make your changes and submit a pull request — https://github.com/AdguardTeam/AdGuardHome/pulls
If you want to help with AdGuard Home translations, please learn more about translating AdGuard products here: https://kb.adguard.com/en/general/adguard-translations
Here is a direct link to AdGuard Home project: http://translate.adguard.com/collaboration/project?id=153384
## Reporting issues
If you run into any problem or have a suggestion, head to [this page](https://github.com/AdguardTeam/AdGuardHome/issues) and click on the `New issue` button.
@ -198,7 +205,6 @@ If you run into any problem or have a suggestion, head to [this page](https://gi
This software wouldn't have been possible without:
* [Go](https://golang.org/dl/) and it's libraries:
* [CoreDNS](https://coredns.io)
* [packr](https://github.com/gobuffalo/packr)
* [gcache](https://github.com/bluele/gcache)
* [miekg's dns](https://github.com/miekg/dns)
@ -209,4 +215,6 @@ This software wouldn't have been possible without:
* And many more node.js packages.
* [whotracks.me data](https://github.com/cliqz-oss/whotracks.me)
You might have seen that [CoreDNS](https://coredns.io) was mentioned here before — we've stopped using it in AdGuardHome. While we still use it on our servers for [AdGuard DNS](https://adguard.com/adguard-dns/overview.html) service, it seemed like an overkill for Home as it impeded with Home features that we plan to implement.
For a full list of all node.js packages in use, please take a look at [client/package.json](https://github.com/AdguardTeam/AdGuardHome/blob/master/client/package.json) file.

22
app.go
View file

@ -7,8 +7,10 @@ import (
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"strconv"
"syscall"
"time"
"github.com/gobuffalo/packr"
@ -149,7 +151,7 @@ func main() {
log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err)
// clear LastUpdated so it gets fetched right away
}
if len(filter.Contents) == 0 {
if len(filter.Rules) == 0 {
filter.LastUpdated = time.Time{}
}
}
@ -164,10 +166,13 @@ func main() {
}
}()
// Eat all args so that coredns can start happily
if len(os.Args) > 1 {
os.Args = os.Args[:1]
}
signalChannel := make(chan os.Signal)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-signalChannel
cleanup()
os.Exit(0)
}()
// Save the updated config
err := config.write()
@ -192,6 +197,13 @@ func main() {
log.Fatal(http.ListenAndServe(address, nil))
}
func cleanup() {
err := stopDNSServer()
if err != nil {
log.Printf("Couldn't stop DNS server: %s", err)
}
}
func getInput() (string, error) {
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()

224
config.go
View file

@ -1,29 +1,22 @@
package main
import (
"bytes"
"io/ioutil"
"log"
"os"
"path/filepath"
"regexp"
"sync"
"text/template"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"gopkg.in/yaml.v2"
)
const (
currentSchemaVersion = 1 // used for upgrading from old configs to new config
dataDir = "data" // data storage
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
userFilterID = 0 // special filter ID, always 0
)
// Just a counter that we use for incrementing the filter ID
var nextFilterID int64 = time.Now().Unix()
// configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here
type configuration struct {
@ -35,9 +28,9 @@ type configuration struct {
AuthName string `yaml:"auth_name"`
AuthPass string `yaml:"auth_pass"`
Language string `yaml:"language"` // two-letter ISO 639-1 language code
CoreDNS coreDNSConfig `yaml:"coredns"`
DNS dnsConfig `yaml:"dns"`
Filters []filter `yaml:"filters"`
UserRules []string `yaml:"user_rules,omitempty"`
UserRules []string `yaml:"user_rules"`
sync.RWMutex `yaml:"-"`
@ -45,40 +38,14 @@ type configuration struct {
}
// field ordering is important -- yaml fields will mirror ordering from here
type coreDNSConfig struct {
binaryFile string
coreFile string
Filters []filter `yaml:"-"`
type dnsConfig struct {
Port int `yaml:"port"`
ProtectionEnabled bool `yaml:"protection_enabled"`
FilteringEnabled bool `yaml:"filtering_enabled"`
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
ParentalEnabled bool `yaml:"parental_enabled"`
ParentalSensitivity int `yaml:"parental_sensitivity"`
BlockedResponseTTL int `yaml:"blocked_response_ttl"`
QueryLogEnabled bool `yaml:"querylog_enabled"`
Ratelimit int `yaml:"ratelimit"`
RefuseAny bool `yaml:"refuse_any"`
Pprof string `yaml:"-"`
Cache string `yaml:"-"`
Prometheus string `yaml:"-"`
BootstrapDNS string `yaml:"bootstrap_dns"`
dnsforward.FilteringConfig `yaml:",inline"`
UpstreamDNS []string `yaml:"upstream_dns"`
}
// field ordering is important -- yaml fields will mirror ordering from here
type filter struct {
Enabled bool `json:"enabled"`
URL string `json:"url"`
Name string `json:"name" yaml:"name"`
RulesCount int `json:"rulesCount" yaml:"-"`
LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"`
ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase
Contents []byte `json:"-" yaml:"-"` // not in yaml or json
}
var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"}
// initialize to default values, will be changed later when reading config or parsing command line
@ -86,47 +53,26 @@ var config = configuration{
ourConfigFilename: "AdGuardHome.yaml",
BindPort: 3000,
BindHost: "127.0.0.1",
CoreDNS: coreDNSConfig{
DNS: dnsConfig{
Port: 53,
binaryFile: "coredns", // only filename, no path
coreFile: "Corefile", // only filename, no path
ProtectionEnabled: true,
FilteringEnabled: true,
SafeBrowsingEnabled: false,
FilteringConfig: dnsforward.FilteringConfig{
ProtectionEnabled: true, // whether or not use any of dnsfilter features
FilteringEnabled: true, // whether or not use filter lists
BlockedResponseTTL: 10, // in seconds
QueryLogEnabled: true,
Ratelimit: 20,
RefuseAny: true,
BootstrapDNS: "8.8.8.8:53",
},
UpstreamDNS: defaultDNS,
Cache: "cache",
Prometheus: "prometheus :9153",
},
Filters: []filter{
{ID: 1, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"},
{ID: 2, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"},
{ID: 3, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"},
{ID: 4, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"},
{Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"},
{Filter: dnsfilter.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"},
{Filter: dnsfilter.Filter{ID: 3}, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"},
{Filter: dnsfilter.Filter{ID: 4}, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"},
},
}
// Creates a helper object for working with the user rules
func userFilter() filter {
// TODO: This should be calculated when UserRules are set
var contents []byte
for _, rule := range config.UserRules {
contents = append(contents, []byte(rule)...)
contents = append(contents, '\n')
}
userFilter := filter{
// User filter always has constant ID=0
ID: userFilterID,
Contents: contents,
Enabled: true,
}
return userFilter
SchemaVersion: currentSchemaVersion,
}
// Loads configuration from the YAML file
@ -150,20 +96,7 @@ func parseConfig() error {
}
// Deduplicate filters
{
i := 0 // output index, used for deletion later
urls := map[string]bool{}
for _, filter := range config.Filters {
if _, ok := urls[filter.URL]; !ok {
// we didn't see it before, keep it
urls[filter.URL] = true // remember the URL
config.Filters[i] = filter
i++
}
}
// all entries we want to keep are at front, delete the rest
config.Filters = config.Filters[:i]
}
deduplicateFilters()
updateUniqueFilterID(config.Filters)
@ -187,6 +120,16 @@ func (c *configuration) write() error {
return err
}
return nil
}
func writeAllConfigs() error {
err := config.write()
if err != nil {
log.Printf("Couldn't write config: %s", err)
return err
}
userFilter := userFilter()
err = userFilter.save()
if err != nil {
@ -196,112 +139,3 @@ func (c *configuration) write() error {
return nil
}
// --------------
// coredns config
// --------------
func writeCoreDNSConfig() error {
coreFile := filepath.Join(config.ourBinaryDir, config.CoreDNS.coreFile)
log.Printf("Writing DNS config: %s", coreFile)
configText, err := generateCoreDNSConfigText()
if err != nil {
log.Printf("Couldn't generate DNS config: %s", err)
return err
}
err = safeWriteFile(coreFile, []byte(configText))
if err != nil {
log.Printf("Couldn't save DNS config: %s", err)
return err
}
return nil
}
func writeAllConfigs() error {
err := config.write()
if err != nil {
log.Printf("Couldn't write our config: %s", err)
return err
}
err = writeCoreDNSConfig()
if err != nil {
log.Printf("Couldn't write DNS config: %s", err)
return err
}
return nil
}
const coreDNSConfigTemplate = `.:{{.Port}} {
{{if .ProtectionEnabled}}dnsfilter {
{{if .SafeBrowsingEnabled}}safebrowsing{{end}}
{{if .ParentalEnabled}}parental {{.ParentalSensitivity}}{{end}}
{{if .SafeSearchEnabled}}safesearch{{end}}
{{if .QueryLogEnabled}}querylog{{end}}
blocked_ttl {{.BlockedResponseTTL}}
{{if .FilteringEnabled}}{{range .Filters}}{{if and .Enabled .Contents}}
filter {{.ID}} "{{.Path}}"
{{end}}{{end}}{{end}}
}{{end}}
{{.Pprof}}
{{if .RefuseAny}}refuseany{{end}}
{{if gt .Ratelimit 0}}ratelimit {{.Ratelimit}}{{end}}
hosts {
fallthrough
}
{{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}}
{{.Cache}}
{{.Prometheus}}
}
`
var removeEmptyLines = regexp.MustCompile("([\t ]*\n)+")
// generate CoreDNS config text
func generateCoreDNSConfigText() (string, error) {
t, err := template.New("config").Parse(coreDNSConfigTemplate)
if err != nil {
log.Printf("Couldn't generate DNS config: %s", err)
return "", err
}
var configBytes bytes.Buffer
temporaryConfig := config.CoreDNS
// generate temporary filter list, needed to put userfilter in coredns config
filters := []filter{}
// first of all, append the user filter
userFilter := userFilter()
filters = append(filters, userFilter)
// then go through other filters
filters = append(filters, config.Filters...)
temporaryConfig.Filters = filters
// run the template
err = t.Execute(&configBytes, &temporaryConfig)
if err != nil {
log.Printf("Couldn't generate DNS config: %s", err)
return "", err
}
configText := configBytes.String()
// remove empty lines from generated config
configText = removeEmptyLines.ReplaceAllString(configText, "\n")
return configText, nil
}
// Set the next filter ID to max(filter.ID) + 1
func updateUniqueFilterID(filters []filter) {
for _, filter := range filters {
if nextFilterID < filter.ID {
nextFilterID = filter.ID + 1
}
}
}
func assignUniqueFilterID() int64 {
value := nextFilterID
nextFilterID += 1
return value
}

View file

@ -1,29 +1,25 @@
package main
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/upstream"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/miekg/dns"
corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
"gopkg.in/asaskevich/govalidator.v4"
)
const updatePeriod = time.Minute * 30
var filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
// cached version.json to avoid hammering github.io for each page reload
var versionCheckJSON []byte
var versionCheckLastTime time.Time
@ -36,24 +32,20 @@ var client = &http.Client{
}
// -------------------
// coredns run control
// dns run control
// -------------------
func tellCoreDNSToReload() {
corednsplugin.Reload <- true
}
func writeAllConfigsAndReloadCoreDNS() error {
func writeAllConfigsAndReloadDNS() error {
err := writeAllConfigs()
if err != nil {
log.Printf("Couldn't write all configs: %s", err)
return err
}
tellCoreDNSToReload()
reconfigureDNSServer()
return nil
}
func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) {
err := writeAllConfigsAndReloadCoreDNS()
err := writeAllConfigsAndReloadDNS()
if err != nil {
errortext := fmt.Sprintf("Couldn't write config file: %s", err)
log.Println(errortext)
@ -75,12 +67,12 @@ func returnOK(w http.ResponseWriter, r *http.Request) {
func handleStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"dns_address": config.BindHost,
"dns_port": config.CoreDNS.Port,
"protection_enabled": config.CoreDNS.ProtectionEnabled,
"querylog_enabled": config.CoreDNS.QueryLogEnabled,
"dns_port": config.DNS.Port,
"protection_enabled": config.DNS.ProtectionEnabled,
"querylog_enabled": config.DNS.QueryLogEnabled,
"running": isRunning(),
"bootstrap_dns": config.CoreDNS.BootstrapDNS,
"upstream_dns": config.CoreDNS.UpstreamDNS,
"bootstrap_dns": config.DNS.BootstrapDNS,
"upstream_dns": config.DNS.UpstreamDNS,
"version": VersionString,
"language": config.Language,
}
@ -103,12 +95,12 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
}
func handleProtectionEnable(w http.ResponseWriter, r *http.Request) {
config.CoreDNS.ProtectionEnabled = true
config.DNS.ProtectionEnabled = true
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleProtectionDisable(w http.ResponseWriter, r *http.Request) {
config.CoreDNS.ProtectionEnabled = false
config.DNS.ProtectionEnabled = false
httpUpdateConfigReloadDNSReturnOK(w, r)
}
@ -116,12 +108,12 @@ func handleProtectionDisable(w http.ResponseWriter, r *http.Request) {
// stats
// -----
func handleQueryLogEnable(w http.ResponseWriter, r *http.Request) {
config.CoreDNS.QueryLogEnabled = true
config.DNS.QueryLogEnabled = true
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) {
config.CoreDNS.QueryLogEnabled = false
config.DNS.QueryLogEnabled = false
httpUpdateConfigReloadDNSReturnOK(w, r)
}
@ -143,9 +135,9 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
hosts := strings.Fields(string(body))
if len(hosts) == 0 {
config.CoreDNS.UpstreamDNS = defaultDNS
config.DNS.UpstreamDNS = defaultDNS
} else {
config.CoreDNS.UpstreamDNS = hosts
config.DNS.UpstreamDNS = hosts
}
err = writeAllConfigs()
@ -155,7 +147,7 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
http.Error(w, errorText, http.StatusInternalServerError)
return
}
tellCoreDNSToReload()
reconfigureDNSServer()
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
@ -211,23 +203,32 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
}
func checkDNS(input string) error {
u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS)
log.Printf("Checking if DNS %s works...", input)
u, err := dnsforward.AddressToUpstream(input, "")
if err != nil {
return err
return fmt.Errorf("Failed to choose upstream for %s: %s", input, err)
}
defer u.Close()
alive, err := upstream.IsAlive(u)
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
reply, err := u.Exchange(&req)
if err != nil {
return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err)
}
if !alive {
return fmt.Errorf("DNS server has not passed the healthcheck: %s", input)
if len(reply.Answer) != 1 {
return fmt.Errorf("DNS server %s returned wrong answer", input)
}
if t, ok := reply.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A)
}
}
log.Printf("DNS %s works OK", input)
return nil
}
@ -242,7 +243,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
resp, err := client.Get(versionCheckURL)
if err != nil {
errortext := fmt.Sprintf("Couldn't get querylog from coredns: %T %s\n", err, err)
errortext := fmt.Sprintf("Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
log.Println(errortext)
http.Error(w, errortext, http.StatusBadGateway)
return
@ -254,7 +255,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
// read the body entirely
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
errortext := fmt.Sprintf("Couldn't read response body: %s", err)
errortext := fmt.Sprintf("Couldn't read response body from %s: %s", versionCheckURL, err)
log.Println(errortext)
http.Error(w, errortext, http.StatusBadGateway)
return
@ -277,18 +278,18 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
// ---------
func handleFilteringEnable(w http.ResponseWriter, r *http.Request) {
config.CoreDNS.FilteringEnabled = true
config.DNS.FilteringEnabled = true
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleFilteringDisable(w http.ResponseWriter, r *http.Request) {
config.CoreDNS.FilteringEnabled = false
config.DNS.FilteringEnabled = false
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": config.CoreDNS.FilteringEnabled,
"enabled": config.DNS.FilteringEnabled,
}
config.RLock()
@ -376,7 +377,8 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
return
}
// URL is deemed valid, append it to filters, update config, write new filter file and tell coredns to reload it
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
// TODO: since we directly feed filters in-memory, revisit if writing configs is always neccessary
config.Filters = append(config.Filters, filter)
err = writeAllConfigs()
if err != nil {
@ -386,7 +388,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
return
}
tellCoreDNSToReload()
reconfigureDNSServer()
_, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount)
if err != nil {
@ -531,199 +533,23 @@ func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "OK %d filters updated\n", updated)
}
// Sets up a timer that will be checking for filters updates periodically
func periodicallyRefreshFilters() {
for range time.Tick(time.Minute) {
refreshFiltersIfNeccessary(false)
}
}
// Checks filters updates if necessary
// If force is true, it ignores the filter.LastUpdated field value
func refreshFiltersIfNeccessary(force bool) int {
config.Lock()
// fetch URLs
updateCount := 0
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we will be operating on a copy
if filter.ID == 0 { // protect against users modifying the yaml and removing the ID
filter.ID = assignUniqueFilterID()
}
updated, err := filter.update(force)
if err != nil {
log.Printf("Failed to update filter %s: %s\n", filter.URL, err)
continue
}
if updated {
// Saving it to the filters dir now
err = filter.save()
if err != nil {
log.Printf("Failed to save the updated filter %d: %s", filter.ID, err)
continue
}
updateCount++
}
}
config.Unlock()
if updateCount > 0 {
tellCoreDNSToReload()
}
return updateCount
}
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
func parseFilterContents(contents []byte) (int, string) {
lines := strings.Split(string(contents), "\n")
rulesCount := 0
name := ""
seenTitle := false
// Count lines in the filter
for _, line := range lines {
line = strings.TrimSpace(line)
if len(line) > 0 && line[0] == '!' {
if m := filterTitleRegexp.FindAllStringSubmatch(line, -1); len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
name = m[0][1]
seenTitle = true
}
} else if len(line) != 0 {
rulesCount++
}
}
return rulesCount, name
}
// Checks for filters updates
// If "force" is true -- does not check the filter's LastUpdated field
// Call "save" to persist the filter contents
func (filter *filter) update(force bool) (bool, error) {
if filter.ID == 0 { // protect against users deleting the ID
filter.ID = assignUniqueFilterID()
}
if !filter.Enabled {
return false, nil
}
if !force && time.Since(filter.LastUpdated) <= updatePeriod {
return false, nil
}
log.Printf("Downloading update for filter %d from %s", filter.ID, filter.URL)
// use the same update period for failed filter downloads to avoid flooding with requests
filter.LastUpdated = time.Now()
resp, err := client.Get(filter.URL)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err)
return false, err
}
if resp.StatusCode != 200 {
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
}
contentType := strings.ToLower(resp.Header.Get("content-type"))
if !strings.HasPrefix(contentType, "text/plain") {
log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL)
return false, fmt.Errorf("non-text response %s", contentType)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err)
return false, err
}
// Extract filter name and count number of rules
rulesCount, filterName := parseFilterContents(body)
if filterName != "" {
filter.Name = filterName
}
// Check if the filter has been really changed
if bytes.Equal(filter.Contents, body) {
log.Printf("The filter %d text has not changed", filter.ID)
return false, nil
}
log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount)
filter.RulesCount = rulesCount
filter.Contents = body
return true, nil
}
// saves filter contents to the file in dataDir
func (filter *filter) save() error {
filterFilePath := filter.Path()
log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
return safeWriteFile(filterFilePath, filter.Contents)
}
// loads filter contents from the file in dataDir
func (filter *filter) load() error {
if !filter.Enabled {
// No need to load a filter that is not enabled
return nil
}
filterFilePath := filter.Path()
log.Printf("Loading filter %d contents to: %s", filter.ID, filterFilePath)
if _, err := os.Stat(filterFilePath); os.IsNotExist(err) {
// do nothing, file doesn't exist
return err
}
filterFileContents, err := ioutil.ReadFile(filterFilePath)
if err != nil {
return err
}
log.Printf("Filter %d length is %d", filter.ID, len(filterFileContents))
filter.Contents = filterFileContents
// Now extract the rules count
rulesCount, _ := parseFilterContents(filter.Contents)
filter.RulesCount = rulesCount
return nil
}
// Path to the filter contents
func (filter *filter) Path() string {
return filepath.Join(config.ourBinaryDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
}
// ------------
// safebrowsing
// ------------
func handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
config.CoreDNS.SafeBrowsingEnabled = true
config.DNS.SafeBrowsingEnabled = true
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) {
config.CoreDNS.SafeBrowsingEnabled = false
config.DNS.SafeBrowsingEnabled = false
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": config.CoreDNS.SafeBrowsingEnabled,
"enabled": config.DNS.SafeBrowsingEnabled,
}
jsonVal, err := json.Marshal(data)
if err != nil {
@ -786,22 +612,22 @@ func handleParentalEnable(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Sensitivity must be set to valid value", 400)
return
}
config.CoreDNS.ParentalSensitivity = i
config.CoreDNS.ParentalEnabled = true
config.DNS.ParentalSensitivity = i
config.DNS.ParentalEnabled = true
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleParentalDisable(w http.ResponseWriter, r *http.Request) {
config.CoreDNS.ParentalEnabled = false
config.DNS.ParentalEnabled = false
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleParentalStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": config.CoreDNS.ParentalEnabled,
"enabled": config.DNS.ParentalEnabled,
}
if config.CoreDNS.ParentalEnabled {
data["sensitivity"] = config.CoreDNS.ParentalSensitivity
if config.DNS.ParentalEnabled {
data["sensitivity"] = config.DNS.ParentalSensitivity
}
jsonVal, err := json.Marshal(data)
if err != nil {
@ -826,18 +652,18 @@ func handleParentalStatus(w http.ResponseWriter, r *http.Request) {
// ------------
func handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) {
config.CoreDNS.SafeSearchEnabled = true
config.DNS.SafeSearchEnabled = true
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) {
config.CoreDNS.SafeSearchEnabled = false
config.DNS.SafeSearchEnabled = false
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": config.CoreDNS.SafeSearchEnabled,
"enabled": config.DNS.SafeSearchEnabled,
}
jsonVal, err := json.Marshal(data)
if err != nil {
@ -861,17 +687,17 @@ func registerControlHandlers() {
http.HandleFunc("/control/status", optionalAuth(ensureGET(handleStatus)))
http.HandleFunc("/control/enable_protection", optionalAuth(ensurePOST(handleProtectionEnable)))
http.HandleFunc("/control/disable_protection", optionalAuth(ensurePOST(handleProtectionDisable)))
http.HandleFunc("/control/querylog", optionalAuth(ensureGET(corednsplugin.HandleQueryLog)))
http.HandleFunc("/control/querylog", optionalAuth(ensureGET(dnsforward.HandleQueryLog)))
http.HandleFunc("/control/querylog_enable", optionalAuth(ensurePOST(handleQueryLogEnable)))
http.HandleFunc("/control/querylog_disable", optionalAuth(ensurePOST(handleQueryLogDisable)))
http.HandleFunc("/control/set_upstream_dns", optionalAuth(ensurePOST(handleSetUpstreamDNS)))
http.HandleFunc("/control/test_upstream_dns", optionalAuth(ensurePOST(handleTestUpstreamDNS)))
http.HandleFunc("/control/i18n/change_language", optionalAuth(ensurePOST(handleI18nChangeLanguage)))
http.HandleFunc("/control/i18n/current_language", optionalAuth(ensureGET(handleI18nCurrentLanguage)))
http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(corednsplugin.HandleStatsTop)))
http.HandleFunc("/control/stats", optionalAuth(ensureGET(corednsplugin.HandleStats)))
http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(corednsplugin.HandleStatsHistory)))
http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(corednsplugin.HandleStatsReset)))
http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(dnsforward.HandleStatsTop)))
http.HandleFunc("/control/stats", optionalAuth(ensureGET(dnsforward.HandleStats)))
http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(dnsforward.HandleStatsHistory)))
http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(dnsforward.HandleStatsReset)))
http.HandleFunc("/control/version.json", optionalAuth(handleGetVersionJSON))
http.HandleFunc("/control/filtering/enable", optionalAuth(ensurePOST(handleFilteringEnable)))
http.HandleFunc("/control/filtering/disable", optionalAuth(ensurePOST(handleFilteringDisable)))

View file

@ -1,132 +0,0 @@
package main
import (
"fmt"
"log"
"os"
"path/filepath"
"sync" // Include all plugins.
_ "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
_ "github.com/AdguardTeam/AdGuardHome/coredns_plugin/ratelimit"
_ "github.com/AdguardTeam/AdGuardHome/coredns_plugin/refuseany"
_ "github.com/AdguardTeam/AdGuardHome/upstream"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/coremain"
_ "github.com/coredns/coredns/plugin/auto"
_ "github.com/coredns/coredns/plugin/autopath"
_ "github.com/coredns/coredns/plugin/bind"
_ "github.com/coredns/coredns/plugin/cache"
_ "github.com/coredns/coredns/plugin/chaos"
_ "github.com/coredns/coredns/plugin/debug"
_ "github.com/coredns/coredns/plugin/dnssec"
_ "github.com/coredns/coredns/plugin/dnstap"
_ "github.com/coredns/coredns/plugin/erratic"
_ "github.com/coredns/coredns/plugin/errors"
_ "github.com/coredns/coredns/plugin/file"
_ "github.com/coredns/coredns/plugin/forward"
_ "github.com/coredns/coredns/plugin/health"
_ "github.com/coredns/coredns/plugin/hosts"
_ "github.com/coredns/coredns/plugin/loadbalance"
_ "github.com/coredns/coredns/plugin/log"
_ "github.com/coredns/coredns/plugin/loop"
_ "github.com/coredns/coredns/plugin/metadata"
_ "github.com/coredns/coredns/plugin/metrics"
_ "github.com/coredns/coredns/plugin/nsid"
_ "github.com/coredns/coredns/plugin/pprof"
_ "github.com/coredns/coredns/plugin/proxy"
_ "github.com/coredns/coredns/plugin/reload"
_ "github.com/coredns/coredns/plugin/rewrite"
_ "github.com/coredns/coredns/plugin/root"
_ "github.com/coredns/coredns/plugin/secondary"
_ "github.com/coredns/coredns/plugin/template"
_ "github.com/coredns/coredns/plugin/tls"
_ "github.com/coredns/coredns/plugin/whoami"
_ "github.com/mholt/caddy/onevent"
)
// Directives are registered in the order they should be
// executed.
//
// Ordering is VERY important. Every plugin will
// feel the effects of all other plugin below
// (after) them during a request, but they must not
// care what plugin above them are doing.
var directives = []string{
"metadata",
"tls",
"reload",
"nsid",
"root",
"bind",
"debug",
"health",
"pprof",
"prometheus",
"errors",
"log",
"refuseany",
"ratelimit",
"dnsfilter",
"dnstap",
"chaos",
"loadbalance",
"cache",
"rewrite",
"dnssec",
"autopath",
"template",
"hosts",
"file",
"auto",
"secondary",
"loop",
"forward",
"proxy",
"upstream",
"erratic",
"whoami",
"on",
}
func init() {
dnsserver.Directives = directives
}
var (
isCoreDNSRunningLock sync.Mutex
isCoreDNSRunning = false
)
func isRunning() bool {
isCoreDNSRunningLock.Lock()
value := isCoreDNSRunning
isCoreDNSRunningLock.Unlock()
return value
}
func startDNSServer() error {
isCoreDNSRunningLock.Lock()
if isCoreDNSRunning {
isCoreDNSRunningLock.Unlock()
return fmt.Errorf("Unable to start coreDNS: Already running")
}
isCoreDNSRunning = true
isCoreDNSRunningLock.Unlock()
configpath := filepath.Join(config.ourBinaryDir, config.CoreDNS.coreFile)
os.Args = os.Args[:1]
os.Args = append(os.Args, "-conf")
os.Args = append(os.Args, configpath)
err := writeCoreDNSConfig()
if err != nil {
errortext := fmt.Errorf("Unable to write coredns config: %s", err)
log.Println(errortext)
return errortext
}
go coremain.Run()
return nil
}

View file

@ -1,557 +0,0 @@
package dnsfilter
import (
"bufio"
"errors"
"fmt"
"log"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/request"
"github.com/mholt/caddy"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/net/context"
)
var defaultSOA = &dns.SOA{
// values copied from verisign's nonexistent .com domain
// their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers
Refresh: 1800,
Retry: 900,
Expire: 604800,
Minttl: 86400,
}
func init() {
caddy.RegisterPlugin("dnsfilter", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
type plugFilter struct {
ID int64
Path string
}
type plugSettings struct {
SafeBrowsingBlockHost string
ParentalBlockHost string
QueryLogEnabled bool
BlockedTTL uint32 // in seconds, default 3600
Filters []plugFilter
}
type plug struct {
d *dnsfilter.Dnsfilter
Next plugin.Handler
upstream upstream.Upstream
settings plugSettings
sync.RWMutex
}
var defaultPluginSettings = plugSettings{
SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com",
ParentalBlockHost: "family.block.dns.adguard.com",
BlockedTTL: 3600, // in seconds
Filters: make([]plugFilter, 0),
}
//
// coredns handling functions
//
func setupPlugin(c *caddy.Controller) (*plug, error) {
// create new Plugin and copy default values
p := &plug{
settings: defaultPluginSettings,
d: dnsfilter.New(),
}
log.Println("Initializing the CoreDNS plugin")
for c.Next() {
for c.NextBlock() {
blockValue := c.Val()
switch blockValue {
case "safebrowsing":
log.Println("Browsing security service is enabled")
p.d.EnableSafeBrowsing()
if c.NextArg() {
if len(c.Val()) == 0 {
return nil, c.ArgErr()
}
p.d.SetSafeBrowsingServer(c.Val())
}
case "safesearch":
log.Println("Safe search is enabled")
p.d.EnableSafeSearch()
case "parental":
if !c.NextArg() {
return nil, c.ArgErr()
}
sensitivity, err := strconv.Atoi(c.Val())
if err != nil {
return nil, c.ArgErr()
}
log.Println("Parental control is enabled")
err = p.d.EnableParental(sensitivity)
if err != nil {
return nil, c.ArgErr()
}
if c.NextArg() {
if len(c.Val()) == 0 {
return nil, c.ArgErr()
}
p.settings.ParentalBlockHost = c.Val()
}
case "blocked_ttl":
if !c.NextArg() {
return nil, c.ArgErr()
}
blockedTtl, err := strconv.ParseUint(c.Val(), 10, 32)
if err != nil {
return nil, c.ArgErr()
}
log.Printf("Blocked request TTL is %d", blockedTtl)
p.settings.BlockedTTL = uint32(blockedTtl)
case "querylog":
log.Println("Query log is enabled")
p.settings.QueryLogEnabled = true
case "filter":
if !c.NextArg() {
return nil, c.ArgErr()
}
filterId, err := strconv.ParseInt(c.Val(), 10, 64)
if err != nil {
return nil, c.ArgErr()
}
if !c.NextArg() {
return nil, c.ArgErr()
}
filterPath := c.Val()
// Initialize filter and add it to the list
p.settings.Filters = append(p.settings.Filters, plugFilter{
ID: filterId,
Path: filterPath,
})
}
}
}
for _, filter := range p.settings.Filters {
log.Printf("Loading rules from %s", filter.Path)
file, err := os.Open(filter.Path)
if err != nil {
return nil, err
}
defer file.Close()
count := 0
scanner := bufio.NewScanner(file)
for scanner.Scan() {
text := scanner.Text()
err = p.d.AddRule(text, filter.ID)
if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax {
continue
}
if err != nil {
log.Printf("Cannot add rule %s: %s", text, err)
// Just ignore invalid rules
continue
}
count++
}
log.Printf("Added %d rules from filter ID=%d", count, filter.ID)
if err = scanner.Err(); err != nil {
return nil, err
}
}
log.Printf("Loading stats from querylog")
err := fillStatsFromQueryLog()
if err != nil {
log.Printf("Failed to load stats from querylog: %s", err)
return nil, err
}
if p.settings.QueryLogEnabled {
onceQueryLog.Do(func() {
go periodicQueryLogRotate()
go periodicHourlyTopRotate()
go statsRotator()
})
}
onceHook.Do(func() {
caddy.RegisterEventHook("dnsfilter-reload", hook)
})
p.upstream, err = upstream.New(nil)
if err != nil {
return nil, err
}
return p, nil
}
func setup(c *caddy.Controller) error {
p, err := setupPlugin(c)
if err != nil {
return err
}
config := dnsserver.GetConfig(c)
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
p.Next = next
return p
})
c.OnStartup(func() error {
m := dnsserver.GetConfig(c).Handler("prometheus")
if m == nil {
return nil
}
if x, ok := m.(*metrics.Metrics); ok {
x.MustRegister(requests)
x.MustRegister(filtered)
x.MustRegister(filteredLists)
x.MustRegister(filteredSafebrowsing)
x.MustRegister(filteredParental)
x.MustRegister(whitelisted)
x.MustRegister(safesearch)
x.MustRegister(errorsTotal)
x.MustRegister(elapsedTime)
x.MustRegister(p)
}
return nil
})
c.OnShutdown(p.onShutdown)
c.OnFinalShutdown(p.onFinalShutdown)
return nil
}
func (p *plug) onShutdown() error {
p.Lock()
p.d.Destroy()
p.d = nil
p.Unlock()
return nil
}
func (p *plug) onFinalShutdown() error {
logBufferLock.Lock()
err := flushToFile(logBuffer)
if err != nil {
log.Printf("failed to flush to file: %s", err)
return err
}
logBufferLock.Unlock()
return nil
}
type statsFunc func(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType)
func doDesc(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) {
realch, ok := ch.(chan<- *prometheus.Desc)
if !ok {
log.Printf("Couldn't convert ch to chan<- *prometheus.Desc\n")
return
}
realch <- prometheus.NewDesc(name, text, nil, nil)
}
func doMetric(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) {
realch, ok := ch.(chan<- prometheus.Metric)
if !ok {
log.Printf("Couldn't convert ch to chan<- prometheus.Metric\n")
return
}
desc := prometheus.NewDesc(name, text, nil, nil)
realch <- prometheus.MustNewConstMetric(desc, valueType, value)
}
func gen(ch interface{}, doFunc statsFunc, name string, text string, value float64, valueType prometheus.ValueType) {
doFunc(ch, name, text, value, valueType)
}
func doStatsLookup(ch interface{}, doFunc statsFunc, name string, lookupstats *dnsfilter.LookupStats) {
gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_requests", name), fmt.Sprintf("Number of %s HTTP requests that were sent", name), float64(lookupstats.Requests), prometheus.CounterValue)
gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_cachehits", name), fmt.Sprintf("Number of %s lookups that didn't need HTTP requests", name), float64(lookupstats.CacheHits), prometheus.CounterValue)
gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending", name), fmt.Sprintf("Number of currently pending %s HTTP requests", name), float64(lookupstats.Pending), prometheus.GaugeValue)
gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending_max", name), fmt.Sprintf("Maximum number of pending %s HTTP requests", name), float64(lookupstats.PendingMax), prometheus.GaugeValue)
}
func (p *plug) doStats(ch interface{}, doFunc statsFunc) {
p.RLock()
stats := p.d.GetStats()
doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing)
doStatsLookup(ch, doFunc, "parental", &stats.Parental)
p.RUnlock()
}
// Describe is called by prometheus handler to know stat types
func (p *plug) Describe(ch chan<- *prometheus.Desc) {
p.doStats(ch, doDesc)
}
// Collect is called by prometheus handler to collect stats
func (p *plug) Collect(ch chan<- prometheus.Metric) {
p.doStats(ch, doMetric)
}
func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, host string, val string, question dns.Question) (int, error) {
// check if it's a domain name or IP address
addr := net.ParseIP(val)
var records []dns.RR
// log.Println("Will give", val, "instead of", host) // debug logging
if addr != nil {
// this is an IP address, return it
result, err := dns.NewRR(fmt.Sprintf("%s %d A %s", host, p.settings.BlockedTTL, val))
if err != nil {
log.Printf("Got error %s\n", err)
return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err)
}
records = append(records, result)
} else {
// this is a domain name, need to look it up
req := new(dns.Msg)
req.SetQuestion(dns.Fqdn(val), question.Qtype)
req.RecursionDesired = true
reqstate := request.Request{W: w, Req: req, Context: ctx}
result, err := p.upstream.Lookup(reqstate, dns.Fqdn(val), reqstate.QType())
if err != nil {
log.Printf("Got error %s\n", err)
return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err)
}
if result != nil {
for _, answer := range result.Answer {
answer.Header().Name = question.Name
}
records = result.Answer
}
}
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true
m.Answer = append(m.Answer, records...)
state := request.Request{W: w, Req: r, Context: ctx}
state.SizeAndDo(m)
err := state.W.WriteMsg(m)
if err != nil {
log.Printf("Got error %s\n", err)
return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err)
}
return dns.RcodeSuccess, nil
}
// generate SOA record that makes DNS clients cache NXdomain results
// the only value that is important is TTL in header, other values like refresh, retry, expire and minttl are irrelevant
func (p *plug) genSOA(r *dns.Msg) []dns.RR {
zone := r.Question[0].Name
header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: p.settings.BlockedTTL, Class: dns.ClassINET}
Mbox := "hostmaster."
if zone[0] != '.' {
Mbox += zone
}
Ns := "fake-for-negative-caching.adguard.com."
soa := *defaultSOA
soa.Hdr = header
soa.Mbox = Mbox
soa.Ns = Ns
soa.Serial = 100500 // faster than uint32(time.Now().Unix())
return []dns.RR{&soa}
}
func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r, Context: ctx}
m := new(dns.Msg)
m.SetRcode(state.Req, dns.RcodeNameError)
m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true
m.Ns = p.genSOA(r)
state.SizeAndDo(m)
err := state.W.WriteMsg(m)
if err != nil {
log.Printf("Got error %s\n", err)
return dns.RcodeServerFailure, err
}
return dns.RcodeNameError, nil
}
func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, dnsfilter.Result, error) {
if len(r.Question) != 1 {
// google DNS, bind and others do the same
return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("got a DNS request with more than one Question")
}
for _, question := range r.Question {
host := strings.ToLower(strings.TrimSuffix(question.Name, "."))
// is it a safesearch domain?
p.RLock()
if val, ok := p.d.SafeSearchDomain(host); ok {
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
if err != nil {
p.RUnlock()
return rcode, dnsfilter.Result{}, err
}
p.RUnlock()
return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err
}
p.RUnlock()
// needs to be filtered instead
p.RLock()
result, err := p.d.CheckHost(host)
if err != nil {
log.Printf("plugin/dnsfilter: %s\n", err)
p.RUnlock()
return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err)
}
p.RUnlock()
if result.IsFiltered {
switch result.Reason {
case dnsfilter.FilteredSafeBrowsing:
// return cname safebrowsing.block.dns.adguard.com
val := p.settings.SafeBrowsingBlockHost
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
if err != nil {
return rcode, dnsfilter.Result{}, err
}
return rcode, result, err
case dnsfilter.FilteredParental:
// return cname family.block.dns.adguard.com
val := p.settings.ParentalBlockHost
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
if err != nil {
return rcode, dnsfilter.Result{}, err
}
return rcode, result, err
case dnsfilter.FilteredBlackList:
if result.Ip == nil {
// return NXDomain
rcode, err := p.writeNXdomain(ctx, w, r)
if err != nil {
return rcode, dnsfilter.Result{}, err
}
return rcode, result, err
} else {
// This is a hosts-syntax rule
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, result.Ip.String(), question)
if err != nil {
return rcode, dnsfilter.Result{}, err
}
return rcode, result, err
}
case dnsfilter.FilteredInvalid:
// return NXdomain
rcode, err := p.writeNXdomain(ctx, w, r)
if err != nil {
return rcode, dnsfilter.Result{}, err
}
return rcode, result, err
default:
log.Printf("SHOULD NOT HAPPEN -- got unknown reason for filtering host \"%s\": %v, %+v", host, result.Reason, result)
}
} else {
switch result.Reason {
case dnsfilter.NotFilteredWhiteList:
rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
return rcode, result, err
case dnsfilter.NotFilteredNotFound:
// do nothing, pass through to lower code
default:
log.Printf("SHOULD NOT HAPPEN -- got unknown reason for not filtering host \"%s\": %v, %+v", host, result.Reason, result)
}
}
}
rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
return rcode, dnsfilter.Result{}, err
}
// ServeDNS handles the DNS request and refuses if it's in filterlists
func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
start := time.Now()
requests.Inc()
state := request.Request{W: w, Req: r}
ip := state.IP()
// capture the written answer
rrw := dnstest.NewRecorder(w)
rcode, result, err := p.serveDNSInternal(ctx, rrw, r)
if rcode > 0 {
// actually send the answer if we have one
answer := new(dns.Msg)
answer.SetRcode(r, rcode)
state.SizeAndDo(answer)
err = w.WriteMsg(answer)
if err != nil {
return dns.RcodeServerFailure, err
}
}
// increment counters
switch {
case err != nil:
errorsTotal.Inc()
case result.Reason == dnsfilter.FilteredBlackList:
filtered.Inc()
filteredLists.Inc()
case result.Reason == dnsfilter.FilteredSafeBrowsing:
filtered.Inc()
filteredSafebrowsing.Inc()
case result.Reason == dnsfilter.FilteredParental:
filtered.Inc()
filteredParental.Inc()
case result.Reason == dnsfilter.FilteredInvalid:
filtered.Inc()
filteredInvalid.Inc()
case result.Reason == dnsfilter.FilteredSafeSearch:
// the request was passsed through but not filtered, don't increment filtered
safesearch.Inc()
case result.Reason == dnsfilter.NotFilteredWhiteList:
whitelisted.Inc()
case result.Reason == dnsfilter.NotFilteredNotFound:
// do nothing
case result.Reason == dnsfilter.NotFilteredError:
text := "SHOULD NOT HAPPEN: got DNSFILTER_NOTFILTERED_ERROR without err != nil!"
log.Println(text)
err = errors.New(text)
rcode = dns.RcodeServerFailure
}
// log
elapsed := time.Since(start)
elapsedTime.Observe(elapsed.Seconds())
if p.settings.QueryLogEnabled {
logRequest(r, rrw.Msg, result, time.Since(start), ip)
}
return rcode, err
}
// Name returns name of the plugin as seen in Corefile and plugin.cfg
func (p *plug) Name() string { return "dnsfilter" }
var onceHook sync.Once
var onceQueryLog sync.Once

View file

@ -1,131 +0,0 @@
package dnsfilter
import (
"context"
"fmt"
"io/ioutil"
"net"
"os"
"testing"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/test"
"github.com/mholt/caddy"
"github.com/miekg/dns"
)
func TestSetup(t *testing.T) {
for i, testcase := range []struct {
config string
failing bool
}{
{`dnsfilter`, false},
{`dnsfilter {
filter 0 /dev/nonexistent/abcdef
}`, true},
{`dnsfilter {
filter 0 ../tests/dns.txt
}`, false},
{`dnsfilter {
safebrowsing
filter 0 ../tests/dns.txt
}`, false},
{`dnsfilter {
parental
filter 0 ../tests/dns.txt
}`, true},
} {
c := caddy.NewTestController("dns", testcase.config)
err := setup(c)
if err != nil {
if !testcase.failing {
t.Fatalf("Test #%d expected no errors, but got: %v", i, err)
}
continue
}
if testcase.failing {
t.Fatalf("Test #%d expected to fail but it didn't", i)
}
}
}
func TestEtcHostsFilter(t *testing.T) {
text := []byte("127.0.0.1 doubleclick.net\n" + "127.0.0.1 example.org example.net www.example.org www.example.net")
tmpfile, err := ioutil.TempFile("", "")
if err != nil {
t.Fatal(err)
}
if _, err = tmpfile.Write(text); err != nil {
t.Fatal(err)
}
if err = tmpfile.Close(); err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())
configText := fmt.Sprintf("dnsfilter {\nfilter 0 %s\n}", tmpfile.Name())
c := caddy.NewTestController("dns", configText)
p, err := setupPlugin(c)
if err != nil {
t.Fatal(err)
}
p.Next = zeroTTLBackend()
ctx := context.TODO()
for _, testcase := range []struct {
host string
filtered bool
}{
{"www.doubleclick.net", false},
{"doubleclick.net", true},
{"www2.example.org", false},
{"www2.example.net", false},
{"test.www.example.org", false},
{"test.www.example.net", false},
{"example.org", true},
{"example.net", true},
{"www.example.org", true},
{"www.example.net", true},
} {
req := new(dns.Msg)
req.SetQuestion(testcase.host+".", dns.TypeA)
resp := test.ResponseWriter{}
rrw := dnstest.NewRecorder(&resp)
rcode, err := p.ServeDNS(ctx, rrw, req)
if err != nil {
t.Fatalf("ServeDNS returned error: %s", err)
}
if rcode != rrw.Rcode {
t.Fatalf("ServeDNS return value for host %s has rcode %d that does not match captured rcode %d", testcase.host, rcode, rrw.Rcode)
}
A, ok := rrw.Msg.Answer[0].(*dns.A)
if !ok {
t.Fatalf("Host %s expected to have result A", testcase.host)
}
ip := net.IPv4(127, 0, 0, 1)
filtered := ip.Equal(A.A)
if testcase.filtered && testcase.filtered != filtered {
t.Fatalf("Host %s expected to be filtered, instead it is not filtered", testcase.host)
}
if !testcase.filtered && testcase.filtered != filtered {
t.Fatalf("Host %s expected to be not filtered, instead it is filtered", testcase.host)
}
}
}
func zeroTTLBackend() plugin.Handler {
return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
m := new(dns.Msg)
m.SetReply(r)
m.Response, m.RecursionAvailable = true, true
m.Answer = []dns.RR{test.A("example.org. 0 IN A 127.0.0.53")}
w.WriteMsg(m)
return dns.RcodeSuccess, nil
})
}

View file

@ -1,182 +0,0 @@
package ratelimit
import (
"errors"
"log"
"sort"
"strconv"
"time"
// ratelimiting and per-ip buckets
"github.com/beefsack/go-rate"
"github.com/patrickmn/go-cache"
// coredns plugin
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/request"
"github.com/mholt/caddy"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/net/context"
)
const defaultRatelimit = 30
const defaultResponseSize = 1000
var (
tokenBuckets = cache.New(time.Hour, time.Hour)
)
// ServeDNS handles the DNS request and refuses if it's an beyind specified ratelimit
func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
ip := state.IP()
allow, err := p.allowRequest(ip)
if err != nil {
return 0, err
}
if !allow {
ratelimited.Inc()
return 0, nil
}
// Record response to get status code and size of the reply.
rw := dnstest.NewRecorder(w)
status, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, rw, r)
size := rw.Len
if size > defaultResponseSize && state.Proto() == "udp" {
// For large UDP responses we call allowRequest more times
// The exact number of times depends on the response size
for i := 0; i < size/defaultResponseSize; i++ {
p.allowRequest(ip)
}
}
return status, err
}
func (p *plug) allowRequest(ip string) (bool, error) {
if len(p.whitelist) > 0 {
i := sort.SearchStrings(p.whitelist, ip)
if i < len(p.whitelist) && p.whitelist[i] == ip {
return true, nil
}
}
if _, found := tokenBuckets.Get(ip); !found {
tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour)
}
value, found := tokenBuckets.Get(ip)
if !found {
// should not happen since we've just inserted it
text := "SHOULD NOT HAPPEN: just-inserted ratelimiter disappeared"
log.Println(text)
err := errors.New(text)
return true, err
}
rl, ok := value.(*rate.RateLimiter)
if !ok {
text := "SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache"
log.Println(text)
err := errors.New(text)
return true, err
}
allow, _ := rl.Try()
return allow, nil
}
//
// helper functions
//
func init() {
caddy.RegisterPlugin("ratelimit", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
type plug struct {
Next plugin.Handler
// configuration for creating above
ratelimit int // in requests per second per IP
whitelist []string // a list of whitelisted IP addresses
}
func setupPlugin(c *caddy.Controller) (*plug, error) {
p := &plug{ratelimit: defaultRatelimit}
for c.Next() {
args := c.RemainingArgs()
if len(args) > 0 {
ratelimit, err := strconv.Atoi(args[0])
if err != nil {
return nil, c.ArgErr()
}
p.ratelimit = ratelimit
}
for c.NextBlock() {
switch c.Val() {
case "whitelist":
p.whitelist = c.RemainingArgs()
if len(p.whitelist) > 0 {
sort.Strings(p.whitelist)
}
}
}
}
return p, nil
}
func setup(c *caddy.Controller) error {
p, err := setupPlugin(c)
if err != nil {
return err
}
config := dnsserver.GetConfig(c)
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
p.Next = next
return p
})
c.OnStartup(func() error {
m := dnsserver.GetConfig(c).Handler("prometheus")
if m == nil {
return nil
}
if x, ok := m.(*metrics.Metrics); ok {
x.MustRegister(ratelimited)
}
return nil
})
return nil
}
func newDNSCounter(name string, help string) prometheus.Counter {
return prometheus.NewCounter(prometheus.CounterOpts{
Namespace: plugin.Namespace,
Subsystem: "ratelimit",
Name: name,
Help: help,
})
}
var (
ratelimited = newDNSCounter("dropped_total", "Count of requests that have been dropped because of rate limit")
)
// Name returns name of the plugin as seen in Corefile and plugin.cfg
func (p *plug) Name() string { return "ratelimit" }

View file

@ -1,80 +0,0 @@
package ratelimit
import (
"testing"
"github.com/mholt/caddy"
)
func TestSetup(t *testing.T) {
for i, testcase := range []struct {
config string
failing bool
}{
{`ratelimit`, false},
{`ratelimit 100`, false},
{`ratelimit {
whitelist 127.0.0.1
}`, false},
{`ratelimit 50 {
whitelist 127.0.0.1 176.103.130.130
}`, false},
{`ratelimit test`, true},
} {
c := caddy.NewTestController("dns", testcase.config)
err := setup(c)
if err != nil {
if !testcase.failing {
t.Fatalf("Test #%d expected no errors, but got: %v", i, err)
}
continue
}
if testcase.failing {
t.Fatalf("Test #%d expected to fail but it didn't", i)
}
}
}
func TestRatelimiting(t *testing.T) {
// rate limit is 1 per sec
c := caddy.NewTestController("dns", `ratelimit 1`)
p, err := setupPlugin(c)
if err != nil {
t.Fatal("Failed to initialize the plugin")
}
allowed, err := p.allowRequest("127.0.0.1")
if err != nil || !allowed {
t.Fatal("First request must have been allowed")
}
allowed, err = p.allowRequest("127.0.0.1")
if err != nil || allowed {
t.Fatal("Second request must have been ratelimited")
}
}
func TestWhitelist(t *testing.T) {
// rate limit is 1 per sec
c := caddy.NewTestController("dns", `ratelimit 1 { whitelist 127.0.0.2 127.0.0.1 127.0.0.125 }`)
p, err := setupPlugin(c)
if err != nil {
t.Fatal("Failed to initialize the plugin")
}
allowed, err := p.allowRequest("127.0.0.1")
if err != nil || !allowed {
t.Fatal("First request must have been allowed")
}
allowed, err = p.allowRequest("127.0.0.1")
if err != nil || !allowed {
t.Fatal("Second request must have been allowed due to whitelist")
}
}

View file

@ -1,91 +0,0 @@
package refuseany
import (
"fmt"
"log"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/request"
"github.com/mholt/caddy"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/net/context"
)
type plug struct {
Next plugin.Handler
}
// ServeDNS handles the DNS request and refuses if it's an ANY request
func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
if len(r.Question) != 1 {
// google DNS, bind and others do the same
return dns.RcodeFormatError, fmt.Errorf("Got DNS request with != 1 questions")
}
q := r.Question[0]
if q.Qtype == dns.TypeANY {
state := request.Request{W: w, Req: r, Context: ctx}
rcode := dns.RcodeNotImplemented
m := new(dns.Msg)
m.SetRcode(r, rcode)
state.SizeAndDo(m)
err := state.W.WriteMsg(m)
if err != nil {
log.Printf("Got error %s\n", err)
return dns.RcodeServerFailure, err
}
return rcode, nil
}
return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
}
func init() {
caddy.RegisterPlugin("refuseany", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
func setup(c *caddy.Controller) error {
p := &plug{}
config := dnsserver.GetConfig(c)
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
p.Next = next
return p
})
c.OnStartup(func() error {
m := dnsserver.GetConfig(c).Handler("prometheus")
if m == nil {
return nil
}
if x, ok := m.(*metrics.Metrics); ok {
x.MustRegister(ratelimited)
}
return nil
})
return nil
}
func newDNSCounter(name string, help string) prometheus.Counter {
return prometheus.NewCounter(prometheus.CounterOpts{
Namespace: plugin.Namespace,
Subsystem: "refuseany",
Name: name,
Help: help,
})
}
var (
ratelimited = newDNSCounter("refusedany_total", "Count of ANY requests that have been dropped")
)
// Name returns name of the plugin as seen in Corefile and plugin.cfg
func (p *plug) Name() string { return "refuseany" }

View file

@ -1,36 +0,0 @@
package dnsfilter
import (
"log"
"github.com/mholt/caddy"
)
var Reload = make(chan bool)
func hook(event caddy.EventName, info interface{}) error {
if event != caddy.InstanceStartupEvent {
return nil
}
// this should be an instance. ok to panic if not
instance := info.(*caddy.Instance)
go func() {
for range Reload {
corefile, err := caddy.LoadCaddyfile(instance.Caddyfile().ServerType())
if err != nil {
continue
}
_, err = instance.Restart(corefile)
if err != nil {
log.Printf("Corefile changed but reload failed: %s", err)
continue
}
// hook will be called again from new instance
return
}
}()
return nil
}

89
dns.go Normal file
View file

@ -0,0 +1,89 @@
package main
import (
"fmt"
"log"
"net"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/joomcode/errorx"
)
var dnsServer = dnsforward.Server{}
func isRunning() bool {
return dnsServer.IsRunning()
}
func generateServerConfig() dnsforward.ServerConfig {
filters := []dnsfilter.Filter{}
userFilter := userFilter()
filters = append(filters, dnsfilter.Filter{
ID: userFilter.ID,
Rules: userFilter.Rules,
})
for _, filter := range config.Filters {
filters = append(filters, dnsfilter.Filter{
ID: filter.ID,
Rules: filter.Rules,
})
}
newconfig := dnsforward.ServerConfig{
UDPListenAddr: &net.UDPAddr{Port: config.DNS.Port},
FilteringConfig: config.DNS.FilteringConfig,
Filters: filters,
}
for _, u := range config.DNS.UpstreamDNS {
upstream, err := dnsforward.AddressToUpstream(u, config.DNS.BootstrapDNS)
if err != nil {
log.Printf("Couldn't get upstream: %s", err)
// continue, just ignore the upstream
continue
}
newconfig.Upstreams = append(newconfig.Upstreams, upstream)
}
return newconfig
}
func startDNSServer() error {
if isRunning() {
return fmt.Errorf("Unable to start forwarding DNS server: Already running")
}
newconfig := generateServerConfig()
err := dnsServer.Start(&newconfig)
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
return nil
}
func reconfigureDNSServer() error {
if !isRunning() {
return fmt.Errorf("Refusing to reconfigure forwarding DNS server: not running")
}
err := dnsServer.Reconfigure(generateServerConfig())
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
return nil
}
func stopDNSServer() error {
if !isRunning() {
return fmt.Errorf("Refusing to stop forwarding DNS server: not running")
}
err := dnsServer.Stop()
if err != nil {
return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
}
return nil
}

View file

@ -38,21 +38,22 @@ var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax")
// ErrInvalidSyntax is returned by AddRule when the rule was already added to the filter
var ErrAlreadyExists = errors.New("dnsfilter: rule was already added")
// ErrInvalidParental is returned by EnableParental when sensitivity is not a valid value
var ErrInvalidParental = errors.New("dnsfilter: invalid parental sensitivity, must be either 3, 10, 13 or 17")
const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet spot
const enableFastLookup = true // flag for debugging, must be true in production for faster performance
const enableDelayedCompilation = true // flag for debugging, must be true in production for faster performance
type config struct {
parentalServer string
parentalSensitivity int // must be either 3, 10, 13 or 17
parentalEnabled bool
safeSearchEnabled bool
safeBrowsingEnabled bool
safeBrowsingServer string
// Config allows you to configure DNS filtering with New() or just change variables directly.
type Config struct {
ParentalSensitivity int `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17
ParentalEnabled bool `yaml:"parental_enabled"`
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
}
type privateConfig struct {
parentalServer string // access via methods
safeBrowsingServer string // access via methods
}
type rule struct {
@ -110,7 +111,13 @@ type Dnsfilter struct {
client http.Client // handle for http client -- single instance as recommended by docs
transport *http.Transport // handle for http transport used by http client
config config
Config // for direct access by library users, even a = assignment
privateConfig
}
type Filter struct {
ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase
Rules []string `json:"-" yaml:"-"` // not in yaml or json
}
//go:generate stringer -type=Reason
@ -171,7 +178,7 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) {
}
// check safebrowsing if no match
if d.config.safeBrowsingEnabled {
if d.SafeBrowsingEnabled {
result, err = d.checkSafeBrowsing(host)
if err != nil {
// failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache
@ -184,7 +191,7 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) {
}
// check parental if no match
if d.config.parentalEnabled {
if d.ParentalEnabled {
result, err = d.checkParental(host)
if err != nil {
// failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache
@ -569,11 +576,11 @@ func hostnameToHashParam(host string, addslash bool) (string, map[string]bool) {
func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
// prevent recursion -- checking the host of safebrowsing server makes no sense
if host == d.config.safeBrowsingServer {
if host == d.safeBrowsingServer {
return Result{}, nil
}
format := func(hashparam string) string {
url := fmt.Sprintf(defaultSafebrowsingURL, d.config.safeBrowsingServer, hashparam)
url := fmt.Sprintf(defaultSafebrowsingURL, d.safeBrowsingServer, hashparam)
return url
}
handleBody := func(body []byte, hashes map[string]bool) (Result, error) {
@ -610,11 +617,11 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
func (d *Dnsfilter) checkParental(host string) (Result, error) {
// prevent recursion -- checking the host of parental safety server makes no sense
if host == d.config.parentalServer {
if host == d.parentalServer {
return Result{}, nil
}
format := func(hashparam string) string {
url := fmt.Sprintf(defaultParentalURL, d.config.parentalServer, hashparam, d.config.parentalSensitivity)
url := fmt.Sprintf(defaultParentalURL, d.parentalServer, hashparam, d.ParentalSensitivity)
return url
}
handleBody := func(body []byte, hashes map[string]bool) (Result, error) {
@ -727,6 +734,24 @@ func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gc
// Adding rule and matching against the rules
//
// AddRules is a convinience function to add an array of filters in one call
func (d *Dnsfilter) AddRules(filters []Filter) error {
for _, f := range filters {
for _, rule := range f.Rules {
err := d.AddRule(rule, f.ID)
if err == ErrAlreadyExists || err == ErrInvalidSyntax {
continue
}
if err != nil {
log.Printf("Cannot add rule %s: %s", rule, err)
// Just ignore invalid rules
continue
}
}
}
return nil
}
// AddRule adds a rule, checking if it is a valid rule first and if it wasn't added already
func (d *Dnsfilter) AddRule(input string, filterListID int64) error {
input = strings.TrimSpace(input)
@ -846,7 +871,7 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) {
//
// New creates properly initialized DNS Filter that is ready to be used
func New() *Dnsfilter {
func New(c *Config) *Dnsfilter {
d := new(Dnsfilter)
d.storage = make(map[string]bool)
@ -867,8 +892,11 @@ func New() *Dnsfilter {
Transport: d.transport,
Timeout: defaultHTTPTimeout,
}
d.config.safeBrowsingServer = defaultSafebrowsingServer
d.config.parentalServer = defaultParentalServer
d.safeBrowsingServer = defaultSafebrowsingServer
d.parentalServer = defaultParentalServer
if c != nil {
d.Config = *c
}
return d
}
@ -885,35 +913,21 @@ func (d *Dnsfilter) Destroy() {
// config manipulation helpers
//
// EnableSafeBrowsing turns on checking hostnames in malware/phishing database
func (d *Dnsfilter) EnableSafeBrowsing() {
d.config.safeBrowsingEnabled = true
}
// EnableParental turns on checking hostnames for containing adult content
func (d *Dnsfilter) EnableParental(sensitivity int) error {
// IsParentalSensitivityValid checks if sensitivity is valid value
func IsParentalSensitivityValid(sensitivity int) bool {
switch sensitivity {
case 3, 10, 13, 17:
d.config.parentalSensitivity = sensitivity
d.config.parentalEnabled = true
return nil
default:
return ErrInvalidParental
return true
}
}
// EnableSafeSearch turns on enforcing safesearch in search engines
// only used in coredns plugin and requires caller to use SafeSearchDomain()
func (d *Dnsfilter) EnableSafeSearch() {
d.config.safeSearchEnabled = true
return false
}
// SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup
func (d *Dnsfilter) SetSafeBrowsingServer(host string) {
if len(host) == 0 {
d.config.safeBrowsingServer = defaultSafebrowsingServer
d.safeBrowsingServer = defaultSafebrowsingServer
} else {
d.config.safeBrowsingServer = host
d.safeBrowsingServer = host
}
}
@ -929,7 +943,7 @@ func (d *Dnsfilter) ResetHTTPTimeout() {
// SafeSearchDomain returns replacement address for search engine
func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) {
if d.config.safeSearchEnabled {
if d.SafeSearchEnabled {
val, ok := safeSearchDomains[host]
return val, ok
}

View file

@ -338,7 +338,7 @@ func mustLoadTestRules(d *Dnsfilter) {
}
func NewForTest() *Dnsfilter {
d := New()
d := New(nil)
purgeCaches()
return d
}
@ -542,7 +542,7 @@ func TestSafeBrowsing(t *testing.T) {
t.Run(fmt.Sprintf("%s in %s", tc, _Func()), func(t *testing.T) {
d := NewForTest()
defer d.Destroy()
d.EnableSafeBrowsing()
d.SafeBrowsingEnabled = true
stats.Safebrowsing.Requests = 0
d.checkMatch(t, "wmconvirus.narod.ru")
d.checkMatch(t, "wmconvirus.narod.ru")
@ -570,7 +570,7 @@ func TestSafeBrowsing(t *testing.T) {
func TestParallelSB(t *testing.T) {
d := NewForTest()
defer d.Destroy()
d.EnableSafeBrowsing()
d.SafeBrowsingEnabled = true
t.Run("group", func(t *testing.T) {
for i := 0; i < 100; i++ {
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
@ -597,7 +597,7 @@ func TestSafeBrowsingCustomServerFail(t *testing.T) {
defer ts.Close()
address := ts.Listener.Addr().String()
d.EnableSafeBrowsing()
d.SafeBrowsingEnabled = true
d.SetHTTPTimeout(time.Second * 5)
d.SetSafeBrowsingServer(address) // this will ensure that test fails
d.checkMatchEmpty(t, "wmconvirus.narod.ru")
@ -606,7 +606,8 @@ func TestSafeBrowsingCustomServerFail(t *testing.T) {
func TestParentalControl(t *testing.T) {
d := NewForTest()
defer d.Destroy()
d.EnableParental(3)
d.ParentalEnabled = true
d.ParentalSensitivity = 3
d.checkMatch(t, "pornhub.com")
d.checkMatch(t, "pornhub.com")
if stats.Parental.Requests != 1 {
@ -637,7 +638,7 @@ func TestSafeSearch(t *testing.T) {
if ok {
t.Errorf("Expected safesearch to error when disabled")
}
d.EnableSafeSearch()
d.SafeSearchEnabled = true
val, ok := d.SafeSearchDomain("www.google.com")
if !ok {
t.Errorf("Expected safesearch to find result for www.google.com")
@ -924,7 +925,7 @@ func BenchmarkLotsOfRulesLotsOfHostsParallel(b *testing.B) {
func BenchmarkSafeBrowsing(b *testing.B) {
d := NewForTest()
defer d.Destroy()
d.EnableSafeBrowsing()
d.SafeBrowsingEnabled = true
for n := 0; n < b.N; n++ {
hostname := "wmconvirus.narod.ru"
ret, err := d.CheckHost(hostname)
@ -940,7 +941,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
func BenchmarkSafeBrowsingParallel(b *testing.B) {
d := NewForTest()
defer d.Destroy()
d.EnableSafeBrowsing()
d.SafeBrowsingEnabled = true
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
hostname := "wmconvirus.narod.ru"
@ -958,7 +959,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
func BenchmarkSafeSearch(b *testing.B) {
d := NewForTest()
defer d.Destroy()
d.EnableSafeSearch()
d.SafeSearchEnabled = true
for n := 0; n < b.N; n++ {
val, ok := d.SafeSearchDomain("www.google.com")
if !ok {
@ -973,7 +974,7 @@ func BenchmarkSafeSearch(b *testing.B) {
func BenchmarkSafeSearchParallel(b *testing.B) {
d := NewForTest()
defer d.Destroy()
d.EnableSafeSearch()
d.SafeSearchEnabled = true
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
val, ok := d.SafeSearchDomain("www.google.com")
@ -1009,17 +1010,3 @@ func _Func() string {
f := runtime.FuncForPC(pc[0])
return path.Base(f.Name())
}
func trace(format string, args ...interface{}) {
pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc)
f := runtime.FuncForPC(pc[0])
var buf strings.Builder
buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name())))
text := fmt.Sprintf(format, args...)
buf.WriteString(text)
if len(text) == 0 || text[len(text)-1] != '\n' {
buf.WriteRune('\n')
}
fmt.Print(buf.String())
}

View file

@ -1,6 +1,10 @@
package dnsfilter
import (
"fmt"
"os"
"path"
"runtime"
"strings"
"sync/atomic"
)
@ -58,3 +62,17 @@ func updateMax(valuePtr *int64, maxPtr *int64) {
// swapping failed because value has changed after reading, try again
}
}
func trace(format string, args ...interface{}) {
pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc)
f := runtime.FuncForPC(pc[0])
var buf strings.Builder
buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name())))
text := fmt.Sprintf(format, args...)
buf.WriteString(text)
if len(text) == 0 || text[len(text)-1] != '\n' {
buf.WriteRune('\n')
}
fmt.Fprint(os.Stderr, buf.String())
}

107
dnsforward/bootstrap.go Normal file
View file

@ -0,0 +1,107 @@
package dnsforward
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/url"
"strings"
"sync"
"github.com/joomcode/errorx"
)
type bootstrapper struct {
address string // in form of "tls://one.one.one.one:853"
resolver *net.Resolver // resolver to use to resolve hostname, if neccessary
resolved string // in form "IP:port"
resolvedConfig *tls.Config
sync.Mutex
}
func toBoot(address, bootstrapAddr string) bootstrapper {
var resolver *net.Resolver
if bootstrapAddr != "" {
resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, network, bootstrapAddr)
},
}
}
return bootstrapper{
address: address,
resolver: resolver,
}
}
// will get usable IP address from Address field, and caches the result
func (n *bootstrapper) get() (string, *tls.Config, error) {
// TODO: RLock() here but atomically upgrade to Lock() if fast path doesn't work
n.Lock()
if n.resolved != "" { // fast path
retval, tlsconfig := n.resolved, n.resolvedConfig
n.Unlock()
return retval, tlsconfig, nil
}
//
// slow path
//
defer n.Unlock()
justHostPort := n.address
if strings.Contains(n.address, "://") {
url, err := url.Parse(n.address)
if err != nil {
return "", nil, errorx.Decorate(err, "Failed to parse %s", n.address)
}
justHostPort = url.Host
}
// convert host to IP if neccessary, we know that it's scheme://hostname:port/
// get a host without port
host, port, err := net.SplitHostPort(justHostPort)
if err != nil {
return "", nil, fmt.Errorf("bootstrapper requires port in address %s", n.address)
}
// if it's an IP
ip := net.ParseIP(host)
if ip != nil {
n.resolved = justHostPort
return n.resolved, nil, nil
}
//
// if it's a hostname
//
resolver := n.resolver // no need to check for nil resolver -- documented that nil is default resolver
addrs, err := resolver.LookupIPAddr(context.TODO(), host)
if err != nil {
return "", nil, errorx.Decorate(err, "Failed to lookup %s", host)
}
for _, addr := range addrs {
// TODO: support ipv6, support multiple ipv4
if addr.IP.To4() == nil {
continue
}
ip = addr.IP
break
}
if ip == nil {
// couldn't find any suitable IP address
return "", nil, fmt.Errorf("Couldn't find any suitable IP address for host %s", host)
}
n.resolved = net.JoinHostPort(ip.String(), port)
n.resolvedConfig = &tls.Config{ServerName: host}
return n.resolved, n.resolvedConfig, nil
}

225
dnsforward/cache.go Normal file
View file

@ -0,0 +1,225 @@
package dnsforward
import (
"encoding/binary"
"log"
"math"
"strings"
"sync"
"time"
"github.com/miekg/dns"
)
type item struct {
m *dns.Msg
when time.Time
}
type cache struct {
items map[string]item
sync.RWMutex
}
func (c *cache) Get(request *dns.Msg) (*dns.Msg, bool) {
if request == nil {
return nil, false
}
ok, key := key(request)
if !ok {
log.Printf("Get(): key returned !ok")
return nil, false
}
c.RLock()
item, ok := c.items[key]
c.RUnlock()
if !ok {
return nil, false
}
// get item's TTL
ttl := findLowestTTL(item.m)
// zero TTL? delete and don't serve it
if ttl == 0 {
c.Lock()
delete(c.items, key)
c.Unlock()
return nil, false
}
// too much time has passed? delete and don't serve it
if time.Since(item.when) >= time.Duration(ttl)*time.Second {
c.Lock()
delete(c.items, key)
c.Unlock()
return nil, false
}
response := item.fromItem(request)
return response, true
}
func (c *cache) Set(m *dns.Msg) {
if m == nil {
return // no-op
}
if !isRequestCacheable(m) {
return
}
if !isResponseCacheable(m) {
return
}
ok, key := key(m)
if !ok {
return
}
i := toItem(m)
c.Lock()
if c.items == nil {
c.items = map[string]item{}
}
c.items[key] = i
c.Unlock()
}
// check only request fields
func isRequestCacheable(m *dns.Msg) bool {
// truncated messages aren't valid
if m.Truncated {
log.Printf("Refusing to cache truncated message")
return false
}
// if has wrong number of questions, also don't cache
if len(m.Question) != 1 {
log.Printf("Refusing to cache message with wrong number of questions")
return false
}
// only OK or NXdomain replies are cached
switch m.Rcode {
case dns.RcodeSuccess:
case dns.RcodeNameError: // that's an NXDomain
case dns.RcodeServerFailure:
return false // quietly refuse, don't log
default:
log.Printf("%s: Refusing to cache message with rcode: %s", m.Question[0].Name, dns.RcodeToString[m.Rcode])
return false
}
return true
}
func isResponseCacheable(m *dns.Msg) bool {
ttl := findLowestTTL(m)
if ttl == 0 {
return false
}
return true
}
func findLowestTTL(m *dns.Msg) uint32 {
var ttl uint32 = math.MaxUint32
found := false
if m.Answer != nil {
for _, r := range m.Answer {
if r.Header().Ttl < ttl {
ttl = r.Header().Ttl
found = true
}
}
}
if m.Ns != nil {
for _, r := range m.Ns {
if r.Header().Ttl < ttl {
ttl = r.Header().Ttl
found = true
}
}
}
if m.Extra != nil {
for _, r := range m.Extra {
if r.Header().Rrtype == dns.TypeOPT {
continue // OPT records use TTL for other purposes
}
if r.Header().Ttl < ttl {
ttl = r.Header().Ttl
found = true
}
}
}
if found == false {
return 0
}
return ttl
}
// key is binary little endian in sequence:
// uint16(qtype) then uint16(qclass) then name
func key(m *dns.Msg) (bool, string) {
if len(m.Question) != 1 {
log.Printf("got msg with len(m.Question) != 1: %d", len(m.Question))
return false, ""
}
bb := strings.Builder{}
b := make([]byte, 2)
binary.LittleEndian.PutUint16(b, m.Question[0].Qtype)
bb.Write(b)
binary.LittleEndian.PutUint16(b, m.Question[0].Qclass)
bb.Write(b)
name := strings.ToLower(m.Question[0].Name)
bb.WriteString(name)
return true, bb.String()
}
func toItem(m *dns.Msg) item {
return item{
m: m,
when: time.Now(),
}
}
func (i *item) fromItem(request *dns.Msg) *dns.Msg {
response := &dns.Msg{}
response.SetReply(request)
response.Authoritative = false
response.AuthenticatedData = i.m.AuthenticatedData
response.RecursionAvailable = i.m.RecursionAvailable
response.Rcode = i.m.Rcode
ttl := findLowestTTL(i.m)
timeleft := math.Round(float64(ttl) - time.Since(i.when).Seconds())
var newttl uint32
if timeleft > 0 {
newttl = uint32(timeleft)
}
for _, r := range i.m.Answer {
answer := dns.Copy(r)
answer.Header().Ttl = newttl
response.Answer = append(response.Answer, answer)
}
for _, r := range i.m.Ns {
ns := dns.Copy(r)
ns.Header().Ttl = newttl
response.Ns = append(response.Ns, ns)
}
for _, r := range i.m.Extra {
// don't return OPT records as these are hop-by-hop
if r.Header().Rrtype == dns.TypeOPT {
continue
}
extra := dns.Copy(r)
extra.Header().Ttl = newttl
response.Extra = append(response.Extra, extra)
}
return response
}

144
dnsforward/cache_test.go Normal file
View file

@ -0,0 +1,144 @@
package dnsforward
import (
"strings"
"testing"
"github.com/go-test/deep"
"github.com/miekg/dns"
)
func RR(rr string) dns.RR {
r, err := dns.NewRR(rr)
if err != nil {
panic(err)
}
return r
}
// deepEqual is same as deep.Equal, except:
// * ignores Id when comparing
// * question names are not case sensetive
func deepEqualMsg(left *dns.Msg, right *dns.Msg) []string {
temp := *left
temp.Id = right.Id
for i := range left.Question {
left.Question[i].Name = strings.ToLower(left.Question[i].Name)
}
for i := range right.Question {
right.Question[i].Name = strings.ToLower(right.Question[i].Name)
}
return deep.Equal(&temp, right)
}
func TestCacheSanity(t *testing.T) {
cache := cache{}
request := dns.Msg{}
request.SetQuestion("google.com.", dns.TypeA)
_, ok := cache.Get(&request)
if ok {
t.Fatal("empty cache replied with positive response")
}
}
type tests struct {
cache []testEntry
cases []testCase
}
type testEntry struct {
q string
t uint16
a []dns.RR
}
type testCase struct {
q string
t uint16
a []dns.RR
ok bool
}
func TestCache(t *testing.T) {
tests := tests{
cache: []testEntry{
{q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}},
},
cases: []testCase{
{q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
{q: "google.com.", t: dns.TypeMX, ok: false},
},
}
runTests(t, tests)
}
func TestCacheMixedCase(t *testing.T) {
tests := tests{
cache: []testEntry{
{q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}},
},
cases: []testCase{
{q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
{q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
{q: "GOOGLE.COM.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
{q: "gOOgle.com.", t: dns.TypeMX, ok: false},
{q: "google.com.", t: dns.TypeMX, ok: false},
{q: "GOOGLE.COM.", t: dns.TypeMX, ok: false},
},
}
runTests(t, tests)
}
func TestZeroTTL(t *testing.T) {
tests := tests{
cache: []testEntry{
{q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 0 IN A 8.8.8.8")}},
},
cases: []testCase{
{q: "google.com.", t: dns.TypeA, ok: false},
{q: "google.com.", t: dns.TypeA, ok: false},
{q: "google.com.", t: dns.TypeA, ok: false},
{q: "google.com.", t: dns.TypeMX, ok: false},
{q: "google.com.", t: dns.TypeMX, ok: false},
{q: "google.com.", t: dns.TypeMX, ok: false},
},
}
runTests(t, tests)
}
func runTests(t *testing.T, tests tests) {
t.Helper()
cache := cache{}
for _, tc := range tests.cache {
reply := dns.Msg{}
reply.SetQuestion(tc.q, tc.t)
reply.Response = true
reply.Answer = tc.a
cache.Set(&reply)
}
for _, tc := range tests.cases {
request := dns.Msg{}
request.SetQuestion(tc.q, tc.t)
val, ok := cache.Get(&request)
if diff := deep.Equal(ok, tc.ok); diff != nil {
t.Error(diff)
}
if tc.a != nil {
if ok == false {
continue
}
reply := dns.Msg{}
reply.SetQuestion(tc.q, tc.t)
reply.Response = true
reply.Answer = tc.a
cache.Set(&reply)
if diff := deepEqualMsg(val, &reply); diff != nil {
t.Error(diff)
} else {
if diff := deep.Equal(val, reply); diff == nil {
t.Error("different message ID were not caught")
}
}
}
}
}

594
dnsforward/dnsforward.go Normal file
View file

@ -0,0 +1,594 @@
package dnsforward
import (
"fmt"
"log"
"net"
"reflect"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/joomcode/errorx"
"github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
)
// Server is the main way to start a DNS server.
//
// Example:
// s := dnsforward.Server{}
// err := s.Start(nil) // will start a DNS server listening on default port 53, in a goroutine
// err := s.Reconfigure(ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}}) // will reconfigure running DNS server to listen on UDP port 53535
// err := s.Stop() // will stop listening on port 53535 and cancel all goroutines
// err := s.Start(nil) // will start listening again, on port 53535, in a goroutine
//
// The zero Server is empty and ready for use.
type Server struct {
udpListen *net.UDPConn
dnsFilter *dnsfilter.Dnsfilter
cache cache
ratelimitBuckets *gocache.Cache // where the ratelimiters are stored, per IP
sync.RWMutex
ServerConfig
}
// uncomment this block to have tracing of locks
/*
func (s *Server) Lock() {
pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc)
f := runtime.FuncForPC(pc[0])
file, line := f.FileLine(pc[0])
fmt.Fprintf(os.Stderr, "%s:%d %s() -> Lock() -> in progress\n", path.Base(file), line, path.Base(f.Name()))
s.RWMutex.Lock()
fmt.Fprintf(os.Stderr, "%s:%d %s() -> Lock() -> done\n", path.Base(file), line, path.Base(f.Name()))
}
func (s *Server) RLock() {
pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc)
f := runtime.FuncForPC(pc[0])
file, line := f.FileLine(pc[0])
fmt.Fprintf(os.Stderr, "%s:%d %s() -> RLock() -> in progress\n", path.Base(file), line, path.Base(f.Name()))
s.RWMutex.RLock()
fmt.Fprintf(os.Stderr, "%s:%d %s() -> RLock() -> done\n", path.Base(file), line, path.Base(f.Name()))
}
func (s *Server) Unlock() {
pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc)
f := runtime.FuncForPC(pc[0])
file, line := f.FileLine(pc[0])
fmt.Fprintf(os.Stderr, "%s:%d %s() -> Unlock() -> in progress\n", path.Base(file), line, path.Base(f.Name()))
s.RWMutex.Unlock()
fmt.Fprintf(os.Stderr, "%s:%d %s() -> Unlock() -> done\n", path.Base(file), line, path.Base(f.Name()))
}
func (s *Server) RUnlock() {
pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc)
f := runtime.FuncForPC(pc[0])
file, line := f.FileLine(pc[0])
fmt.Fprintf(os.Stderr, "%s:%d %s() -> RUnlock() -> in progress\n", path.Base(file), line, path.Base(f.Name()))
s.RWMutex.RUnlock()
fmt.Fprintf(os.Stderr, "%s:%d %s() -> RUnlock() -> done\n", path.Base(file), line, path.Base(f.Name()))
}
*/
type FilteringConfig struct {
ProtectionEnabled bool `yaml:"protection_enabled"`
FilteringEnabled bool `yaml:"filtering_enabled"`
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
QueryLogEnabled bool `yaml:"querylog_enabled"`
Ratelimit int `yaml:"ratelimit"`
RatelimitWhitelist []string `yaml:"ratelimit_whitelist"`
RefuseAny bool `yaml:"refuse_any"`
BootstrapDNS string `yaml:"bootstrap_dns"`
dnsfilter.Config `yaml:",inline"`
}
// The zero ServerConfig is empty and ready for use.
type ServerConfig struct {
UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *)
Upstreams []Upstream
Filters []dnsfilter.Filter
FilteringConfig
}
// if any of ServerConfig values are zero, then default values from below are used
var defaultValues = ServerConfig{
UDPListenAddr: &net.UDPAddr{Port: 53},
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
Upstreams: []Upstream{
//// dns over HTTPS
// &dnsOverHTTPS{boot: toBoot("https://1.1.1.1/dns-query", "")},
// &dnsOverHTTPS{boot: toBoot("https://dns.google.com/experimental", "")},
// &dnsOverHTTPS{boot: toBoot("https://doh.cleanbrowsing.org/doh/security-filter/", "")},
// &dnsOverHTTPS{boot: toBoot("https://dns10.quad9.net/dns-query", "")},
// &dnsOverHTTPS{boot: toBoot("https://doh.powerdns.org", "")},
// &dnsOverHTTPS{boot: toBoot("https://doh.securedns.eu/dns-query", "")},
//// dns over TLS
// &dnsOverTLS{boot: toBoot("tls://8.8.8.8:853", "")},
// &dnsOverTLS{boot: toBoot("tls://8.8.4.4:853", "")},
// &dnsOverTLS{boot: toBoot("tls://1.1.1.1:853", "")},
// &dnsOverTLS{boot: toBoot("tls://1.0.0.1:853", "")},
//// plainDNS
&plainDNS{boot: toBoot("8.8.8.8:53", "")},
&plainDNS{boot: toBoot("8.8.4.4:53", "")},
&plainDNS{boot: toBoot("1.1.1.1:53", "")},
&plainDNS{boot: toBoot("1.0.0.1:53", "")},
},
}
//
// packet loop
//
func (s *Server) packetLoop() {
log.Printf("Entering packet handle loop")
b := make([]byte, dns.MaxMsgSize)
for {
s.RLock()
conn := s.udpListen
s.RUnlock()
if conn == nil {
log.Printf("udp socket has disappeared, exiting loop")
break
}
n, addr, err := conn.ReadFrom(b)
// documentation says to handle the packet even if err occurs, so do that first
if n > 0 {
// make a copy of all bytes because ReadFrom() will overwrite contents of b on next call
// we need the contents to survive the call because we're handling them in goroutine
p := make([]byte, n)
copy(p, b)
go s.handlePacket(p, addr, conn) // ignore errors
}
if err != nil {
if isConnClosed(err) {
log.Printf("ReadFrom() returned because we're reading from a closed connection, exiting loop")
// don't try to nullify s.udpListen here, because s.udpListen could be already re-bound to listen
break
}
log.Printf("Got error when reading from udp listen: %s", err)
}
}
}
//
// Control functions
//
func (s *Server) Start(config *ServerConfig) error {
s.Lock()
defer s.Unlock()
if config != nil {
s.ServerConfig = *config
}
// TODO: handle being called Start() second time after Stop()
if s.udpListen == nil {
log.Printf("Creating UDP socket")
var err error
addr := s.UDPListenAddr
if addr == nil {
addr = defaultValues.UDPListenAddr
}
s.udpListen, err = net.ListenUDP("udp", addr)
if err != nil {
s.udpListen = nil
return errorx.Decorate(err, "Couldn't listen to UDP socket")
}
log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr)
}
if s.dnsFilter == nil {
log.Printf("Creating dnsfilter")
s.dnsFilter = dnsfilter.New(&s.Config)
// add rules only if they are enabled
if s.FilteringEnabled {
s.dnsFilter.AddRules(s.Filters)
}
}
log.Printf("Loading stats from querylog")
err := fillStatsFromQueryLog()
if err != nil {
log.Printf("Failed to load stats from querylog: %s", err)
return err
}
once.Do(func() {
go periodicQueryLogRotate()
go periodicHourlyTopRotate()
go statsRotator()
})
go s.packetLoop()
return nil
}
func (s *Server) Stop() error {
s.Lock()
defer s.Unlock()
if s.udpListen != nil {
err := s.udpListen.Close()
s.udpListen = nil
if err != nil {
return errorx.Decorate(err, "Couldn't close UDP listening socket")
}
}
// flush remainder to file
logBufferLock.Lock()
flushBuffer := logBuffer
logBuffer = nil
logBufferLock.Unlock()
err := flushToFile(flushBuffer)
if err != nil {
log.Printf("Saving querylog to file failed: %s", err)
return err
}
return nil
}
func (s *Server) IsRunning() bool {
s.RLock()
isRunning := true
if s.udpListen == nil {
isRunning = false
}
s.RUnlock()
return isRunning
}
//
// Server reconfigure
//
func (s *Server) reconfigureListenAddr(new ServerConfig) error {
oldAddr := s.UDPListenAddr
if oldAddr == nil {
oldAddr = defaultValues.UDPListenAddr
}
newAddr := new.UDPListenAddr
if newAddr == nil {
newAddr = defaultValues.UDPListenAddr
}
if newAddr.Port == 0 {
return errorx.IllegalArgument.New("new port cannot be 0")
}
if reflect.DeepEqual(oldAddr, newAddr) {
// do nothing, the addresses are exactly the same
log.Printf("Not going to rebind because addresses are same: %v -> %v", oldAddr, newAddr)
return nil
}
// rebind, using a strategy:
// * if ports are different, bind new first, then close old
// * if ports are same, close old first, then bind new
var newListen *net.UDPConn
var err error
if oldAddr.Port != newAddr.Port {
log.Printf("Rebinding -- ports are different so bind first then close")
newListen, err = net.ListenUDP("udp", newAddr)
if err != nil {
return errorx.Decorate(err, "Couldn't bind to %v", newAddr)
}
s.Lock()
if s.udpListen != nil {
err = s.udpListen.Close()
s.udpListen = nil
}
s.Unlock()
if err != nil {
return errorx.Decorate(err, "Couldn't close UDP listening socket")
}
} else {
log.Printf("Rebinding -- ports are same so close first then bind")
s.Lock()
if s.udpListen != nil {
err = s.udpListen.Close()
s.udpListen = nil
}
s.Unlock()
if err != nil {
return errorx.Decorate(err, "Couldn't close UDP listening socket")
}
newListen, err = net.ListenUDP("udp", newAddr)
if err != nil {
return errorx.Decorate(err, "Couldn't bind to %v", newAddr)
}
}
s.Lock()
s.udpListen = newListen
s.UDPListenAddr = new.UDPListenAddr
s.Unlock()
log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr)
go s.packetLoop() // the old one has quit, use new one
return nil
}
func (s *Server) reconfigureBlockedResponseTTL(new ServerConfig) {
newVal := new.BlockedResponseTTL
if newVal == 0 {
newVal = defaultValues.BlockedResponseTTL
}
oldVal := s.BlockedResponseTTL
if oldVal == 0 {
oldVal = defaultValues.BlockedResponseTTL
}
if newVal != oldVal {
s.BlockedResponseTTL = new.BlockedResponseTTL
}
}
func (s *Server) reconfigureUpstreams(new ServerConfig) {
newVal := new.Upstreams
if len(newVal) == 0 {
newVal = defaultValues.Upstreams
}
oldVal := s.Upstreams
if len(oldVal) == 0 {
oldVal = defaultValues.Upstreams
}
if reflect.DeepEqual(newVal, oldVal) {
// they're exactly the same, do nothing
return
}
s.Upstreams = new.Upstreams
}
func (s *Server) reconfigureFiltering(new ServerConfig) {
newFilters := new.Filters
if len(newFilters) == 0 {
newFilters = defaultValues.Filters
}
oldFilters := s.Filters
if len(oldFilters) == 0 {
oldFilters = defaultValues.Filters
}
needUpdate := false
if !reflect.DeepEqual(newFilters, oldFilters) {
needUpdate = true
}
if !reflect.DeepEqual(new.FilteringConfig, s.FilteringConfig) {
needUpdate = true
}
if !needUpdate {
// nothing to do, everything is same
return
}
// TODO: instead of creating new dnsfilter, change existing one's settings and filters
dnsFilter := dnsfilter.New(&new.Config) // sets safebrowsing, safesearch and parental
// add rules only if they are enabled
if new.FilteringEnabled {
dnsFilter.AddRules(newFilters)
}
s.Lock()
oldDNSFilter := s.dnsFilter
s.dnsFilter = dnsFilter
s.FilteringConfig = new.FilteringConfig
s.Unlock()
oldDNSFilter.Destroy()
}
func (s *Server) Reconfigure(new ServerConfig) error {
s.reconfigureBlockedResponseTTL(new)
s.reconfigureUpstreams(new)
s.reconfigureFiltering(new)
err := s.reconfigureListenAddr(new)
if err != nil {
return errorx.Decorate(err, "Couldn't reconfigure to new listening address %+v", new.UDPListenAddr)
}
return nil
}
//
// packet handling functions
//
// handlePacketInternal processes the incoming packet bytes and returns with an optional response packet.
//
// If an empty dns.Msg is returned, do not try to send anything back to client, otherwise send contents of dns.Msg.
//
// If an error is returned, log it, don't try to generate data based on that error.
func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDPConn) (*dns.Msg, *dnsfilter.Result, Upstream, error) {
// log.Printf("Got packet %d bytes from %s: %v", len(p), addr, p)
//
// DNS packet byte format is valid
//
// any errors below here require a response to client
// log.Printf("Unpacked: %v", msg.String())
if len(msg.Question) != 1 {
log.Printf("Got invalid number of questions: %v", len(msg.Question))
return s.genServerFailure(msg), nil, nil, nil
}
if msg.Question[0].Qtype == dns.TypeANY && s.RefuseAny {
return s.genNotImpl(msg), nil, nil, nil
}
// use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise
host := strings.TrimSuffix(msg.Question[0].Name, ".")
res, err := s.dnsFilter.CheckHost(host)
if err != nil {
log.Printf("dnsfilter failed to check host '%s': %s", host, err)
return s.genServerFailure(msg), &res, nil, err
} else if res.IsFiltered {
log.Printf("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule)
return s.genNXDomain(msg), &res, nil, nil
}
{
val, ok := s.cache.Get(msg)
if ok && val != nil {
return val, &res, nil, nil
}
}
// TODO: replace with single-socket implementation
upstream := s.chooseUpstream()
reply, err := upstream.Exchange(msg)
if err != nil {
log.Printf("talking to upstream failed for host '%s': %s", host, err)
return s.genServerFailure(msg), &res, upstream, err
}
if reply == nil {
log.Printf("SHOULD NOT HAPPEN upstream returned empty message for host '%s'. Request is %v", host, msg.String())
return s.genServerFailure(msg), &res, upstream, nil
}
s.cache.Set(reply)
return reply, &res, upstream, nil
}
func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) {
start := time.Now()
ip, _, err := net.SplitHostPort(addr.String())
if err != nil {
log.Printf("Failed to split %v into host/port: %s", addr, err)
// not a fatal error, move on
}
// ratelimit based on IP only, protects CPU cycles and outbound connections
if s.isRatelimited(ip) {
// log.Printf("Ratelimiting %s based on IP only", ip)
return // do nothing, don't reply, we got ratelimited
}
msg := &dns.Msg{}
err = msg.Unpack(p)
if err != nil {
log.Printf("got invalid DNS packet: %s", err)
return // do nothing
}
reply, result, upstream, err := s.handlePacketInternal(msg, addr, conn)
if reply != nil {
// ratelimit based on reply size now
replysize := reply.Len()
if s.isRatelimitedForReply(ip, replysize) {
log.Printf("Ratelimiting %s based on IP and size %d", ip, replysize)
return // do nothing, don't reply, we got ratelimited
}
// we're good to respond
rerr := s.respond(reply, addr, conn)
if rerr != nil {
log.Printf("Couldn't respond to UDP packet: %s", err)
}
}
// query logging and stats counters
if s.QueryLogEnabled {
elapsed := time.Since(start)
upstreamAddr := ""
if upstream != nil {
upstreamAddr = upstream.Address()
}
logRequest(msg, reply, result, elapsed, ip, upstreamAddr)
}
}
//
// packet sending functions
//
func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error {
// log.Printf("Replying to %s with %s", addr, resp)
resp.Compress = true
bytes, err := resp.Pack()
if err != nil {
return errorx.Decorate(err, "Couldn't convert message into wire format")
}
n, err := conn.WriteTo(bytes, addr)
if n == 0 && isConnClosed(err) {
return err
}
if n != len(bytes) {
return fmt.Errorf("WriteTo() returned with %d != %d", n, len(bytes))
}
if err != nil {
return errorx.Decorate(err, "WriteTo() returned error")
}
return nil
}
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeServerFailure)
resp.RecursionAvailable = true
return &resp
}
func (s *Server) genNotImpl(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeNotImplemented)
resp.RecursionAvailable = true
resp.SetEdns0(1452, false) // NOTIMPL without EDNS is treated as 'we don't support EDNS', so explicitly set it
return &resp
}
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeNameError)
resp.RecursionAvailable = true
resp.Ns = s.genSOA(request)
return &resp
}
func (s *Server) genSOA(request *dns.Msg) []dns.RR {
zone := ""
if len(request.Question) > 0 {
zone = request.Question[0].Name
}
soa := dns.SOA{
// values copied from verisign's nonexistent .com domain
// their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers
Refresh: 1800,
Retry: 900,
Expire: 604800,
Minttl: 86400,
// copied from AdGuard DNS
Ns: "fake-for-negative-caching.adguard.com.",
Serial: 100500,
// rest is request-specific
Hdr: dns.RR_Header{
Name: zone,
Rrtype: dns.TypeSOA,
Ttl: s.BlockedResponseTTL,
Class: dns.ClassINET,
},
Mbox: "hostmaster.", // zone will be appended later if it's not empty or "."
}
if soa.Hdr.Ttl == 0 {
soa.Hdr.Ttl = defaultValues.BlockedResponseTTL
}
if len(zone) > 0 && zone[0] != '.' {
soa.Mbox += zone
}
return []dns.RR{&soa}
}
var once sync.Once

View file

@ -0,0 +1,49 @@
package dnsforward
import (
"net"
"testing"
"github.com/miekg/dns"
)
func TestServer(t *testing.T) {
s := Server{}
s.UDPListenAddr = &net.UDPAddr{Port: 0}
err := s.Start(nil)
if err != nil {
t.Fatalf("Failed to start server: %s", err)
}
if s.udpListen == nil {
t.Fatal("Started server has nil udpListen")
}
// server is running, send a message
addr := s.udpListen.LocalAddr()
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
reply, err := dns.Exchange(&req, addr.String())
if err != nil {
t.Fatalf("Couldn't talk to server %s: %s", addr, err)
}
if len(reply.Answer) != 1 {
t.Fatalf("DNS server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
}
if a, ok := reply.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(a.A) {
t.Fatalf("DNS server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A)
}
} else {
t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
}
err = s.Stop()
if err != nil {
t.Fatalf("DNS server %s failed to stop: %s", addr, err)
}
}

50
dnsforward/helpers.go Normal file
View file

@ -0,0 +1,50 @@
package dnsforward
import (
"fmt"
"net"
"os"
"path"
"runtime"
"strings"
)
func isConnClosed(err error) bool {
if err == nil {
return false
}
nerr, ok := err.(*net.OpError)
if !ok {
return false
}
if strings.Contains(nerr.Err.Error(), "use of closed network connection") {
return true
}
return false
}
// ---------------------
// debug logging helpers
// ---------------------
func _Func() string {
pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc)
f := runtime.FuncForPC(pc[0])
return path.Base(f.Name())
}
func trace(format string, args ...interface{}) {
pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc)
f := runtime.FuncForPC(pc[0])
var buf strings.Builder
buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name())))
text := fmt.Sprintf(format, args...)
buf.WriteString(text)
if len(text) == 0 || text[len(text)-1] != '\n' {
buf.WriteRune('\n')
}
fmt.Fprint(os.Stderr, buf.String())
}

View file

@ -1,20 +1,16 @@
package dnsfilter
package dnsforward
import (
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"path"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/miekg/dns"
)
@ -42,9 +38,10 @@ type logEntry struct {
Time time.Time
Elapsed time.Duration
IP string
Upstream string `json:",omitempty"` // if empty, means it was cached
}
func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, elapsed time.Duration, ip string) {
func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, ip string, upstream string) {
var q []byte
var a []byte
var err error
@ -64,14 +61,19 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
}
}
if result == nil {
result = &dnsfilter.Result{}
}
now := time.Now()
entry := logEntry{
Question: q,
Answer: a,
Result: result,
Result: *result,
Time: now,
Elapsed: elapsed,
IP: ip,
Upstream: upstream,
}
var flushBuffer []*logEntry
@ -97,6 +99,8 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
// don't do failure, just log
}
incrementCounters(&entry)
// if buffer needs to be flushed to disk, do it now
if len(flushBuffer) > 0 {
// write to file
@ -153,8 +157,7 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
}
if a != nil {
status, _ := response.Typify(a, time.Now().UTC())
jsonEntry["status"] = status.String()
jsonEntry["status"] = dns.RcodeToString[a.Rcode]
}
if len(entry.Result.Rule) > 0 {
jsonEntry["rule"] = entry.Result.Rule
@ -223,17 +226,3 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
http.Error(w, errorText, http.StatusInternalServerError)
}
}
func trace(format string, args ...interface{}) {
pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc)
f := runtime.FuncForPC(pc[0])
var buf strings.Builder
buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name())))
text := fmt.Sprintf(format, args...)
buf.WriteString(text)
if len(text) == 0 || text[len(text)-1] != '\n' {
buf.WriteRune('\n')
}
fmt.Fprint(os.Stderr, buf.String())
}

View file

@ -1,4 +1,4 @@
package dnsfilter
package dnsforward
import (
"bytes"
@ -251,41 +251,3 @@ func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, ti
}
return nil
}
func appendFromLogFile(values []*logEntry, maxLen int, timeWindow time.Duration) []*logEntry {
a := []*logEntry{}
onEntry := func(entry *logEntry) error {
a = append(a, entry)
if len(a) > maxLen {
toskip := len(a) - maxLen
a = a[toskip:]
}
return nil
}
needMore := func() bool {
return true
}
err := genericLoader(onEntry, needMore, timeWindow)
if err != nil {
log.Printf("Failed to load entries from querylog: %s", err)
return values
}
// now that we've read all eligible entries, reverse the slice to make it go from newest->oldest
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
a[left], a[right] = a[right], a[left]
}
// append it to values
values = append(values, a...)
// then cut off of it is bigger than maxLen
if len(values) > maxLen {
values = values[:maxLen]
}
return values
}

View file

@ -1,4 +1,4 @@
package dnsfilter
package dnsforward
import (
"bytes"
@ -14,7 +14,6 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/bluele/gcache"
"github.com/miekg/dns"
)
@ -231,27 +230,7 @@ func fillStatsFromQueryLog() error {
}
queryLogLock.Unlock()
requests.IncWithTime(entry.Time)
if entry.Result.IsFiltered {
filtered.IncWithTime(entry.Time)
}
switch entry.Result.Reason {
case dnsfilter.NotFilteredWhiteList:
whitelisted.IncWithTime(entry.Time)
case dnsfilter.NotFilteredError:
errorsTotal.IncWithTime(entry.Time)
case dnsfilter.FilteredBlackList:
filteredLists.IncWithTime(entry.Time)
case dnsfilter.FilteredSafeBrowsing:
filteredSafebrowsing.IncWithTime(entry.Time)
case dnsfilter.FilteredParental:
filteredParental.IncWithTime(entry.Time)
case dnsfilter.FilteredInvalid:
// do nothing
case dnsfilter.FilteredSafeSearch:
safesearch.IncWithTime(entry.Time)
}
elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time)
incrementCounters(entry)
return nil
}

80
dnsforward/ratelimit.go Normal file
View file

@ -0,0 +1,80 @@
package dnsforward
import (
"log"
"sort"
"time"
"github.com/beefsack/go-rate"
gocache "github.com/patrickmn/go-cache"
)
func (s *Server) limiterForIP(ip string) interface{} {
if s.ratelimitBuckets == nil {
s.ratelimitBuckets = gocache.New(time.Hour, time.Hour)
}
// check if ratelimiter for that IP already exists, if not, create
value, found := s.ratelimitBuckets.Get(ip)
if !found {
value = rate.New(s.Ratelimit, time.Second)
s.ratelimitBuckets.Set(ip, value, time.Hour)
}
return value
}
func (s *Server) isRatelimited(ip string) bool {
if s.Ratelimit == 0 { // 0 -- disabled
return false
}
if len(s.RatelimitWhitelist) > 0 {
i := sort.SearchStrings(s.RatelimitWhitelist, ip)
if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip {
// found, don't ratelimit
return false
}
}
value := s.limiterForIP(ip)
rl, ok := value.(*rate.RateLimiter)
if !ok {
log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache")
return false
}
allow, _ := rl.Try()
return !allow
}
func (s *Server) isRatelimitedForReply(ip string, size int) bool {
if s.Ratelimit == 0 { // 0 -- disabled
return false
}
if len(s.RatelimitWhitelist) > 0 {
i := sort.SearchStrings(s.RatelimitWhitelist, ip)
if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip {
// found, don't ratelimit
return false
}
}
value := s.limiterForIP(ip)
rl, ok := value.(*rate.RateLimiter)
if !ok {
log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache")
return false
}
// For large UDP responses we try more times, effectively limiting per bandwidth
// The exact number of times depends on the response size
for i := 0; i < size/1000; i++ {
allow, _ := rl.Try()
if !allow { // not allowed -> ratelimited
return true
}
}
return false
}

View file

@ -0,0 +1,42 @@
package dnsforward
import (
"testing"
)
func TestRatelimiting(t *testing.T) {
// rate limit is 1 per sec
p := Server{}
p.Ratelimit = 1
limited := p.isRatelimited("127.0.0.1")
if limited {
t.Fatal("First request must have been allowed")
}
limited = p.isRatelimited("127.0.0.1")
if !limited {
t.Fatal("Second request must have been ratelimited")
}
}
func TestWhitelist(t *testing.T) {
// rate limit is 1 per sec with whitelist
p := Server{}
p.Ratelimit = 1
p.RatelimitWhitelist = []string{"127.0.0.1", "127.0.0.2", "127.0.0.125"}
limited := p.isRatelimited("127.0.0.1")
if limited {
t.Fatal("First request must have been allowed")
}
limited = p.isRatelimited("127.0.0.1")
if limited {
t.Fatal("Second request must have been allowed due to whitelist")
}
}

1
dnsforward/standalone/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/standalone

View file

@ -0,0 +1,51 @@
package main
import (
"log"
"net"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"runtime"
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
)
//
// main function
//
func main() {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
go func() {
for range time.Tick(time.Second) {
log.Printf("goroutines = %d", runtime.NumGoroutine())
}
}()
s := dnsforward.Server{}
err := s.Start(nil)
if err != nil {
panic(err)
}
time.Sleep(time.Second)
err = s.Stop()
if err != nil {
panic(err)
}
err = s.Start(&dnsforward.ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}})
if err != nil {
panic(err)
}
err = s.Reconfigure(dnsforward.ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53, IP: net.ParseIP("0.0.0.0")}})
if err != nil {
panic(err)
}
log.Printf("Now serving DNS")
signal_channel := make(chan os.Signal)
signal.Notify(signal_channel, syscall.SIGINT, syscall.SIGTERM)
<-signal_channel
}

View file

@ -1,4 +1,4 @@
package dnsfilter
package dnsforward
import (
"encoding/json"
@ -8,21 +8,20 @@ import (
"sync"
"time"
"github.com/coredns/coredns/plugin"
"github.com/prometheus/client_golang/prometheus"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
)
var (
requests = newDNSCounter("requests_total", "Count of requests seen by dnsfilter.")
filtered = newDNSCounter("filtered_total", "Count of requests filtered by dnsfilter.")
filteredLists = newDNSCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.")
filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.")
filteredParental = newDNSCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.")
filteredInvalid = newDNSCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.")
whitelisted = newDNSCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.")
safesearch = newDNSCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.")
errorsTotal = newDNSCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.")
elapsedTime = newDNSHistogram("request_duration", "Histogram of the time (in seconds) each request took.")
requests = newDNSCounter("requests_total")
filtered = newDNSCounter("filtered_total")
filteredLists = newDNSCounter("filtered_lists_total")
filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total")
filteredParental = newDNSCounter("filtered_parental_total")
filteredInvalid = newDNSCounter("filtered_invalid_total")
whitelisted = newDNSCounter("whitelisted_total")
safesearch = newDNSCounter("safesearch_total")
errorsTotal = newDNSCounter("errors_total")
elapsedTime = newDNSHistogram("request_duration")
)
// entries for single time period (for example all per-second entries)
@ -143,21 +142,13 @@ func statsRotator() {
type counter struct {
name string // used as key in periodic stats
value int64
prom prometheus.Counter
}
func newDNSCounter(name string, help string) *counter {
func newDNSCounter(name string) *counter {
// trace("called")
c := &counter{}
c.prom = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: plugin.Namespace,
Subsystem: "dnsfilter",
Name: name,
Help: help,
})
c.name = name
return c
return &counter{
name: name,
}
}
func (c *counter) IncWithTime(when time.Time) {
@ -166,40 +157,22 @@ func (c *counter) IncWithTime(when time.Time) {
statistics.PerHour.Inc(c.name, when)
statistics.PerDay.Inc(c.name, when)
c.value++
c.prom.Inc()
}
func (c *counter) Inc() {
c.IncWithTime(time.Now())
}
func (c *counter) Describe(ch chan<- *prometheus.Desc) {
c.prom.Describe(ch)
}
func (c *counter) Collect(ch chan<- prometheus.Metric) {
c.prom.Collect(ch)
}
type histogram struct {
name string // used as key in periodic stats
count int64
total float64
prom prometheus.Histogram
}
func newDNSHistogram(name string, help string) *histogram {
// trace("called")
h := &histogram{}
h.prom = prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: plugin.Namespace,
Subsystem: "dnsfilter",
Name: name,
Help: help,
})
h.name = name
return h
func newDNSHistogram(name string) *histogram {
return &histogram{
name: name,
}
}
func (h *histogram) ObserveWithTime(value float64, when time.Time) {
@ -209,24 +182,40 @@ func (h *histogram) ObserveWithTime(value float64, when time.Time) {
statistics.PerDay.Observe(h.name, when, value)
h.count++
h.total += value
h.prom.Observe(value)
}
func (h *histogram) Observe(value float64) {
h.ObserveWithTime(value, time.Now())
}
func (h *histogram) Describe(ch chan<- *prometheus.Desc) {
h.prom.Describe(ch)
}
func (h *histogram) Collect(ch chan<- prometheus.Metric) {
h.prom.Collect(ch)
}
// -----
// stats
// -----
func incrementCounters(entry *logEntry) {
requests.IncWithTime(entry.Time)
if entry.Result.IsFiltered {
filtered.IncWithTime(entry.Time)
}
switch entry.Result.Reason {
case dnsfilter.NotFilteredWhiteList:
whitelisted.IncWithTime(entry.Time)
case dnsfilter.NotFilteredError:
errorsTotal.IncWithTime(entry.Time)
case dnsfilter.FilteredBlackList:
filteredLists.IncWithTime(entry.Time)
case dnsfilter.FilteredSafeBrowsing:
filteredSafebrowsing.IncWithTime(entry.Time)
case dnsfilter.FilteredParental:
filteredParental.IncWithTime(entry.Time)
case dnsfilter.FilteredInvalid:
// do nothing
case dnsfilter.FilteredSafeSearch:
safesearch.IncWithTime(entry.Time)
}
elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time)
}
func HandleStats(w http.ResponseWriter, r *http.Request) {
const numHours = 24
histrical := generateMapFromStats(&statistics.PerHour, 0, numHours)

239
dnsforward/upstream.go Normal file
View file

@ -0,0 +1,239 @@
package dnsforward
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"math/rand"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/joomcode/errorx"
"github.com/miekg/dns"
)
const defaultTimeout = time.Second * 10
type Upstream interface {
Exchange(m *dns.Msg) (*dns.Msg, error)
Address() string
}
//
// plain DNS
//
type plainDNS struct {
boot bootstrapper
preferTCP bool
}
var defaultUDPClient = dns.Client{
Timeout: defaultTimeout,
UDPSize: dns.MaxMsgSize,
}
var defaultTCPClient = dns.Client{
Net: "tcp",
UDPSize: dns.MaxMsgSize,
Timeout: defaultTimeout,
}
// Address returns the original address that we've put in initially, not resolved one
func (p *plainDNS) Address() string { return p.boot.address }
func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
addr, _, err := p.boot.get()
if err != nil {
return nil, err
}
if p.preferTCP {
reply, _, err := defaultTCPClient.Exchange(m, addr)
return reply, err
}
reply, _, err := defaultUDPClient.Exchange(m, addr)
if err != nil && reply != nil && reply.Truncated {
log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String())
reply, _, err = defaultTCPClient.Exchange(m, addr)
}
return reply, err
}
//
// DNS-over-TLS
//
type dnsOverTLS struct {
boot bootstrapper
pool *TLSPool
sync.RWMutex // protects pool
}
func (p *dnsOverTLS) Address() string { return p.boot.address }
func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
var pool *TLSPool
p.RLock()
pool = p.pool
p.RUnlock()
if pool == nil {
p.Lock()
// lazy initialize it
p.pool = &TLSPool{boot: &p.boot}
p.Unlock()
}
p.RLock()
poolConn, err := p.pool.Get()
p.RUnlock()
if err != nil {
return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address())
}
c := dns.Conn{Conn: poolConn}
err = c.WriteMsg(m)
if err != nil {
poolConn.Close()
return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address())
}
reply, err := c.ReadMsg()
if err != nil {
poolConn.Close()
return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address())
}
p.RLock()
p.pool.Put(poolConn)
p.RUnlock()
return reply, nil
}
//
// DNS-over-https
//
type dnsOverHTTPS struct {
boot bootstrapper
}
func (p *dnsOverHTTPS) Address() string { return p.boot.address }
func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) {
addr, tlsConfig, err := p.boot.get()
if err != nil {
return nil, errorx.Decorate(err, "Couldn't bootstrap %s", p.boot.address)
}
buf, err := m.Pack()
if err != nil {
return nil, errorx.Decorate(err, "Couldn't pack request msg")
}
bb := bytes.NewBuffer(buf)
// set up a custom request with custom URL
url, err := url.Parse(p.boot.address)
if err != nil {
return nil, errorx.Decorate(err, "Couldn't parse URL %s", p.boot.address)
}
req := http.Request{
Method: "POST",
URL: url,
Body: ioutil.NopCloser(bb),
Header: make(http.Header),
Host: url.Host,
}
url.Host = addr
req.Header.Set("Content-Type", "application/dns-message")
client := http.Client{
Transport: &http.Transport{TLSClientConfig: tlsConfig},
}
resp, err := client.Do(&req)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", addr)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", addr)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, addr)
}
if len(body) == 0 {
return nil, fmt.Errorf("Got an unexpected empty body from '%s'", addr)
}
response := dns.Msg{}
err = response.Unpack(body)
if err != nil {
return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", addr, string(body))
}
return &response, nil
}
func (s *Server) chooseUpstream() Upstream {
upstreams := s.Upstreams
if upstreams == nil {
upstreams = defaultValues.Upstreams
}
if len(upstreams) == 0 {
panic("SHOULD NOT HAPPEN: no default upstreams specified")
}
if len(upstreams) == 1 {
return upstreams[0]
}
n := rand.Intn(len(upstreams))
upstream := upstreams[n]
return upstream
}
func AddressToUpstream(address string, bootstrap string) (Upstream, error) {
if strings.Contains(address, "://") {
url, err := url.Parse(address)
if err != nil {
return nil, errorx.Decorate(err, "Failed to parse %s", address)
}
switch url.Scheme {
case "dns":
if url.Port() == "" {
url.Host += ":53"
}
return &plainDNS{boot: toBoot(url.Host, bootstrap)}, nil
case "tcp":
if url.Port() == "" {
url.Host += ":53"
}
return &plainDNS{boot: toBoot(url.Host, bootstrap), preferTCP: true}, nil
case "tls":
if url.Port() == "" {
url.Host += ":853"
}
return &dnsOverTLS{boot: toBoot(url.String(), bootstrap)}, nil
case "https":
if url.Port() == "" {
url.Host += ":443"
}
return &dnsOverHTTPS{boot: toBoot(url.String(), bootstrap)}, nil
default:
// assume it's plain DNS
if url.Port() == "" {
url.Host += ":53"
}
return &plainDNS{boot: toBoot(url.String(), bootstrap)}, nil
}
}
// we don't have scheme in the url, so it's just a plain DNS host:port
_, _, err := net.SplitHostPort(address)
if err != nil {
// doesn't have port, default to 53
address = net.JoinHostPort(address, "53")
}
return &plainDNS{boot: toBoot(address, bootstrap)}, nil
}

View file

@ -0,0 +1,74 @@
package dnsforward
import (
"crypto/tls"
"net"
"sync"
"github.com/joomcode/errorx"
)
// Upstream TLS pool.
//
// Example:
// pool := TLSPool{Address: "tls://1.1.1.1:853"}
// netConn, err := pool.Get()
// if err != nil {panic(err)}
// c := dns.Conn{Conn: netConn}
// q := dns.Msg{}
// q.SetQuestion("google.com.", dns.TypeA)
// log.Println(q)
// err = c.WriteMsg(&q)
// if err != nil {panic(err)}
// r, err := c.ReadMsg()
// if err != nil {panic(err)}
// log.Println(r)
// pool.Put(c.Conn)
type TLSPool struct {
boot *bootstrapper
// connections
conns []net.Conn
connsMutex sync.Mutex // protects conns
}
func (n *TLSPool) Get() (net.Conn, error) {
address, tlsConfig, err := n.boot.get()
if err != nil {
return nil, err
}
// get the connection from the slice inside the lock
var c net.Conn
n.connsMutex.Lock()
num := len(n.conns)
if num > 0 {
last := num - 1
c = n.conns[last]
n.conns = n.conns[:last]
}
n.connsMutex.Unlock()
// if we got connection from the slice, return it
if c != nil {
// log.Printf("Returning existing connection to %s", host)
return c, nil
}
// we'll need a new connection, dial now
// log.Printf("Dialing to %s", address)
conn, err := tls.Dial("tcp", address, tlsConfig)
if err != nil {
return nil, errorx.Decorate(err, "Failed to connect to %s", address)
}
return conn, nil
}
func (n *TLSPool) Put(c net.Conn) {
if c == nil {
return
}
n.connsMutex.Lock()
n.conns = append(n.conns, c)
n.connsMutex.Unlock()
}

View file

@ -0,0 +1,96 @@
package dnsforward
import (
"net"
"testing"
"github.com/miekg/dns"
)
func TestUpstreams(t *testing.T) {
upstreams := []struct {
address string
bootstrap string
}{
{
address: "8.8.8.8:53",
bootstrap: "8.8.8.8:53",
},
{
address: "1.1.1.1",
bootstrap: "",
},
{
address: "tcp://1.1.1.1:53",
bootstrap: "",
},
{
address: "176.103.130.130:5353",
bootstrap: "",
},
{
address: "tls://1.1.1.1",
bootstrap: "",
},
{
address: "tls://9.9.9.9:853",
bootstrap: "",
},
{
address: "tls://security-filter-dns.cleanbrowsing.org",
bootstrap: "8.8.8.8:53",
},
{
address: "tls://adult-filter-dns.cleanbrowsing.org:853",
bootstrap: "8.8.8.8:53",
},
{
address: "https://cloudflare-dns.com/dns-query",
bootstrap: "8.8.8.8:53",
},
{
address: "https://dns.google.com/experimental",
bootstrap: "8.8.8.8:53",
},
{
address: "https://doh.cleanbrowsing.org/doh/security-filter/",
bootstrap: "",
},
}
for _, test := range upstreams {
t.Run(test.address, func(t *testing.T) {
u, err := AddressToUpstream(test.address, test.bootstrap)
if err != nil {
t.Fatalf("Failed to generate upstream from address %s: %s", test.address, err)
}
checkUpstream(t, u, test.address)
})
}
}
func checkUpstream(t *testing.T, u Upstream, addr string) {
t.Helper()
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
reply, err := u.Exchange(&req)
if err != nil {
t.Fatalf("Couldn't talk to upstream %s: %s", addr, err)
}
if len(reply.Answer) != 1 {
t.Fatalf("DNS upstream %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
}
if a, ok := reply.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(a.A) {
t.Fatalf("DNS upstream %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A)
}
} else {
t.Fatalf("DNS upstream %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
}
}

251
filter.go Normal file
View file

@ -0,0 +1,251 @@
package main
import (
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
)
var (
nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID
filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
)
// field ordering is important -- yaml fields will mirror ordering from here
type filter struct {
Enabled bool `json:"enabled"`
URL string `json:"url"`
Name string `json:"name" yaml:"name"`
RulesCount int `json:"rulesCount" yaml:"-"`
LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"`
dnsfilter.Filter `yaml:",inline"`
}
// Creates a helper object for working with the user rules
func userFilter() filter {
return filter{
// User filter always has constant ID=0
Enabled: true,
Filter: dnsfilter.Filter{
Rules: config.UserRules,
},
}
}
func deduplicateFilters() {
// Deduplicate filters
i := 0 // output index, used for deletion later
urls := map[string]bool{}
for _, filter := range config.Filters {
if _, ok := urls[filter.URL]; !ok {
// we didn't see it before, keep it
urls[filter.URL] = true // remember the URL
config.Filters[i] = filter
i++
}
}
// all entries we want to keep are at front, delete the rest
config.Filters = config.Filters[:i]
}
// Set the next filter ID to max(filter.ID) + 1
func updateUniqueFilterID(filters []filter) {
for _, filter := range filters {
if nextFilterID < filter.ID {
nextFilterID = filter.ID + 1
}
}
}
func assignUniqueFilterID() int64 {
value := nextFilterID
nextFilterID += 1
return value
}
// Sets up a timer that will be checking for filters updates periodically
func periodicallyRefreshFilters() {
for range time.Tick(time.Minute) {
refreshFiltersIfNeccessary(false)
}
}
// Checks filters updates if necessary
// If force is true, it ignores the filter.LastUpdated field value
func refreshFiltersIfNeccessary(force bool) int {
config.Lock()
// fetch URLs
updateCount := 0
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we will be operating on a copy
if filter.ID == 0 { // protect against users modifying the yaml and removing the ID
filter.ID = assignUniqueFilterID()
}
updated, err := filter.update(force)
if err != nil {
log.Printf("Failed to update filter %s: %s\n", filter.URL, err)
continue
}
if updated {
// Saving it to the filters dir now
err = filter.save()
if err != nil {
log.Printf("Failed to save the updated filter %d: %s", filter.ID, err)
continue
}
updateCount++
}
}
config.Unlock()
if updateCount > 0 {
reconfigureDNSServer()
}
return updateCount
}
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
func parseFilterContents(contents []byte) (int, string, []string) {
lines := strings.Split(string(contents), "\n")
rulesCount := 0
name := ""
seenTitle := false
// Count lines in the filter
for _, line := range lines {
line = strings.TrimSpace(line)
if len(line) > 0 && line[0] == '!' {
if m := filterTitleRegexp.FindAllStringSubmatch(line, -1); len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
name = m[0][1]
seenTitle = true
}
} else if len(line) != 0 {
rulesCount++
}
}
return rulesCount, name, lines
}
// Checks for filters updates
// If "force" is true -- does not check the filter's LastUpdated field
// Call "save" to persist the filter contents
func (filter *filter) update(force bool) (bool, error) {
if filter.ID == 0 { // protect against users deleting the ID
filter.ID = assignUniqueFilterID()
}
if !filter.Enabled {
return false, nil
}
if !force && time.Since(filter.LastUpdated) <= updatePeriod {
return false, nil
}
log.Printf("Downloading update for filter %d from %s", filter.ID, filter.URL)
// use the same update period for failed filter downloads to avoid flooding with requests
filter.LastUpdated = time.Now()
resp, err := client.Get(filter.URL)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err)
return false, err
}
if resp.StatusCode != 200 {
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
}
contentType := strings.ToLower(resp.Header.Get("content-type"))
if !strings.HasPrefix(contentType, "text/plain") {
log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL)
return false, fmt.Errorf("non-text response %s", contentType)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err)
return false, err
}
// Extract filter name and count number of rules
rulesCount, filterName, rules := parseFilterContents(body)
if filterName != "" {
filter.Name = filterName
}
// Check if the filter has been really changed
if reflect.DeepEqual(filter.Rules, rules) {
log.Printf("Filter #%d at URL %s hasn't changed, not updating it", filter.ID, filter.URL)
return false, nil
}
log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount)
filter.RulesCount = rulesCount
filter.Rules = rules
return true, nil
}
// saves filter contents to the file in dataDir
func (filter *filter) save() error {
filterFilePath := filter.Path()
log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
body := []byte(strings.Join(filter.Rules, "\n"))
return safeWriteFile(filterFilePath, body)
}
// loads filter contents from the file in dataDir
func (filter *filter) load() error {
if !filter.Enabled {
// No need to load a filter that is not enabled
return nil
}
filterFilePath := filter.Path()
log.Printf("Loading filter %d contents to: %s", filter.ID, filterFilePath)
if _, err := os.Stat(filterFilePath); os.IsNotExist(err) {
// do nothing, file doesn't exist
return err
}
filterFileContents, err := ioutil.ReadFile(filterFilePath)
if err != nil {
return err
}
log.Printf("File %s, id %d, length %d", filterFilePath, filter.ID, len(filterFileContents))
rulesCount, _, rules := parseFilterContents(filterFileContents)
filter.RulesCount = rulesCount
filter.Rules = rules
return nil
}
// Path to the filter contents
func (filter *filter) Path() string {
return filepath.Join(config.ourBinaryDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
}

17
go.mod
View file

@ -3,34 +3,19 @@ module github.com/AdguardTeam/AdGuardHome
require (
github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f // indirect
github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 // indirect
github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7
github.com/coredns/coredns v1.2.6
github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11 // indirect
github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710 // indirect
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
github.com/go-ole/go-ole v1.2.1 // indirect
github.com/go-test/deep v1.0.1
github.com/gobuffalo/packr v1.19.0
github.com/google/uuid v1.0.0 // indirect
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/mholt/caddy v0.11.0
github.com/joomcode/errorx v0.1.0
github.com/miekg/dns v1.0.15
github.com/opentracing/opentracing-go v1.0.2 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.8.0
github.com/prometheus/client_golang v0.9.0-pre1
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 // indirect
github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949 // indirect
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d // indirect
github.com/shirou/gopsutil v2.18.10+incompatible
github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 // indirect
go.uber.org/goleak v0.10.0
golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd
golang.org/x/net v0.0.0-20181108082009-03003ca0c849
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8 // indirect
google.golang.org/grpc v1.16.0 // indirect
gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477
gopkg.in/yaml.v2 v2.2.1
)

52
go.sum
View file

@ -1,23 +1,11 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f h1:5ZfJxyXo8KyX8DgGXC5B7ILL8y51fci/qYz2B4j8iLY=
github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg=
github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6 h1:KXlsf+qt/X5ttPGEjR0tPH1xaWWoKBEg9Q1THAj2h3I=
github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7 h1:NpQ+gkFOH27AyDypSCJ/LdsIi/b4rdnEb1N5+IpFfYs=
github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7/go.mod h1:8c4/i2VlovMO2gBnHGQPN5EJw+H0lx1u/5p+cgsXtCk=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/coredns/coredns v1.2.6 h1:QIAOkBqVE44Zx0ttrFqgE5YhCEn64XPIngU60JyuTGM=
github.com/coredns/coredns v1.2.6/go.mod h1:zASH/MVDgR6XZTbxvOnsZfffS+31vg6Ackf/wo1+AM0=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11 h1:m8nX8hsUghn853BJ5qB0lX+VvS6LTJPksWyILFZRYN4=
github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11/go.mod h1:s1PfVYYVmTMgCSPtho4LKBDecEHJWtiVDPNv78Z985U=
github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710 h1:QdyRyGZWLEvJG5Kw3VcVJvhXJ5tZ1MkRgqpJOEZSySM=
github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710/go.mod h1:eNde4IQyEiA5br02AouhEHCu3p3UzrCdFR4LuQHklMI=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/go-ole/go-ole v1.2.1 h1:2lOsA72HgjxAuMlKpFiCbHTvu44PIVkZ5hqm3RSdI/E=
github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8=
github.com/go-test/deep v1.0.1 h1:UQhStjbkDClarlmv0am7OXXO4/GaPdCGiUiMTvi28sg=
@ -28,44 +16,21 @@ github.com/gobuffalo/packd v0.0.0-20181031195726-c82734870264 h1:roWyi0eEdiFreSq
github.com/gobuffalo/packd v0.0.0-20181031195726-c82734870264/go.mod h1:Yf2toFaISlyQrr5TfO3h6DB9pl9mZRmyvBGQb/aQ/pI=
github.com/gobuffalo/packr v1.19.0 h1:3UDmBDxesCOPF8iZdMDBBWKfkBoYujIMIZePnobqIUI=
github.com/gobuffalo/packr v1.19.0/go.mod h1:MstrNkfCQhd5o+Ct4IJ0skWlxN8emOq8DsoT1G98VIU=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 h1:MJG/KsmcqMwFAkh8mTnAwhyKoB+sTAnY4CACC110tbU=
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645/go.mod h1:6iZfnjpejD4L/4DwD7NryNaJyCQdzwWwH2MWhCA90Kw=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc=
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/joomcode/errorx v0.1.0 h1:QmJMiI1DE1UFje2aI1ZWO/VMT5a32qBoXUclGOt8vsc=
github.com/joomcode/errorx v0.1.0/go.mod h1:kgco15ekB6cs+4Xjzo7SPeXzx38PbJzBwbnu9qfVNHQ=
github.com/markbates/oncer v0.0.0-20181014194634-05fccaae8fc4 h1:Mlji5gkcpzkqTROyE4ZxZ8hN7osunMb2RuGVrbvMvCc=
github.com/markbates/oncer v0.0.0-20181014194634-05fccaae8fc4/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/mholt/caddy v0.11.0 h1:cuhEyR7So/SBBRiAaiRBe9BoccDu6uveIPuM9FMMavg=
github.com/mholt/caddy v0.11.0/go.mod h1:Wb1PlT4DAYSqOEd03MsqkdkXnTxA8v9pKjdpxbqM1kY=
github.com/miekg/dns v1.0.15 h1:9+UupePBQCG6zf1q/bGmTO1vumoG13jsrbWOSX1W6Tw=
github.com/miekg/dns v1.0.15/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg=
github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v0.9.0-pre1 h1:AWTOhsOI9qxeirTuA0A4By/1Es1+y9EcCGY6bBZ2fhM=
github.com/prometheus/client_golang v0.9.0-pre1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 h1:idejC8f05m9MGOsuEi1ATq9shN03HrxNkD/luQvxCv8=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949 h1:MVbUQq1a49hMEISI29UcAUjywT3FyvDwx5up90OvVa4=
github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d h1:GoAlyOgbOEIFdaDqxJVlbOQ1DtGmZWs/Qau0hIlk+WQ=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/shirou/gopsutil v2.18.10+incompatible h1:cy84jW6EVRPa5g9HAHrlbxMSIjBhDSX0OFYyMYminYs=
github.com/shirou/gopsutil v2.18.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 h1:udFKJ0aHUL60LboW/A+DfgoHVedieIzIXE8uylPue0U=
@ -80,29 +45,16 @@ go.uber.org/goleak v0.10.0 h1:G3eWbSNIskeRqtsN/1uI5B+eP73y3JUuBsv9AZjehb4=
go.uber.org/goleak v0.10.0/go.mod h1:VCZuO8V8mFPlL0F5J5GK1rtHV3DrFcQ1R8ryq7FK0aI=
golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd h1:VtIkGDhk0ph3t+THbvXHfMZ8QHgsBO39Nh52+74pq7w=
golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181102091132-c10e9556a7bc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181108082009-03003ca0c849 h1:FSqE2GGG7wzsYUsWiQ8MZrvEd1EOyU3NCF0AW3Wtltg=
golang.org/x/net v0.0.0-20181108082009-03003ca0c849/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8 h1:YoY1wS6JYVRpIfFngRf2HHo9R9dAne3xbkGOQ5rJXjU=
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 h1:Nw54tB0rB7hY/N0NQvRW8DG4Yk3Q6T9cu9RcFQDu1tc=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/grpc v1.16.0 h1:dz5IJGuC2BB7qXR5AyHNwAUBhZscK2xVez7mznh72sY=
google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio=
gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477 h1:5xUJw+lg4zao9W4HIDzlFbMYgSgtvNVHh00MEHvbGpQ=
gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477/go.mod h1:QDV1vrFSrowdoOba0UM8VJPUZONT7dnfdLsM+GG53Z8=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View file

@ -10,6 +10,8 @@ import (
"gopkg.in/yaml.v2"
)
const currentSchemaVersion = 2 // used for upgrading from old configs to new config
// Performs necessary upgrade operations if needed
func upgradeConfig() error {
// read a config file into an interface map, so we can manipulate values without losing any
@ -57,7 +59,12 @@ func upgradeConfig() error {
func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) error {
switch oldVersion {
case 0:
err := upgradeSchema0to1(diskConfig)
err := upgradeSchema0to2(diskConfig)
if err != nil {
return err
}
case 1:
err := upgradeSchema1to2(diskConfig)
if err != nil {
return err
}
@ -83,14 +90,13 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err
return nil
}
// The first schema upgrade:
// No more "dnsfilter.txt", filters are now kept in data/filters/
func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func())
// The first schema upgrade:
// No more "dnsfilter.txt", filters are now kept in data/filters/
dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt")
_, err := os.Stat(dnsFilterPath)
if !os.IsNotExist(err) {
if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) {
log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath)
err = os.Remove(dnsFilterPath)
if err != nil {
@ -103,3 +109,38 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
return nil
}
// Second schema upgrade:
// coredns is now dns in config
// delete 'Corefile', since we don't use that anymore
func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func())
coreFilePath := filepath.Join(config.ourBinaryDir, "Corefile")
if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) {
log.Printf("Deleting %s as we don't need it anymore", coreFilePath)
err = os.Remove(coreFilePath)
if err != nil {
log.Printf("Cannot remove %s due to %s", coreFilePath, err)
// not fatal, move on
}
}
if _, ok := (*diskConfig)["dns"]; !ok {
(*diskConfig)["dns"] = (*diskConfig)["coredns"]
delete((*diskConfig), "coredns")
}
(*diskConfig)["schema_version"] = 2
return nil
}
// jump two schemas at once -- this time we just do it sequentially
func upgradeSchema0to2(diskConfig *map[string]interface{}) error {
err := upgradeSchema0to1(diskConfig)
if err != nil {
return err
}
return upgradeSchema1to2(diskConfig)
}

View file

@ -1,105 +0,0 @@
package upstream
import (
"crypto/tls"
"time"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
// DnsUpstream is a very simple upstream implementation for plain DNS
type DnsUpstream struct {
endpoint string // IP:port
timeout time.Duration // Max read and write timeout
proto string // Protocol (tcp, tcp-tls, or udp)
transport *Transport // Persistent connections cache
}
// NewDnsUpstream creates a new DNS upstream
func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) {
u := &DnsUpstream{
endpoint: endpoint,
timeout: defaultTimeout,
proto: proto,
}
var tlsConfig *tls.Config
if proto == "tcp-tls" {
tlsConfig = new(tls.Config)
tlsConfig.ServerName = tlsServerName
}
// Initialize the connections cache
u.transport = NewTransport(endpoint)
u.transport.tlsConfig = tlsConfig
u.transport.Start()
return u, nil
}
// Exchange provides an implementation for the Upstream interface
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
resp, err := u.exchange(u.proto, query)
// Retry over TCP if response is truncated
if err == dns.ErrTruncated && u.proto == "udp" {
resp, err = u.exchange("tcp", query)
} else if err == dns.ErrTruncated && resp != nil {
// Reassemble something to be sent to client
m := new(dns.Msg)
m.SetReply(query)
m.Truncated = true
m.Authoritative = true
m.Rcode = dns.RcodeSuccess
return m, nil
}
if err != nil {
resp = &dns.Msg{}
resp.SetRcode(resp, dns.RcodeServerFailure)
}
return resp, err
}
// Clear resources
func (u *DnsUpstream) Close() error {
// Close active connections
u.transport.Stop()
return nil
}
// Performs a synchronous query. It sends the message m via the conn
// c and waits for a reply. The conn c is not closed.
func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) {
// Establish a connection if needed (or reuse cached)
conn, err := u.transport.Dial(proto)
if err != nil {
return nil, err
}
// Write the request with a timeout
conn.SetWriteDeadline(time.Now().Add(u.timeout))
if err = conn.WriteMsg(query); err != nil {
conn.Close() // Not giving it back
return nil, err
}
// Write response with a timeout
conn.SetReadDeadline(time.Now().Add(u.timeout))
r, err = conn.ReadMsg()
if err != nil {
conn.Close() // Not giving it back
} else if err == nil && r.Id != query.Id {
err = dns.ErrId
conn.Close() // Not giving it back
}
if err == nil {
// Return it back to the connections cache if there were no errors
u.transport.Yield(conn)
}
return r, err
}

View file

@ -1,98 +0,0 @@
package upstream
import (
"net"
"strings"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
// Detects the upstream type from the specified url and creates a proper Upstream object
func NewUpstream(url string, bootstrap string) (Upstream, error) {
proto := "udp"
prefix := ""
switch {
case strings.HasPrefix(url, "tcp://"):
proto = "tcp"
prefix = "tcp://"
case strings.HasPrefix(url, "tls://"):
proto = "tcp-tls"
prefix = "tls://"
case strings.HasPrefix(url, "https://"):
return NewHttpsUpstream(url, bootstrap)
}
hostname := strings.TrimPrefix(url, prefix)
host, port, err := net.SplitHostPort(hostname)
if err != nil {
// Set port depending on the protocol
switch proto {
case "udp":
port = "53"
case "tcp":
port = "53"
case "tcp-tls":
port = "853"
}
// Set host = hostname
host = hostname
}
// Try to resolve the host address (or check if it's an IP address)
bootstrapResolver := CreateResolver(bootstrap)
ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host)
if err != nil || len(ips) == 0 {
return nil, err
}
addr := ips[0].String()
endpoint := net.JoinHostPort(addr, port)
tlsServerName := ""
if proto == "tcp-tls" && host != addr {
// Check if we need to specify TLS server name
tlsServerName = host
}
return NewDnsUpstream(endpoint, proto, tlsServerName)
}
func CreateResolver(bootstrap string) *net.Resolver {
bootstrapResolver := net.DefaultResolver
if bootstrap != "" {
bootstrapResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, network, bootstrap)
},
}
}
return bootstrapResolver
}
// Performs a simple health-check of the specified upstream
func IsAlive(u Upstream) (bool, error) {
// Using ipv4only.arpa. domain as it is a part of DNS64 RFC and it should exist everywhere
ping := new(dns.Msg)
ping.SetQuestion("ipv4only.arpa.", dns.TypeA)
resp, err := u.Exchange(context.Background(), ping)
// If we got a header, we're alright, basically only care about I/O errors 'n stuff.
if err != nil && resp != nil {
// Silly check, something sane came back.
if resp.Rcode != dns.RcodeServerFailure {
err = nil
}
}
return err == nil, err
}

View file

@ -1,128 +0,0 @@
package upstream
import (
"bytes"
"crypto/tls"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"net/url"
"time"
"github.com/miekg/dns"
"github.com/pkg/errors"
"golang.org/x/net/context"
"golang.org/x/net/http2"
)
const (
dnsMessageContentType = "application/dns-message"
defaultKeepAlive = 30 * time.Second
)
// HttpsUpstream is the upstream implementation for DNS-over-HTTPS
type HttpsUpstream struct {
client *http.Client
endpoint *url.URL
}
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url
func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) {
u, err := url.Parse(endpoint)
if err != nil {
return nil, err
}
// Initialize bootstrap resolver
bootstrapResolver := CreateResolver(bootstrap)
dialer := &net.Dialer{
Timeout: defaultTimeout,
KeepAlive: defaultKeepAlive,
DualStack: true,
Resolver: bootstrapResolver,
}
// Update TLS and HTTP client configuration
tlsConfig := &tls.Config{ServerName: u.Hostname()}
transport := &http.Transport{
TLSClientConfig: tlsConfig,
DisableCompression: true,
MaxIdleConns: 1,
DialContext: dialer.DialContext,
}
http2.ConfigureTransport(transport)
client := &http.Client{
Timeout: defaultTimeout,
Transport: transport,
}
return &HttpsUpstream{client: client, endpoint: u}, nil
}
// Exchange provides an implementation for the Upstream interface
func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
queryBuf, err := query.Pack()
if err != nil {
return nil, errors.Wrap(err, "failed to pack DNS query")
}
// No content negotiation for now, use DNS wire format
buf, backendErr := u.exchangeWireformat(queryBuf)
if backendErr == nil {
response := &dns.Msg{}
if err := response.Unpack(buf); err != nil {
return nil, errors.Wrap(err, "failed to unpack DNS response from body")
}
response.Id = query.Id
return response, nil
}
log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr)
return nil, backendErr
}
// Perform message exchange with the default UDP wireformat defined in current draft
// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10
func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) {
req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
if err != nil {
return nil, errors.Wrap(err, "failed to create an HTTPS request")
}
req.Header.Add("Content-Type", dnsMessageContentType)
req.Header.Add("Accept", dnsMessageContentType)
req.Host = u.endpoint.Hostname()
resp, err := u.client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to perform an HTTPS request")
}
// Check response status code
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if contentType != dnsMessageContentType {
return nil, fmt.Errorf("return wrong content type %s", contentType)
}
// Read application/dns-message response from the body
buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read the response body")
}
return buf, nil
}
// Clear resources
func (u *HttpsUpstream) Close() error {
return nil
}

View file

@ -1,210 +0,0 @@
package upstream
import (
"crypto/tls"
"net"
"sort"
"sync/atomic"
"time"
"github.com/miekg/dns"
)
// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin
const (
defaultExpire = 10 * time.Second
minDialTimeout = 100 * time.Millisecond
maxDialTimeout = 30 * time.Second
defaultDialTimeout = 30 * time.Second
cumulativeAvgWeight = 4
)
// a persistConn hold the dns.Conn and the last used time.
type persistConn struct {
c *dns.Conn
used time.Time
}
// Transport hold the persistent cache.
type Transport struct {
avgDialTime int64 // kind of average time of dial time
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
expire time.Duration // After this duration a connection is expired.
addr string
tlsConfig *tls.Config
dial chan string
yield chan *dns.Conn
ret chan *dns.Conn
stop chan bool
}
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
func (t *Transport) Dial(proto string) (*dns.Conn, error) {
// If tls has been configured; use it.
if t.tlsConfig != nil {
proto = "tcp-tls"
}
t.dial <- proto
c := <-t.ret
if c != nil {
return c, nil
}
reqTime := time.Now()
timeout := t.dialTimeout()
if proto == "tcp-tls" {
conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout)
t.updateDialTimeout(time.Since(reqTime))
return conn, err
}
conn, err := dns.DialTimeout(proto, t.addr, timeout)
t.updateDialTimeout(time.Since(reqTime))
return conn, err
}
// Yield return the connection to transport for reuse.
func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
// Start starts the transport's connection manager.
func (t *Transport) Start() { go t.connManager() }
// Stop stops the transport's connection manager.
func (t *Transport) Stop() { close(t.stop) }
// SetExpire sets the connection expire time in transport.
func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
// SetTLSConfig sets the TLS config in transport.
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
func NewTransport(addr string) *Transport {
t := &Transport{
avgDialTime: int64(defaultDialTimeout / 2),
conns: make(map[string][]*persistConn),
expire: defaultExpire,
addr: addr,
dial: make(chan string),
yield: make(chan *dns.Conn),
ret: make(chan *dns.Conn),
stop: make(chan bool),
}
return t
}
func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
dt := time.Duration(atomic.LoadInt64(currentAvg))
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
}
func (t *Transport) dialTimeout() time.Duration {
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
}
func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
}
// limitTimeout is a utility function to auto-tune timeout values
// average observed time is moved towards the last observed delay moderated by a weight
// next timeout to use will be the double of the computed average, limited by min and max frame.
func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
rt := time.Duration(atomic.LoadInt64(currentAvg))
if rt < minValue {
return minValue
}
if rt < maxValue/2 {
return 2 * rt
}
return maxValue
}
// connManagers manages the persistent connection cache for UDP and TCP.
func (t *Transport) connManager() {
ticker := time.NewTicker(t.expire)
Wait:
for {
select {
case proto := <-t.dial:
// take the last used conn - complexity O(1)
if stack := t.conns[proto]; len(stack) > 0 {
pc := stack[len(stack)-1]
if time.Since(pc.used) < t.expire {
// Found one, remove from pool and return this conn.
t.conns[proto] = stack[:len(stack)-1]
t.ret <- pc.c
continue Wait
}
// clear entire cache if the last conn is expired
t.conns[proto] = nil
// now, the connections being passed to closeConns() are not reachable from
// transport methods anymore. So, it's safe to close them in a separate goroutine
go closeConns(stack)
}
t.ret <- nil
case conn := <-t.yield:
// no proto here, infer from config and conn
if _, ok := conn.Conn.(*net.UDPConn); ok {
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()})
continue Wait
}
if t.tlsConfig == nil {
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()})
continue Wait
}
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
case <-ticker.C:
t.cleanup(false)
case <-t.stop:
t.cleanup(true)
close(t.ret)
return
}
}
}
// closeConns closes connections.
func closeConns(conns []*persistConn) {
for _, pc := range conns {
pc.c.Close()
}
}
// cleanup removes connections from cache.
func (t *Transport) cleanup(all bool) {
staleTime := time.Now().Add(-t.expire)
for proto, stack := range t.conns {
if len(stack) == 0 {
continue
}
if all {
t.conns[proto] = nil
// now, the connections being passed to closeConns() are not reachable from
// transport methods anymore. So, it's safe to close them in a separate goroutine
go closeConns(stack)
continue
}
if stack[0].used.After(staleTime) {
continue
}
// connections in stack are sorted by "used"
good := sort.Search(len(stack), func(i int) bool {
return stack[i].used.After(staleTime)
})
t.conns[proto] = stack[good:]
// now, the connections being passed to closeConns() are not reachable from
// transport methods anymore. So, it's safe to close them in a separate goroutine
go closeConns(stack[:good])
}
}

View file

@ -1,81 +0,0 @@
package upstream
import (
"log"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/mholt/caddy"
)
func init() {
caddy.RegisterPlugin("upstream", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
// Read the configuration and initialize upstreams
func setup(c *caddy.Controller) error {
p, err := setupPlugin(c)
if err != nil {
return err
}
config := dnsserver.GetConfig(c)
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
p.Next = next
return p
})
c.OnShutdown(p.onShutdown)
return nil
}
// Read the configuration
func setupPlugin(c *caddy.Controller) (*UpstreamPlugin, error) {
p := New()
log.Println("Initializing the Upstream plugin")
bootstrap := ""
upstreamUrls := []string{}
for c.Next() {
args := c.RemainingArgs()
if len(args) > 0 {
upstreamUrls = append(upstreamUrls, args...)
}
for c.NextBlock() {
switch c.Val() {
case "bootstrap":
if !c.NextArg() {
return nil, c.ArgErr()
}
bootstrap = c.Val()
}
}
}
for _, url := range upstreamUrls {
u, err := NewUpstream(url, bootstrap)
if err != nil {
log.Printf("Cannot initialize upstream %s", url)
return nil, err
}
p.Upstreams = append(p.Upstreams, u)
}
return p, nil
}
func (p *UpstreamPlugin) onShutdown() error {
for i := range p.Upstreams {
u := p.Upstreams[i]
err := u.Close()
if err != nil {
log.Printf("Error while closing the upstream: %s", err)
}
}
return nil
}

View file

@ -1,29 +0,0 @@
package upstream
import (
"testing"
"github.com/mholt/caddy"
)
func TestSetup(t *testing.T) {
var tests = []struct {
config string
}{
{`upstream 8.8.8.8`},
{`upstream 8.8.8.8 {
bootstrap 8.8.8.8:53
}`},
{`upstream tls://1.1.1.1 8.8.8.8 {
bootstrap 1.1.1.1
}`},
}
for _, test := range tests {
c := caddy.NewTestController("dns", test.config)
err := setup(c)
if err != nil {
t.Fatalf("Test failed")
}
}
}

View file

@ -1,57 +0,0 @@
package upstream
import (
"time"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
"github.com/pkg/errors"
"golang.org/x/net/context"
)
const (
defaultTimeout = 5 * time.Second
)
// Upstream is a simplified interface for proxy destination
type Upstream interface {
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
Close() error
}
// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface
type UpstreamPlugin struct {
Upstreams []Upstream
Next plugin.Handler
}
// Initialize the upstream plugin
func New() *UpstreamPlugin {
p := &UpstreamPlugin{
Upstreams: []Upstream{},
}
return p
}
// ServeDNS implements interface for CoreDNS plugin
func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
var reply *dns.Msg
var backendErr error
for i := range p.Upstreams {
upstream := p.Upstreams[i]
reply, backendErr = upstream.Exchange(ctx, r)
if backendErr == nil {
w.WriteMsg(reply)
return 0, nil
}
}
return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
}
// Name implements interface for CoreDNS plugin
func (p *UpstreamPlugin) Name() string {
return "upstream"
}

View file

@ -1,187 +0,0 @@
package upstream
import (
"net"
"testing"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
func TestDnsUpstreamIsAlive(t *testing.T) {
var tests = []struct {
url string
bootstrap string
}{
{"8.8.8.8:53", "8.8.8.8:53"},
{"1.1.1.1", ""},
{"tcp://1.1.1.1:53", ""},
{"176.103.130.130:5353", ""},
}
for _, test := range tests {
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS upstream")
}
testUpstreamIsAlive(t, u)
}
}
func TestHttpsUpstreamIsAlive(t *testing.T) {
var tests = []struct {
url string
bootstrap string
}{
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
{"https://dns.google.com/experimental", "8.8.8.8:53"},
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
}
for _, test := range tests {
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS-over-HTTPS upstream")
}
testUpstreamIsAlive(t, u)
}
}
func TestDnsOverTlsIsAlive(t *testing.T) {
var tests = []struct {
url string
bootstrap string
}{
{"tls://1.1.1.1", ""},
{"tls://9.9.9.9:853", ""},
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
}
for _, test := range tests {
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS-over-TLS upstream")
}
testUpstreamIsAlive(t, u)
}
}
func TestDnsUpstream(t *testing.T) {
var tests = []struct {
url string
bootstrap string
}{
{"8.8.8.8:53", "8.8.8.8:53"},
{"1.1.1.1", ""},
{"tcp://1.1.1.1:53", ""},
{"176.103.130.130:5353", ""},
}
for _, test := range tests {
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS upstream")
}
testUpstream(t, u)
}
}
func TestHttpsUpstream(t *testing.T) {
var tests = []struct {
url string
bootstrap string
}{
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
{"https://dns.google.com/experimental", "8.8.8.8:53"},
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
}
for _, test := range tests {
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS-over-HTTPS upstream")
}
testUpstream(t, u)
}
}
func TestDnsOverTlsUpstream(t *testing.T) {
var tests = []struct {
url string
bootstrap string
}{
{"tls://1.1.1.1", ""},
{"tls://9.9.9.9:853", ""},
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
}
for _, test := range tests {
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS-over-TLS upstream")
}
testUpstream(t, u)
}
}
func testUpstreamIsAlive(t *testing.T, u Upstream) {
alive, err := IsAlive(u)
if !alive || err != nil {
t.Errorf("Upstream is not alive")
}
u.Close()
}
func testUpstream(t *testing.T, u Upstream) {
var tests = []struct {
name string
expected net.IP
}{
{"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)},
{"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)},
}
for _, test := range tests {
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
resp, err := u.Exchange(context.Background(), &req)
if err != nil {
t.Fatalf("error while making an upstream request: %s", err)
}
if len(resp.Answer) != 1 {
t.Fatalf("no answer section in the response")
}
if answer, ok := resp.Answer[0].(*dns.A); ok {
if !test.expected.Equal(answer.A) {
t.Errorf("wrong IP in the response: %v", answer.A)
}
}
}
err := u.Close()
if err != nil {
t.Errorf("Error while closing the upstream: %s", err)
}
}