mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-22 04:55:33 +03:00
Pull request: 2305 limit message size
Merge in DNS/adguard-home from 2305-limit-message-size to master Closes #2305. Squashed commit of the following: commit 6edd1e0521277a680f0053308efcf3d9cacc8e62 Author: Eugene Burkov <e.burkov@adguard.com> Date: Mon Nov 23 14:03:36 2020 +0300 aghio: fix final inaccuracies commit 4dd382aaf25132b31eb269749a2cd36daf0cb792 Author: Eugene Burkov <e.burkov@adguard.com> Date: Mon Nov 23 13:59:10 2020 +0300 all: improve code quality commit 060f923f6023d0e6f26441559b7023d5e5f96843 Author: Eugene Burkov <e.burkov@adguard.com> Date: Mon Nov 23 13:10:57 2020 +0300 aghio: add validation to constructor commit f57a2f596f5dc578548241c315c68dce7fc93905 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Nov 20 19:19:26 2020 +0300 all: fix minor inaccuracies commit 93462c71725d3d00655a4bd565b77e64451fff60 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Nov 20 19:13:23 2020 +0300 home: make test name follow convention commit 4922986ad84481b054479c43b4133a1b97bee86b Merge: 1f5472abc046ec13fd
Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Nov 20 19:09:01 2020 +0300 Merge branch 'master' into 2305-limit-message-size commit 1f5472abcfa7427f389825fc59eb4253e1e2bfb7 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Nov 20 19:08:21 2020 +0300 aghio: improve readability commit 60dc706b093fa22bbf62f13b2341934364ddc4df Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Nov 20 18:44:08 2020 +0300 home: cover middleware with test commit bedf436b947ca1fa4493af2fc94f1f40beec7c35 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Nov 20 17:10:23 2020 +0300 aghio: improved error informativeness commit 682c5da9f21fa330fb3536bb1c112129c91b9990 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Nov 20 13:37:51 2020 +0300 all: limit readers for ReadAll dealing with miscellanious data. commit 78c6dd8d90a0a43fe6ee3f9ed4d5fc637b15ba74 Author: Eugene Burkov <e.burkov@adguard.com> Date: Thu Nov 19 20:07:43 2020 +0300 all: handle ReadAll calls dealing with request's bodies. commit bfe1a6faf6468eb44515e2b0ecffa8c51f90b7e8 Author: Eugene Burkov <e.burkov@adguard.com> Date: Thu Nov 19 17:25:34 2020 +0300 home: add middlewares commit bbd1d491b318e6ba07f8af23ad546183383783a8 Merge: 7b77c2cad62a8fe0b7
Author: Eugene Burkov <e.burkov@adguard.com> Date: Thu Nov 19 16:44:04 2020 +0300 Merge branch 'master' into 2305-limit-message-size commit 7b77c2cad03154177392460982e1d73ee2a30177 Author: Eugene Burkov <e.burkov@adguard.com> Date: Tue Nov 17 15:33:33 2020 +0300 aghio: create package
This commit is contained in:
parent
046ec13fdc
commit
c129361e55
15 changed files with 413 additions and 64 deletions
13
CHANGELOG.md
13
CHANGELOG.md
|
@ -9,6 +9,19 @@ and this project adheres to
|
|||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- HTTP API request body size limit [#2305].
|
||||
|
||||
[#2305]: https://github.com/AdguardTeam/AdGuardHome/issues/2305
|
||||
|
||||
### Changed
|
||||
|
||||
- Various internal improvements ([#2271], [#2297]).
|
||||
|
||||
[#2271]: https://github.com/AdguardTeam/AdGuardHome/issues/2271
|
||||
[#2297]: https://github.com/AdguardTeam/AdGuardHome/issues/2297
|
||||
|
||||
|
||||
|
||||
## [v0.104.3] - 2020-11-19
|
||||
|
|
46
HACKING.md
46
HACKING.md
|
@ -1,4 +1,4 @@
|
|||
# AdGuardHome Developer Guidelines
|
||||
# *AdGuardHome* Developer Guidelines
|
||||
|
||||
As of **2020-11-20**, this document is still a work-in-progress. Some of the
|
||||
rules aren't enforced, and others might change. Still, this is a good place to
|
||||
|
@ -6,7 +6,11 @@ find out about how we **want** our code to look like.
|
|||
|
||||
The rules are mostly sorted in the alphabetical order.
|
||||
|
||||
## Git
|
||||
## *Git*
|
||||
|
||||
* Call your branches either `NNNN-fix-foo` (where `NNNN` is the ID of the
|
||||
*GitHub* issue you worked on in this branch) or just `fix-foo` if there was
|
||||
no *GitHub* issue.
|
||||
|
||||
* Follow the commit message header format:
|
||||
|
||||
|
@ -22,9 +26,10 @@ The rules are mostly sorted in the alphabetical order.
|
|||
* Only use lowercase letters in your commit message headers. The rest of the
|
||||
message should follow the plain text conventions below.
|
||||
|
||||
The only exception are direct mentions of identifiers from the source code.
|
||||
The only exceptions are direct mentions of identifiers from the source code
|
||||
and filenames like `HACKING.md`.
|
||||
|
||||
## Go
|
||||
## *Go*
|
||||
|
||||
* <https://github.com/golang/go/wiki/CodeReviewComments>.
|
||||
|
||||
|
@ -32,6 +37,9 @@ The rules are mostly sorted in the alphabetical order.
|
|||
|
||||
* <https://go-proverbs.github.io/>
|
||||
|
||||
* Add an empty line before `break`, `continue`, and `return`, unless it's the
|
||||
only statement in that block.
|
||||
|
||||
* Avoid `init` and use explicit initialization functions instead.
|
||||
|
||||
* Avoid `new`, especially with structs.
|
||||
|
@ -53,6 +61,18 @@ The rules are mostly sorted in the alphabetical order.
|
|||
* Eschew external dependencies, including transitive, unless
|
||||
absolutely necessary.
|
||||
|
||||
* Name benchmarks and tests using the same convention as examples. For
|
||||
example:
|
||||
|
||||
```go
|
||||
func TestFunction(t *testing.T) { /* … */ }
|
||||
func TestFunction_suffix(t *testing.T) { /* … */ }
|
||||
func TestType_Method(t *testing.T) { /* … */ }
|
||||
func TestType_Method_suffix(t *testing.T) { /* … */ }
|
||||
```
|
||||
|
||||
* Name the deferred errors (e.g. when closing something) `cerr`.
|
||||
|
||||
* No `goto`.
|
||||
|
||||
* No shadowing, since it can often lead to subtle bugs, especially with
|
||||
|
@ -103,9 +123,9 @@ The rules are mostly sorted in the alphabetical order.
|
|||
[constant errors]: https://dave.cheney.net/2016/04/07/constant-errors
|
||||
[Linus said]: https://www.kernel.org/doc/html/v4.17/process/coding-style.html#indentation
|
||||
|
||||
## Markdown
|
||||
## *Markdown*
|
||||
|
||||
* **TODO(a.garipov):** Define our Markdown conventions.
|
||||
* **TODO(a.garipov):** Define our *Markdown* conventions.
|
||||
|
||||
## Text, Including Comments
|
||||
|
||||
|
@ -128,7 +148,7 @@ The rules are mostly sorted in the alphabetical order.
|
|||
|
||||
* Use double spacing between sentences to make sentence borders more clear.
|
||||
|
||||
* Use the serial comma (a.k.a. Oxford comma) to improve comprehension,
|
||||
* Use the serial comma (a.k.a. *Oxford* comma) to improve comprehension,
|
||||
decrease ambiguity, and use a common standard.
|
||||
|
||||
* Write todos like this:
|
||||
|
@ -143,16 +163,16 @@ The rules are mostly sorted in the alphabetical order.
|
|||
// TODO(usr1, usr2): Fix the frobulation issue.
|
||||
```
|
||||
|
||||
## YAML
|
||||
## *YAML*
|
||||
|
||||
* **TODO(a.garipov):** Define naming conventions for schema names in our
|
||||
OpenAPI YAML file. And just generally OpenAPI conventions.
|
||||
*OpenAPI* *YAML* file. And just generally OpenAPI conventions.
|
||||
|
||||
* **TODO(a.garipov):** Find a YAML formatter or write our own.
|
||||
* **TODO(a.garipov):** Find a *YAML* formatter or write our own.
|
||||
|
||||
* All strings, including keys, must be quoted. Reason: the [NO-rway Law].
|
||||
* All strings, including keys, must be quoted. Reason: the [*NO-rway Law*].
|
||||
|
||||
* Indent with two (**2**) spaces. YAML documents can get pretty
|
||||
* Indent with two (**2**) spaces. *YAML* documents can get pretty
|
||||
deeply-nested.
|
||||
|
||||
* No extra indentation in multiline arrays:
|
||||
|
@ -170,4 +190,4 @@ The rules are mostly sorted in the alphabetical order.
|
|||
|
||||
* Use `>` for multiline strings, unless you need to keep the line breaks.
|
||||
|
||||
[NO-rway Law]: https://news.ycombinator.com/item?id=17359376
|
||||
[*NO-rway Law*]: https://news.ycombinator.com/item?id=17359376
|
||||
|
|
59
internal/aghio/limitedreadcloser.go
Normal file
59
internal/aghio/limitedreadcloser.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
// Package aghio contains extensions for io package's types and methods
|
||||
package aghio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// LimitReachedError records the limit and the operation that caused it.
|
||||
type LimitReachedError struct {
|
||||
Limit int64
|
||||
}
|
||||
|
||||
// Error implements error interface for LimitReachedError.
|
||||
// TODO(a.garipov): Think about error string format.
|
||||
func (lre *LimitReachedError) Error() string {
|
||||
return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit)
|
||||
}
|
||||
|
||||
// limitedReadCloser is a wrapper for io.ReadCloser with limited reader and
|
||||
// dealing with agherr package.
|
||||
type limitedReadCloser struct {
|
||||
limit int64
|
||||
n int64
|
||||
rc io.ReadCloser
|
||||
}
|
||||
|
||||
// Read implements Reader interface.
|
||||
func (lrc *limitedReadCloser) Read(p []byte) (n int, err error) {
|
||||
if lrc.n == 0 {
|
||||
return 0, &LimitReachedError{
|
||||
Limit: lrc.limit,
|
||||
}
|
||||
}
|
||||
if int64(len(p)) > lrc.n {
|
||||
p = p[0:lrc.n]
|
||||
}
|
||||
n, err = lrc.rc.Read(p)
|
||||
lrc.n -= int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close implements Closer interface.
|
||||
func (lrc *limitedReadCloser) Close() error {
|
||||
return lrc.rc.Close()
|
||||
}
|
||||
|
||||
// LimitReadCloser wraps ReadCloser to make it's Reader stop with
|
||||
// ErrLimitReached after n bytes read.
|
||||
func LimitReadCloser(rc io.ReadCloser, n int64) (limited io.ReadCloser, err error) {
|
||||
if n < 0 {
|
||||
return nil, fmt.Errorf("aghio: invalid n in LimitReadCloser: %d", n)
|
||||
}
|
||||
return &limitedReadCloser{
|
||||
limit: n,
|
||||
n: n,
|
||||
rc: rc,
|
||||
}, nil
|
||||
}
|
108
internal/aghio/limitedreadcloser_test.go
Normal file
108
internal/aghio/limitedreadcloser_test.go
Normal file
|
@ -0,0 +1,108 @@
|
|||
package aghio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLimitReadCloser(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
n int64
|
||||
want error
|
||||
}{{
|
||||
name: "positive",
|
||||
n: 1,
|
||||
want: nil,
|
||||
}, {
|
||||
name: "zero",
|
||||
n: 0,
|
||||
want: nil,
|
||||
}, {
|
||||
name: "negative",
|
||||
n: -1,
|
||||
want: fmt.Errorf("aghio: invalid n in LimitReadCloser: -1"),
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := LimitReadCloser(nil, tc.n)
|
||||
assert.Equal(t, tc.want, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitedReadCloser_Read(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
limit int64
|
||||
rStr string
|
||||
want int
|
||||
err error
|
||||
}{{
|
||||
name: "perfectly_match",
|
||||
limit: 3,
|
||||
rStr: "abc",
|
||||
want: 3,
|
||||
err: nil,
|
||||
}, {
|
||||
name: "eof",
|
||||
limit: 3,
|
||||
rStr: "",
|
||||
want: 0,
|
||||
err: io.EOF,
|
||||
}, {
|
||||
name: "limit_reached",
|
||||
limit: 0,
|
||||
rStr: "abc",
|
||||
want: 0,
|
||||
err: &LimitReachedError{
|
||||
Limit: 0,
|
||||
},
|
||||
}, {
|
||||
name: "truncated",
|
||||
limit: 2,
|
||||
rStr: "abc",
|
||||
want: 2,
|
||||
err: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
readCloser := ioutil.NopCloser(strings.NewReader(tc.rStr))
|
||||
buf := make([]byte, tc.limit+1)
|
||||
|
||||
lreader, err := LimitReadCloser(readCloser, tc.limit)
|
||||
assert.Nil(t, err)
|
||||
|
||||
n, err := lreader.Read(buf)
|
||||
assert.Equal(t, n, tc.want)
|
||||
assert.Equal(t, tc.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitedReadCloser_LimitReachedError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
err error
|
||||
}{{
|
||||
name: "simplest",
|
||||
want: "attempted to read more than 0 bytes",
|
||||
err: &LimitReachedError{
|
||||
Limit: 0,
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, tc.err.Error())
|
||||
})
|
||||
}
|
||||
}
|
|
@ -299,6 +299,7 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
|||
// . Check if a static IP is configured for the network interface
|
||||
// Respond with results
|
||||
func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
|
||||
// This use of ReadAll is safe, because request's body is now limited.
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("failed to read request body: %s", err)
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
|
@ -18,8 +19,10 @@ var GLMode bool
|
|||
|
||||
var glFilePrefix = "/tmp/gl_token_"
|
||||
|
||||
const glTokenTimeoutSeconds = 3600
|
||||
const glCookieName = "Admin-Token"
|
||||
const (
|
||||
glTokenTimeoutSeconds = 3600
|
||||
glCookieName = "Admin-Token"
|
||||
)
|
||||
|
||||
func glProcessRedirect(w http.ResponseWriter, r *http.Request) bool {
|
||||
if !GLMode {
|
||||
|
@ -71,14 +74,28 @@ func archIsLittleEndian() bool {
|
|||
return (b == 0x04)
|
||||
}
|
||||
|
||||
// MaxFileSize is a maximum file length in bytes.
|
||||
const MaxFileSize = 1024 * 1024
|
||||
|
||||
func glGetTokenDate(file string) uint32 {
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
log.Error("os.Open: %s", err)
|
||||
return 0
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fileReadCloser, err := aghio.LimitReadCloser(f, MaxFileSize)
|
||||
if err != nil {
|
||||
log.Error("LimitReadCloser: %s", err)
|
||||
return 0
|
||||
}
|
||||
defer fileReadCloser.Close()
|
||||
|
||||
var dateToken uint32
|
||||
bs, err := ioutil.ReadAll(f)
|
||||
|
||||
// This use of ReadAll is now safe, because we limited reader.
|
||||
bs, err := ioutil.ReadAll(fileReadCloser)
|
||||
if err != nil {
|
||||
log.Error("ioutil.ReadAll: %s", err)
|
||||
return 0
|
||||
|
|
|
@ -3,7 +3,6 @@ package home
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
|
@ -150,16 +149,11 @@ func clientHostToJSON(ip string, ch ClientHost) clientJSON {
|
|||
|
||||
// Add a new client
|
||||
func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
cj := clientJSON{}
|
||||
err = json.Unmarshal(body, &cj)
|
||||
err := json.NewDecoder(r.Body).Decode(&cj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -183,16 +177,17 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
|||
|
||||
// Remove client
|
||||
func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
cj := clientJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&cj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
|
||||
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
cj := clientJSON{}
|
||||
err = json.Unmarshal(body, &cj)
|
||||
if err != nil || len(cj.Name) == 0 {
|
||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
if len(cj.Name) == 0 {
|
||||
httpError(w, http.StatusBadRequest, "client's name must be non-empty")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -211,18 +206,14 @@ type updateJSON struct {
|
|||
|
||||
// Update client's properties
|
||||
func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
dj := updateJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&dj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
|
||||
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var dj updateJSON
|
||||
err = json.Unmarshal(body, &dj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
return
|
||||
}
|
||||
if len(dj.Name) == 0 {
|
||||
httpError(w, http.StatusBadRequest, "Invalid request")
|
||||
return
|
||||
|
|
|
@ -214,6 +214,7 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
|
|||
}
|
||||
|
||||
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
|
||||
// This use of ReadAll is safe, because request's body is now limited.
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err)
|
||||
|
|
|
@ -66,6 +66,7 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
// This use of ReadAll is safe, because request's body is now limited.
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("failed to read request body: %s", err)
|
||||
|
|
59
internal/home/middlewares.go
Normal file
59
internal/home/middlewares.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package home
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// middlerware is a wrapper function signature.
|
||||
type middleware func(http.Handler) http.Handler
|
||||
|
||||
// withMiddlewares consequently wraps h with all the middlewares.
|
||||
func withMiddlewares(h http.Handler, middlewares ...middleware) (wrapped http.Handler) {
|
||||
wrapped = h
|
||||
|
||||
for _, mw := range middlewares {
|
||||
wrapped = mw(wrapped)
|
||||
}
|
||||
|
||||
return wrapped
|
||||
}
|
||||
|
||||
// RequestBodySizeLimit is maximum request body length in bytes.
|
||||
const RequestBodySizeLimit = 64 * 1024
|
||||
|
||||
// limitRequestBody wraps underlying handler h, making it's request's body Read
|
||||
// method limited.
|
||||
func limitRequestBody(h http.Handler) (limited http.Handler) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
r.Body, err = aghio.LimitReadCloser(r.Body, RequestBodySizeLimit)
|
||||
if err != nil {
|
||||
log.Error("limitRequestBody: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// TODO(a.garipov): We currently have to use this, because everything registers
|
||||
// its HTTP handlers in http.DefaultServeMux. In the future, refactor our HTTP
|
||||
// API initialization process and stop using the gosh darn http.DefaultServeMux
|
||||
// for anything at all. Gosh darn global variables.
|
||||
func filterPProf(h http.Handler) (filtered http.Handler) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/debug/pprof") {
|
||||
http.NotFound(w, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
64
internal/home/middlewares_test.go
Normal file
64
internal/home/middlewares_test.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package home
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLimitRequestBody(t *testing.T) {
|
||||
errReqLimitReached := &aghio.LimitReachedError{
|
||||
Limit: RequestBodySizeLimit,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
body string
|
||||
want []byte
|
||||
wantErr error
|
||||
}{{
|
||||
name: "not_so_big",
|
||||
body: "somestr",
|
||||
want: []byte("somestr"),
|
||||
wantErr: nil,
|
||||
}, {
|
||||
name: "so_big",
|
||||
body: string(make([]byte, RequestBodySizeLimit+1)),
|
||||
want: make([]byte, RequestBodySizeLimit),
|
||||
wantErr: errReqLimitReached,
|
||||
}, {
|
||||
name: "empty",
|
||||
body: "",
|
||||
want: []byte(nil),
|
||||
wantErr: nil,
|
||||
}}
|
||||
|
||||
makeHandler := func(err *error) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var b []byte
|
||||
b, *err = ioutil.ReadAll(r.Body)
|
||||
w.Write(b)
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var err error
|
||||
handler := makeHandler(&err)
|
||||
lim := limitRequestBody(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "https://www.example.com", strings.NewReader(tc.body))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
lim.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, tc.want, res.Body.Bytes())
|
||||
assert.Equal(t, tc.wantErr, err)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -7,7 +7,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
|
@ -142,7 +141,7 @@ func (web *Web) Start() {
|
|||
web.httpServer = &http.Server{
|
||||
ErrorLog: web.errLogger,
|
||||
Addr: address,
|
||||
Handler: filterPPROF(http.DefaultServeMux),
|
||||
Handler: withMiddlewares(http.DefaultServeMux, filterPProf, limitRequestBody),
|
||||
}
|
||||
err := web.httpServer.ListenAndServe()
|
||||
if err != http.ErrServerClosed {
|
||||
|
@ -153,22 +152,6 @@ func (web *Web) Start() {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(a.garipov): We currently have to use this, because everything registers
|
||||
// its HTTP handlers in http.DefaultServeMux. In the future, refactor our HTTP
|
||||
// API initialization process and stop using the gosh darn http.DefaultServeMux
|
||||
// for anything at all. Gosh darn global variables.
|
||||
func filterPPROF(h http.Handler) (filtered http.Handler) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/debug/pprof") {
|
||||
http.NotFound(w, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Close - stop HTTP server, possibly waiting for all active connections to be closed
|
||||
func (web *Web) Close() {
|
||||
log.Info("Stopping HTTP server...")
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
|
@ -115,6 +116,9 @@ func whoisParse(data string) map[string]string {
|
|||
return m
|
||||
}
|
||||
|
||||
// MaxConnReadSize is an upper limit in bytes for reading from net.Conn.
|
||||
const MaxConnReadSize = 64 * 1024
|
||||
|
||||
// Send request to a server and receive the response
|
||||
func (w *Whois) query(target, serverAddr string) (string, error) {
|
||||
addr, _, _ := net.SplitHostPort(serverAddr)
|
||||
|
@ -127,13 +131,20 @@ func (w *Whois) query(target, serverAddr string) (string, error) {
|
|||
}
|
||||
defer conn.Close()
|
||||
|
||||
connReadCloser, err := aghio.LimitReadCloser(conn, MaxConnReadSize)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer connReadCloser.Close()
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond))
|
||||
_, err = conn.Write([]byte(target + "\r\n"))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadAll(conn)
|
||||
// This use of ReadAll is now safe, because we limited the conn Reader.
|
||||
data, err := ioutil.ReadAll(connReadCloser)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"io/ioutil"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
)
|
||||
|
||||
const versionCheckPeriod = 8 * 60 * 60
|
||||
|
@ -19,6 +21,9 @@ type VersionInfo struct {
|
|||
CanAutoUpdate bool // If true - we can auto-update
|
||||
}
|
||||
|
||||
// MaxResponseSize is responses on server's requests maximum length in bytes.
|
||||
const MaxResponseSize = 64 * 1024
|
||||
|
||||
// GetVersionResponse - downloads version.json (if needed) and deserializes it
|
||||
func (u *Updater) GetVersionResponse(forceRecheck bool) (VersionInfo, error) {
|
||||
if !forceRecheck &&
|
||||
|
@ -27,14 +32,19 @@ func (u *Updater) GetVersionResponse(forceRecheck bool) (VersionInfo, error) {
|
|||
}
|
||||
|
||||
resp, err := u.Client.Get(u.VersionURL)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", u.VersionURL, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxResponseSize)
|
||||
if err != nil {
|
||||
return VersionInfo{}, fmt.Errorf("updater: LimitReadCloser: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// This use of ReadAll is safe, because we just limited the appropriate
|
||||
// ReadCloser.
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", u.VersionURL, err)
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
@ -217,17 +218,27 @@ func (u *Updater) clean() {
|
|||
_ = os.RemoveAll(u.updateDir)
|
||||
}
|
||||
|
||||
// MaxPackageFileSize is a maximum package file length in bytes. The largest
|
||||
// package whose size is limited by this constant currently has the size of
|
||||
// approximately 9 MiB.
|
||||
const MaxPackageFileSize = 32 * 1024 * 1024
|
||||
|
||||
// Download package file and save it to disk
|
||||
func (u *Updater) downloadPackageFile(url string, filename string) error {
|
||||
resp, err := u.Client.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http request failed: %w", err)
|
||||
}
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxPackageFileSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
log.Debug("updater: reading HTTP body")
|
||||
// This use of ReadAll is now safe, because we limited body's Reader.
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ioutil.ReadAll() failed: %w", err)
|
||||
|
|
Loading…
Reference in a new issue