diff --git a/lsm/marshal.go b/lsm/marshal.go new file mode 100644 index 0000000..00034d2 --- /dev/null +++ b/lsm/marshal.go @@ -0,0 +1,138 @@ +package lsm + +import ( + "bytes" + "encoding" + "encoding/binary" + "fmt" +) + +type entry struct { + key string + value uint64 +} + +// MarshalBinary implements encoding.BinaryMarshaler. +func (e *entry) MarshalBinary() (data []byte, err error) { + data = make([]byte, len(e.key), len(e.key)+binary.MaxVarintLen16) + copy(data, e.key) + + data = binary.AppendUvarint(data, e.value) + reverse(data[len(e.key):]) + return data, err +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler. +func (e *entry) UnmarshalBinary(data []byte) error { + // fmt.Println("unmarshal", data, string(data)) + + if len(data) < binary.MaxVarintLen16 { + return fmt.Errorf("%w: bad data", ErrDecode) + } + head := make([]byte, binary.MaxVarintLen16) + copy(head, data[max(0, len(data)-cap(head)):]) + reverse(head) + + size := 0 + e.value, size = binary.Uvarint(head) + if size == 0 { + return fmt.Errorf("%w: invalid data", ErrDecode) + } + e.key = string(data[:len(data)-size]) + + return nil +} + +var _ encoding.BinaryMarshaler = (*entry)(nil) +var _ encoding.BinaryUnmarshaler = (*entry)(nil) + +type entries []entry + +// MarshalBinary implements encoding.BinaryMarshaler. +func (lis *entries) MarshalBinary() (data []byte, err error) { + var buf bytes.Buffer + + for _, e := range *lis { + d, err := e.MarshalBinary() + if err != nil { + return nil, err + } + + _, err = buf.Write(d) + if err != nil { + return nil, err + } + + _, err = buf.Write(reverse(binary.AppendUvarint(make([]byte, 0, binary.MaxVarintLen32), uint64(len(d))))) + if err != nil { + return nil, err + } + } + + return buf.Bytes(), err +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler. +func (lis *entries) UnmarshalBinary(data []byte) error { + head := make([]byte, binary.MaxVarintLen16) + pos := uint64(len(data)) + + for pos > 0 { + copy(head, data[max(0, pos-uint64(cap(head))):]) + length, size := binary.Uvarint(reverse(head)) + + e := entry{} + if err := e.UnmarshalBinary(data[max(0, pos-(length+uint64(size))) : pos-uint64(size)]); err != nil { + return err + } + *lis = append(*lis, e) + + pos -= length + uint64(size) + } + reverse(*lis) + return nil +} + +var _ encoding.BinaryMarshaler = (*entries)(nil) +var _ encoding.BinaryUnmarshaler = (*entries)(nil) + +type segment struct { + entries entries +} + +// MarshalBinary implements encoding.BinaryMarshaler. +func (s *segment) MarshalBinary() (data []byte, err error) { + head := header{ + entries: uint64(len(s.entries)), + } + + data, err = s.entries.MarshalBinary() + if err != nil { + return nil, err + } + + head.datalen = uint64(len(data)) + + h := hash() + h.Write(data) + head.sig = h.Sum(nil) + + return head.Append(data), err +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler. +func (s *segment) UnmarshalBinary(data []byte) error { + head, err := ReadHead(data) + if err != nil { + return err + } + + h := hash() + h.Write(data[:head.datalen]) + if !bytes.Equal(head.sig, h.Sum(nil)) { + return fmt.Errorf("%w: invalid checksum", ErrDecode) + } + + s.entries = make(entries, 0, head.entries) + return s.entries.UnmarshalBinary(data[:head.datalen]) +} diff --git a/lsm/marshal_test.go b/lsm/marshal_test.go new file mode 100644 index 0000000..e67f0a0 --- /dev/null +++ b/lsm/marshal_test.go @@ -0,0 +1,71 @@ +package lsm + +import ( + "io/fs" + "testing" + + "github.com/matryer/is" +) + +func TestEncoding(t *testing.T) { + is := is.New(t) + + data := segment{entries: entries{ + {"key-1", 1}, + {"key-2", 2}, + {"key-3", 3}, + {"longerkey-4", 65535}, + }} + + b, err := data.MarshalBinary() + is.NoErr(err) + + var got segment + err = got.UnmarshalBinary(b) + is.NoErr(err) + + is.Equal(data, got) +} + +func TestReverse(t *testing.T) { + is := is.New(t) + + got := []byte("gnirts a si siht") + reverse(got) + + is.Equal(got, []byte("this is a string")) + + got = []byte("!gnirts a si siht") + reverse(got) + + is.Equal(got, []byte("this is a string!")) +} + +func TestFile(t *testing.T) { + is := is.New(t) + + f := basicFile(t) + + sf, err := ReadFile(f) + is.NoErr(err) + + is.Equal(len(sf.segments), 3) +} + +func basicFile(t *testing.T) fs.File { + t.Helper() + + data := segment{entries: entries{ + {"key-1", 1}, + {"key-2", 2}, + {"key-3", 3}, + {"longerkey-4", 65535}, + }} + + b, err := data.MarshalBinary() + if err != nil { + t.Error(err) + } + + return NewFile(b, b, b) +} diff --git a/lsm/sst.go b/lsm/sst.go new file mode 100644 index 0000000..bc901c7 --- /dev/null +++ b/lsm/sst.go @@ -0,0 +1,315 @@ +// SPDX-FileCopyrightText: 2023 Jon Lundy +// SPDX-License-Identifier: BSD-3-Clause +package lsm + +import ( + "bytes" + "encoding" + "encoding/binary" + "errors" + "fmt" + "hash/fnv" + "io" + "io/fs" + "sort" +) + +var ( + magic = reverse(append([]byte("Souris"), '\x01')) + hash = fnv.New32a + hashLength = hash().Size() + // segmentSize = 2 ^ 16 // min 2^9 = 512b, max? 2^20 = 1M + segmentFooterLength = len(magic) + hashLength + binary.MaxVarintLen32 + binary.MaxVarintLen32 +) + +type header struct { + sig []byte + entries uint64 + datalen uint64 + headlen uint64 + end int64 +} + +func ReadHead(data []byte) (*header, error) { + if len(data) < len(magic)+6 { + return nil, fmt.Errorf("%w: invalid size", ErrDecode) + } + + if !bytes.Equal(data[len(data)-len(magic):], magic) { + return nil, fmt.Errorf("%w: invalid header", ErrDecode) + } + + head := make([]byte, 0, segmentFooterLength) + head = reverse(append(head, data[max(0, len(data)-cap(head)-1):]...)) + size, s := binary.Uvarint(head[len(magic)+4:]) + length, i := binary.Uvarint(head[len(magic)+4+s:]) + + return &header{ + sig: head[len(magic) : len(magic)+4], + entries: size, + datalen: length, + headlen: uint64(len(magic) + hashLength + s + i), + end: int64(len(data)), + }, nil +} +func (h *header) Append(data []byte) []byte { + + length := len(data) + data = append(data, h.sig...) + data = binary.AppendUvarint(data, h.entries) + data = binary.AppendUvarint(data, h.datalen) + reverse(data[length:]) + + return append(data, magic...) +} + +var _ encoding.BinaryMarshaler = (*segment)(nil) +var _ encoding.BinaryUnmarshaler = (*segment)(nil) + +var ErrDecode = errors.New("decode") + +func reverse[T any](b []T) []T { + l := len(b) + for i := 0; i < l/2; i++ { + b[i], b[l-i-1] = b[l-i-1], b[i] + } + return b +} + +// func clone[T ~[]E, E any](e []E) []E { +// return append(e[0:0:0], e...) +// } + +type entryBytes []byte + +func (e entryBytes) KeyValue() ([]byte, uint64) { + if len(e) < 2 { + return nil, 0 + } + head := reverse(append(e[0:0:0], e[max(0, len(e)-binary.MaxVarintLen64):]...)) + value, i := binary.Uvarint(head) + return append(e[0:0:0], e[:len(e)-i]...), value +} +func NewKeyValue(key []byte, val uint64) entryBytes { + length := len(key) + data := append(key[0:0:0], key...) + data = binary.AppendUvarint(data, val) + reverse(data[length:]) + + return data +} + +type listEntries []entryBytes + +func (lis *listEntries) WriteTo(wr io.Writer) (int64, error) { + if lis == nil { + return 0, nil + } + + head := header{ + entries: uint64(len(*lis)), + } + h := hash() + + wr = io.MultiWriter(wr, h) + + var i int64 + for _, b := range *lis { + j, err := wr.Write(b) + i += int64(j) + if err != nil { + return i, err + } + + j, err = wr.Write(reverse(binary.AppendUvarint(make([]byte, 0, binary.MaxVarintLen32), uint64(len(b))))) + i += int64(j) + if err != nil { + return i, err + } + } + head.datalen = uint64(i) + head.sig = h.Sum(nil) + + b := head.Append([]byte{}) + j, err := wr.Write(b) + i += int64(j) + + return i, err +} + +var _ sort.Interface = listEntries{} + +// Len implements sort.Interface. +func (lis listEntries) Len() int { + return len(lis) +} + +// Less implements sort.Interface. +func (lis listEntries) Less(i int, j int) bool { + iname, _ := lis[i].KeyValue() + jname, _ := lis[j].KeyValue() + + return bytes.Compare(iname, jname) < 0 +} + +// Swap implements sort.Interface. +func (lis listEntries) Swap(i int, j int) { + lis[i], lis[j] = lis[j], lis[i] +} + +type segmentReader struct { + head *header + rd io.ReaderAt +} + +func (s *segmentReader) FirstEntry() (*entryBytes, error) { + e, _, err := s.readEntryAt(-1) + return e, err +} +func (s *segmentReader) Find(needle []byte) (*entryBytes, bool, error) { + if s == nil { + return nil, false, nil + } + e, pos, err := s.readEntryAt(-1) + if err != nil { + return nil, false, err + } + + last := e + for pos > 0 { + key, _ := e.KeyValue() + switch bytes.Compare(key, needle) { + case 0: // equal + return e, true, nil + case -1: // key=aaa, needle=bbb + last = e + e, pos, err = s.readEntryAt(pos) + if err != nil { + return nil, false, err + } + + case 1: // key=ccc, needle=bbb + return last, false, nil + } + } + return last, false, nil +} +func (s *segmentReader) readEntryAt(pos int64) (*entryBytes, int64, error) { + if pos < 0 { + pos = s.head.end + } + head := make([]byte, binary.MaxVarintLen16) + s.rd.ReadAt(head, pos-binary.MaxVarintLen16) + length, hsize := binary.Uvarint(reverse(head)) + + e := make(entryBytes, length) + _, err := s.rd.ReadAt(e, pos-int64(length)-int64(hsize)) + + return &e, pos - int64(length) - int64(hsize), err +} + +type logFile struct { + rd interface{io.ReaderAt; io.WriterTo} + segments []segmentReader + + fs.File +} + +func ReadFile(fd fs.File) (*logFile, error) { + l := &logFile{File: fd} + + stat, err := fd.Stat() + if err != nil { + return nil, err + } + + eof := stat.Size() + if rd, ok := fd.(interface{io.ReaderAt; io.WriterTo}); ok { + l.rd = rd + + } else { + rd, err := io.ReadAll(fd) + if err != nil { + return nil, err + } + l.rd = bytes.NewReader(rd) + } + + for eof > 0 { + head := make([]byte, segmentFooterLength) + _, err = l.rd.ReadAt(head, eof-int64(segmentFooterLength)) + if err != nil { + return nil, err + } + + s := segmentReader{ + rd: l.rd, + } + s.head, err = ReadHead(head) + s.head.end = eof - int64(s.head.headlen) + if err != nil { + return nil, err + } + eof -= int64(s.head.datalen) + int64(s.head.headlen) + l.segments = append(l.segments, s) + } + + return l, nil +} + +func (l *logFile) Count() int64 { + return int64(len(l.segments)) +} +func (l *logFile) LoadSegment(pos int64) (*segmentBytes, error) { + if pos < 0 { + pos = int64(len(l.segments) - 1) + } + if pos > int64(len(l.segments)-1) { + return nil, ErrDecode + } + s := l.segments[pos] + + b := make([]byte, s.head.datalen+s.head.headlen) + _, err := l.rd.ReadAt(b, s.head.end-int64(len(b))) + if err != nil { + return nil, err + } + + return &segmentBytes{b, -1}, nil +} +func (l *logFile) Find(needle []byte) (*entryBytes, bool, error) { + var last segmentReader + + for _, s := range l.segments { + e, err := s.FirstEntry() + if err != nil { + return nil, false, err + } + k, _ := e.KeyValue() + if bytes.Compare(k, needle) > 0 { + break + } + last = s + } + + return last.Find(needle) +} +func (l *logFile) WriteTo(w io.Writer) (int64, error) { + return l.rd.WriteTo(w) +} + +type segmentBytes struct { + b []byte + pos int +} + +type dataset struct { + rd io.ReaderAt + files []logFile + + fs.FS +} + +func ReadDataset(fd fs.FS) (*dataset, error) { + panic("not implemented") +} diff --git a/lsm/sst_test.go b/lsm/sst_test.go new file mode 100644 index 0000000..839a924 --- /dev/null +++ b/lsm/sst_test.go @@ -0,0 +1,302 @@ +// SPDX-FileCopyrightText: 2023 Jon Lundy +// SPDX-License-Identifier: BSD-3-Clause +package lsm + +import ( + "bytes" + crand "crypto/rand" + "encoding/base64" + "io" + "io/fs" + "math/rand" + "os" + "sort" + "sync" + "testing" + "time" + + "github.com/matryer/is" +) + +func TestLargeFile(t *testing.T) { + is := is.New(t) + + segCount := 4098 + + f := randFile(t, 2_000_000, segCount) + + sf, err := ReadFile(f) + is.NoErr(err) + + is.True(len(sf.segments) <= segCount) + var needle []byte + for i, s := range sf.segments { + e, err := s.FirstEntry() + is.NoErr(err) + k, v := e.KeyValue() + needle = k + t.Logf("Segment-%d: %s = %d", i, k, v) + } + t.Log(f.Stat()) + + tt, ok, err := sf.Find(needle) + is.NoErr(err) + is.True(ok) + key, val := tt.KeyValue() + t.Log(string(key), val) + + tt, ok, err = sf.Find([]byte("needle")) + is.NoErr(err) + is.True(!ok) + key, val = tt.KeyValue() + t.Log(string(key), val) + + tt, ok, err = sf.Find([]byte{'\xff'}) + is.NoErr(err) + is.True(!ok) + key, val = tt.KeyValue() + t.Log(string(key), val) +} + +func TestLargeFileDisk(t *testing.T) { + is := is.New(t) + + segCount := 4098 + + t.Log("generate large file") + f := randFile(t, 2_000_000, segCount) + + fd, err := os.CreateTemp("", "sst*") + is.NoErr(err) + defer func() { t.Log("cleanup:", fd.Name()); fd.Close(); os.Remove(fd.Name()) }() + + t.Log("write file:", fd.Name()) + _, err = io.Copy(fd, f) + is.NoErr(err) + fd.Seek(0, 0) + + sf, err := ReadFile(fd) + is.NoErr(err) + + is.True(len(sf.segments) <= segCount) + var needle []byte + for i, s := range sf.segments { + e, err := s.FirstEntry() + is.NoErr(err) + k, v := e.KeyValue() + needle = k + t.Logf("Segment-%d: %s = %d", i, k, v) + } + t.Log(f.Stat()) + + tt, ok, err := sf.Find(needle) + is.NoErr(err) + is.True(ok) + key, val := tt.KeyValue() + t.Log(string(key), val) + + tt, ok, err = sf.Find([]byte("needle")) + is.NoErr(err) + is.True(!ok) + key, val = tt.KeyValue() + t.Log(string(key), val) + + tt, ok, err = sf.Find([]byte{'\xff'}) + is.NoErr(err) + is.True(!ok) + key, val = tt.KeyValue() + t.Log(string(key), val) +} + +func BenchmarkLargeFile(b *testing.B) { + segCount := 4098 / 4 + f := randFile(b, 2_000_000, segCount) + + sf, err := ReadFile(f) + if err != nil { + b.Error(err) + } + key := make([]byte, 5) + keys := make([][]byte, b.N) + for i := range keys { + _, err = crand.Read(key) + if err != nil { + b.Error(err) + } + keys[i] = []byte(base64.RawURLEncoding.EncodeToString(key)) + } + b.Log("ready", b.N) + b.ResetTimer() + okays := 0 + each := b.N / 10 + for n := 0; n < b.N; n++ { + if each > 0 && n%each == 0 { + b.Log(n) + } + _, ok, err := sf.Find(keys[n]) + if err != nil { + b.Error(err) + } + if ok { + okays++ + } + } + b.Log("okays=", b.N, okays) +} + +func BenchmarkLargeFileB(b *testing.B) { + segCount := 4098 / 16 + f := randFile(b, 2_000_000, segCount) + + sf, err := ReadFile(f) + if err != nil { + b.Error(err) + } + key := make([]byte, 5) + keys := make([][]byte, b.N) + for i := range keys { + _, err = crand.Read(key) + if err != nil { + b.Error(err) + } + keys[i] = []byte(base64.RawURLEncoding.EncodeToString(key)) + } + b.Log("ready", b.N) + b.ResetTimer() + okays := 0 + each := b.N / 10 + for n := 0; n < b.N; n++ { + if each > 0 && n%each == 0 { + b.Log(n) + } + _, ok, err := sf.Find(keys[n]) + if err != nil { + b.Error(err) + } + if ok { + okays++ + } + } + b.Log("okays=", b.N, okays) +} + +func randFile(t interface { + Helper() + Error(...any) +}, size int, segments int) fs.File { + t.Helper() + + lis := make(listEntries, size) + for i := range lis { + key := make([]byte, 5) + _, err := crand.Read(key) + if err != nil { + t.Error(err) + } + key = []byte(base64.RawURLEncoding.EncodeToString(key)) + // key := []byte(fmt.Sprintf("key-%05d", i)) + + lis[i] = NewKeyValue(key, rand.Uint64()%16_777_216) + } + + sort.Sort(sort.Reverse(&lis)) + each := size / segments + if size%segments != 0 { + each++ + } + split := make([]listEntries, segments) + + for i := range split { + if (i+1)*each > len(lis) { + split[i] = lis[i*each : i*each+len(lis[i*each:])] + split = split[:i+1] + break + } + split[i] = lis[i*each : (i+1)*each] + } + + var b bytes.Buffer + for _, s := range split { + s.WriteTo(&b) + } + + return NewFile(b.Bytes()) +} + +type fakeStat struct { + size int64 +} + +// IsDir implements fs.FileInfo. +func (*fakeStat) IsDir() bool { + panic("unimplemented") +} + +// ModTime implements fs.FileInfo. +func (*fakeStat) ModTime() time.Time { + panic("unimplemented") +} + +// Mode implements fs.FileInfo. +func (*fakeStat) Mode() fs.FileMode { + panic("unimplemented") +} + +// Name implements fs.FileInfo. +func (*fakeStat) Name() string { + panic("unimplemented") +} + +// Size implements fs.FileInfo. +func (s *fakeStat) Size() int64 { + return s.size +} + +// Sys implements fs.FileInfo. +func (*fakeStat) Sys() any { + panic("unimplemented") +} + +var _ fs.FileInfo = (*fakeStat)(nil) + +type rd interface { + io.ReaderAt + io.Reader +} +type fakeFile struct { + stat func() fs.FileInfo + + rd +} + +func (fakeFile) Close() error { return nil } +func (f fakeFile) Stat() (fs.FileInfo, error) { return f.stat(), nil } + +func NewFile(b ...[]byte) fs.File { + in := bytes.Join(b, nil) + rd := bytes.NewReader(in) + size := int64(len(in)) + return &fakeFile{stat: func() fs.FileInfo { return &fakeStat{size: size} }, rd: rd} +} +func NewFileFromReader(rd *bytes.Reader) fs.File { + return &fakeFile{stat: func() fs.FileInfo { return &fakeStat{size: int64(rd.Len())} }, rd: rd} +} + +type fakeFS struct { + files map[string]*fakeFile + mu sync.RWMutex +} + +// Open implements fs.FS. +func (f *fakeFS) Open(name string) (fs.File, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + if file, ok := f.files[name]; ok { + return file, nil + } + + return nil, fs.ErrNotExist +} + +var _ fs.FS = (*fakeFS)(nil)