diff --git a/CHANGELOG.md b/CHANGELOG.md index eb01e003..bb2f1be7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,19 @@ and this project adheres to ## [Unreleased] +### Added + +- HTTP API request body size limit [#2305]. + +[#2305]: https://github.com/AdguardTeam/AdGuardHome/issues/2305 + +### Changed + +- Various internal improvements ([#2271], [#2297]). + +[#2271]: https://github.com/AdguardTeam/AdGuardHome/issues/2271 +[#2297]: https://github.com/AdguardTeam/AdGuardHome/issues/2297 + ## [v0.104.3] - 2020-11-19 diff --git a/HACKING.md b/HACKING.md index fbd56cc4..bbd4dfbd 100644 --- a/HACKING.md +++ b/HACKING.md @@ -1,4 +1,4 @@ - # AdGuardHome Developer Guidelines + # *AdGuardHome* Developer Guidelines As of **2020-11-20**, this document is still a work-in-progress. Some of the rules aren't enforced, and others might change. Still, this is a good place to @@ -6,7 +6,11 @@ find out about how we **want** our code to look like. The rules are mostly sorted in the alphabetical order. -## Git +## *Git* + + * Call your branches either `NNNN-fix-foo` (where `NNNN` is the ID of the + *GitHub* issue you worked on in this branch) or just `fix-foo` if there was + no *GitHub* issue. * Follow the commit message header format: @@ -22,9 +26,10 @@ The rules are mostly sorted in the alphabetical order. * Only use lowercase letters in your commit message headers. The rest of the message should follow the plain text conventions below. - The only exception are direct mentions of identifiers from the source code. + The only exceptions are direct mentions of identifiers from the source code + and filenames like `HACKING.md`. -## Go +## *Go* * . @@ -32,6 +37,9 @@ The rules are mostly sorted in the alphabetical order. * + * Add an empty line before `break`, `continue`, and `return`, unless it's the + only statement in that block. + * Avoid `init` and use explicit initialization functions instead. * Avoid `new`, especially with structs. @@ -53,6 +61,18 @@ The rules are mostly sorted in the alphabetical order. * Eschew external dependencies, including transitive, unless absolutely necessary. + * Name benchmarks and tests using the same convention as examples. For + example: + + ```go + func TestFunction(t *testing.T) { /* … */ } + func TestFunction_suffix(t *testing.T) { /* … */ } + func TestType_Method(t *testing.T) { /* … */ } + func TestType_Method_suffix(t *testing.T) { /* … */ } + ``` + + * Name the deferred errors (e.g. when closing something) `cerr`. + * No `goto`. * No shadowing, since it can often lead to subtle bugs, especially with @@ -103,9 +123,9 @@ The rules are mostly sorted in the alphabetical order. [constant errors]: https://dave.cheney.net/2016/04/07/constant-errors [Linus said]: https://www.kernel.org/doc/html/v4.17/process/coding-style.html#indentation -## Markdown +## *Markdown* - * **TODO(a.garipov):** Define our Markdown conventions. + * **TODO(a.garipov):** Define our *Markdown* conventions. ## Text, Including Comments @@ -128,7 +148,7 @@ The rules are mostly sorted in the alphabetical order. * Use double spacing between sentences to make sentence borders more clear. - * Use the serial comma (a.k.a. Oxford comma) to improve comprehension, + * Use the serial comma (a.k.a. *Oxford* comma) to improve comprehension, decrease ambiguity, and use a common standard. * Write todos like this: @@ -143,16 +163,16 @@ The rules are mostly sorted in the alphabetical order. // TODO(usr1, usr2): Fix the frobulation issue. ``` -## YAML +## *YAML* * **TODO(a.garipov):** Define naming conventions for schema names in our - OpenAPI YAML file. And just generally OpenAPI conventions. + *OpenAPI* *YAML* file. And just generally OpenAPI conventions. - * **TODO(a.garipov):** Find a YAML formatter or write our own. + * **TODO(a.garipov):** Find a *YAML* formatter or write our own. - * All strings, including keys, must be quoted. Reason: the [NO-rway Law]. + * All strings, including keys, must be quoted. Reason: the [*NO-rway Law*]. - * Indent with two (**2**) spaces. YAML documents can get pretty + * Indent with two (**2**) spaces. *YAML* documents can get pretty deeply-nested. * No extra indentation in multiline arrays: @@ -170,4 +190,4 @@ The rules are mostly sorted in the alphabetical order. * Use `>` for multiline strings, unless you need to keep the line breaks. -[NO-rway Law]: https://news.ycombinator.com/item?id=17359376 +[*NO-rway Law*]: https://news.ycombinator.com/item?id=17359376 diff --git a/internal/aghio/limitedreadcloser.go b/internal/aghio/limitedreadcloser.go new file mode 100644 index 00000000..7690705a --- /dev/null +++ b/internal/aghio/limitedreadcloser.go @@ -0,0 +1,59 @@ +// Package aghio contains extensions for io package's types and methods +package aghio + +import ( + "fmt" + "io" +) + +// LimitReachedError records the limit and the operation that caused it. +type LimitReachedError struct { + Limit int64 +} + +// Error implements error interface for LimitReachedError. +// TODO(a.garipov): Think about error string format. +func (lre *LimitReachedError) Error() string { + return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit) +} + +// limitedReadCloser is a wrapper for io.ReadCloser with limited reader and +// dealing with agherr package. +type limitedReadCloser struct { + limit int64 + n int64 + rc io.ReadCloser +} + +// Read implements Reader interface. +func (lrc *limitedReadCloser) Read(p []byte) (n int, err error) { + if lrc.n == 0 { + return 0, &LimitReachedError{ + Limit: lrc.limit, + } + } + if int64(len(p)) > lrc.n { + p = p[0:lrc.n] + } + n, err = lrc.rc.Read(p) + lrc.n -= int64(n) + return n, err +} + +// Close implements Closer interface. +func (lrc *limitedReadCloser) Close() error { + return lrc.rc.Close() +} + +// LimitReadCloser wraps ReadCloser to make it's Reader stop with +// ErrLimitReached after n bytes read. +func LimitReadCloser(rc io.ReadCloser, n int64) (limited io.ReadCloser, err error) { + if n < 0 { + return nil, fmt.Errorf("aghio: invalid n in LimitReadCloser: %d", n) + } + return &limitedReadCloser{ + limit: n, + n: n, + rc: rc, + }, nil +} diff --git a/internal/aghio/limitedreadcloser_test.go b/internal/aghio/limitedreadcloser_test.go new file mode 100644 index 00000000..1f10e32b --- /dev/null +++ b/internal/aghio/limitedreadcloser_test.go @@ -0,0 +1,108 @@ +package aghio + +import ( + "fmt" + "io" + "io/ioutil" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLimitReadCloser(t *testing.T) { + testCases := []struct { + name string + n int64 + want error + }{{ + name: "positive", + n: 1, + want: nil, + }, { + name: "zero", + n: 0, + want: nil, + }, { + name: "negative", + n: -1, + want: fmt.Errorf("aghio: invalid n in LimitReadCloser: -1"), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := LimitReadCloser(nil, tc.n) + assert.Equal(t, tc.want, err) + }) + } +} + +func TestLimitedReadCloser_Read(t *testing.T) { + testCases := []struct { + name string + limit int64 + rStr string + want int + err error + }{{ + name: "perfectly_match", + limit: 3, + rStr: "abc", + want: 3, + err: nil, + }, { + name: "eof", + limit: 3, + rStr: "", + want: 0, + err: io.EOF, + }, { + name: "limit_reached", + limit: 0, + rStr: "abc", + want: 0, + err: &LimitReachedError{ + Limit: 0, + }, + }, { + name: "truncated", + limit: 2, + rStr: "abc", + want: 2, + err: nil, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + readCloser := ioutil.NopCloser(strings.NewReader(tc.rStr)) + buf := make([]byte, tc.limit+1) + + lreader, err := LimitReadCloser(readCloser, tc.limit) + assert.Nil(t, err) + + n, err := lreader.Read(buf) + assert.Equal(t, n, tc.want) + assert.Equal(t, tc.err, err) + }) + } +} + +func TestLimitedReadCloser_LimitReachedError(t *testing.T) { + testCases := []struct { + name string + want string + err error + }{{ + name: "simplest", + want: "attempted to read more than 0 bytes", + err: &LimitReachedError{ + Limit: 0, + }, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, tc.err.Error()) + }) + } +} diff --git a/internal/dhcpd/dhcphttp.go b/internal/dhcpd/dhcphttp.go index f4ce801b..1cacd83c 100644 --- a/internal/dhcpd/dhcphttp.go +++ b/internal/dhcpd/dhcphttp.go @@ -299,6 +299,7 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { // . Check if a static IP is configured for the network interface // Respond with results func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) { + // This use of ReadAll is safe, because request's body is now limited. body, err := ioutil.ReadAll(r.Body) if err != nil { msg := fmt.Sprintf("failed to read request body: %s", err) diff --git a/internal/home/auth_glinet.go b/internal/home/auth_glinet.go index 7dd2790d..228843fa 100644 --- a/internal/home/auth_glinet.go +++ b/internal/home/auth_glinet.go @@ -10,6 +10,7 @@ import ( "time" "unsafe" + "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/golibs/log" ) @@ -18,8 +19,10 @@ var GLMode bool var glFilePrefix = "/tmp/gl_token_" -const glTokenTimeoutSeconds = 3600 -const glCookieName = "Admin-Token" +const ( + glTokenTimeoutSeconds = 3600 + glCookieName = "Admin-Token" +) func glProcessRedirect(w http.ResponseWriter, r *http.Request) bool { if !GLMode { @@ -71,14 +74,28 @@ func archIsLittleEndian() bool { return (b == 0x04) } +// MaxFileSize is a maximum file length in bytes. +const MaxFileSize = 1024 * 1024 + func glGetTokenDate(file string) uint32 { f, err := os.Open(file) if err != nil { log.Error("os.Open: %s", err) return 0 } + defer f.Close() + + fileReadCloser, err := aghio.LimitReadCloser(f, MaxFileSize) + if err != nil { + log.Error("LimitReadCloser: %s", err) + return 0 + } + defer fileReadCloser.Close() + var dateToken uint32 - bs, err := ioutil.ReadAll(f) + + // This use of ReadAll is now safe, because we limited reader. + bs, err := ioutil.ReadAll(fileReadCloser) if err != nil { log.Error("ioutil.ReadAll: %s", err) return 0 diff --git a/internal/home/clients_http.go b/internal/home/clients_http.go index 42fa6f2a..752b3c6f 100644 --- a/internal/home/clients_http.go +++ b/internal/home/clients_http.go @@ -3,7 +3,6 @@ package home import ( "encoding/json" "fmt" - "io/ioutil" "net/http" ) @@ -150,16 +149,11 @@ func clientHostToJSON(ip string, ch ClientHost) clientJSON { // Add a new client func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.Request) { - body, err := ioutil.ReadAll(r.Body) - if err != nil { - httpError(w, http.StatusBadRequest, "failed to read request body: %s", err) - return - } - cj := clientJSON{} - err = json.Unmarshal(body, &cj) + err := json.NewDecoder(r.Body).Decode(&cj) if err != nil { - httpError(w, http.StatusBadRequest, "JSON parse: %s", err) + httpError(w, http.StatusBadRequest, "failed to process request body: %s", err) + return } @@ -183,16 +177,17 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. // Remove client func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) { - body, err := ioutil.ReadAll(r.Body) + cj := clientJSON{} + err := json.NewDecoder(r.Body).Decode(&cj) if err != nil { - httpError(w, http.StatusBadRequest, "failed to read request body: %s", err) + httpError(w, http.StatusBadRequest, "failed to process request body: %s", err) + return } - cj := clientJSON{} - err = json.Unmarshal(body, &cj) - if err != nil || len(cj.Name) == 0 { - httpError(w, http.StatusBadRequest, "JSON parse: %s", err) + if len(cj.Name) == 0 { + httpError(w, http.StatusBadRequest, "client's name must be non-empty") + return } @@ -211,18 +206,14 @@ type updateJSON struct { // Update client's properties func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *http.Request) { - body, err := ioutil.ReadAll(r.Body) + dj := updateJSON{} + err := json.NewDecoder(r.Body).Decode(&dj) if err != nil { - httpError(w, http.StatusBadRequest, "failed to read request body: %s", err) + httpError(w, http.StatusBadRequest, "failed to process request body: %s", err) + return } - var dj updateJSON - err = json.Unmarshal(body, &dj) - if err != nil { - httpError(w, http.StatusBadRequest, "JSON parse: %s", err) - return - } if len(dj.Name) == 0 { httpError(w, http.StatusBadRequest, "Invalid request") return diff --git a/internal/home/control_filtering.go b/internal/home/control_filtering.go index 37f9af81..1794cce7 100644 --- a/internal/home/control_filtering.go +++ b/internal/home/control_filtering.go @@ -214,6 +214,7 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request } func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { + // This use of ReadAll is safe, because request's body is now limited. body, err := ioutil.ReadAll(r.Body) if err != nil { httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err) diff --git a/internal/home/i18n.go b/internal/home/i18n.go index 6ddfe549..adbc95aa 100644 --- a/internal/home/i18n.go +++ b/internal/home/i18n.go @@ -66,6 +66,7 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) { } func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) { + // This use of ReadAll is safe, because request's body is now limited. body, err := ioutil.ReadAll(r.Body) if err != nil { msg := fmt.Sprintf("failed to read request body: %s", err) diff --git a/internal/home/middlewares.go b/internal/home/middlewares.go new file mode 100644 index 00000000..4a38160d --- /dev/null +++ b/internal/home/middlewares.go @@ -0,0 +1,59 @@ +package home + +import ( + "net/http" + "strings" + + "github.com/AdguardTeam/AdGuardHome/internal/aghio" + + "github.com/AdguardTeam/golibs/log" +) + +// middlerware is a wrapper function signature. +type middleware func(http.Handler) http.Handler + +// withMiddlewares consequently wraps h with all the middlewares. +func withMiddlewares(h http.Handler, middlewares ...middleware) (wrapped http.Handler) { + wrapped = h + + for _, mw := range middlewares { + wrapped = mw(wrapped) + } + + return wrapped +} + +// RequestBodySizeLimit is maximum request body length in bytes. +const RequestBodySizeLimit = 64 * 1024 + +// limitRequestBody wraps underlying handler h, making it's request's body Read +// method limited. +func limitRequestBody(h http.Handler) (limited http.Handler) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + r.Body, err = aghio.LimitReadCloser(r.Body, RequestBodySizeLimit) + if err != nil { + log.Error("limitRequestBody: %s", err) + + return + } + + h.ServeHTTP(w, r) + }) +} + +// TODO(a.garipov): We currently have to use this, because everything registers +// its HTTP handlers in http.DefaultServeMux. In the future, refactor our HTTP +// API initialization process and stop using the gosh darn http.DefaultServeMux +// for anything at all. Gosh darn global variables. +func filterPProf(h http.Handler) (filtered http.Handler) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/debug/pprof") { + http.NotFound(w, r) + + return + } + + h.ServeHTTP(w, r) + }) +} diff --git a/internal/home/middlewares_test.go b/internal/home/middlewares_test.go new file mode 100644 index 00000000..4d6a33d0 --- /dev/null +++ b/internal/home/middlewares_test.go @@ -0,0 +1,64 @@ +package home + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/aghio" + "github.com/stretchr/testify/assert" +) + +func TestLimitRequestBody(t *testing.T) { + errReqLimitReached := &aghio.LimitReachedError{ + Limit: RequestBodySizeLimit, + } + + testCases := []struct { + name string + body string + want []byte + wantErr error + }{{ + name: "not_so_big", + body: "somestr", + want: []byte("somestr"), + wantErr: nil, + }, { + name: "so_big", + body: string(make([]byte, RequestBodySizeLimit+1)), + want: make([]byte, RequestBodySizeLimit), + wantErr: errReqLimitReached, + }, { + name: "empty", + body: "", + want: []byte(nil), + wantErr: nil, + }} + + makeHandler := func(err *error) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var b []byte + b, *err = ioutil.ReadAll(r.Body) + w.Write(b) + }) + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var err error + handler := makeHandler(&err) + lim := limitRequestBody(handler) + + req := httptest.NewRequest(http.MethodPost, "https://www.example.com", strings.NewReader(tc.body)) + res := httptest.NewRecorder() + + lim.ServeHTTP(res, req) + + assert.Equal(t, tc.want, res.Body.Bytes()) + assert.Equal(t, tc.wantErr, err) + }) + } +} diff --git a/internal/home/web.go b/internal/home/web.go index f8ceb296..0d6a1628 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -7,7 +7,6 @@ import ( "net" "net/http" "strconv" - "strings" "sync" "github.com/AdguardTeam/AdGuardHome/internal/util" @@ -142,7 +141,7 @@ func (web *Web) Start() { web.httpServer = &http.Server{ ErrorLog: web.errLogger, Addr: address, - Handler: filterPPROF(http.DefaultServeMux), + Handler: withMiddlewares(http.DefaultServeMux, filterPProf, limitRequestBody), } err := web.httpServer.ListenAndServe() if err != http.ErrServerClosed { @@ -153,22 +152,6 @@ func (web *Web) Start() { } } -// TODO(a.garipov): We currently have to use this, because everything registers -// its HTTP handlers in http.DefaultServeMux. In the future, refactor our HTTP -// API initialization process and stop using the gosh darn http.DefaultServeMux -// for anything at all. Gosh darn global variables. -func filterPPROF(h http.Handler) (filtered http.Handler) { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.HasPrefix(r.URL.Path, "/debug/pprof") { - http.NotFound(w, r) - - return - } - - h.ServeHTTP(w, r) - }) -} - // Close - stop HTTP server, possibly waiting for all active connections to be closed func (web *Web) Close() { log.Info("Stopping HTTP server...") diff --git a/internal/home/whois.go b/internal/home/whois.go index 1fcff3dc..4884d776 100644 --- a/internal/home/whois.go +++ b/internal/home/whois.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/golibs/cache" @@ -115,6 +116,9 @@ func whoisParse(data string) map[string]string { return m } +// MaxConnReadSize is an upper limit in bytes for reading from net.Conn. +const MaxConnReadSize = 64 * 1024 + // Send request to a server and receive the response func (w *Whois) query(target, serverAddr string) (string, error) { addr, _, _ := net.SplitHostPort(serverAddr) @@ -127,13 +131,20 @@ func (w *Whois) query(target, serverAddr string) (string, error) { } defer conn.Close() + connReadCloser, err := aghio.LimitReadCloser(conn, MaxConnReadSize) + if err != nil { + return "", err + } + defer connReadCloser.Close() + _ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond)) _, err = conn.Write([]byte(target + "\r\n")) if err != nil { return "", err } - data, err := ioutil.ReadAll(conn) + // This use of ReadAll is now safe, because we limited the conn Reader. + data, err := ioutil.ReadAll(connReadCloser) if err != nil { return "", err } diff --git a/internal/update/check.go b/internal/update/check.go index b10cec73..e83ab5c2 100644 --- a/internal/update/check.go +++ b/internal/update/check.go @@ -6,6 +6,8 @@ import ( "io/ioutil" "strings" "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghio" ) const versionCheckPeriod = 8 * 60 * 60 @@ -19,6 +21,9 @@ type VersionInfo struct { CanAutoUpdate bool // If true - we can auto-update } +// MaxResponseSize is responses on server's requests maximum length in bytes. +const MaxResponseSize = 64 * 1024 + // GetVersionResponse - downloads version.json (if needed) and deserializes it func (u *Updater) GetVersionResponse(forceRecheck bool) (VersionInfo, error) { if !forceRecheck && @@ -27,14 +32,19 @@ func (u *Updater) GetVersionResponse(forceRecheck bool) (VersionInfo, error) { } resp, err := u.Client.Get(u.VersionURL) - if resp != nil && resp.Body != nil { - defer resp.Body.Close() - } - if err != nil { return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", u.VersionURL, err) } + defer resp.Body.Close() + resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxResponseSize) + if err != nil { + return VersionInfo{}, fmt.Errorf("updater: LimitReadCloser: %w", err) + } + defer resp.Body.Close() + + // This use of ReadAll is safe, because we just limited the appropriate + // ReadCloser. body, err := ioutil.ReadAll(resp.Body) if err != nil { return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", u.VersionURL, err) diff --git a/internal/update/updater.go b/internal/update/updater.go index f78f85c5..34d66819 100644 --- a/internal/update/updater.go +++ b/internal/update/updater.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/golibs/log" ) @@ -217,17 +218,27 @@ func (u *Updater) clean() { _ = os.RemoveAll(u.updateDir) } +// MaxPackageFileSize is a maximum package file length in bytes. The largest +// package whose size is limited by this constant currently has the size of +// approximately 9 MiB. +const MaxPackageFileSize = 32 * 1024 * 1024 + // Download package file and save it to disk func (u *Updater) downloadPackageFile(url string, filename string) error { resp, err := u.Client.Get(url) if err != nil { return fmt.Errorf("http request failed: %w", err) } - if resp != nil && resp.Body != nil { - defer resp.Body.Close() + defer resp.Body.Close() + + resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxPackageFileSize) + if err != nil { + return fmt.Errorf("http request failed: %w", err) } + defer resp.Body.Close() log.Debug("updater: reading HTTP body") + // This use of ReadAll is now safe, because we limited body's Reader. body, err := ioutil.ReadAll(resp.Body) if err != nil { return fmt.Errorf("ioutil.ReadAll() failed: %w", err)