feat(locker): add check for nested calls
This commit is contained in:
		
							parent
							
								
									d27eb21ffb
								
							
						
					
					
						commit
						fcc5d08aa7
					
				@ -97,7 +97,7 @@ func (s *service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			var requests []*Request
 | 
								var requests []*Request
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			s.state.Modify(r.Context(), func(ctx context.Context, state *state) error {
 | 
								s.state.Use(r.Context(), func(ctx context.Context, state *state) error {
 | 
				
			||||||
				for id, p := range state.peers {
 | 
									for id, p := range state.peers {
 | 
				
			||||||
					fmt.Fprintln(w, "PEER:", id[24:], p.Owner, p.Name)
 | 
										fmt.Fprintln(w, "PEER:", id[24:], p.Owner, p.Name)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
@ -170,7 +170,7 @@ func (s *service) getPending(w http.ResponseWriter, r *http.Request, peerID stri
 | 
				
			|||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var peer *Peer
 | 
						var peer *Peer
 | 
				
			||||||
	err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
						err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
		var ok bool
 | 
							var ok bool
 | 
				
			||||||
		if peer, ok = state.peers[peerID]; !ok {
 | 
							if peer, ok = state.peers[peerID]; !ok {
 | 
				
			||||||
			return fmt.Errorf("peer not found: %s", peerID)
 | 
								return fmt.Errorf("peer not found: %s", peerID)
 | 
				
			||||||
@ -285,7 +285,7 @@ func (s *service) getResults(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	// }
 | 
						// }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var requests ListRequest
 | 
						var requests ListRequest
 | 
				
			||||||
	s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
						s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
		requests = make([]*Request, 0, len(state.requests))
 | 
							requests = make([]*Request, 0, len(state.requests))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for _, req := range state.requests {
 | 
							for _, req := range state.requests {
 | 
				
			||||||
@ -306,7 +306,7 @@ func (s *service) getResults(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	args := requestArgs(r)
 | 
						args := requestArgs(r)
 | 
				
			||||||
	args.Requests = requests[:maxResults]
 | 
						args.Requests = requests[:maxResults]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
						s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
		args.CountPeers = len(state.peers)
 | 
							args.CountPeers = len(state.peers)
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
@ -323,7 +323,7 @@ func (s *service) getResultsForRequest(w http.ResponseWriter, r *http.Request, u
 | 
				
			|||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var request *Request
 | 
						var request *Request
 | 
				
			||||||
	err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
						err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
		request = state.requests[uuid]
 | 
							request = state.requests[uuid]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
@ -430,7 +430,7 @@ func (s *service) postResult(w http.ResponseWriter, r *http.Request, reqID strin
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	peerID := r.Form.Get("peer_id")
 | 
						peerID := r.Form.Get("peer_id")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
						err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
		var ok bool
 | 
							var ok bool
 | 
				
			||||||
		if _, ok = state.peers[peerID]; !ok {
 | 
							if _, ok = state.peers[peerID]; !ok {
 | 
				
			||||||
			log.Printf("peer not found: %s\n", peerID)
 | 
								log.Printf("peer not found: %s\n", peerID)
 | 
				
			||||||
 | 
				
			|||||||
@ -41,7 +41,7 @@ func (s *service) RefreshJob(ctx context.Context, _ time.Time) error {
 | 
				
			|||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = s.state.Modify(ctx, func(ctx context.Context, t *state) error {
 | 
						err = s.state.Use(ctx, func(ctx context.Context, t *state) error {
 | 
				
			||||||
		for _, peer := range peers {
 | 
							for _, peer := range peers {
 | 
				
			||||||
			t.peers[peer.ID] = peer
 | 
								t.peers[peer.ID] = peer
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -88,7 +88,7 @@ func (s *service) cleanPeerJobs(ctx context.Context) error {
 | 
				
			|||||||
	defer span.End()
 | 
						defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	peers := set.New[string]()
 | 
						peers := set.New[string]()
 | 
				
			||||||
	err := s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
						err := s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
		for id := range state.peers {
 | 
							for id := range state.peers {
 | 
				
			||||||
			peers.Add(id)
 | 
								peers.Add(id)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -181,7 +181,7 @@ func (s *service) cleanRequests(ctx context.Context, now time.Time) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// truncate all the request streams
 | 
						// truncate all the request streams
 | 
				
			||||||
	for _, streamID := range streamIDs {
 | 
						for _, streamID := range streamIDs {
 | 
				
			||||||
		s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
							s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
			return state.ApplyEvents(event.NewEvents(&RequestTruncated{
 | 
								return state.ApplyEvents(event.NewEvents(&RequestTruncated{
 | 
				
			||||||
				RequestID: streamID,
 | 
									RequestID: streamID,
 | 
				
			||||||
			}))
 | 
								}))
 | 
				
			||||||
 | 
				
			|||||||
@ -63,7 +63,7 @@ func (s *service) loadResult(ctx context.Context, request *Request) (*Request, e
 | 
				
			|||||||
		return request, nil
 | 
							return request, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return request, s.state.Modify(ctx, func(ctx context.Context, t *state) error {
 | 
						return request, s.state.Use(ctx, func(ctx context.Context, t *state) error {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for i := range request.Responses {
 | 
							for i := range request.Responses {
 | 
				
			||||||
			res := request.Responses[i]
 | 
								res := request.Responses[i]
 | 
				
			||||||
@ -116,7 +116,7 @@ func (s *service) Run(ctx context.Context) (err error) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
							s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
			return state.ApplyEvents(events)
 | 
								return state.ApplyEvents(events)
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
		events = events[:0]
 | 
							events = events[:0]
 | 
				
			||||||
 | 
				
			|||||||
@ -24,6 +24,7 @@ func TestMain(m *testing.M) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	defer os.RemoveAll(data)
 | 
						defer os.RemoveAll(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						os.Setenv("EV_DATA", "mem:")
 | 
				
			||||||
	os.Setenv("EV_HTTP", "[::1]:61234")
 | 
						os.Setenv("EV_HTTP", "[::1]:61234")
 | 
				
			||||||
	os.Setenv("WEBFINGER_DOMAINS", "::1")
 | 
						os.Setenv("WEBFINGER_DOMAINS", "::1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								ev.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								ev.go
									
									
									
									
									
								
							@ -54,7 +54,7 @@ func Register(ctx context.Context, name string, d driver.Driver) error {
 | 
				
			|||||||
	ctx, span := lg.Span(ctx)
 | 
						ctx, span := lg.Span(ctx)
 | 
				
			||||||
	defer span.End()
 | 
						defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return drivers.Modify(ctx, func(ctx context.Context, c *config) error {
 | 
						return drivers.Use(ctx, func(ctx context.Context, c *config) error {
 | 
				
			||||||
		if _, set := c.drivers[name]; set {
 | 
							if _, set := c.drivers[name]; set {
 | 
				
			||||||
			return fmt.Errorf("driver %s already set", name)
 | 
								return fmt.Errorf("driver %s already set", name)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
				
			|||||||
@ -74,7 +74,7 @@ func (c *cron) NewCron(expr string, task func(context.Context, time.Time) error)
 | 
				
			|||||||
	c.jobs = append(c.jobs, job)
 | 
						c.jobs = append(c.jobs, job)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
func (c *cron) RunOnce(ctx context.Context, once func(context.Context, time.Time) error) {
 | 
					func (c *cron) RunOnce(ctx context.Context, once func(context.Context, time.Time) error) {
 | 
				
			||||||
	c.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
						c.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
		state.queue = append(state.queue, once)
 | 
							state.queue = append(state.queue, once)
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
@ -126,7 +126,7 @@ func (c *cron) run(ctx context.Context, now time.Time) {
 | 
				
			|||||||
	span.AddEvent("Cron Run: " + now.Format(time.RFC822))
 | 
						span.AddEvent("Cron Run: " + now.Format(time.RFC822))
 | 
				
			||||||
	// fmt.Println("Cron Run: ", now.Format(time.RFC822))
 | 
						// fmt.Println("Cron Run: ", now.Format(time.RFC822))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	c.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
						c.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
		run = append(run, state.queue...)
 | 
							run = append(run, state.queue...)
 | 
				
			||||||
		state.queue = state.queue[:0]
 | 
							state.queue = state.queue[:0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -100,7 +100,7 @@ func (d *diskStore) Open(ctx context.Context, dsn string) (driver.Driver, error)
 | 
				
			|||||||
		ctx, span := lg.Span(ctx)
 | 
							ctx, span := lg.Span(ctx)
 | 
				
			||||||
		defer span.End()
 | 
							defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		l.Modify(ctx, func(ctx context.Context, w *wal.Log) error {
 | 
							l.Use(ctx, func(ctx context.Context, w *wal.Log) error {
 | 
				
			||||||
			ctx, span := lg.Span(ctx)
 | 
								ctx, span := lg.Span(ctx)
 | 
				
			||||||
			defer span.End()
 | 
								defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -139,7 +139,7 @@ func (d *diskStore) EventLog(ctx context.Context, streamID string) (driver.Event
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	el := &eventLog{streamID: streamID, diskStore: d}
 | 
						el := &eventLog{streamID: streamID, diskStore: d}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return el, d.openlogs.Modify(ctx, func(ctx context.Context, openlogs *openlogs) error {
 | 
						return el, d.openlogs.Use(ctx, func(ctx context.Context, openlogs *openlogs) error {
 | 
				
			||||||
		ctx, span := lg.Span(ctx)
 | 
							ctx, span := lg.Span(ctx)
 | 
				
			||||||
		defer span.End()
 | 
							defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -193,7 +193,7 @@ func (e *eventLog) Append(ctx context.Context, events event.Events, version uint
 | 
				
			|||||||
	event.SetStreamID(e.streamID, events...)
 | 
						event.SetStreamID(e.streamID, events...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var count uint64
 | 
						var count uint64
 | 
				
			||||||
	err := e.events.Modify(ctx, func(ctx context.Context, l *wal.Log) error {
 | 
						err := e.events.Use(ctx, func(ctx context.Context, l *wal.Log) error {
 | 
				
			||||||
		ctx, span := lg.Span(ctx)
 | 
							ctx, span := lg.Span(ctx)
 | 
				
			||||||
		defer span.End()
 | 
							defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -248,7 +248,7 @@ func (e *eventLog) Read(ctx context.Context, after, count int64) (event.Events,
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	var events event.Events
 | 
						var events event.Events
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := e.events.Modify(ctx, func(ctx context.Context, stream *wal.Log) error {
 | 
						err := e.events.Use(ctx, func(ctx context.Context, stream *wal.Log) error {
 | 
				
			||||||
		ctx, span := lg.Span(ctx)
 | 
							ctx, span := lg.Span(ctx)
 | 
				
			||||||
		defer span.End()
 | 
							defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -330,7 +330,7 @@ func (e *eventLog) ReadN(ctx context.Context, index ...uint64) (event.Events, er
 | 
				
			|||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var events event.Events
 | 
						var events event.Events
 | 
				
			||||||
	err := e.events.Modify(ctx, func(ctx context.Context, stream *wal.Log) error {
 | 
						err := e.events.Use(ctx, func(ctx context.Context, stream *wal.Log) error {
 | 
				
			||||||
		var err error
 | 
							var err error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		events, err = readStreamN(ctx, stream, index...)
 | 
							events, err = readStreamN(ctx, stream, index...)
 | 
				
			||||||
@ -352,7 +352,7 @@ func (e *eventLog) FirstIndex(ctx context.Context) (uint64, error) {
 | 
				
			|||||||
	var idx uint64
 | 
						var idx uint64
 | 
				
			||||||
	var err error
 | 
						var err error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error {
 | 
						err = e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
 | 
				
			||||||
		idx, err = events.FirstIndex()
 | 
							idx, err = events.FirstIndex()
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
@ -371,7 +371,7 @@ func (e *eventLog) LastIndex(ctx context.Context) (uint64, error) {
 | 
				
			|||||||
	var idx uint64
 | 
						var idx uint64
 | 
				
			||||||
	var err error
 | 
						var err error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error {
 | 
						err = e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
 | 
				
			||||||
		idx, err = events.LastIndex()
 | 
							idx, err = events.LastIndex()
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
@ -391,7 +391,7 @@ func (e *eventLog) Truncate(ctx context.Context, index int64) error {
 | 
				
			|||||||
	if index == 0 {
 | 
						if index == 0 {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return e.events.Modify(ctx, func(ctx context.Context, events *wal.Log) error {
 | 
						return e.events.Use(ctx, func(ctx context.Context, events *wal.Log) error {
 | 
				
			||||||
		if index < 0 {
 | 
							if index < 0 {
 | 
				
			||||||
			return events.TruncateBack(uint64(-index))
 | 
								return events.TruncateBack(uint64(-index))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
				
			|||||||
@ -49,7 +49,7 @@ func (m *memstore) EventLog(ctx context.Context, streamID string) (driver.EventL
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	el := &eventLog{streamID: streamID}
 | 
						el := &eventLog{streamID: streamID}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := m.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
						err := m.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
		_, span := lg.Span(ctx)
 | 
							_, span := lg.Span(ctx)
 | 
				
			||||||
		defer span.End()
 | 
							defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -76,7 +76,7 @@ func (m *eventLog) Append(ctx context.Context, events event.Events, version uint
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	event.SetStreamID(m.streamID, events...)
 | 
						event.SetStreamID(m.streamID, events...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return uint64(len(events)), m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error {
 | 
						return uint64(len(events)), m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
 | 
				
			||||||
		ctx, span := lg.Span(ctx)
 | 
							ctx, span := lg.Span(ctx)
 | 
				
			||||||
		defer span.End()
 | 
							defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -117,7 +117,7 @@ func (m *eventLog) ReadN(ctx context.Context, index ...uint64) (event.Events, er
 | 
				
			|||||||
	defer span.End()
 | 
						defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var events event.Events
 | 
						var events event.Events
 | 
				
			||||||
	err := m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error {
 | 
						err := m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
 | 
				
			||||||
		var err error
 | 
							var err error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		events, err = readStreamN(ctx, stream, index...)
 | 
							events, err = readStreamN(ctx, stream, index...)
 | 
				
			||||||
@ -135,7 +135,7 @@ func (m *eventLog) Read(ctx context.Context, after int64, count int64) (event.Ev
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	var events event.Events
 | 
						var events event.Events
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := m.events.Modify(ctx, func(ctx context.Context, stream *event.Events) error {
 | 
						err := m.events.Use(ctx, func(ctx context.Context, stream *event.Events) error {
 | 
				
			||||||
		ctx, span := lg.Span(ctx)
 | 
							ctx, span := lg.Span(ctx)
 | 
				
			||||||
		defer span.End()
 | 
							defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -76,7 +76,7 @@ func (s *streamer) Subscribe(ctx context.Context, streamID string, start int64)
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
	sub.unsub = s.delete(streamID, sub)
 | 
						sub.unsub = s.delete(streamID, sub)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return sub, s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
						return sub, s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
		state.subscribers[streamID] = append(state.subscribers[streamID], sub)
 | 
							state.subscribers[streamID] = append(state.subscribers[streamID], sub)
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
@ -85,14 +85,14 @@ func (s *streamer) Send(ctx context.Context, streamID string, events event.Event
 | 
				
			|||||||
	ctx, span := lg.Span(ctx)
 | 
						ctx, span := lg.Span(ctx)
 | 
				
			||||||
	defer span.End()
 | 
						defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
						return s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
		ctx, span := lg.Span(ctx)
 | 
							ctx, span := lg.Span(ctx)
 | 
				
			||||||
		defer span.End()
 | 
							defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		span.AddEvent(fmt.Sprint("subscribers=", len(state.subscribers[streamID])))
 | 
							span.AddEvent(fmt.Sprint("subscribers=", len(state.subscribers[streamID])))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for _, sub := range state.subscribers[streamID] {
 | 
							for _, sub := range state.subscribers[streamID] {
 | 
				
			||||||
			err := sub.position.Modify(ctx, func(ctx context.Context, position *position) error {
 | 
								err := sub.position.Use(ctx, func(ctx context.Context, position *position) error {
 | 
				
			||||||
				ctx, span := lg.Span(ctx)
 | 
									ctx, span := lg.Span(ctx)
 | 
				
			||||||
				defer span.End()
 | 
									defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -128,7 +128,7 @@ func (s *streamer) delete(streamID string, sub *subscription) func(context.Conte
 | 
				
			|||||||
		if err := ctx.Err(); err != nil {
 | 
							if err := ctx.Err(); err != nil {
 | 
				
			||||||
			return err
 | 
								return err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return s.state.Modify(ctx, func(ctx context.Context, state *state) error {
 | 
							return s.state.Use(ctx, func(ctx context.Context, state *state) error {
 | 
				
			||||||
			_, span := lg.Span(ctx)
 | 
								_, span := lg.Span(ctx)
 | 
				
			||||||
			defer span.End()
 | 
								defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -228,7 +228,7 @@ func (s *subscription) Recv(ctx context.Context) <-chan bool {
 | 
				
			|||||||
		var wait func(context.Context) bool
 | 
							var wait func(context.Context) bool
 | 
				
			||||||
		defer close(done)
 | 
							defer close(done)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		err := s.position.Modify(ctx, func(ctx context.Context, position *position) error {
 | 
							err := s.position.Use(ctx, func(ctx context.Context, position *position) error {
 | 
				
			||||||
			_, span := lg.Span(ctx)
 | 
								_, span := lg.Span(ctx)
 | 
				
			||||||
			defer span.End()
 | 
								defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -280,7 +280,7 @@ func (s *subscription) Events(ctx context.Context) (event.Events, error) {
 | 
				
			|||||||
	defer span.End()
 | 
						defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var events event.Events
 | 
						var events event.Events
 | 
				
			||||||
	return events, s.position.Modify(ctx, func(ctx context.Context, position *position) error {
 | 
						return events, s.position.Use(ctx, func(ctx context.Context, position *position) error {
 | 
				
			||||||
		ctx, span := lg.Span(ctx)
 | 
							ctx, span := lg.Span(ctx)
 | 
				
			||||||
		defer span.End()
 | 
							defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -106,7 +106,7 @@ func RegisterName(ctx context.Context, name string, e Event) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	span.AddEvent("register: " + name)
 | 
						span.AddEvent("register: " + name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := eventTypes.Modify(ctx, func(ctx context.Context, c *config) error {
 | 
						if err := eventTypes.Use(ctx, func(ctx context.Context, c *config) error {
 | 
				
			||||||
		_, span := lg.Span(ctx)
 | 
							_, span := lg.Span(ctx)
 | 
				
			||||||
		defer span.End()
 | 
							defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -124,7 +124,7 @@ func GetContainer(ctx context.Context, s string) Event {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	var e Event
 | 
						var e Event
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	eventTypes.Modify(ctx, func(ctx context.Context, c *config) error {
 | 
						eventTypes.Use(ctx, func(ctx context.Context, c *config) error {
 | 
				
			||||||
		_, span := lg.Span(ctx)
 | 
							_, span := lg.Span(ctx)
 | 
				
			||||||
		defer span.End()
 | 
							defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -2,6 +2,7 @@ package locker
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"go.opentelemetry.io/otel/attribute"
 | 
						"go.opentelemetry.io/otel/attribute"
 | 
				
			||||||
@ -20,12 +21,21 @@ func New[T any](initial *T) *Locked[T] {
 | 
				
			|||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Modify will call the function with the locked value
 | 
					type ctxKey struct{ name string }
 | 
				
			||||||
func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) error) error {
 | 
					
 | 
				
			||||||
 | 
					// Use will call the function with the locked value
 | 
				
			||||||
 | 
					func (s *Locked[T]) Use(ctx context.Context, fn func(context.Context, *T) error) error {
 | 
				
			||||||
	if s == nil {
 | 
						if s == nil {
 | 
				
			||||||
		return fmt.Errorf("locker not initialized")
 | 
							return fmt.Errorf("locker not initialized")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						key := ctxKey{fmt.Sprintf("%p", s)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if value := ctx.Value(key); value != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("%w: %T", ErrNested, s)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						ctx = context.WithValue(ctx, key, key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ctx, span := lg.Span(ctx)
 | 
						ctx, span := lg.Span(ctx)
 | 
				
			||||||
	defer span.End()
 | 
						defer span.End()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -51,7 +61,7 @@ func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) err
 | 
				
			|||||||
func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
 | 
					func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
 | 
				
			||||||
	var t T
 | 
						var t T
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := s.Modify(ctx, func(ctx context.Context, c *T) error {
 | 
						err := s.Use(ctx, func(ctx context.Context, c *T) error {
 | 
				
			||||||
		if c != nil {
 | 
							if c != nil {
 | 
				
			||||||
			t = *c
 | 
								t = *c
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -60,3 +70,5 @@ func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	return t, err
 | 
						return t, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var ErrNested = errors.New("nested locker call")
 | 
				
			||||||
 | 
				
			|||||||
@ -2,6 +2,7 @@ package locker_test
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/matryer/is"
 | 
						"github.com/matryer/is"
 | 
				
			||||||
@ -22,7 +23,7 @@ func TestLocker(t *testing.T) {
 | 
				
			|||||||
	ctx, cancel := context.WithCancel(context.Background())
 | 
						ctx, cancel := context.WithCancel(context.Background())
 | 
				
			||||||
	defer cancel()
 | 
						defer cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := value.Modify(ctx, func(ctx context.Context, c *config) error {
 | 
						err := value.Use(ctx, func(ctx context.Context, c *config) error {
 | 
				
			||||||
		c.Value = "one"
 | 
							c.Value = "one"
 | 
				
			||||||
		c.Counter++
 | 
							c.Counter++
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
@ -37,7 +38,7 @@ func TestLocker(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	wait := make(chan struct{})
 | 
						wait := make(chan struct{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go value.Modify(ctx, func(ctx context.Context, c *config) error {
 | 
						go value.Use(ctx, func(ctx context.Context, c *config) error {
 | 
				
			||||||
		c.Value = "two"
 | 
							c.Value = "two"
 | 
				
			||||||
		c.Counter++
 | 
							c.Counter++
 | 
				
			||||||
		close(wait)
 | 
							close(wait)
 | 
				
			||||||
@ -47,7 +48,7 @@ func TestLocker(t *testing.T) {
 | 
				
			|||||||
	<-wait
 | 
						<-wait
 | 
				
			||||||
	cancel()
 | 
						cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = value.Modify(ctx, func(ctx context.Context, c *config) error {
 | 
						err = value.Use(ctx, func(ctx context.Context, c *config) error {
 | 
				
			||||||
		c.Value = "three"
 | 
							c.Value = "three"
 | 
				
			||||||
		c.Counter++
 | 
							c.Counter++
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
@ -60,3 +61,36 @@ func TestLocker(t *testing.T) {
 | 
				
			|||||||
	is.Equal(c.Value, "two")
 | 
						is.Equal(c.Value, "two")
 | 
				
			||||||
	is.Equal(c.Counter, 2)
 | 
						is.Equal(c.Counter, 2)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestNestedLocker(t *testing.T) {
 | 
				
			||||||
 | 
						is := is.New(t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						value := locker.New(&config{})
 | 
				
			||||||
 | 
						other := locker.New(&config{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ctx, cancel := context.WithCancel(context.Background())
 | 
				
			||||||
 | 
						defer cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err := value.Use(ctx, func(ctx context.Context, c *config) error {
 | 
				
			||||||
 | 
							return value.Use(ctx, func(ctx context.Context, t *config) error {
 | 
				
			||||||
 | 
								return nil
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						is.True(errors.Is(err, locker.ErrNested))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = value.Use(ctx, func(ctx context.Context, c *config) error {
 | 
				
			||||||
 | 
							return other.Use(ctx, func(ctx context.Context, t *config) error {
 | 
				
			||||||
 | 
								return nil
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						is.NoErr(err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = value.Use(ctx, func(ctx context.Context, c *config) error {
 | 
				
			||||||
 | 
							return other.Use(ctx, func(ctx context.Context, t *config) error {
 | 
				
			||||||
 | 
								return value.Use(ctx, func(ctx context.Context, x *config) error {
 | 
				
			||||||
 | 
									return nil
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						is.True(errors.Is(err, locker.ErrNested))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user