Filters are now saved to a file
Also, they're loaded from the file on startup
Filter ID is not passed to the CoreDNS plugin config (server-side AG DNS must be changed accordingly)
Some minor refactoring, unused functions removed
This commit is contained in:
Andrey Meshkov 2018-10-30 02:17:24 +03:00
parent 30f3eb446c
commit 32d4e80c93
8 changed files with 339 additions and 190 deletions

1
.gitignore vendored
View file

@ -4,6 +4,7 @@
debug debug
/AdGuardHome /AdGuardHome
/AdGuardHome.yaml /AdGuardHome.yaml
/data/
/build/ /build/
/client/node_modules/ /client/node_modules/
/coredns /coredns

117
app.go
View file

@ -25,10 +25,18 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
config.ourBinaryDir = filepath.Dir(executable)
}
doConfigRename := true executableName := filepath.Base(executable)
if executableName == "AdGuardHome" {
// Binary build
config.ourBinaryDir = filepath.Dir(executable)
} else {
// Most likely we're debugging -- using current working directory in this case
workDir, _ := os.Getwd()
config.ourBinaryDir = workDir
}
log.Printf("Current working directory is %s", config.ourBinaryDir)
}
// config can be specified, which reads options from there, but other command line flags have to override config values // config can be specified, which reads options from there, but other command line flags have to override config values
// therefore, we must do it manually instead of using a lib // therefore, we must do it manually instead of using a lib
@ -98,18 +106,9 @@ func main() {
} }
} }
if configFilename != nil { if configFilename != nil {
// config was manually specified, don't do anything
doConfigRename = false
config.ourConfigFilename = *configFilename config.ourConfigFilename = *configFilename
} }
if doConfigRename {
err := renameOldConfigIfNeccessary()
if err != nil {
panic(err)
}
}
err := askUsernamePasswordIfPossible() err := askUsernamePasswordIfPossible()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@ -128,16 +127,32 @@ func main() {
} }
} }
// eat all args so that coredns can start happily // Eat all args so that coredns can start happily
if len(os.Args) > 1 { if len(os.Args) > 1 {
os.Args = os.Args[:1] os.Args = os.Args[:1]
} }
err := writeConfig() // Do the upgrade if necessary
err := upgradeConfig()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
// Save the updated config
err = writeConfig()
if err != nil {
log.Fatal(err)
}
// Load filters from the disk
for i := range config.Filters {
filter := &config.Filters[i]
err = filter.load()
if err != nil {
log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err)
}
}
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort)) address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
runFilterRefreshers() runFilterRefreshers()
@ -240,27 +255,71 @@ func askUsernamePasswordIfPossible() error {
return nil return nil
} }
func renameOldConfigIfNeccessary() error { // Performs necessary upgrade operations if needed
oldConfigFile := filepath.Join(config.ourBinaryDir, "AdguardDNS.yaml") func upgradeConfig() error {
_, err := os.Stat(oldConfigFile)
if os.IsNotExist(err) { if config.SchemaVersion == SchemaVersion {
// do nothing, file doesn't exist // No upgrade, do nothing
trace("File %s doesn't exist, nothing to do", oldConfigFile)
return nil return nil
} }
newConfigFile := filepath.Join(config.ourBinaryDir, config.ourConfigFilename) if config.SchemaVersion > SchemaVersion {
_, err = os.Stat(newConfigFile) // Unexpected -- config file is newer than the
if !os.IsNotExist(err) { return fmt.Errorf("configuration file is supposed to be used with a newer version of AdGuard Home, schema=%d", config.SchemaVersion)
// do nothing, file doesn't exist
trace("File %s already exists, will not overwrite", newConfigFile)
return nil
} }
err = os.Rename(oldConfigFile, newConfigFile) // Perform upgrade operations for each consecutive version upgrade
for oldVersion, newVersion := config.SchemaVersion, config.SchemaVersion+1; newVersion <= SchemaVersion; {
err := upgradeConfigSchema(oldVersion, newVersion)
if err != nil { if err != nil {
log.Printf("Failed to rename %s to %s: %s", oldConfigFile, newConfigFile, err) log.Fatal(err)
return err }
// Increment old and new versions
oldVersion++
newVersion++
}
// Save the current schema version
config.SchemaVersion = SchemaVersion
return nil
}
// Upgrade from oldVersion to newVersion
func upgradeConfigSchema(oldVersion int, newVersion int) error {
if oldVersion == 0 && newVersion == 1 {
log.Printf("Updating schema from %d to %d", oldVersion, newVersion)
// The first schema upgrade:
// Added "ID" field to "filter" -- we need to populate this field now
// Added "config.ourDataDir" -- where we will now store filters contents
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we will be operating on a copy
log.Printf("Seting ID=%d for filter %s", i, filter.URL)
filter.ID = i + 1 // start with ID=1
// Forcibly update the filter
_, err := filter.update(true)
if err != nil {
log.Fatal(err)
}
}
// No more "dnsfilter.txt", filters are now loaded from config.ourDataDir/filters/
dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt")
_, err := os.Stat(dnsFilterPath)
if !os.IsNotExist(err) {
log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath)
err = os.Remove(dnsFilterPath)
if err != nil {
log.Printf("Cannot remove %s due to %s", dnsFilterPath, err)
}
}
} }
return nil return nil

104
config.go
View file

@ -2,6 +2,7 @@ package main
import ( import (
"bytes" "bytes"
"gopkg.in/yaml.v2"
"io/ioutil" "io/ioutil"
"log" "log"
"os" "os"
@ -10,15 +11,24 @@ import (
"sync" "sync"
"text/template" "text/template"
"time" "time"
"gopkg.in/yaml.v2"
) )
// Current schema version. We compare it with the value from
// the configuration file and perform necessary upgrade operations if needed
const SchemaVersion = 1
// Directory where we'll store all downloaded filters contents
const FiltersDir = "filters"
// configuration is loaded from YAML // configuration is loaded from YAML
type configuration struct { type configuration struct {
ourConfigFilename string ourConfigFilename string
ourBinaryDir string ourBinaryDir string
// Directory to store data (i.e. filters contents)
ourDataDir string
// Schema version of the config file. This value is used when performing the app updates.
SchemaVersion int `yaml:"schema_version"`
BindHost string `yaml:"bind_host"` BindHost string `yaml:"bind_host"`
BindPort int `yaml:"bind_port"` BindPort int `yaml:"bind_port"`
AuthName string `yaml:"auth_name"` AuthName string `yaml:"auth_name"`
@ -30,10 +40,15 @@ type configuration struct {
sync.RWMutex `yaml:"-"` sync.RWMutex `yaml:"-"`
} }
type coreDnsFilter struct {
ID int `yaml:"-"`
Path string `yaml:"-"`
}
type coreDNSConfig struct { type coreDNSConfig struct {
binaryFile string binaryFile string
coreFile string coreFile string
FilterFile string `yaml:"-"` Filters []coreDnsFilter `yaml:"-"`
Port int `yaml:"port"` Port int `yaml:"port"`
ProtectionEnabled bool `yaml:"protection_enabled"` ProtectionEnabled bool `yaml:"protection_enabled"`
FilteringEnabled bool `yaml:"filtering_enabled"` FilteringEnabled bool `yaml:"filtering_enabled"`
@ -50,6 +65,7 @@ type coreDNSConfig struct {
} }
type filter struct { type filter struct {
ID int `json:"ID"` // auto-assigned when filter is added
URL string `json:"url"` URL string `json:"url"`
Name string `json:"name" yaml:"name"` Name string `json:"name" yaml:"name"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
@ -63,13 +79,13 @@ var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"}
// initialize to default values, will be changed later when reading config or parsing command line // initialize to default values, will be changed later when reading config or parsing command line
var config = configuration{ var config = configuration{
ourConfigFilename: "AdGuardHome.yaml", ourConfigFilename: "AdGuardHome.yaml",
ourDataDir: "data",
BindPort: 3000, BindPort: 3000,
BindHost: "127.0.0.1", BindHost: "127.0.0.1",
CoreDNS: coreDNSConfig{ CoreDNS: coreDNSConfig{
Port: 53, Port: 53,
binaryFile: "coredns", // only filename, no path binaryFile: "coredns", // only filename, no path
coreFile: "Corefile", // only filename, no path coreFile: "Corefile", // only filename, no path
FilterFile: "dnsfilter.txt", // only filename, no path
ProtectionEnabled: true, ProtectionEnabled: true,
FilteringEnabled: true, FilteringEnabled: true,
SafeBrowsingEnabled: false, SafeBrowsingEnabled: false,
@ -80,13 +96,33 @@ var config = configuration{
Prometheus: "prometheus :9153", Prometheus: "prometheus :9153",
}, },
Filters: []filter{ Filters: []filter{
{Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt"}, {ID: 1, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"},
{Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"}, {ID: 2, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"},
{Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"}, {ID: 3, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"},
{Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"}, {ID: 4, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"},
}, },
} }
// Creates a helper object for working with the user rules
func getUserFilter() filter {
// TODO: This should be calculated when UserRules are set
contents := []byte{}
for _, rule := range config.UserRules {
contents = append(contents, []byte(rule)...)
contents = append(contents, '\n')
}
userFilter := filter{
// User filter always has ID=0
ID: 0,
contents: contents,
Enabled: true,
}
return userFilter
}
func parseConfig() error { func parseConfig() error {
configfile := filepath.Join(config.ourBinaryDir, config.ourConfigFilename) configfile := filepath.Join(config.ourBinaryDir, config.ourConfigFilename)
log.Printf("Reading YAML file: %s", configfile) log.Printf("Reading YAML file: %s", configfile)
@ -117,16 +153,19 @@ func writeConfig() error {
log.Printf("Couldn't generate YAML file: %s", err) log.Printf("Couldn't generate YAML file: %s", err)
return err return err
} }
err = ioutil.WriteFile(configfile+".tmp", yamlText, 0644) err = writeFileSafe(configfile, yamlText)
if err != nil { if err != nil {
log.Printf("Couldn't write YAML config: %s", err) log.Printf("Couldn't save YAML config: %s", err)
return err return err
} }
err = os.Rename(configfile+".tmp", configfile)
userFilter := getUserFilter()
err = userFilter.save()
if err != nil { if err != nil {
log.Printf("Couldn't rename YAML config: %s", err) log.Printf("Couldn't save the user filter: %s", err)
return err return err
} }
return nil return nil
} }
@ -141,15 +180,12 @@ func writeCoreDNSConfig() error {
log.Printf("Couldn't generate DNS config: %s", err) log.Printf("Couldn't generate DNS config: %s", err)
return err return err
} }
err = ioutil.WriteFile(corefile+".tmp", []byte(configtext), 0644) err = writeFileSafe(corefile, []byte(configtext))
if err != nil { if err != nil {
log.Printf("Couldn't write DNS config: %s", err) log.Printf("Couldn't save DNS config: %s", err)
}
err = os.Rename(corefile+".tmp", corefile)
if err != nil {
log.Printf("Couldn't rename DNS config: %s", err)
}
return err return err
}
return nil
} }
func writeAllConfigs() error { func writeAllConfigs() error {
@ -167,12 +203,17 @@ func writeAllConfigs() error {
} }
const coreDNSConfigTemplate = `.:{{.Port}} { const coreDNSConfigTemplate = `.:{{.Port}} {
{{if .ProtectionEnabled}}dnsfilter {{if .FilteringEnabled}}{{.FilterFile}}{{end}} { {{if .ProtectionEnabled}}dnsfilter {
{{if .SafeBrowsingEnabled}}safebrowsing{{end}} {{if .SafeBrowsingEnabled}}safebrowsing{{end}}
{{if .ParentalEnabled}}parental {{.ParentalSensitivity}}{{end}} {{if .ParentalEnabled}}parental {{.ParentalSensitivity}}{{end}}
{{if .SafeSearchEnabled}}safesearch{{end}} {{if .SafeSearchEnabled}}safesearch{{end}}
{{if .QueryLogEnabled}}querylog{{end}} {{if .QueryLogEnabled}}querylog{{end}}
blocked_ttl {{.BlockedResponseTTL}} blocked_ttl {{.BlockedResponseTTL}}
{{if .FilteringEnabled}}
{{range .Filters}}
filter {{.ID}} "{{.Path}}"
{{end}}
{{end}}
}{{end}} }{{end}}
{{.Pprof}} {{.Pprof}}
hosts { hosts {
@ -196,7 +237,28 @@ func generateCoreDNSConfigText() (string, error) {
var configBytes bytes.Buffer var configBytes bytes.Buffer
temporaryConfig := config.CoreDNS temporaryConfig := config.CoreDNS
temporaryConfig.FilterFile = filepath.Join(config.ourBinaryDir, config.CoreDNS.FilterFile)
// fill the list of filters
filters := make([]coreDnsFilter, 0)
// first of all, append the user filter
userFilter := getUserFilter()
// TODO: Don't add if empty
//if len(userFilter.contents) > 0 {
filters = append(filters, coreDnsFilter{ID: userFilter.ID, Path: userFilter.getFilterFilePath()})
//}
// then go through other filters
for i := range config.Filters {
filter := &config.Filters[i]
if filter.Enabled && len(filter.contents) > 0 {
filters = append(filters, coreDnsFilter{ID: filter.ID, Path: filter.getFilterFilePath()})
}
}
temporaryConfig.Filters = filters
// run the template // run the template
err = t.Execute(&configBytes, &temporaryConfig) err = t.Execute(&configBytes, &temporaryConfig)
if err != nil { if err != nil {

View file

@ -16,7 +16,6 @@ import (
"time" "time"
coredns_plugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin" coredns_plugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/miekg/dns" "github.com/miekg/dns"
"gopkg.in/asaskevich/govalidator.v4" "gopkg.in/asaskevich/govalidator.v4"
) )
@ -423,7 +422,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
} }
} }
ok, err := filter.update(time.Now()) ok, err := filter.update(true)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't fetch filter from url %s: %s", filter.URL, err) errortext := fmt.Sprintf("Couldn't fetch filter from url %s: %s", filter.URL, err)
log.Println(errortext) log.Println(errortext)
@ -452,14 +451,9 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errortext, http.StatusInternalServerError)
return return
} }
err = writeFilterFile()
if err != nil {
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
return
}
tellCoreDNSToReload() tellCoreDNSToReload()
_, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount) _, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errortext := fmt.Sprintf("Couldn't write body: %s", err)
@ -468,6 +462,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
} }
} }
// TODO: Start using filter ID
func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
parameters, err := parseParametersFromBody(r.Body) parameters, err := parseParametersFromBody(r.Body)
if err != nil { if err != nil {
@ -493,19 +488,22 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
for _, filter := range config.Filters { for _, filter := range config.Filters {
if filter.URL != url { if filter.URL != url {
newFilters = append(newFilters, filter) newFilters = append(newFilters, filter)
} } else {
} // Remove the filter file
config.Filters = newFilters err := os.Remove(filter.getFilterFilePath())
err = writeFilterFile()
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write filter file: %s", err) errortext := fmt.Sprintf("Couldn't remove the filter file: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errortext, http.StatusInternalServerError)
return return
} }
}
}
// Update the configuration after removing filter files
config.Filters = newFilters
httpUpdateConfigReloadDNSReturnOK(w, r) httpUpdateConfigReloadDNSReturnOK(w, r)
} }
// TODO: Start using filter ID
func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) { func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
parameters, err := parseParametersFromBody(r.Body) parameters, err := parseParametersFromBody(r.Body)
if err != nil { if err != nil {
@ -542,16 +540,10 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
// kick off refresh of rules from new URLs // kick off refresh of rules from new URLs
refreshFiltersIfNeccessary() refreshFiltersIfNeccessary()
err = writeFilterFile()
if err != nil {
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
return
}
httpUpdateConfigReloadDNSReturnOK(w, r) httpUpdateConfigReloadDNSReturnOK(w, r)
} }
// TODO: Start using filter ID
func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) { func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) {
parameters, err := parseParametersFromBody(r.Body) parameters, err := parseParametersFromBody(r.Body)
if err != nil { if err != nil {
@ -586,13 +578,6 @@ func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) {
return return
} }
err = writeFilterFile()
if err != nil {
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
return
}
httpUpdateConfigReloadDNSReturnOK(w, r) httpUpdateConfigReloadDNSReturnOK(w, r)
} }
@ -606,13 +591,6 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
} }
config.UserRules = strings.Split(string(body), "\n") config.UserRules = strings.Split(string(body), "\n")
err = writeFilterFile()
if err != nil {
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
return
}
httpUpdateConfigReloadDNSReturnOK(w, r) httpUpdateConfigReloadDNSReturnOK(w, r)
} }
@ -639,7 +617,6 @@ func runFilterRefreshers() {
} }
func refreshFiltersIfNeccessary() int { func refreshFiltersIfNeccessary() int {
now := time.Now()
config.Lock() config.Lock()
// deduplicate // deduplicate
@ -663,7 +640,7 @@ func refreshFiltersIfNeccessary() int {
updateCount := 0 updateCount := 0
for i := range config.Filters { for i := range config.Filters {
filter := &config.Filters[i] // otherwise we will be operating on a copy filter := &config.Filters[i] // otherwise we will be operating on a copy
updated, err := filter.update(now) updated, err := filter.update(false)
if err != nil { if err != nil {
log.Printf("Failed to update filter %s: %s\n", filter.URL, err) log.Printf("Failed to update filter %s: %s\n", filter.URL, err)
continue continue
@ -675,27 +652,25 @@ func refreshFiltersIfNeccessary() int {
config.Unlock() config.Unlock()
if updateCount > 0 { if updateCount > 0 {
err := writeFilterFile()
if err != nil {
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
log.Println(errortext)
}
tellCoreDNSToReload() tellCoreDNSToReload()
} }
return updateCount return updateCount
} }
func (filter *filter) update(now time.Time) (bool, error) { // Checks for filters updates
// If "force" is true -- does not check the filter's LastUpdated field
func (filter *filter) update(force bool) (bool, error) {
if !filter.Enabled { if !filter.Enabled {
return false, nil return false, nil
} }
elapsed := time.Since(filter.LastUpdated) if !force && time.Since(filter.LastUpdated) <= updatePeriod {
if elapsed <= updatePeriod {
return false, nil return false, nil
} }
log.Printf("Downloading update for filter %d", filter.ID)
// use same update period for failed filter downloads to avoid flooding with requests // use same update period for failed filter downloads to avoid flooding with requests
filter.LastUpdated = now filter.LastUpdated = time.Now()
resp, err := client.Get(filter.URL) resp, err := client.Get(filter.URL)
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
@ -706,9 +681,15 @@ func (filter *filter) update(now time.Time) (bool, error) {
return false, err return false, err
} }
if resp.StatusCode >= 400 { if resp.StatusCode != 200 {
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL) log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
return false, fmt.Errorf("Got status code >= 400: %d", resp.StatusCode) return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
}
contentType := strings.ToLower(resp.Header.Get("content-type"))
if !strings.HasPrefix(contentType, "text/plain") {
log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL)
return false, fmt.Errorf("non-text response %s", contentType)
} }
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
@ -717,11 +698,12 @@ func (filter *filter) update(now time.Time) (bool, error) {
return false, err return false, err
} }
// extract filter name and count number of rules // Extract filter name and count number of rules
lines := strings.Split(string(body), "\n") lines := strings.Split(string(body), "\n")
rulesCount := 0 rulesCount := 0
seenTitle := false seenTitle := false
d := dnsfilter.New()
// Count lines in the filter
for _, line := range lines { for _, line := range lines {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
if len(line) > 0 && line[0] == '!' { if len(line) > 0 && line[0] == '!' {
@ -730,61 +712,73 @@ func (filter *filter) update(now time.Time) (bool, error) {
seenTitle = true seenTitle = true
} }
} else if len(line) != 0 { } else if len(line) != 0 {
err = d.AddRule(line, 0)
if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax {
continue
}
if err != nil {
log.Printf("Cannot add rule %s from %s: %s", line, filter.URL, err)
// Just ignore invalid rules
continue
}
rulesCount++ rulesCount++
} }
} }
// Check if the filter was really changed
if bytes.Equal(filter.contents, body) { if bytes.Equal(filter.contents, body) {
return false, nil return false, nil
} }
log.Printf("Filter %s updated: %d bytes, %d rules", filter.URL, len(body), rulesCount) log.Printf("Filter %s updated: %d bytes, %d rules", filter.URL, len(body), rulesCount)
filter.RulesCount = rulesCount filter.RulesCount = rulesCount
filter.contents = body filter.contents = body
// Saving it to the filters dir now
err = filter.save()
if err != nil {
return false, nil
}
return true, nil return true, nil
} }
// write filter file // saves filter contents to the file in config.ourDataDir
func writeFilterFile() error { func (filter *filter) save() error {
filterpath := filepath.Join(config.ourBinaryDir, config.CoreDNS.FilterFile)
log.Printf("Writing filter file: %s", filterpath) filterFilePath := filter.getFilterFilePath()
// TODO: check if file contents have modified log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
data := []byte{}
config.RLock() err := writeFileSafe(filterFilePath, filter.contents)
filters := config.Filters
for _, filter := range filters {
if !filter.Enabled {
continue
}
data = append(data, filter.contents...)
data = append(data, '\n')
}
for _, rule := range config.UserRules {
data = append(data, []byte(rule)...)
data = append(data, '\n')
}
config.RUnlock()
err := ioutil.WriteFile(filterpath+".tmp", data, 0644)
if err != nil { if err != nil {
log.Printf("Couldn't write filter file: %s", err)
return err return err
} }
err = os.Rename(filterpath+".tmp", filterpath) return nil;
if err != nil { }
log.Printf("Couldn't rename filter file: %s", err)
// loads filter contents from the file in config.ourDataDir
func (filter *filter) load() error {
if !filter.Enabled {
// No need to load a filter that is not enabled
return nil
}
filterFilePath := filter.getFilterFilePath()
log.Printf("Loading filter %d contents to: %s", filter.ID, filterFilePath)
if _, err := os.Stat(filterFilePath); os.IsNotExist(err) {
// do nothing, file doesn't exist
return err return err
} }
filterFile, err := ioutil.ReadFile(filterFilePath)
if err != nil {
return err
}
log.Printf("Filter %d length is %d", filter.ID, len(filterFile))
filter.contents = filterFile
return nil return nil
} }
// Path to the filter contents
func (filter *filter) getFilterFilePath() string {
return filepath.Join(config.ourBinaryDir, config.ourDataDir, FiltersDir, strconv.Itoa(filter.ID) + ".txt")
}
// ------------ // ------------
// safebrowsing // safebrowsing
// ------------ // ------------

View file

@ -120,12 +120,6 @@ func startDNSServer() error {
log.Println(errortext) log.Println(errortext)
return errortext return errortext
} }
err = writeFilterFile()
if err != nil {
errortext := fmt.Errorf("Couldn't write filter file: %s", err)
log.Println(errortext)
return errortext
}
go coremain.Run() go coremain.Run()
return nil return nil

View file

@ -51,11 +51,17 @@ var (
lookupCache = map[string]cacheEntry{} lookupCache = map[string]cacheEntry{}
) )
type plugFilter struct {
ID uint32
Path string
}
type plugSettings struct { type plugSettings struct {
SafeBrowsingBlockHost string SafeBrowsingBlockHost string
ParentalBlockHost string ParentalBlockHost string
QueryLogEnabled bool QueryLogEnabled bool
BlockedTTL uint32 // in seconds, default 3600 BlockedTTL uint32 // in seconds, default 3600
Filters []plugFilter
} }
type plug struct { type plug struct {
@ -71,6 +77,7 @@ var defaultPluginSettings = plugSettings{
SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com", SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com",
ParentalBlockHost: "family.block.dns.adguard.com", ParentalBlockHost: "family.block.dns.adguard.com",
BlockedTTL: 3600, // in seconds BlockedTTL: 3600, // in seconds
Filters: make([]plugFilter, 0),
} }
// //
@ -83,14 +90,12 @@ func setupPlugin(c *caddy.Controller) (*plug, error) {
d: dnsfilter.New(), d: dnsfilter.New(),
} }
filterFileNames := []string{} log.Println("Initializing the CoreDNS plugin")
for c.Next() { for c.Next() {
args := c.RemainingArgs()
if len(args) > 0 {
filterFileNames = append(filterFileNames, args...)
}
for c.NextBlock() { for c.NextBlock() {
switch c.Val() { blockValue := c.Val()
switch blockValue {
case "safebrowsing": case "safebrowsing":
p.d.EnableSafeBrowsing() p.d.EnableSafeBrowsing()
if c.NextArg() { if c.NextArg() {
@ -130,17 +135,38 @@ func setupPlugin(c *caddy.Controller) (*plug, error) {
p.settings.BlockedTTL = uint32(blockttl) p.settings.BlockedTTL = uint32(blockttl)
case "querylog": case "querylog":
p.settings.QueryLogEnabled = true p.settings.QueryLogEnabled = true
case "filter":
if !c.NextArg() {
return nil, c.ArgErr()
}
filterId, err := strconv.Atoi(c.Val())
if err != nil {
return nil, c.ArgErr()
}
if !c.NextArg() {
return nil, c.ArgErr()
}
filterPath := c.Val()
// Initialize filter and add it to the list
p.settings.Filters = append(p.settings.Filters, plugFilter{
ID: uint32(filterId),
Path: filterPath,
})
} }
} }
} }
log.Printf("filterFileNames = %+v", filterFileNames) for _, filter := range p.settings.Filters {
log.Printf("Loading rules from %s", filter.Path)
for i, filterFileName := range filterFileNames { file, err := os.Open(filter.Path)
file, err := os.Open(filterFileName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
//noinspection GoDeferInLoop
defer file.Close() defer file.Close()
count := 0 count := 0
@ -148,7 +174,7 @@ func setupPlugin(c *caddy.Controller) (*plug, error) {
for scanner.Scan() { for scanner.Scan() {
text := scanner.Text() text := scanner.Text()
err = p.d.AddRule(text, uint32(i)) err = p.d.AddRule(text, filter.ID)
if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax { if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax {
continue continue
} }
@ -159,7 +185,7 @@ func setupPlugin(c *caddy.Controller) (*plug, error) {
} }
count++ count++
} }
log.Printf("Added %d rules from %s", count, filterFileName) log.Printf("Added %d rules from %d", count, filter.ID)
if err = scanner.Err(); err != nil { if err = scanner.Err(); err != nil {
return nil, err return nil, err
@ -250,6 +276,7 @@ func (p *plug) onFinalShutdown() error {
type statsFunc func(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) type statsFunc func(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType)
//noinspection GoUnusedParameter
func doDesc(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) { func doDesc(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) {
realch, ok := ch.(chan<- *prometheus.Desc) realch, ok := ch.(chan<- *prometheus.Desc)
if !ok { if !ok {
@ -391,7 +418,7 @@ func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.M
func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, dnsfilter.Result, error) { func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, dnsfilter.Result, error) {
if len(r.Question) != 1 { if len(r.Question) != 1 {
// google DNS, bind and others do the same // google DNS, bind and others do the same
return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("Got DNS request with != 1 questions") return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("got a DNS request with more than one Question")
} }
for _, question := range r.Question { for _, question := range r.Question {
host := strings.ToLower(strings.TrimSuffix(question.Name, ".")) host := strings.ToLower(strings.TrimSuffix(question.Name, "."))

View file

@ -15,6 +15,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// TODO: Change tests -- there's new config template now
func TestSetup(t *testing.T) { func TestSetup(t *testing.T) {
for i, testcase := range []struct { for i, testcase := range []struct {
config string config string

View file

@ -5,21 +5,39 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"os" "os"
"path" "path"
"path/filepath"
"runtime" "runtime"
"strings" "strings"
) )
func clamp(value, low, high int) int { // ----------------------------------
if value < low { // helper functions for working with files
return low // ----------------------------------
// Writes data first to a temporary file and then renames it to what's specified in path
func writeFileSafe(path string, data []byte) error {
dir := filepath.Dir(path)
err := os.MkdirAll(dir, 0755)
if err != nil {
return err
} }
if value > high {
return high tmpPath := path + ".tmp"
err = ioutil.WriteFile(tmpPath, data, 0644)
if err != nil {
return err
} }
return value err = os.Rename(tmpPath, path)
if err != nil {
return err
}
return nil
} }
// ---------------------------------- // ----------------------------------
@ -117,13 +135,6 @@ func parseParametersFromBody(r io.Reader) (map[string]string, error) {
// --------------------- // ---------------------
// debug logging helpers // debug logging helpers
// --------------------- // ---------------------
func _Func() string {
pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc)
f := runtime.FuncForPC(pc[0])
return path.Base(f.Name())
}
func trace(format string, args ...interface{}) { func trace(format string, args ...interface{}) {
pc := make([]uintptr, 10) // at least 1 entry needed pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc) runtime.Callers(2, pc)