diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 45512ee..76d9e18 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -1,7 +1,8 @@ # Release Notes -## 3.4.2 +## 3.5.0 +* [Allow graceful shutdowns](https://code.forgejo.org/forgejo/runner/pulls/202): when receiving a signal (INT or TERM) wait for running jobs to complete (up to shutdown_timeout). * [Fix label declaration](https://code.forgejo.org/forgejo/runner/pulls/176): Runner in daemon mode now takes labels found in config.yml into account when declaration was successful. * [Fix the docker compose example](https://code.forgejo.org/forgejo/runner/pulls/175) to workaround the race on labels. * [Fix the kubernetes dind example](https://code.forgejo.org/forgejo/runner/pulls/169). diff --git a/internal/app/cmd/daemon.go b/internal/app/cmd/daemon.go index 8e47bf6..9eb7c6e 100644 --- a/internal/app/cmd/daemon.go +++ b/internal/app/cmd/daemon.go @@ -120,8 +120,18 @@ func runDaemon(ctx context.Context, configFile *string) func(cmd *cobra.Command, poller := poll.New(cfg, cli, runner) - poller.Poll(ctx) + go poller.Poll() + <-ctx.Done() + log.Infof("runner: %s shutdown initiated, waiting [runner].shutdown_timeout=%s for running jobs to complete before shutting down", resp.Msg.Runner.Name, cfg.Runner.ShutdownTimeout) + + ctx, cancel := context.WithTimeout(context.Background(), cfg.Runner.ShutdownTimeout) + defer cancel() + + err = poller.Shutdown(ctx) + if err != nil { + log.Warnf("runner: %s cancelled in progress jobs during shutdown", resp.Msg.Runner.Name) + } return nil } } diff --git a/internal/app/poll/poller.go b/internal/app/poll/poller.go index f79e98e..6198fe0 100644 --- a/internal/app/poll/poller.go +++ b/internal/app/poll/poller.go @@ -20,49 +20,100 @@ import ( "gitea.com/gitea/act_runner/internal/pkg/config" ) -type Poller struct { +const PollerID = "PollerID" + +type Poller interface { + Poll() + Shutdown(ctx context.Context) error +} + +type poller struct { client client.Client - runner *run.Runner + runner run.RunnerInterface cfg *config.Config tasksVersion atomic.Int64 // tasksVersion used to store the version of the last task fetched from the Gitea. + + pollingCtx context.Context + shutdownPolling context.CancelFunc + + jobsCtx context.Context + shutdownJobs context.CancelFunc + + done chan any } -func New(cfg *config.Config, client client.Client, runner *run.Runner) *Poller { - return &Poller{ - client: client, - runner: runner, - cfg: cfg, - } +func New(cfg *config.Config, client client.Client, runner run.RunnerInterface) Poller { + return (&poller{}).init(cfg, client, runner) } -func (p *Poller) Poll(ctx context.Context) { +func (p *poller) init(cfg *config.Config, client client.Client, runner run.RunnerInterface) Poller { + pollingCtx, shutdownPolling := context.WithCancel(context.Background()) + + jobsCtx, shutdownJobs := context.WithCancel(context.Background()) + + done := make(chan any) + + p.client = client + p.runner = runner + p.cfg = cfg + + p.pollingCtx = pollingCtx + p.shutdownPolling = shutdownPolling + + p.jobsCtx = jobsCtx + p.shutdownJobs = shutdownJobs + p.done = done + + return p +} + +func (p *poller) Poll() { limiter := rate.NewLimiter(rate.Every(p.cfg.Runner.FetchInterval), 1) wg := &sync.WaitGroup{} for i := 0; i < p.cfg.Runner.Capacity; i++ { wg.Add(1) - go p.poll(ctx, wg, limiter) + go p.poll(i, wg, limiter) } wg.Wait() + + // signal the poller is finished + close(p.done) } -func (p *Poller) poll(ctx context.Context, wg *sync.WaitGroup, limiter *rate.Limiter) { +func (p *poller) Shutdown(ctx context.Context) error { + p.shutdownPolling() + + select { + case <-p.done: + log.Trace("all jobs are complete") + return nil + + case <-ctx.Done(): + log.Trace("forcing the jobs to shutdown") + p.shutdownJobs() + <-p.done + log.Trace("all jobs have been shutdown") + return ctx.Err() + } +} + +func (p *poller) poll(id int, wg *sync.WaitGroup, limiter *rate.Limiter) { + log.Infof("[poller %d] launched", id) defer wg.Done() for { - if err := limiter.Wait(ctx); err != nil { - if ctx.Err() != nil { - log.WithError(err).Debug("limiter wait failed") - } + if err := limiter.Wait(p.pollingCtx); err != nil { + log.Infof("[poller %d] shutdown", id) return } - task, ok := p.fetchTask(ctx) + task, ok := p.fetchTask(p.pollingCtx) if !ok { continue } - p.runTaskWithRecover(ctx, task) + p.runTaskWithRecover(p.jobsCtx, task) } } -func (p *Poller) runTaskWithRecover(ctx context.Context, task *runnerv1.Task) { +func (p *poller) runTaskWithRecover(ctx context.Context, task *runnerv1.Task) { defer func() { if r := recover(); r != nil { err := fmt.Errorf("panic: %v", r) @@ -75,7 +126,7 @@ func (p *Poller) runTaskWithRecover(ctx context.Context, task *runnerv1.Task) { } } -func (p *Poller) fetchTask(ctx context.Context) (*runnerv1.Task, bool) { +func (p *poller) fetchTask(ctx context.Context) (*runnerv1.Task, bool) { reqCtx, cancel := context.WithTimeout(ctx, p.cfg.Runner.FetchTimeout) defer cancel() @@ -85,10 +136,15 @@ func (p *Poller) fetchTask(ctx context.Context) (*runnerv1.Task, bool) { TasksVersion: v, })) if errors.Is(err, context.DeadlineExceeded) { + log.Trace("deadline exceeded") err = nil } if err != nil { - log.WithError(err).Error("failed to fetch task") + if errors.Is(err, context.Canceled) { + log.WithError(err).Debugf("shutdown, fetch task canceled") + } else { + log.WithError(err).Error("failed to fetch task") + } return nil, false } diff --git a/internal/app/poll/poller_test.go b/internal/app/poll/poller_test.go new file mode 100644 index 0000000..2fdd8d6 --- /dev/null +++ b/internal/app/poll/poller_test.go @@ -0,0 +1,263 @@ +// Copyright The Forgejo Authors. +// SPDX-License-Identifier: MIT + +package poll + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/bufbuild/connect-go" + + "code.gitea.io/actions-proto-go/ping/v1/pingv1connect" + runnerv1 "code.gitea.io/actions-proto-go/runner/v1" + "code.gitea.io/actions-proto-go/runner/v1/runnerv1connect" + "gitea.com/gitea/act_runner/internal/pkg/config" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +type mockPoller struct { + poller +} + +func (o *mockPoller) Poll() { + o.poller.Poll() +} + +type mockClient struct { + pingv1connect.PingServiceClient + runnerv1connect.RunnerServiceClient + + sleep time.Duration + cancel bool + err error + noTask bool +} + +func (o mockClient) Address() string { + return "" +} + +func (o mockClient) Insecure() bool { + return true +} + +func (o *mockClient) FetchTask(ctx context.Context, req *connect.Request[runnerv1.FetchTaskRequest]) (*connect.Response[runnerv1.FetchTaskResponse], error) { + if o.sleep > 0 { + select { + case <-ctx.Done(): + log.Trace("fetch task done") + return nil, context.DeadlineExceeded + case <-time.After(o.sleep): + log.Trace("slept") + return nil, fmt.Errorf("unexpected") + } + } + if o.cancel { + return nil, context.Canceled + } + if o.err != nil { + return nil, o.err + } + task := &runnerv1.Task{} + if o.noTask { + task = nil + o.noTask = false + } + + return connect.NewResponse(&runnerv1.FetchTaskResponse{ + Task: task, + TasksVersion: int64(1), + }), nil +} + +type mockRunner struct { + cfg *config.Runner + log chan string + panics bool + err error +} + +func (o *mockRunner) Run(ctx context.Context, task *runnerv1.Task) error { + o.log <- "runner starts" + if o.panics { + log.Trace("panics") + o.log <- "runner panics" + o.panics = false + panic("whatever") + } + if o.err != nil { + log.Trace("error") + o.log <- "runner error" + err := o.err + o.err = nil + return err + } + for { + select { + case <-ctx.Done(): + log.Trace("shutdown") + o.log <- "runner shutdown" + return nil + case <-time.After(o.cfg.Timeout): + log.Trace("after") + o.log <- "runner timeout" + return nil + } + } +} + +func setTrace(t *testing.T) { + t.Helper() + log.SetReportCaller(true) + log.SetLevel(log.TraceLevel) +} + +func TestPoller_New(t *testing.T) { + p := New(&config.Config{}, &mockClient{}, &mockRunner{}) + assert.NotNil(t, p) +} + +func TestPoller_Runner(t *testing.T) { + setTrace(t) + for _, testCase := range []struct { + name string + timeout time.Duration + noTask bool + panics bool + err error + expected string + contextTimeout time.Duration + }{ + { + name: "Simple", + timeout: 10 * time.Second, + expected: "runner shutdown", + }, + { + name: "Panics", + timeout: 10 * time.Second, + panics: true, + expected: "runner panics", + }, + { + name: "Error", + timeout: 10 * time.Second, + err: fmt.Errorf("ERROR"), + expected: "runner error", + }, + { + name: "PollTaskError", + timeout: 10 * time.Second, + noTask: true, + expected: "runner shutdown", + }, + { + name: "ShutdownTimeout", + timeout: 1 * time.Second, + contextTimeout: 1 * time.Minute, + expected: "runner timeout", + }, + } { + t.Run(testCase.name, func(t *testing.T) { + runnerLog := make(chan string, 3) + configRunner := config.Runner{ + FetchInterval: 1, + Capacity: 1, + Timeout: testCase.timeout, + } + p := &mockPoller{} + p.init( + &config.Config{ + Runner: configRunner, + }, + &mockClient{ + noTask: testCase.noTask, + }, + &mockRunner{ + cfg: &configRunner, + log: runnerLog, + panics: testCase.panics, + err: testCase.err, + }) + go p.Poll() + assert.Equal(t, "runner starts", <-runnerLog) + var ctx context.Context + var cancel context.CancelFunc + if testCase.contextTimeout > 0 { + ctx, cancel = context.WithTimeout(context.Background(), testCase.contextTimeout) + defer cancel() + } else { + ctx, cancel = context.WithCancel(context.Background()) + cancel() + } + p.Shutdown(ctx) + <-p.done + assert.Equal(t, testCase.expected, <-runnerLog) + }) + } +} + +func TestPoller_Fetch(t *testing.T) { + setTrace(t) + for _, testCase := range []struct { + name string + noTask bool + sleep time.Duration + err error + cancel bool + success bool + }{ + { + name: "Success", + success: true, + }, + { + name: "Timeout", + sleep: 100 * time.Millisecond, + }, + { + name: "Canceled", + cancel: true, + }, + { + name: "NoTask", + noTask: true, + }, + { + name: "Error", + err: fmt.Errorf("random error"), + }, + } { + t.Run(testCase.name, func(t *testing.T) { + configRunner := config.Runner{ + FetchTimeout: 1 * time.Millisecond, + } + p := &mockPoller{} + p.init( + &config.Config{ + Runner: configRunner, + }, + &mockClient{ + sleep: testCase.sleep, + cancel: testCase.cancel, + noTask: testCase.noTask, + err: testCase.err, + }, + &mockRunner{}, + ) + task, ok := p.fetchTask(context.Background()) + if testCase.success { + assert.True(t, ok) + assert.NotNil(t, task) + } else { + assert.False(t, ok) + assert.Nil(t, task) + } + }) + } +} diff --git a/internal/app/run/runner.go b/internal/app/run/runner.go index 69001d1..b17705d 100644 --- a/internal/app/run/runner.go +++ b/internal/app/run/runner.go @@ -41,6 +41,10 @@ type Runner struct { runningTasks sync.Map } +type RunnerInterface interface { + Run(ctx context.Context, task *runnerv1.Task) error +} + func NewRunner(cfg *config.Config, reg *config.Registration, cli client.Client) *Runner { ls := labels.Labels{} for _, v := range reg.Labels { diff --git a/internal/pkg/config/config.example.yaml b/internal/pkg/config/config.example.yaml index bc26489..fa40f5c 100644 --- a/internal/pkg/config/config.example.yaml +++ b/internal/pkg/config/config.example.yaml @@ -23,7 +23,13 @@ runner: # Please note that the Forgejo instance also has a timeout (3h by default) for the job. # So the job could be stopped by the Forgejo instance if it's timeout is shorter than this. timeout: 3h - # Whether skip verifying the TLS certificate of the Forgejo instance. + # The timeout for the runner to wait for running jobs to finish when + # shutting down because a TERM or INT signal has been received. Any + # running jobs that haven't finished after this timeout will be + # cancelled. + # If unset or zero the jobs will be cancelled immediately. + shutdown_timeout: 3h + # Whether skip verifying the TLS certificate of the instance. insecure: false # The timeout for fetching the job from the Forgejo instance. fetch_timeout: 5s diff --git a/internal/pkg/config/config.go b/internal/pkg/config/config.go index a7bb977..5c260fb 100644 --- a/internal/pkg/config/config.go +++ b/internal/pkg/config/config.go @@ -21,15 +21,16 @@ type Log struct { // Runner represents the configuration for the runner. type Runner struct { - File string `yaml:"file"` // File specifies the file path for the runner. - Capacity int `yaml:"capacity"` // Capacity specifies the capacity of the runner. - Envs map[string]string `yaml:"envs"` // Envs stores environment variables for the runner. - EnvFile string `yaml:"env_file"` // EnvFile specifies the path to the file containing environment variables for the runner. - Timeout time.Duration `yaml:"timeout"` // Timeout specifies the duration for runner timeout. - Insecure bool `yaml:"insecure"` // Insecure indicates whether the runner operates in an insecure mode. - FetchTimeout time.Duration `yaml:"fetch_timeout"` // FetchTimeout specifies the timeout duration for fetching resources. - FetchInterval time.Duration `yaml:"fetch_interval"` // FetchInterval specifies the interval duration for fetching resources. - Labels []string `yaml:"labels"` // Labels specifies the labels of the runner. Labels are declared on each startup + File string `yaml:"file"` // File specifies the file path for the runner. + Capacity int `yaml:"capacity"` // Capacity specifies the capacity of the runner. + Envs map[string]string `yaml:"envs"` // Envs stores environment variables for the runner. + EnvFile string `yaml:"env_file"` // EnvFile specifies the path to the file containing environment variables for the runner. + Timeout time.Duration `yaml:"timeout"` // Timeout specifies the duration for runner timeout. + ShutdownTimeout time.Duration `yaml:"shutdown_timeout"` // ShutdownTimeout specifies the duration to wait for running jobs to complete during a shutdown of the runner. + Insecure bool `yaml:"insecure"` // Insecure indicates whether the runner operates in an insecure mode. + FetchTimeout time.Duration `yaml:"fetch_timeout"` // FetchTimeout specifies the timeout duration for fetching resources. + FetchInterval time.Duration `yaml:"fetch_interval"` // FetchInterval specifies the interval duration for fetching resources. + Labels []string `yaml:"labels"` // Labels specify the labels of the runner. Labels are declared on each startup } // Cache represents the configuration for caching.