package aghio

import (
	"io"
	"strings"
	"testing"

	"github.com/AdguardTeam/golibs/testutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestLimitReader(t *testing.T) {
	testCases := []struct {
		wantErrMsg string
		name       string
		n          int64
	}{{
		wantErrMsg: "",
		name:       "positive",
		n:          1,
	}, {
		wantErrMsg: "",
		name:       "zero",
		n:          0,
	}, {
		wantErrMsg: "aghio: invalid n in LimitReader: -1",
		name:       "negative",
		n:          -1,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			_, err := LimitReader(nil, tc.n)
			testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
		})
	}
}

func TestLimitedReader_Read(t *testing.T) {
	testCases := []struct {
		err   error
		name  string
		rStr  string
		limit int64
		want  int
	}{{
		err:   nil,
		name:  "perfectly_match",
		rStr:  "abc",
		limit: 3,
		want:  3,
	}, {
		err:   io.EOF,
		name:  "eof",
		rStr:  "",
		limit: 3,
		want:  0,
	}, {
		err: &LimitReachedError{
			Limit: 0,
		},
		name:  "limit_reached",
		rStr:  "abc",
		limit: 0,
		want:  0,
	}, {
		err:   nil,
		name:  "truncated",
		rStr:  "abc",
		limit: 2,
		want:  2,
	}}

	for _, tc := range testCases {
		readCloser := io.NopCloser(strings.NewReader(tc.rStr))
		lreader, err := LimitReader(readCloser, tc.limit)
		require.NoError(t, err)
		require.NotNil(t, lreader)

		t.Run(tc.name, func(t *testing.T) {
			buf := make([]byte, tc.limit+1)
			n, rerr := lreader.Read(buf)
			require.Equal(t, rerr, tc.err)

			assert.Equal(t, tc.want, n)
		})
	}
}

func TestLimitedReader_LimitReachedError(t *testing.T) {
	testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &LimitReachedError{
		Limit: 0,
	})
}