MITM proxy

This commit is contained in:
Simon Zolin 2020-08-18 19:23:33 +03:00
parent c3123473cf
commit f85de51452
21 changed files with 2116 additions and 491 deletions

View file

@ -52,15 +52,21 @@ Contents:
* API: Get query log
* API: Set querylog parameters
* API: Get querylog parameters
* Filtering
* DNS Filtering
* Filters update mechanism
* API: Get filtering parameters
* API: Set filtering parameters
* API: Refresh filters
* API: Add Filter
* API: Set URL parameters
* API: Delete URL
* API: Set Filter parameters
* API: Delete Filter
* API: Domain Check
* HTTP Proxy
* API: Get Proxy settings
* API: Set Proxy settings
* API: Get Proxy filtering parameters
* API: Add Proxy Filter
* API: Delete Proxy Filter
* Log-in page
* API: Log in
* API: Log out
@ -1477,7 +1483,7 @@ Response:
}
## Filtering
## DNS Filtering
![](doc/agh-filtering.png)
@ -1548,7 +1554,19 @@ Response:
}
...
],
"user_rules":["...", ...]
"user_rules":["...", ...],
"proxy_filtering_enabled": true | false
"proxy_filters":[
{
"enabled":true,
"url":"https://...",
"name":"...",
"rules_count":1234,
"last_updated":"2019-09-04T18:29:30+00:00",
}
...
],
}
For both arrays `filters` and `whitelist_filters` there are unique values: id, url.
@ -1563,6 +1581,7 @@ Request:
{
"enabled": true | false
"proxy_filtering_enabled": true | false
"interval": 0 | 1 | 12 | 1*24 || 3*24 || 7*24
}
@ -1578,7 +1597,7 @@ Request:
POST /control/filtering/refresh
{
"whitelist": true
"type": blocklist | whitelist | proxylist
}
Response:
@ -1599,7 +1618,7 @@ Request:
{
"name": "..."
"url": "..." // URL or an absolute file path
"whitelist": true
"type": blocklist | whitelist | proxylist
}
Response:
@ -1607,7 +1626,7 @@ Response:
200 OK
### API: Set URL parameters
### API: Set Filter parameters
Request:
@ -1615,11 +1634,11 @@ Request:
{
"url": "..."
"whitelist": true
"type": blocklist | whitelist | proxylist
"data": {
"name": "..."
"url": "..."
"enabled": true | false
"enabled": true
}
}
@ -1628,7 +1647,7 @@ Response:
200 OK
### API: Delete URL
### API: Delete Filter
Request:
@ -1636,7 +1655,7 @@ Request:
{
"url": "..."
"whitelist": true
"type": blocklist | whitelist | proxylist
}
Response:
@ -1668,6 +1687,60 @@ Response:
}
## HTTP Proxy
Browser <-(HTTP)-> AGH Proxy <-(HTTP)-> Internet Server
HTTPS MITM:
. Browser --(CONNECT...)-> AGH Proxy --(handshake)-> Internet Server
. Browser <-(handshake,cert/AGH)-- AGH Proxy <-(cert/issuer)-- Internet Server
. Browser <-(TLS/session2)-> AGH Proxy <-(TLS/session1)-> Internet Server
### API: Get Proxy settings
Request:
GET /control/proxy_info
Response:
200 OK
{
"enabled": true|false,
"listen_address": "ip",
"listen_port": 12345,
"auth_username": "",
"auth_password": ""
}
### API: Set Proxy settings
Request:
POST /control/proxy_config
{
"enabled": true|false,
"listen_address": "ip",
"listen_port": 12345,
"auth_username": "",
"auth_password": "",
"cert_data":"...", // user-specified certificate. "": generate new
"pkey_data":"...",
}
Response:
200 OK
## Log-in page
After user completes the steps of installation wizard, he must log in into dashboard using his name and password. After user successfully logs in, he gets the Cookie which allows the server to authenticate him next time without password. After the Cookie is expired, user needs to perform log-in operation again.

247
filters/filter_file.go Normal file
View file

@ -0,0 +1,247 @@
package filters
import (
"bufio"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log"
)
// Allows printable UTF-8 text with CR, LF, TAB characters
func isPrintableText(data []byte) bool {
for _, c := range data {
if (c >= ' ' && c != 0x7f) || c == '\n' || c == '\r' || c == '\t' {
continue
}
return false
}
return true
}
// Download filter data
// Return nil on success. Set f.Path to a file path, or "" if the file was not modified
func (fs *filterStg) downloadFilter(f *Filter) error {
log.Debug("Filters: Downloading filter from %s", f.URL)
// create temp file
tmpFile, err := ioutil.TempFile(filepath.Join(fs.conf.FilterDir), "")
if err != nil {
return err
}
defer func() {
if tmpFile != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpFile.Name())
}
}()
// create data reader object
var reader io.Reader
if filepath.IsAbs(f.URL) {
f, err := os.Open(f.URL)
if err != nil {
return fmt.Errorf("open file: %s", err)
}
defer f.Close()
reader = f
} else {
req, err := http.NewRequest("GET", f.URL, nil)
if err != nil {
return err
}
if len(f.LastModified) != 0 {
req.Header.Add("If-Modified-Since", f.LastModified)
}
resp, err := fs.conf.HTTPClient.Do(req)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
f.networkError = true
return err
}
if resp.StatusCode == 304 { // "NOT_MODIFIED"
log.Debug("Filters: filter %s isn't modified since %s",
f.URL, f.LastModified)
f.LastUpdated = time.Now()
f.Path = ""
return nil
} else if resp.StatusCode != 200 {
err := fmt.Errorf("Filters: Couldn't download filter from %s: status code: %d",
f.URL, resp.StatusCode)
return err
}
f.LastModified = resp.Header.Get("Last-Modified")
reader = resp.Body
}
// parse and validate data, write to a file
err = writeFile(f, reader, tmpFile)
if err != nil {
return err
}
// Closing the file before renaming it is necessary on Windows
_ = tmpFile.Close()
fname := fs.filePath(*f)
err = os.Rename(tmpFile.Name(), fname)
if err != nil {
return err
}
tmpFile = nil // prevent from deleting this file in "defer" handler
log.Debug("Filters: saved filter %s at %s", f.URL, fname)
f.Path = fname
f.LastUpdated = time.Now()
return nil
}
func gatherUntil(dst []byte, dstLen int, src []byte, until int) int {
num := util.MinInt(len(src), until-dstLen)
return copy(dst[dstLen:], src[:num])
}
func isHTML(buf []byte) bool {
s := strings.ToLower(string(buf))
return strings.Contains(s, "<html") ||
strings.Contains(s, "<!doctype")
}
// Read file data and count the number of rules
func parseFilter(f *Filter, reader io.Reader) error {
ruleCount := 0
r := bufio.NewReader(reader)
log.Debug("Filters: parsing %s", f.URL)
var err error
for err == nil {
var line string
line, err = r.ReadString('\n')
if err != nil && err != io.EOF {
return err
}
line = strings.TrimSpace(line)
if len(line) == 0 ||
line[0] == '#' ||
line[0] == '!' {
continue
}
ruleCount++
}
log.Debug("Filters: %s: %d rules", f.URL, ruleCount)
f.RuleCount = uint64(ruleCount)
return nil
}
// Read data, parse, write to a file
func writeFile(f *Filter, reader io.Reader, outFile *os.File) error {
ruleCount := 0
buf := make([]byte, 64*1024)
total := 0
var chunk []byte
firstChunk := make([]byte, 4*1024)
firstChunkLen := 0
for {
n, err := reader.Read(buf)
if err != nil && err != io.EOF {
return err
}
total += n
if !isPrintableText(buf[:n]) {
return fmt.Errorf("data contains non-printable characters")
}
if firstChunk != nil {
// gather full buffer firstChunk and perform its data tests
firstChunkLen += gatherUntil(firstChunk, firstChunkLen, buf[:n], len(firstChunk))
if firstChunkLen == len(firstChunk) ||
err == io.EOF {
if isHTML(firstChunk[:firstChunkLen]) {
return fmt.Errorf("data is HTML, not plain text")
}
firstChunk = nil
}
}
_, err2 := outFile.Write(buf[:n])
if err2 != nil {
return err2
}
chunk = append(chunk, buf[:n]...)
s := string(chunk)
for len(s) != 0 {
i, line := splitNext(&s, '\n')
if i < 0 && err != io.EOF {
// no more lines in the current chunk
break
}
chunk = []byte(s)
if len(line) == 0 ||
line[0] == '#' ||
line[0] == '!' {
continue
}
ruleCount++
}
if err == io.EOF {
break
}
}
log.Debug("Filters: updated filter %s: %d bytes, %d rules",
f.URL, total, ruleCount)
f.RuleCount = uint64(ruleCount)
return nil
}
// SplitNext - split string by a byte
// Whitespace is trimmed
// Return byte position and the first chunk
func splitNext(data *string, by byte) (int, string) {
s := *data
i := strings.IndexByte(s, by)
var chunk string
if i < 0 {
chunk = s
s = ""
} else {
chunk = s[:i]
s = s[i+1:]
}
*data = s
chunk = strings.TrimSpace(chunk)
return i, chunk
}

329
filters/filter_http.go Normal file
View file

@ -0,0 +1,329 @@
package filters
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/jsonutil"
"github.com/AdguardTeam/golibs/log"
)
// Print to log and set HTTP error message
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("Filters: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
// IsValidURL - return TRUE if URL or file path is valid
func IsValidURL(rawurl string) bool {
if filepath.IsAbs(rawurl) {
// this is a file path
return util.FileExists(rawurl)
}
url, err := url.ParseRequestURI(rawurl)
if err != nil {
return false //Couldn't even parse the rawurl
}
if len(url.Scheme) == 0 {
return false //No Scheme found
}
return true
}
func (f *Filtering) getFilterModule(t string) Filters {
switch t {
case "blocklist":
return f.dnsBlocklist
case "whitelist":
return f.dnsAllowlist
case "proxylist":
return f.Proxylist
default:
return nil
}
}
func (f *Filtering) restartMods(t string) {
fN := f.getFilterModule(t)
fN.NotifyObserver(EventBeforeUpdate)
fN.NotifyObserver(EventAfterUpdate)
}
func (f *Filtering) handleFilterAdd(w http.ResponseWriter, r *http.Request) {
type reqJSON struct {
Name string `json:"name"`
URL string `json:"url"`
Type string `json:"type"`
}
req := reqJSON{}
_, err := jsonutil.DecodeObject(&req, r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
filterN := f.getFilterModule(req.Type)
if filterN == nil {
httpError(r, w, http.StatusBadRequest, "invalid type: %s", req.Type)
return
}
filt := Filter{
Enabled: true,
Name: req.Name,
URL: req.URL,
}
err = filterN.Add(filt)
if err != nil {
httpError(r, w, http.StatusBadRequest, "add filter: %s", err)
return
}
f.conf.ConfigModified()
f.restartMods(req.Type)
}
func (f *Filtering) handleFilterRemove(w http.ResponseWriter, r *http.Request) {
type reqJSON struct {
URL string `json:"url"`
Type string `json:"type"`
}
req := reqJSON{}
_, err := jsonutil.DecodeObject(&req, r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
filterN := f.getFilterModule(req.Type)
if filterN == nil {
httpError(r, w, http.StatusBadRequest, "invalid type: %s", req.Type)
return
}
removed := filterN.Delete(req.URL)
if removed == nil {
httpError(r, w, http.StatusInternalServerError, "no filter with such URL")
return
}
f.conf.ConfigModified()
if removed.Enabled {
f.restartMods(req.Type)
}
err = os.Remove(removed.Path)
if err != nil {
log.Error("os.Remove: %s", err)
}
}
func (f *Filtering) handleFilterModify(w http.ResponseWriter, r *http.Request) {
type propsJSON struct {
Name string `json:"name"`
URL string `json:"url"`
Enabled bool `json:"enabled"`
}
type reqJSON struct {
URL string `json:"url"`
Type string `json:"type"`
Data propsJSON `json:"data"`
}
req := reqJSON{}
_, err := jsonutil.DecodeObject(&req, r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
filterN := f.getFilterModule(req.Type)
if filterN == nil {
httpError(r, w, http.StatusBadRequest, "invalid type: %s", req.Type)
return
}
st, _, err := filterN.Modify(req.URL, req.Data.Enabled, req.Data.Name, req.Data.URL)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
f.conf.ConfigModified()
if st == StatusChangedEnabled ||
st == StatusChangedURL {
// TODO StatusChangedURL: delete old file
f.restartMods(req.Type)
}
}
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
return
}
f.conf.UserRules = strings.Split(string(body), "\n")
f.conf.ConfigModified()
f.restartMods("blocklist")
}
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
type reqJSON struct {
Type string `json:"type"`
}
req := reqJSON{}
_, err := jsonutil.DecodeObject(&req, r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
filterN := f.getFilterModule(req.Type)
if filterN == nil {
httpError(r, w, http.StatusBadRequest, "invalid type: %s", req.Type)
return
}
filterN.Refresh(0)
}
type filterJSON struct {
ID int64 `json:"id"`
Enabled bool `json:"enabled"`
URL string `json:"url"`
Name string `json:"name"`
RulesCount uint32 `json:"rules_count"`
LastUpdated string `json:"last_updated"`
}
func filterToJSON(f Filter) filterJSON {
fj := filterJSON{
ID: int64(f.ID),
Enabled: f.Enabled,
URL: f.URL,
Name: f.Name,
RulesCount: uint32(f.RuleCount),
}
if !f.LastUpdated.IsZero() {
fj.LastUpdated = f.LastUpdated.Format(time.RFC3339)
}
return fj
}
// Get filtering configuration
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
type respJSON struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"` // in hours
Filters []filterJSON `json:"filters"`
WhitelistFilters []filterJSON `json:"whitelist_filters"`
UserRules []string `json:"user_rules"`
Proxylist []filterJSON `json:"proxy_filters"`
}
resp := respJSON{}
resp.Enabled = f.conf.Enabled
resp.Interval = f.conf.UpdateIntervalHours
resp.UserRules = f.conf.UserRules
f0 := f.dnsBlocklist.List(0)
f1 := f.dnsAllowlist.List(0)
f2 := f.Proxylist.List(0)
for _, filt := range f0 {
fj := filterToJSON(filt)
resp.Filters = append(resp.Filters, fj)
}
for _, filt := range f1 {
fj := filterToJSON(filt)
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
}
for _, filt := range f2 {
fj := filterToJSON(filt)
resp.Proxylist = append(resp.Proxylist, fj)
}
jsonVal, err := json.Marshal(resp)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jsonVal)
}
// Set filtering configuration
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
type reqJSON struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"`
}
req := reqJSON{}
_, err := jsonutil.DecodeObject(&req, r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
if !CheckFiltersUpdateIntervalHours(req.Interval) {
httpError(r, w, http.StatusBadRequest, "Unsupported interval")
return
}
restart := false
if f.conf.Enabled != req.Enabled {
restart = true
}
f.conf.Enabled = req.Enabled
f.conf.UpdateIntervalHours = req.Interval
c := Conf{}
c.UpdateIntervalHours = req.Interval
f.dnsBlocklist.SetConfig(c)
f.dnsAllowlist.SetConfig(c)
f.Proxylist.SetConfig(c)
f.conf.ConfigModified()
if restart {
f.restartMods("blocklist")
}
}
// registerWebHandlers - register handlers
func (f *Filtering) registerWebHandlers() {
f.conf.HTTPRegister("GET", "/control/filtering/status", f.handleFilteringStatus)
f.conf.HTTPRegister("POST", "/control/filtering/config", f.handleFilteringConfig)
f.conf.HTTPRegister("POST", "/control/filtering/add_url", f.handleFilterAdd)
f.conf.HTTPRegister("POST", "/control/filtering/remove_url", f.handleFilterRemove)
f.conf.HTTPRegister("POST", "/control/filtering/set_url", f.handleFilterModify)
f.conf.HTTPRegister("POST", "/control/filtering/refresh", f.handleFilteringRefresh)
f.conf.HTTPRegister("POST", "/control/filtering/set_rules", f.handleFilteringSetRules)
}
// CheckFiltersUpdateIntervalHours - verify update interval
func CheckFiltersUpdateIntervalHours(i uint32) bool {
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
}

118
filters/filter_module.go Normal file
View file

@ -0,0 +1,118 @@
package filters
import (
"net/http"
"path/filepath"
)
// Filtering - module object
type Filtering struct {
dnsBlocklist Filters // DNS blocklist filters
dnsAllowlist Filters // DNS allowlist filters
Proxylist Filters // MITM Proxy filtering module
conf ModuleConf
}
// ModuleConf - module config
type ModuleConf struct {
Enabled bool
UpdateIntervalHours uint32 // 0: disabled
HTTPClient *http.Client
DataDir string
DNSBlocklist []Filter
DNSAllowlist []Filter
Proxylist []Filter
UserRules []string
// Called when the configuration is changed by HTTP request
ConfigModified func()
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
}
// NewModule - create module
func NewModule(conf ModuleConf) *Filtering {
f := Filtering{}
f.conf = conf
fconf := Conf{}
fconf.FilterDir = filepath.Join(conf.DataDir, "filters_dnsblock")
fconf.List = conf.DNSBlocklist
fconf.UpdateIntervalHours = conf.UpdateIntervalHours
fconf.HTTPClient = conf.HTTPClient
f.dnsBlocklist = New(fconf)
fconf = Conf{}
fconf.FilterDir = filepath.Join(conf.DataDir, "filters_dnsallow")
fconf.List = conf.DNSAllowlist
fconf.UpdateIntervalHours = conf.UpdateIntervalHours
fconf.HTTPClient = conf.HTTPClient
f.dnsAllowlist = New(fconf)
fconf = Conf{}
fconf.FilterDir = filepath.Join(conf.DataDir, "filters_mitmproxy")
fconf.List = conf.Proxylist
fconf.UpdateIntervalHours = conf.UpdateIntervalHours
fconf.HTTPClient = conf.HTTPClient
f.Proxylist = New(fconf)
return &f
}
const (
DNSBlocklist = iota
DNSAllowlist
Proxylist
)
func stringArrayDup(a []string) []string {
a2 := make([]string, len(a))
copy(a2, a)
return a2
}
// WriteDiskConfig - write configuration data
func (f *Filtering) WriteDiskConfig(mc *ModuleConf) {
mc.Enabled = f.conf.Enabled
mc.UpdateIntervalHours = f.conf.UpdateIntervalHours
mc.UserRules = stringArrayDup(f.conf.UserRules)
c := Conf{}
f.dnsBlocklist.WriteDiskConfig(&c)
mc.DNSBlocklist = c.List
c = Conf{}
f.dnsAllowlist.WriteDiskConfig(&c)
mc.DNSAllowlist = c.List
c = Conf{}
f.Proxylist.WriteDiskConfig(&c)
mc.Proxylist = c.List
}
// GetList - get specific filter list
func (f *Filtering) GetList(t uint32) Filters {
switch t {
case DNSBlocklist:
return f.dnsBlocklist
case DNSAllowlist:
return f.dnsAllowlist
case Proxylist:
return f.Proxylist
}
return nil
}
// Start - start module
func (f *Filtering) Start() {
f.dnsBlocklist.Start()
f.dnsAllowlist.Start()
f.Proxylist.Start()
f.registerWebHandlers()
}
// Close - close the module
func (f *Filtering) Close() {
}

246
filters/filter_storage.go Normal file
View file

@ -0,0 +1,246 @@
package filters
import (
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/AdguardTeam/golibs/log"
"go.uber.org/atomic"
)
// filter storage object
type filterStg struct {
updateTaskRunning bool
updated []Filter // list of filters that were downloaded during update procedure
updateChan chan bool // signal for the update goroutine
conf *Conf
confLock sync.Mutex
nextID atomic.Uint64 // next filter ID
observer EventHandler // user function that receives notifications
}
// initialize the module
func newFiltersObj(conf Conf) Filters {
fs := filterStg{}
fs.conf = &Conf{}
*fs.conf = conf
fs.nextID.Store(uint64(time.Now().Unix()))
fs.updateChan = make(chan bool, 2)
return &fs
}
// Start - start module
func (fs *filterStg) Start() {
_ = os.MkdirAll(fs.conf.FilterDir, 0755)
// Load all enabled filters
// On error, RuleCount is set to 0 - users won't try to use such filters
// and in the future the update procedure will re-download the file
for i := range fs.conf.List {
f := &fs.conf.List[i]
fname := fs.filePath(*f)
st, err := os.Stat(fname)
if err != nil {
log.Debug("Filters: os.Stat: %s %s", fname, err)
continue
}
f.LastUpdated = st.ModTime()
if !f.Enabled {
continue
}
file, err := os.OpenFile(fname, os.O_RDONLY, 0)
if err != nil {
log.Error("Filters: os.OpenFile: %s %s", fname, err)
continue
}
_ = parseFilter(f, file)
file.Close()
f.nextUpdate = f.LastUpdated.Add(time.Duration(fs.conf.UpdateIntervalHours) * time.Hour)
}
if !fs.updateTaskRunning {
fs.updateTaskRunning = true
go fs.updateBySignal()
go fs.updateByTimer()
}
}
// Close - close the module
func (fs *filterStg) Close() {
fs.updateChan <- false
close(fs.updateChan)
}
// Duplicate filter array
func arrayFilterDup(f []Filter) []Filter {
nf := make([]Filter, len(f))
copy(nf, f)
return nf
}
// WriteDiskConfig - write configuration on disk
func (fs *filterStg) WriteDiskConfig(c *Conf) {
fs.confLock.Lock()
*c = *fs.conf
c.List = arrayFilterDup(fs.conf.List)
fs.confLock.Unlock()
}
// SetConfig - set new configuration settings
func (fs *filterStg) SetConfig(c Conf) {
fs.conf.UpdateIntervalHours = c.UpdateIntervalHours
}
// SetObserver - set user handler for notifications
func (fs *filterStg) SetObserver(handler EventHandler) {
fs.observer = handler
}
// NotifyObserver - notify users about the event
func (fs *filterStg) NotifyObserver(flags uint) {
if fs.observer == nil {
return
}
fs.observer(flags)
}
// List (thread safe)
func (fs *filterStg) List(flags uint) []Filter {
fs.confLock.Lock()
list := make([]Filter, len(fs.conf.List))
for i, f := range fs.conf.List {
nf := f
nf.Path = fs.filePath(f)
list[i] = nf
}
fs.confLock.Unlock()
return list
}
// Add - add filter (thread safe)
func (fs *filterStg) Add(nf Filter) error {
fs.confLock.Lock()
defer fs.confLock.Unlock()
for _, f := range fs.conf.List {
if f.Name == nf.Name || f.URL == nf.URL {
return fmt.Errorf("filter with this Name or URL already exists")
}
}
nf.ID = fs.nextFilterID()
nf.Enabled = true
err := fs.downloadFilter(&nf)
if err != nil {
log.Debug("%s", err)
return err
}
fs.conf.List = append(fs.conf.List, nf)
log.Debug("Filters: added filter %s", nf.URL)
return nil
}
// Delete - remove filter (thread safe)
func (fs *filterStg) Delete(url string) *Filter {
fs.confLock.Lock()
defer fs.confLock.Unlock()
nf := []Filter{}
var found *Filter
for i := range fs.conf.List {
f := &fs.conf.List[i]
if f.URL == url {
found = f
continue
}
nf = append(nf, *f)
}
if found == nil {
return nil
}
fs.conf.List = nf
log.Debug("Filters: removed filter %s", url)
found.Path = fs.filePath(*found) // the caller will delete the file
return found
}
// Modify - set filter properties (thread safe)
// Return Status* bitarray
func (fs *filterStg) Modify(url string, enabled bool, name string, newURL string) (int, Filter, error) {
fs.confLock.Lock()
defer fs.confLock.Unlock()
st := 0
for i := range fs.conf.List {
f := &fs.conf.List[i]
if f.URL == url {
backup := *f
f.Name = name
if f.Enabled != enabled {
f.Enabled = enabled
st |= StatusChangedEnabled
}
if f.URL != newURL {
f.URL = newURL
st |= StatusChangedURL
}
needDownload := false
if (st & StatusChangedURL) != 0 {
f.ID = fs.nextFilterID()
needDownload = true
} else if (st&StatusChangedEnabled) != 0 && enabled {
fname := fs.filePath(*f)
file, err := os.OpenFile(fname, os.O_RDONLY, 0)
if err != nil {
log.Debug("Filters: os.OpenFile: %s %s", fname, err)
needDownload = true
} else {
_ = parseFilter(f, file)
file.Close()
}
}
if needDownload {
f.LastModified = ""
f.RuleCount = 0
err := fs.downloadFilter(f)
if err != nil {
*f = backup
return 0, Filter{}, err
}
}
return st, backup, nil
}
}
return 0, Filter{}, fmt.Errorf("filter %s not found", url)
}
// Get filter file name
func (fs *filterStg) filePath(f Filter) string {
return filepath.Join(fs.conf.FilterDir, fmt.Sprintf("%d.txt", f.ID))
}
// Get next filter ID
func (fs *filterStg) nextFilterID() uint64 {
return fs.nextID.Inc()
}

154
filters/filter_test.go Normal file
View file

@ -0,0 +1,154 @@
package filters
import (
"fmt"
"net"
"net/http"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
)
func testStartFilterListener(counter *atomic.Uint32) net.Listener {
mux := http.NewServeMux()
mux.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
(*counter).Inc()
content := `||example.org^$third-party
# Inline comment example
||example.com^$third-party
0.0.0.0 example.com
`
_, _ = w.Write([]byte(content))
})
mux.HandleFunc("/filters/2.txt", func(w http.ResponseWriter, r *http.Request) {
(*counter).Inc()
content := `||example.org^$third-party
# Inline comment example
||example.com^$third-party
0.0.0.0 example.com
1.1.1.1 example1.com
`
_, _ = w.Write([]byte(content))
})
listener, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
go func() {
_ = http.Serve(listener, mux)
}()
return listener
}
func prepareTestDir() string {
const dir = "./agh-test"
_ = os.RemoveAll(dir)
_ = os.MkdirAll(dir, 0755)
return dir
}
var updateStatus atomic.Uint32
func onFiltersUpdate(flags uint) {
switch flags {
case EventBeforeUpdate:
updateStatus.Store(updateStatus.Load() | 1)
case EventAfterUpdate:
updateStatus.Store(updateStatus.Load() | 2)
}
}
func TestFilters(t *testing.T) {
counter := atomic.Uint32{}
lhttp := testStartFilterListener(&counter)
defer func() { _ = lhttp.Close() }()
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
fconf := Conf{}
fconf.UpdateIntervalHours = 1
fconf.FilterDir = dir
fconf.HTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
fs := New(fconf)
fs.SetObserver(onFiltersUpdate)
fs.Start()
port := lhttp.Addr().(*net.TCPAddr).Port
URL := fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", port)
// add and download
f := Filter{
URL: URL,
}
err := fs.Add(f)
assert.Equal(t, nil, err)
// check
l := fs.List(0)
assert.Equal(t, 1, len(l))
assert.Equal(t, URL, l[0].URL)
assert.True(t, l[0].Enabled)
assert.Equal(t, uint64(3), l[0].RuleCount)
assert.True(t, l[0].ID != 0)
// disable
st, _, err := fs.Modify(f.URL, false, "name", f.URL)
assert.Equal(t, StatusChangedEnabled, st)
// check: disabled
l = fs.List(0)
assert.Equal(t, 1, len(l))
assert.True(t, !l[0].Enabled)
// modify URL
newURL := fmt.Sprintf("http://127.0.0.1:%d/filters/2.txt", port)
st, modified, err := fs.Modify(URL, false, "name", newURL)
assert.Equal(t, StatusChangedURL, st)
_ = os.Remove(modified.Path)
// check: new ID, new URL
l = fs.List(0)
assert.Equal(t, 1, len(l))
assert.Equal(t, newURL, l[0].URL)
assert.Equal(t, uint64(4), l[0].RuleCount)
assert.True(t, modified.ID != l[0].ID)
// enable
st, _, err = fs.Modify(newURL, true, "name", newURL)
assert.Equal(t, StatusChangedEnabled, st)
// update
cnt := counter.Load()
fs.Refresh(0)
for i := 0; ; i++ {
if i == 2 {
assert.True(t, false)
break
}
if cnt != counter.Load() {
// filter was updated
break
}
time.Sleep(time.Second)
}
assert.Equal(t, uint32(1|2), updateStatus.Load())
// delete
removed := fs.Delete(newURL)
assert.NotNil(t, removed)
_ = os.Remove(removed.Path)
fs.Close()
}

176
filters/filter_update.go Normal file
View file

@ -0,0 +1,176 @@
package filters
import (
"os"
"time"
"github.com/AdguardTeam/golibs/log"
)
// Refresh - begin filters update procedure
func (fs *filterStg) Refresh(flags uint) {
fs.confLock.Lock()
defer fs.confLock.Unlock()
for i := range fs.conf.List {
f := &fs.conf.List[i]
f.nextUpdate = time.Time{}
}
fs.updateChan <- true
}
// Start update procedure periodically
func (fs *filterStg) updateByTimer() {
const maxPeriod = 1 * 60 * 60
period := 5 // use a dynamically increasing time interval, while network or DNS is down
for {
if fs.conf.UpdateIntervalHours == 0 {
period = maxPeriod
// update is disabled
time.Sleep(time.Duration(period) * time.Second)
continue
}
fs.updateChan <- true
time.Sleep(time.Duration(period) * time.Second)
period += period
if period > maxPeriod {
period = maxPeriod
}
}
}
// Begin update procedure by signal
func (fs *filterStg) updateBySignal() {
for {
select {
case ok := <-fs.updateChan:
if !ok {
return
}
fs.updateAll()
}
}
}
// Update filters
// Algorithm:
// . Get next filter to update:
// . Download data from Internet and store on disk (in a new file)
// . Add new filter to the special list
// . Repeat for next filter
// (All filters are downloaded)
// . Stop modules that use filters
// . For each updated filter:
// . Rename "new file name" -> "old file name"
// . Update meta data
// . Restart modules that use filters
func (fs *filterStg) updateAll() {
log.Debug("Filters: updating...")
for {
var uf Filter
fs.confLock.Lock()
f := fs.getNextToUpdate()
if f != nil {
uf = *f
}
fs.confLock.Unlock()
if f == nil {
fs.applyUpdate()
return
}
uf.ID = fs.nextFilterID()
err := fs.downloadFilter(&uf)
if err != nil {
if uf.networkError {
fs.confLock.Lock()
f.nextUpdate = time.Now().Add(10 * time.Second)
fs.confLock.Unlock()
}
continue
}
// add new filter to the list
fs.updated = append(fs.updated, uf)
}
}
// Get next filter to update
func (fs *filterStg) getNextToUpdate() *Filter {
now := time.Now()
for i := range fs.conf.List {
f := &fs.conf.List[i]
if f.Enabled &&
f.nextUpdate.Unix() <= now.Unix() {
f.nextUpdate = now.Add(time.Duration(fs.conf.UpdateIntervalHours) * time.Hour)
return f
}
}
return nil
}
// Replace filter files
func (fs *filterStg) applyUpdate() {
if len(fs.updated) == 0 {
log.Debug("Filters: no filters were updated")
return
}
fs.NotifyObserver(EventBeforeUpdate)
nUpdated := 0
fs.confLock.Lock()
for _, uf := range fs.updated {
found := false
for i := range fs.conf.List {
f := &fs.conf.List[i]
if uf.URL == f.URL {
found = true
fpath := fs.filePath(*f)
f.LastUpdated = uf.LastUpdated
if len(uf.Path) == 0 {
// the data hasn't changed - just update file mod time
err := os.Chtimes(fpath, f.LastUpdated, f.LastUpdated)
if err != nil {
log.Error("Filters: os.Chtimes: %s", err)
}
continue
}
err := os.Rename(uf.Path, fpath)
if err != nil {
log.Error("Filters: os.Rename:%s", err)
}
f.RuleCount = uf.RuleCount
nUpdated++
break
}
}
if !found {
// the updated filter was downloaded,
// but it's already removed from the main list
_ = os.Remove(fs.filePath(uf))
}
}
fs.confLock.Unlock()
log.Debug("Filters: %d filters were updated", nUpdated)
fs.updated = nil
fs.NotifyObserver(EventAfterUpdate)
}

93
filters/filters.go Normal file
View file

@ -0,0 +1,93 @@
package filters
import (
"net/http"
"time"
)
// Filters - main interface
type Filters interface {
// Start - start module
Start()
// Close - close the module
Close()
// WriteDiskConfig - write configuration on disk
WriteDiskConfig(c *Conf)
// SetConfig - set new configuration settings
// Currently only UpdateIntervalHours is supported
SetConfig(c Conf)
// SetObserver - set user handler for notifications
SetObserver(handler EventHandler)
// NotifyObserver - notify users about the event
NotifyObserver(flags uint)
// List (thread safe)
List(flags uint) []Filter
// Add - add filter (thread safe)
Add(nf Filter) error
// Delete - remove filter (thread safe)
Delete(url string) *Filter
// Modify - set filter properties (thread safe)
// Return Status* bitarray, old filter properties and error
Modify(url string, enabled bool, name string, newURL string) (int, Filter, error)
// Refresh - begin filters update procedure
Refresh(flags uint)
}
// Filter - filter object
type Filter struct {
ID uint64 `yaml:"id"`
Enabled bool `yaml:"enabled"`
Name string `yaml:"name"`
URL string `yaml:"url"`
LastModified string `yaml:"last_modified"` // value of Last-Modified HTTP header field
Path string `yaml:"-"`
// number of rules
// 0 means the file isn't loaded - user shouldn't use this filter
RuleCount uint64 `yaml:"-"`
LastUpdated time.Time `yaml:"-"` // time of the last update (= file modification time)
nextUpdate time.Time // time of the next update
networkError bool // network error during download
}
const (
// EventBeforeUpdate - this event is signalled before the update procedure renames/removes old filter files
EventBeforeUpdate = iota
// EventAfterUpdate - this event is signalled after the update procedure is finished
EventAfterUpdate
)
// EventHandler - event handler function
type EventHandler func(flags uint)
const (
// StatusChangedEnabled - changed 'Enabled'
StatusChangedEnabled = 2
// StatusChangedURL - changed 'URL'
StatusChangedURL = 4
)
// Conf - configuration
type Conf struct {
FilterDir string
UpdateIntervalHours uint32 // 0: disabled
HTTPClient *http.Client
List []Filter
}
// New - create object
func New(conf Conf) Filters {
return newFiltersObj(conf)
}

2
go.mod
View file

@ -5,6 +5,7 @@ go 1.14
require (
github.com/AdguardTeam/dnsproxy v0.30.1
github.com/AdguardTeam/golibs v0.4.2
github.com/AdguardTeam/gomitmproxy v0.2.0
github.com/AdguardTeam/urlfilter v0.11.2
github.com/NYTimes/gziphandler v1.1.1
github.com/fsnotify/fsnotify v1.4.7
@ -17,6 +18,7 @@ require (
github.com/sparrc/go-ping v0.0.0-20190613174326-4e5b6552494c
github.com/stretchr/testify v1.5.1
go.etcd.io/bbolt v1.3.4
go.uber.org/atomic v1.6.0
golang.org/x/crypto v0.0.0-20200403201458-baeed622b8d8
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e
golang.org/x/sys v0.0.0-20200331124033-c3d80250170d

8
go.sum
View file

@ -3,6 +3,7 @@ github.com/AdguardTeam/dnsproxy v0.30.1/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPq
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o=
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.11.2 h1:gCrWGh63Yqw3z4yi9pgikfsbshIEyvAu/KYV3MvTBlc=
github.com/AdguardTeam/urlfilter v0.11.2/go.mod h1:aMuejlNxpWppOVjiEV87X6z0eMf7wsXHTAIWQuylfZY=
@ -113,6 +114,8 @@ github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljT
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
go.etcd.io/bbolt v1.3.4 h1:hi1bXHMVrlQh6WwxAy+qZCV/SYIlqo+Ushwdpa4tAKg=
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -120,6 +123,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200403201458-baeed622b8d8 h1:fpnn/HnJONpIu6hkXi1u/7rR0NzilgWr4T0JmWkEitk=
golang.org/x/crypto v0.0.0-20200403201458-baeed622b8d8/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
@ -145,9 +150,12 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190624180213-70d37148ca0c/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 h1:VvQyQJN0tSuecqgcIxMWnnfG5kSmgy9KZR9sW3W5QeA=
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View file

@ -9,6 +9,8 @@ import (
"github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/filters"
"github.com/AdguardTeam/AdGuardHome/mitmproxy"
"github.com/AdguardTeam/AdGuardHome/querylog"
"github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/golibs/file"
@ -17,8 +19,7 @@ import (
)
const (
dataDir = "data" // data storage
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
dataDir = "data" // data storage
)
// logSettings
@ -54,9 +55,13 @@ type configuration struct {
DNS dnsConfig `yaml:"dns"`
TLS tlsConfigSettings `yaml:"tls"`
Filters []filter `yaml:"filters"`
WhitelistFilters []filter `yaml:"whitelist_filters"`
UserRules []string `yaml:"user_rules"`
MITM mitmproxy.Config `yaml:"mitmproxy"`
Filters []filters.Filter `yaml:"filters"`
WhitelistFilters []filters.Filter `yaml:"whitelist_filters"`
UserRules []string `yaml:"user_rules"`
ProxyFilters []filters.Filter `yaml:"proxy_filters"`
DHCP dhcpd.ServerConfig `yaml:"dhcp"`
@ -155,7 +160,43 @@ func initConfig() {
config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.CacheTime = 30
config.Filters = defaultFilters()
config.Filters = defaultDNSBlocklistFilters()
config.ProxyFilters = defaultContentFilters()
}
func defaultDNSBlocklistFilters() []filters.Filter {
return []filters.Filter{
{
ID: 1,
Enabled: true,
URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt",
Name: "AdGuard Simplified Domain Names filter",
},
{
ID: 2,
Enabled: false,
URL: "https://adaway.org/hosts.txt",
Name: "AdAway",
},
{
ID: 3,
Enabled: false,
URL: "https://www.malwaredomainlist.com/hostslist/hosts.txt",
Name: "MalwareDomainList.com Hosts List",
},
}
}
func defaultContentFilters() []filters.Filter {
return []filters.Filter{
{
ID: 1,
Enabled: true,
URL: "https://filters.adtidy.org/extension/chromium/filters/2.txt",
Name: "AdGuard Base filter",
},
}
}
// getConfigFilename returns path to the current config file
@ -203,7 +244,7 @@ func parseConfig() error {
return err
}
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
if !filters.CheckFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
config.DNS.FiltersUpdateIntervalHours = 24
}
@ -263,6 +304,17 @@ func (c *configuration) write() error {
config.DNS.DnsfilterConf = c
}
if Context.filters != nil {
fconf := filters.ModuleConf{}
Context.filters.WriteDiskConfig(&fconf)
config.DNS.FilteringEnabled = fconf.Enabled
config.DNS.FiltersUpdateIntervalHours = fconf.UpdateIntervalHours
config.Filters = fconf.DNSBlocklist
config.WhitelistFilters = fconf.DNSAllowlist
config.ProxyFilters = fconf.Proxylist
config.UserRules = fconf.UserRules
}
if Context.dnsServer != nil {
c := dnsforward.FilteringConfig{}
Context.dnsServer.WriteDiskConfig(&c)
@ -275,6 +327,12 @@ func (c *configuration) write() error {
config.DHCP = c
}
if Context.mitmProxy != nil {
c := mitmproxy.Config{}
Context.mitmProxy.WriteDiskConfig(&c)
config.MITM = c
}
configFile := config.getConfigFilename()
log.Debug("Writing YAML file: %s", configFile)
yamlText, err := yaml.Marshal(&config)

View file

@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/golibs/log"
"github.com/NYTimes/gziphandler"
"github.com/miekg/dns"
)
// ----------------
@ -87,6 +88,48 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write(data)
}
type checkHostResp struct {
Reason string `json:"reason"`
FilterID int64 `json:"filter_id"`
Rule string `json:"rule"`
// for FilteredBlockedService:
SvcName string `json:"service_name"`
// for ReasonRewrite:
CanonName string `json:"cname"` // CNAME value
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
}
func handleCheckHost(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
host := q.Get("name")
setts := Context.dnsFilter.GetConfig()
setts.FilteringEnabled = true
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
if err != nil {
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
return
}
resp := checkHostResp{}
resp.Reason = result.Reason.String()
resp.FilterID = result.FilterID
resp.Rule = result.Rule
resp.SvcName = result.ServiceName
resp.CanonName = result.CanonName
resp.IPList = result.IPList
js, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
}
// ------------------------
// registration of handlers
// ------------------------
@ -96,6 +139,7 @@ func registerControlHandlers() {
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON)))
httpRegister(http.MethodPost, "/control/update", handleUpdate)
httpRegister("GET", "/control/filtering/check_host", handleCheckHost)
httpRegister("GET", "/control/profile", handleGetProfile)
RegisterAuthHandlers()

View file

@ -1,391 +0,0 @@
package home
import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// isValidURL - return TRUE if URL or file path is valid
func isValidURL(rawurl string) bool {
if filepath.IsAbs(rawurl) {
// this is a file path
return util.FileExists(rawurl)
}
url, err := url.ParseRequestURI(rawurl)
if err != nil {
return false //Couldn't even parse the rawurl
}
if len(url.Scheme) == 0 {
return false //No Scheme found
}
return true
}
type filterAddJSON struct {
Name string `json:"name"`
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
}
func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
fj := filterAddJSON{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
return
}
if !isValidURL(fj.URL) {
http.Error(w, "Invalid URL or file path", http.StatusBadRequest)
return
}
// Check for duplicates
if filterExists(fj.URL) {
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
return
}
// Set necessary properties
filt := filter{
Enabled: true,
URL: fj.URL,
Name: fj.Name,
white: fj.Whitelist,
}
filt.ID = assignUniqueFilterID()
// Download the filter contents
ok, err := f.update(&filt)
if err != nil {
httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", filt.URL, err)
return
}
if !ok {
httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", filt.URL)
return
}
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
if !filterAdd(filt) {
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
return
}
onConfigModified()
enableFilters(true)
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
type request struct {
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
}
req := request{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
return
}
// go through each element and delete if url matches
config.Lock()
newFilters := []filter{}
filters := &config.Filters
if req.Whitelist {
filters = &config.WhitelistFilters
}
for _, filter := range *filters {
if filter.URL != req.URL {
newFilters = append(newFilters, filter)
} else {
err := os.Rename(filter.Path(), filter.Path()+".old")
if err != nil {
log.Error("os.Rename: %s: %s", filter.Path(), err)
}
}
}
// Update the configuration after removing filter files
*filters = newFilters
config.Unlock()
onConfigModified()
enableFilters(true)
// Note: the old files "filter.txt.old" aren't deleted - it's not really necessary,
// but will require the additional code to run after enableFilters() is finished: i.e. complicated
}
type filterURLJSON struct {
Name string `json:"name"`
URL string `json:"url"`
Enabled bool `json:"enabled"`
}
type filterURLReq struct {
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
Data filterURLJSON `json:"data"`
}
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
fj := filterURLReq{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
return
}
if !isValidURL(fj.Data.URL) {
http.Error(w, "invalid URL or file path", http.StatusBadRequest)
return
}
filt := filter{
Enabled: fj.Data.Enabled,
Name: fj.Data.Name,
URL: fj.Data.URL,
}
status := f.filterSetProperties(fj.URL, filt, fj.Whitelist)
if (status & statusFound) == 0 {
http.Error(w, "URL doesn't exist", http.StatusBadRequest)
return
}
if (status & statusURLExists) != 0 {
http.Error(w, "URL already exists", http.StatusBadRequest)
return
}
onConfigModified()
restart := false
if (status & statusEnabledChanged) != 0 {
// we must add or remove filter rules
restart = true
}
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
// download new filter and apply its rules
flags := FilterRefreshBlocklists
if fj.Whitelist {
flags = FilterRefreshAllowlists
}
nUpdated, _ := f.refreshFilters(flags, true)
// if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically
// if not - we restart the filtering ourselves
restart = false
if nUpdated == 0 {
restart = true
}
}
if restart {
enableFilters(true)
}
}
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err)
return
}
config.UserRules = strings.Split(string(body), "\n")
onConfigModified()
enableFilters(true)
}
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
type Req struct {
White bool `json:"whitelist"`
}
type Resp struct {
Updated int `json:"updated"`
}
resp := Resp{}
var err error
req := Req{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
return
}
Context.controlLock.Unlock()
flags := FilterRefreshBlocklists
if req.White {
flags = FilterRefreshAllowlists
}
resp.Updated, err = f.refreshFilters(flags|FilterRefreshForce, false)
Context.controlLock.Lock()
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
return
}
js, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
}
type filterJSON struct {
ID int64 `json:"id"`
Enabled bool `json:"enabled"`
URL string `json:"url"`
Name string `json:"name"`
RulesCount uint32 `json:"rules_count"`
LastUpdated string `json:"last_updated"`
}
type filteringConfig struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"` // in hours
Filters []filterJSON `json:"filters"`
WhitelistFilters []filterJSON `json:"whitelist_filters"`
UserRules []string `json:"user_rules"`
}
func filterToJSON(f filter) filterJSON {
fj := filterJSON{
ID: f.ID,
Enabled: f.Enabled,
URL: f.URL,
Name: f.Name,
RulesCount: uint32(f.RulesCount),
}
if !f.LastUpdated.IsZero() {
fj.LastUpdated = f.LastUpdated.Format(time.RFC3339)
}
return fj
}
// Get filtering configuration
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
resp := filteringConfig{}
config.RLock()
resp.Enabled = config.DNS.FilteringEnabled
resp.Interval = config.DNS.FiltersUpdateIntervalHours
for _, f := range config.Filters {
fj := filterToJSON(f)
resp.Filters = append(resp.Filters, fj)
}
for _, f := range config.WhitelistFilters {
fj := filterToJSON(f)
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
}
resp.UserRules = config.UserRules
config.RUnlock()
jsonVal, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(w, http.StatusInternalServerError, "http write: %s", err)
}
}
// Set filtering configuration
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
req := filteringConfig{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
return
}
if !checkFiltersUpdateIntervalHours(req.Interval) {
httpError(w, http.StatusBadRequest, "Unsupported interval")
return
}
config.DNS.FilteringEnabled = req.Enabled
config.DNS.FiltersUpdateIntervalHours = req.Interval
onConfigModified()
enableFilters(true)
}
type checkHostResp struct {
Reason string `json:"reason"`
FilterID int64 `json:"filter_id"`
Rule string `json:"rule"`
// for FilteredBlockedService:
SvcName string `json:"service_name"`
// for ReasonRewrite:
CanonName string `json:"cname"` // CNAME value
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
}
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
host := q.Get("name")
setts := Context.dnsFilter.GetConfig()
setts.FilteringEnabled = true
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
if err != nil {
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
return
}
resp := checkHostResp{}
resp.Reason = result.Reason.String()
resp.FilterID = result.FilterID
resp.Rule = result.Rule
resp.SvcName = result.ServiceName
resp.CanonName = result.CanonName
resp.IPList = result.IPList
js, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
}
// RegisterFilteringHandlers - register handlers
func (f *Filtering) RegisterFilteringHandlers() {
httpRegister("GET", "/control/filtering/status", f.handleFilteringStatus)
httpRegister("POST", "/control/filtering/config", f.handleFilteringConfig)
httpRegister("POST", "/control/filtering/add_url", f.handleFilteringAddURL)
httpRegister("POST", "/control/filtering/remove_url", f.handleFilteringRemoveURL)
httpRegister("POST", "/control/filtering/set_url", f.handleFilteringSetURL)
httpRegister("POST", "/control/filtering/refresh", f.handleFilteringRefresh)
httpRegister("POST", "/control/filtering/set_rules", f.handleFilteringSetRules)
httpRegister("GET", "/control/filtering/check_host", f.handleCheckHost)
}
func checkFiltersUpdateIntervalHours(i uint32) bool {
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
}

View file

@ -4,9 +4,11 @@ import (
"fmt"
"net"
"path/filepath"
"strings"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/filters"
"github.com/AdguardTeam/AdGuardHome/querylog"
"github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/AdGuardHome/util"
@ -77,8 +79,6 @@ func initDNSServer() error {
Context.rdns = InitRDNS(Context.dnsServer, &Context.clients)
Context.whois = initWhois(&Context.clients)
Context.filters.Init()
return nil
}
@ -277,6 +277,8 @@ func startDNSServer() error {
Context.dnsFilter.Start()
Context.filters.Start()
Context.filters.GetList(filters.DNSBlocklist).SetObserver(onFiltersChanged)
Context.filters.GetList(filters.DNSAllowlist).SetObserver(onFiltersChanged)
Context.stats.Start()
Context.queryLog.Start()
@ -345,3 +347,58 @@ func closeDNSServer() {
log.Debug("Closed all DNS modules")
}
func onFiltersChanged(flags uint) {
switch flags {
case filters.EventBeforeUpdate:
//
case filters.EventAfterUpdate:
enableFilters(true)
}
}
// Activate new DNS filters
// async: do it asynchronously (the function returns immediately)
func enableFilters(async bool) {
var blockFilters []dnsfilter.Filter
var allowFilters []dnsfilter.Filter
if config.DNS.FilteringEnabled {
// convert array of filters
// add user filter
userFilter := dnsfilter.Filter{
ID: 0,
Data: []byte(strings.Join(config.UserRules, "\n")),
}
blockFilters = append(blockFilters, userFilter)
// add blocklist filters
list := Context.filters.GetList(filters.DNSBlocklist).List(0)
for _, f := range list {
if !f.Enabled || f.RuleCount == 0 {
continue
}
f := dnsfilter.Filter{
ID: int64(f.ID),
FilePath: f.Path,
}
blockFilters = append(blockFilters, f)
}
// add allowlist filters
list = Context.filters.GetList(filters.DNSAllowlist).List(0)
for _, f := range list {
if !f.Enabled || f.RuleCount == 0 {
continue
}
f := dnsfilter.Filter{
ID: int64(f.ID),
FilePath: f.Path,
}
allowFilters = append(allowFilters, f)
}
}
_ = Context.dnsFilter.SetFilters(blockFilters, allowFilters, async)
}

View file

@ -1,65 +0,0 @@
package home
import (
"fmt"
"net"
"net/http"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func testStartFilterListener() net.Listener {
http.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
content := `||example.org^$third-party
# Inline comment example
||example.com^$third-party
0.0.0.0 example.com
`
_, _ = w.Write([]byte(content))
})
listener, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
go func() { _ = http.Serve(listener, nil) }()
return listener
}
func TestFilters(t *testing.T) {
l := testStartFilterListener()
defer func() { _ = l.Close() }()
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
Context = homeContext{}
Context.workDir = dir
Context.client = &http.Client{
Timeout: 5 * time.Second,
}
Context.filters.Init()
f := filter{
URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port),
}
// download
ok, err := Context.filters.update(&f)
assert.Equal(t, nil, err)
assert.True(t, ok)
assert.Equal(t, 3, f.RulesCount)
// refresh
ok, err = Context.filters.update(&f)
assert.True(t, !ok && err == nil)
err = Context.filters.load(&f)
assert.True(t, err == nil)
f.unload()
_ = os.Remove(f.Path())
}

View file

@ -20,6 +20,7 @@ import (
"gopkg.in/natefinch/lumberjack.v2"
"github.com/AdguardTeam/AdGuardHome/filters"
"github.com/AdguardTeam/AdGuardHome/update"
"github.com/AdguardTeam/AdGuardHome/util"
@ -30,6 +31,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/mitmproxy"
"github.com/AdguardTeam/AdGuardHome/querylog"
"github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/golibs/log"
@ -62,12 +64,14 @@ type homeContext struct {
dnsFilter *dnsfilter.Dnsfilter // DNS filtering module
dhcpServer *dhcpd.Server // DHCP module
auth *Auth // HTTP authentication module
filters Filtering // DNS filtering module
filters *filters.Filtering // DNS filtering module
web *Web // Web (HTTP, HTTPS) module
tls *TLSMod // TLS module
autoHosts util.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files
updater *update.Updater
mitmProxy *mitmproxy.MITMProxy // MITM proxy module
// Runtime properties
// --
@ -279,6 +283,28 @@ func run(args options) {
log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err)
}
fconf := filters.ModuleConf{}
fconf.Enabled = config.DNS.FilteringEnabled
fconf.UpdateIntervalHours = config.DNS.FiltersUpdateIntervalHours
fconf.DataDir = Context.getDataDir()
fconf.DNSBlocklist = config.Filters
fconf.DNSAllowlist = config.WhitelistFilters
fconf.UserRules = config.UserRules
fconf.Proxylist = config.ProxyFilters
fconf.HTTPClient = Context.client
fconf.ConfigModified = onConfigModified
fconf.HTTPRegister = httpRegister
Context.filters = filters.NewModule(fconf)
config.MITM.CertDir = Context.getDataDir()
config.MITM.ConfigModified = onConfigModified
config.MITM.HTTPRegister = httpRegister
config.MITM.Filter = Context.filters.GetList(filters.Proxylist)
Context.mitmProxy = mitmproxy.New(config.MITM)
if Context.mitmProxy == nil {
os.Exit(1)
}
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
GLMode = args.glinetMode
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
@ -317,6 +343,13 @@ func run(args options) {
}
}()
if Context.mitmProxy != nil {
err = Context.mitmProxy.Start()
if err != nil {
log.Fatal(err)
}
}
err = startDHCPServer()
if err != nil {
log.Fatal(err)
@ -501,10 +534,16 @@ func cleanup() {
Context.auth = nil
}
if Context.mitmProxy != nil {
Context.mitmProxy.Close()
Context.mitmProxy = nil
}
err := stopDNSServer()
if err != nil {
log.Error("Couldn't stop DNS server: %s", err)
}
err = stopDHCPServer()
if err != nil {
log.Error("Couldn't stop DHCP server: %s", err)

View file

@ -170,7 +170,7 @@ func TestHome(t *testing.T) {
assert.True(t, haveIP)
for i := 1; ; i++ {
st, err := os.Stat(filepath.Join(dir, "data", "filters", "1.txt"))
st, err := os.Stat(filepath.Join(dir, "data", "filters_dnsblock", "1.txt"))
if err == nil && st.Size() != 0 {
break
}

View file

@ -3,6 +3,8 @@ package home
import (
"fmt"
"testing"
"github.com/AdguardTeam/AdGuardHome/filters"
)
func TestUpgrade1to2(t *testing.T) {
@ -148,13 +150,13 @@ func compareConfigs(t *testing.T, oldConfig, newConfig *map[string]interface{})
if v != (*oldConfig)[k].(bool) {
t.Fatalf("wrong boolean value for %s", k)
}
case []filter:
if len((*oldConfig)[k].([]filter)) != len(value) {
t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filter)), len(value))
case []filters.Filter:
if len((*oldConfig)[k].([]filters.Filter)) != len(value) {
t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filters.Filter)), len(value))
}
for i, newFilter := range value {
oldFilter := (*oldConfig)[k].([]filter)[i]
if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RulesCount != newFilter.RulesCount {
oldFilter := (*oldConfig)[k].([]filters.Filter)[i]
if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RuleCount != newFilter.RuleCount {
t.Fatalf("old filter %s not equals new filter %s", oldFilter.Name, newFilter.Name)
}
}
@ -179,16 +181,16 @@ func compareSchemaVersion(t *testing.T, newSchemaVersion interface{}, schemaVers
func createTestDiskConfig(schemaVersion int) (diskConfig map[string]interface{}) {
diskConfig = make(map[string]interface{})
diskConfig["language"] = "en"
diskConfig["filters"] = []filter{
diskConfig["filters"] = []filters.Filter{
{
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
Name: "Latvian filter",
RulesCount: 100,
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
Name: "Latvian filter",
RuleCount: 100,
},
{
URL: "https://easylist.to/easylistgermany/easylistgermany.txt",
Name: "Germany filter",
RulesCount: 200,
URL: "https://easylist.to/easylistgermany/easylistgermany.txt",
Name: "Germany filter",
RuleCount: 200,
},
}
diskConfig["user_rules"] = []string{}

99
mitmproxy/mitm_http.go Normal file
View file

@ -0,0 +1,99 @@
package mitmproxy
import (
"encoding/json"
"fmt"
"net"
"net/http"
"strconv"
"github.com/AdguardTeam/golibs/jsonutil"
"github.com/AdguardTeam/golibs/log"
)
// Print to log and set HTTP error message
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("MITM: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
type mitmConfigJSON struct {
Enabled bool `json:"enabled"`
ListenAddr string `json:"listen_address"`
ListenPort int `json:"listen_port"`
UserName string `json:"auth_username"`
Password string `json:"auth_password"`
CertData string `json:"cert_data"`
PKeyData string `json:"pkey_data"`
}
func (p *MITMProxy) handleGetConfig(w http.ResponseWriter, r *http.Request) {
resp := mitmConfigJSON{}
p.confLock.Lock()
resp.Enabled = p.conf.Enabled
host, port, _ := net.SplitHostPort(p.conf.ListenAddr)
resp.ListenAddr = host
resp.ListenPort, _ = strconv.Atoi(port)
resp.UserName = p.conf.UserName
resp.Password = p.conf.Password
p.confLock.Unlock()
js, err := json.Marshal(resp)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "json.Marshal: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
}
func (p *MITMProxy) handleSetConfig(w http.ResponseWriter, r *http.Request) {
req := mitmConfigJSON{}
_, err := jsonutil.DecodeObject(&req, r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
if !((len(req.CertData) != 0 && len(req.PKeyData) != 0) ||
(len(req.CertData) == 0 && len(req.PKeyData) == 0)) {
httpError(r, w, http.StatusBadRequest, "certificate & private key must be both empty or specified")
return
}
p.confLock.Lock()
if len(req.CertData) != 0 {
err = p.storeCert([]byte(req.CertData), []byte(req.PKeyData))
if err != nil {
httpError(r, w, http.StatusInternalServerError, "%s", err)
p.confLock.Unlock()
return
}
p.conf.RegenCert = false
} else {
p.conf.RegenCert = true
}
p.conf.Enabled = req.Enabled
p.conf.ListenAddr = net.JoinHostPort(req.ListenAddr, strconv.Itoa(req.ListenPort))
p.conf.UserName = req.UserName
p.conf.Password = req.Password
p.confLock.Unlock()
p.conf.ConfigModified()
p.Close()
err = p.Restart()
if err != nil {
httpError(r, w, http.StatusInternalServerError, "%s", err)
return
}
}
// Initialize web handlers
func (p *MITMProxy) initWeb() {
p.conf.HTTPRegister("GET", "/control/proxy_info", p.handleGetConfig)
p.conf.HTTPRegister("POST", "/control/proxy_config", p.handleSetConfig)
}

57
mitmproxy/mitm_test.go Normal file
View file

@ -0,0 +1,57 @@
package mitmproxy
import (
"net/http"
"net/url"
"os"
"testing"
"github.com/AdguardTeam/AdGuardHome/filters"
"github.com/stretchr/testify/assert"
)
func prepareTestDir() string {
const dir = "./agh-test"
_ = os.RemoveAll(dir)
_ = os.MkdirAll(dir, 0755)
return dir
}
func TestMITM(t *testing.T) {
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
fconf := filters.Conf{}
fconf.FilterDir = dir
fconf.HTTPClient = http.DefaultClient
filters := filters.New(fconf)
conf := Config{}
conf.Enabled = true
conf.CertDir = dir
conf.RegenCert = true
conf.ListenAddr = "127.0.0.1:8081"
conf.Filter = filters
s := New(conf)
assert.NotNil(t, s)
err := s.Start()
assert.Nil(t, err)
proxyURL, _ := url.Parse("http://127.0.0.1:8081")
transport := &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
c := http.Client{
Transport: transport,
}
resp, err := c.Get("http://example.com/")
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
resp, err = c.Get("http://adguardhome.api/cert.crt")
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
s.Close()
}

279
mitmproxy/mitmproxy.go Normal file
View file

@ -0,0 +1,279 @@
package mitmproxy
import (
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"io/ioutil"
"net"
"net/http"
"path/filepath"
"strconv"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/filters"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/gomitmproxy/mitm"
"github.com/AdguardTeam/urlfilter/proxy"
)
// MITMProxy - MITM proxy structure
type MITMProxy struct {
proxy *proxy.Server
conf Config
confLock sync.Mutex
}
// Config - module configuration
type Config struct {
Enabled bool `yaml:"enabled"`
ListenAddr string `yaml:"listen_address"`
UserName string `yaml:"auth_username"`
Password string `yaml:"auth_password"`
// TLS:
RegenCert bool `yaml:"regenerate_cert"` // Regenerate certificate on cert loading failure
CertDir string `yaml:"-"` // Directory where Root certificate & pkey is stored
certFileName string
pkeyFileName string
certData []byte
pkeyData []byte
Filter filters.Filters `yaml:"-"`
// Called when the configuration is changed by HTTP request
ConfigModified func() `yaml:"-"`
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
}
// New - create a new instance of the query log
func New(conf Config) *MITMProxy {
p := MITMProxy{}
p.conf = conf
p.conf.certFileName = filepath.Join(p.conf.CertDir, "/http_proxy.crt")
p.conf.pkeyFileName = filepath.Join(p.conf.CertDir, "/http_proxy.key")
err := p.create()
if err != nil {
log.Error("MITM: %s", err)
return nil
}
if p.conf.HTTPRegister != nil {
p.initWeb()
}
p.conf.Filter.SetObserver(p.onFiltersChanged)
return &p
}
// Close - close the object
func (p *MITMProxy) Close() {
if p.proxy != nil {
p.proxy.Close()
p.proxy = nil
log.Debug("MITM: Closed proxy")
}
}
// WriteDiskConfig - write configuration on disk
func (p *MITMProxy) WriteDiskConfig(c *Config) {
p.confLock.Lock()
*c = p.conf
p.confLock.Unlock()
}
// Start - start proxy server
func (p *MITMProxy) Start() error {
if !p.conf.Enabled {
return nil
}
err := p.proxy.Start()
if err != nil {
return err
}
log.Debug("MITM: Running...")
return nil
}
// Restart - restart proxy server after Close()
func (p *MITMProxy) Restart() error {
err := p.create()
if err != nil {
return err
}
return p.Start()
}
// Create a gomitmproxy object
func (p *MITMProxy) create() error {
if !p.conf.Enabled {
return nil
}
c := proxy.Config{}
c.ProxyConfig.APIHost = "adguardhome.api"
addr, port, err := net.SplitHostPort(p.conf.ListenAddr)
if err != nil {
return fmt.Errorf("net.SplitHostPort: %s", err)
}
c.CompressContentScript = true
c.ProxyConfig.ListenAddr = &net.TCPAddr{}
c.ProxyConfig.ListenAddr.IP = net.ParseIP(addr)
if c.ProxyConfig.ListenAddr.IP == nil {
return fmt.Errorf("invalid IP: %s", addr)
}
c.ProxyConfig.ListenAddr.Port, err = strconv.Atoi(port)
if c.ProxyConfig.ListenAddr.Port < 0 || c.ProxyConfig.ListenAddr.Port > 0xffff || err != nil {
return fmt.Errorf("invalid port number: %s", port)
}
c.ProxyConfig.Username = p.conf.UserName
c.ProxyConfig.Password = p.conf.Password
err = p.loadCert()
if err != nil {
if !p.conf.RegenCert {
return err
}
log.Debug("%s", err)
// certificate or private key file doesn't exist - generate new
err = p.createRootCert()
if err != nil {
return err
}
}
c.ProxyConfig.MITMConfig, err = p.prepareMITMConfig()
if err != nil {
if !p.conf.RegenCert {
return err
}
// certificate or private key is invalid - generate new
err = p.createRootCert()
if err != nil {
return err
}
c.ProxyConfig.MITMConfig, err = p.prepareMITMConfig()
if err != nil {
return err
}
}
c.FiltersPaths = make(map[int]string)
filtrs := p.conf.Filter.List(0)
i := 0
for _, f := range filtrs {
if !f.Enabled ||
f.RuleCount == 0 { // not loaded
continue
}
c.FiltersPaths[i] = f.Path
i++
}
p.proxy, err = proxy.NewServer(c)
if err != nil {
return fmt.Errorf("proxy.NewServer: %s", err)
}
return nil
}
// Load cert and pkey from file
func (p *MITMProxy) loadCert() error {
var err error
p.conf.certData, err = ioutil.ReadFile(p.conf.certFileName)
if err != nil {
return err
}
p.conf.pkeyData, err = ioutil.ReadFile(p.conf.pkeyFileName)
if err != nil {
return err
}
return nil
}
// Create Root certificate and pkey and store it on disk
func (p *MITMProxy) createRootCert() error {
cert, key, err := mitm.NewAuthority("AdGuardHome Root", "AdGuard", 365*24*time.Hour)
if err != nil {
return err
}
p.conf.certData = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
p.conf.pkeyData = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
log.Debug("MITM: Created root certificate and key")
err = p.storeCert(p.conf.certData, p.conf.pkeyData)
if err != nil {
return err
}
return nil
}
// Store cert & pkey on disk
func (p *MITMProxy) storeCert(certData []byte, pkeyData []byte) error {
err := file.SafeWrite(p.conf.certFileName, certData)
if err != nil {
return err
}
err = file.SafeWrite(p.conf.pkeyFileName, pkeyData)
if err != nil {
return err
}
log.Debug("MITM: stored root certificate and key: %s, %s", p.conf.certFileName, p.conf.pkeyFileName)
return nil
}
// Fill TLSConfig & MITMConfig objects
func (p *MITMProxy) prepareMITMConfig() (*mitm.Config, error) {
tlsCert, err := tls.X509KeyPair(p.conf.certData, p.conf.pkeyData)
if err != nil {
return nil, fmt.Errorf("failed to load root CA: %v", err)
}
privateKey := tlsCert.PrivateKey.(*rsa.PrivateKey)
x509c, err := x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
return nil, fmt.Errorf("invalid certificate: %v", err)
}
mitmConfig, err := mitm.NewConfig(x509c, privateKey, nil)
if err != nil {
return nil, fmt.Errorf("failed to create MITM config: %v", err)
}
mitmConfig.SetValidity(time.Hour * 24 * 7) // generate certs valid for 7 days
mitmConfig.SetOrganization("AdGuard") // cert organization
return mitmConfig, nil
}
func (p *MITMProxy) onFiltersChanged(flags uint) {
switch flags {
case filters.EventBeforeUpdate:
p.Close()
case filters.EventAfterUpdate:
err := p.Restart()
if err != nil {
log.Error("MITM: %s", err)
}
}
}