From fcc5d08aa7b8743ec85a9db278bdd74ee756c69e Mon Sep 17 00:00:00 2001 From: Jon Lundy Date: Sun, 19 Mar 2023 08:31:00 -0600 Subject: [PATCH] feat(locker): add check for nested calls --- app/peerfinder/http.go | 12 ++++---- app/peerfinder/jobs.go | 6 ++-- app/peerfinder/service.go | 4 +-- cmd/webfinger/webfinger_e2e_test.go | 1 + ev.go | 2 +- pkg/cron/cron.go | 4 +-- pkg/es/driver/disk-store/disk-store.go | 16 +++++------ pkg/es/driver/mem-store/mem-store.go | 8 +++--- pkg/es/driver/streamer/streamer.go | 12 ++++---- pkg/es/event/reflect.go | 4 +-- pkg/locker/locker.go | 18 ++++++++++-- pkg/locker/locker_test.go | 40 ++++++++++++++++++++++++-- 12 files changed, 87 insertions(+), 40 deletions(-) diff --git a/app/peerfinder/http.go b/app/peerfinder/http.go index e20c677..3b56569 100644 --- a/app/peerfinder/http.go +++ b/app/peerfinder/http.go @@ -97,7 +97,7 @@ func (s *service) ServeHTTP(w http.ResponseWriter, r *http.Request) { var requests []*Request - s.state.Modify(r.Context(), func(ctx context.Context, state *state) error { + s.state.Use(r.Context(), func(ctx context.Context, state *state) error { for id, p := range state.peers { fmt.Fprintln(w, "PEER:", id[24:], p.Owner, p.Name) } @@ -170,7 +170,7 @@ func (s *service) getPending(w http.ResponseWriter, r *http.Request, peerID stri ) var peer *Peer - err := s.state.Modify(ctx, func(ctx context.Context, state *state) error { + err := s.state.Use(ctx, func(ctx context.Context, state *state) error { var ok bool if peer, ok = state.peers[peerID]; !ok { return fmt.Errorf("peer not found: %s", peerID) @@ -285,7 +285,7 @@ func (s *service) getResults(w http.ResponseWriter, r *http.Request) { // } var requests ListRequest - s.state.Modify(ctx, func(ctx context.Context, state *state) error { + s.state.Use(ctx, func(ctx context.Context, state *state) error { requests = make([]*Request, 0, len(state.requests)) for _, req := range state.requests { @@ -306,7 +306,7 @@ func (s *service) getResults(w http.ResponseWriter, r *http.Request) { args := requestArgs(r) args.Requests = requests[:maxResults] - s.state.Modify(ctx, func(ctx context.Context, state *state) error { + s.state.Use(ctx, func(ctx context.Context, state *state) error { args.CountPeers = len(state.peers) return nil }) @@ -323,7 +323,7 @@ func (s *service) getResultsForRequest(w http.ResponseWriter, r *http.Request, u ) var request *Request - err := s.state.Modify(ctx, func(ctx context.Context, state *state) error { + err := s.state.Use(ctx, func(ctx context.Context, state *state) error { request = state.requests[uuid] return nil @@ -430,7 +430,7 @@ func (s *service) postResult(w http.ResponseWriter, r *http.Request, reqID strin peerID := r.Form.Get("peer_id") - err := s.state.Modify(ctx, func(ctx context.Context, state *state) error { + err := s.state.Use(ctx, func(ctx context.Context, state *state) error { var ok bool if _, ok = state.peers[peerID]; !ok { log.Printf("peer not found: %s\n", peerID) diff --git a/app/peerfinder/jobs.go b/app/peerfinder/jobs.go index 1f30094..ccb45bf 100644 --- a/app/peerfinder/jobs.go +++ b/app/peerfinder/jobs.go @@ -41,7 +41,7 @@ func (s *service) RefreshJob(ctx context.Context, _ time.Time) error { return err } - err = s.state.Modify(ctx, func(ctx context.Context, t *state) error { + err = s.state.Use(ctx, func(ctx context.Context, t *state) error { for _, peer := range peers { t.peers[peer.ID] = peer } @@ -88,7 +88,7 @@ func (s *service) cleanPeerJobs(ctx context.Context) error { defer span.End() peers := set.New[string]() - err := s.state.Modify(ctx, func(ctx context.Context, state *state) error { + err := s.state.Use(ctx, func(ctx context.Context, state *state) error { for id := range state.peers { peers.Add(id) } @@ -181,7 +181,7 @@ func (s *service) cleanRequests(ctx context.Context, now time.Time) error { // truncate all the request streams for _, streamID := range streamIDs { - s.state.Modify(ctx, func(ctx context.Context, state *state) error { + s.state.Use(ctx, func(ctx context.Context, state *state) error { return state.ApplyEvents(event.NewEvents(&RequestTruncated{ RequestID: streamID, })) diff --git a/app/peerfinder/service.go b/app/peerfinder/service.go index 677ab80..3bf713d 100644 --- a/app/peerfinder/service.go +++ b/app/peerfinder/service.go @@ -63,7 +63,7 @@ func (s *service) loadResult(ctx context.Context, request *Request) (*Request, e return request, nil } - return request, s.state.Modify(ctx, func(ctx context.Context, t *state) error { + return request, s.state.Use(ctx, func(ctx context.Context, t *state) error { for i := range request.Responses { res := request.Responses[i] @@ -116,7 +116,7 @@ func (s *service) Run(ctx context.Context) (err error) { } } - s.state.Modify(ctx, func(ctx context.Context, state *state) error { + s.state.Use(ctx, func(ctx context.Context, state *state) error { return state.ApplyEvents(events) }) events = events[:0] diff --git a/cmd/webfinger/webfinger_e2e_test.go b/cmd/webfinger/webfinger_e2e_test.go index 534e86d..d34ba37 100644 --- a/cmd/webfinger/webfinger_e2e_test.go +++ b/cmd/webfinger/webfinger_e2e_test.go @@ -24,6 +24,7 @@ func TestMain(m *testing.M) { } defer os.RemoveAll(data) + os.Setenv("EV_DATA", "mem:") os.Setenv("EV_HTTP", "[::1]:61234") os.Setenv("WEBFINGER_DOMAINS", "::1") diff --git a/ev.go b/ev.go index c1bd479..f704277 100644 --- a/ev.go +++ b/ev.go @@ -54,7 +54,7 @@ func Register(ctx context.Context, name string, d driver.Driver) error { ctx, span := lg.Span(ctx) defer span.End() - return drivers.Modify(ctx, func(ctx context.Context, c *config) error { + return drivers.Use(ctx, func(ctx context.Context, c *config) error { if _, set := c.drivers[name]; set { return fmt.Errorf("driver %s already set", name) } diff --git a/pkg/cron/cron.go b/pkg/cron/cron.go index 5678cef..bb5131d 100644 --- a/pkg/cron/cron.go +++ b/pkg/cron/cron.go @@ -74,7 +74,7 @@ func (c *cron) NewCron(expr string, task func(context.Context, time.Time) error) c.jobs = append(c.jobs, job) } func (c *cron) RunOnce(ctx context.Context, once func(context.Context, time.Time) error) { - c.state.Modify(ctx, func(ctx context.Context, state *state) error { + c.state.Use(ctx, func(ctx context.Context, state *state) error { state.queue = append(state.queue, once) return nil }) @@ -126,7 +126,7 @@ func (c *cron) run(ctx context.Context, now time.Time) { span.AddEvent("Cron Run: " + now.Format(time.RFC822)) // fmt.Println("Cron Run: ", now.Format(time.RFC822)) - c.state.Modify(ctx, func(ctx context.Context, state *state) error { + c.state.Use(ctx, func(ctx context.Context, state *state) error { run = append(run, state.queue...) state.queue = state.queue[:0] diff --git a/pkg/es/driver/disk-store/disk-store.go b/pkg/es/driver/disk-store/disk-store.go index 6d696c0..e4e551f 100644 --- a/pkg/es/driver/disk-store/disk-store.go +++ b/pkg/es/driver/disk-store/disk-store.go @@ -100,7 +100,7 @@ func (d *diskStore) Open(ctx context.Context, dsn string) (driver.Driver, error) ctx, span := lg.Span(ctx) defer span.End() - l.Modify(ctx, func(ctx context.Context, w *wal.Log) error { + l.Use(ctx, func(ctx context.Context, w *wal.Log) error { ctx, span := lg.Span(ctx) defer span.End() @@ -139,7 +139,7 @@ func (d *diskStore) EventLog(ctx context.Context, streamID string) (driver.Event el := &eventLog{streamID: streamID, diskStore: d} - return el, d.openlogs.Modify(ctx, func(ctx context.Context, openlogs *openlogs) error { + return el, d.openlogs.Use(ctx, func(ctx context.Context, openlogs *openlogs) error { ctx, span := lg.Span(ctx) defer span.End() @@ -193,7 +193,7 @@ func (e *eventLog) Append(ctx context.Context, events event.Events, version uint event.SetStreamID(e.streamID, events...) var count uint64 - err := e.events.Modify(ctx, func(ctx context.Context, l *wal.Log) error { + err := e.events.Use(ctx, func(ctx context.Context, l *wal.Log) error { ctx, span := lg.Span(ctx) defer span.End() @@ -248,7 +248,7 @@ func (e *eventLog) Read(ctx context.Context, after, count int64) (event.Events, var events event.Events - err := e.events.Modify(ctx, func(ctx context.Context, stream *wal.Log) error { + err := e.events.Use(ctx, func(ctx context.Context, stream *wal.Log) error { ctx, span := lg.Span(ctx) defer span.End() @@ -330,7 +330,7 @@ func (e *eventLog) ReadN(ctx context.Context, index ...uint64) (event.Events, er ) var events event.Events - err := e.events.Modify(ctx, func(ctx context.Context, stream *wal.Log) error { + err := e.events.Use(ctx, func(ctx context.Context, stream *wal.Log) error { var err error events, err = readStreamN(ctx, stream, index...) @@ -352,7 +352,7 @@ func (e *eventLog) FirstIndex(ctx context.Context) (uint64, error) { var idx uint64 var err error - err = e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error { + err = e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error { idx, err = events.FirstIndex() return err }) @@ -371,7 +371,7 @@ func (e *eventLog) LastIndex(ctx context.Context) (uint64, error) { var idx uint64 var err error - err = e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error { + err = e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error { idx, err = events.LastIndex() return err }) @@ -391,7 +391,7 @@ func (e *eventLog) Truncate(ctx context.Context, index int64) error { if index == 0 { return nil } - return e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error { + return e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error { if index < 0 { return events.TruncateBack(uint64(-index)) } diff --git a/pkg/es/driver/mem-store/mem-store.go b/pkg/es/driver/mem-store/mem-store.go index a4741a4..645da9e 100644 --- a/pkg/es/driver/mem-store/mem-store.go +++ b/pkg/es/driver/mem-store/mem-store.go @@ -49,7 +49,7 @@ func (m *memstore) EventLog(ctx context.Context, streamID string) (driver.EventL el := &eventLog{streamID: streamID} - err := m.state.Modify(ctx, func(ctx context.Context, state *state) error { + err := m.state.Use(ctx, func(ctx context.Context, state *state) error { _, span := lg.Span(ctx) defer span.End() @@ -76,7 +76,7 @@ func (m *eventLog) Append(ctx context.Context, events event.Events, version uint event.SetStreamID(m.streamID, events...) - return uint64(len(events)), m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error { + return uint64(len(events)), m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error { ctx, span := lg.Span(ctx) defer span.End() @@ -117,7 +117,7 @@ func (m *eventLog) ReadN(ctx context.Context, index ...uint64) (event.Events, er defer span.End() var events event.Events - err := m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error { + err := m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error { var err error events, err = readStreamN(ctx, stream, index...) @@ -135,7 +135,7 @@ func (m *eventLog) Read(ctx context.Context, after int64, count int64) (event.Ev var events event.Events - err := m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error { + err := m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error { ctx, span := lg.Span(ctx) defer span.End() diff --git a/pkg/es/driver/streamer/streamer.go b/pkg/es/driver/streamer/streamer.go index 2f1949a..b37bfb2 100644 --- a/pkg/es/driver/streamer/streamer.go +++ b/pkg/es/driver/streamer/streamer.go @@ -76,7 +76,7 @@ func (s *streamer) Subscribe(ctx context.Context, streamID string, start int64) }) sub.unsub = s.delete(streamID, sub) - return sub, s.state.Modify(ctx, func(ctx context.Context, state *state) error { + return sub, s.state.Use(ctx, func(ctx context.Context, state *state) error { state.subscribers[streamID] = append(state.subscribers[streamID], sub) return nil }) @@ -85,14 +85,14 @@ func (s *streamer) Send(ctx context.Context, streamID string, events event.Event ctx, span := lg.Span(ctx) defer span.End() - return s.state.Modify(ctx, func(ctx context.Context, state *state) error { + return s.state.Use(ctx, func(ctx context.Context, state *state) error { ctx, span := lg.Span(ctx) defer span.End() span.AddEvent(fmt.Sprint("subscribers=", len(state.subscribers[streamID]))) for _, sub := range state.subscribers[streamID] { - err := sub.position.Modify(ctx, func(ctx context.Context, position *position) error { + err := sub.position.Use(ctx, func(ctx context.Context, position *position) error { ctx, span := lg.Span(ctx) defer span.End() @@ -128,7 +128,7 @@ func (s *streamer) delete(streamID string, sub *subscription) func(context.Conte if err := ctx.Err(); err != nil { return err } - return s.state.Modify(ctx, func(ctx context.Context, state *state) error { + return s.state.Use(ctx, func(ctx context.Context, state *state) error { _, span := lg.Span(ctx) defer span.End() @@ -228,7 +228,7 @@ func (s *subscription) Recv(ctx context.Context) <-chan bool { var wait func(context.Context) bool defer close(done) - err := s.position.Modify(ctx, func(ctx context.Context, position *position) error { + err := s.position.Use(ctx, func(ctx context.Context, position *position) error { _, span := lg.Span(ctx) defer span.End() @@ -280,7 +280,7 @@ func (s *subscription) Events(ctx context.Context) (event.Events, error) { defer span.End() var events event.Events - return events, s.position.Modify(ctx, func(ctx context.Context, position *position) error { + return events, s.position.Use(ctx, func(ctx context.Context, position *position) error { ctx, span := lg.Span(ctx) defer span.End() diff --git a/pkg/es/event/reflect.go b/pkg/es/event/reflect.go index afcab2c..e2407f2 100644 --- a/pkg/es/event/reflect.go +++ b/pkg/es/event/reflect.go @@ -106,7 +106,7 @@ func RegisterName(ctx context.Context, name string, e Event) error { span.AddEvent("register: " + name) - if err := eventTypes.Modify(ctx, func(ctx context.Context, c *config) error { + if err := eventTypes.Use(ctx, func(ctx context.Context, c *config) error { _, span := lg.Span(ctx) defer span.End() @@ -124,7 +124,7 @@ func GetContainer(ctx context.Context, s string) Event { var e Event - eventTypes.Modify(ctx, func(ctx context.Context, c *config) error { + eventTypes.Use(ctx, func(ctx context.Context, c *config) error { _, span := lg.Span(ctx) defer span.End() diff --git a/pkg/locker/locker.go b/pkg/locker/locker.go index f2706a1..be5461e 100644 --- a/pkg/locker/locker.go +++ b/pkg/locker/locker.go @@ -2,6 +2,7 @@ package locker import ( "context" + "errors" "fmt" "go.opentelemetry.io/otel/attribute" @@ -20,12 +21,21 @@ func New[T any](initial *T) *Locked[T] { return s } -// Modify will call the function with the locked value -func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) error) error { +type ctxKey struct{ name string } + +// Use will call the function with the locked value +func (s *Locked[T]) Use(ctx context.Context, fn func(context.Context, *T) error) error { if s == nil { return fmt.Errorf("locker not initialized") } + key := ctxKey{fmt.Sprintf("%p", s)} + + if value := ctx.Value(key); value != nil { + return fmt.Errorf("%w: %T", ErrNested, s) + } + ctx = context.WithValue(ctx, key, key) + ctx, span := lg.Span(ctx) defer span.End() @@ -51,7 +61,7 @@ func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) err func (s *Locked[T]) Copy(ctx context.Context) (T, error) { var t T - err := s.Modify(ctx, func(ctx context.Context, c *T) error { + err := s.Use(ctx, func(ctx context.Context, c *T) error { if c != nil { t = *c } @@ -60,3 +70,5 @@ func (s *Locked[T]) Copy(ctx context.Context) (T, error) { return t, err } + +var ErrNested = errors.New("nested locker call") diff --git a/pkg/locker/locker_test.go b/pkg/locker/locker_test.go index a95dbd0..9c28dd9 100644 --- a/pkg/locker/locker_test.go +++ b/pkg/locker/locker_test.go @@ -2,6 +2,7 @@ package locker_test import ( "context" + "errors" "testing" "github.com/matryer/is" @@ -22,7 +23,7 @@ func TestLocker(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - err := value.Modify(ctx, func(ctx context.Context, c *config) error { + err := value.Use(ctx, func(ctx context.Context, c *config) error { c.Value = "one" c.Counter++ return nil @@ -37,7 +38,7 @@ func TestLocker(t *testing.T) { wait := make(chan struct{}) - go value.Modify(ctx, func(ctx context.Context, c *config) error { + go value.Use(ctx, func(ctx context.Context, c *config) error { c.Value = "two" c.Counter++ close(wait) @@ -47,7 +48,7 @@ func TestLocker(t *testing.T) { <-wait cancel() - err = value.Modify(ctx, func(ctx context.Context, c *config) error { + err = value.Use(ctx, func(ctx context.Context, c *config) error { c.Value = "three" c.Counter++ return nil @@ -60,3 +61,36 @@ func TestLocker(t *testing.T) { is.Equal(c.Value, "two") is.Equal(c.Counter, 2) } + +func TestNestedLocker(t *testing.T) { + is := is.New(t) + + value := locker.New(&config{}) + other := locker.New(&config{}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err := value.Use(ctx, func(ctx context.Context, c *config) error { + return value.Use(ctx, func(ctx context.Context, t *config) error { + return nil + }) + }) + is.True(errors.Is(err, locker.ErrNested)) + + err = value.Use(ctx, func(ctx context.Context, c *config) error { + return other.Use(ctx, func(ctx context.Context, t *config) error { + return nil + }) + }) + is.NoErr(err) + + err = value.Use(ctx, func(ctx context.Context, c *config) error { + return other.Use(ctx, func(ctx context.Context, t *config) error { + return value.Use(ctx, func(ctx context.Context, x *config) error { + return nil + }) + }) + }) + is.True(errors.Is(err, locker.ErrNested)) +}