feat(locker): add check for nested calls
This commit is contained in:
		
							parent
							
								
									d27eb21ffb
								
							
						
					
					
						commit
						fcc5d08aa7
					
				@ -97,7 +97,7 @@ func (s *service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
 | 
			
		||||
			var requests []*Request
 | 
			
		||||
 | 
			
		||||
			s.state.Modify(r.Context(), func(ctx context.Context, state *state) error {
 | 
			
		||||
			s.state.Use(r.Context(), func(ctx context.Context, state *state) error {
 | 
			
		||||
				for id, p := range state.peers {
 | 
			
		||||
					fmt.Fprintln(w, "PEER:", id[24:], p.Owner, p.Name)
 | 
			
		||||
				}
 | 
			
		||||
@ -170,7 +170,7 @@ func (s *service) getPending(w http.ResponseWriter, r *http.Request, peerID stri
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	var peer *Peer
 | 
			
		||||
	err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
	err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		var ok bool
 | 
			
		||||
		if peer, ok = state.peers[peerID]; !ok {
 | 
			
		||||
			return fmt.Errorf("peer not found: %s", peerID)
 | 
			
		||||
@ -285,7 +285,7 @@ func (s *service) getResults(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	// }
 | 
			
		||||
 | 
			
		||||
	var requests ListRequest
 | 
			
		||||
	s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
	s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		requests = make([]*Request, 0, len(state.requests))
 | 
			
		||||
 | 
			
		||||
		for _, req := range state.requests {
 | 
			
		||||
@ -306,7 +306,7 @@ func (s *service) getResults(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	args := requestArgs(r)
 | 
			
		||||
	args.Requests = requests[:maxResults]
 | 
			
		||||
 | 
			
		||||
	s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
	s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		args.CountPeers = len(state.peers)
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
@ -323,7 +323,7 @@ func (s *service) getResultsForRequest(w http.ResponseWriter, r *http.Request, u
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	var request *Request
 | 
			
		||||
	err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
	err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		request = state.requests[uuid]
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
@ -430,7 +430,7 @@ func (s *service) postResult(w http.ResponseWriter, r *http.Request, reqID strin
 | 
			
		||||
 | 
			
		||||
	peerID := r.Form.Get("peer_id")
 | 
			
		||||
 | 
			
		||||
	err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
	err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		var ok bool
 | 
			
		||||
		if _, ok = state.peers[peerID]; !ok {
 | 
			
		||||
			log.Printf("peer not found: %s\n", peerID)
 | 
			
		||||
 | 
			
		||||
@ -41,7 +41,7 @@ func (s *service) RefreshJob(ctx context.Context, _ time.Time) error {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = s.state.Modify(ctx, func(ctx context.Context, t *state) error {
 | 
			
		||||
	err = s.state.Use(ctx, func(ctx context.Context, t *state) error {
 | 
			
		||||
		for _, peer := range peers {
 | 
			
		||||
			t.peers[peer.ID] = peer
 | 
			
		||||
		}
 | 
			
		||||
@ -88,7 +88,7 @@ func (s *service) cleanPeerJobs(ctx context.Context) error {
 | 
			
		||||
	defer span.End()
 | 
			
		||||
 | 
			
		||||
	peers := set.New[string]()
 | 
			
		||||
	err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
	err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		for id := range state.peers {
 | 
			
		||||
			peers.Add(id)
 | 
			
		||||
		}
 | 
			
		||||
@ -181,7 +181,7 @@ func (s *service) cleanRequests(ctx context.Context, now time.Time) error {
 | 
			
		||||
 | 
			
		||||
	// truncate all the request streams
 | 
			
		||||
	for _, streamID := range streamIDs {
 | 
			
		||||
		s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
			return state.ApplyEvents(event.NewEvents(&RequestTruncated{
 | 
			
		||||
				RequestID: streamID,
 | 
			
		||||
			}))
 | 
			
		||||
 | 
			
		||||
@ -63,7 +63,7 @@ func (s *service) loadResult(ctx context.Context, request *Request) (*Request, e
 | 
			
		||||
		return request, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return request, s.state.Modify(ctx, func(ctx context.Context, t *state) error {
 | 
			
		||||
	return request, s.state.Use(ctx, func(ctx context.Context, t *state) error {
 | 
			
		||||
 | 
			
		||||
		for i := range request.Responses {
 | 
			
		||||
			res := request.Responses[i]
 | 
			
		||||
@ -116,7 +116,7 @@ func (s *service) Run(ctx context.Context) (err error) {
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
			return state.ApplyEvents(events)
 | 
			
		||||
		})
 | 
			
		||||
		events = events[:0]
 | 
			
		||||
 | 
			
		||||
@ -24,6 +24,7 @@ func TestMain(m *testing.M) {
 | 
			
		||||
	}
 | 
			
		||||
	defer os.RemoveAll(data)
 | 
			
		||||
 | 
			
		||||
	os.Setenv("EV_DATA", "mem:")
 | 
			
		||||
	os.Setenv("EV_HTTP", "[::1]:61234")
 | 
			
		||||
	os.Setenv("WEBFINGER_DOMAINS", "::1")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								ev.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								ev.go
									
									
									
									
									
								
							@ -54,7 +54,7 @@ func Register(ctx context.Context, name string, d driver.Driver) error {
 | 
			
		||||
	ctx, span := lg.Span(ctx)
 | 
			
		||||
	defer span.End()
 | 
			
		||||
 | 
			
		||||
	return drivers.Modify(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
	return drivers.Use(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
		if _, set := c.drivers[name]; set {
 | 
			
		||||
			return fmt.Errorf("driver %s already set", name)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -74,7 +74,7 @@ func (c *cron) NewCron(expr string, task func(context.Context, time.Time) error)
 | 
			
		||||
	c.jobs = append(c.jobs, job)
 | 
			
		||||
}
 | 
			
		||||
func (c *cron) RunOnce(ctx context.Context, once func(context.Context, time.Time) error) {
 | 
			
		||||
	c.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
	c.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		state.queue = append(state.queue, once)
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
@ -126,7 +126,7 @@ func (c *cron) run(ctx context.Context, now time.Time) {
 | 
			
		||||
	span.AddEvent("Cron Run: " + now.Format(time.RFC822))
 | 
			
		||||
	// fmt.Println("Cron Run: ", now.Format(time.RFC822))
 | 
			
		||||
 | 
			
		||||
	c.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
	c.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		run = append(run, state.queue...)
 | 
			
		||||
		state.queue = state.queue[:0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -100,7 +100,7 @@ func (d *diskStore) Open(ctx context.Context, dsn string) (driver.Driver, error)
 | 
			
		||||
		ctx, span := lg.Span(ctx)
 | 
			
		||||
		defer span.End()
 | 
			
		||||
 | 
			
		||||
		l.Modify(ctx, func(ctx context.Context, w *wal.Log) error {
 | 
			
		||||
		l.Use(ctx, func(ctx context.Context, w *wal.Log) error {
 | 
			
		||||
			ctx, span := lg.Span(ctx)
 | 
			
		||||
			defer span.End()
 | 
			
		||||
 | 
			
		||||
@ -139,7 +139,7 @@ func (d *diskStore) EventLog(ctx context.Context, streamID string) (driver.Event
 | 
			
		||||
 | 
			
		||||
	el := &eventLog{streamID: streamID, diskStore: d}
 | 
			
		||||
 | 
			
		||||
	return el, d.openlogs.Modify(ctx, func(ctx context.Context, openlogs *openlogs) error {
 | 
			
		||||
	return el, d.openlogs.Use(ctx, func(ctx context.Context, openlogs *openlogs) error {
 | 
			
		||||
		ctx, span := lg.Span(ctx)
 | 
			
		||||
		defer span.End()
 | 
			
		||||
 | 
			
		||||
@ -193,7 +193,7 @@ func (e *eventLog) Append(ctx context.Context, events event.Events, version uint
 | 
			
		||||
	event.SetStreamID(e.streamID, events...)
 | 
			
		||||
 | 
			
		||||
	var count uint64
 | 
			
		||||
	err := e.events.Modify(ctx, func(ctx context.Context, l *wal.Log) error {
 | 
			
		||||
	err := e.events.Use(ctx, func(ctx context.Context, l *wal.Log) error {
 | 
			
		||||
		ctx, span := lg.Span(ctx)
 | 
			
		||||
		defer span.End()
 | 
			
		||||
 | 
			
		||||
@ -248,7 +248,7 @@ func (e *eventLog) Read(ctx context.Context, after, count int64) (event.Events,
 | 
			
		||||
 | 
			
		||||
	var events event.Events
 | 
			
		||||
 | 
			
		||||
	err := e.events.Modify(ctx, func(ctx context.Context, stream *wal.Log) error {
 | 
			
		||||
	err := e.events.Use(ctx, func(ctx context.Context, stream *wal.Log) error {
 | 
			
		||||
		ctx, span := lg.Span(ctx)
 | 
			
		||||
		defer span.End()
 | 
			
		||||
 | 
			
		||||
@ -330,7 +330,7 @@ func (e *eventLog) ReadN(ctx context.Context, index ...uint64) (event.Events, er
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	var events event.Events
 | 
			
		||||
	err := e.events.Modify(ctx, func(ctx context.Context, stream *wal.Log) error {
 | 
			
		||||
	err := e.events.Use(ctx, func(ctx context.Context, stream *wal.Log) error {
 | 
			
		||||
		var err error
 | 
			
		||||
 | 
			
		||||
		events, err = readStreamN(ctx, stream, index...)
 | 
			
		||||
@ -352,7 +352,7 @@ func (e *eventLog) FirstIndex(ctx context.Context) (uint64, error) {
 | 
			
		||||
	var idx uint64
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	err = e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error {
 | 
			
		||||
	err = e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
 | 
			
		||||
		idx, err = events.FirstIndex()
 | 
			
		||||
		return err
 | 
			
		||||
	})
 | 
			
		||||
@ -371,7 +371,7 @@ func (e *eventLog) LastIndex(ctx context.Context) (uint64, error) {
 | 
			
		||||
	var idx uint64
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	err = e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error {
 | 
			
		||||
	err = e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
 | 
			
		||||
		idx, err = events.LastIndex()
 | 
			
		||||
		return err
 | 
			
		||||
	})
 | 
			
		||||
@ -391,7 +391,7 @@ func (e *eventLog) Truncate(ctx context.Context, index int64) error {
 | 
			
		||||
	if index == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error {
 | 
			
		||||
	return e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
 | 
			
		||||
		if index < 0 {
 | 
			
		||||
			return events.TruncateBack(uint64(-index))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -49,7 +49,7 @@ func (m *memstore) EventLog(ctx context.Context, streamID string) (driver.EventL
 | 
			
		||||
 | 
			
		||||
	el := &eventLog{streamID: streamID}
 | 
			
		||||
 | 
			
		||||
	err := m.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
	err := m.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		_, span := lg.Span(ctx)
 | 
			
		||||
		defer span.End()
 | 
			
		||||
 | 
			
		||||
@ -76,7 +76,7 @@ func (m *eventLog) Append(ctx context.Context, events event.Events, version uint
 | 
			
		||||
 | 
			
		||||
	event.SetStreamID(m.streamID, events...)
 | 
			
		||||
 | 
			
		||||
	return uint64(len(events)), m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error {
 | 
			
		||||
	return uint64(len(events)), m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
 | 
			
		||||
		ctx, span := lg.Span(ctx)
 | 
			
		||||
		defer span.End()
 | 
			
		||||
 | 
			
		||||
@ -117,7 +117,7 @@ func (m *eventLog) ReadN(ctx context.Context, index ...uint64) (event.Events, er
 | 
			
		||||
	defer span.End()
 | 
			
		||||
 | 
			
		||||
	var events event.Events
 | 
			
		||||
	err := m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error {
 | 
			
		||||
	err := m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
 | 
			
		||||
		var err error
 | 
			
		||||
 | 
			
		||||
		events, err = readStreamN(ctx, stream, index...)
 | 
			
		||||
@ -135,7 +135,7 @@ func (m *eventLog) Read(ctx context.Context, after int64, count int64) (event.Ev
 | 
			
		||||
 | 
			
		||||
	var events event.Events
 | 
			
		||||
 | 
			
		||||
	err := m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error {
 | 
			
		||||
	err := m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
 | 
			
		||||
		ctx, span := lg.Span(ctx)
 | 
			
		||||
		defer span.End()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -76,7 +76,7 @@ func (s *streamer) Subscribe(ctx context.Context, streamID string, start int64)
 | 
			
		||||
	})
 | 
			
		||||
	sub.unsub = s.delete(streamID, sub)
 | 
			
		||||
 | 
			
		||||
	return sub, s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
	return sub, s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		state.subscribers[streamID] = append(state.subscribers[streamID], sub)
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
@ -85,14 +85,14 @@ func (s *streamer) Send(ctx context.Context, streamID string, events event.Event
 | 
			
		||||
	ctx, span := lg.Span(ctx)
 | 
			
		||||
	defer span.End()
 | 
			
		||||
 | 
			
		||||
	return s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
	return s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		ctx, span := lg.Span(ctx)
 | 
			
		||||
		defer span.End()
 | 
			
		||||
 | 
			
		||||
		span.AddEvent(fmt.Sprint("subscribers=", len(state.subscribers[streamID])))
 | 
			
		||||
 | 
			
		||||
		for _, sub := range state.subscribers[streamID] {
 | 
			
		||||
			err := sub.position.Modify(ctx, func(ctx context.Context, position *position) error {
 | 
			
		||||
			err := sub.position.Use(ctx, func(ctx context.Context, position *position) error {
 | 
			
		||||
				ctx, span := lg.Span(ctx)
 | 
			
		||||
				defer span.End()
 | 
			
		||||
 | 
			
		||||
@ -128,7 +128,7 @@ func (s *streamer) delete(streamID string, sub *subscription) func(context.Conte
 | 
			
		||||
		if err := ctx.Err(); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		return s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
		return s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
			
		||||
			_, span := lg.Span(ctx)
 | 
			
		||||
			defer span.End()
 | 
			
		||||
 | 
			
		||||
@ -228,7 +228,7 @@ func (s *subscription) Recv(ctx context.Context) <-chan bool {
 | 
			
		||||
		var wait func(context.Context) bool
 | 
			
		||||
		defer close(done)
 | 
			
		||||
 | 
			
		||||
		err := s.position.Modify(ctx, func(ctx context.Context, position *position) error {
 | 
			
		||||
		err := s.position.Use(ctx, func(ctx context.Context, position *position) error {
 | 
			
		||||
			_, span := lg.Span(ctx)
 | 
			
		||||
			defer span.End()
 | 
			
		||||
 | 
			
		||||
@ -280,7 +280,7 @@ func (s *subscription) Events(ctx context.Context) (event.Events, error) {
 | 
			
		||||
	defer span.End()
 | 
			
		||||
 | 
			
		||||
	var events event.Events
 | 
			
		||||
	return events, s.position.Modify(ctx, func(ctx context.Context, position *position) error {
 | 
			
		||||
	return events, s.position.Use(ctx, func(ctx context.Context, position *position) error {
 | 
			
		||||
		ctx, span := lg.Span(ctx)
 | 
			
		||||
		defer span.End()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -106,7 +106,7 @@ func RegisterName(ctx context.Context, name string, e Event) error {
 | 
			
		||||
 | 
			
		||||
	span.AddEvent("register: " + name)
 | 
			
		||||
 | 
			
		||||
	if err := eventTypes.Modify(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
	if err := eventTypes.Use(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
		_, span := lg.Span(ctx)
 | 
			
		||||
		defer span.End()
 | 
			
		||||
 | 
			
		||||
@ -124,7 +124,7 @@ func GetContainer(ctx context.Context, s string) Event {
 | 
			
		||||
 | 
			
		||||
	var e Event
 | 
			
		||||
 | 
			
		||||
	eventTypes.Modify(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
	eventTypes.Use(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
		_, span := lg.Span(ctx)
 | 
			
		||||
		defer span.End()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@ package locker
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"go.opentelemetry.io/otel/attribute"
 | 
			
		||||
@ -20,12 +21,21 @@ func New[T any](initial *T) *Locked[T] {
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Modify will call the function with the locked value
 | 
			
		||||
func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) error) error {
 | 
			
		||||
type ctxKey struct{ name string }
 | 
			
		||||
 | 
			
		||||
// Use will call the function with the locked value
 | 
			
		||||
func (s *Locked[T]) Use(ctx context.Context, fn func(context.Context, *T) error) error {
 | 
			
		||||
	if s == nil {
 | 
			
		||||
		return fmt.Errorf("locker not initialized")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	key := ctxKey{fmt.Sprintf("%p", s)}
 | 
			
		||||
 | 
			
		||||
	if value := ctx.Value(key); value != nil {
 | 
			
		||||
		return fmt.Errorf("%w: %T", ErrNested, s)
 | 
			
		||||
	}
 | 
			
		||||
	ctx = context.WithValue(ctx, key, key)
 | 
			
		||||
 | 
			
		||||
	ctx, span := lg.Span(ctx)
 | 
			
		||||
	defer span.End()
 | 
			
		||||
 | 
			
		||||
@ -51,7 +61,7 @@ func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) err
 | 
			
		||||
func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
 | 
			
		||||
	var t T
 | 
			
		||||
 | 
			
		||||
	err := s.Modify(ctx, func(ctx context.Context, c *T) error {
 | 
			
		||||
	err := s.Use(ctx, func(ctx context.Context, c *T) error {
 | 
			
		||||
		if c != nil {
 | 
			
		||||
			t = *c
 | 
			
		||||
		}
 | 
			
		||||
@ -60,3 +70,5 @@ func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
 | 
			
		||||
 | 
			
		||||
	return t, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var ErrNested = errors.New("nested locker call")
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@ package locker_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/matryer/is"
 | 
			
		||||
@ -22,7 +23,7 @@ func TestLocker(t *testing.T) {
 | 
			
		||||
	ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	err := value.Modify(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
	err := value.Use(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
		c.Value = "one"
 | 
			
		||||
		c.Counter++
 | 
			
		||||
		return nil
 | 
			
		||||
@ -37,7 +38,7 @@ func TestLocker(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	wait := make(chan struct{})
 | 
			
		||||
 | 
			
		||||
	go value.Modify(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
	go value.Use(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
		c.Value = "two"
 | 
			
		||||
		c.Counter++
 | 
			
		||||
		close(wait)
 | 
			
		||||
@ -47,7 +48,7 @@ func TestLocker(t *testing.T) {
 | 
			
		||||
	<-wait
 | 
			
		||||
	cancel()
 | 
			
		||||
 | 
			
		||||
	err = value.Modify(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
	err = value.Use(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
		c.Value = "three"
 | 
			
		||||
		c.Counter++
 | 
			
		||||
		return nil
 | 
			
		||||
@ -60,3 +61,36 @@ func TestLocker(t *testing.T) {
 | 
			
		||||
	is.Equal(c.Value, "two")
 | 
			
		||||
	is.Equal(c.Counter, 2)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNestedLocker(t *testing.T) {
 | 
			
		||||
	is := is.New(t)
 | 
			
		||||
 | 
			
		||||
	value := locker.New(&config{})
 | 
			
		||||
	other := locker.New(&config{})
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	err := value.Use(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
		return value.Use(ctx, func(ctx context.Context, t *config) error {
 | 
			
		||||
			return nil
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
	is.True(errors.Is(err, locker.ErrNested))
 | 
			
		||||
 | 
			
		||||
	err = value.Use(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
		return other.Use(ctx, func(ctx context.Context, t *config) error {
 | 
			
		||||
			return nil
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
	is.NoErr(err)
 | 
			
		||||
 | 
			
		||||
	err = value.Use(ctx, func(ctx context.Context, c *config) error {
 | 
			
		||||
		return other.Use(ctx, func(ctx context.Context, t *config) error {
 | 
			
		||||
			return value.Use(ctx, func(ctx context.Context, x *config) error {
 | 
			
		||||
				return nil
 | 
			
		||||
			})
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
	is.True(errors.Is(err, locker.ErrNested))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user