From 2c33905a79241e3f203f31ed5d9ca5a69d408c70 Mon Sep 17 00:00:00 2001
From: Eugene Bujak <hmage@hmage.net>
Date: Sun, 7 Oct 2018 00:51:44 +0300
Subject: [PATCH] Querylog -- Implement file writing and update /querylog
 handler for changed structures.

---
 coredns_plugin/coredns_plugin.go |  10 +++
 coredns_plugin/querylog.go       | 124 +++++++++++++++++++-------
 coredns_plugin/querylog_file.go  | 146 +++++++++++++++++++++++++++++++
 stats.go                         |   7 +-
 4 files changed, 255 insertions(+), 32 deletions(-)
 create mode 100644 coredns_plugin/querylog_file.go

diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go
index 8af7601e..530e2c2b 100644
--- a/coredns_plugin/coredns_plugin.go
+++ b/coredns_plugin/coredns_plugin.go
@@ -219,6 +219,7 @@ func setup(c *caddy.Controller) error {
 		return nil
 	})
 	c.OnShutdown(p.onShutdown)
+	c.OnFinalShutdown(p.onFinalShutdown)
 
 	return nil
 }
@@ -250,6 +251,15 @@ func (p *plug) onShutdown() error {
 	return nil
 }
 
+func (p *plug) onFinalShutdown() error {
+	err := flushToFile(logBuffer)
+	if err != nil {
+		log.Printf("failed to flush to file: %s", err)
+		return err
+	}
+	return nil
+}
+
 type statsFunc func(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType)
 
 func doDesc(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) {
diff --git a/coredns_plugin/querylog.go b/coredns_plugin/querylog.go
index 3e2ae239..808450db 100644
--- a/coredns_plugin/querylog.go
+++ b/coredns_plugin/querylog.go
@@ -5,6 +5,8 @@ import (
 	"fmt"
 	"log"
 	"net/http"
+	"os"
+	"path"
 	"runtime"
 	"strconv"
 	"strings"
@@ -13,45 +15,93 @@ import (
 	"github.com/AdguardTeam/AdguardDNS/dnsfilter"
 	"github.com/coredns/coredns/plugin/pkg/response"
 	"github.com/miekg/dns"
-	"github.com/zfjagann/golang-ring"
 )
 
-const logBufferCap = 10000
+const (
+	logBufferCap = 1000 // maximum capacity of logBuffer before it's flushed to disk
+	queryLogAPI  = 1000 // maximum API response for /querylog
+)
 
-var logBuffer = ring.Ring{}
+var (
+	logBuffer     []logEntry
+)
 
 type logEntry struct {
-	Question *dns.Msg
-	Answer   *dns.Msg
+	Question []byte
+	Answer   []byte
 	Result   dnsfilter.Result
 	Time     time.Time
 	Elapsed  time.Duration
 	IP       string
 }
 
-func init() {
-	logBuffer.SetCapacity(logBufferCap)
-}
-
 func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, elapsed time.Duration, ip string) {
+	var q []byte
+	var a []byte
+	var err error
+
+	if question != nil {
+		q, err = question.Pack()
+		if err != nil {
+			log.Printf("failed to pack question for querylog: %s", err)
+			return
+		}
+	}
+	if answer != nil {
+		a, err = answer.Pack()
+		if err != nil {
+			log.Printf("failed to pack answer for querylog: %s", err)
+			return
+		}
+	}
+
 	entry := logEntry{
-		Question: question,
-		Answer:   answer,
+		Question: q,
+		Answer:   a,
 		Result:   result,
 		Time:     time.Now(),
 		Elapsed:  elapsed,
 		IP:       ip,
 	}
-	logBuffer.Enqueue(entry)
+	var flushBuffer []logEntry
+
+	logBuffer = append(logBuffer, entry)
+	if len(logBuffer) >= logBufferCap {
+		flushBuffer = logBuffer
+		logBuffer = nil
+	}
+	if len(flushBuffer) > 0 {
+		// write to file
+		// do it in separate goroutine -- we are stalling DNS response this whole time
+		go flushToFile(flushBuffer)
+	}
+	return
 }
 
-func handler(w http.ResponseWriter, r *http.Request) {
-	values := logBuffer.Values()
+func handleQueryLog(w http.ResponseWriter, r *http.Request) {
+	// TODO: fetch values from disk if len(logBuffer) < queryLogSize
+	// TODO: cache output
+	values := logBuffer
 	var data = []map[string]interface{}{}
-	for _, value := range values {
-		entry, ok := value.(logEntry)
-		if !ok {
-			continue
+	for _, entry := range values {
+		var q *dns.Msg
+		var a *dns.Msg
+
+		if len(entry.Question) > 0 {
+			q = new(dns.Msg)
+			if err := q.Unpack(entry.Question); err != nil {
+				// ignore, log and move on
+				log.Printf("Failed to unpack dns message question: %s", err)
+				q = nil
+			}
+		}
+		if len(entry.Answer) > 0 {
+			a = new(dns.Msg)
+			if err := a.Unpack(entry.Answer); err != nil {
+				// ignore, log and move on
+				log.Printf("Failed to unpack dns message question: %s", err)
+				a = nil
+			}
 		}
 
 		jsonentry := map[string]interface{}{
@@ -60,22 +110,25 @@ func handler(w http.ResponseWriter, r *http.Request) {
 			"time":       entry.Time.Format(time.RFC3339),
 			"client":     entry.IP,
 		}
-		question := map[string]interface{}{
-			"host":  strings.ToLower(strings.TrimSuffix(entry.Question.Question[0].Name, ".")),
-			"type":  dns.Type(entry.Question.Question[0].Qtype).String(),
-			"class": dns.Class(entry.Question.Question[0].Qclass).String(),
+		if q != nil {
+			jsonentry["question"] = map[string]interface{}{
+				"host":  strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")),
+				"type":  dns.Type(q.Question[0].Qtype).String(),
+				"class": dns.Class(q.Question[0].Qclass).String(),
+			}
 		}
-		jsonentry["question"] = question
 
-		status, _ := response.Typify(entry.Answer, time.Now().UTC())
-		jsonentry["status"] = status.String()
+		if a != nil {
+			status, _ := response.Typify(a, time.Now().UTC())
+			jsonentry["status"] = status.String()
+		}
 		if len(entry.Result.Rule) > 0 {
 			jsonentry["rule"] = entry.Result.Rule
 		}
 
-		if entry.Answer != nil && len(entry.Answer.Answer) > 0 {
+		if a != nil && len(a.Answer) > 0 {
 			var answers = []map[string]interface{}{}
-			for _, k := range entry.Answer.Answer {
+			for _, k := range a.Answer {
 				header := k.Header()
 				answer := map[string]interface{}{
 					"type": dns.TypeToString[header.Rrtype],
@@ -137,17 +190,26 @@ func handler(w http.ResponseWriter, r *http.Request) {
 }
 
 func startQueryLogServer() {
-	listenAddr := "127.0.0.1:8618" // sha512sum of "querylog" then each byte summed
+	listenAddr := "127.0.0.1:8618" // 8618 is sha512sum of "querylog" then each byte summed
 
-	http.HandleFunc("/querylog", handler)
+	go periodicQueryLogRotate(queryLogRotationPeriod)
+
+	http.HandleFunc("/querylog", handleQueryLog)
 	if err := http.ListenAndServe(listenAddr, nil); err != nil {
 		log.Fatalf("error in ListenAndServe: %s", err)
 	}
 }
 
-func trace(text string) {
+func trace(format string, args ...interface{}) {
 	pc := make([]uintptr, 10) // at least 1 entry needed
 	runtime.Callers(2, pc)
 	f := runtime.FuncForPC(pc[0])
-	log.Printf("%s(): %s\n", f.Name(), text)
+	var buf strings.Builder
+	buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name())))
+	text := fmt.Sprintf(format, args...)
+	buf.WriteString(text)
+	if len(text) == 0 || text[len(text)-1] != '\n' {
+		buf.WriteRune('\n')
+	}
+	fmt.Fprint(os.Stderr, buf.String())
 }
diff --git a/coredns_plugin/querylog_file.go b/coredns_plugin/querylog_file.go
new file mode 100644
index 00000000..1cfa93a7
--- /dev/null
+++ b/coredns_plugin/querylog_file.go
@@ -0,0 +1,146 @@
+package dnsfilter
+
+import (
+	"bytes"
+	"compress/gzip"
+	"encoding/json"
+	"fmt"
+	"log"
+	"os"
+	"sync"
+	"time"
+
+	"github.com/go-test/deep"
+)
+
+const (
+	queryLogRotationPeriod = time.Hour * 24  // rotate the log every 24 hours
+	queryLogFileName       = "querylog.json" // .gz added during compression
+)
+
+var (
+	fileWriteLock sync.Mutex
+)
+
+func flushToFile(buffer []logEntry) error {
+	if len(buffer) == 0 {
+		return nil
+	}
+	start := time.Now()
+
+	var b bytes.Buffer
+	e := json.NewEncoder(&b)
+	for _, entry := range buffer {
+		err := e.Encode(entry)
+		if err != nil {
+			log.Printf("Failed to marshal entry: %s", err)
+			return err
+		}
+	}
+
+	elapsed := time.Since(start)
+	log.Printf("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer)))
+
+	err := checkBuffer(buffer, b)
+	if err != nil {
+		log.Printf("failed to check buffer: %s", err)
+		return err
+	}
+
+	filenamegz := queryLogFileName + ".gz"
+
+	var zb bytes.Buffer
+
+	zw := gzip.NewWriter(&zb)
+	zw.Name = queryLogFileName
+	zw.ModTime = time.Now()
+
+	_, err = zw.Write(b.Bytes())
+	if err != nil {
+		log.Printf("Couldn't compress to gzip: %s", err)
+		zw.Close()
+		return err
+	}
+
+	if err = zw.Close(); err != nil {
+		log.Printf("Couldn't close gzip writer: %s", err)
+		return err
+	}
+
+	fileWriteLock.Lock()
+	defer fileWriteLock.Unlock()
+	f, err := os.OpenFile(filenamegz, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
+	if err != nil {
+		log.Printf("failed to create file \"%s\": %s", filenamegz, err)
+		return err
+	}
+	defer f.Close()
+
+	n, err := f.Write(zb.Bytes())
+	if err != nil {
+		log.Printf("Couldn't write to file: %s", err)
+		return err
+	}
+
+	log.Printf("ok \"%s\": %v bytes written", filenamegz, n)
+
+	return nil
+}
+
+func checkBuffer(buffer []logEntry, b bytes.Buffer) error {
+	l := len(buffer)
+	d := json.NewDecoder(&b)
+
+	i := 0
+	for d.More() {
+		var entry logEntry
+		err := d.Decode(&entry)
+		if err != nil {
+			log.Printf("Failed to decode: %s", err)
+			return err
+		}
+		if diff := deep.Equal(entry, buffer[i]); diff != nil {
+			log.Printf("decoded buffer differs: %s", diff)
+			return fmt.Errorf("decoded buffer differs: %s", diff)
+		}
+		i++
+	}
+	if i != l {
+		err := fmt.Errorf("check fail: %d vs %d entries", l, i)
+		log.Print(err)
+		return err
+	}
+	log.Printf("check ok: %d entries", i)
+
+	return nil
+}
+
+func rotateQueryLog() error {
+	from := queryLogFileName + ".gz"
+	to := queryLogFileName + ".gz.1"
+
+	if _, err := os.Stat(from); os.IsNotExist(err) {
+		// do nothing, file doesn't exist
+		return nil
+	}
+
+	err := os.Rename(from, to)
+	if err != nil {
+		log.Printf("Failed to rename querylog: %s", err)
+		return err
+	}
+
+	log.Printf("Rotated from %s to %s successfully", from, to)
+
+	return nil
+}
+
+func periodicQueryLogRotate(t time.Duration) {
+	for range time.Tick(t) {
+		err := rotateQueryLog()
+		if err != nil {
+			log.Printf("Failed to rotate querylog: %s", err)
+			// do nothing, continue rotating
+		}
+	}
+}
diff --git a/stats.go b/stats.go
index aff78ee0..29f36fbf 100644
--- a/stats.go
+++ b/stats.go
@@ -62,7 +62,12 @@ type stats struct {
 var statistics stats
 
 func initPeriodicStats(periodic *periodicStats) {
-	*periodic = periodicStats{}
+	periodic.Entries = statsEntries{}
+	periodic.LastRotate = time.Time{}
+}
+
+func init() {
+	purgeStats()
 }
 
 func purgeStats() {