mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-03-14 22:48:35 +03:00
MITM proxy
This commit is contained in:
parent
c3123473cf
commit
f85de51452
21 changed files with 2116 additions and 491 deletions
|
@ -52,15 +52,21 @@ Contents:
|
|||
* API: Get query log
|
||||
* API: Set querylog parameters
|
||||
* API: Get querylog parameters
|
||||
* Filtering
|
||||
* DNS Filtering
|
||||
* Filters update mechanism
|
||||
* API: Get filtering parameters
|
||||
* API: Set filtering parameters
|
||||
* API: Refresh filters
|
||||
* API: Add Filter
|
||||
* API: Set URL parameters
|
||||
* API: Delete URL
|
||||
* API: Set Filter parameters
|
||||
* API: Delete Filter
|
||||
* API: Domain Check
|
||||
* HTTP Proxy
|
||||
* API: Get Proxy settings
|
||||
* API: Set Proxy settings
|
||||
* API: Get Proxy filtering parameters
|
||||
* API: Add Proxy Filter
|
||||
* API: Delete Proxy Filter
|
||||
* Log-in page
|
||||
* API: Log in
|
||||
* API: Log out
|
||||
|
@ -1477,7 +1483,7 @@ Response:
|
|||
}
|
||||
|
||||
|
||||
## Filtering
|
||||
## DNS Filtering
|
||||
|
||||

|
||||
|
||||
|
@ -1548,7 +1554,19 @@ Response:
|
|||
}
|
||||
...
|
||||
],
|
||||
"user_rules":["...", ...]
|
||||
"user_rules":["...", ...],
|
||||
|
||||
"proxy_filtering_enabled": true | false
|
||||
"proxy_filters":[
|
||||
{
|
||||
"enabled":true,
|
||||
"url":"https://...",
|
||||
"name":"...",
|
||||
"rules_count":1234,
|
||||
"last_updated":"2019-09-04T18:29:30+00:00",
|
||||
}
|
||||
...
|
||||
],
|
||||
}
|
||||
|
||||
For both arrays `filters` and `whitelist_filters` there are unique values: id, url.
|
||||
|
@ -1563,6 +1581,7 @@ Request:
|
|||
|
||||
{
|
||||
"enabled": true | false
|
||||
"proxy_filtering_enabled": true | false
|
||||
"interval": 0 | 1 | 12 | 1*24 || 3*24 || 7*24
|
||||
}
|
||||
|
||||
|
@ -1578,7 +1597,7 @@ Request:
|
|||
POST /control/filtering/refresh
|
||||
|
||||
{
|
||||
"whitelist": true
|
||||
"type": blocklist | whitelist | proxylist
|
||||
}
|
||||
|
||||
Response:
|
||||
|
@ -1599,7 +1618,7 @@ Request:
|
|||
{
|
||||
"name": "..."
|
||||
"url": "..." // URL or an absolute file path
|
||||
"whitelist": true
|
||||
"type": blocklist | whitelist | proxylist
|
||||
}
|
||||
|
||||
Response:
|
||||
|
@ -1607,7 +1626,7 @@ Response:
|
|||
200 OK
|
||||
|
||||
|
||||
### API: Set URL parameters
|
||||
### API: Set Filter parameters
|
||||
|
||||
Request:
|
||||
|
||||
|
@ -1615,11 +1634,11 @@ Request:
|
|||
|
||||
{
|
||||
"url": "..."
|
||||
"whitelist": true
|
||||
"type": blocklist | whitelist | proxylist
|
||||
"data": {
|
||||
"name": "..."
|
||||
"url": "..."
|
||||
"enabled": true | false
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1628,7 +1647,7 @@ Response:
|
|||
200 OK
|
||||
|
||||
|
||||
### API: Delete URL
|
||||
### API: Delete Filter
|
||||
|
||||
Request:
|
||||
|
||||
|
@ -1636,7 +1655,7 @@ Request:
|
|||
|
||||
{
|
||||
"url": "..."
|
||||
"whitelist": true
|
||||
"type": blocklist | whitelist | proxylist
|
||||
}
|
||||
|
||||
Response:
|
||||
|
@ -1668,6 +1687,60 @@ Response:
|
|||
}
|
||||
|
||||
|
||||
## HTTP Proxy
|
||||
|
||||
Browser <-(HTTP)-> AGH Proxy <-(HTTP)-> Internet Server
|
||||
|
||||
HTTPS MITM:
|
||||
|
||||
. Browser --(CONNECT...)-> AGH Proxy --(handshake)-> Internet Server
|
||||
. Browser <-(handshake,cert/AGH)-- AGH Proxy <-(cert/issuer)-- Internet Server
|
||||
. Browser <-(TLS/session2)-> AGH Proxy <-(TLS/session1)-> Internet Server
|
||||
|
||||
|
||||
### API: Get Proxy settings
|
||||
|
||||
Request:
|
||||
|
||||
GET /control/proxy_info
|
||||
|
||||
Response:
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
"enabled": true|false,
|
||||
"listen_address": "ip",
|
||||
"listen_port": 12345,
|
||||
|
||||
"auth_username": "",
|
||||
"auth_password": ""
|
||||
}
|
||||
|
||||
|
||||
### API: Set Proxy settings
|
||||
|
||||
Request:
|
||||
|
||||
POST /control/proxy_config
|
||||
|
||||
{
|
||||
"enabled": true|false,
|
||||
"listen_address": "ip",
|
||||
"listen_port": 12345,
|
||||
|
||||
"auth_username": "",
|
||||
"auth_password": "",
|
||||
|
||||
"cert_data":"...", // user-specified certificate. "": generate new
|
||||
"pkey_data":"...",
|
||||
}
|
||||
|
||||
Response:
|
||||
|
||||
200 OK
|
||||
|
||||
|
||||
## Log-in page
|
||||
|
||||
After user completes the steps of installation wizard, he must log in into dashboard using his name and password. After user successfully logs in, he gets the Cookie which allows the server to authenticate him next time without password. After the Cookie is expired, user needs to perform log-in operation again.
|
||||
|
|
247
filters/filter_file.go
Normal file
247
filters/filter_file.go
Normal file
|
@ -0,0 +1,247 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Allows printable UTF-8 text with CR, LF, TAB characters
|
||||
func isPrintableText(data []byte) bool {
|
||||
for _, c := range data {
|
||||
if (c >= ' ' && c != 0x7f) || c == '\n' || c == '\r' || c == '\t' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Download filter data
|
||||
// Return nil on success. Set f.Path to a file path, or "" if the file was not modified
|
||||
func (fs *filterStg) downloadFilter(f *Filter) error {
|
||||
log.Debug("Filters: Downloading filter from %s", f.URL)
|
||||
|
||||
// create temp file
|
||||
tmpFile, err := ioutil.TempFile(filepath.Join(fs.conf.FilterDir), "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if tmpFile != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
}()
|
||||
|
||||
// create data reader object
|
||||
var reader io.Reader
|
||||
if filepath.IsAbs(f.URL) {
|
||||
f, err := os.Open(f.URL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open file: %s", err)
|
||||
}
|
||||
defer f.Close()
|
||||
reader = f
|
||||
} else {
|
||||
req, err := http.NewRequest("GET", f.URL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(f.LastModified) != 0 {
|
||||
req.Header.Add("If-Modified-Since", f.LastModified)
|
||||
}
|
||||
|
||||
resp, err := fs.conf.HTTPClient.Do(req)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
f.networkError = true
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode == 304 { // "NOT_MODIFIED"
|
||||
log.Debug("Filters: filter %s isn't modified since %s",
|
||||
f.URL, f.LastModified)
|
||||
f.LastUpdated = time.Now()
|
||||
f.Path = ""
|
||||
return nil
|
||||
|
||||
} else if resp.StatusCode != 200 {
|
||||
err := fmt.Errorf("Filters: Couldn't download filter from %s: status code: %d",
|
||||
f.URL, resp.StatusCode)
|
||||
return err
|
||||
}
|
||||
|
||||
f.LastModified = resp.Header.Get("Last-Modified")
|
||||
|
||||
reader = resp.Body
|
||||
}
|
||||
|
||||
// parse and validate data, write to a file
|
||||
err = writeFile(f, reader, tmpFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Closing the file before renaming it is necessary on Windows
|
||||
_ = tmpFile.Close()
|
||||
fname := fs.filePath(*f)
|
||||
err = os.Rename(tmpFile.Name(), fname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tmpFile = nil // prevent from deleting this file in "defer" handler
|
||||
|
||||
log.Debug("Filters: saved filter %s at %s", f.URL, fname)
|
||||
f.Path = fname
|
||||
f.LastUpdated = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func gatherUntil(dst []byte, dstLen int, src []byte, until int) int {
|
||||
num := util.MinInt(len(src), until-dstLen)
|
||||
return copy(dst[dstLen:], src[:num])
|
||||
}
|
||||
|
||||
func isHTML(buf []byte) bool {
|
||||
s := strings.ToLower(string(buf))
|
||||
return strings.Contains(s, "<html") ||
|
||||
strings.Contains(s, "<!doctype")
|
||||
}
|
||||
|
||||
// Read file data and count the number of rules
|
||||
func parseFilter(f *Filter, reader io.Reader) error {
|
||||
ruleCount := 0
|
||||
r := bufio.NewReader(reader)
|
||||
|
||||
log.Debug("Filters: parsing %s", f.URL)
|
||||
|
||||
var err error
|
||||
for err == nil {
|
||||
var line string
|
||||
line, err = r.ReadString('\n')
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if len(line) == 0 ||
|
||||
line[0] == '#' ||
|
||||
line[0] == '!' {
|
||||
continue
|
||||
}
|
||||
|
||||
ruleCount++
|
||||
}
|
||||
|
||||
log.Debug("Filters: %s: %d rules", f.URL, ruleCount)
|
||||
|
||||
f.RuleCount = uint64(ruleCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read data, parse, write to a file
|
||||
func writeFile(f *Filter, reader io.Reader, outFile *os.File) error {
|
||||
ruleCount := 0
|
||||
buf := make([]byte, 64*1024)
|
||||
total := 0
|
||||
var chunk []byte
|
||||
|
||||
firstChunk := make([]byte, 4*1024)
|
||||
firstChunkLen := 0
|
||||
|
||||
for {
|
||||
n, err := reader.Read(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
total += n
|
||||
|
||||
if !isPrintableText(buf[:n]) {
|
||||
return fmt.Errorf("data contains non-printable characters")
|
||||
}
|
||||
|
||||
if firstChunk != nil {
|
||||
// gather full buffer firstChunk and perform its data tests
|
||||
firstChunkLen += gatherUntil(firstChunk, firstChunkLen, buf[:n], len(firstChunk))
|
||||
|
||||
if firstChunkLen == len(firstChunk) ||
|
||||
err == io.EOF {
|
||||
|
||||
if isHTML(firstChunk[:firstChunkLen]) {
|
||||
return fmt.Errorf("data is HTML, not plain text")
|
||||
}
|
||||
|
||||
firstChunk = nil
|
||||
}
|
||||
}
|
||||
|
||||
_, err2 := outFile.Write(buf[:n])
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
|
||||
chunk = append(chunk, buf[:n]...)
|
||||
s := string(chunk)
|
||||
for len(s) != 0 {
|
||||
i, line := splitNext(&s, '\n')
|
||||
if i < 0 && err != io.EOF {
|
||||
// no more lines in the current chunk
|
||||
break
|
||||
}
|
||||
chunk = []byte(s)
|
||||
|
||||
if len(line) == 0 ||
|
||||
line[0] == '#' ||
|
||||
line[0] == '!' {
|
||||
continue
|
||||
}
|
||||
|
||||
ruleCount++
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Filters: updated filter %s: %d bytes, %d rules",
|
||||
f.URL, total, ruleCount)
|
||||
|
||||
f.RuleCount = uint64(ruleCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SplitNext - split string by a byte
|
||||
// Whitespace is trimmed
|
||||
// Return byte position and the first chunk
|
||||
func splitNext(data *string, by byte) (int, string) {
|
||||
s := *data
|
||||
i := strings.IndexByte(s, by)
|
||||
var chunk string
|
||||
if i < 0 {
|
||||
chunk = s
|
||||
s = ""
|
||||
|
||||
} else {
|
||||
chunk = s[:i]
|
||||
s = s[i+1:]
|
||||
}
|
||||
|
||||
*data = s
|
||||
chunk = strings.TrimSpace(chunk)
|
||||
return i, chunk
|
||||
}
|
329
filters/filter_http.go
Normal file
329
filters/filter_http.go
Normal file
|
@ -0,0 +1,329 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
"github.com/AdguardTeam/golibs/jsonutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Print to log and set HTTP error message
|
||||
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Info("Filters: %s %s: %s", r.Method, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
// IsValidURL - return TRUE if URL or file path is valid
|
||||
func IsValidURL(rawurl string) bool {
|
||||
if filepath.IsAbs(rawurl) {
|
||||
// this is a file path
|
||||
return util.FileExists(rawurl)
|
||||
}
|
||||
|
||||
url, err := url.ParseRequestURI(rawurl)
|
||||
if err != nil {
|
||||
return false //Couldn't even parse the rawurl
|
||||
}
|
||||
if len(url.Scheme) == 0 {
|
||||
return false //No Scheme found
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Filtering) getFilterModule(t string) Filters {
|
||||
switch t {
|
||||
case "blocklist":
|
||||
return f.dnsBlocklist
|
||||
|
||||
case "whitelist":
|
||||
return f.dnsAllowlist
|
||||
|
||||
case "proxylist":
|
||||
return f.Proxylist
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filtering) restartMods(t string) {
|
||||
fN := f.getFilterModule(t)
|
||||
fN.NotifyObserver(EventBeforeUpdate)
|
||||
fN.NotifyObserver(EventAfterUpdate)
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilterAdd(w http.ResponseWriter, r *http.Request) {
|
||||
type reqJSON struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
req := reqJSON{}
|
||||
_, err := jsonutil.DecodeObject(&req, r.Body)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
filterN := f.getFilterModule(req.Type)
|
||||
if filterN == nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid type: %s", req.Type)
|
||||
return
|
||||
}
|
||||
|
||||
filt := Filter{
|
||||
Enabled: true,
|
||||
Name: req.Name,
|
||||
URL: req.URL,
|
||||
}
|
||||
err = filterN.Add(filt)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "add filter: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
f.conf.ConfigModified()
|
||||
|
||||
f.restartMods(req.Type)
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilterRemove(w http.ResponseWriter, r *http.Request) {
|
||||
type reqJSON struct {
|
||||
URL string `json:"url"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
req := reqJSON{}
|
||||
_, err := jsonutil.DecodeObject(&req, r.Body)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
filterN := f.getFilterModule(req.Type)
|
||||
if filterN == nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid type: %s", req.Type)
|
||||
return
|
||||
}
|
||||
|
||||
removed := filterN.Delete(req.URL)
|
||||
if removed == nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "no filter with such URL")
|
||||
return
|
||||
}
|
||||
|
||||
f.conf.ConfigModified()
|
||||
|
||||
if removed.Enabled {
|
||||
f.restartMods(req.Type)
|
||||
}
|
||||
|
||||
err = os.Remove(removed.Path)
|
||||
if err != nil {
|
||||
log.Error("os.Remove: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilterModify(w http.ResponseWriter, r *http.Request) {
|
||||
type propsJSON struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
type reqJSON struct {
|
||||
URL string `json:"url"`
|
||||
Type string `json:"type"`
|
||||
Data propsJSON `json:"data"`
|
||||
}
|
||||
req := reqJSON{}
|
||||
_, err := jsonutil.DecodeObject(&req, r.Body)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
filterN := f.getFilterModule(req.Type)
|
||||
if filterN == nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid type: %s", req.Type)
|
||||
return
|
||||
}
|
||||
|
||||
st, _, err := filterN.Modify(req.URL, req.Data.Enabled, req.Data.Name, req.Data.URL)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
f.conf.ConfigModified()
|
||||
|
||||
if st == StatusChangedEnabled ||
|
||||
st == StatusChangedURL {
|
||||
|
||||
// TODO StatusChangedURL: delete old file
|
||||
|
||||
f.restartMods(req.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
f.conf.UserRules = strings.Split(string(body), "\n")
|
||||
f.conf.ConfigModified()
|
||||
f.restartMods("blocklist")
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
type reqJSON struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
req := reqJSON{}
|
||||
_, err := jsonutil.DecodeObject(&req, r.Body)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
filterN := f.getFilterModule(req.Type)
|
||||
if filterN == nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid type: %s", req.Type)
|
||||
return
|
||||
}
|
||||
|
||||
filterN.Refresh(0)
|
||||
}
|
||||
|
||||
type filterJSON struct {
|
||||
ID int64 `json:"id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
URL string `json:"url"`
|
||||
Name string `json:"name"`
|
||||
RulesCount uint32 `json:"rules_count"`
|
||||
LastUpdated string `json:"last_updated"`
|
||||
}
|
||||
|
||||
func filterToJSON(f Filter) filterJSON {
|
||||
fj := filterJSON{
|
||||
ID: int64(f.ID),
|
||||
Enabled: f.Enabled,
|
||||
URL: f.URL,
|
||||
Name: f.Name,
|
||||
RulesCount: uint32(f.RuleCount),
|
||||
}
|
||||
|
||||
if !f.LastUpdated.IsZero() {
|
||||
fj.LastUpdated = f.LastUpdated.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
return fj
|
||||
}
|
||||
|
||||
// Get filtering configuration
|
||||
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
|
||||
type respJSON struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Interval uint32 `json:"interval"` // in hours
|
||||
|
||||
Filters []filterJSON `json:"filters"`
|
||||
WhitelistFilters []filterJSON `json:"whitelist_filters"`
|
||||
UserRules []string `json:"user_rules"`
|
||||
|
||||
Proxylist []filterJSON `json:"proxy_filters"`
|
||||
}
|
||||
resp := respJSON{}
|
||||
|
||||
resp.Enabled = f.conf.Enabled
|
||||
resp.Interval = f.conf.UpdateIntervalHours
|
||||
resp.UserRules = f.conf.UserRules
|
||||
|
||||
f0 := f.dnsBlocklist.List(0)
|
||||
f1 := f.dnsAllowlist.List(0)
|
||||
f2 := f.Proxylist.List(0)
|
||||
|
||||
for _, filt := range f0 {
|
||||
fj := filterToJSON(filt)
|
||||
resp.Filters = append(resp.Filters, fj)
|
||||
}
|
||||
for _, filt := range f1 {
|
||||
fj := filterToJSON(filt)
|
||||
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
|
||||
}
|
||||
for _, filt := range f2 {
|
||||
fj := filterToJSON(filt)
|
||||
resp.Proxylist = append(resp.Proxylist, fj)
|
||||
}
|
||||
|
||||
jsonVal, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(jsonVal)
|
||||
}
|
||||
|
||||
// Set filtering configuration
|
||||
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
|
||||
type reqJSON struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Interval uint32 `json:"interval"`
|
||||
}
|
||||
req := reqJSON{}
|
||||
_, err := jsonutil.DecodeObject(&req, r.Body)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
if !CheckFiltersUpdateIntervalHours(req.Interval) {
|
||||
httpError(r, w, http.StatusBadRequest, "Unsupported interval")
|
||||
return
|
||||
}
|
||||
|
||||
restart := false
|
||||
if f.conf.Enabled != req.Enabled {
|
||||
restart = true
|
||||
}
|
||||
f.conf.Enabled = req.Enabled
|
||||
f.conf.UpdateIntervalHours = req.Interval
|
||||
|
||||
c := Conf{}
|
||||
c.UpdateIntervalHours = req.Interval
|
||||
f.dnsBlocklist.SetConfig(c)
|
||||
f.dnsAllowlist.SetConfig(c)
|
||||
f.Proxylist.SetConfig(c)
|
||||
|
||||
f.conf.ConfigModified()
|
||||
|
||||
if restart {
|
||||
f.restartMods("blocklist")
|
||||
}
|
||||
}
|
||||
|
||||
// registerWebHandlers - register handlers
|
||||
func (f *Filtering) registerWebHandlers() {
|
||||
f.conf.HTTPRegister("GET", "/control/filtering/status", f.handleFilteringStatus)
|
||||
f.conf.HTTPRegister("POST", "/control/filtering/config", f.handleFilteringConfig)
|
||||
f.conf.HTTPRegister("POST", "/control/filtering/add_url", f.handleFilterAdd)
|
||||
f.conf.HTTPRegister("POST", "/control/filtering/remove_url", f.handleFilterRemove)
|
||||
f.conf.HTTPRegister("POST", "/control/filtering/set_url", f.handleFilterModify)
|
||||
f.conf.HTTPRegister("POST", "/control/filtering/refresh", f.handleFilteringRefresh)
|
||||
f.conf.HTTPRegister("POST", "/control/filtering/set_rules", f.handleFilteringSetRules)
|
||||
}
|
||||
|
||||
// CheckFiltersUpdateIntervalHours - verify update interval
|
||||
func CheckFiltersUpdateIntervalHours(i uint32) bool {
|
||||
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
|
||||
}
|
118
filters/filter_module.go
Normal file
118
filters/filter_module.go
Normal file
|
@ -0,0 +1,118 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// Filtering - module object
|
||||
type Filtering struct {
|
||||
dnsBlocklist Filters // DNS blocklist filters
|
||||
dnsAllowlist Filters // DNS allowlist filters
|
||||
Proxylist Filters // MITM Proxy filtering module
|
||||
|
||||
conf ModuleConf
|
||||
}
|
||||
|
||||
// ModuleConf - module config
|
||||
type ModuleConf struct {
|
||||
Enabled bool
|
||||
UpdateIntervalHours uint32 // 0: disabled
|
||||
HTTPClient *http.Client
|
||||
DataDir string
|
||||
DNSBlocklist []Filter
|
||||
DNSAllowlist []Filter
|
||||
Proxylist []Filter
|
||||
UserRules []string
|
||||
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func()
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
|
||||
}
|
||||
|
||||
// NewModule - create module
|
||||
func NewModule(conf ModuleConf) *Filtering {
|
||||
f := Filtering{}
|
||||
f.conf = conf
|
||||
|
||||
fconf := Conf{}
|
||||
fconf.FilterDir = filepath.Join(conf.DataDir, "filters_dnsblock")
|
||||
fconf.List = conf.DNSBlocklist
|
||||
fconf.UpdateIntervalHours = conf.UpdateIntervalHours
|
||||
fconf.HTTPClient = conf.HTTPClient
|
||||
f.dnsBlocklist = New(fconf)
|
||||
|
||||
fconf = Conf{}
|
||||
fconf.FilterDir = filepath.Join(conf.DataDir, "filters_dnsallow")
|
||||
fconf.List = conf.DNSAllowlist
|
||||
fconf.UpdateIntervalHours = conf.UpdateIntervalHours
|
||||
fconf.HTTPClient = conf.HTTPClient
|
||||
f.dnsAllowlist = New(fconf)
|
||||
|
||||
fconf = Conf{}
|
||||
fconf.FilterDir = filepath.Join(conf.DataDir, "filters_mitmproxy")
|
||||
fconf.List = conf.Proxylist
|
||||
fconf.UpdateIntervalHours = conf.UpdateIntervalHours
|
||||
fconf.HTTPClient = conf.HTTPClient
|
||||
f.Proxylist = New(fconf)
|
||||
|
||||
return &f
|
||||
}
|
||||
|
||||
const (
|
||||
DNSBlocklist = iota
|
||||
DNSAllowlist
|
||||
Proxylist
|
||||
)
|
||||
|
||||
func stringArrayDup(a []string) []string {
|
||||
a2 := make([]string, len(a))
|
||||
copy(a2, a)
|
||||
return a2
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write configuration data
|
||||
func (f *Filtering) WriteDiskConfig(mc *ModuleConf) {
|
||||
mc.Enabled = f.conf.Enabled
|
||||
mc.UpdateIntervalHours = f.conf.UpdateIntervalHours
|
||||
mc.UserRules = stringArrayDup(f.conf.UserRules)
|
||||
|
||||
c := Conf{}
|
||||
f.dnsBlocklist.WriteDiskConfig(&c)
|
||||
mc.DNSBlocklist = c.List
|
||||
|
||||
c = Conf{}
|
||||
f.dnsAllowlist.WriteDiskConfig(&c)
|
||||
mc.DNSAllowlist = c.List
|
||||
|
||||
c = Conf{}
|
||||
f.Proxylist.WriteDiskConfig(&c)
|
||||
mc.Proxylist = c.List
|
||||
}
|
||||
|
||||
// GetList - get specific filter list
|
||||
func (f *Filtering) GetList(t uint32) Filters {
|
||||
switch t {
|
||||
case DNSBlocklist:
|
||||
return f.dnsBlocklist
|
||||
case DNSAllowlist:
|
||||
return f.dnsAllowlist
|
||||
case Proxylist:
|
||||
return f.Proxylist
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start - start module
|
||||
func (f *Filtering) Start() {
|
||||
f.dnsBlocklist.Start()
|
||||
f.dnsAllowlist.Start()
|
||||
f.Proxylist.Start()
|
||||
f.registerWebHandlers()
|
||||
}
|
||||
|
||||
// Close - close the module
|
||||
func (f *Filtering) Close() {
|
||||
}
|
246
filters/filter_storage.go
Normal file
246
filters/filter_storage.go
Normal file
|
@ -0,0 +1,246 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// filter storage object
|
||||
type filterStg struct {
|
||||
updateTaskRunning bool
|
||||
updated []Filter // list of filters that were downloaded during update procedure
|
||||
updateChan chan bool // signal for the update goroutine
|
||||
|
||||
conf *Conf
|
||||
confLock sync.Mutex
|
||||
nextID atomic.Uint64 // next filter ID
|
||||
|
||||
observer EventHandler // user function that receives notifications
|
||||
}
|
||||
|
||||
// initialize the module
|
||||
func newFiltersObj(conf Conf) Filters {
|
||||
fs := filterStg{}
|
||||
fs.conf = &Conf{}
|
||||
*fs.conf = conf
|
||||
fs.nextID.Store(uint64(time.Now().Unix()))
|
||||
fs.updateChan = make(chan bool, 2)
|
||||
return &fs
|
||||
}
|
||||
|
||||
// Start - start module
|
||||
func (fs *filterStg) Start() {
|
||||
_ = os.MkdirAll(fs.conf.FilterDir, 0755)
|
||||
|
||||
// Load all enabled filters
|
||||
// On error, RuleCount is set to 0 - users won't try to use such filters
|
||||
// and in the future the update procedure will re-download the file
|
||||
for i := range fs.conf.List {
|
||||
f := &fs.conf.List[i]
|
||||
|
||||
fname := fs.filePath(*f)
|
||||
st, err := os.Stat(fname)
|
||||
if err != nil {
|
||||
log.Debug("Filters: os.Stat: %s %s", fname, err)
|
||||
continue
|
||||
}
|
||||
f.LastUpdated = st.ModTime()
|
||||
|
||||
if !f.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(fname, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
log.Error("Filters: os.OpenFile: %s %s", fname, err)
|
||||
continue
|
||||
}
|
||||
|
||||
_ = parseFilter(f, file)
|
||||
file.Close()
|
||||
|
||||
f.nextUpdate = f.LastUpdated.Add(time.Duration(fs.conf.UpdateIntervalHours) * time.Hour)
|
||||
}
|
||||
|
||||
if !fs.updateTaskRunning {
|
||||
fs.updateTaskRunning = true
|
||||
go fs.updateBySignal()
|
||||
go fs.updateByTimer()
|
||||
}
|
||||
}
|
||||
|
||||
// Close - close the module
|
||||
func (fs *filterStg) Close() {
|
||||
fs.updateChan <- false
|
||||
close(fs.updateChan)
|
||||
}
|
||||
|
||||
// Duplicate filter array
|
||||
func arrayFilterDup(f []Filter) []Filter {
|
||||
nf := make([]Filter, len(f))
|
||||
copy(nf, f)
|
||||
return nf
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write configuration on disk
|
||||
func (fs *filterStg) WriteDiskConfig(c *Conf) {
|
||||
fs.confLock.Lock()
|
||||
*c = *fs.conf
|
||||
c.List = arrayFilterDup(fs.conf.List)
|
||||
fs.confLock.Unlock()
|
||||
}
|
||||
|
||||
// SetConfig - set new configuration settings
|
||||
func (fs *filterStg) SetConfig(c Conf) {
|
||||
fs.conf.UpdateIntervalHours = c.UpdateIntervalHours
|
||||
}
|
||||
|
||||
// SetObserver - set user handler for notifications
|
||||
func (fs *filterStg) SetObserver(handler EventHandler) {
|
||||
fs.observer = handler
|
||||
}
|
||||
|
||||
// NotifyObserver - notify users about the event
|
||||
func (fs *filterStg) NotifyObserver(flags uint) {
|
||||
if fs.observer == nil {
|
||||
return
|
||||
}
|
||||
fs.observer(flags)
|
||||
}
|
||||
|
||||
// List (thread safe)
|
||||
func (fs *filterStg) List(flags uint) []Filter {
|
||||
fs.confLock.Lock()
|
||||
list := make([]Filter, len(fs.conf.List))
|
||||
for i, f := range fs.conf.List {
|
||||
nf := f
|
||||
nf.Path = fs.filePath(f)
|
||||
list[i] = nf
|
||||
}
|
||||
fs.confLock.Unlock()
|
||||
return list
|
||||
}
|
||||
|
||||
// Add - add filter (thread safe)
|
||||
func (fs *filterStg) Add(nf Filter) error {
|
||||
fs.confLock.Lock()
|
||||
defer fs.confLock.Unlock()
|
||||
|
||||
for _, f := range fs.conf.List {
|
||||
if f.Name == nf.Name || f.URL == nf.URL {
|
||||
return fmt.Errorf("filter with this Name or URL already exists")
|
||||
}
|
||||
}
|
||||
|
||||
nf.ID = fs.nextFilterID()
|
||||
nf.Enabled = true
|
||||
err := fs.downloadFilter(&nf)
|
||||
if err != nil {
|
||||
log.Debug("%s", err)
|
||||
return err
|
||||
}
|
||||
fs.conf.List = append(fs.conf.List, nf)
|
||||
log.Debug("Filters: added filter %s", nf.URL)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete - remove filter (thread safe)
|
||||
func (fs *filterStg) Delete(url string) *Filter {
|
||||
fs.confLock.Lock()
|
||||
defer fs.confLock.Unlock()
|
||||
|
||||
nf := []Filter{}
|
||||
var found *Filter
|
||||
for i := range fs.conf.List {
|
||||
f := &fs.conf.List[i]
|
||||
|
||||
if f.URL == url {
|
||||
found = f
|
||||
continue
|
||||
}
|
||||
nf = append(nf, *f)
|
||||
}
|
||||
if found == nil {
|
||||
return nil
|
||||
}
|
||||
fs.conf.List = nf
|
||||
log.Debug("Filters: removed filter %s", url)
|
||||
found.Path = fs.filePath(*found) // the caller will delete the file
|
||||
return found
|
||||
}
|
||||
|
||||
// Modify - set filter properties (thread safe)
|
||||
// Return Status* bitarray
|
||||
func (fs *filterStg) Modify(url string, enabled bool, name string, newURL string) (int, Filter, error) {
|
||||
fs.confLock.Lock()
|
||||
defer fs.confLock.Unlock()
|
||||
|
||||
st := 0
|
||||
|
||||
for i := range fs.conf.List {
|
||||
f := &fs.conf.List[i]
|
||||
if f.URL == url {
|
||||
|
||||
backup := *f
|
||||
f.Name = name
|
||||
|
||||
if f.Enabled != enabled {
|
||||
f.Enabled = enabled
|
||||
st |= StatusChangedEnabled
|
||||
}
|
||||
|
||||
if f.URL != newURL {
|
||||
f.URL = newURL
|
||||
st |= StatusChangedURL
|
||||
}
|
||||
|
||||
needDownload := false
|
||||
|
||||
if (st & StatusChangedURL) != 0 {
|
||||
f.ID = fs.nextFilterID()
|
||||
needDownload = true
|
||||
|
||||
} else if (st&StatusChangedEnabled) != 0 && enabled {
|
||||
fname := fs.filePath(*f)
|
||||
file, err := os.OpenFile(fname, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
log.Debug("Filters: os.OpenFile: %s %s", fname, err)
|
||||
needDownload = true
|
||||
} else {
|
||||
_ = parseFilter(f, file)
|
||||
file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if needDownload {
|
||||
f.LastModified = ""
|
||||
f.RuleCount = 0
|
||||
err := fs.downloadFilter(f)
|
||||
if err != nil {
|
||||
*f = backup
|
||||
return 0, Filter{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return st, backup, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, Filter{}, fmt.Errorf("filter %s not found", url)
|
||||
}
|
||||
|
||||
// Get filter file name
|
||||
func (fs *filterStg) filePath(f Filter) string {
|
||||
return filepath.Join(fs.conf.FilterDir, fmt.Sprintf("%d.txt", f.ID))
|
||||
}
|
||||
|
||||
// Get next filter ID
|
||||
func (fs *filterStg) nextFilterID() uint64 {
|
||||
return fs.nextID.Inc()
|
||||
}
|
154
filters/filter_test.go
Normal file
154
filters/filter_test.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
func testStartFilterListener(counter *atomic.Uint32) net.Listener {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
|
||||
(*counter).Inc()
|
||||
content := `||example.org^$third-party
|
||||
# Inline comment example
|
||||
||example.com^$third-party
|
||||
0.0.0.0 example.com
|
||||
`
|
||||
_, _ = w.Write([]byte(content))
|
||||
})
|
||||
|
||||
mux.HandleFunc("/filters/2.txt", func(w http.ResponseWriter, r *http.Request) {
|
||||
(*counter).Inc()
|
||||
content := `||example.org^$third-party
|
||||
# Inline comment example
|
||||
||example.com^$third-party
|
||||
0.0.0.0 example.com
|
||||
1.1.1.1 example1.com
|
||||
`
|
||||
_, _ = w.Write([]byte(content))
|
||||
})
|
||||
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = http.Serve(listener, mux)
|
||||
}()
|
||||
return listener
|
||||
}
|
||||
|
||||
func prepareTestDir() string {
|
||||
const dir = "./agh-test"
|
||||
_ = os.RemoveAll(dir)
|
||||
_ = os.MkdirAll(dir, 0755)
|
||||
return dir
|
||||
}
|
||||
|
||||
var updateStatus atomic.Uint32
|
||||
|
||||
func onFiltersUpdate(flags uint) {
|
||||
switch flags {
|
||||
case EventBeforeUpdate:
|
||||
updateStatus.Store(updateStatus.Load() | 1)
|
||||
|
||||
case EventAfterUpdate:
|
||||
updateStatus.Store(updateStatus.Load() | 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters(t *testing.T) {
|
||||
counter := atomic.Uint32{}
|
||||
lhttp := testStartFilterListener(&counter)
|
||||
defer func() { _ = lhttp.Close() }()
|
||||
|
||||
dir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
fconf := Conf{}
|
||||
fconf.UpdateIntervalHours = 1
|
||||
fconf.FilterDir = dir
|
||||
fconf.HTTPClient = &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
fs := New(fconf)
|
||||
fs.SetObserver(onFiltersUpdate)
|
||||
fs.Start()
|
||||
|
||||
port := lhttp.Addr().(*net.TCPAddr).Port
|
||||
URL := fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", port)
|
||||
|
||||
// add and download
|
||||
f := Filter{
|
||||
URL: URL,
|
||||
}
|
||||
err := fs.Add(f)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// check
|
||||
l := fs.List(0)
|
||||
assert.Equal(t, 1, len(l))
|
||||
assert.Equal(t, URL, l[0].URL)
|
||||
assert.True(t, l[0].Enabled)
|
||||
assert.Equal(t, uint64(3), l[0].RuleCount)
|
||||
assert.True(t, l[0].ID != 0)
|
||||
|
||||
// disable
|
||||
st, _, err := fs.Modify(f.URL, false, "name", f.URL)
|
||||
assert.Equal(t, StatusChangedEnabled, st)
|
||||
|
||||
// check: disabled
|
||||
l = fs.List(0)
|
||||
assert.Equal(t, 1, len(l))
|
||||
assert.True(t, !l[0].Enabled)
|
||||
|
||||
// modify URL
|
||||
newURL := fmt.Sprintf("http://127.0.0.1:%d/filters/2.txt", port)
|
||||
st, modified, err := fs.Modify(URL, false, "name", newURL)
|
||||
assert.Equal(t, StatusChangedURL, st)
|
||||
|
||||
_ = os.Remove(modified.Path)
|
||||
|
||||
// check: new ID, new URL
|
||||
l = fs.List(0)
|
||||
assert.Equal(t, 1, len(l))
|
||||
assert.Equal(t, newURL, l[0].URL)
|
||||
assert.Equal(t, uint64(4), l[0].RuleCount)
|
||||
assert.True(t, modified.ID != l[0].ID)
|
||||
|
||||
// enable
|
||||
st, _, err = fs.Modify(newURL, true, "name", newURL)
|
||||
assert.Equal(t, StatusChangedEnabled, st)
|
||||
|
||||
// update
|
||||
cnt := counter.Load()
|
||||
fs.Refresh(0)
|
||||
for i := 0; ; i++ {
|
||||
if i == 2 {
|
||||
assert.True(t, false)
|
||||
break
|
||||
}
|
||||
if cnt != counter.Load() {
|
||||
// filter was updated
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
assert.Equal(t, uint32(1|2), updateStatus.Load())
|
||||
|
||||
// delete
|
||||
removed := fs.Delete(newURL)
|
||||
assert.NotNil(t, removed)
|
||||
_ = os.Remove(removed.Path)
|
||||
|
||||
fs.Close()
|
||||
}
|
176
filters/filter_update.go
Normal file
176
filters/filter_update.go
Normal file
|
@ -0,0 +1,176 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Refresh - begin filters update procedure
|
||||
func (fs *filterStg) Refresh(flags uint) {
|
||||
fs.confLock.Lock()
|
||||
defer fs.confLock.Unlock()
|
||||
|
||||
for i := range fs.conf.List {
|
||||
f := &fs.conf.List[i]
|
||||
f.nextUpdate = time.Time{}
|
||||
}
|
||||
|
||||
fs.updateChan <- true
|
||||
}
|
||||
|
||||
// Start update procedure periodically
|
||||
func (fs *filterStg) updateByTimer() {
|
||||
const maxPeriod = 1 * 60 * 60
|
||||
period := 5 // use a dynamically increasing time interval, while network or DNS is down
|
||||
for {
|
||||
if fs.conf.UpdateIntervalHours == 0 {
|
||||
period = maxPeriod
|
||||
// update is disabled
|
||||
time.Sleep(time.Duration(period) * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
fs.updateChan <- true
|
||||
|
||||
time.Sleep(time.Duration(period) * time.Second)
|
||||
period += period
|
||||
if period > maxPeriod {
|
||||
period = maxPeriod
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Begin update procedure by signal
|
||||
func (fs *filterStg) updateBySignal() {
|
||||
for {
|
||||
select {
|
||||
case ok := <-fs.updateChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
fs.updateAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update filters
|
||||
// Algorithm:
|
||||
// . Get next filter to update:
|
||||
// . Download data from Internet and store on disk (in a new file)
|
||||
// . Add new filter to the special list
|
||||
// . Repeat for next filter
|
||||
// (All filters are downloaded)
|
||||
// . Stop modules that use filters
|
||||
// . For each updated filter:
|
||||
// . Rename "new file name" -> "old file name"
|
||||
// . Update meta data
|
||||
// . Restart modules that use filters
|
||||
func (fs *filterStg) updateAll() {
|
||||
log.Debug("Filters: updating...")
|
||||
|
||||
for {
|
||||
var uf Filter
|
||||
fs.confLock.Lock()
|
||||
f := fs.getNextToUpdate()
|
||||
if f != nil {
|
||||
uf = *f
|
||||
}
|
||||
fs.confLock.Unlock()
|
||||
|
||||
if f == nil {
|
||||
fs.applyUpdate()
|
||||
return
|
||||
}
|
||||
|
||||
uf.ID = fs.nextFilterID()
|
||||
err := fs.downloadFilter(&uf)
|
||||
if err != nil {
|
||||
if uf.networkError {
|
||||
fs.confLock.Lock()
|
||||
f.nextUpdate = time.Now().Add(10 * time.Second)
|
||||
fs.confLock.Unlock()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// add new filter to the list
|
||||
fs.updated = append(fs.updated, uf)
|
||||
}
|
||||
}
|
||||
|
||||
// Get next filter to update
|
||||
func (fs *filterStg) getNextToUpdate() *Filter {
|
||||
now := time.Now()
|
||||
|
||||
for i := range fs.conf.List {
|
||||
f := &fs.conf.List[i]
|
||||
|
||||
if f.Enabled &&
|
||||
f.nextUpdate.Unix() <= now.Unix() {
|
||||
|
||||
f.nextUpdate = now.Add(time.Duration(fs.conf.UpdateIntervalHours) * time.Hour)
|
||||
return f
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Replace filter files
|
||||
func (fs *filterStg) applyUpdate() {
|
||||
if len(fs.updated) == 0 {
|
||||
log.Debug("Filters: no filters were updated")
|
||||
return
|
||||
}
|
||||
|
||||
fs.NotifyObserver(EventBeforeUpdate)
|
||||
|
||||
nUpdated := 0
|
||||
|
||||
fs.confLock.Lock()
|
||||
for _, uf := range fs.updated {
|
||||
found := false
|
||||
|
||||
for i := range fs.conf.List {
|
||||
f := &fs.conf.List[i]
|
||||
|
||||
if uf.URL == f.URL {
|
||||
found = true
|
||||
fpath := fs.filePath(*f)
|
||||
f.LastUpdated = uf.LastUpdated
|
||||
|
||||
if len(uf.Path) == 0 {
|
||||
// the data hasn't changed - just update file mod time
|
||||
err := os.Chtimes(fpath, f.LastUpdated, f.LastUpdated)
|
||||
if err != nil {
|
||||
log.Error("Filters: os.Chtimes: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
err := os.Rename(uf.Path, fpath)
|
||||
if err != nil {
|
||||
log.Error("Filters: os.Rename:%s", err)
|
||||
}
|
||||
|
||||
f.RuleCount = uf.RuleCount
|
||||
nUpdated++
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
// the updated filter was downloaded,
|
||||
// but it's already removed from the main list
|
||||
_ = os.Remove(fs.filePath(uf))
|
||||
}
|
||||
}
|
||||
fs.confLock.Unlock()
|
||||
|
||||
log.Debug("Filters: %d filters were updated", nUpdated)
|
||||
|
||||
fs.updated = nil
|
||||
fs.NotifyObserver(EventAfterUpdate)
|
||||
}
|
93
filters/filters.go
Normal file
93
filters/filters.go
Normal file
|
@ -0,0 +1,93 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Filters - main interface
|
||||
type Filters interface {
|
||||
// Start - start module
|
||||
Start()
|
||||
|
||||
// Close - close the module
|
||||
Close()
|
||||
|
||||
// WriteDiskConfig - write configuration on disk
|
||||
WriteDiskConfig(c *Conf)
|
||||
|
||||
// SetConfig - set new configuration settings
|
||||
// Currently only UpdateIntervalHours is supported
|
||||
SetConfig(c Conf)
|
||||
|
||||
// SetObserver - set user handler for notifications
|
||||
SetObserver(handler EventHandler)
|
||||
|
||||
// NotifyObserver - notify users about the event
|
||||
NotifyObserver(flags uint)
|
||||
|
||||
// List (thread safe)
|
||||
List(flags uint) []Filter
|
||||
|
||||
// Add - add filter (thread safe)
|
||||
Add(nf Filter) error
|
||||
|
||||
// Delete - remove filter (thread safe)
|
||||
Delete(url string) *Filter
|
||||
|
||||
// Modify - set filter properties (thread safe)
|
||||
// Return Status* bitarray, old filter properties and error
|
||||
Modify(url string, enabled bool, name string, newURL string) (int, Filter, error)
|
||||
|
||||
// Refresh - begin filters update procedure
|
||||
Refresh(flags uint)
|
||||
}
|
||||
|
||||
// Filter - filter object
|
||||
type Filter struct {
|
||||
ID uint64 `yaml:"id"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Name string `yaml:"name"`
|
||||
URL string `yaml:"url"`
|
||||
LastModified string `yaml:"last_modified"` // value of Last-Modified HTTP header field
|
||||
|
||||
Path string `yaml:"-"`
|
||||
|
||||
// number of rules
|
||||
// 0 means the file isn't loaded - user shouldn't use this filter
|
||||
RuleCount uint64 `yaml:"-"`
|
||||
|
||||
LastUpdated time.Time `yaml:"-"` // time of the last update (= file modification time)
|
||||
nextUpdate time.Time // time of the next update
|
||||
networkError bool // network error during download
|
||||
}
|
||||
|
||||
const (
|
||||
// EventBeforeUpdate - this event is signalled before the update procedure renames/removes old filter files
|
||||
EventBeforeUpdate = iota
|
||||
// EventAfterUpdate - this event is signalled after the update procedure is finished
|
||||
EventAfterUpdate
|
||||
)
|
||||
|
||||
// EventHandler - event handler function
|
||||
type EventHandler func(flags uint)
|
||||
|
||||
const (
|
||||
// StatusChangedEnabled - changed 'Enabled'
|
||||
StatusChangedEnabled = 2
|
||||
// StatusChangedURL - changed 'URL'
|
||||
StatusChangedURL = 4
|
||||
)
|
||||
|
||||
// Conf - configuration
|
||||
type Conf struct {
|
||||
FilterDir string
|
||||
UpdateIntervalHours uint32 // 0: disabled
|
||||
HTTPClient *http.Client
|
||||
List []Filter
|
||||
}
|
||||
|
||||
// New - create object
|
||||
func New(conf Conf) Filters {
|
||||
return newFiltersObj(conf)
|
||||
}
|
2
go.mod
2
go.mod
|
@ -5,6 +5,7 @@ go 1.14
|
|||
require (
|
||||
github.com/AdguardTeam/dnsproxy v0.30.1
|
||||
github.com/AdguardTeam/golibs v0.4.2
|
||||
github.com/AdguardTeam/gomitmproxy v0.2.0
|
||||
github.com/AdguardTeam/urlfilter v0.11.2
|
||||
github.com/NYTimes/gziphandler v1.1.1
|
||||
github.com/fsnotify/fsnotify v1.4.7
|
||||
|
@ -17,6 +18,7 @@ require (
|
|||
github.com/sparrc/go-ping v0.0.0-20190613174326-4e5b6552494c
|
||||
github.com/stretchr/testify v1.5.1
|
||||
go.etcd.io/bbolt v1.3.4
|
||||
go.uber.org/atomic v1.6.0
|
||||
golang.org/x/crypto v0.0.0-20200403201458-baeed622b8d8
|
||||
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e
|
||||
golang.org/x/sys v0.0.0-20200331124033-c3d80250170d
|
||||
|
|
8
go.sum
8
go.sum
|
@ -3,6 +3,7 @@ github.com/AdguardTeam/dnsproxy v0.30.1/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPq
|
|||
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||
github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o=
|
||||
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||
github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU=
|
||||
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
|
||||
github.com/AdguardTeam/urlfilter v0.11.2 h1:gCrWGh63Yqw3z4yi9pgikfsbshIEyvAu/KYV3MvTBlc=
|
||||
github.com/AdguardTeam/urlfilter v0.11.2/go.mod h1:aMuejlNxpWppOVjiEV87X6z0eMf7wsXHTAIWQuylfZY=
|
||||
|
@ -113,6 +114,8 @@ github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljT
|
|||
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
|
||||
go.etcd.io/bbolt v1.3.4 h1:hi1bXHMVrlQh6WwxAy+qZCV/SYIlqo+Ushwdpa4tAKg=
|
||||
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
|
||||
go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=
|
||||
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
|
@ -120,6 +123,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
|||
golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20200403201458-baeed622b8d8 h1:fpnn/HnJONpIu6hkXi1u/7rR0NzilgWr4T0JmWkEitk=
|
||||
golang.org/x/crypto v0.0.0-20200403201458-baeed622b8d8/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs=
|
||||
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
|
@ -145,9 +150,12 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
|||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20190624180213-70d37148ca0c/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 h1:VvQyQJN0tSuecqgcIxMWnnfG5kSmgy9KZR9sW3W5QeA=
|
||||
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/filters"
|
||||
"github.com/AdguardTeam/AdGuardHome/mitmproxy"
|
||||
"github.com/AdguardTeam/AdGuardHome/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||
"github.com/AdguardTeam/golibs/file"
|
||||
|
@ -17,8 +19,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
dataDir = "data" // data storage
|
||||
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
|
||||
dataDir = "data" // data storage
|
||||
)
|
||||
|
||||
// logSettings
|
||||
|
@ -54,9 +55,13 @@ type configuration struct {
|
|||
DNS dnsConfig `yaml:"dns"`
|
||||
TLS tlsConfigSettings `yaml:"tls"`
|
||||
|
||||
Filters []filter `yaml:"filters"`
|
||||
WhitelistFilters []filter `yaml:"whitelist_filters"`
|
||||
UserRules []string `yaml:"user_rules"`
|
||||
MITM mitmproxy.Config `yaml:"mitmproxy"`
|
||||
|
||||
Filters []filters.Filter `yaml:"filters"`
|
||||
WhitelistFilters []filters.Filter `yaml:"whitelist_filters"`
|
||||
UserRules []string `yaml:"user_rules"`
|
||||
|
||||
ProxyFilters []filters.Filter `yaml:"proxy_filters"`
|
||||
|
||||
DHCP dhcpd.ServerConfig `yaml:"dhcp"`
|
||||
|
||||
|
@ -155,7 +160,43 @@ func initConfig() {
|
|||
config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.CacheTime = 30
|
||||
config.Filters = defaultFilters()
|
||||
config.Filters = defaultDNSBlocklistFilters()
|
||||
|
||||
config.ProxyFilters = defaultContentFilters()
|
||||
}
|
||||
|
||||
func defaultDNSBlocklistFilters() []filters.Filter {
|
||||
return []filters.Filter{
|
||||
{
|
||||
ID: 1,
|
||||
Enabled: true,
|
||||
URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt",
|
||||
Name: "AdGuard Simplified Domain Names filter",
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Enabled: false,
|
||||
URL: "https://adaway.org/hosts.txt",
|
||||
Name: "AdAway",
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Enabled: false,
|
||||
URL: "https://www.malwaredomainlist.com/hostslist/hosts.txt",
|
||||
Name: "MalwareDomainList.com Hosts List",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func defaultContentFilters() []filters.Filter {
|
||||
return []filters.Filter{
|
||||
{
|
||||
ID: 1,
|
||||
Enabled: true,
|
||||
URL: "https://filters.adtidy.org/extension/chromium/filters/2.txt",
|
||||
Name: "AdGuard Base filter",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// getConfigFilename returns path to the current config file
|
||||
|
@ -203,7 +244,7 @@ func parseConfig() error {
|
|||
return err
|
||||
}
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
||||
if !filters.CheckFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
||||
config.DNS.FiltersUpdateIntervalHours = 24
|
||||
}
|
||||
|
||||
|
@ -263,6 +304,17 @@ func (c *configuration) write() error {
|
|||
config.DNS.DnsfilterConf = c
|
||||
}
|
||||
|
||||
if Context.filters != nil {
|
||||
fconf := filters.ModuleConf{}
|
||||
Context.filters.WriteDiskConfig(&fconf)
|
||||
config.DNS.FilteringEnabled = fconf.Enabled
|
||||
config.DNS.FiltersUpdateIntervalHours = fconf.UpdateIntervalHours
|
||||
config.Filters = fconf.DNSBlocklist
|
||||
config.WhitelistFilters = fconf.DNSAllowlist
|
||||
config.ProxyFilters = fconf.Proxylist
|
||||
config.UserRules = fconf.UserRules
|
||||
}
|
||||
|
||||
if Context.dnsServer != nil {
|
||||
c := dnsforward.FilteringConfig{}
|
||||
Context.dnsServer.WriteDiskConfig(&c)
|
||||
|
@ -275,6 +327,12 @@ func (c *configuration) write() error {
|
|||
config.DHCP = c
|
||||
}
|
||||
|
||||
if Context.mitmProxy != nil {
|
||||
c := mitmproxy.Config{}
|
||||
Context.mitmProxy.WriteDiskConfig(&c)
|
||||
config.MITM = c
|
||||
}
|
||||
|
||||
configFile := config.getConfigFilename()
|
||||
log.Debug("Writing YAML file: %s", configFile)
|
||||
yamlText, err := yaml.Marshal(&config)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ----------------
|
||||
|
@ -87,6 +88,48 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
|||
_, _ = w.Write(data)
|
||||
}
|
||||
|
||||
type checkHostResp struct {
|
||||
Reason string `json:"reason"`
|
||||
FilterID int64 `json:"filter_id"`
|
||||
Rule string `json:"rule"`
|
||||
|
||||
// for FilteredBlockedService:
|
||||
SvcName string `json:"service_name"`
|
||||
|
||||
// for ReasonRewrite:
|
||||
CanonName string `json:"cname"` // CNAME value
|
||||
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
|
||||
}
|
||||
|
||||
func handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
host := q.Get("name")
|
||||
|
||||
setts := Context.dnsFilter.GetConfig()
|
||||
setts.FilteringEnabled = true
|
||||
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
|
||||
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := checkHostResp{}
|
||||
resp.Reason = result.Reason.String()
|
||||
resp.FilterID = result.FilterID
|
||||
resp.Rule = result.Rule
|
||||
resp.SvcName = result.ServiceName
|
||||
resp.CanonName = result.CanonName
|
||||
resp.IPList = result.IPList
|
||||
js, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(js)
|
||||
}
|
||||
|
||||
// ------------------------
|
||||
// registration of handlers
|
||||
// ------------------------
|
||||
|
@ -96,6 +139,7 @@ func registerControlHandlers() {
|
|||
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
|
||||
http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON)))
|
||||
httpRegister(http.MethodPost, "/control/update", handleUpdate)
|
||||
httpRegister("GET", "/control/filtering/check_host", handleCheckHost)
|
||||
|
||||
httpRegister("GET", "/control/profile", handleGetProfile)
|
||||
RegisterAuthHandlers()
|
||||
|
|
|
@ -1,391 +0,0 @@
|
|||
package home
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// isValidURL - return TRUE if URL or file path is valid
|
||||
func isValidURL(rawurl string) bool {
|
||||
if filepath.IsAbs(rawurl) {
|
||||
// this is a file path
|
||||
return util.FileExists(rawurl)
|
||||
}
|
||||
|
||||
url, err := url.ParseRequestURI(rawurl)
|
||||
if err != nil {
|
||||
return false //Couldn't even parse the rawurl
|
||||
}
|
||||
if len(url.Scheme) == 0 {
|
||||
return false //No Scheme found
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type filterAddJSON struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Whitelist bool `json:"whitelist"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
||||
fj := filterAddJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&fj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !isValidURL(fj.URL) {
|
||||
http.Error(w, "Invalid URL or file path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for duplicates
|
||||
if filterExists(fj.URL) {
|
||||
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
|
||||
return
|
||||
}
|
||||
|
||||
// Set necessary properties
|
||||
filt := filter{
|
||||
Enabled: true,
|
||||
URL: fj.URL,
|
||||
Name: fj.Name,
|
||||
white: fj.Whitelist,
|
||||
}
|
||||
filt.ID = assignUniqueFilterID()
|
||||
|
||||
// Download the filter contents
|
||||
ok, err := f.update(&filt)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", filt.URL, err)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", filt.URL)
|
||||
return
|
||||
}
|
||||
|
||||
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
|
||||
if !filterAdd(filt) {
|
||||
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
|
||||
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
type request struct {
|
||||
URL string `json:"url"`
|
||||
Whitelist bool `json:"whitelist"`
|
||||
}
|
||||
req := request{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// go through each element and delete if url matches
|
||||
config.Lock()
|
||||
newFilters := []filter{}
|
||||
filters := &config.Filters
|
||||
if req.Whitelist {
|
||||
filters = &config.WhitelistFilters
|
||||
}
|
||||
for _, filter := range *filters {
|
||||
if filter.URL != req.URL {
|
||||
newFilters = append(newFilters, filter)
|
||||
} else {
|
||||
err := os.Rename(filter.Path(), filter.Path()+".old")
|
||||
if err != nil {
|
||||
log.Error("os.Rename: %s: %s", filter.Path(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Update the configuration after removing filter files
|
||||
*filters = newFilters
|
||||
config.Unlock()
|
||||
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
|
||||
// Note: the old files "filter.txt.old" aren't deleted - it's not really necessary,
|
||||
// but will require the additional code to run after enableFilters() is finished: i.e. complicated
|
||||
}
|
||||
|
||||
type filterURLJSON struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type filterURLReq struct {
|
||||
URL string `json:"url"`
|
||||
Whitelist bool `json:"whitelist"`
|
||||
Data filterURLJSON `json:"data"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
|
||||
fj := filterURLReq{}
|
||||
err := json.NewDecoder(r.Body).Decode(&fj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !isValidURL(fj.Data.URL) {
|
||||
http.Error(w, "invalid URL or file path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
filt := filter{
|
||||
Enabled: fj.Data.Enabled,
|
||||
Name: fj.Data.Name,
|
||||
URL: fj.Data.URL,
|
||||
}
|
||||
status := f.filterSetProperties(fj.URL, filt, fj.Whitelist)
|
||||
if (status & statusFound) == 0 {
|
||||
http.Error(w, "URL doesn't exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if (status & statusURLExists) != 0 {
|
||||
http.Error(w, "URL already exists", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
restart := false
|
||||
if (status & statusEnabledChanged) != 0 {
|
||||
// we must add or remove filter rules
|
||||
restart = true
|
||||
}
|
||||
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
|
||||
// download new filter and apply its rules
|
||||
flags := FilterRefreshBlocklists
|
||||
if fj.Whitelist {
|
||||
flags = FilterRefreshAllowlists
|
||||
}
|
||||
nUpdated, _ := f.refreshFilters(flags, true)
|
||||
// if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically
|
||||
// if not - we restart the filtering ourselves
|
||||
restart = false
|
||||
if nUpdated == 0 {
|
||||
restart = true
|
||||
}
|
||||
}
|
||||
if restart {
|
||||
enableFilters(true)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
config.UserRules = strings.Split(string(body), "\n")
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
type Req struct {
|
||||
White bool `json:"whitelist"`
|
||||
}
|
||||
type Resp struct {
|
||||
Updated int `json:"updated"`
|
||||
}
|
||||
resp := Resp{}
|
||||
var err error
|
||||
|
||||
req := Req{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
Context.controlLock.Unlock()
|
||||
flags := FilterRefreshBlocklists
|
||||
if req.White {
|
||||
flags = FilterRefreshAllowlists
|
||||
}
|
||||
resp.Updated, err = f.refreshFilters(flags|FilterRefreshForce, false)
|
||||
Context.controlLock.Lock()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
js, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(js)
|
||||
}
|
||||
|
||||
type filterJSON struct {
|
||||
ID int64 `json:"id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
URL string `json:"url"`
|
||||
Name string `json:"name"`
|
||||
RulesCount uint32 `json:"rules_count"`
|
||||
LastUpdated string `json:"last_updated"`
|
||||
}
|
||||
|
||||
type filteringConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Interval uint32 `json:"interval"` // in hours
|
||||
Filters []filterJSON `json:"filters"`
|
||||
WhitelistFilters []filterJSON `json:"whitelist_filters"`
|
||||
UserRules []string `json:"user_rules"`
|
||||
}
|
||||
|
||||
func filterToJSON(f filter) filterJSON {
|
||||
fj := filterJSON{
|
||||
ID: f.ID,
|
||||
Enabled: f.Enabled,
|
||||
URL: f.URL,
|
||||
Name: f.Name,
|
||||
RulesCount: uint32(f.RulesCount),
|
||||
}
|
||||
|
||||
if !f.LastUpdated.IsZero() {
|
||||
fj.LastUpdated = f.LastUpdated.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
return fj
|
||||
}
|
||||
|
||||
// Get filtering configuration
|
||||
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
|
||||
resp := filteringConfig{}
|
||||
config.RLock()
|
||||
resp.Enabled = config.DNS.FilteringEnabled
|
||||
resp.Interval = config.DNS.FiltersUpdateIntervalHours
|
||||
for _, f := range config.Filters {
|
||||
fj := filterToJSON(f)
|
||||
resp.Filters = append(resp.Filters, fj)
|
||||
}
|
||||
for _, f := range config.WhitelistFilters {
|
||||
fj := filterToJSON(f)
|
||||
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
|
||||
}
|
||||
resp.UserRules = config.UserRules
|
||||
config.RUnlock()
|
||||
|
||||
jsonVal, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonVal)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "http write: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set filtering configuration
|
||||
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
|
||||
req := filteringConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(req.Interval) {
|
||||
httpError(w, http.StatusBadRequest, "Unsupported interval")
|
||||
return
|
||||
}
|
||||
|
||||
config.DNS.FilteringEnabled = req.Enabled
|
||||
config.DNS.FiltersUpdateIntervalHours = req.Interval
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
}
|
||||
|
||||
type checkHostResp struct {
|
||||
Reason string `json:"reason"`
|
||||
FilterID int64 `json:"filter_id"`
|
||||
Rule string `json:"rule"`
|
||||
|
||||
// for FilteredBlockedService:
|
||||
SvcName string `json:"service_name"`
|
||||
|
||||
// for ReasonRewrite:
|
||||
CanonName string `json:"cname"` // CNAME value
|
||||
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
|
||||
}
|
||||
|
||||
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
host := q.Get("name")
|
||||
|
||||
setts := Context.dnsFilter.GetConfig()
|
||||
setts.FilteringEnabled = true
|
||||
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
|
||||
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := checkHostResp{}
|
||||
resp.Reason = result.Reason.String()
|
||||
resp.FilterID = result.FilterID
|
||||
resp.Rule = result.Rule
|
||||
resp.SvcName = result.ServiceName
|
||||
resp.CanonName = result.CanonName
|
||||
resp.IPList = result.IPList
|
||||
js, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(js)
|
||||
}
|
||||
|
||||
// RegisterFilteringHandlers - register handlers
|
||||
func (f *Filtering) RegisterFilteringHandlers() {
|
||||
httpRegister("GET", "/control/filtering/status", f.handleFilteringStatus)
|
||||
httpRegister("POST", "/control/filtering/config", f.handleFilteringConfig)
|
||||
httpRegister("POST", "/control/filtering/add_url", f.handleFilteringAddURL)
|
||||
httpRegister("POST", "/control/filtering/remove_url", f.handleFilteringRemoveURL)
|
||||
httpRegister("POST", "/control/filtering/set_url", f.handleFilteringSetURL)
|
||||
httpRegister("POST", "/control/filtering/refresh", f.handleFilteringRefresh)
|
||||
httpRegister("POST", "/control/filtering/set_rules", f.handleFilteringSetRules)
|
||||
httpRegister("GET", "/control/filtering/check_host", f.handleCheckHost)
|
||||
}
|
||||
|
||||
func checkFiltersUpdateIntervalHours(i uint32) bool {
|
||||
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
|
||||
}
|
61
home/dns.go
61
home/dns.go
|
@ -4,9 +4,11 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/filters"
|
||||
"github.com/AdguardTeam/AdGuardHome/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
|
@ -77,8 +79,6 @@ func initDNSServer() error {
|
|||
|
||||
Context.rdns = InitRDNS(Context.dnsServer, &Context.clients)
|
||||
Context.whois = initWhois(&Context.clients)
|
||||
|
||||
Context.filters.Init()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -277,6 +277,8 @@ func startDNSServer() error {
|
|||
|
||||
Context.dnsFilter.Start()
|
||||
Context.filters.Start()
|
||||
Context.filters.GetList(filters.DNSBlocklist).SetObserver(onFiltersChanged)
|
||||
Context.filters.GetList(filters.DNSAllowlist).SetObserver(onFiltersChanged)
|
||||
Context.stats.Start()
|
||||
Context.queryLog.Start()
|
||||
|
||||
|
@ -345,3 +347,58 @@ func closeDNSServer() {
|
|||
|
||||
log.Debug("Closed all DNS modules")
|
||||
}
|
||||
|
||||
func onFiltersChanged(flags uint) {
|
||||
switch flags {
|
||||
case filters.EventBeforeUpdate:
|
||||
//
|
||||
|
||||
case filters.EventAfterUpdate:
|
||||
enableFilters(true)
|
||||
}
|
||||
}
|
||||
|
||||
// Activate new DNS filters
|
||||
// async: do it asynchronously (the function returns immediately)
|
||||
func enableFilters(async bool) {
|
||||
var blockFilters []dnsfilter.Filter
|
||||
var allowFilters []dnsfilter.Filter
|
||||
if config.DNS.FilteringEnabled {
|
||||
// convert array of filters
|
||||
|
||||
// add user filter
|
||||
userFilter := dnsfilter.Filter{
|
||||
ID: 0,
|
||||
Data: []byte(strings.Join(config.UserRules, "\n")),
|
||||
}
|
||||
blockFilters = append(blockFilters, userFilter)
|
||||
|
||||
// add blocklist filters
|
||||
list := Context.filters.GetList(filters.DNSBlocklist).List(0)
|
||||
for _, f := range list {
|
||||
if !f.Enabled || f.RuleCount == 0 {
|
||||
continue
|
||||
}
|
||||
f := dnsfilter.Filter{
|
||||
ID: int64(f.ID),
|
||||
FilePath: f.Path,
|
||||
}
|
||||
blockFilters = append(blockFilters, f)
|
||||
}
|
||||
|
||||
// add allowlist filters
|
||||
list = Context.filters.GetList(filters.DNSAllowlist).List(0)
|
||||
for _, f := range list {
|
||||
if !f.Enabled || f.RuleCount == 0 {
|
||||
continue
|
||||
}
|
||||
f := dnsfilter.Filter{
|
||||
ID: int64(f.ID),
|
||||
FilePath: f.Path,
|
||||
}
|
||||
allowFilters = append(allowFilters, f)
|
||||
}
|
||||
}
|
||||
|
||||
_ = Context.dnsFilter.SetFilters(blockFilters, allowFilters, async)
|
||||
}
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testStartFilterListener() net.Listener {
|
||||
http.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
|
||||
content := `||example.org^$third-party
|
||||
# Inline comment example
|
||||
||example.com^$third-party
|
||||
0.0.0.0 example.com
|
||||
`
|
||||
_, _ = w.Write([]byte(content))
|
||||
})
|
||||
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go func() { _ = http.Serve(listener, nil) }()
|
||||
return listener
|
||||
}
|
||||
|
||||
func TestFilters(t *testing.T) {
|
||||
l := testStartFilterListener()
|
||||
defer func() { _ = l.Close() }()
|
||||
|
||||
dir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
Context = homeContext{}
|
||||
Context.workDir = dir
|
||||
Context.client = &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
Context.filters.Init()
|
||||
|
||||
f := filter{
|
||||
URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port),
|
||||
}
|
||||
|
||||
// download
|
||||
ok, err := Context.filters.update(&f)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 3, f.RulesCount)
|
||||
|
||||
// refresh
|
||||
ok, err = Context.filters.update(&f)
|
||||
assert.True(t, !ok && err == nil)
|
||||
|
||||
err = Context.filters.load(&f)
|
||||
assert.True(t, err == nil)
|
||||
|
||||
f.unload()
|
||||
_ = os.Remove(f.Path())
|
||||
}
|
41
home/home.go
41
home/home.go
|
@ -20,6 +20,7 @@ import (
|
|||
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/filters"
|
||||
"github.com/AdguardTeam/AdGuardHome/update"
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
|
||||
|
@ -30,6 +31,7 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/mitmproxy"
|
||||
"github.com/AdguardTeam/AdGuardHome/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
|
@ -62,12 +64,14 @@ type homeContext struct {
|
|||
dnsFilter *dnsfilter.Dnsfilter // DNS filtering module
|
||||
dhcpServer *dhcpd.Server // DHCP module
|
||||
auth *Auth // HTTP authentication module
|
||||
filters Filtering // DNS filtering module
|
||||
filters *filters.Filtering // DNS filtering module
|
||||
web *Web // Web (HTTP, HTTPS) module
|
||||
tls *TLSMod // TLS module
|
||||
autoHosts util.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files
|
||||
updater *update.Updater
|
||||
|
||||
mitmProxy *mitmproxy.MITMProxy // MITM proxy module
|
||||
|
||||
// Runtime properties
|
||||
// --
|
||||
|
||||
|
@ -279,6 +283,28 @@ func run(args options) {
|
|||
log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err)
|
||||
}
|
||||
|
||||
fconf := filters.ModuleConf{}
|
||||
fconf.Enabled = config.DNS.FilteringEnabled
|
||||
fconf.UpdateIntervalHours = config.DNS.FiltersUpdateIntervalHours
|
||||
fconf.DataDir = Context.getDataDir()
|
||||
fconf.DNSBlocklist = config.Filters
|
||||
fconf.DNSAllowlist = config.WhitelistFilters
|
||||
fconf.UserRules = config.UserRules
|
||||
fconf.Proxylist = config.ProxyFilters
|
||||
fconf.HTTPClient = Context.client
|
||||
fconf.ConfigModified = onConfigModified
|
||||
fconf.HTTPRegister = httpRegister
|
||||
Context.filters = filters.NewModule(fconf)
|
||||
|
||||
config.MITM.CertDir = Context.getDataDir()
|
||||
config.MITM.ConfigModified = onConfigModified
|
||||
config.MITM.HTTPRegister = httpRegister
|
||||
config.MITM.Filter = Context.filters.GetList(filters.Proxylist)
|
||||
Context.mitmProxy = mitmproxy.New(config.MITM)
|
||||
if Context.mitmProxy == nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
|
||||
GLMode = args.glinetMode
|
||||
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
|
||||
|
@ -317,6 +343,13 @@ func run(args options) {
|
|||
}
|
||||
}()
|
||||
|
||||
if Context.mitmProxy != nil {
|
||||
err = Context.mitmProxy.Start()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err = startDHCPServer()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -501,10 +534,16 @@ func cleanup() {
|
|||
Context.auth = nil
|
||||
}
|
||||
|
||||
if Context.mitmProxy != nil {
|
||||
Context.mitmProxy.Close()
|
||||
Context.mitmProxy = nil
|
||||
}
|
||||
|
||||
err := stopDNSServer()
|
||||
if err != nil {
|
||||
log.Error("Couldn't stop DNS server: %s", err)
|
||||
}
|
||||
|
||||
err = stopDHCPServer()
|
||||
if err != nil {
|
||||
log.Error("Couldn't stop DHCP server: %s", err)
|
||||
|
|
|
@ -170,7 +170,7 @@ func TestHome(t *testing.T) {
|
|||
assert.True(t, haveIP)
|
||||
|
||||
for i := 1; ; i++ {
|
||||
st, err := os.Stat(filepath.Join(dir, "data", "filters", "1.txt"))
|
||||
st, err := os.Stat(filepath.Join(dir, "data", "filters_dnsblock", "1.txt"))
|
||||
if err == nil && st.Size() != 0 {
|
||||
break
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ package home
|
|||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/filters"
|
||||
)
|
||||
|
||||
func TestUpgrade1to2(t *testing.T) {
|
||||
|
@ -148,13 +150,13 @@ func compareConfigs(t *testing.T, oldConfig, newConfig *map[string]interface{})
|
|||
if v != (*oldConfig)[k].(bool) {
|
||||
t.Fatalf("wrong boolean value for %s", k)
|
||||
}
|
||||
case []filter:
|
||||
if len((*oldConfig)[k].([]filter)) != len(value) {
|
||||
t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filter)), len(value))
|
||||
case []filters.Filter:
|
||||
if len((*oldConfig)[k].([]filters.Filter)) != len(value) {
|
||||
t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filters.Filter)), len(value))
|
||||
}
|
||||
for i, newFilter := range value {
|
||||
oldFilter := (*oldConfig)[k].([]filter)[i]
|
||||
if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RulesCount != newFilter.RulesCount {
|
||||
oldFilter := (*oldConfig)[k].([]filters.Filter)[i]
|
||||
if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RuleCount != newFilter.RuleCount {
|
||||
t.Fatalf("old filter %s not equals new filter %s", oldFilter.Name, newFilter.Name)
|
||||
}
|
||||
}
|
||||
|
@ -179,16 +181,16 @@ func compareSchemaVersion(t *testing.T, newSchemaVersion interface{}, schemaVers
|
|||
func createTestDiskConfig(schemaVersion int) (diskConfig map[string]interface{}) {
|
||||
diskConfig = make(map[string]interface{})
|
||||
diskConfig["language"] = "en"
|
||||
diskConfig["filters"] = []filter{
|
||||
diskConfig["filters"] = []filters.Filter{
|
||||
{
|
||||
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
|
||||
Name: "Latvian filter",
|
||||
RulesCount: 100,
|
||||
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
|
||||
Name: "Latvian filter",
|
||||
RuleCount: 100,
|
||||
},
|
||||
{
|
||||
URL: "https://easylist.to/easylistgermany/easylistgermany.txt",
|
||||
Name: "Germany filter",
|
||||
RulesCount: 200,
|
||||
URL: "https://easylist.to/easylistgermany/easylistgermany.txt",
|
||||
Name: "Germany filter",
|
||||
RuleCount: 200,
|
||||
},
|
||||
}
|
||||
diskConfig["user_rules"] = []string{}
|
||||
|
|
99
mitmproxy/mitm_http.go
Normal file
99
mitmproxy/mitm_http.go
Normal file
|
@ -0,0 +1,99 @@
|
|||
package mitmproxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/AdguardTeam/golibs/jsonutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Print to log and set HTTP error message
|
||||
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Info("MITM: %s %s: %s", r.Method, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
type mitmConfigJSON struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ListenAddr string `json:"listen_address"`
|
||||
ListenPort int `json:"listen_port"`
|
||||
|
||||
UserName string `json:"auth_username"`
|
||||
Password string `json:"auth_password"`
|
||||
|
||||
CertData string `json:"cert_data"`
|
||||
PKeyData string `json:"pkey_data"`
|
||||
}
|
||||
|
||||
func (p *MITMProxy) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
resp := mitmConfigJSON{}
|
||||
p.confLock.Lock()
|
||||
resp.Enabled = p.conf.Enabled
|
||||
host, port, _ := net.SplitHostPort(p.conf.ListenAddr)
|
||||
resp.ListenAddr = host
|
||||
resp.ListenPort, _ = strconv.Atoi(port)
|
||||
resp.UserName = p.conf.UserName
|
||||
resp.Password = p.conf.Password
|
||||
p.confLock.Unlock()
|
||||
|
||||
js, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "json.Marshal: %s", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(js)
|
||||
}
|
||||
|
||||
func (p *MITMProxy) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
req := mitmConfigJSON{}
|
||||
_, err := jsonutil.DecodeObject(&req, r.Body)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !((len(req.CertData) != 0 && len(req.PKeyData) != 0) ||
|
||||
(len(req.CertData) == 0 && len(req.PKeyData) == 0)) {
|
||||
httpError(r, w, http.StatusBadRequest, "certificate & private key must be both empty or specified")
|
||||
return
|
||||
}
|
||||
|
||||
p.confLock.Lock()
|
||||
if len(req.CertData) != 0 {
|
||||
err = p.storeCert([]byte(req.CertData), []byte(req.PKeyData))
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "%s", err)
|
||||
p.confLock.Unlock()
|
||||
return
|
||||
}
|
||||
p.conf.RegenCert = false
|
||||
} else {
|
||||
p.conf.RegenCert = true
|
||||
}
|
||||
p.conf.Enabled = req.Enabled
|
||||
p.conf.ListenAddr = net.JoinHostPort(req.ListenAddr, strconv.Itoa(req.ListenPort))
|
||||
p.conf.UserName = req.UserName
|
||||
p.conf.Password = req.Password
|
||||
p.confLock.Unlock()
|
||||
|
||||
p.conf.ConfigModified()
|
||||
|
||||
p.Close()
|
||||
err = p.Restart()
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "%s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize web handlers
|
||||
func (p *MITMProxy) initWeb() {
|
||||
p.conf.HTTPRegister("GET", "/control/proxy_info", p.handleGetConfig)
|
||||
p.conf.HTTPRegister("POST", "/control/proxy_config", p.handleSetConfig)
|
||||
}
|
57
mitmproxy/mitm_test.go
Normal file
57
mitmproxy/mitm_test.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package mitmproxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/filters"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func prepareTestDir() string {
|
||||
const dir = "./agh-test"
|
||||
_ = os.RemoveAll(dir)
|
||||
_ = os.MkdirAll(dir, 0755)
|
||||
return dir
|
||||
}
|
||||
|
||||
func TestMITM(t *testing.T) {
|
||||
dir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
fconf := filters.Conf{}
|
||||
fconf.FilterDir = dir
|
||||
fconf.HTTPClient = http.DefaultClient
|
||||
filters := filters.New(fconf)
|
||||
|
||||
conf := Config{}
|
||||
conf.Enabled = true
|
||||
conf.CertDir = dir
|
||||
conf.RegenCert = true
|
||||
conf.ListenAddr = "127.0.0.1:8081"
|
||||
conf.Filter = filters
|
||||
s := New(conf)
|
||||
assert.NotNil(t, s)
|
||||
|
||||
err := s.Start()
|
||||
assert.Nil(t, err)
|
||||
|
||||
proxyURL, _ := url.Parse("http://127.0.0.1:8081")
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
c := http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
resp, err := c.Get("http://example.com/")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
resp, err = c.Get("http://adguardhome.api/cert.crt")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
s.Close()
|
||||
}
|
279
mitmproxy/mitmproxy.go
Normal file
279
mitmproxy/mitmproxy.go
Normal file
|
@ -0,0 +1,279 @@
|
|||
package mitmproxy
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/filters"
|
||||
"github.com/AdguardTeam/golibs/file"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/gomitmproxy/mitm"
|
||||
"github.com/AdguardTeam/urlfilter/proxy"
|
||||
)
|
||||
|
||||
// MITMProxy - MITM proxy structure
|
||||
type MITMProxy struct {
|
||||
proxy *proxy.Server
|
||||
conf Config
|
||||
confLock sync.Mutex
|
||||
}
|
||||
|
||||
// Config - module configuration
|
||||
type Config struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
ListenAddr string `yaml:"listen_address"`
|
||||
|
||||
UserName string `yaml:"auth_username"`
|
||||
Password string `yaml:"auth_password"`
|
||||
|
||||
// TLS:
|
||||
RegenCert bool `yaml:"regenerate_cert"` // Regenerate certificate on cert loading failure
|
||||
CertDir string `yaml:"-"` // Directory where Root certificate & pkey is stored
|
||||
certFileName string
|
||||
pkeyFileName string
|
||||
certData []byte
|
||||
pkeyData []byte
|
||||
|
||||
Filter filters.Filters `yaml:"-"`
|
||||
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func() `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
|
||||
}
|
||||
|
||||
// New - create a new instance of the query log
|
||||
func New(conf Config) *MITMProxy {
|
||||
p := MITMProxy{}
|
||||
|
||||
p.conf = conf
|
||||
p.conf.certFileName = filepath.Join(p.conf.CertDir, "/http_proxy.crt")
|
||||
p.conf.pkeyFileName = filepath.Join(p.conf.CertDir, "/http_proxy.key")
|
||||
|
||||
err := p.create()
|
||||
if err != nil {
|
||||
log.Error("MITM: %s", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.conf.HTTPRegister != nil {
|
||||
p.initWeb()
|
||||
}
|
||||
|
||||
p.conf.Filter.SetObserver(p.onFiltersChanged)
|
||||
|
||||
return &p
|
||||
}
|
||||
|
||||
// Close - close the object
|
||||
func (p *MITMProxy) Close() {
|
||||
if p.proxy != nil {
|
||||
p.proxy.Close()
|
||||
p.proxy = nil
|
||||
log.Debug("MITM: Closed proxy")
|
||||
}
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write configuration on disk
|
||||
func (p *MITMProxy) WriteDiskConfig(c *Config) {
|
||||
p.confLock.Lock()
|
||||
*c = p.conf
|
||||
p.confLock.Unlock()
|
||||
}
|
||||
|
||||
// Start - start proxy server
|
||||
func (p *MITMProxy) Start() error {
|
||||
if !p.conf.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := p.proxy.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug("MITM: Running...")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restart - restart proxy server after Close()
|
||||
func (p *MITMProxy) Restart() error {
|
||||
err := p.create()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.Start()
|
||||
}
|
||||
|
||||
// Create a gomitmproxy object
|
||||
func (p *MITMProxy) create() error {
|
||||
if !p.conf.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := proxy.Config{}
|
||||
c.ProxyConfig.APIHost = "adguardhome.api"
|
||||
addr, port, err := net.SplitHostPort(p.conf.ListenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("net.SplitHostPort: %s", err)
|
||||
}
|
||||
|
||||
c.CompressContentScript = true
|
||||
c.ProxyConfig.ListenAddr = &net.TCPAddr{}
|
||||
c.ProxyConfig.ListenAddr.IP = net.ParseIP(addr)
|
||||
if c.ProxyConfig.ListenAddr.IP == nil {
|
||||
return fmt.Errorf("invalid IP: %s", addr)
|
||||
}
|
||||
c.ProxyConfig.ListenAddr.Port, err = strconv.Atoi(port)
|
||||
if c.ProxyConfig.ListenAddr.Port < 0 || c.ProxyConfig.ListenAddr.Port > 0xffff || err != nil {
|
||||
return fmt.Errorf("invalid port number: %s", port)
|
||||
}
|
||||
|
||||
c.ProxyConfig.Username = p.conf.UserName
|
||||
c.ProxyConfig.Password = p.conf.Password
|
||||
|
||||
err = p.loadCert()
|
||||
if err != nil {
|
||||
if !p.conf.RegenCert {
|
||||
return err
|
||||
}
|
||||
log.Debug("%s", err)
|
||||
|
||||
// certificate or private key file doesn't exist - generate new
|
||||
err = p.createRootCert()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.ProxyConfig.MITMConfig, err = p.prepareMITMConfig()
|
||||
if err != nil {
|
||||
if !p.conf.RegenCert {
|
||||
return err
|
||||
}
|
||||
|
||||
// certificate or private key is invalid - generate new
|
||||
err = p.createRootCert()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.ProxyConfig.MITMConfig, err = p.prepareMITMConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.FiltersPaths = make(map[int]string)
|
||||
filtrs := p.conf.Filter.List(0)
|
||||
i := 0
|
||||
for _, f := range filtrs {
|
||||
if !f.Enabled ||
|
||||
f.RuleCount == 0 { // not loaded
|
||||
continue
|
||||
}
|
||||
|
||||
c.FiltersPaths[i] = f.Path
|
||||
i++
|
||||
}
|
||||
|
||||
p.proxy, err = proxy.NewServer(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("proxy.NewServer: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load cert and pkey from file
|
||||
func (p *MITMProxy) loadCert() error {
|
||||
var err error
|
||||
p.conf.certData, err = ioutil.ReadFile(p.conf.certFileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.conf.pkeyData, err = ioutil.ReadFile(p.conf.pkeyFileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create Root certificate and pkey and store it on disk
|
||||
func (p *MITMProxy) createRootCert() error {
|
||||
cert, key, err := mitm.NewAuthority("AdGuardHome Root", "AdGuard", 365*24*time.Hour)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.conf.certData = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
|
||||
p.conf.pkeyData = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
||||
log.Debug("MITM: Created root certificate and key")
|
||||
|
||||
err = p.storeCert(p.conf.certData, p.conf.pkeyData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store cert & pkey on disk
|
||||
func (p *MITMProxy) storeCert(certData []byte, pkeyData []byte) error {
|
||||
err := file.SafeWrite(p.conf.certFileName, certData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = file.SafeWrite(p.conf.pkeyFileName, pkeyData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("MITM: stored root certificate and key: %s, %s", p.conf.certFileName, p.conf.pkeyFileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fill TLSConfig & MITMConfig objects
|
||||
func (p *MITMProxy) prepareMITMConfig() (*mitm.Config, error) {
|
||||
tlsCert, err := tls.X509KeyPair(p.conf.certData, p.conf.pkeyData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load root CA: %v", err)
|
||||
}
|
||||
privateKey := tlsCert.PrivateKey.(*rsa.PrivateKey)
|
||||
|
||||
x509c, err := x509.ParseCertificate(tlsCert.Certificate[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid certificate: %v", err)
|
||||
}
|
||||
|
||||
mitmConfig, err := mitm.NewConfig(x509c, privateKey, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create MITM config: %v", err)
|
||||
}
|
||||
|
||||
mitmConfig.SetValidity(time.Hour * 24 * 7) // generate certs valid for 7 days
|
||||
mitmConfig.SetOrganization("AdGuard") // cert organization
|
||||
return mitmConfig, nil
|
||||
}
|
||||
|
||||
func (p *MITMProxy) onFiltersChanged(flags uint) {
|
||||
switch flags {
|
||||
case filters.EventBeforeUpdate:
|
||||
p.Close()
|
||||
|
||||
case filters.EventAfterUpdate:
|
||||
err := p.Restart()
|
||||
if err != nil {
|
||||
log.Error("MITM: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue