128 lines
2.0 KiB
Go
128 lines
2.0 KiB
Go
|
package graceful
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"os"
|
||
|
"os/signal"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/rs/zerolog/log"
|
||
|
"go.uber.org/multierr"
|
||
|
)
|
||
|
|
||
|
func WithInterupt(ctx context.Context) context.Context {
|
||
|
log := log.Ctx(ctx)
|
||
|
ctx, cancel := context.WithCancel(ctx)
|
||
|
|
||
|
// Listen for Interrupt signals
|
||
|
c := make(chan os.Signal, 1)
|
||
|
signal.Notify(c, os.Interrupt)
|
||
|
|
||
|
go func() {
|
||
|
defer signal.Stop(c)
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case <-c:
|
||
|
cancel()
|
||
|
log.Warn().Msg("Shutting down! interrupt received")
|
||
|
return
|
||
|
case <-ctx.Done():
|
||
|
log.Warn().Msg("Shutting down! context cancelled")
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
return ctx
|
||
|
}
|
||
|
|
||
|
type contextKey struct{ string }
|
||
|
|
||
|
var wgKey = contextKey{"waitgroup"}
|
||
|
|
||
|
type wgContext struct {
|
||
|
wg sync.WaitGroup
|
||
|
err error
|
||
|
ctx context.Context
|
||
|
}
|
||
|
|
||
|
func (wg *wgContext) String() string {
|
||
|
return fmt.Sprintf("WaitGroup[%v %v]", wg.err, wg.ctx)
|
||
|
}
|
||
|
|
||
|
type WG interface {
|
||
|
Wait(time.Duration) error
|
||
|
Go(func() error)
|
||
|
}
|
||
|
|
||
|
func WithWaitGroup(ctx context.Context) (context.Context, WG) {
|
||
|
if wg := WaitGroup(ctx); wg != nil {
|
||
|
return ctx, wg
|
||
|
}
|
||
|
wg := &wgContext{ctx: ctx}
|
||
|
return context.WithValue(ctx, wgKey, wg), wg
|
||
|
}
|
||
|
|
||
|
func WaitGroup(ctx context.Context) *wgContext {
|
||
|
if wg, ok := ctx.Value(wgKey).(*wgContext); ok {
|
||
|
return wg
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (wg *wgContext) Go(fn func() error) {
|
||
|
if wg == nil {
|
||
|
panic("nil wait group")
|
||
|
}
|
||
|
|
||
|
wg.Add(1)
|
||
|
go func() {
|
||
|
err := fn()
|
||
|
wg.err = multierr.Append(wg.err, err)
|
||
|
wg.Done()
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
func (wg *wgContext) Add(n int) {
|
||
|
wg.wg.Add(n)
|
||
|
}
|
||
|
|
||
|
func (wg *wgContext) Done() {
|
||
|
wg.wg.Done()
|
||
|
}
|
||
|
|
||
|
func (wg *wgContext) Wait(gracetime time.Duration) error {
|
||
|
if wg == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
log := log.Ctx(wg.ctx)
|
||
|
|
||
|
ch := make(chan struct{})
|
||
|
go func() {
|
||
|
wg.wg.Wait()
|
||
|
close(ch)
|
||
|
}()
|
||
|
|
||
|
<-wg.ctx.Done()
|
||
|
wg.err = multierr.Append(wg.err, wg.ctx.Err())
|
||
|
|
||
|
log.Debug().Msg("shutdown begin")
|
||
|
timer := time.NewTimer(gracetime)
|
||
|
|
||
|
select {
|
||
|
case <-ch:
|
||
|
case <-timer.C:
|
||
|
wg.err = multierr.Append(wg.err, ErrExpiredGrace)
|
||
|
}
|
||
|
log.Debug().Msg("shutdown complete")
|
||
|
|
||
|
return wg.err
|
||
|
}
|
||
|
|
||
|
var ErrExpiredGrace = errors.New("grace time expired")
|