home: write timeout middleware

This commit is contained in:
Dimitry Kolyshev 2023-06-14 10:51:17 +04:00
parent 5cd4ce766d
commit 5f0e53ded7
2 changed files with 54 additions and 7 deletions

View file

@ -3,13 +3,13 @@ package home
import ( import (
"io" "io"
"net/http" "net/http"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
// middlerware is a wrapper function signature. // middleware is a wrapper function signature.
type middleware func(http.Handler) http.Handler type middleware func(http.Handler) http.Handler
// withMiddlewares consequently wraps h with all the middlewares. // withMiddlewares consequently wraps h with all the middlewares.
@ -75,3 +75,48 @@ func limitRequestBody(h http.Handler) (limited http.Handler) {
h.ServeHTTP(w, rr) h.ServeHTTP(w, rr)
}) })
} }
const (
// defaultWriteTimeout is the maximum duration before timing out writes of
// the response.
defaultWriteTimeout = 60 * time.Second
// longerWriteTimeout is the maximum duration before timing out for APIs
// expecting longer response requests.
longerWriteTimeout = 5 * time.Minute
)
// expectsLongTimeoutRequests shows if this request should use a bigger write
// timeout value. These are exceptions for poorly designed current APIs as
// well as APIs that are designed to expect large files and requests. Remove
// once the new, better APIs are up.
//
// TODO(d.kolyshev): This could be achieved with [http.NewResponseController]
// with go v1.20.
func expectsLongTimeoutRequests(r *http.Request) (ok bool) {
if r.Method != http.MethodGet {
return false
}
return r.URL.Path == "/control/querylog/export"
}
// addWriteTimeout wraps underlying handler h, adding a response write timeout.
func addWriteTimeout(h http.Handler) (limited http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var handler http.Handler
if expectsLongTimeoutRequests(r) {
handler = http.TimeoutHandler(h, longerWriteTimeout, "write timeout exceeded")
} else {
handler = http.TimeoutHandler(h, defaultWriteTimeout, "write timeout exceeded")
}
handler.ServeHTTP(w, r)
})
}
// limitHandler wraps underlying handler h with default limits, such as request
// body limit and write timeout.
func limitHandler(h http.Handler) (limited http.Handler) {
return limitRequestBody(addWriteTimeout(h))
}

View file

@ -25,11 +25,13 @@ const (
// readTimeout is the maximum duration for reading the entire request, // readTimeout is the maximum duration for reading the entire request,
// including the body. // including the body.
readTimeout = 60 * time.Second readTimeout = 60 * time.Second
// readHdrTimeout is the amount of time allowed to read request headers. // readHdrTimeout is the amount of time allowed to read request headers.
readHdrTimeout = 60 * time.Second readHdrTimeout = 60 * time.Second
// writeTimeout is the maximum duration before timing out writes of the // writeTimeout is the maximum duration before timing out writes of the
// response. // response. This limit is overwritten by [addWriteTimeout] middleware.
writeTimeout = 60 * time.Second writeTimeout = 10 * time.Minute
) )
type webConfig struct { type webConfig struct {
@ -169,7 +171,7 @@ func (web *webAPI) start() {
errs := make(chan error, 2) errs := make(chan error, 2)
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies. // Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{}) hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitHandler), &http2.Server{})
// Create a new instance, because the Web is not usable after Shutdown. // Create a new instance, because the Web is not usable after Shutdown.
hostStr := web.conf.BindHost.String() hostStr := web.conf.BindHost.String()
@ -254,7 +256,7 @@ func (web *webAPI) tlsServerLoop() {
CipherSuites: Context.tlsCipherIDs, CipherSuites: Context.tlsCipherIDs,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
}, },
Handler: withMiddlewares(Context.mux, limitRequestBody), Handler: withMiddlewares(Context.mux, limitHandler),
ReadTimeout: web.conf.ReadTimeout, ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout, ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
WriteTimeout: web.conf.WriteTimeout, WriteTimeout: web.conf.WriteTimeout,
@ -288,7 +290,7 @@ func (web *webAPI) mustStartHTTP3(address string) {
CipherSuites: Context.tlsCipherIDs, CipherSuites: Context.tlsCipherIDs,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
}, },
Handler: withMiddlewares(Context.mux, limitRequestBody), Handler: withMiddlewares(Context.mux, limitHandler),
} }
log.Debug("web: starting http/3 server") log.Debug("web: starting http/3 server")