diff --git a/cmd/components.go b/cmd/components.go index 9358f885..92b0556c 100644 --- a/cmd/components.go +++ b/cmd/components.go @@ -39,6 +39,7 @@ func StartStopComponents(r *reporter.Reporter, daemonComponent daemon.Component, <-daemonComponent.Terminated() r.Info().Msg("stopping all components") + daemonComponent.FinishReexec() return nil } diff --git a/common/daemon/root.go b/common/daemon/root.go index 4caec74e..6a310fb5 100644 --- a/common/daemon/root.go +++ b/common/daemon/root.go @@ -9,6 +9,7 @@ package daemon import ( "os" "os/signal" + "sync/atomic" "syscall" "gopkg.in/tomb.v2" @@ -20,6 +21,8 @@ import ( type Component interface { Start() error Stop() error + Reexec() + FinishReexec() Track(t *tomb.Tomb, who string) // Lifecycle @@ -30,8 +33,9 @@ type Component interface { // realComponent is a non-mock implementation of the Component // interface. type realComponent struct { - r *reporter.Reporter - tombs []tombWithOrigin + r *reporter.Reporter + tombs []tombWithOrigin + shouldReexec atomic.Bool lifecycleComponent } @@ -70,11 +74,12 @@ func (c *realComponent) Start() error { c.Terminate() }(t) } - // On signal, terminate + // On signal, terminate or reexec go func() { signals := make(chan os.Signal, 1) signal.Notify(signals, - syscall.SIGINT, syscall.SIGTERM) + syscall.SIGINT, + syscall.SIGTERM) select { case s := <-signals: c.r.Debug().Stringer("signal", s).Msg("signal received") @@ -97,6 +102,30 @@ func (c *realComponent) Stop() error { return nil } +// Reexec will reexecute the current process with the same arguments. +func (c *realComponent) Reexec() { + c.shouldReexec.Store(true) + c.Terminate() +} + +// FinishReexec should be called just before exiting to trigger the real reexec. +func (c *realComponent) FinishReexec() { + if c.shouldReexec.Load() { + executable, err := os.Executable() + if err != nil { + c.r.Err(err).Msg("cannot get executable name") + return + } + + env := os.Environ() + args := append([]string{executable}, os.Args[1:]...) + c.r.Info().Strs("args", args).Msg("reexec in progress") + if err := syscall.Exec(executable, args, env); err != nil { + c.r.Err(err).Msg("cannot reexec") + } + } +} + // Add a new tomb to be tracked. This is only used before Start(). func (c *realComponent) Track(t *tomb.Tomb, who string) { c.tombs = append(c.tombs, tombWithOrigin{ diff --git a/common/daemon/root_test.go b/common/daemon/root_test.go index fe835638..2ea5e3a7 100644 --- a/common/daemon/root_test.go +++ b/common/daemon/root_test.go @@ -5,6 +5,7 @@ package daemon import ( "errors" + "os" "syscall" "testing" "testing/synctest" @@ -62,6 +63,28 @@ func TestTerminateWithSignal(t *testing.T) { } } +func TestReexecWithSignal(t *testing.T) { + + r := reporter.NewMock(t) + c, err := New(r) + if err != nil { + t.Fatalf("New() error:\n%+v", err) + } + helpers.StartStop(t, c) + + c.Reexec() + if os.Getenv("TEST_DAEMON_REEXEC") == "1" { + // This is a way to increase a bit coverage + executable, _ := os.Executable() + os.Remove(executable) + c.FinishReexec() + return + } + os.Setenv("TEST_DAEMON_REEXEC", "1") + c.FinishReexec() + t.Fatalf("No reexec done!") +} + func TestStop(t *testing.T) { r := reporter.NewMock(t) c, err := New(r) diff --git a/common/daemon/tests.go b/common/daemon/tests.go index 1e22e4ad..af9041bc 100644 --- a/common/daemon/tests.go +++ b/common/daemon/tests.go @@ -17,6 +17,8 @@ type MockComponent struct { lifecycleComponent } +var _ Component = &MockComponent{} + // NewMock will create a daemon component that does nothing. func NewMock(t testing.TB) Component { t.Helper() @@ -38,6 +40,14 @@ func (c *MockComponent) Stop() error { return nil } +// Reexec does nothing for the mock implementation +func (c *MockComponent) Reexec() { +} + +// FinishReexec does nothing for the mock implementation +func (c *MockComponent) FinishReexec() { +} + // Track does nothing func (c *MockComponent) Track(_ *tomb.Tomb, _ string) { }