feat(locker): add check for nested calls

This commit is contained in:
Jon Lundy 2023-03-19 08:31:00 -06:00
parent d27eb21ffb
commit fcc5d08aa7
Signed by untrusted user who does not match committer: xuu
GPG Key ID: C63E6D61F3035024
12 changed files with 87 additions and 40 deletions

View File

@ -97,7 +97,7 @@ func (s *service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var requests []*Request var requests []*Request
s.state.Modify(r.Context(), func(ctx context.Context, state *state) error { s.state.Use(r.Context(), func(ctx context.Context, state *state) error {
for id, p := range state.peers { for id, p := range state.peers {
fmt.Fprintln(w, "PEER:", id[24:], p.Owner, p.Name) fmt.Fprintln(w, "PEER:", id[24:], p.Owner, p.Name)
} }
@ -170,7 +170,7 @@ func (s *service) getPending(w http.ResponseWriter, r *http.Request, peerID stri
) )
var peer *Peer var peer *Peer
err := s.state.Modify(ctx, func(ctx context.Context, state *state) error { err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
var ok bool var ok bool
if peer, ok = state.peers[peerID]; !ok { if peer, ok = state.peers[peerID]; !ok {
return fmt.Errorf("peer not found: %s", peerID) return fmt.Errorf("peer not found: %s", peerID)
@ -285,7 +285,7 @@ func (s *service) getResults(w http.ResponseWriter, r *http.Request) {
// } // }
var requests ListRequest var requests ListRequest
s.state.Modify(ctx, func(ctx context.Context, state *state) error { s.state.Use(ctx, func(ctx context.Context, state *state) error {
requests = make([]*Request, 0, len(state.requests)) requests = make([]*Request, 0, len(state.requests))
for _, req := range state.requests { for _, req := range state.requests {
@ -306,7 +306,7 @@ func (s *service) getResults(w http.ResponseWriter, r *http.Request) {
args := requestArgs(r) args := requestArgs(r)
args.Requests = requests[:maxResults] args.Requests = requests[:maxResults]
s.state.Modify(ctx, func(ctx context.Context, state *state) error { s.state.Use(ctx, func(ctx context.Context, state *state) error {
args.CountPeers = len(state.peers) args.CountPeers = len(state.peers)
return nil return nil
}) })
@ -323,7 +323,7 @@ func (s *service) getResultsForRequest(w http.ResponseWriter, r *http.Request, u
) )
var request *Request var request *Request
err := s.state.Modify(ctx, func(ctx context.Context, state *state) error { err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
request = state.requests[uuid] request = state.requests[uuid]
return nil return nil
@ -430,7 +430,7 @@ func (s *service) postResult(w http.ResponseWriter, r *http.Request, reqID strin
peerID := r.Form.Get("peer_id") peerID := r.Form.Get("peer_id")
err := s.state.Modify(ctx, func(ctx context.Context, state *state) error { err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
var ok bool var ok bool
if _, ok = state.peers[peerID]; !ok { if _, ok = state.peers[peerID]; !ok {
log.Printf("peer not found: %s\n", peerID) log.Printf("peer not found: %s\n", peerID)

View File

@ -41,7 +41,7 @@ func (s *service) RefreshJob(ctx context.Context, _ time.Time) error {
return err return err
} }
err = s.state.Modify(ctx, func(ctx context.Context, t *state) error { err = s.state.Use(ctx, func(ctx context.Context, t *state) error {
for _, peer := range peers { for _, peer := range peers {
t.peers[peer.ID] = peer t.peers[peer.ID] = peer
} }
@ -88,7 +88,7 @@ func (s *service) cleanPeerJobs(ctx context.Context) error {
defer span.End() defer span.End()
peers := set.New[string]() peers := set.New[string]()
err := s.state.Modify(ctx, func(ctx context.Context, state *state) error { err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
for id := range state.peers { for id := range state.peers {
peers.Add(id) peers.Add(id)
} }
@ -181,7 +181,7 @@ func (s *service) cleanRequests(ctx context.Context, now time.Time) error {
// truncate all the request streams // truncate all the request streams
for _, streamID := range streamIDs { for _, streamID := range streamIDs {
s.state.Modify(ctx, func(ctx context.Context, state *state) error { s.state.Use(ctx, func(ctx context.Context, state *state) error {
return state.ApplyEvents(event.NewEvents(&RequestTruncated{ return state.ApplyEvents(event.NewEvents(&RequestTruncated{
RequestID: streamID, RequestID: streamID,
})) }))

View File

@ -63,7 +63,7 @@ func (s *service) loadResult(ctx context.Context, request *Request) (*Request, e
return request, nil return request, nil
} }
return request, s.state.Modify(ctx, func(ctx context.Context, t *state) error { return request, s.state.Use(ctx, func(ctx context.Context, t *state) error {
for i := range request.Responses { for i := range request.Responses {
res := request.Responses[i] res := request.Responses[i]
@ -116,7 +116,7 @@ func (s *service) Run(ctx context.Context) (err error) {
} }
} }
s.state.Modify(ctx, func(ctx context.Context, state *state) error { s.state.Use(ctx, func(ctx context.Context, state *state) error {
return state.ApplyEvents(events) return state.ApplyEvents(events)
}) })
events = events[:0] events = events[:0]

View File

@ -24,6 +24,7 @@ func TestMain(m *testing.M) {
} }
defer os.RemoveAll(data) defer os.RemoveAll(data)
os.Setenv("EV_DATA", "mem:")
os.Setenv("EV_HTTP", "[::1]:61234") os.Setenv("EV_HTTP", "[::1]:61234")
os.Setenv("WEBFINGER_DOMAINS", "::1") os.Setenv("WEBFINGER_DOMAINS", "::1")

2
ev.go
View File

@ -54,7 +54,7 @@ func Register(ctx context.Context, name string, d driver.Driver) error {
ctx, span := lg.Span(ctx) ctx, span := lg.Span(ctx)
defer span.End() defer span.End()
return drivers.Modify(ctx, func(ctx context.Context, c *config) error { return drivers.Use(ctx, func(ctx context.Context, c *config) error {
if _, set := c.drivers[name]; set { if _, set := c.drivers[name]; set {
return fmt.Errorf("driver %s already set", name) return fmt.Errorf("driver %s already set", name)
} }

View File

@ -74,7 +74,7 @@ func (c *cron) NewCron(expr string, task func(context.Context, time.Time) error)
c.jobs = append(c.jobs, job) c.jobs = append(c.jobs, job)
} }
func (c *cron) RunOnce(ctx context.Context, once func(context.Context, time.Time) error) { func (c *cron) RunOnce(ctx context.Context, once func(context.Context, time.Time) error) {
c.state.Modify(ctx, func(ctx context.Context, state *state) error { c.state.Use(ctx, func(ctx context.Context, state *state) error {
state.queue = append(state.queue, once) state.queue = append(state.queue, once)
return nil return nil
}) })
@ -126,7 +126,7 @@ func (c *cron) run(ctx context.Context, now time.Time) {
span.AddEvent("Cron Run: " + now.Format(time.RFC822)) span.AddEvent("Cron Run: " + now.Format(time.RFC822))
// fmt.Println("Cron Run: ", now.Format(time.RFC822)) // fmt.Println("Cron Run: ", now.Format(time.RFC822))
c.state.Modify(ctx, func(ctx context.Context, state *state) error { c.state.Use(ctx, func(ctx context.Context, state *state) error {
run = append(run, state.queue...) run = append(run, state.queue...)
state.queue = state.queue[:0] state.queue = state.queue[:0]

View File

@ -100,7 +100,7 @@ func (d *diskStore) Open(ctx context.Context, dsn string) (driver.Driver, error)
ctx, span := lg.Span(ctx) ctx, span := lg.Span(ctx)
defer span.End() defer span.End()
l.Modify(ctx, func(ctx context.Context, w *wal.Log) error { l.Use(ctx, func(ctx context.Context, w *wal.Log) error {
ctx, span := lg.Span(ctx) ctx, span := lg.Span(ctx)
defer span.End() defer span.End()
@ -139,7 +139,7 @@ func (d *diskStore) EventLog(ctx context.Context, streamID string) (driver.Event
el := &eventLog{streamID: streamID, diskStore: d} el := &eventLog{streamID: streamID, diskStore: d}
return el, d.openlogs.Modify(ctx, func(ctx context.Context, openlogs *openlogs) error { return el, d.openlogs.Use(ctx, func(ctx context.Context, openlogs *openlogs) error {
ctx, span := lg.Span(ctx) ctx, span := lg.Span(ctx)
defer span.End() defer span.End()
@ -193,7 +193,7 @@ func (e *eventLog) Append(ctx context.Context, events event.Events, version uint
event.SetStreamID(e.streamID, events...) event.SetStreamID(e.streamID, events...)
var count uint64 var count uint64
err := e.events.Modify(ctx, func(ctx context.Context, l *wal.Log) error { err := e.events.Use(ctx, func(ctx context.Context, l *wal.Log) error {
ctx, span := lg.Span(ctx) ctx, span := lg.Span(ctx)
defer span.End() defer span.End()
@ -248,7 +248,7 @@ func (e *eventLog) Read(ctx context.Context, after, count int64) (event.Events,
var events event.Events var events event.Events
err := e.events.Modify(ctx, func(ctx context.Context, stream *wal.Log) error { err := e.events.Use(ctx, func(ctx context.Context, stream *wal.Log) error {
ctx, span := lg.Span(ctx) ctx, span := lg.Span(ctx)
defer span.End() defer span.End()
@ -330,7 +330,7 @@ func (e *eventLog) ReadN(ctx context.Context, index ...uint64) (event.Events, er
) )
var events event.Events var events event.Events
err := e.events.Modify(ctx, func(ctx context.Context, stream *wal.Log) error { err := e.events.Use(ctx, func(ctx context.Context, stream *wal.Log) error {
var err error var err error
events, err = readStreamN(ctx, stream, index...) events, err = readStreamN(ctx, stream, index...)
@ -352,7 +352,7 @@ func (e *eventLog) FirstIndex(ctx context.Context) (uint64, error) {
var idx uint64 var idx uint64
var err error var err error
err = e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error { err = e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
idx, err = events.FirstIndex() idx, err = events.FirstIndex()
return err return err
}) })
@ -371,7 +371,7 @@ func (e *eventLog) LastIndex(ctx context.Context) (uint64, error) {
var idx uint64 var idx uint64
var err error var err error
err = e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error { err = e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
idx, err = events.LastIndex() idx, err = events.LastIndex()
return err return err
}) })
@ -391,7 +391,7 @@ func (e *eventLog) Truncate(ctx context.Context, index int64) error {
if index == 0 { if index == 0 {
return nil return nil
} }
return e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error { return e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
if index < 0 { if index < 0 {
return events.TruncateBack(uint64(-index)) return events.TruncateBack(uint64(-index))
} }

View File

@ -49,7 +49,7 @@ func (m *memstore) EventLog(ctx context.Context, streamID string) (driver.EventL
el := &eventLog{streamID: streamID} el := &eventLog{streamID: streamID}
err := m.state.Modify(ctx, func(ctx context.Context, state *state) error { err := m.state.Use(ctx, func(ctx context.Context, state *state) error {
_, span := lg.Span(ctx) _, span := lg.Span(ctx)
defer span.End() defer span.End()
@ -76,7 +76,7 @@ func (m *eventLog) Append(ctx context.Context, events event.Events, version uint
event.SetStreamID(m.streamID, events...) event.SetStreamID(m.streamID, events...)
return uint64(len(events)), m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error { return uint64(len(events)), m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
ctx, span := lg.Span(ctx) ctx, span := lg.Span(ctx)
defer span.End() defer span.End()
@ -117,7 +117,7 @@ func (m *eventLog) ReadN(ctx context.Context, index ...uint64) (event.Events, er
defer span.End() defer span.End()
var events event.Events var events event.Events
err := m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error { err := m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
var err error var err error
events, err = readStreamN(ctx, stream, index...) events, err = readStreamN(ctx, stream, index...)
@ -135,7 +135,7 @@ func (m *eventLog) Read(ctx context.Context, after int64, count int64) (event.Ev
var events event.Events var events event.Events
err := m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error { err := m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
ctx, span := lg.Span(ctx) ctx, span := lg.Span(ctx)
defer span.End() defer span.End()

View File

@ -76,7 +76,7 @@ func (s *streamer) Subscribe(ctx context.Context, streamID string, start int64)
}) })
sub.unsub = s.delete(streamID, sub) sub.unsub = s.delete(streamID, sub)
return sub, s.state.Modify(ctx, func(ctx context.Context, state *state) error { return sub, s.state.Use(ctx, func(ctx context.Context, state *state) error {
state.subscribers[streamID] = append(state.subscribers[streamID], sub) state.subscribers[streamID] = append(state.subscribers[streamID], sub)
return nil return nil
}) })
@ -85,14 +85,14 @@ func (s *streamer) Send(ctx context.Context, streamID string, events event.Event
ctx, span := lg.Span(ctx) ctx, span := lg.Span(ctx)
defer span.End() defer span.End()
return s.state.Modify(ctx, func(ctx context.Context, state *state) error { return s.state.Use(ctx, func(ctx context.Context, state *state) error {
ctx, span := lg.Span(ctx) ctx, span := lg.Span(ctx)
defer span.End() defer span.End()
span.AddEvent(fmt.Sprint("subscribers=", len(state.subscribers[streamID]))) span.AddEvent(fmt.Sprint("subscribers=", len(state.subscribers[streamID])))
for _, sub := range state.subscribers[streamID] { for _, sub := range state.subscribers[streamID] {
err := sub.position.Modify(ctx, func(ctx context.Context, position *position) error { err := sub.position.Use(ctx, func(ctx context.Context, position *position) error {
ctx, span := lg.Span(ctx) ctx, span := lg.Span(ctx)
defer span.End() defer span.End()
@ -128,7 +128,7 @@ func (s *streamer) delete(streamID string, sub *subscription) func(context.Conte
if err := ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
return err return err
} }
return s.state.Modify(ctx, func(ctx context.Context, state *state) error { return s.state.Use(ctx, func(ctx context.Context, state *state) error {
_, span := lg.Span(ctx) _, span := lg.Span(ctx)
defer span.End() defer span.End()
@ -228,7 +228,7 @@ func (s *subscription) Recv(ctx context.Context) <-chan bool {
var wait func(context.Context) bool var wait func(context.Context) bool
defer close(done) defer close(done)
err := s.position.Modify(ctx, func(ctx context.Context, position *position) error { err := s.position.Use(ctx, func(ctx context.Context, position *position) error {
_, span := lg.Span(ctx) _, span := lg.Span(ctx)
defer span.End() defer span.End()
@ -280,7 +280,7 @@ func (s *subscription) Events(ctx context.Context) (event.Events, error) {
defer span.End() defer span.End()
var events event.Events var events event.Events
return events, s.position.Modify(ctx, func(ctx context.Context, position *position) error { return events, s.position.Use(ctx, func(ctx context.Context, position *position) error {
ctx, span := lg.Span(ctx) ctx, span := lg.Span(ctx)
defer span.End() defer span.End()

View File

@ -106,7 +106,7 @@ func RegisterName(ctx context.Context, name string, e Event) error {
span.AddEvent("register: " + name) span.AddEvent("register: " + name)
if err := eventTypes.Modify(ctx, func(ctx context.Context, c *config) error { if err := eventTypes.Use(ctx, func(ctx context.Context, c *config) error {
_, span := lg.Span(ctx) _, span := lg.Span(ctx)
defer span.End() defer span.End()
@ -124,7 +124,7 @@ func GetContainer(ctx context.Context, s string) Event {
var e Event var e Event
eventTypes.Modify(ctx, func(ctx context.Context, c *config) error { eventTypes.Use(ctx, func(ctx context.Context, c *config) error {
_, span := lg.Span(ctx) _, span := lg.Span(ctx)
defer span.End() defer span.End()

View File

@ -2,6 +2,7 @@ package locker
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
@ -20,12 +21,21 @@ func New[T any](initial *T) *Locked[T] {
return s return s
} }
// Modify will call the function with the locked value type ctxKey struct{ name string }
func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) error) error {
// Use will call the function with the locked value
func (s *Locked[T]) Use(ctx context.Context, fn func(context.Context, *T) error) error {
if s == nil { if s == nil {
return fmt.Errorf("locker not initialized") return fmt.Errorf("locker not initialized")
} }
key := ctxKey{fmt.Sprintf("%p", s)}
if value := ctx.Value(key); value != nil {
return fmt.Errorf("%w: %T", ErrNested, s)
}
ctx = context.WithValue(ctx, key, key)
ctx, span := lg.Span(ctx) ctx, span := lg.Span(ctx)
defer span.End() defer span.End()
@ -51,7 +61,7 @@ func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) err
func (s *Locked[T]) Copy(ctx context.Context) (T, error) { func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
var t T var t T
err := s.Modify(ctx, func(ctx context.Context, c *T) error { err := s.Use(ctx, func(ctx context.Context, c *T) error {
if c != nil { if c != nil {
t = *c t = *c
} }
@ -60,3 +70,5 @@ func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
return t, err return t, err
} }
var ErrNested = errors.New("nested locker call")

View File

@ -2,6 +2,7 @@ package locker_test
import ( import (
"context" "context"
"errors"
"testing" "testing"
"github.com/matryer/is" "github.com/matryer/is"
@ -22,7 +23,7 @@ func TestLocker(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
err := value.Modify(ctx, func(ctx context.Context, c *config) error { err := value.Use(ctx, func(ctx context.Context, c *config) error {
c.Value = "one" c.Value = "one"
c.Counter++ c.Counter++
return nil return nil
@ -37,7 +38,7 @@ func TestLocker(t *testing.T) {
wait := make(chan struct{}) wait := make(chan struct{})
go value.Modify(ctx, func(ctx context.Context, c *config) error { go value.Use(ctx, func(ctx context.Context, c *config) error {
c.Value = "two" c.Value = "two"
c.Counter++ c.Counter++
close(wait) close(wait)
@ -47,7 +48,7 @@ func TestLocker(t *testing.T) {
<-wait <-wait
cancel() cancel()
err = value.Modify(ctx, func(ctx context.Context, c *config) error { err = value.Use(ctx, func(ctx context.Context, c *config) error {
c.Value = "three" c.Value = "three"
c.Counter++ c.Counter++
return nil return nil
@ -60,3 +61,36 @@ func TestLocker(t *testing.T) {
is.Equal(c.Value, "two") is.Equal(c.Value, "two")
is.Equal(c.Counter, 2) is.Equal(c.Counter, 2)
} }
func TestNestedLocker(t *testing.T) {
is := is.New(t)
value := locker.New(&config{})
other := locker.New(&config{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := value.Use(ctx, func(ctx context.Context, c *config) error {
return value.Use(ctx, func(ctx context.Context, t *config) error {
return nil
})
})
is.True(errors.Is(err, locker.ErrNested))
err = value.Use(ctx, func(ctx context.Context, c *config) error {
return other.Use(ctx, func(ctx context.Context, t *config) error {
return nil
})
})
is.NoErr(err)
err = value.Use(ctx, func(ctx context.Context, c *config) error {
return other.Use(ctx, func(ctx context.Context, t *config) error {
return value.Use(ctx, func(ctx context.Context, x *config) error {
return nil
})
})
})
is.True(errors.Is(err, locker.ErrNested))
}