ev/pkg/es/driver/mem-store/mem-store.go
2022-08-04 22:11:34 -06:00

143 lines
3.2 KiB
Go

package memstore
import (
"context"
"fmt"
"github.com/sour-is/ev/pkg/es"
"github.com/sour-is/ev/pkg/es/driver"
"github.com/sour-is/ev/pkg/es/event"
"github.com/sour-is/ev/pkg/locker"
"github.com/sour-is/ev/pkg/math"
)
type state struct {
streams map[string]event.Events
}
type memstore struct {
state *locker.Locked[state]
}
var _ driver.Driver = (*memstore)(nil)
func Init(ctx context.Context) {
es.Register(ctx, "mem", &memstore{})
}
func (memstore) Open(name string) (driver.EventStore, error) {
s := &state{streams: make(map[string]event.Events)}
return &memstore{locker.New(s)}, nil
}
// Append implements driver.EventStore
func (m *memstore) Append(ctx context.Context, streamID string, events event.Events) (uint64, error) {
event.SetStreamID(streamID, events...)
return uint64(len(events)), m.state.Modify(ctx, func(state *state) error {
stream := state.streams[streamID]
last := uint64(len(stream))
for i := range events {
pos := last + uint64(i) + 1
event.SetPosition(events[i], pos)
stream = append(stream, events[i])
state.streams[streamID] = stream
}
return nil
})
}
// Load implements driver.EventStore
func (m *memstore) Load(ctx context.Context, agg event.Aggregate) error {
return m.state.Modify(ctx, func(state *state) error {
events := state.streams[agg.StreamID()]
event.SetStreamID(agg.StreamID(), events...)
agg.ApplyEvent(events...)
return nil
})
}
// Read implements driver.EventStore
func (m *memstore) Read(ctx context.Context, streamID string, pos int64, count int64) (event.Events, error) {
events := make([]event.Event, math.Abs(count))
err := m.state.Modify(ctx, func(state *state) error {
stream := state.streams[streamID]
var first, last, start uint64
first = stream.First().EventMeta().Position
last = stream.Last().EventMeta().Position
if first == 0 || last == 0 {
events = events[:0]
return nil
}
switch {
case pos >= 0:
start = first + uint64(pos)
if pos == 0 && count < 0 {
count = -count // if pos=0 assume forward count.
}
case pos < 0:
start = uint64(int64(last) + pos + 1)
if pos == -1 && count > 0 {
count = -count // if pos=-1 assume backward count.
}
}
for i := range events {
events[i] = stream[start-1]
if count > 0 {
start += 1
} else {
start -= 1
}
if start < first || start > last {
events = events[:i+1]
break
}
}
return nil
})
if err != nil {
return nil, err
}
return events, nil
}
// Save implements driver.EventStore
func (m *memstore) Save(ctx context.Context, agg event.Aggregate) (uint64, error) {
events := agg.Events(true)
event.SetStreamID(agg.StreamID(), events...)
err := m.state.Modify(ctx, func(state *state) error {
stream := state.streams[agg.StreamID()]
last := uint64(len(stream))
if agg.StreamVersion() != last {
return fmt.Errorf("current version wrong %d != %d", agg.StreamVersion(), last)
}
for i := range events {
pos := last + uint64(i) + 1
event.SetPosition(events[i], pos)
stream = append(stream, events[i])
}
state.streams[agg.StreamID()] = stream
return nil
})
if err != nil {
return 0, err
}
agg.Commit()
return uint64(len(events)), nil
}