AdGuardHome/internal/aghrenameio/renameio_test.go
2024-05-15 13:34:12 +03:00

100 lines
2.1 KiB
Go

package aghrenameio_test
import (
"io/fs"
"os"
"path/filepath"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghrenameio"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testPerm is the common permission mode for tests.
const testPerm fs.FileMode = 0o644
// Common file data for tests.
var (
initialData = []byte("initial data\n")
newData = []byte("new data\n")
)
func TestPendingFile(t *testing.T) {
t.Parallel()
targetPath := newInitialFile(t)
f, err := aghrenameio.NewPendingFile(targetPath, testPerm)
require.NoError(t, err)
_, err = f.Write(newData)
require.NoError(t, err)
err = f.CloseReplace()
require.NoError(t, err)
gotData, err := os.ReadFile(targetPath)
require.NoError(t, err)
assert.Equal(t, newData, gotData)
}
// newInitialFile is a test helper that returns the path to the file containing
// [initialData].
func newInitialFile(t *testing.T) (targetPath string) {
t.Helper()
dir := t.TempDir()
targetPath = filepath.Join(dir, "target")
err := os.WriteFile(targetPath, initialData, 0o644)
require.NoError(t, err)
return targetPath
}
func TestWithDeferredCleanup(t *testing.T) {
t.Parallel()
const testError errors.Error = "test error"
testCases := []struct {
error error
name string
wantErrMsg string
wantData []byte
}{{
name: "success",
error: nil,
wantErrMsg: "",
wantData: newData,
}, {
name: "error",
error: testError,
wantErrMsg: testError.Error(),
wantData: initialData,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
targetPath := newInitialFile(t)
f, err := aghrenameio.NewPendingFile(targetPath, testPerm)
require.NoError(t, err)
_, err = f.Write(newData)
require.NoError(t, err)
err = aghrenameio.WithDeferredCleanup(tc.error, f)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
gotData, err := os.ReadFile(targetPath)
require.NoError(t, err)
assert.Equal(t, tc.wantData, gotData)
})
}
}