go-pkg/mercury/sql/sql.go
2024-04-19 10:56:27 -06:00

595 lines
14 KiB
Go

package sql
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"slices"
"strings"
sq "github.com/Masterminds/squirrel"
"go.sour.is/pkg/lg"
"go.sour.is/pkg/mercury"
"golang.org/x/exp/maps"
)
var MAX_FILTER int = 40
type sqlHandler struct {
name string
db *sql.DB
paceholderFormat sq.PlaceholderFormat
listFormat [2]rune
readonly bool
getWhere func(search mercury.Search) (func(sq.SelectBuilder) sq.SelectBuilder, error)
}
var (
_ mercury.GetIndex = (*sqlHandler)(nil)
_ mercury.GetConfig = (*sqlHandler)(nil)
_ mercury.GetRules = (*sqlHandler)(nil)
_ mercury.WriteConfig = (*sqlHandler)(nil)
)
func Register() func(context.Context) error {
var hdlrs []*sqlHandler
mercury.Registry.Register("sql", func(s *mercury.Space) any {
var dsn string
var opts strings.Builder
var dbtype string
var readonly bool = slices.Contains(s.Tags, "readonly")
for _, c := range s.List {
if c.Name == "match" {
continue
}
if c.Name == "dbtype" {
dbtype = c.First()
continue
}
if c.Name == "dsn" {
dsn = c.First()
break
}
fmt.Fprintln(&opts, c.Name, "=", c.First())
}
if dsn == "" {
dsn = opts.String()
}
db, err := openDB(dbtype, dsn)
if err != nil {
return err
}
if err = db.Ping(); err != nil {
return err
}
switch dbtype {
case "sqlite", "libsql", "libsql+embed":
h := &sqlHandler{s.Space, db, sq.Question, [2]rune{'[', ']'}, readonly, GetWhereSQ}
hdlrs = append(hdlrs, h)
return h
case "postgres":
h := &sqlHandler{s.Space, db, sq.Dollar, [2]rune{'{', '}'}, readonly, GetWherePG}
hdlrs = append(hdlrs, h)
return h
default:
return fmt.Errorf("unsupported dbtype: %s", dbtype)
}
})
return func(ctx context.Context) error {
var errs error
for _, h := range hdlrs {
// if err = ctx.Err(); err != nil {
// return errors.Join(errs, err)
// }
errs = errors.Join(errs, h.db.Close())
}
return errs
}
}
type Space struct {
mercury.Space
id uint64
}
type Value struct {
mercury.Value
id uint64
}
func (p *sqlHandler) GetIndex(ctx context.Context, search mercury.Search) (mercury.Config, error) {
ctx, span := lg.Span(ctx)
defer span.End()
where, err := p.getWhere(search)
if err != nil {
return nil, err
}
lis, err := p.listSpace(ctx, nil, where)
if err != nil {
log.Println(err)
return nil, err
}
config := make(mercury.Config, len(lis))
for i, s := range lis {
config[i] = &s.Space
}
return config, nil
}
func (p *sqlHandler) GetConfig(ctx context.Context, search mercury.Search) (config mercury.Config, err error) {
ctx, span := lg.Span(ctx)
defer span.End()
where, err := p.getWhere(search)
if err != nil {
return nil, err
}
lis, err := p.listSpace(ctx, nil, where)
if err != nil {
log.Println(err)
return nil, err
}
if len(lis) == 0 {
return nil, nil
}
spaceIDX := make([]uint64, len(lis))
spaceMap := make(map[uint64]int, len(lis))
config = make(mercury.Config, len(lis))
for i, s := range lis {
spaceIDX[i] = s.id
config[i] = &s.Space
spaceMap[s.id] = i
}
query := sq.Select(`"id"`, `"name"`, `"seq"`, `"notes"`, `"tags"`, `"values"`).
From("mercury_values").
Where(sq.Eq{"id": spaceIDX}).
OrderBy("id asc", "seq asc").
PlaceholderFormat(p.paceholderFormat)
span.AddEvent(p.name)
span.AddEvent(lg.LogQuery(query.ToSql()))
rows, err := query.RunWith(p.db).
QueryContext(ctx)
if err != nil {
log.Println(err)
return nil, err
}
defer rows.Close()
for rows.Next() {
var s Value
err = rows.Scan(
&s.id,
&s.Name,
&s.Seq,
listScan(&s.Notes, p.listFormat),
listScan(&s.Tags, p.listFormat),
listScan(&s.Values, p.listFormat),
)
if err != nil {
return nil, err
}
if u, ok := spaceMap[s.id]; ok {
lis[u].List = append(lis[u].List, s.Value)
}
}
err = rows.Err()
span.RecordError(err)
span.AddEvent(fmt.Sprint("read index ", len(lis)))
// log.Println(config.String())
return config, err
}
func (p *sqlHandler) listSpace(ctx context.Context, tx sq.BaseRunner, where func(sq.SelectBuilder) sq.SelectBuilder) ([]*Space, error) {
ctx, span := lg.Span(ctx)
defer span.End()
if tx == nil {
tx = p.db
}
query := sq.Select(`"id"`, `"space"`, `"notes"`, `"tags"`, `"trailer"`).
From("mercury_spaces").
OrderBy("space asc").
PlaceholderFormat(p.paceholderFormat)
query = where(query)
span.AddEvent(p.name)
span.AddEvent(lg.LogQuery(query.ToSql()))
rows, err := query.RunWith(tx).
QueryContext(ctx)
if err != nil {
log.Println(err)
return nil, err
}
defer rows.Close()
var lis []*Space
for rows.Next() {
var s Space
err = rows.Scan(
&s.id,
&s.Space.Space,
listScan(&s.Space.Notes, p.listFormat),
listScan(&s.Space.Tags, p.listFormat),
listScan(&s.Trailer, p.listFormat),
)
if err != nil {
return nil, err
}
lis = append(lis, &s)
}
err = rows.Err()
span.RecordError(err)
span.AddEvent(fmt.Sprint("read config ", len(lis)))
return lis, err
}
// WriteConfig writes a config map to database
func (p *sqlHandler) WriteConfig(ctx context.Context, config mercury.Config) (err error) {
ctx, span := lg.Span(ctx)
defer span.End()
if p.readonly {
return fmt.Errorf("readonly database")
}
// Delete spaces that are present in input but are empty.
deleteSpaces := make(map[string]struct{})
// get names of each space
var names = make(map[string]int)
for i, v := range config {
names[v.Space] = i
if len(v.Tags) == 0 && len(v.Notes) == 0 && len(v.List) == 0 {
deleteSpaces[v.Space] = struct{}{}
}
}
tx, err := p.db.Begin()
if err != nil {
return err
}
defer func() {
if err != nil && tx != nil {
tx.Rollback()
}
}()
// get current spaces
where := func(qry sq.SelectBuilder) sq.SelectBuilder { return qry.Where(sq.Eq{"space": maps.Keys(names)}) }
lis, err := p.listSpace(ctx, tx, where)
if err != nil {
return
}
// determine which are being updated
var deleteIDs []uint64
var updateIDs []uint64
var currentNames = make(map[string]struct{}, len(lis))
var updateSpaces []*mercury.Space
var insertSpaces []*mercury.Space
for _, s := range lis {
spaceName := s.Space.Space
currentNames[spaceName] = struct{}{}
if _, ok := deleteSpaces[spaceName]; ok {
deleteIDs = append(deleteIDs, s.id)
continue
}
updateSpaces = append(updateSpaces, config[names[spaceName]])
updateIDs = append(updateIDs, s.id)
}
for _, s := range config {
spaceName := s.Space
if _, ok := currentNames[spaceName]; !ok {
insertSpaces = append(insertSpaces, s)
}
}
// delete spaces
if ids := deleteIDs; len(ids) > 0 {
_, err = sq.Delete("mercury_spaces").Where(sq.Eq{"id": ids}).RunWith(tx).PlaceholderFormat(p.paceholderFormat).ExecContext(ctx)
if err != nil {
return err
}
}
// delete values
if ids := append(updateIDs, deleteIDs...); len(ids) > 0 {
_, err = sq.Delete("mercury_values").Where(sq.Eq{"id": ids}).RunWith(tx).PlaceholderFormat(p.paceholderFormat).ExecContext(ctx)
if err != nil {
return err
}
}
var newValues []*Value
// update spaces
for i, u := range updateSpaces {
query := sq.Update("mercury_spaces").
Where(sq.Eq{"id": updateIDs[i]}).
Set("tags", listValue(u.Tags, p.listFormat)).
Set("notes", listValue(u.Notes, p.listFormat)).
Set("trailer", listValue(u.Trailer, p.listFormat)).
PlaceholderFormat(p.paceholderFormat)
span.AddEvent(p.name)
span.AddEvent(lg.LogQuery(query.ToSql()))
_, err := query.RunWith(tx).ExecContext(ctx)
if err != nil {
return err
}
// log.Debugf("UPDATED %d SPACES", len(updateSpaces))
for _, v := range u.List {
newValues = append(newValues, &Value{Value: v, id: updateIDs[i]})
}
}
// insert spaces
for _, s := range insertSpaces {
var id uint64
query := sq.Insert("mercury_spaces").
PlaceholderFormat(p.paceholderFormat).
Columns("space", "tags", "notes", "trailer").
Values(
s.Space,
listValue(s.Tags, p.listFormat),
listValue(s.Notes, p.listFormat),
listValue(s.Trailer, p.listFormat),
).
Suffix("RETURNING \"id\"")
span.AddEvent(p.name)
span.AddEvent(lg.LogQuery(query.ToSql()))
err := query.
RunWith(tx).
QueryRowContext(ctx).
Scan(&id)
if err != nil {
span.AddEvent(p.name)
s, v, _ := query.ToSql()
log.Println(s, v, err)
return err
}
for _, v := range s.List {
newValues = append(newValues, &Value{Value: v, id: id})
}
}
// write all values to db.
err = p.writeValues(ctx, tx, newValues)
// log.Debugf("WROTE %d ATTRS", len(attrs))
tx.Commit()
tx = nil
return
}
// writeValues writes the values to db
func (p *sqlHandler) writeValues(ctx context.Context, tx sq.BaseRunner, lis []*Value) (err error) {
ctx, span := lg.Span(ctx)
defer span.End()
if len(lis) == 0 {
return nil
}
newInsert := func() sq.InsertBuilder {
return sq.Insert("mercury_values").
RunWith(tx).
PlaceholderFormat(p.paceholderFormat).
Columns(
`"id"`,
`"seq"`,
`"name"`,
`"values"`,
`"notes"`,
`"tags"`,
)
}
chunk := int(65000 / 3)
insert := newInsert()
for i, s := range lis {
insert = insert.Values(
s.id,
s.Seq,
s.Name,
listValue(s.Values, p.listFormat),
listValue(s.Notes, p.listFormat),
listValue(s.Tags, p.listFormat),
)
// log.Debug(s.Name)
if i > 0 && i%chunk == 0 {
// log.Debugf("inserting %v rows into %v", i%chunk, d.Table)
span.AddEvent(p.name)
span.AddEvent(lg.LogQuery(insert.ToSql()))
_, err = insert.ExecContext(ctx)
if err != nil {
// log.Error(err)
return
}
insert = newInsert()
}
}
if len(lis)%chunk > 0 {
// log.Debugf("inserting %v rows into %v", len(lis)%chunk, d.Table)
span.AddEvent(p.name)
span.AddEvent(lg.LogQuery(insert.ToSql()))
_, err = insert.ExecContext(ctx)
if err != nil {
// log.Error(err)
return
}
}
return
}
func GetWherePG(search mercury.Search) (func(sq.SelectBuilder) sq.SelectBuilder, error) {
var where sq.Or
space := "space"
for _, m := range search.NamespaceSearch {
switch m.(type) {
case mercury.NamespaceNode:
where = append(where, sq.Eq{space: m.Value()})
case mercury.NamespaceStar:
where = append(where, sq.Like{space: m.Value()})
case mercury.NamespaceTrace:
e := sq.Expr(`? LIKE `+space+` || '%'`, m.Value())
where = append(where, e)
}
}
var joins []sq.SelectBuilder
for i, o := range search.Find {
log.Println(o)
if i > MAX_FILTER {
err := fmt.Errorf("too many filters [%d]", MAX_FILTER)
return nil, err
}
q := sq.Select("DISTINCT id").From("mercury_values")
switch o.Op {
case "key":
q = q.Where(sq.Eq{"name": o.Left})
case "nkey":
q = q.Where(sq.NotEq{"name": o.Left})
case "eq":
q = q.Where("name = ? AND ? = any (values)", o.Left, o.Right)
case "neq":
q = q.Where("name = ? AND ? != any (values)", o.Left, o.Right)
case "gt":
q = q.Where("name = ? AND ? > any (values)", o.Left, o.Right)
case "lt":
q = q.Where("name = ? AND ? < any (values)", o.Left, o.Right)
case "ge":
q = q.Where("name = ? AND ? >= any (values)", o.Left, o.Right)
case "le":
q = q.Where("name = ? AND ? <= any (values)", o.Left, o.Right)
// case "like":
// q = q.Where("name = ? AND value LIKE ?", o.Left, o.Right)
// case "in":
// q = q.Where(sq.Eq{"name": o.Left, "value": strings.Split(o.Right, " ")})
}
joins = append(joins, q)
}
return func(s sq.SelectBuilder) sq.SelectBuilder {
for i, q := range joins {
s = s.JoinClause(q.Prefix("JOIN (").Suffix(fmt.Sprintf(`) r%03d USING (id)`, i)))
}
if search.Count > 0 {
s = s.Limit(search.Count)
}
return s.Where(where)
}, nil
}
func GetWhereSQ(search mercury.Search) (func(sq.SelectBuilder) sq.SelectBuilder, error) {
var where sq.Or
var errs error
id := "id"
space := "space"
name := "name"
values_each := `json_valid("values")`
values_valid := `json_valid("values")`
if errs != nil {
return nil, errs
}
for _, m := range search.NamespaceSearch {
switch m.(type) {
case mercury.NamespaceNode:
where = append(where, sq.Eq{space: m.Value()})
case mercury.NamespaceStar:
where = append(where, sq.Like{space: m.Value()})
case mercury.NamespaceTrace:
e := sq.Expr(`? LIKE `+space+` || '%'`, m.Value())
where = append(where, e)
}
}
var joins []sq.SelectBuilder
for i, o := range search.Find {
log.Println(o)
if i > MAX_FILTER {
err := fmt.Errorf("too many filters [%d]", MAX_FILTER)
return nil, err
}
q := sq.Select("DISTINCT " + id).From(`mercury_values mv, ` + values_each + ` vs`)
switch o.Op {
case "key":
q = q.Where(sq.Eq{name: o.Left})
case "nkey":
q = q.Where(sq.NotEq{name: o.Left})
case "eq":
q = q.Where(sq.And{sq.Expr(values_valid), sq.Eq{name: o.Left, `vs.value`: o.Right}})
case "neq":
q = q.Where(sq.And{sq.Expr(values_valid), sq.Eq{name: o.Left}, sq.NotEq{`vs.value`: o.Right}})
case "gt":
q = q.Where(sq.And{sq.Expr(values_valid), sq.Eq{name: o.Left}, sq.Gt{`vs.value`: o.Right}})
case "lt":
q = q.Where(sq.And{sq.Expr(values_valid), sq.Eq{name: o.Left}, sq.Lt{`vs.value`: o.Right}})
case "ge":
q = q.Where(sq.And{sq.Expr(values_valid), sq.Eq{name: o.Left}, sq.GtOrEq{`vs.value`: o.Right}})
case "le":
q = q.Where(sq.And{sq.Expr(values_valid), sq.Eq{name: o.Left}, sq.LtOrEq{`vs.value`: o.Right}})
case "like":
q = q.Where(sq.And{sq.Expr(values_valid), sq.Eq{name: o.Left}, sq.Like{`vs.value`: o.Right}})
case "in":
q = q.Where(sq.Eq{name: o.Left, "vs.value": strings.Split(o.Right, " ")})
}
joins = append(joins, q)
}
return func(s sq.SelectBuilder) sq.SelectBuilder {
for i, q := range joins {
s = s.JoinClause(q.Prefix("JOIN (").Suffix(fmt.Sprintf(`) r%03d USING (id)`, i)))
}
if search.Count > 0 {
s = s.Limit(search.Count)
}
if search.Offset > 0 {
s = s.Offset(search.Offset)
}
return s.Where(where)
}, nil
}