gometalinter

This commit is contained in:
Andrey Meshkov 2019-01-24 20:11:01 +03:00 committed by Eugene Bujak
parent c9d627ea71
commit d078851246
21 changed files with 454 additions and 409 deletions

4
.gitignore vendored
View file

@ -12,5 +12,5 @@
/scripts/translations/oneskyapp.json /scripts/translations/oneskyapp.json
# Test output # Test output
dnsfilter/dnsfilter.TestLotsOfRules*.pprof dnsfilter/tests/top-1m.csv
tests/top-1m.csv dnsfilter/tests/dnsfilter.TestLotsOfRules*.pprof

40
.gometalinter.json Normal file
View file

@ -0,0 +1,40 @@
{
"Vendor": true,
"Test": true,
"Deadline": "2m",
"Sort": ["linter", "severity", "path", "line"],
"Exclude": [
".*generated.*",
"dnsfilter/rule_to_regexp.go"
],
"EnableGC": true,
"Linters": {
"nakedret": {
"Command": "nakedret",
"Pattern": "^(?P<path>.*?\\.go):(?P<line>\\d+)\\s*(?P<message>.*)$"
}
},
"WarnUnmatchedDirective": true,
"DisableAll": true,
"Enable": [
"deadcode",
"gocyclo",
"gofmt",
"goimports",
"golint",
"gosimple",
"ineffassign",
"interfacer",
"lll",
"misspell",
"nakedret",
"unconvert",
"unparam",
"unused",
"vet"
],
"Cyclo": 20,
"LineLength": 200
}

163
app.go
View file

@ -43,7 +43,89 @@ func main() {
// config can be specified, which reads options from there, but other command line flags have to override config values // config can be specified, which reads options from there, but other command line flags have to override config values
// therefore, we must do it manually instead of using a lib // therefore, we must do it manually instead of using a lib
{ loadOptions()
// Load filters from the disk
// And if any filter has zero ID, assign a new one
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we're operating on a copy
if filter.ID == 0 {
filter.ID = assignUniqueFilterID()
}
err := filter.load()
if err != nil {
// This is okay for the first start, the filter will be loaded later
log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err)
// clear LastUpdated so it gets fetched right away
}
if len(filter.Rules) == 0 {
filter.LastUpdated = time.Time{}
}
}
// Update filters we've just loaded right away, don't wait for periodic update timer
go func() {
refreshFiltersIfNecessary(false)
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
}()
signalChannel := make(chan os.Signal)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-signalChannel
cleanup()
os.Exit(0)
}()
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
go periodicallyRefreshFilters()
http.Handle("/", optionalAuthHandler(http.FileServer(box)))
registerControlHandlers()
err = startDNSServer()
if err != nil {
log.Fatal(err)
}
err = startDHCPServer()
if err != nil {
log.Fatal(err)
}
URL := fmt.Sprintf("http://%s", address)
log.Println("Go to " + URL)
log.Fatal(http.ListenAndServe(address, nil))
}
func cleanup() {
err := stopDNSServer()
if err != nil {
log.Printf("Couldn't stop DNS server: %s", err)
}
}
func getInput() (string, error) {
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
text := scanner.Text()
err := scanner.Err()
return text, err
}
// loadOptions reads command line arguments and initializes configuration
func loadOptions() {
var printHelp func() var printHelp func()
var configFilename *string var configFilename *string
var bindHost *string var bindHost *string
@ -128,85 +210,6 @@ func main() {
if bindPort != nil { if bindPort != nil {
config.BindPort = *bindPort config.BindPort = *bindPort
} }
}
// Load filters from the disk
// And if any filter has zero ID, assign a new one
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we're operating on a copy
if filter.ID == 0 {
filter.ID = assignUniqueFilterID()
}
err := filter.load()
if err != nil {
// This is okay for the first start, the filter will be loaded later
log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err)
// clear LastUpdated so it gets fetched right away
}
if len(filter.Rules) == 0 {
filter.LastUpdated = time.Time{}
}
}
// Update filters we've just loaded right away, don't wait for periodic update timer
go func() {
refreshFiltersIfNeccessary(false)
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
}()
signalChannel := make(chan os.Signal)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-signalChannel
cleanup()
os.Exit(0)
}()
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
go periodicallyRefreshFilters()
http.Handle("/", optionalAuthHandler(http.FileServer(box)))
registerControlHandlers()
err = startDNSServer()
if err != nil {
log.Fatal(err)
}
err = startDHCPServer()
if err != nil {
log.Fatal(err)
}
URL := fmt.Sprintf("http://%s", address)
log.Println("Go to " + URL)
log.Fatal(http.ListenAndServe(address, nil))
}
func cleanup() {
err := stopDNSServer()
if err != nil {
log.Printf("Couldn't stop DNS server: %s", err)
}
}
func getInput() (string, error) {
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
text := scanner.Text()
err := scanner.Err()
return text, err
} }
func promptAndGet(prompt string) (string, error) { func promptAndGet(prompt string) (string, error) {

View file

@ -21,7 +21,7 @@ const (
// configuration is loaded from YAML // configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type configuration struct { type configuration struct {
ourConfigFilename string // Config filename (can be overriden via the command line arguments) ourConfigFilename string // Config filename (can be overridden via the command line arguments)
ourBinaryDir string // Location of our directory, used to protect against CWD being somewhere else ourBinaryDir string // Location of our directory, used to protect against CWD being somewhere else
BindHost string `yaml:"bind_host"` BindHost string `yaml:"bind_host"`

View file

@ -40,27 +40,26 @@ func writeAllConfigsAndReloadDNS() error {
log.Printf("Couldn't write all configs: %s", err) log.Printf("Couldn't write all configs: %s", err)
return err return err
} }
reconfigureDNSServer() return reconfigureDNSServer()
return nil
} }
func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) { func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) {
err := writeAllConfigsAndReloadDNS() err := writeAllConfigsAndReloadDNS()
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write config file: %s", err) errorText := fmt.Sprintf("Couldn't write config file: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
return return
} }
returnOK(w, r) returnOK(w)
} }
func returnOK(w http.ResponseWriter, r *http.Request) { func returnOK(w http.ResponseWriter) {
_, err := fmt.Fprintf(w, "OK\n") _, err := fmt.Fprintf(w, "OK\n")
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
} }
@ -79,17 +78,17 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
@ -147,7 +146,13 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
http.Error(w, errorText, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
return return
} }
reconfigureDNSServer() err = reconfigureDNSServer()
if err != nil {
errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts)) _, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
if err != nil { if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
@ -206,7 +211,7 @@ func checkDNS(input string) error {
log.Printf("Checking if DNS %s works...", input) log.Printf("Checking if DNS %s works...", input)
u, err := upstream.AddressToUpstream(input, "", dnsforward.DefaultTimeout) u, err := upstream.AddressToUpstream(input, "", dnsforward.DefaultTimeout)
if err != nil { if err != nil {
return fmt.Errorf("Failed to choose upstream for %s: %s", input, err) return fmt.Errorf("failed to choose upstream for %s: %s", input, err)
} }
req := dns.Msg{} req := dns.Msg{}
@ -243,9 +248,9 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
resp, err := client.Get(versionCheckURL) resp, err := client.Get(versionCheckURL)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err) errorText := fmt.Sprintf("Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusBadGateway) http.Error(w, errorText, http.StatusBadGateway)
return return
} }
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
@ -255,18 +260,18 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
// read the body entirely // read the body entirely
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't read response body from %s: %s", versionCheckURL, err) errorText := fmt.Sprintf("Couldn't read response body from %s: %s", versionCheckURL, err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusBadGateway) http.Error(w, errorText, http.StatusBadGateway)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(body) _, err = w.Write(body)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
versionCheckLastTime = now versionCheckLastTime = now
@ -299,18 +304,18 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
config.RUnlock() config.RUnlock()
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
@ -378,7 +383,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
} }
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it // URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
// TODO: since we directly feed filters in-memory, revisit if writing configs is always neccessary // TODO: since we directly feed filters in-memory, revisit if writing configs is always necessary
config.Filters = append(config.Filters, filter) config.Filters = append(config.Filters, filter)
err = writeAllConfigs() err = writeAllConfigs()
if err != nil { if err != nil {
@ -473,7 +478,7 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
} }
// kick off refresh of rules from new URLs // kick off refresh of rules from new URLs
refreshFiltersIfNeccessary(false) refreshFiltersIfNecessary(false)
httpUpdateConfigReloadDNSReturnOK(w, r) httpUpdateConfigReloadDNSReturnOK(w, r)
} }
@ -529,7 +534,7 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
force := r.URL.Query().Get("force") force := r.URL.Query().Get("force")
updated := refreshFiltersIfNeccessary(force != "") updated := refreshFiltersIfNecessary(force != "")
fmt.Fprintf(w, "OK %d filters updated\n", updated) fmt.Fprintf(w, "OK %d filters updated\n", updated)
} }
@ -553,17 +558,17 @@ func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
} }
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
@ -574,9 +579,9 @@ func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
func handleParentalEnable(w http.ResponseWriter, r *http.Request) { func handleParentalEnable(w http.ResponseWriter, r *http.Request) {
parameters, err := parseParametersFromBody(r.Body) parameters, err := parseParametersFromBody(r.Body)
if err != nil { if err != nil {
errortext := fmt.Sprintf("failed to parse parameters from body: %s", err) errorText := fmt.Sprintf("failed to parse parameters from body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 400) http.Error(w, errorText, 400)
return return
} }
@ -631,18 +636,18 @@ func handleParentalStatus(w http.ResponseWriter, r *http.Request) {
} }
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
@ -667,18 +672,18 @@ func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
} }
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }

10
dhcp.go
View file

@ -58,7 +58,10 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
} }
} }
if !newconfig.Enabled { if !newconfig.Enabled {
dhcpServer.Stop() err := dhcpServer.Stop()
if err != nil {
log.Printf("failed to stop the DHCP server: %s", err)
}
} }
config.DHCP = newconfig config.DHCP = newconfig
httpUpdateConfigReloadDNSReturnOK(w, r) httpUpdateConfigReloadDNSReturnOK(w, r)
@ -73,11 +76,6 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
return return
} }
type address struct {
IP string
Netmask string
}
type responseInterface struct { type responseInterface struct {
Name string `json:"name"` Name string `json:"name"`
MTU int `json:"mtu"` MTU int `json:"mtu"`

View file

@ -13,6 +13,8 @@ import (
"github.com/krolaw/dhcp4" "github.com/krolaw/dhcp4"
) )
// CheckIfOtherDHCPServersPresent sends a DHCP request to the specified network interface,
// and waits for a response for a period defined by defaultDiscoverTime
func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) { func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
iface, err := net.InterfaceByName(ifaceName) iface, err := net.InterfaceByName(ifaceName)
if err != nil { if err != nil {
@ -30,8 +32,8 @@ func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
dst := "255.255.255.255:67" dst := "255.255.255.255:67"
// form a DHCP request packet, try to emulate existing client as much as possible // form a DHCP request packet, try to emulate existing client as much as possible
xId := make([]byte, 8) xID := make([]byte, 8)
n, err := rand.Read(xId) n, err := rand.Read(xID)
if n != 8 && err == nil { if n != 8 && err == nil {
err = fmt.Errorf("Generated less than 8 bytes") err = fmt.Errorf("Generated less than 8 bytes")
} }
@ -60,13 +62,13 @@ func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
leaseTime := uint32(math.RoundToEven(time.Duration(time.Hour * 24 * 90).Seconds())) leaseTime := uint32(math.RoundToEven(time.Duration(time.Hour * 24 * 90).Seconds()))
binary.BigEndian.PutUint32(leaseTimeRaw, leaseTime) binary.BigEndian.PutUint32(leaseTimeRaw, leaseTime)
options := []dhcp4.Option{ options := []dhcp4.Option{
{dhcp4.OptionParameterRequestList, requestList}, {Code: dhcp4.OptionParameterRequestList, Value: requestList},
{dhcp4.OptionMaximumDHCPMessageSize, maxUDPsizeRaw}, {Code: dhcp4.OptionMaximumDHCPMessageSize, Value: maxUDPsizeRaw},
{dhcp4.OptionClientIdentifier, append([]byte{0x01}, iface.HardwareAddr...)}, {Code: dhcp4.OptionClientIdentifier, Value: append([]byte{0x01}, iface.HardwareAddr...)},
{dhcp4.OptionIPAddressLeaseTime, leaseTimeRaw}, {Code: dhcp4.OptionIPAddressLeaseTime, Value: leaseTimeRaw},
{dhcp4.OptionHostName, []byte(hostname)}, {Code: dhcp4.OptionHostName, Value: []byte(hostname)},
} }
packet := dhcp4.RequestPacket(dhcp4.Discover, iface.HardwareAddr, nil, xId, false, options) packet := dhcp4.RequestPacket(dhcp4.Discover, iface.HardwareAddr, nil, xID, false, options)
// resolve 0.0.0.0:68 // resolve 0.0.0.0:68
udpAddr, err := net.ResolveUDPAddr("udp4", src) udpAddr, err := net.ResolveUDPAddr("udp4", src)
@ -98,7 +100,7 @@ func CheckIfOtherDHCPServersPresent(ifaceName string) (bool, error) {
} }
// send to 255.255.255.255:67 // send to 255.255.255.255:67
n, err = c.WriteTo(packet, dstAddr) _, err = c.WriteTo(packet, dstAddr)
// spew.Dump(n, err) // spew.Dump(n, err)
if err != nil { if err != nil {
return false, wrapErrPrint(err, "Couldn't send a packet to %s", dst) return false, wrapErrPrint(err, "Couldn't send a packet to %s", dst)

View file

@ -13,6 +13,7 @@ import (
const defaultDiscoverTime = time.Second * 3 const defaultDiscoverTime = time.Second * 3
// Lease contains the necessary information about a DHCP lease
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type Lease struct { type Lease struct {
HWAddr net.HardwareAddr `json:"mac" yaml:"hwaddr"` HWAddr net.HardwareAddr `json:"mac" yaml:"hwaddr"`
@ -21,6 +22,7 @@ type Lease struct {
Expiry time.Time `json:"expires"` Expiry time.Time `json:"expires"`
} }
// ServerConfig - DHCP server configuration
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type ServerConfig struct { type ServerConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"` Enabled bool `json:"enabled" yaml:"enabled"`
@ -32,6 +34,7 @@ type ServerConfig struct {
LeaseDuration uint `json:"lease_duration" yaml:"lease_duration"` // in seconds LeaseDuration uint `json:"lease_duration" yaml:"lease_duration"` // in seconds
} }
// Server - the current state of the DHCP server
type Server struct { type Server struct {
conn *filterConn // listening UDP socket conn *filterConn // listening UDP socket
@ -137,6 +140,7 @@ func (s *Server) Start(config *ServerConfig) error {
return nil return nil
} }
// Stop closes the listening UDP socket
func (s *Server) Stop() error { func (s *Server) Stop() error {
if s.conn == nil { if s.conn == nil {
// nothing to do, return silently // nothing to do, return silently
@ -249,6 +253,7 @@ func (s *Server) unreserveIP(ip net.IP) {
delete(s.IPpool, IP4) delete(s.IPpool, IP4)
} }
// ServeDHCP handles an incoming DHCP request
func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dhcp4.Options) dhcp4.Packet { func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dhcp4.Options) dhcp4.Packet {
log.Tracef("Got %v message", msgType) log.Tracef("Got %v message", msgType)
log.Tracef("Leases:") log.Tracef("Leases:")
@ -259,27 +264,6 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
for ip, hwaddr := range s.IPpool { for ip, hwaddr := range s.IPpool {
log.Tracef("IP pool entry %s -> %s", net.IPv4(ip[0], ip[1], ip[2], ip[3]), hwaddr) log.Tracef("IP pool entry %s -> %s", net.IPv4(ip[0], ip[1], ip[2], ip[3]), hwaddr)
} }
// spew.Dump(s.leases, s.IPpool)
// log.Printf("Called with msgType = %v, options = %+v", msgType, options)
// spew.Dump(p)
// log.Printf("%14s %v", "p.Broadcast", p.Broadcast()) // false
// log.Printf("%14s %v", "p.CHAddr", p.CHAddr()) // 2c:f0:a2:f2:31:00
// log.Printf("%14s %v", "p.CIAddr", p.CIAddr()) // 0.0.0.0
// log.Printf("%14s %v", "p.Cookie", p.Cookie()) // [99 130 83 99]
// log.Printf("%14s %v", "p.File", p.File()) // []
// log.Printf("%14s %v", "p.Flags", p.Flags()) // [0 0]
// log.Printf("%14s %v", "p.GIAddr", p.GIAddr()) // 0.0.0.0
// log.Printf("%14s %v", "p.HLen", p.HLen()) // 6
// log.Printf("%14s %v", "p.HType", p.HType()) // 1
// log.Printf("%14s %v", "p.Hops", p.Hops()) // 0
// log.Printf("%14s %v", "p.OpCode", p.OpCode()) // BootRequest
// log.Printf("%14s %v", "p.Options", p.Options()) // [53 1 1 55 10 1 121 3 6 15 119 252 95 44 46 57 2 5 220 61 7 1 44 240 162 242 49 0 51 4 0 118 167 0 12 4 119 104 109 100 255 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
// log.Printf("%14s %v", "p.ParseOptions", p.ParseOptions()) // map[OptionParameterRequestList:[1 121 3 6 15 119 252 95 44 46] OptionDHCPMessageType:[1] OptionMaximumDHCPMessageSize:[5 220] OptionClientIdentifier:[1 44 240 162 242 49 0] OptionIPAddressLeaseTime:[0 118 167 0] OptionHostName:[119 104 109 100]]
// log.Printf("%14s %v", "p.SIAddr", p.SIAddr()) // 0.0.0.0
// log.Printf("%14s %v", "p.SName", p.SName()) // []
// log.Printf("%14s %v", "p.Secs", p.Secs()) // [0 8]
// log.Printf("%14s %v", "p.XId", p.XId()) // [211 184 20 44]
// log.Printf("%14s %v", "p.YIAddr", p.YIAddr()) // 0.0.0.0
switch msgType { switch msgType {
case dhcp4.Discover: // Broadcast Packet From Client - Can I have an IP? case dhcp4.Discover: // Broadcast Packet From Client - Can I have an IP?
@ -297,6 +281,32 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
case dhcp4.Request: // Broadcast From Client - I'll take that IP (Also start for renewals) case dhcp4.Request: // Broadcast From Client - I'll take that IP (Also start for renewals)
// start/renew a lease -- update lease time // start/renew a lease -- update lease time
// some clients (OSX) just go right ahead and do Request first from previously known IP, if they get NAK, they restart full cycle with Discover then Request // some clients (OSX) just go right ahead and do Request first from previously known IP, if they get NAK, they restart full cycle with Discover then Request
return s.handleDHCP4Request(p, msgType, options)
case dhcp4.Decline: // Broadcast From Client - Sorry I can't use that IP
log.Tracef("Got from client: Decline")
case dhcp4.Release: // From Client, I don't need that IP anymore
log.Tracef("Got from client: Release")
case dhcp4.Inform: // From Client, I have this IP and there's nothing you can do about it
log.Tracef("Got from client: Inform")
// do nothing
// from server -- ignore those but enumerate just in case
case dhcp4.Offer: // Broadcast From Server - Here's an IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: Offer")
case dhcp4.ACK: // From Server, Yes you can have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: ACK")
case dhcp4.NAK: // From Server, No you cannot have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: NAK")
default:
log.Printf("Unknown DHCP packet detected, ignoring: %v", msgType)
return nil
}
return nil
}
func (s *Server) handleDHCP4Request(p dhcp4.Packet, msgType dhcp4.MessageType, options dhcp4.Options) dhcp4.Packet {
log.Tracef("Got from client: Request") log.Tracef("Got from client: Request")
if server, ok := options[dhcp4.OptionServerIdentifier]; ok && !net.IP(server).Equal(s.ipnet.IP) { if server, ok := options[dhcp4.OptionServerIdentifier]; ok && !net.IP(server).Equal(s.ipnet.IP) {
log.Tracef("Request message not for this DHCP server (%v vs %v)", server, s.ipnet.IP) log.Tracef("Request message not for this DHCP server (%v vs %v)", server, s.ipnet.IP)
@ -366,30 +376,9 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh
// requested IP is outside of DHCP range // requested IP is outside of DHCP range
log.Tracef("Replying with NAK: request IP %s is outside of DHCP range [%s, %s], asked by %v", reqIP, s.leaseStart, s.leaseStop, p.CHAddr()) log.Tracef("Replying with NAK: request IP %s is outside of DHCP range [%s, %s], asked by %v", reqIP, s.leaseStart, s.leaseStop, p.CHAddr())
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil) return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
case dhcp4.Decline: // Broadcast From Client - Sorry I can't use that IP
log.Tracef("Got from client: Decline")
case dhcp4.Release: // From Client, I don't need that IP anymore
log.Tracef("Got from client: Release")
case dhcp4.Inform: // From Client, I have this IP and there's nothing you can do about it
log.Tracef("Got from client: Inform")
// do nothing
// from server -- ignore those but enumerate just in case
case dhcp4.Offer: // Broadcast From Server - Here's an IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: Offer")
case dhcp4.ACK: // From Server, Yes you can have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: ACK")
case dhcp4.NAK: // From Server, No you cannot have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: NAK")
default:
log.Printf("Unknown DHCP packet detected, ignoring: %v", msgType)
return nil
}
return nil
} }
// Leases returns the list of current DHCP leases
func (s *Server) Leases() []*Lease { func (s *Server) Leases() []*Lease {
s.RLock() s.RLock()
result := s.leases result := s.leases

View file

@ -8,7 +8,7 @@ import (
) )
// filterConn listens to 0.0.0.0:67, but accepts packets only from specific interface // filterConn listens to 0.0.0.0:67, but accepts packets only from specific interface
// This is neccessary for DHCP daemon to work, since binding to IP address doesn't // This is necessary for DHCP daemon to work, since binding to IP address doesn't
// us access to see Discover/Request packets from clients. // us access to see Discover/Request packets from clients.
// //
// TODO: on windows, controlmessage does not work, try to find out another way // TODO: on windows, controlmessage does not work, try to find out another way
@ -49,7 +49,6 @@ func (f *filterConn) ReadFrom(b []byte) (int, net.Addr, error) {
} }
// packet doesn't match criteria, drop it // packet doesn't match criteria, drop it
} }
return 0, nil, nil
} }
func (f *filterConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (f *filterConn) WriteTo(b []byte, addr net.Addr) (int, error) {

View file

@ -3,7 +3,6 @@ package dhcpd
import ( import (
"fmt" "fmt"
"net" "net"
"strings"
"github.com/hmage/golibs/log" "github.com/hmage/golibs/log"
"github.com/joomcode/errorx" "github.com/joomcode/errorx"
@ -45,22 +44,6 @@ func getIfaceIPv4(iface *net.Interface) *net.IPNet {
return nil return nil
} }
func isConnClosed(err error) bool {
if err == nil {
return false
}
nerr, ok := err.(*net.OpError)
if !ok {
return false
}
if strings.Contains(nerr.Err.Error(), "use of closed network connection") {
return true
}
return false
}
func wrapErrPrint(err error, message string, args ...interface{}) error { func wrapErrPrint(err error, message string, args ...interface{}) error {
var errx error var errx error
if err == nil { if err == nil {

4
dns.go
View file

@ -39,13 +39,13 @@ func generateServerConfig() dnsforward.ServerConfig {
} }
for _, u := range config.DNS.UpstreamDNS { for _, u := range config.DNS.UpstreamDNS {
upstream, err := upstream.AddressToUpstream(u, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout) dnsUpstream, err := upstream.AddressToUpstream(u, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
if err != nil { if err != nil {
log.Printf("Couldn't get upstream: %s", err) log.Printf("Couldn't get upstream: %s", err)
// continue, just ignore the upstream // continue, just ignore the upstream
continue continue
} }
newconfig.Upstreams = append(newconfig.Upstreams, upstream) newconfig.Upstreams = append(newconfig.Upstreams, dnsUpstream)
} }
return newconfig return newconfig
} }

View file

@ -35,7 +35,7 @@ const defaultParentalURL = "http://%s/check-parental-control-hash?prefixes=%s&se
// ErrInvalidSyntax is returned by AddRule when the rule is invalid // ErrInvalidSyntax is returned by AddRule when the rule is invalid
var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax") var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax")
// ErrInvalidSyntax is returned by AddRule when the rule was already added to the filter // ErrAlreadyExists is returned by AddRule when the rule was already added to the filter
var ErrAlreadyExists = errors.New("dnsfilter: rule was already added") var ErrAlreadyExists = errors.New("dnsfilter: rule was already added")
const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet spot const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet spot
@ -115,6 +115,7 @@ type Dnsfilter struct {
privateConfig privateConfig
} }
// Filter represents a filter list
type Filter struct { type Filter struct {
ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase
Rules []string `json:"-" yaml:"-"` // not in yaml or json Rules []string `json:"-" yaml:"-"` // not in yaml or json
@ -127,16 +128,26 @@ type Reason int
const ( const (
// reasons for not filtering // reasons for not filtering
NotFilteredNotFound Reason = iota // host was not find in any checks, default value for result
NotFilteredWhiteList // the host is explicitly whitelisted // NotFilteredNotFound - host was not find in any checks, default value for result
NotFilteredError // there was a transitive error during check NotFilteredNotFound Reason = iota
// NotFilteredWhiteList - the host is explicitly whitelisted
NotFilteredWhiteList
// NotFilteredError - there was a transitive error during check
NotFilteredError
// reasons for filtering // reasons for filtering
FilteredBlackList // the host was matched to be advertising host
FilteredSafeBrowsing // the host was matched to be malicious/phishing // FilteredBlackList - the host was matched to be advertising host
FilteredParental // the host was matched to be outside of parental control settings FilteredBlackList
FilteredInvalid // the request was invalid and was not processed // FilteredSafeBrowsing - the host was matched to be malicious/phishing
FilteredSafeSearch // the host was replaced with safesearch variant FilteredSafeBrowsing
// FilteredParental - the host was matched to be outside of parental control settings
FilteredParental
// FilteredInvalid - the request was invalid and was not processed
FilteredInvalid
// FilteredSafeSearch - the host was replaced with safesearch variant
FilteredSafeSearch
) )
// these variables need to survive coredns reload // these variables need to survive coredns reload
@ -151,7 +162,7 @@ type Result struct {
IsFiltered bool `json:",omitempty"` // True if the host name is filtered IsFiltered bool `json:",omitempty"` // True if the host name is filtered
Reason Reason `json:",omitempty"` // Reason for blocking / unblocking Reason Reason `json:",omitempty"` // Reason for blocking / unblocking
Rule string `json:",omitempty"` // Original rule text Rule string `json:",omitempty"` // Original rule text
Ip net.IP `json:",omitempty"` // Not nil only in the case of a hosts file syntax IP net.IP `json:",omitempty"` // Not nil only in the case of a hosts file syntax
FilterID int64 `json:",omitempty"` // Filter ID the rule belongs to FilterID int64 `json:",omitempty"` // Filter ID the rule belongs to
} }
@ -228,7 +239,6 @@ func newRulesTable() *rulesTable {
func (r *rulesTable) Add(rule *rule) { func (r *rulesTable) Add(rule *rule) {
r.Lock() r.Lock()
if rule.ip != nil { if rule.ip != nil {
// Hosts syntax // Hosts syntax
r.rulesByHost[rule.text] = rule r.rulesByHost[rule.text] = rule
@ -476,7 +486,7 @@ func (rule *rule) match(host string) (Result, error) {
IsFiltered: true, IsFiltered: true,
Reason: FilteredBlackList, Reason: FilteredBlackList,
Rule: rule.originalText, Rule: rule.originalText,
Ip: rule.ip, IP: rule.ip,
FilterID: rule.listID, FilterID: rule.listID,
}, nil }, nil
} }
@ -661,8 +671,11 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) {
return result, err return result, err
} }
type formatHandler func(hashparam string) string
type handleBodyHandler func(body []byte, hashes map[string]bool) (Result, error)
// real implementation of lookup/check // real implementation of lookup/check
func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gcache.Cache, hashparamNeedSlash bool, format func(hashparam string) string, handleBody func(body []byte, hashes map[string]bool) (Result, error)) (Result, error) { func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gcache.Cache, hashparamNeedSlash bool, format formatHandler, handleBody handleBodyHandler) (Result, error) {
// if host ends with a dot, trim it // if host ends with a dot, trim it
host = strings.ToLower(strings.Trim(host, ".")) host = strings.ToLower(strings.Trim(host, "."))
@ -913,15 +926,6 @@ func (d *Dnsfilter) Destroy() {
// config manipulation helpers // config manipulation helpers
// //
// IsParentalSensitivityValid checks if sensitivity is valid value
func IsParentalSensitivityValid(sensitivity int) bool {
switch sensitivity {
case 3, 10, 13, 17:
return true
}
return false
}
// SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup // SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup
func (d *Dnsfilter) SetSafeBrowsingServer(host string) { func (d *Dnsfilter) SetSafeBrowsingServer(host string) {
if len(host) == 0 { if len(host) == 0 {

View file

@ -26,7 +26,7 @@ import (
func TestLotsOfRulesMemoryUsage(t *testing.T) { func TestLotsOfRulesMemoryUsage(t *testing.T) {
start := getRSS() start := getRSS()
log.Tracef("RSS before loading rules - %d kB\n", start/1024) log.Tracef("RSS before loading rules - %d kB\n", start/1024)
dumpMemProfile(_Func() + "1.pprof") dumpMemProfile("tests/" + _Func() + "1.pprof")
d := NewForTest() d := NewForTest()
defer d.Destroy() defer d.Destroy()
@ -37,7 +37,7 @@ func TestLotsOfRulesMemoryUsage(t *testing.T) {
afterLoad := getRSS() afterLoad := getRSS()
log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024) log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024)
dumpMemProfile(_Func() + "2.pprof") dumpMemProfile("tests/" + _Func() + "2.pprof")
tests := []struct { tests := []struct {
host string host string
@ -60,7 +60,7 @@ func TestLotsOfRulesMemoryUsage(t *testing.T) {
} }
afterMatch := getRSS() afterMatch := getRSS()
log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024) log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024)
dumpMemProfile(_Func() + "3.pprof") dumpMemProfile("tests/" + _Func() + "3.pprof")
} }
func getRSS() uint64 { func getRSS() uint64 {
@ -69,6 +69,9 @@ func getRSS() uint64 {
panic(err) panic(err)
} }
minfo, err := proc.MemoryInfo() minfo, err := proc.MemoryInfo()
if err != nil {
panic(err)
}
return minfo.RSS return minfo.RSS
} }
@ -86,7 +89,7 @@ func dumpMemProfile(name string) {
} }
} }
const topHostsFilename = "../tests/top-1m.csv" const topHostsFilename = "tests/top-1m.csv"
func fetchTopHostsFromNet() { func fetchTopHostsFromNet() {
log.Tracef("Fetching top hosts from network") log.Tracef("Fetching top hosts from network")
@ -146,7 +149,7 @@ func getTopHosts() {
func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) { func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
start := getRSS() start := getRSS()
log.Tracef("RSS before loading rules - %d kB\n", start/1024) log.Tracef("RSS before loading rules - %d kB\n", start/1024)
dumpMemProfile(_Func() + "1.pprof") dumpMemProfile("tests/" + _Func() + "1.pprof")
d := NewForTest() d := NewForTest()
defer d.Destroy() defer d.Destroy()
@ -155,7 +158,7 @@ func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
afterLoad := getRSS() afterLoad := getRSS()
log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024) log.Tracef("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024)
dumpMemProfile(_Func() + "2.pprof") dumpMemProfile("tests/" + _Func() + "2.pprof")
getTopHosts() getTopHosts()
hostnames, err := os.Open(topHostsFilename) hostnames, err := os.Open(topHostsFilename)
@ -165,7 +168,7 @@ func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
defer hostnames.Close() defer hostnames.Close()
afterHosts := getRSS() afterHosts := getRSS()
log.Tracef("RSS after loading hosts - %d kB (%d kB diff)\n", afterHosts/1024, (afterHosts-afterLoad)/1024) log.Tracef("RSS after loading hosts - %d kB (%d kB diff)\n", afterHosts/1024, (afterHosts-afterLoad)/1024)
dumpMemProfile(_Func() + "2.pprof") dumpMemProfile("tests/" + _Func() + "2.pprof")
{ {
scanner := bufio.NewScanner(hostnames) scanner := bufio.NewScanner(hostnames)
@ -184,7 +187,7 @@ func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
afterMatch := getRSS() afterMatch := getRSS()
log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024) log.Tracef("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024)
dumpMemProfile(_Func() + "3.pprof") dumpMemProfile("tests/" + _Func() + "3.pprof")
} }
func TestRuleToRegexp(t *testing.T) { func TestRuleToRegexp(t *testing.T) {
@ -282,7 +285,7 @@ func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) {
} }
} }
func (d *Dnsfilter) checkMatchIp(t *testing.T, hostname string, ip string) { func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string) {
t.Helper() t.Helper()
ret, err := d.CheckHost(hostname) ret, err := d.CheckHost(hostname)
if err != nil { if err != nil {
@ -291,8 +294,8 @@ func (d *Dnsfilter) checkMatchIp(t *testing.T, hostname string, ip string) {
if !ret.IsFiltered { if !ret.IsFiltered {
t.Errorf("Expected hostname %s to match", hostname) t.Errorf("Expected hostname %s to match", hostname)
} }
if ret.Ip == nil || ret.Ip.String() != ip { if ret.IP == nil || ret.IP.String() != ip {
t.Errorf("Expected ip %s to match, actual: %v", ip, ret.Ip) t.Errorf("Expected ip %s to match, actual: %v", ip, ret.IP)
} }
} }
@ -308,7 +311,7 @@ func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) {
} }
func loadTestRules(d *Dnsfilter) error { func loadTestRules(d *Dnsfilter) error {
filterFileName := "../tests/dns.txt" filterFileName := "tests/dns.txt"
file, err := os.Open(filterFileName) file, err := os.Open(filterFileName)
if err != nil { if err != nil {
return err return err
@ -368,8 +371,8 @@ func TestEtcHostsMatching(t *testing.T) {
text := fmt.Sprintf(" %s google.com www.google.com # enforce google's safesearch ", addr) text := fmt.Sprintf(" %s google.com www.google.com # enforce google's safesearch ", addr)
d.checkAddRule(t, text) d.checkAddRule(t, text)
d.checkMatchIp(t, "google.com", addr) d.checkMatchIP(t, "google.com", addr)
d.checkMatchIp(t, "www.google.com", addr) d.checkMatchIP(t, "www.google.com", addr)
d.checkMatchEmpty(t, "subdomain.google.com") d.checkMatchEmpty(t, "subdomain.google.com")
d.checkMatchEmpty(t, "example.org") d.checkMatchEmpty(t, "example.org")
} }

View file

@ -310,8 +310,8 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu
case dnsfilter.FilteredParental: case dnsfilter.FilteredParental:
return s.genBlockedHost(m, parentalBlockHost, d.Upstream) return s.genBlockedHost(m, parentalBlockHost, d.Upstream)
default: default:
if result.Ip != nil { if result.IP != nil {
return s.genARecord(m, result.Ip) return s.genARecord(m, result.IP)
} }
return s.genNXDomain(m) return s.genNXDomain(m)

View file

@ -112,6 +112,7 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, el
} }
} }
// HandleQueryLog handles query log web request
func HandleQueryLog(w http.ResponseWriter, r *http.Request) { func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
queryLogLock.RLock() queryLogLock.RLock()
values := make([]*logEntry, len(queryLogCache)) values := make([]*logEntry, len(queryLogCache))
@ -123,6 +124,7 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
values[left], values[right] = values[right], values[left] values[left], values[right] = values[right], values[left]
} }
// iterate
var data = []map[string]interface{}{} var data = []map[string]interface{}{}
for _, entry := range values { for _, entry := range values {
var q *dns.Msg var q *dns.Msg
@ -167,7 +169,36 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
jsonEntry["filterId"] = entry.Result.FilterID jsonEntry["filterId"] = entry.Result.FilterID
} }
if a != nil && len(a.Answer) > 0 { answers := answerToMap(a)
if answers != nil {
jsonEntry["answer"] = answers
}
data = append(data, jsonEntry)
}
jsonVal, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}
func answerToMap(a *dns.Msg) []map[string]interface{} {
if a == nil || len(a.Answer) == 0 {
return nil
}
var answers = []map[string]interface{}{} var answers = []map[string]interface{}{}
for _, k := range a.Answer { for _, k := range a.Answer {
header := k.Header() header := k.Header()
@ -207,27 +238,8 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
} }
answers = append(answers, answer) answers = append(answers, answer)
} }
jsonEntry["answer"] = answers
}
data = append(data, jsonEntry) return answers
}
jsonVal, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
} }
// getIPString is a helper function that extracts IP address from net.Addr // getIPString is a helper function that extracts IP address from net.Addr

View file

@ -156,7 +156,7 @@ func periodicQueryLogRotate() {
func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error { func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error {
now := time.Now() now := time.Now()
// read from querylog files, try newest file first // read from querylog files, try newest file first
files := []string{} var files []string
if enableGzip { if enableGzip {
files = []string{ files = []string{

View file

@ -26,10 +26,10 @@ type hourTop struct {
mutex sync.RWMutex mutex sync.RWMutex
} }
func (top *hourTop) init() { func (h *hourTop) init() {
top.domains = gcache.New(queryLogTopSize).LRU().Build() h.domains = gcache.New(queryLogTopSize).LRU().Build()
top.blocked = gcache.New(queryLogTopSize).LRU().Build() h.blocked = gcache.New(queryLogTopSize).LRU().Build()
top.clients = gcache.New(queryLogTopSize).LRU().Build() h.clients = gcache.New(queryLogTopSize).LRU().Build()
} }
type dayTop struct { type dayTop struct {
@ -69,9 +69,9 @@ func periodicHourlyTopRotate() {
} }
} }
func (top *hourTop) incrementValue(key string, cache gcache.Cache) error { func (h *hourTop) incrementValue(key string, cache gcache.Cache) error {
top.Lock() h.Lock()
defer top.Unlock() defer h.Unlock()
ivalue, err := cache.Get(key) ivalue, err := cache.Get(key)
if err == gcache.KeyNotFoundError { if err == gcache.KeyNotFoundError {
// we just set it and we're done // we just set it and we're done
@ -103,20 +103,20 @@ func (top *hourTop) incrementValue(key string, cache gcache.Cache) error {
return nil return nil
} }
func (top *hourTop) incrementDomains(key string) error { func (h *hourTop) incrementDomains(key string) error {
return top.incrementValue(key, top.domains) return h.incrementValue(key, h.domains)
} }
func (top *hourTop) incrementBlocked(key string) error { func (h *hourTop) incrementBlocked(key string) error {
return top.incrementValue(key, top.blocked) return h.incrementValue(key, h.blocked)
} }
func (top *hourTop) incrementClients(key string) error { func (h *hourTop) incrementClients(key string) error {
return top.incrementValue(key, top.clients) return h.incrementValue(key, h.clients)
} }
// if does not exist -- return 0 // if does not exist -- return 0
func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) { func (h *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) {
ivalue, err := cache.Get(key) ivalue, err := cache.Get(key)
if err == gcache.KeyNotFoundError { if err == gcache.KeyNotFoundError {
return 0, nil return 0, nil
@ -137,19 +137,19 @@ func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error)
return value, nil return value, nil
} }
func (top *hourTop) lockedGetDomains(key string) (int, error) { func (h *hourTop) lockedGetDomains(key string) (int, error) {
return top.lockedGetValue(key, top.domains) return h.lockedGetValue(key, h.domains)
} }
func (top *hourTop) lockedGetBlocked(key string) (int, error) { func (h *hourTop) lockedGetBlocked(key string) (int, error) {
return top.lockedGetValue(key, top.blocked) return h.lockedGetValue(key, h.blocked)
} }
func (top *hourTop) lockedGetClients(key string) (int, error) { func (h *hourTop) lockedGetClients(key string) (int, error) {
return top.lockedGetValue(key, top.clients) return h.lockedGetValue(key, h.clients)
} }
func (r *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { func (d *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error {
// figure out which hour bucket it belongs to // figure out which hour bucket it belongs to
hour := int(now.Sub(entry.Time).Hours()) hour := int(now.Sub(entry.Time).Hours())
if hour >= 24 { if hour >= 24 {
@ -252,6 +252,7 @@ func fillStatsFromQueryLog() error {
return nil return nil
} }
// HandleStatsTop returns the current top stats
func HandleStatsTop(w http.ResponseWriter, r *http.Request) { func HandleStatsTop(w http.ResponseWriter, r *http.Request) {
domains := map[string]int{} domains := map[string]int{}
blocked := map[string]int{} blocked := map[string]int{}
@ -320,9 +321,9 @@ func HandleStatsTop(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err := w.Write(json.Bytes()) _, err := w.Write(json.Bytes())
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
} }

View file

@ -17,7 +17,6 @@ var (
filteredLists = newDNSCounter("filtered_lists_total") filteredLists = newDNSCounter("filtered_lists_total")
filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total") filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total")
filteredParental = newDNSCounter("filtered_parental_total") filteredParental = newDNSCounter("filtered_parental_total")
filteredInvalid = newDNSCounter("filtered_invalid_total")
whitelisted = newDNSCounter("whitelisted_total") whitelisted = newDNSCounter("whitelisted_total")
safesearch = newDNSCounter("safesearch_total") safesearch = newDNSCounter("safesearch_total")
errorsTotal = newDNSCounter("errors_total") errorsTotal = newDNSCounter("errors_total")
@ -93,7 +92,7 @@ func (p *periodicStats) Observe(name string, when time.Time, value float64) {
currentValues := p.Entries[countname] currentValues := p.Entries[countname]
value := currentValues[elapsed] value := currentValues[elapsed]
// log.Tracef("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1) // log.Tracef("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1)
value += 1 value++
currentValues[elapsed] = value currentValues[elapsed] = value
p.Entries[countname] = currentValues p.Entries[countname] = currentValues
} }
@ -224,6 +223,7 @@ func incrementCounters(entry *logEntry) {
elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time)
} }
// HandleStats returns aggregated stats data for the 24 hours
func HandleStats(w http.ResponseWriter, r *http.Request) { func HandleStats(w http.ResponseWriter, r *http.Request) {
const numHours = 24 const numHours = 24
histrical := generateMapFromStats(&statistics.PerHour, 0, numHours) histrical := generateMapFromStats(&statistics.PerHour, 0, numHours)
@ -252,17 +252,17 @@ func HandleStats(w http.ResponseWriter, r *http.Request) {
json, err := json.Marshal(summed) json, err := json.Marshal(summed)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(json) _, err = w.Write(json)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
@ -296,6 +296,7 @@ func generateMapFromStats(stats *periodicStats, start int, end int) map[string]i
return result return result
} }
// HandleStatsHistory returns historical stats data for the 24 hours
func HandleStatsHistory(w http.ResponseWriter, r *http.Request) { func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
// handle time unit and prepare our time window size // handle time unit and prepare our time window size
now := time.Now() now := time.Now()
@ -323,16 +324,16 @@ func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
// parse start and end time // parse start and end time
startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time")) startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time"))
if err != nil { if err != nil {
errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err) errorText := fmt.Sprintf("Must specify valid start_time parameter: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 400) http.Error(w, errorText, 400)
return return
} }
endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time")) endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time"))
if err != nil { if err != nil {
errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err) errorText := fmt.Sprintf("Must specify valid end_time parameter: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 400) http.Error(w, errorText, 400)
return return
} }
@ -360,28 +361,29 @@ func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
data := generateMapFromStats(stats, start, end) data := generateMapFromStats(stats, start, end)
json, err := json.Marshal(data) json, err := json.Marshal(data)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(json) _, err = w.Write(json)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 500) http.Error(w, errorText, 500)
return return
} }
} }
// HandleStatsReset resets the stats caches
func HandleStatsReset(w http.ResponseWriter, r *http.Request) { func HandleStatsReset(w http.ResponseWriter, r *http.Request) {
purgeStats() purgeStats()
_, err := fmt.Fprintf(w, "OK\n") _, err := fmt.Fprintf(w, "OK\n")
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
} }

View file

@ -70,20 +70,20 @@ func updateUniqueFilterID(filters []filter) {
func assignUniqueFilterID() int64 { func assignUniqueFilterID() int64 {
value := nextFilterID value := nextFilterID
nextFilterID += 1 nextFilterID++
return value return value
} }
// Sets up a timer that will be checking for filters updates periodically // Sets up a timer that will be checking for filters updates periodically
func periodicallyRefreshFilters() { func periodicallyRefreshFilters() {
for range time.Tick(time.Minute) { for range time.Tick(time.Minute) {
refreshFiltersIfNeccessary(false) refreshFiltersIfNecessary(false)
} }
} }
// Checks filters updates if necessary // Checks filters updates if necessary
// If force is true, it ignores the filter.LastUpdated field value // If force is true, it ignores the filter.LastUpdated field value
func refreshFiltersIfNeccessary(force bool) int { func refreshFiltersIfNecessary(force bool) int {
config.Lock() config.Lock()
// fetch URLs // fetch URLs
@ -113,8 +113,12 @@ func refreshFiltersIfNeccessary(force bool) int {
} }
config.Unlock() config.Unlock()
if updateCount > 0 { if updateCount > 0 && isRunning() {
reconfigureDNSServer() err := reconfigureDNSServer()
if err != nil {
msg := fmt.Sprintf("SHOULD NOT HAPPEN: cannot reconfigure DNS server with the new filters: %s", err)
panic(msg)
}
} }
return updateCount return updateCount
} }

View file

@ -34,9 +34,9 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
log.Printf("config.Language is %s", config.Language) log.Printf("config.Language is %s", config.Language)
_, err := fmt.Fprintf(w, "%s\n", config.Language) _, err := fmt.Fprintf(w, "%s\n", config.Language)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to write response json: %s", err) errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
return return
} }
} }