diff --git a/core/proc/shutdown.go b/core/proc/shutdown.go index d5072cb15826..e3fa7174088a 100644 --- a/core/proc/shutdown.go +++ b/core/proc/shutdown.go @@ -13,19 +13,26 @@ import ( "github.com/zeromicro/go-zero/core/threading" ) -type ProcConf struct { - WrapUpTime time.Duration `json:",default=1s"` - WaitTime time.Duration `json:",default=5.5s"` -} +const ( + defaultWrapUpTime = time.Second + // why we use 5500 milliseconds is because most of our queue are blocking mode with 5 seconds + defaultWaitTime = 5500 * time.Millisecond +) var ( wrapUpListeners = new(listenerManager) shutdownListeners = new(listenerManager) - wrapUpTime = time.Second - // why we use 5500 milliseconds is because most of our queue are blocking mode with 5 seconds - delayTimeBeforeForceQuit = 5500 * time.Millisecond + wrapUpTime = defaultWrapUpTime + waitTime = defaultWaitTime + shutdownLock sync.Mutex ) +// ShutdownConf defines the shutdown configuration for the process. +type ShutdownConf struct { + WrapUpTime time.Duration `json:",default=1s"` + WaitTime time.Duration `json:",default=5.5s"` +} + // AddShutdownListener adds fn as a shutdown listener. // The returned func can be used to wait for fn getting called. func AddShutdownListener(fn func()) (waitForCalled func()) { @@ -40,12 +47,21 @@ func AddWrapUpListener(fn func()) (waitForCalled func()) { // SetTimeToForceQuit sets the waiting time before force quitting. func SetTimeToForceQuit(duration time.Duration) { - delayTimeBeforeForceQuit = duration + shutdownLock.Lock() + defer shutdownLock.Unlock() + waitTime = duration } -func Setup(conf ProcConf) { - wrapUpTime = conf.WrapUpTime - delayTimeBeforeForceQuit = conf.WaitTime +func Setup(conf ShutdownConf) { + shutdownLock.Lock() + defer shutdownLock.Unlock() + + if conf.WrapUpTime > 0 { + wrapUpTime = conf.WrapUpTime + } + if conf.WaitTime > 0 { + waitTime = conf.WaitTime + } } // Shutdown calls the registered shutdown listeners, only for test purpose. @@ -67,8 +83,12 @@ func gracefulStop(signals chan os.Signal, sig syscall.Signal) { time.Sleep(wrapUpTime) go shutdownListeners.notifyListeners() - time.Sleep(delayTimeBeforeForceQuit - wrapUpTime) - logx.Infof("Still alive after %v, going to force kill the process...", delayTimeBeforeForceQuit) + shutdownLock.Lock() + remainingTime := waitTime - wrapUpTime + shutdownLock.Unlock() + + time.Sleep(remainingTime) + logx.Infof("Still alive after %v, going to force kill the process...", waitTime) _ = syscall.Kill(syscall.Getpid(), sig) } diff --git a/core/proc/shutdown_test.go b/core/proc/shutdown_test.go index d5f5869bd67c..4f3aed59279e 100644 --- a/core/proc/shutdown_test.go +++ b/core/proc/shutdown_test.go @@ -11,8 +11,12 @@ import ( ) func TestShutdown(t *testing.T) { + t.Cleanup(restoreSettings) + SetTimeToForceQuit(time.Hour) - assert.Equal(t, time.Hour, delayTimeBeforeForceQuit) + shutdownLock.Lock() + assert.Equal(t, time.Hour, waitTime) + shutdownLock.Unlock() var val int called := AddWrapUpListener(func() { @@ -31,8 +35,12 @@ func TestShutdown(t *testing.T) { } func TestShutdownWithMultipleServices(t *testing.T) { + t.Cleanup(restoreSettings) + SetTimeToForceQuit(time.Hour) - assert.Equal(t, time.Hour, delayTimeBeforeForceQuit) + shutdownLock.Lock() + assert.Equal(t, time.Hour, waitTime) + shutdownLock.Unlock() var val int32 called1 := AddShutdownListener(func() { @@ -49,8 +57,12 @@ func TestShutdownWithMultipleServices(t *testing.T) { } func TestWrapUpWithMultipleServices(t *testing.T) { + t.Cleanup(restoreSettings) + SetTimeToForceQuit(time.Hour) - assert.Equal(t, time.Hour, delayTimeBeforeForceQuit) + shutdownLock.Lock() + assert.Equal(t, time.Hour, waitTime) + shutdownLock.Unlock() var val int32 called1 := AddWrapUpListener(func() { @@ -67,6 +79,8 @@ func TestWrapUpWithMultipleServices(t *testing.T) { } func TestNotifyMoreThanOnce(t *testing.T) { + t.Cleanup(restoreSettings) + ch := make(chan struct{}, 1) go func() { @@ -97,10 +111,36 @@ func TestNotifyMoreThanOnce(t *testing.T) { } func TestSetup(t *testing.T) { - Setup(ProcConf{ - WrapUpTime: time.Second * 2, - WaitTime: time.Second * 30, + t.Run("valid time", func(t *testing.T) { + defer restoreSettings() + + Setup(ShutdownConf{ + WrapUpTime: time.Second * 2, + WaitTime: time.Second * 30, + }) + + shutdownLock.Lock() + assert.Equal(t, time.Second*2, wrapUpTime) + assert.Equal(t, time.Second*30, waitTime) + shutdownLock.Unlock() }) - assert.Equal(t, time.Second*2, wrapUpTime) - assert.Equal(t, time.Second*30, delayTimeBeforeForceQuit) + + t.Run("valid time", func(t *testing.T) { + defer restoreSettings() + + Setup(ShutdownConf{}) + + shutdownLock.Lock() + assert.Equal(t, defaultWrapUpTime, wrapUpTime) + assert.Equal(t, defaultWaitTime, waitTime) + shutdownLock.Unlock() + }) +} + +func restoreSettings() { + shutdownLock.Lock() + defer shutdownLock.Unlock() + + wrapUpTime = defaultWrapUpTime + waitTime = defaultWaitTime } diff --git a/core/service/serviceconf.go b/core/service/serviceconf.go index fc66e91e281b..bc3ab967c5f2 100644 --- a/core/service/serviceconf.go +++ b/core/service/serviceconf.go @@ -37,7 +37,7 @@ type ( Prometheus prometheus.Config `json:",optional"` Telemetry trace.Config `json:",optional"` DevServer DevServerConfig `json:",optional"` - Proc proc.ProcConf `json:",optional"` + Shutdown proc.ShutdownConf `json:",optional"` } ) @@ -62,7 +62,7 @@ func (sc ServiceConf) SetUp() error { sc.Telemetry.Name = sc.Name } trace.StartAgent(sc.Telemetry) - proc.Setup(sc.Proc) + proc.Setup(sc.Shutdown) proc.AddShutdownListener(func() { trace.StopAgent() })