From 3aac7e7bc9b4bb3ecff697b7748499a14bc64a0d Mon Sep 17 00:00:00 2001
From: Eugene Bujak <hmage@hmage.net>
Date: Thu, 4 Oct 2018 00:20:53 +0300
Subject: [PATCH] Add a test to demonstrate huge memory usage due from having
 too many regexps

---
 dnsfilter/dnsfilter_test.go | 367 +++++++++++++++++++++++++++++++-----
 1 file changed, 325 insertions(+), 42 deletions(-)

diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go
index e2d1522a..61d0a4ae 100644
--- a/dnsfilter/dnsfilter_test.go
+++ b/dnsfilter/dnsfilter_test.go
@@ -1,9 +1,13 @@
 package dnsfilter
 
 import (
+	"archive/zip"
+	"bytes"
+	"io/ioutil"
 	"net/http"
 	"net/http/httptest"
 	"path"
+	"runtime/pprof"
 	"strings"
 	"testing"
 	"time"
@@ -13,9 +17,175 @@ import (
 	"os"
 	"runtime"
 
+	"github.com/shirou/gopsutil/process"
 	"go.uber.org/goleak"
 )
 
+// first in file because it must be run first
+func TestLotsOfRulesMemoryUsage(t *testing.T) {
+	start := getRSS()
+	trace("RSS before loading rules - %d kB\n", start/1024)
+	dumpMemProfile(_Func() + "1.pprof")
+
+	d := NewForTest()
+	defer d.Destroy()
+	err := loadTestRules(d)
+	if err != nil {
+		t.Error(err)
+	}
+
+	afterLoad := getRSS()
+	trace("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024)
+	dumpMemProfile(_Func() + "2.pprof")
+
+	tests := []struct {
+		host  string
+		match bool
+	}{
+		{"asdasdasd_adsajdasda_asdasdjashdkasdasdasdasd_adsajdasda_asdasdjashdkasd.thisistesthost.com", false},
+		{"asdasdasd_adsajdasda_asdasdjashdkasdasdasdasd_adsajdasda_asdasdjashdkasd.ad.doubleclick.net", true},
+	}
+	for _, testcase := range tests {
+		ret, err := d.CheckHost(testcase.host)
+		if err != nil {
+			t.Errorf("Error while matching host %s: %s", testcase.host, err)
+		}
+		if !ret.IsFiltered && ret.IsFiltered != testcase.match {
+			t.Errorf("Expected hostname %s to not match", testcase.host)
+		}
+		if ret.IsFiltered && ret.IsFiltered != testcase.match {
+			t.Errorf("Expected hostname %s to match", testcase.host)
+		}
+	}
+	afterMatch := getRSS()
+	trace("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024)
+	dumpMemProfile(_Func() + "3.pprof")
+}
+
+func getRSS() uint64 {
+	proc, err := process.NewProcess(int32(os.Getpid()))
+	if err != nil {
+		panic(err)
+	}
+	minfo, err := proc.MemoryInfo()
+	return minfo.RSS
+}
+
+func dumpMemProfile(name string) {
+	runtime.GC()
+	f, err := os.Create(name)
+	if err != nil {
+		panic(err)
+	}
+	defer f.Close()
+	runtime.GC() // update the stats before writing them
+	err = pprof.WriteHeapProfile(f)
+	if err != nil {
+		panic(err)
+	}
+}
+
+const topHostsFilename = "../tests/top-1m.csv"
+
+func fetchTopHostsFromNet() {
+	trace("Fetching top hosts from network")
+	resp, err := http.Get("http://s3-us-west-1.amazonaws.com/umbrella-static/top-1m.csv.zip")
+	if err != nil {
+		panic(err)
+	}
+	defer resp.Body.Close()
+
+	trace("Reading zipfile body")
+	zipfile, err := ioutil.ReadAll(resp.Body)
+	if err != nil {
+		panic(err)
+	}
+
+	trace("Opening zipfile")
+	r, err := zip.NewReader(bytes.NewReader(zipfile), int64(len(zipfile)))
+	if err != nil {
+		panic(err)
+	}
+
+	if len(r.File) != 1 {
+		panic(fmt.Errorf("zipfile must have only one entry: %+v", r))
+	}
+	f := r.File[0]
+	trace("Unpacking file %s from zipfile", f.Name)
+	rc, err := f.Open()
+	if err != nil {
+		panic(err)
+	}
+	trace("Reading file %s contents", f.Name)
+	body, err := ioutil.ReadAll(rc)
+	if err != nil {
+		panic(err)
+	}
+	rc.Close()
+
+	trace("Writing file %s contents to disk", f.Name)
+	err = ioutil.WriteFile(topHostsFilename+".tmp", body, 0644)
+	if err != nil {
+		panic(err)
+	}
+	err = os.Rename(topHostsFilename+".tmp", topHostsFilename)
+	if err != nil {
+		panic(err)
+	}
+}
+
+func getTopHosts() {
+	// if file doesn't exist, fetch it
+	if _, err := os.Stat(topHostsFilename); os.IsNotExist(err) {
+		// file does not exist, fetch it
+		fetchTopHostsFromNet()
+	}
+}
+
+func TestLotsOfRulesLotsOfHostsMemoryUsage(t *testing.T) {
+	start := getRSS()
+	trace("RSS before loading rules - %d kB\n", start/1024)
+	dumpMemProfile(_Func() + "1.pprof")
+
+	d := NewForTest()
+	defer d.Destroy()
+	mustLoadTestRules(d)
+	trace("Have %d rules", d.Count())
+
+	afterLoad := getRSS()
+	trace("RSS after loading rules - %d kB (%d kB diff)\n", afterLoad/1024, (afterLoad-start)/1024)
+	dumpMemProfile(_Func() + "2.pprof")
+
+	getTopHosts()
+	hostnames, err := os.Open(topHostsFilename)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer hostnames.Close()
+	afterHosts := getRSS()
+	trace("RSS after loading hosts - %d kB (%d kB diff)\n", afterHosts/1024, (afterHosts-afterLoad)/1024)
+	dumpMemProfile(_Func() + "2.pprof")
+
+	{
+		scanner := bufio.NewScanner(hostnames)
+		for scanner.Scan() {
+			line := scanner.Text()
+			records := strings.Split(line, ",")
+			ret, err := d.CheckHost(records[1] + "." + records[1])
+			if err != nil {
+				t.Error(err)
+			}
+			if ret.Reason.Matched() {
+				// log.Printf("host \"%s\" mathed. Rule \"%s\", reason: %v", host, ret.Rule, ret.Reason)
+			}
+		}
+	}
+
+	afterMatch := getRSS()
+	trace("RSS after matching - %d kB (%d kB diff)\n", afterMatch/1024, (afterMatch-afterLoad)/1024)
+	dumpMemProfile(_Func() + "3.pprof")
+}
+
 func TestRuleToRegexp(t *testing.T) {
 	tests := []struct {
 		rule   string
@@ -114,6 +284,13 @@ func loadTestRules(d *Dnsfilter) error {
 	return err
 }
 
+func mustLoadTestRules(d *Dnsfilter) {
+	err := loadTestRules(d)
+	if err != nil {
+		panic(err)
+	}
+}
+
 func NewForTest() *Dnsfilter {
 	d := New()
 	purgeCaches()
@@ -126,7 +303,9 @@ func NewForTest() *Dnsfilter {
 func TestSanityCheck(t *testing.T) {
 	d := NewForTest()
 	defer d.Destroy()
+
 	d.checkAddRule(t, "||doubleclick.net^")
+	d.checkMatch(t, "doubleclick.net")
 	d.checkMatch(t, "www.doubleclick.net")
 	d.checkMatchEmpty(t, "nodoubleclick.net")
 	d.checkMatchEmpty(t, "doubleclick.net.ru")
@@ -134,6 +313,72 @@ func TestSanityCheck(t *testing.T) {
 	d.checkAddRuleFail(t, "lkfaojewhoawehfwacoefawr$@#$@3413841384")
 }
 
+func TestSuffixMatching1(t *testing.T) {
+	d := NewForTest()
+	defer d.Destroy()
+
+	d.checkAddRule(t, "||doubleclick.net^")
+	d.checkMatch(t, "doubleclick.net")
+	d.checkMatch(t, "www.doubleclick.net")
+	d.checkMatchEmpty(t, "nodoubleclick.net")
+	d.checkMatchEmpty(t, "doubleclick.net.ru")
+}
+
+func TestSuffixMatching2(t *testing.T) {
+	d := NewForTest()
+	defer d.Destroy()
+
+	d.checkAddRule(t, "|doubleclick.net^")
+	d.checkMatch(t, "doubleclick.net")
+	d.checkMatchEmpty(t, "www.doubleclick.net")
+	d.checkMatchEmpty(t, "nodoubleclick.net")
+	d.checkMatchEmpty(t, "doubleclick.net.ru")
+}
+
+func TestSuffixMatching3(t *testing.T) {
+	d := NewForTest()
+	defer d.Destroy()
+
+	d.checkAddRule(t, "doubleclick.net^")
+	d.checkMatch(t, "doubleclick.net")
+	d.checkMatch(t, "www.doubleclick.net")
+	d.checkMatch(t, "nodoubleclick.net")
+	d.checkMatchEmpty(t, "doubleclick.net.ru")
+}
+
+func TestSuffixMatching4(t *testing.T) {
+	d := NewForTest()
+	defer d.Destroy()
+
+	d.checkAddRule(t, "*doubleclick.net^")
+	d.checkMatch(t, "doubleclick.net")
+	d.checkMatch(t, "www.doubleclick.net")
+	d.checkMatch(t, "nodoubleclick.net")
+	d.checkMatchEmpty(t, "doubleclick.net.ru")
+}
+
+func TestSuffixMatching5(t *testing.T) {
+	d := NewForTest()
+	defer d.Destroy()
+
+	d.checkAddRule(t, "|*doubleclick.net^")
+	d.checkMatch(t, "doubleclick.net")
+	d.checkMatch(t, "www.doubleclick.net")
+	d.checkMatch(t, "nodoubleclick.net")
+	d.checkMatchEmpty(t, "doubleclick.net.ru")
+}
+
+func TestSuffixMatching6(t *testing.T) {
+	d := NewForTest()
+	defer d.Destroy()
+
+	d.checkAddRule(t, "||*doubleclick.net^")
+	d.checkMatch(t, "doubleclick.net")
+	d.checkMatch(t, "www.doubleclick.net")
+	d.checkMatch(t, "nodoubleclick.net")
+	d.checkMatchEmpty(t, "doubleclick.net.ru")
+}
+
 func TestCount(t *testing.T) {
 	d := NewForTest()
 	defer d.Destroy()
@@ -219,46 +464,6 @@ func TestAddRuleFail(t *testing.T) {
 	d.checkAddRuleFail(t, "lkfaojewhoawehfwacoefawr$@#$@3413841384")
 }
 
-func TestLotsOfRulesMemoryUsage(t *testing.T) {
-	var start, afterLoad, end runtime.MemStats
-	runtime.GC()
-	runtime.ReadMemStats(&start)
-	fmt.Printf("Memory usage before loading rules - %d kB alloc, %d kB sys\n", start.Alloc/1024, start.Sys/1024)
-
-	d := NewForTest()
-	defer d.Destroy()
-	err := loadTestRules(d)
-	if err != nil {
-		t.Error(err)
-	}
-	runtime.GC()
-	runtime.ReadMemStats(&afterLoad)
-	fmt.Printf("Memory usage after loading rules - %d kB alloc, %d kB sys\n", afterLoad.Alloc/1024, afterLoad.Sys/1024)
-
-	tests := []struct {
-		host  string
-		match bool
-	}{
-		{"asdasdasd_adsajdasda_asdasdjashdkasdasdasdasd_adsajdasda_asdasdjashdkasd.thisistesthost.com", false},
-		{"asdasdasd_adsajdasda_asdasdjashdkasdasdasdasd_adsajdasda_asdasdjashdkasd.ad.doubleclick.net", true},
-	}
-	for _, testcase := range tests {
-		ret, err := d.CheckHost(testcase.host)
-		if err != nil {
-			t.Errorf("Error while matching host %s: %s", testcase.host, err)
-		}
-		if !ret.IsFiltered && ret.IsFiltered != testcase.match {
-			t.Errorf("Expected hostname %s to not match", testcase.host)
-		}
-		if ret.IsFiltered && ret.IsFiltered != testcase.match {
-			t.Errorf("Expected hostname %s to match", testcase.host)
-		}
-	}
-	runtime.GC()
-	runtime.ReadMemStats(&end)
-	fmt.Printf("Memory usage after matching - %d kB alloc, %d kB sys\n", afterLoad.Alloc/1024, afterLoad.Sys/1024)
-}
-
 func TestSafeBrowsing(t *testing.T) {
 	testCases := []string{
 		"",
@@ -571,6 +776,80 @@ func BenchmarkLotsOfRulesMatchParallel(b *testing.B) {
 	})
 }
 
+func BenchmarkLotsOfRulesLotsOfHosts(b *testing.B) {
+	d := NewForTest()
+	defer d.Destroy()
+	mustLoadTestRules(d)
+
+	getTopHosts()
+	hostnames, err := os.Open(topHostsFilename)
+	if err != nil {
+		b.Fatal(err)
+	}
+	defer hostnames.Close()
+
+	scanner := bufio.NewScanner(hostnames)
+	b.ResetTimer()
+	for n := 0; n < b.N; n++ {
+		havedata := scanner.Scan()
+		if !havedata {
+			hostnames.Seek(0, 0)
+			scanner = bufio.NewScanner(hostnames)
+			havedata = scanner.Scan()
+		}
+		if !havedata {
+			b.Fatal(scanner.Err())
+		}
+		line := scanner.Text()
+		records := strings.Split(line, ",")
+		ret, err := d.CheckHost(records[1] + "." + records[1])
+		if err != nil {
+			b.Error(err)
+		}
+		if ret.Reason.Matched() {
+			// log.Printf("host \"%s\" mathed. Rule \"%s\", reason: %v", host, ret.Rule, ret.Reason)
+		}
+	}
+}
+
+func BenchmarkLotsOfRulesLotsOfHostsParallel(b *testing.B) {
+	d := NewForTest()
+	defer d.Destroy()
+	mustLoadTestRules(d)
+
+	getTopHosts()
+
+	b.ResetTimer()
+	b.RunParallel(func(pb *testing.PB) {
+		hostnames, err := os.Open(topHostsFilename)
+		if err != nil {
+			b.Fatal(err)
+		}
+		defer hostnames.Close()
+		scanner := bufio.NewScanner(hostnames)
+		for pb.Next() {
+			havedata := scanner.Scan()
+			if !havedata {
+				hostnames.Seek(0, 0)
+				scanner = bufio.NewScanner(hostnames)
+				havedata = scanner.Scan()
+			}
+			if !havedata {
+				b.Fatal(scanner.Err())
+			}
+			line := scanner.Text()
+			records := strings.Split(line, ",")
+			ret, err := d.CheckHost(records[1] + "." + records[1])
+			if err != nil {
+				b.Error(err)
+			}
+			if ret.Reason.Matched() {
+				// log.Printf("host \"%s\" mathed. Rule \"%s\", reason: %v", host, ret.Rule, ret.Reason)
+			}
+		}
+	})
+}
+
 func BenchmarkSafeBrowsing(b *testing.B) {
 	d := NewForTest()
 	defer d.Destroy()
@@ -645,8 +924,12 @@ func TestMain(m *testing.M) {
 // helper functions for debugging and testing
 //
 func purgeCaches() {
-	safebrowsingCache.Purge()
-	parentalCache.Purge()
+	if safebrowsingCache != nil {
+		safebrowsingCache.Purge()
+	}
+	if parentalCache != nil {
+		parentalCache.Purge()
+	}
 }
 
 func _Func() string {