go-pkg/libsql_embed/open.go

163 lines
3.2 KiB
Go
Raw Normal View History

2024-04-05 12:41:30 -06:00
package libsqlembed
import (
"context"
"database/sql"
"database/sql/driver"
2024-04-19 10:56:27 -06:00
"errors"
2024-04-05 12:41:30 -06:00
"fmt"
2024-04-19 10:56:27 -06:00
"io"
"log"
2024-04-05 12:41:30 -06:00
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/tursodatabase/go-libsql"
)
func init() {
2024-04-19 10:56:27 -06:00
sql.Register("libsql+embed", &db{conns: make(map[string]*connector)})
2024-04-05 12:41:30 -06:00
}
type db struct {
2024-04-19 10:56:27 -06:00
conns map[string]*connector
2024-04-05 12:41:30 -06:00
mu sync.RWMutex
}
type connector struct {
*libsql.Connector
2024-04-19 10:56:27 -06:00
dsn string
dir string
driver *db
removeDir bool
2024-04-05 12:41:30 -06:00
}
2024-04-19 10:56:27 -06:00
var _ io.Closer = (*connector)(nil)
2024-04-05 12:41:30 -06:00
func (c *connector) Close() error {
2024-04-19 10:56:27 -06:00
log.Println("closing db connection", c.dir)
defer log.Println("closed db connection", c.dir)
2024-04-05 12:41:30 -06:00
c.driver.mu.Lock()
delete(c.driver.conns, c.dsn)
c.driver.mu.Unlock()
2024-04-19 10:56:27 -06:00
if c.removeDir {
defer os.RemoveAll(c.dir)
}
2024-04-05 12:41:30 -06:00
2024-04-19 10:56:27 -06:00
log.Println("sync db")
2024-04-05 12:41:30 -06:00
if err := c.Connector.Sync(); err != nil {
return fmt.Errorf("syncing database: %w", err)
}
return c.Connector.Close()
}
func (db *db) OpenConnector(dsn string) (driver.Connector, error) {
2024-04-19 10:56:27 -06:00
// log.Println("connector", dsn)
if dsn == "" {
return nil, fmt.Errorf("no dsn")
}
if c, ok := func() (*connector, bool) {
2024-04-05 12:41:30 -06:00
db.mu.RLock()
defer db.mu.RUnlock()
c, ok := db.conns[dsn]
return c, ok
}(); ok {
return c, nil
}
db.mu.Lock()
defer db.mu.Unlock()
u, err := url.Parse(dsn)
if err != nil {
return nil, err
}
var primary url.URL
primary.Scheme = strings.TrimSuffix(u.Scheme, "+embed")
primary.Host = u.Host
dbname, _, _ := strings.Cut(u.Host, ".")
authToken := u.Query().Get("authToken")
if authToken == "" {
return nil, fmt.Errorf("missing authToken")
}
opts := []libsql.Option{
libsql.WithAuthToken(authToken),
}
2024-04-19 10:56:27 -06:00
if refresh, err := strconv.ParseInt(u.Query().Get("refresh"), 10, 64); err == nil {
log.Println("refresh: ", refresh)
2024-04-05 12:41:30 -06:00
opts = append(opts, libsql.WithSyncInterval(time.Duration(refresh)*time.Minute))
}
if readWrite, err := strconv.ParseBool(u.Query().Get("readYourWrites")); err == nil {
2024-04-19 10:56:27 -06:00
log.Println("read your writes: ", readWrite)
2024-04-05 12:41:30 -06:00
opts = append(opts, libsql.WithReadYourWrites(readWrite))
}
if key := u.Query().Get("key"); key != "" {
opts = append(opts, libsql.WithEncryption(key))
}
2024-04-19 10:56:27 -06:00
var dir string
var removeDir bool
if dir = u.Query().Get("store"); dir == "" {
removeDir = true
dir, err = os.MkdirTemp("", "libsql-*")
log.Println("creating temporary directory:", dir)
if err != nil {
return nil, fmt.Errorf("creating temporary directory: %w", err)
}
} else {
stat, err := os.Stat(dir)
if errors.Is(err, os.ErrNotExist) {
if err = os.MkdirAll(dir, 0700); err != nil {
return nil, err
}
} else {
if !stat.IsDir() {
return nil, fmt.Errorf("store not directory")
}
}
2024-04-05 12:41:30 -06:00
}
dbPath := filepath.Join(dir, dbname)
c, err := libsql.NewEmbeddedReplicaConnector(
dbPath,
primary.String(),
opts...)
if err != nil {
return nil, fmt.Errorf("creating connector: %w", err)
}
2024-04-19 10:56:27 -06:00
log.Println("sync db")
if err := c.Sync(); err != nil {
return nil, fmt.Errorf("syncing database: %w", err)
}
connector := &connector{c, dsn, dir, db, removeDir}
2024-04-05 12:41:30 -06:00
db.conns[dsn] = connector
return connector, nil
}
func (db *db) Open(dsn string) (driver.Conn, error) {
2024-04-19 10:56:27 -06:00
log.Println("open", dsn)
2024-04-05 12:41:30 -06:00
c, err := db.OpenConnector(dsn)
if err != nil {
return nil, err
}
return c.Connect(context.Background())
}