diff --git a/go.mod b/go.mod index 395c7606..f52e313d 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.20 require ( github.com/AdguardTeam/dnsproxy v0.56.0 - github.com/AdguardTeam/golibs v0.17.0 + github.com/AdguardTeam/golibs v0.17.1 github.com/AdguardTeam/urlfilter v0.17.0 github.com/NYTimes/gziphandler v1.1.1 github.com/ameshkov/dnscrypt/v2 v2.2.7 diff --git a/go.sum b/go.sum index 32869da1..637de98c 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/AdguardTeam/dnsproxy v0.56.0 h1:kAg88woRTWTgqVEB2i2Uhze8Lv0JF1zTIAGhe0prjKM= github.com/AdguardTeam/dnsproxy v0.56.0/go.mod h1:fqmehcE3cHFNqKbWQpIjGk7GqBy7ur1v5At499lFjRc= -github.com/AdguardTeam/golibs v0.17.0 h1:oPp2+2kV41qH45AIFbAlHFTPQOQ6JbF+JemjeECFn1g= -github.com/AdguardTeam/golibs v0.17.0/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U= +github.com/AdguardTeam/golibs v0.17.1 h1:j3Ehhld5GI/amcHYG+CF0sJ4OOzAQ06BY3N/iBYJZ1M= +github.com/AdguardTeam/golibs v0.17.1/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U= github.com/AdguardTeam/urlfilter v0.17.0 h1:tUzhtR9wMx704GIP3cibsDQJrixlMHfwoQbYJfPdFow= github.com/AdguardTeam/urlfilter v0.17.0/go.mod h1:bbuZjPUzm/Ip+nz5qPPbwIP+9rZyQbQad8Lt/0fCulU= github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I= diff --git a/internal/aghio/limitedreader.go b/internal/aghio/limitedreader.go deleted file mode 100644 index dbf7edce..00000000 --- a/internal/aghio/limitedreader.go +++ /dev/null @@ -1,60 +0,0 @@ -// Package aghio contains extensions for io package's types and methods -package aghio - -import ( - "fmt" - "io" - - "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/mathutil" -) - -// LimitReachedError records the limit and the operation that caused it. -type LimitReachedError struct { - Limit int64 -} - -// Error implements the [error] interface for *LimitReachedError. -// -// TODO(a.garipov): Think about error string format. -func (lre *LimitReachedError) Error() string { - return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit) -} - -// limitedReader is a wrapper for [io.Reader] limiting the input and dealing -// with errors package. -type limitedReader struct { - r io.Reader - limit int64 - n int64 -} - -// Read implements the [io.Reader] interface. -func (lr *limitedReader) Read(p []byte) (n int, err error) { - if lr.n == 0 { - return 0, &LimitReachedError{ - Limit: lr.limit, - } - } - - p = p[:mathutil.Min(lr.n, int64(len(p)))] - - n, err = lr.r.Read(p) - lr.n -= int64(n) - - return n, err -} - -// LimitReader wraps Reader to make it's Reader stop with ErrLimitReached after -// n bytes read. -func LimitReader(r io.Reader, n int64) (limited io.Reader, err error) { - if n < 0 { - return nil, errors.Error("limit must be non-negative") - } - - return &limitedReader{ - r: r, - limit: n, - n: n, - }, nil -} diff --git a/internal/aghio/limitedreader_test.go b/internal/aghio/limitedreader_test.go deleted file mode 100644 index a85fd051..00000000 --- a/internal/aghio/limitedreader_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package aghio_test - -import ( - "io" - "strings" - "testing" - - "github.com/AdguardTeam/AdGuardHome/internal/aghio" - "github.com/AdguardTeam/golibs/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestLimitReader(t *testing.T) { - testCases := []struct { - wantErrMsg string - name string - n int64 - }{{ - wantErrMsg: "", - name: "positive", - n: 1, - }, { - wantErrMsg: "", - name: "zero", - n: 0, - }, { - wantErrMsg: "limit must be non-negative", - name: "negative", - n: -1, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - _, err := aghio.LimitReader(nil, tc.n) - testutil.AssertErrorMsg(t, tc.wantErrMsg, err) - }) - } -} - -func TestLimitedReader_Read(t *testing.T) { - testCases := []struct { - err error - name string - rStr string - limit int64 - want int - }{{ - err: nil, - name: "perfectly_match", - rStr: "abc", - limit: 3, - want: 3, - }, { - err: io.EOF, - name: "eof", - rStr: "", - limit: 3, - want: 0, - }, { - err: &aghio.LimitReachedError{ - Limit: 0, - }, - name: "limit_reached", - rStr: "abc", - limit: 0, - want: 0, - }, { - err: nil, - name: "truncated", - rStr: "abc", - limit: 2, - want: 2, - }} - - for _, tc := range testCases { - readCloser := io.NopCloser(strings.NewReader(tc.rStr)) - lreader, err := aghio.LimitReader(readCloser, tc.limit) - require.NoError(t, err) - require.NotNil(t, lreader) - - t.Run(tc.name, func(t *testing.T) { - buf := make([]byte, tc.limit+1) - n, rerr := lreader.Read(buf) - require.Equal(t, rerr, tc.err) - - assert.Equal(t, tc.want, n) - }) - } -} - -func TestLimitedReader_LimitReachedError(t *testing.T) { - testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &aghio.LimitReachedError{ - Limit: 0, - }) -} diff --git a/internal/aghos/os_test.go b/internal/aghos/os_test.go index f56a93ac..ebd98420 100644 --- a/internal/aghos/os_test.go +++ b/internal/aghos/os_test.go @@ -4,7 +4,7 @@ import ( "bytes" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/aghio" + "github.com/AdguardTeam/golibs/ioutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -72,11 +72,10 @@ func TestLargestLabeled(t *testing.T) { } t.Run("scanner_fail", func(t *testing.T) { - lr, err := aghio.LimitReader(bytes.NewReader([]byte{1, 2, 3}), 0) - require.NoError(t, err) + lr := ioutil.LimitReader(bytes.NewReader([]byte{1, 2, 3}), 0) - target := &aghio.LimitReachedError{} - _, _, err = parsePSOutput(lr, "", nil) + target := &ioutil.LimitError{} + _, _, err := parsePSOutput(lr, "", nil) require.ErrorAs(t, err, &target) assert.EqualValues(t, 0, target.Limit) diff --git a/internal/filtering/filter.go b/internal/filtering/filter.go index 329b6745..169b2d51 100644 --- a/internal/filtering/filter.go +++ b/internal/filtering/filter.go @@ -504,7 +504,7 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) { } defer func() { err = errors.WithDeferred(err, r.Close()) }() - bufPtr := d.bufPool.Get().(*[]byte) + bufPtr := d.bufPool.Get() defer d.bufPool.Put(bufPtr) p := rulelist.NewParser() @@ -607,7 +607,7 @@ func (d *DNSFilter) load(flt *FilterYAML) (err error) { log.Debug("filtering: file %q, id %d, length %d", fileName, flt.ID, st.Size()) - bufPtr := d.bufPool.Get().(*[]byte) + bufPtr := d.bufPool.Get() defer d.bufPool.Put(bufPtr) p := rulelist.NewParser() diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 5dab0524..c9f52dfc 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -26,6 +26,7 @@ import ( "github.com/AdguardTeam/golibs/mathutil" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" + "github.com/AdguardTeam/golibs/syncutil" "github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter/filterlist" "github.com/AdguardTeam/urlfilter/rules" @@ -232,7 +233,7 @@ type Checker interface { // DNSFilter matches hostnames and DNS requests against filtering rules. type DNSFilter struct { // bufPool is a pool of buffers used for filtering-rule list parsing. - bufPool *sync.Pool + bufPool *syncutil.Pool[[]byte] rulesStorage *filterlist.RuleStorage filteringEngine *urlfilter.DNSEngine @@ -1061,13 +1062,7 @@ func InitModule() { // be non-nil. func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { d = &DNSFilter{ - bufPool: &sync.Pool{ - New: func() (buf any) { - bufVal := make([]byte, rulelist.DefaultRuleBufSize) - - return &bufVal - }, - }, + bufPool: syncutil.NewSlicePool[byte](rulelist.DefaultRuleBufSize), refreshLock: &sync.Mutex{}, safeBrowsingChecker: c.SafeBrowsingChecker, parentalControlChecker: c.ParentalControlChecker, diff --git a/internal/home/authglinet.go b/internal/home/authglinet.go index 81fe8065..7b65a28b 100644 --- a/internal/home/authglinet.go +++ b/internal/home/authglinet.go @@ -9,7 +9,7 @@ import ( "os" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghio" + "github.com/AdguardTeam/golibs/ioutil" "github.com/AdguardTeam/golibs/log" "github.com/josharian/native" ) @@ -83,12 +83,7 @@ func glGetTokenDate(file string) uint32 { } }() - fileReader, err := aghio.LimitReader(f, MaxFileSize) - if err != nil { - log.Error("creating limited reader: %s", err) - - return 0 - } + fileReader := ioutil.LimitReader(f, MaxFileSize) var dateToken uint32 diff --git a/internal/home/middlewares.go b/internal/home/middlewares.go index 5ad02ee0..c6297375 100644 --- a/internal/home/middlewares.go +++ b/internal/home/middlewares.go @@ -4,9 +4,7 @@ import ( "io" "net/http" - "github.com/AdguardTeam/AdGuardHome/internal/aghio" - - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/ioutil" ) // middlerware is a wrapper function signature. @@ -23,12 +21,14 @@ func withMiddlewares(h http.Handler, middlewares ...middleware) (wrapped http.Ha return wrapped } -// defaultReqBodySzLim is the default maximum request body size. -const defaultReqBodySzLim = 64 * 1024 +const ( + // defaultReqBodySzLim is the default maximum request body size. + defaultReqBodySzLim = 64 * 1024 -// largerReqBodySzLim is the maximum request body size for APIs expecting larger -// requests. -const largerReqBodySzLim = 4 * 1024 * 1024 + // largerReqBodySzLim is the maximum request body size for APIs expecting + // larger requests. + largerReqBodySzLim = 4 * 1024 * 1024 +) // expectsLargerRequests shows if this request should use a larger body size // limit. These are exceptions for poorly designed current APIs as well as APIs @@ -52,20 +52,12 @@ func expectsLargerRequests(r *http.Request) (ok bool) { // method limited. func limitRequestBody(h http.Handler) (limited http.Handler) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var err error - - var szLim int64 = defaultReqBodySzLim + var szLim uint64 = defaultReqBodySzLim if expectsLargerRequests(r) { szLim = largerReqBodySzLim } - var reader io.Reader - reader, err = aghio.LimitReader(r.Body, szLim) - if err != nil { - log.Error("limitRequestBody: %s", err) - - return - } + reader := ioutil.LimitReader(r.Body, szLim) // HTTP handlers aren't supposed to call r.Body.Close(), so just // replace the body in a clone. diff --git a/internal/home/middlewares_test.go b/internal/home/middlewares_test.go index 96da52fb..5393503f 100644 --- a/internal/home/middlewares_test.go +++ b/internal/home/middlewares_test.go @@ -7,13 +7,13 @@ import ( "strings" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/aghio" + "github.com/AdguardTeam/golibs/ioutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLimitRequestBody(t *testing.T) { - errReqLimitReached := &aghio.LimitReachedError{ + errReqLimitReached := &ioutil.LimitError{ Limit: defaultReqBodySzLim, } diff --git a/internal/updater/check.go b/internal/updater/check.go index a72e58ee..71dc9582 100644 --- a/internal/updater/check.go +++ b/internal/updater/check.go @@ -8,8 +8,8 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" - "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/ioutil" "github.com/AdguardTeam/golibs/log" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -51,11 +51,7 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) { } defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }() - var r io.Reader - r, err = aghio.LimitReader(resp.Body, MaxResponseSize) - if err != nil { - return VersionInfo{}, fmt.Errorf("updater: LimitReadCloser: %w", err) - } + r := ioutil.LimitReader(resp.Body, MaxResponseSize) // This use of ReadAll is safe, because we just limited the appropriate // ReadCloser. diff --git a/internal/updater/updater.go b/internal/updater/updater.go index 7041963e..bf7a9dae 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -15,9 +15,9 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/ioutil" "github.com/AdguardTeam/golibs/log" ) @@ -328,11 +328,7 @@ func (u *Updater) downloadPackageFile() (err error) { } defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }() - var r io.Reader - r, err = aghio.LimitReader(resp.Body, MaxPackageFileSize) - if err != nil { - return fmt.Errorf("http request failed: %w", err) - } + r := ioutil.LimitReader(resp.Body, MaxPackageFileSize) log.Debug("updater: reading http body") // This use of ReadAll is now safe, because we limited body's Reader. diff --git a/internal/whois/whois.go b/internal/whois/whois.go index ae01304b..b2d20a80 100644 --- a/internal/whois/whois.go +++ b/internal/whois/whois.go @@ -12,9 +12,9 @@ import ( "strings" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/ioutil" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" @@ -62,7 +62,7 @@ type Config struct { CacheTTL time.Duration // MaxConnReadSize is an upper limit in bytes for reading from net.Conn. - MaxConnReadSize int64 + MaxConnReadSize uint64 // MaxRedirects is the maximum redirects count. MaxRedirects int @@ -102,7 +102,7 @@ type Default struct { cacheTTL time.Duration // maxConnReadSize is an upper limit in bytes for reading from net.Conn. - maxConnReadSize int64 + maxConnReadSize uint64 // maxRedirects is the maximum redirects count. maxRedirects int @@ -208,11 +208,7 @@ func (w *Default) query(ctx context.Context, target, serverAddr string) (data [] } defer func() { err = errors.WithDeferred(err, conn.Close()) }() - r, err := aghio.LimitReader(conn, w.maxConnReadSize) - if err != nil { - // Don't wrap the error since it's informative enough as is. - return nil, err - } + r := ioutil.LimitReader(conn, w.maxConnReadSize) _ = conn.SetDeadline(time.Now().Add(w.timeout)) _, err = io.WriteString(conn, target+"\r\n") diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index e89eca6d..c8a9a3d2 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -206,7 +206,6 @@ run_linter gocognit --over='10'\ ./internal/aghalg/\ ./internal/aghchan/\ ./internal/aghhttp/\ - ./internal/aghio/\ ./internal/aghrenameio/\ ./internal/aghtest/\ ./internal/arpdb/\ @@ -243,7 +242,6 @@ run_linter fieldalignment \ ./internal/aghalg/\ ./internal/aghchan/\ ./internal/aghhttp/\ - ./internal/aghio/\ ./internal/aghos/\ ./internal/aghrenameio/\ ./internal/aghtest/\ @@ -272,7 +270,6 @@ run_linter gosec --quiet\ ./internal/aghalg/\ ./internal/aghchan/\ ./internal/aghhttp/\ - ./internal/aghio/\ ./internal/aghnet/\ ./internal/aghos/\ ./internal/aghrenameio/\ diff --git a/scripts/translations/download.go b/scripts/translations/download.go index b9b809c9..123cf350 100644 --- a/scripts/translations/download.go +++ b/scripts/translations/download.go @@ -12,8 +12,8 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/ioutil" "github.com/AdguardTeam/golibs/log" "golang.org/x/exp/slices" ) @@ -148,12 +148,7 @@ func getTranslation(client *http.Client, url string) (data []byte, err error) { // Go on and download the body for inspection. } - limitReader, lrErr := aghio.LimitReader(resp.Body, readLimit) - if lrErr != nil { - // Generally shouldn't happen, since the only error returned by - // [aghio.LimitReader] is an argument error. - panic(fmt.Errorf("limit reading: %w", lrErr)) - } + limitReader := ioutil.LimitReader(resp.Body, readLimit) data, readErr := io.ReadAll(limitReader)