diff --git a/internal/home/config.go b/internal/home/config.go index 810ec24e..8ddb691d 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -162,6 +162,10 @@ type configuration struct { // SchemaVersion is the version of the configuration schema. See // [configmigrate.LastSchemaVersion]. SchemaVersion uint `yaml:"schema_version"` + + // UnsafeCustomUpdateIndexURL is the URL to the custom update index. It's + // only used in testing purposes and should not be used in release. + UnsafeCustomUpdateIndexURL bool `yaml:"unsafe_custom_update_index_url,omitempty"` } // httpConfig is a block with HTTP configuration params. diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index 50a1a6f3..acf3e0dd 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -77,28 +77,28 @@ func (web *webAPI) requestVersionInfo(resp *versionResponse, recheck bool) (err updater := web.conf.updater for i := 0; i != 3; i++ { resp.VersionInfo, err = updater.VersionInfo(recheck) - if err != nil { - var terr temporaryError - if errors.As(err, &terr) && terr.Temporary() { - // Temporary network error. This case may happen while we're - // restarting our DNS server. Log and sleep for some time. - // - // See https://github.com/AdguardTeam/AdGuardHome/issues/934. - d := time.Duration(i) * time.Second - log.Info("update: temp net error: %q; sleeping for %s and retrying", err, d) - time.Sleep(d) + if err == nil { + return nil + } - continue - } + var terr temporaryError + if errors.As(err, &terr) && terr.Temporary() { + // Temporary network error. This case may happen while we're + // restarting our DNS server. Log and sleep for some time. + // + // See https://github.com/AdguardTeam/AdGuardHome/issues/934. + d := time.Duration(i) * time.Second + log.Info("update: temp net error: %q; sleeping for %s and retrying", err, d) + time.Sleep(d) + + continue } break } if err != nil { - vcu := updater.VersionCheckURL() - - return fmt.Errorf("getting version info from %s: %w", vcu, err) + return fmt.Errorf("getting version info: %w", err) } return nil diff --git a/internal/home/home.go b/internal/home/home.go index 4c2771e4..d6f11489 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -604,29 +604,9 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) { execPath, err := os.Executable() fatalOnError(errors.Annotate(err, "getting executable path: %w")) - u := &url.URL{ - Scheme: urlutil.SchemeHTTPS, - // TODO(a.garipov): Make configurable. - Host: "static.adtidy.org", - Path: path.Join("adguardhome", version.Channel(), "version.json"), - } - confPath := configFilePath() - log.Debug("using config path %q for updater", confPath) - upd := updater.NewUpdater(&updater.Config{ - Client: config.Filtering.HTTPClient, - Version: version.Version(), - Channel: version.Channel(), - GOARCH: runtime.GOARCH, - GOOS: runtime.GOOS, - GOARM: version.GOARM(), - GOMIPS: version.GOMIPS(), - WorkDir: Context.workDir, - ConfName: confPath, - ExecPath: execPath, - VersionCheckURL: u.String(), - }) + upd := newUpdater(Context.workDir, confPath, execPath, config) // TODO(e.burkov): This could be made earlier, probably as the option's // effect. @@ -698,6 +678,48 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) { <-done } +// newUpdater creates a new AdGuard Home updater. +func newUpdater(workDir, confPath, execPath string, config *configuration) (upd *updater.Updater) { + // envName is the name of the environment variable that can be used to + // override the default version check URL. + const envName = "ADGUARD_HOME_TEST_UPDATE_VERSION_URL" + + var versionURL *url.URL + if version.Channel() == version.ChannelRelease || !config.UnsafeCustomUpdateIndexURL { + // Go on, use the default URL. + } else if versionURLStr, ok := os.LookupEnv(envName); ok { + var err error + versionURL, err = url.Parse(versionURLStr) + if err != nil { + log.Error(envName+" is not a valid URL: %s", err) + } + } + + if versionURL == nil { + versionURL = &url.URL{ + Scheme: urlutil.SchemeHTTPS, + Host: "static.adtidy.org", + Path: path.Join("adguardhome", version.Channel(), "version.json"), + } + } + + log.Debug("using config path %q for updater", confPath) + + return updater.NewUpdater(&updater.Config{ + Client: config.Filtering.HTTPClient, + Version: version.Version(), + Channel: version.Channel(), + GOARCH: runtime.GOARCH, + GOOS: runtime.GOOS, + GOARM: version.GOARM(), + GOMIPS: version.GOMIPS(), + WorkDir: workDir, + ConfName: confPath, + ExecPath: execPath, + VersionCheckURL: versionURL, + }) +} + // initUsers initializes context auth module. Clears config users field. func initUsers() (auth *Auth, err error) { sessFilename := filepath.Join(Context.getDataDir(), "sessions.db") @@ -1018,8 +1040,7 @@ func cmdlineUpdate(opts options, upd *updater.Updater, l *slog.Logger) { info, err := upd.VersionInfo(true) if err != nil { - vcu := upd.VersionCheckURL() - log.Error("getting version info from %s: %s", vcu, err) + log.Error("getting version info: %s", err) os.Exit(1) } diff --git a/internal/updater/check.go b/internal/updater/check.go index 84da6281..2a3e2cfe 100644 --- a/internal/updater/check.go +++ b/internal/updater/check.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/ioutil" "github.com/AdguardTeam/golibs/log" + "github.com/c2h5oh/datasize" ) // TODO(a.garipov): Make configurable. @@ -28,8 +29,9 @@ type VersionInfo struct { CanAutoUpdate aghalg.NullBool `json:"can_autoupdate,omitempty"` } -// MaxResponseSize is responses on server's requests maximum length in bytes. -const MaxResponseSize = 64 * 1024 +// maxVersionRespSize is the maximum length in bytes for version information +// response. +const maxVersionRespSize datasize.ByteSize = 64 * datasize.KB // VersionInfo downloads the latest version information. If forceRecheck is // false and there are cached results, those results are returned. @@ -51,7 +53,7 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) { } defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }() - r := ioutil.LimitReader(resp.Body, MaxResponseSize) + r := ioutil.LimitReader(resp.Body, maxVersionRespSize.Bytes()) // This use of ReadAll is safe, because we just limited the appropriate // ReadCloser. diff --git a/internal/updater/check_test.go b/internal/updater/check_test.go index e61ba443..5a7c0f5d 100644 --- a/internal/updater/check_test.go +++ b/internal/updater/check_test.go @@ -51,9 +51,11 @@ func TestUpdater_VersionInfo(t *testing.T) { })) t.Cleanup(srv.Close) - fakeURL, err := url.JoinPath(srv.URL, "adguardhome", version.ChannelBeta, "version.json") + srvURL, err := url.Parse(srv.URL) require.NoError(t, err) + fakeURL := srvURL.JoinPath("adguardhome", version.ChannelBeta, "version.json") + u := updater.NewUpdater(&updater.Config{ Client: srv.Client(), Version: "v0.103.0-beta.1", @@ -134,7 +136,7 @@ func TestUpdater_VersionInfo_others(t *testing.T) { GOARCH: tc.arch, GOARM: tc.arm, GOMIPS: tc.mips, - VersionCheckURL: fakeURL.String(), + VersionCheckURL: fakeURL, }) info, err := u.VersionInfo(false) diff --git a/internal/updater/updater.go b/internal/updater/updater.go index 85cb69b0..704ba990 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -9,6 +9,7 @@ import ( "io" "io/fs" "net/http" + "net/url" "os" "os/exec" "path/filepath" @@ -65,6 +66,9 @@ type Updater struct { type Config struct { Client *http.Client + // VersionCheckURL is url to the latest version announcement. + VersionCheckURL *url.URL + Version string Channel string GOARCH string @@ -81,9 +85,6 @@ type Config struct { // ExecPath is path to the executable file. ExecPath string - - // VersionCheckURL is url to the latest version announcement. - VersionCheckURL string } // NewUpdater creates a new Updater. @@ -101,7 +102,7 @@ func NewUpdater(conf *Config) *Updater { confName: conf.ConfName, workDir: conf.WorkDir, execPath: conf.ExecPath, - versionCheckURL: conf.VersionCheckURL, + versionCheckURL: conf.VersionCheckURL.String(), mu: &sync.RWMutex{}, } @@ -167,14 +168,6 @@ func (u *Updater) NewVersion() (nv string) { return u.newVersion } -// VersionCheckURL returns the version check URL. -func (u *Updater) VersionCheckURL() (vcu string) { - u.mu.RLock() - defer u.mu.RUnlock() - - return u.versionCheckURL -} - // prepare fills all necessary fields in Updater object. func (u *Updater) prepare() (err error) { u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion)) diff --git a/internal/updater/updater_internal_test.go b/internal/updater/updater_internal_test.go index e233db3d..0ed94641 100644 --- a/internal/updater/updater_internal_test.go +++ b/internal/updater/updater_internal_test.go @@ -1,6 +1,7 @@ package updater import ( + "net/url" "os" "path/filepath" "testing" @@ -59,6 +60,9 @@ func TestUpdater_internal(t *testing.T) { ExecPath: exePath, WorkDir: wd, ConfName: yamlPath, + // TODO(e.burkov): Rewrite the test to use a fake version check + // URL with a fake URLs for the package files. + VersionCheckURL: &url.URL{}, }) u.newVersion = "v0.103.1" @@ -72,36 +76,40 @@ func TestUpdater_internal(t *testing.T) { u.clean() - // check backup files - d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml")) - require.NoError(t, err) + t.Run("backup", func(t *testing.T) { + var d []byte + d, err = os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml")) + require.NoError(t, err) - assert.Equal(t, "AdGuardHome.yaml", string(d)) + assert.Equal(t, "AdGuardHome.yaml", string(d)) - d, err = os.ReadFile(filepath.Join(wd, "agh-backup", tc.exeName)) - require.NoError(t, err) + d, err = os.ReadFile(filepath.Join(wd, "agh-backup", tc.exeName)) + require.NoError(t, err) - assert.Equal(t, tc.exeName, string(d)) + assert.Equal(t, tc.exeName, string(d)) + }) - // check updated files - d, err = os.ReadFile(exePath) - require.NoError(t, err) + t.Run("updated", func(t *testing.T) { + var d []byte + d, err = os.ReadFile(exePath) + require.NoError(t, err) - assert.Equal(t, "1", string(d)) + assert.Equal(t, "1", string(d)) - d, err = os.ReadFile(readmePath) - require.NoError(t, err) + d, err = os.ReadFile(readmePath) + require.NoError(t, err) - assert.Equal(t, "2", string(d)) + assert.Equal(t, "2", string(d)) - d, err = os.ReadFile(licensePath) - require.NoError(t, err) + d, err = os.ReadFile(licensePath) + require.NoError(t, err) - assert.Equal(t, "3", string(d)) + assert.Equal(t, "3", string(d)) - d, err = os.ReadFile(yamlPath) - require.NoError(t, err) + d, err = os.ReadFile(yamlPath) + require.NoError(t, err) - assert.Equal(t, "AdGuardHome.yaml", string(d)) + assert.Equal(t, "AdGuardHome.yaml", string(d)) + }) } } diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go index 4af567c0..735d9c99 100644 --- a/internal/updater/updater_test.go +++ b/internal/updater/updater_test.go @@ -65,7 +65,10 @@ func TestUpdater_Update(t *testing.T) { srv := httptest.NewServer(mux) t.Cleanup(srv.Close) - versionCheckURL, err := url.JoinPath(srv.URL, versionPath) + srvURL, err := url.Parse(srv.URL) + require.NoError(t, err) + + versionCheckURL := srvURL.JoinPath(versionPath) require.NoError(t, err) u := updater.NewUpdater(&updater.Config{