From 5f0e53ded7312298c8376bfac46ea22f824fcf91 Mon Sep 17 00:00:00 2001 From: Dimitry Kolyshev Date: Wed, 14 Jun 2023 10:51:17 +0400 Subject: [PATCH] home: write timeout middleware --- internal/home/middlewares.go | 49 ++++++++++++++++++++++++++++++++++-- internal/home/web.go | 12 +++++---- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/internal/home/middlewares.go b/internal/home/middlewares.go index 5ad02ee0..a32c367c 100644 --- a/internal/home/middlewares.go +++ b/internal/home/middlewares.go @@ -3,13 +3,13 @@ package home import ( "io" "net/http" + "time" "github.com/AdguardTeam/AdGuardHome/internal/aghio" - "github.com/AdguardTeam/golibs/log" ) -// middlerware is a wrapper function signature. +// middleware is a wrapper function signature. type middleware func(http.Handler) http.Handler // withMiddlewares consequently wraps h with all the middlewares. @@ -75,3 +75,48 @@ func limitRequestBody(h http.Handler) (limited http.Handler) { h.ServeHTTP(w, rr) }) } + +const ( + // defaultWriteTimeout is the maximum duration before timing out writes of + // the response. + defaultWriteTimeout = 60 * time.Second + + // longerWriteTimeout is the maximum duration before timing out for APIs + // expecting longer response requests. + longerWriteTimeout = 5 * time.Minute +) + +// expectsLongTimeoutRequests shows if this request should use a bigger write +// timeout value. These are exceptions for poorly designed current APIs as +// well as APIs that are designed to expect large files and requests. Remove +// once the new, better APIs are up. +// +// TODO(d.kolyshev): This could be achieved with [http.NewResponseController] +// with go v1.20. +func expectsLongTimeoutRequests(r *http.Request) (ok bool) { + if r.Method != http.MethodGet { + return false + } + + return r.URL.Path == "/control/querylog/export" +} + +// addWriteTimeout wraps underlying handler h, adding a response write timeout. +func addWriteTimeout(h http.Handler) (limited http.Handler) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var handler http.Handler + if expectsLongTimeoutRequests(r) { + handler = http.TimeoutHandler(h, longerWriteTimeout, "write timeout exceeded") + } else { + handler = http.TimeoutHandler(h, defaultWriteTimeout, "write timeout exceeded") + } + + handler.ServeHTTP(w, r) + }) +} + +// limitHandler wraps underlying handler h with default limits, such as request +// body limit and write timeout. +func limitHandler(h http.Handler) (limited http.Handler) { + return limitRequestBody(addWriteTimeout(h)) +} diff --git a/internal/home/web.go b/internal/home/web.go index e0b45fe5..0e80fb21 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -25,11 +25,13 @@ const ( // readTimeout is the maximum duration for reading the entire request, // including the body. readTimeout = 60 * time.Second + // readHdrTimeout is the amount of time allowed to read request headers. readHdrTimeout = 60 * time.Second + // writeTimeout is the maximum duration before timing out writes of the - // response. - writeTimeout = 60 * time.Second + // response. This limit is overwritten by [addWriteTimeout] middleware. + writeTimeout = 10 * time.Minute ) type webConfig struct { @@ -169,7 +171,7 @@ func (web *webAPI) start() { errs := make(chan error, 2) // Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies. - hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{}) + hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitHandler), &http2.Server{}) // Create a new instance, because the Web is not usable after Shutdown. hostStr := web.conf.BindHost.String() @@ -254,7 +256,7 @@ func (web *webAPI) tlsServerLoop() { CipherSuites: Context.tlsCipherIDs, MinVersion: tls.VersionTLS12, }, - Handler: withMiddlewares(Context.mux, limitRequestBody), + Handler: withMiddlewares(Context.mux, limitHandler), ReadTimeout: web.conf.ReadTimeout, ReadHeaderTimeout: web.conf.ReadHeaderTimeout, WriteTimeout: web.conf.WriteTimeout, @@ -288,7 +290,7 @@ func (web *webAPI) mustStartHTTP3(address string) { CipherSuites: Context.tlsCipherIDs, MinVersion: tls.VersionTLS12, }, - Handler: withMiddlewares(Context.mux, limitRequestBody), + Handler: withMiddlewares(Context.mux, limitHandler), } log.Debug("web: starting http/3 server")