feat(locker): add check for nested calls
This commit is contained in:
parent
d27eb21ffb
commit
fcc5d08aa7
|
@ -97,7 +97,7 @@ func (s *service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
var requests []*Request
|
var requests []*Request
|
||||||
|
|
||||||
s.state.Modify(r.Context(), func(ctx context.Context, state *state) error {
|
s.state.Use(r.Context(), func(ctx context.Context, state *state) error {
|
||||||
for id, p := range state.peers {
|
for id, p := range state.peers {
|
||||||
fmt.Fprintln(w, "PEER:", id[24:], p.Owner, p.Name)
|
fmt.Fprintln(w, "PEER:", id[24:], p.Owner, p.Name)
|
||||||
}
|
}
|
||||||
|
@ -170,7 +170,7 @@ func (s *service) getPending(w http.ResponseWriter, r *http.Request, peerID stri
|
||||||
)
|
)
|
||||||
|
|
||||||
var peer *Peer
|
var peer *Peer
|
||||||
err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
var ok bool
|
var ok bool
|
||||||
if peer, ok = state.peers[peerID]; !ok {
|
if peer, ok = state.peers[peerID]; !ok {
|
||||||
return fmt.Errorf("peer not found: %s", peerID)
|
return fmt.Errorf("peer not found: %s", peerID)
|
||||||
|
@ -285,7 +285,7 @@ func (s *service) getResults(w http.ResponseWriter, r *http.Request) {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
var requests ListRequest
|
var requests ListRequest
|
||||||
s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
requests = make([]*Request, 0, len(state.requests))
|
requests = make([]*Request, 0, len(state.requests))
|
||||||
|
|
||||||
for _, req := range state.requests {
|
for _, req := range state.requests {
|
||||||
|
@ -306,7 +306,7 @@ func (s *service) getResults(w http.ResponseWriter, r *http.Request) {
|
||||||
args := requestArgs(r)
|
args := requestArgs(r)
|
||||||
args.Requests = requests[:maxResults]
|
args.Requests = requests[:maxResults]
|
||||||
|
|
||||||
s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
args.CountPeers = len(state.peers)
|
args.CountPeers = len(state.peers)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -323,7 +323,7 @@ func (s *service) getResultsForRequest(w http.ResponseWriter, r *http.Request, u
|
||||||
)
|
)
|
||||||
|
|
||||||
var request *Request
|
var request *Request
|
||||||
err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
request = state.requests[uuid]
|
request = state.requests[uuid]
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -430,7 +430,7 @@ func (s *service) postResult(w http.ResponseWriter, r *http.Request, reqID strin
|
||||||
|
|
||||||
peerID := r.Form.Get("peer_id")
|
peerID := r.Form.Get("peer_id")
|
||||||
|
|
||||||
err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
var ok bool
|
var ok bool
|
||||||
if _, ok = state.peers[peerID]; !ok {
|
if _, ok = state.peers[peerID]; !ok {
|
||||||
log.Printf("peer not found: %s\n", peerID)
|
log.Printf("peer not found: %s\n", peerID)
|
||||||
|
|
|
@ -41,7 +41,7 @@ func (s *service) RefreshJob(ctx context.Context, _ time.Time) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.state.Modify(ctx, func(ctx context.Context, t *state) error {
|
err = s.state.Use(ctx, func(ctx context.Context, t *state) error {
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
t.peers[peer.ID] = peer
|
t.peers[peer.ID] = peer
|
||||||
}
|
}
|
||||||
|
@ -88,7 +88,7 @@ func (s *service) cleanPeerJobs(ctx context.Context) error {
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
peers := set.New[string]()
|
peers := set.New[string]()
|
||||||
err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
for id := range state.peers {
|
for id := range state.peers {
|
||||||
peers.Add(id)
|
peers.Add(id)
|
||||||
}
|
}
|
||||||
|
@ -181,7 +181,7 @@ func (s *service) cleanRequests(ctx context.Context, now time.Time) error {
|
||||||
|
|
||||||
// truncate all the request streams
|
// truncate all the request streams
|
||||||
for _, streamID := range streamIDs {
|
for _, streamID := range streamIDs {
|
||||||
s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
return state.ApplyEvents(event.NewEvents(&RequestTruncated{
|
return state.ApplyEvents(event.NewEvents(&RequestTruncated{
|
||||||
RequestID: streamID,
|
RequestID: streamID,
|
||||||
}))
|
}))
|
||||||
|
|
|
@ -63,7 +63,7 @@ func (s *service) loadResult(ctx context.Context, request *Request) (*Request, e
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return request, s.state.Modify(ctx, func(ctx context.Context, t *state) error {
|
return request, s.state.Use(ctx, func(ctx context.Context, t *state) error {
|
||||||
|
|
||||||
for i := range request.Responses {
|
for i := range request.Responses {
|
||||||
res := request.Responses[i]
|
res := request.Responses[i]
|
||||||
|
@ -116,7 +116,7 @@ func (s *service) Run(ctx context.Context) (err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
return state.ApplyEvents(events)
|
return state.ApplyEvents(events)
|
||||||
})
|
})
|
||||||
events = events[:0]
|
events = events[:0]
|
||||||
|
|
|
@ -24,6 +24,7 @@ func TestMain(m *testing.M) {
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(data)
|
defer os.RemoveAll(data)
|
||||||
|
|
||||||
|
os.Setenv("EV_DATA", "mem:")
|
||||||
os.Setenv("EV_HTTP", "[::1]:61234")
|
os.Setenv("EV_HTTP", "[::1]:61234")
|
||||||
os.Setenv("WEBFINGER_DOMAINS", "::1")
|
os.Setenv("WEBFINGER_DOMAINS", "::1")
|
||||||
|
|
||||||
|
|
2
ev.go
2
ev.go
|
@ -54,7 +54,7 @@ func Register(ctx context.Context, name string, d driver.Driver) error {
|
||||||
ctx, span := lg.Span(ctx)
|
ctx, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
return drivers.Modify(ctx, func(ctx context.Context, c *config) error {
|
return drivers.Use(ctx, func(ctx context.Context, c *config) error {
|
||||||
if _, set := c.drivers[name]; set {
|
if _, set := c.drivers[name]; set {
|
||||||
return fmt.Errorf("driver %s already set", name)
|
return fmt.Errorf("driver %s already set", name)
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,7 +74,7 @@ func (c *cron) NewCron(expr string, task func(context.Context, time.Time) error)
|
||||||
c.jobs = append(c.jobs, job)
|
c.jobs = append(c.jobs, job)
|
||||||
}
|
}
|
||||||
func (c *cron) RunOnce(ctx context.Context, once func(context.Context, time.Time) error) {
|
func (c *cron) RunOnce(ctx context.Context, once func(context.Context, time.Time) error) {
|
||||||
c.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
c.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
state.queue = append(state.queue, once)
|
state.queue = append(state.queue, once)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -126,7 +126,7 @@ func (c *cron) run(ctx context.Context, now time.Time) {
|
||||||
span.AddEvent("Cron Run: " + now.Format(time.RFC822))
|
span.AddEvent("Cron Run: " + now.Format(time.RFC822))
|
||||||
// fmt.Println("Cron Run: ", now.Format(time.RFC822))
|
// fmt.Println("Cron Run: ", now.Format(time.RFC822))
|
||||||
|
|
||||||
c.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
c.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
run = append(run, state.queue...)
|
run = append(run, state.queue...)
|
||||||
state.queue = state.queue[:0]
|
state.queue = state.queue[:0]
|
||||||
|
|
||||||
|
|
|
@ -100,7 +100,7 @@ func (d *diskStore) Open(ctx context.Context, dsn string) (driver.Driver, error)
|
||||||
ctx, span := lg.Span(ctx)
|
ctx, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
l.Modify(ctx, func(ctx context.Context, w *wal.Log) error {
|
l.Use(ctx, func(ctx context.Context, w *wal.Log) error {
|
||||||
ctx, span := lg.Span(ctx)
|
ctx, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ func (d *diskStore) EventLog(ctx context.Context, streamID string) (driver.Event
|
||||||
|
|
||||||
el := &eventLog{streamID: streamID, diskStore: d}
|
el := &eventLog{streamID: streamID, diskStore: d}
|
||||||
|
|
||||||
return el, d.openlogs.Modify(ctx, func(ctx context.Context, openlogs *openlogs) error {
|
return el, d.openlogs.Use(ctx, func(ctx context.Context, openlogs *openlogs) error {
|
||||||
ctx, span := lg.Span(ctx)
|
ctx, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
@ -193,7 +193,7 @@ func (e *eventLog) Append(ctx context.Context, events event.Events, version uint
|
||||||
event.SetStreamID(e.streamID, events...)
|
event.SetStreamID(e.streamID, events...)
|
||||||
|
|
||||||
var count uint64
|
var count uint64
|
||||||
err := e.events.Modify(ctx, func(ctx context.Context, l *wal.Log) error {
|
err := e.events.Use(ctx, func(ctx context.Context, l *wal.Log) error {
|
||||||
ctx, span := lg.Span(ctx)
|
ctx, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
@ -248,7 +248,7 @@ func (e *eventLog) Read(ctx context.Context, after, count int64) (event.Events,
|
||||||
|
|
||||||
var events event.Events
|
var events event.Events
|
||||||
|
|
||||||
err := e.events.Modify(ctx, func(ctx context.Context, stream *wal.Log) error {
|
err := e.events.Use(ctx, func(ctx context.Context, stream *wal.Log) error {
|
||||||
ctx, span := lg.Span(ctx)
|
ctx, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
@ -330,7 +330,7 @@ func (e *eventLog) ReadN(ctx context.Context, index ...uint64) (event.Events, er
|
||||||
)
|
)
|
||||||
|
|
||||||
var events event.Events
|
var events event.Events
|
||||||
err := e.events.Modify(ctx, func(ctx context.Context, stream *wal.Log) error {
|
err := e.events.Use(ctx, func(ctx context.Context, stream *wal.Log) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
events, err = readStreamN(ctx, stream, index...)
|
events, err = readStreamN(ctx, stream, index...)
|
||||||
|
@ -352,7 +352,7 @@ func (e *eventLog) FirstIndex(ctx context.Context) (uint64, error) {
|
||||||
var idx uint64
|
var idx uint64
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
err = e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error {
|
err = e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
|
||||||
idx, err = events.FirstIndex()
|
idx, err = events.FirstIndex()
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
@ -371,7 +371,7 @@ func (e *eventLog) LastIndex(ctx context.Context) (uint64, error) {
|
||||||
var idx uint64
|
var idx uint64
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
err = e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error {
|
err = e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
|
||||||
idx, err = events.LastIndex()
|
idx, err = events.LastIndex()
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
@ -391,7 +391,7 @@ func (e *eventLog) Truncate(ctx context.Context, index int64) error {
|
||||||
if index == 0 {
|
if index == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error {
|
return e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
|
||||||
if index < 0 {
|
if index < 0 {
|
||||||
return events.TruncateBack(uint64(-index))
|
return events.TruncateBack(uint64(-index))
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,7 @@ func (m *memstore) EventLog(ctx context.Context, streamID string) (driver.EventL
|
||||||
|
|
||||||
el := &eventLog{streamID: streamID}
|
el := &eventLog{streamID: streamID}
|
||||||
|
|
||||||
err := m.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
err := m.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
_, span := lg.Span(ctx)
|
_, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ func (m *eventLog) Append(ctx context.Context, events event.Events, version uint
|
||||||
|
|
||||||
event.SetStreamID(m.streamID, events...)
|
event.SetStreamID(m.streamID, events...)
|
||||||
|
|
||||||
return uint64(len(events)), m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error {
|
return uint64(len(events)), m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
|
||||||
ctx, span := lg.Span(ctx)
|
ctx, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ func (m *eventLog) ReadN(ctx context.Context, index ...uint64) (event.Events, er
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
var events event.Events
|
var events event.Events
|
||||||
err := m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error {
|
err := m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
events, err = readStreamN(ctx, stream, index...)
|
events, err = readStreamN(ctx, stream, index...)
|
||||||
|
@ -135,7 +135,7 @@ func (m *eventLog) Read(ctx context.Context, after int64, count int64) (event.Ev
|
||||||
|
|
||||||
var events event.Events
|
var events event.Events
|
||||||
|
|
||||||
err := m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error {
|
err := m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
|
||||||
ctx, span := lg.Span(ctx)
|
ctx, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
|
|
@ -76,7 +76,7 @@ func (s *streamer) Subscribe(ctx context.Context, streamID string, start int64)
|
||||||
})
|
})
|
||||||
sub.unsub = s.delete(streamID, sub)
|
sub.unsub = s.delete(streamID, sub)
|
||||||
|
|
||||||
return sub, s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
return sub, s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
state.subscribers[streamID] = append(state.subscribers[streamID], sub)
|
state.subscribers[streamID] = append(state.subscribers[streamID], sub)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -85,14 +85,14 @@ func (s *streamer) Send(ctx context.Context, streamID string, events event.Event
|
||||||
ctx, span := lg.Span(ctx)
|
ctx, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
return s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
return s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
ctx, span := lg.Span(ctx)
|
ctx, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
span.AddEvent(fmt.Sprint("subscribers=", len(state.subscribers[streamID])))
|
span.AddEvent(fmt.Sprint("subscribers=", len(state.subscribers[streamID])))
|
||||||
|
|
||||||
for _, sub := range state.subscribers[streamID] {
|
for _, sub := range state.subscribers[streamID] {
|
||||||
err := sub.position.Modify(ctx, func(ctx context.Context, position *position) error {
|
err := sub.position.Use(ctx, func(ctx context.Context, position *position) error {
|
||||||
ctx, span := lg.Span(ctx)
|
ctx, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
@ -128,7 +128,7 @@ func (s *streamer) delete(streamID string, sub *subscription) func(context.Conte
|
||||||
if err := ctx.Err(); err != nil {
|
if err := ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.state.Modify(ctx, func(ctx context.Context, state *state) error {
|
return s.state.Use(ctx, func(ctx context.Context, state *state) error {
|
||||||
_, span := lg.Span(ctx)
|
_, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
@ -228,7 +228,7 @@ func (s *subscription) Recv(ctx context.Context) <-chan bool {
|
||||||
var wait func(context.Context) bool
|
var wait func(context.Context) bool
|
||||||
defer close(done)
|
defer close(done)
|
||||||
|
|
||||||
err := s.position.Modify(ctx, func(ctx context.Context, position *position) error {
|
err := s.position.Use(ctx, func(ctx context.Context, position *position) error {
|
||||||
_, span := lg.Span(ctx)
|
_, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
@ -280,7 +280,7 @@ func (s *subscription) Events(ctx context.Context) (event.Events, error) {
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
var events event.Events
|
var events event.Events
|
||||||
return events, s.position.Modify(ctx, func(ctx context.Context, position *position) error {
|
return events, s.position.Use(ctx, func(ctx context.Context, position *position) error {
|
||||||
ctx, span := lg.Span(ctx)
|
ctx, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
|
|
@ -106,7 +106,7 @@ func RegisterName(ctx context.Context, name string, e Event) error {
|
||||||
|
|
||||||
span.AddEvent("register: " + name)
|
span.AddEvent("register: " + name)
|
||||||
|
|
||||||
if err := eventTypes.Modify(ctx, func(ctx context.Context, c *config) error {
|
if err := eventTypes.Use(ctx, func(ctx context.Context, c *config) error {
|
||||||
_, span := lg.Span(ctx)
|
_, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
@ -124,7 +124,7 @@ func GetContainer(ctx context.Context, s string) Event {
|
||||||
|
|
||||||
var e Event
|
var e Event
|
||||||
|
|
||||||
eventTypes.Modify(ctx, func(ctx context.Context, c *config) error {
|
eventTypes.Use(ctx, func(ctx context.Context, c *config) error {
|
||||||
_, span := lg.Span(ctx)
|
_, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package locker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"go.opentelemetry.io/otel/attribute"
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
@ -20,12 +21,21 @@ func New[T any](initial *T) *Locked[T] {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Modify will call the function with the locked value
|
type ctxKey struct{ name string }
|
||||||
func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) error) error {
|
|
||||||
|
// Use will call the function with the locked value
|
||||||
|
func (s *Locked[T]) Use(ctx context.Context, fn func(context.Context, *T) error) error {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return fmt.Errorf("locker not initialized")
|
return fmt.Errorf("locker not initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
key := ctxKey{fmt.Sprintf("%p", s)}
|
||||||
|
|
||||||
|
if value := ctx.Value(key); value != nil {
|
||||||
|
return fmt.Errorf("%w: %T", ErrNested, s)
|
||||||
|
}
|
||||||
|
ctx = context.WithValue(ctx, key, key)
|
||||||
|
|
||||||
ctx, span := lg.Span(ctx)
|
ctx, span := lg.Span(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
@ -51,7 +61,7 @@ func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) err
|
||||||
func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
|
func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
|
||||||
var t T
|
var t T
|
||||||
|
|
||||||
err := s.Modify(ctx, func(ctx context.Context, c *T) error {
|
err := s.Use(ctx, func(ctx context.Context, c *T) error {
|
||||||
if c != nil {
|
if c != nil {
|
||||||
t = *c
|
t = *c
|
||||||
}
|
}
|
||||||
|
@ -60,3 +70,5 @@ func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
|
||||||
|
|
||||||
return t, err
|
return t, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrNested = errors.New("nested locker call")
|
||||||
|
|
|
@ -2,6 +2,7 @@ package locker_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/matryer/is"
|
"github.com/matryer/is"
|
||||||
|
@ -22,7 +23,7 @@ func TestLocker(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
err := value.Modify(ctx, func(ctx context.Context, c *config) error {
|
err := value.Use(ctx, func(ctx context.Context, c *config) error {
|
||||||
c.Value = "one"
|
c.Value = "one"
|
||||||
c.Counter++
|
c.Counter++
|
||||||
return nil
|
return nil
|
||||||
|
@ -37,7 +38,7 @@ func TestLocker(t *testing.T) {
|
||||||
|
|
||||||
wait := make(chan struct{})
|
wait := make(chan struct{})
|
||||||
|
|
||||||
go value.Modify(ctx, func(ctx context.Context, c *config) error {
|
go value.Use(ctx, func(ctx context.Context, c *config) error {
|
||||||
c.Value = "two"
|
c.Value = "two"
|
||||||
c.Counter++
|
c.Counter++
|
||||||
close(wait)
|
close(wait)
|
||||||
|
@ -47,7 +48,7 @@ func TestLocker(t *testing.T) {
|
||||||
<-wait
|
<-wait
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
err = value.Modify(ctx, func(ctx context.Context, c *config) error {
|
err = value.Use(ctx, func(ctx context.Context, c *config) error {
|
||||||
c.Value = "three"
|
c.Value = "three"
|
||||||
c.Counter++
|
c.Counter++
|
||||||
return nil
|
return nil
|
||||||
|
@ -60,3 +61,36 @@ func TestLocker(t *testing.T) {
|
||||||
is.Equal(c.Value, "two")
|
is.Equal(c.Value, "two")
|
||||||
is.Equal(c.Counter, 2)
|
is.Equal(c.Counter, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNestedLocker(t *testing.T) {
|
||||||
|
is := is.New(t)
|
||||||
|
|
||||||
|
value := locker.New(&config{})
|
||||||
|
other := locker.New(&config{})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := value.Use(ctx, func(ctx context.Context, c *config) error {
|
||||||
|
return value.Use(ctx, func(ctx context.Context, t *config) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
is.True(errors.Is(err, locker.ErrNested))
|
||||||
|
|
||||||
|
err = value.Use(ctx, func(ctx context.Context, c *config) error {
|
||||||
|
return other.Use(ctx, func(ctx context.Context, t *config) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
is.NoErr(err)
|
||||||
|
|
||||||
|
err = value.Use(ctx, func(ctx context.Context, c *config) error {
|
||||||
|
return other.Use(ctx, func(ctx context.Context, t *config) error {
|
||||||
|
return value.Use(ctx, func(ctx context.Context, x *config) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
is.True(errors.Is(err, locker.ErrNested))
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user