feat(locker): add check for nested calls
This commit is contained in:
parent
d27eb21ffb
commit
fcc5d08aa7
|
@ -97,7 +97,7 @@ func (s *service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var requests []*Request
|
||||
|
||||
s.state.Modify(r.Context(), func(ctx context.Context, state *state) error {
|
||||
s.state.Use(r.Context(), func(ctx context.Context, state *state) error {
|
||||
for id, p := range state.peers {
|
||||
fmt.Fprintln(w, "PEER:", id[24:], p.Owner, p.Name)
|
||||
}
|
||||
|
@ -170,7 +170,7 @@ func (s *service) getPending(w http.ResponseWriter, r *http.Request, peerID stri
|
|||
)
|
||||
|
||||
var peer *Peer
|
||||
err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
var ok bool
|
||||
if peer, ok = state.peers[peerID]; !ok {
|
||||
return fmt.Errorf("peer not found: %s", peerID)
|
||||
|
@ -285,7 +285,7 @@ func (s *service) getResults(w http.ResponseWriter, r *http.Request) {
|
|||
// }
|
||||
|
||||
var requests ListRequest
|
||||
s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
requests = make([]*Request, 0, len(state.requests))
|
||||
|
||||
for _, req := range state.requests {
|
||||
|
@ -306,7 +306,7 @@ func (s *service) getResults(w http.ResponseWriter, r *http.Request) {
|
|||
args := requestArgs(r)
|
||||
args.Requests = requests[:maxResults]
|
||||
|
||||
s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
args.CountPeers = len(state.peers)
|
||||
return nil
|
||||
})
|
||||
|
@ -323,7 +323,7 @@ func (s *service) getResultsForRequest(w http.ResponseWriter, r *http.Request, u
|
|||
)
|
||||
|
||||
var request *Request
|
||||
err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
request = state.requests[uuid]
|
||||
|
||||
return nil
|
||||
|
@ -430,7 +430,7 @@ func (s *service) postResult(w http.ResponseWriter, r *http.Request, reqID strin
|
|||
|
||||
peerID := r.Form.Get("peer_id")
|
||||
|
||||
err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
var ok bool
|
||||
if _, ok = state.peers[peerID]; !ok {
|
||||
log.Printf("peer not found: %s\n", peerID)
|
||||
|
|
|
@ -41,7 +41,7 @@ func (s *service) RefreshJob(ctx context.Context, _ time.Time) error {
|
|||
return err
|
||||
}
|
||||
|
||||
err = s.state.Modify(ctx, func(ctx context.Context, t *state) error {
|
||||
err = s.state.Use(ctx, func(ctx context.Context, t *state) error {
|
||||
for _, peer := range peers {
|
||||
t.peers[peer.ID] = peer
|
||||
}
|
||||
|
@ -88,7 +88,7 @@ func (s *service) cleanPeerJobs(ctx context.Context) error {
|
|||
defer span.End()
|
||||
|
||||
peers := set.New[string]()
|
||||
err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
for id := range state.peers {
|
||||
peers.Add(id)
|
||||
}
|
||||
|
@ -181,7 +181,7 @@ func (s *service) cleanRequests(ctx context.Context, now time.Time) error {
|
|||
|
||||
// truncate all the request streams
|
||||
for _, streamID := range streamIDs {
|
||||
s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
return state.ApplyEvents(event.NewEvents(&RequestTruncated{
|
||||
RequestID: streamID,
|
||||
}))
|
||||
|
|
|
@ -63,7 +63,7 @@ func (s *service) loadResult(ctx context.Context, request *Request) (*Request, e
|
|||
return request, nil
|
||||
}
|
||||
|
||||
return request, s.state.Modify(ctx, func(ctx context.Context, t *state) error {
|
||||
return request, s.state.Use(ctx, func(ctx context.Context, t *state) error {
|
||||
|
||||
for i := range request.Responses {
|
||||
res := request.Responses[i]
|
||||
|
@ -116,7 +116,7 @@ func (s *service) Run(ctx context.Context) (err error) {
|
|||
}
|
||||
}
|
||||
|
||||
s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
return state.ApplyEvents(events)
|
||||
})
|
||||
events = events[:0]
|
||||
|
|
|
@ -24,6 +24,7 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
defer os.RemoveAll(data)
|
||||
|
||||
os.Setenv("EV_DATA", "mem:")
|
||||
os.Setenv("EV_HTTP", "[::1]:61234")
|
||||
os.Setenv("WEBFINGER_DOMAINS", "::1")
|
||||
|
||||
|
|
2
ev.go
2
ev.go
|
@ -54,7 +54,7 @@ func Register(ctx context.Context, name string, d driver.Driver) error {
|
|||
ctx, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
return drivers.Modify(ctx, func(ctx context.Context, c *config) error {
|
||||
return drivers.Use(ctx, func(ctx context.Context, c *config) error {
|
||||
if _, set := c.drivers[name]; set {
|
||||
return fmt.Errorf("driver %s already set", name)
|
||||
}
|
||||
|
|
|
@ -74,7 +74,7 @@ func (c *cron) NewCron(expr string, task func(context.Context, time.Time) error)
|
|||
c.jobs = append(c.jobs, job)
|
||||
}
|
||||
func (c *cron) RunOnce(ctx context.Context, once func(context.Context, time.Time) error) {
|
||||
c.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
c.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
state.queue = append(state.queue, once)
|
||||
return nil
|
||||
})
|
||||
|
@ -126,7 +126,7 @@ func (c *cron) run(ctx context.Context, now time.Time) {
|
|||
span.AddEvent("Cron Run: " + now.Format(time.RFC822))
|
||||
// fmt.Println("Cron Run: ", now.Format(time.RFC822))
|
||||
|
||||
c.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
c.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
run = append(run, state.queue...)
|
||||
state.queue = state.queue[:0]
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ func (d *diskStore) Open(ctx context.Context, dsn string) (driver.Driver, error)
|
|||
ctx, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
l.Modify(ctx, func(ctx context.Context, w *wal.Log) error {
|
||||
l.Use(ctx, func(ctx context.Context, w *wal.Log) error {
|
||||
ctx, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
@ -139,7 +139,7 @@ func (d *diskStore) EventLog(ctx context.Context, streamID string) (driver.Event
|
|||
|
||||
el := &eventLog{streamID: streamID, diskStore: d}
|
||||
|
||||
return el, d.openlogs.Modify(ctx, func(ctx context.Context, openlogs *openlogs) error {
|
||||
return el, d.openlogs.Use(ctx, func(ctx context.Context, openlogs *openlogs) error {
|
||||
ctx, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
@ -193,7 +193,7 @@ func (e *eventLog) Append(ctx context.Context, events event.Events, version uint
|
|||
event.SetStreamID(e.streamID, events...)
|
||||
|
||||
var count uint64
|
||||
err := e.events.Modify(ctx, func(ctx context.Context, l *wal.Log) error {
|
||||
err := e.events.Use(ctx, func(ctx context.Context, l *wal.Log) error {
|
||||
ctx, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
@ -248,7 +248,7 @@ func (e *eventLog) Read(ctx context.Context, after, count int64) (event.Events,
|
|||
|
||||
var events event.Events
|
||||
|
||||
err := e.events.Modify(ctx, func(ctx context.Context, stream *wal.Log) error {
|
||||
err := e.events.Use(ctx, func(ctx context.Context, stream *wal.Log) error {
|
||||
ctx, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
@ -330,7 +330,7 @@ func (e *eventLog) ReadN(ctx context.Context, index ...uint64) (event.Events, er
|
|||
)
|
||||
|
||||
var events event.Events
|
||||
err := e.events.Modify(ctx, func(ctx context.Context, stream *wal.Log) error {
|
||||
err := e.events.Use(ctx, func(ctx context.Context, stream *wal.Log) error {
|
||||
var err error
|
||||
|
||||
events, err = readStreamN(ctx, stream, index...)
|
||||
|
@ -352,7 +352,7 @@ func (e *eventLog) FirstIndex(ctx context.Context) (uint64, error) {
|
|||
var idx uint64
|
||||
var err error
|
||||
|
||||
err = e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error {
|
||||
err = e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
|
||||
idx, err = events.FirstIndex()
|
||||
return err
|
||||
})
|
||||
|
@ -371,7 +371,7 @@ func (e *eventLog) LastIndex(ctx context.Context) (uint64, error) {
|
|||
var idx uint64
|
||||
var err error
|
||||
|
||||
err = e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error {
|
||||
err = e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
|
||||
idx, err = events.LastIndex()
|
||||
return err
|
||||
})
|
||||
|
@ -391,7 +391,7 @@ func (e *eventLog) Truncate(ctx context.Context, index int64) error {
|
|||
if index == 0 {
|
||||
return nil
|
||||
}
|
||||
return e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error {
|
||||
return e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
|
||||
if index < 0 {
|
||||
return events.TruncateBack(uint64(-index))
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ func (m *memstore) EventLog(ctx context.Context, streamID string) (driver.EventL
|
|||
|
||||
el := &eventLog{streamID: streamID}
|
||||
|
||||
err := m.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
err := m.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
_, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
@ -76,7 +76,7 @@ func (m *eventLog) Append(ctx context.Context, events event.Events, version uint
|
|||
|
||||
event.SetStreamID(m.streamID, events...)
|
||||
|
||||
return uint64(len(events)), m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error {
|
||||
return uint64(len(events)), m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
|
||||
ctx, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
@ -117,7 +117,7 @@ func (m *eventLog) ReadN(ctx context.Context, index ...uint64) (event.Events, er
|
|||
defer span.End()
|
||||
|
||||
var events event.Events
|
||||
err := m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error {
|
||||
err := m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
|
||||
var err error
|
||||
|
||||
events, err = readStreamN(ctx, stream, index...)
|
||||
|
@ -135,7 +135,7 @@ func (m *eventLog) Read(ctx context.Context, after int64, count int64) (event.Ev
|
|||
|
||||
var events event.Events
|
||||
|
||||
err := m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error {
|
||||
err := m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
|
||||
ctx, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
|
|
@ -76,7 +76,7 @@ func (s *streamer) Subscribe(ctx context.Context, streamID string, start int64)
|
|||
})
|
||||
sub.unsub = s.delete(streamID, sub)
|
||||
|
||||
return sub, s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
return sub, s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
state.subscribers[streamID] = append(state.subscribers[streamID], sub)
|
||||
return nil
|
||||
})
|
||||
|
@ -85,14 +85,14 @@ func (s *streamer) Send(ctx context.Context, streamID string, events event.Event
|
|||
ctx, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
return s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
return s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
ctx, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
span.AddEvent(fmt.Sprint("subscribers=", len(state.subscribers[streamID])))
|
||||
|
||||
for _, sub := range state.subscribers[streamID] {
|
||||
err := sub.position.Modify(ctx, func(ctx context.Context, position *position) error {
|
||||
err := sub.position.Use(ctx, func(ctx context.Context, position *position) error {
|
||||
ctx, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
@ -128,7 +128,7 @@ func (s *streamer) delete(streamID string, sub *subscription) func(context.Conte
|
|||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
||||
return s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||
_, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
@ -228,7 +228,7 @@ func (s *subscription) Recv(ctx context.Context) <-chan bool {
|
|||
var wait func(context.Context) bool
|
||||
defer close(done)
|
||||
|
||||
err := s.position.Modify(ctx, func(ctx context.Context, position *position) error {
|
||||
err := s.position.Use(ctx, func(ctx context.Context, position *position) error {
|
||||
_, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
@ -280,7 +280,7 @@ func (s *subscription) Events(ctx context.Context) (event.Events, error) {
|
|||
defer span.End()
|
||||
|
||||
var events event.Events
|
||||
return events, s.position.Modify(ctx, func(ctx context.Context, position *position) error {
|
||||
return events, s.position.Use(ctx, func(ctx context.Context, position *position) error {
|
||||
ctx, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ func RegisterName(ctx context.Context, name string, e Event) error {
|
|||
|
||||
span.AddEvent("register: " + name)
|
||||
|
||||
if err := eventTypes.Modify(ctx, func(ctx context.Context, c *config) error {
|
||||
if err := eventTypes.Use(ctx, func(ctx context.Context, c *config) error {
|
||||
_, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
@ -124,7 +124,7 @@ func GetContainer(ctx context.Context, s string) Event {
|
|||
|
||||
var e Event
|
||||
|
||||
eventTypes.Modify(ctx, func(ctx context.Context, c *config) error {
|
||||
eventTypes.Use(ctx, func(ctx context.Context, c *config) error {
|
||||
_, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package locker
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
|
@ -20,12 +21,21 @@ func New[T any](initial *T) *Locked[T] {
|
|||
return s
|
||||
}
|
||||
|
||||
// Modify will call the function with the locked value
|
||||
func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) error) error {
|
||||
type ctxKey struct{ name string }
|
||||
|
||||
// Use will call the function with the locked value
|
||||
func (s *Locked[T]) Use(ctx context.Context, fn func(context.Context, *T) error) error {
|
||||
if s == nil {
|
||||
return fmt.Errorf("locker not initialized")
|
||||
}
|
||||
|
||||
key := ctxKey{fmt.Sprintf("%p", s)}
|
||||
|
||||
if value := ctx.Value(key); value != nil {
|
||||
return fmt.Errorf("%w: %T", ErrNested, s)
|
||||
}
|
||||
ctx = context.WithValue(ctx, key, key)
|
||||
|
||||
ctx, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
@ -51,7 +61,7 @@ func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) err
|
|||
func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
|
||||
var t T
|
||||
|
||||
err := s.Modify(ctx, func(ctx context.Context, c *T) error {
|
||||
err := s.Use(ctx, func(ctx context.Context, c *T) error {
|
||||
if c != nil {
|
||||
t = *c
|
||||
}
|
||||
|
@ -60,3 +70,5 @@ func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
|
|||
|
||||
return t, err
|
||||
}
|
||||
|
||||
var ErrNested = errors.New("nested locker call")
|
||||
|
|
|
@ -2,6 +2,7 @@ package locker_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/matryer/is"
|
||||
|
@ -22,7 +23,7 @@ func TestLocker(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
err := value.Modify(ctx, func(ctx context.Context, c *config) error {
|
||||
err := value.Use(ctx, func(ctx context.Context, c *config) error {
|
||||
c.Value = "one"
|
||||
c.Counter++
|
||||
return nil
|
||||
|
@ -37,7 +38,7 @@ func TestLocker(t *testing.T) {
|
|||
|
||||
wait := make(chan struct{})
|
||||
|
||||
go value.Modify(ctx, func(ctx context.Context, c *config) error {
|
||||
go value.Use(ctx, func(ctx context.Context, c *config) error {
|
||||
c.Value = "two"
|
||||
c.Counter++
|
||||
close(wait)
|
||||
|
@ -47,7 +48,7 @@ func TestLocker(t *testing.T) {
|
|||
<-wait
|
||||
cancel()
|
||||
|
||||
err = value.Modify(ctx, func(ctx context.Context, c *config) error {
|
||||
err = value.Use(ctx, func(ctx context.Context, c *config) error {
|
||||
c.Value = "three"
|
||||
c.Counter++
|
||||
return nil
|
||||
|
@ -60,3 +61,36 @@ func TestLocker(t *testing.T) {
|
|||
is.Equal(c.Value, "two")
|
||||
is.Equal(c.Counter, 2)
|
||||
}
|
||||
|
||||
func TestNestedLocker(t *testing.T) {
|
||||
is := is.New(t)
|
||||
|
||||
value := locker.New(&config{})
|
||||
other := locker.New(&config{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
err := value.Use(ctx, func(ctx context.Context, c *config) error {
|
||||
return value.Use(ctx, func(ctx context.Context, t *config) error {
|
||||
return nil
|
||||
})
|
||||
})
|
||||
is.True(errors.Is(err, locker.ErrNested))
|
||||
|
||||
err = value.Use(ctx, func(ctx context.Context, c *config) error {
|
||||
return other.Use(ctx, func(ctx context.Context, t *config) error {
|
||||
return nil
|
||||
})
|
||||
})
|
||||
is.NoErr(err)
|
||||
|
||||
err = value.Use(ctx, func(ctx context.Context, c *config) error {
|
||||
return other.Use(ctx, func(ctx context.Context, t *config) error {
|
||||
return value.Use(ctx, func(ctx context.Context, x *config) error {
|
||||
return nil
|
||||
})
|
||||
})
|
||||
})
|
||||
is.True(errors.Is(err, locker.ErrNested))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user