feat(locker): add check for nested calls

This commit is contained in:
Jon Lundy
2023-03-19 08:31:00 -06:00
parent d27eb21ffb
commit fcc5d08aa7
12 changed files with 87 additions and 40 deletions

View File

@@ -2,6 +2,7 @@ package locker
import (
"context"
"errors"
"fmt"
"go.opentelemetry.io/otel/attribute"
@@ -20,12 +21,21 @@ func New[T any](initial *T) *Locked[T] {
return s
}
// Modify will call the function with the locked value
func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) error) error {
type ctxKey struct{ name string }
// Use will call the function with the locked value
func (s *Locked[T]) Use(ctx context.Context, fn func(context.Context, *T) error) error {
if s == nil {
return fmt.Errorf("locker not initialized")
}
key := ctxKey{fmt.Sprintf("%p", s)}
if value := ctx.Value(key); value != nil {
return fmt.Errorf("%w: %T", ErrNested, s)
}
ctx = context.WithValue(ctx, key, key)
ctx, span := lg.Span(ctx)
defer span.End()
@@ -51,7 +61,7 @@ func (s *Locked[T]) Modify(ctx context.Context, fn func(context.Context, *T) err
func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
var t T
err := s.Modify(ctx, func(ctx context.Context, c *T) error {
err := s.Use(ctx, func(ctx context.Context, c *T) error {
if c != nil {
t = *c
}
@@ -60,3 +70,5 @@ func (s *Locked[T]) Copy(ctx context.Context) (T, error) {
return t, err
}
var ErrNested = errors.New("nested locker call")

View File

@@ -2,6 +2,7 @@ package locker_test
import (
"context"
"errors"
"testing"
"github.com/matryer/is"
@@ -22,7 +23,7 @@ func TestLocker(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := value.Modify(ctx, func(ctx context.Context, c *config) error {
err := value.Use(ctx, func(ctx context.Context, c *config) error {
c.Value = "one"
c.Counter++
return nil
@@ -37,7 +38,7 @@ func TestLocker(t *testing.T) {
wait := make(chan struct{})
go value.Modify(ctx, func(ctx context.Context, c *config) error {
go value.Use(ctx, func(ctx context.Context, c *config) error {
c.Value = "two"
c.Counter++
close(wait)
@@ -47,7 +48,7 @@ func TestLocker(t *testing.T) {
<-wait
cancel()
err = value.Modify(ctx, func(ctx context.Context, c *config) error {
err = value.Use(ctx, func(ctx context.Context, c *config) error {
c.Value = "three"
c.Counter++
return nil
@@ -60,3 +61,36 @@ func TestLocker(t *testing.T) {
is.Equal(c.Value, "two")
is.Equal(c.Counter, 2)
}
func TestNestedLocker(t *testing.T) {
is := is.New(t)
value := locker.New(&config{})
other := locker.New(&config{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := value.Use(ctx, func(ctx context.Context, c *config) error {
return value.Use(ctx, func(ctx context.Context, t *config) error {
return nil
})
})
is.True(errors.Is(err, locker.ErrNested))
err = value.Use(ctx, func(ctx context.Context, c *config) error {
return other.Use(ctx, func(ctx context.Context, t *config) error {
return nil
})
})
is.NoErr(err)
err = value.Use(ctx, func(ctx context.Context, c *config) error {
return other.Use(ctx, func(ctx context.Context, t *config) error {
return value.Use(ctx, func(ctx context.Context, x *config) error {
return nil
})
})
})
is.True(errors.Is(err, locker.ErrNested))
}