301 lines
5.4 KiB
Go
301 lines
5.4 KiB
Go
|
package squirrel
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/Masterminds/squirrel"
|
||
|
"go.sour.is/pkg/rsql"
|
||
|
)
|
||
|
|
||
|
type dbInfo interface {
|
||
|
Col(string) (string, error)
|
||
|
}
|
||
|
|
||
|
type args map[string]string
|
||
|
|
||
|
func (d *decoder) mkArgs(a args) args {
|
||
|
m := make(args, len(a))
|
||
|
for k, v := range a {
|
||
|
if k == "limit" || k == "offset" {
|
||
|
m[k] = v
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
var err error
|
||
|
if k, err = d.dbInfo.Col(k); err == nil {
|
||
|
m[k] = v
|
||
|
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return m
|
||
|
}
|
||
|
|
||
|
func (a args) Limit() (uint64, bool) {
|
||
|
if a == nil {
|
||
|
return 0, false
|
||
|
}
|
||
|
if v, ok := a["limit"]; ok {
|
||
|
i, err := strconv.ParseUint(v, 10, 64)
|
||
|
return i, err == nil
|
||
|
}
|
||
|
|
||
|
return 0, false
|
||
|
}
|
||
|
func (a args) Offset() (uint64, bool) {
|
||
|
if a == nil {
|
||
|
return 0, false
|
||
|
}
|
||
|
if v, ok := a["offset"]; ok {
|
||
|
i, err := strconv.ParseUint(v, 10, 64)
|
||
|
return i, err == nil
|
||
|
}
|
||
|
|
||
|
return 0, false
|
||
|
}
|
||
|
func (a args) Order() []string {
|
||
|
var lis []string
|
||
|
|
||
|
for k, v := range a {
|
||
|
if k == "limit" || k == "offset" {
|
||
|
continue
|
||
|
}
|
||
|
lis = append(lis, k+" "+v)
|
||
|
}
|
||
|
|
||
|
return lis
|
||
|
}
|
||
|
|
||
|
func Query(in string, db dbInfo) (squirrel.Sqlizer, args, error) {
|
||
|
d := decoder{dbInfo: db}
|
||
|
program := rsql.DefaultParse(in)
|
||
|
sql, err := d.decode(program)
|
||
|
return sql, d.mkArgs(program.Args), err
|
||
|
}
|
||
|
|
||
|
type decoder struct {
|
||
|
dbInfo dbInfo
|
||
|
}
|
||
|
|
||
|
func (db *decoder) decode(in *rsql.Program) (squirrel.Sqlizer, error) {
|
||
|
switch len(in.Statements) {
|
||
|
case 0:
|
||
|
return nil, nil
|
||
|
case 1:
|
||
|
return db.decodeStatement(in.Statements[0])
|
||
|
default:
|
||
|
a := squirrel.And{}
|
||
|
for _, stmt := range in.Statements {
|
||
|
d, err := db.decodeStatement(stmt)
|
||
|
if d == nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
a = append(a, d)
|
||
|
}
|
||
|
return a, nil
|
||
|
}
|
||
|
}
|
||
|
func (db *decoder) decodeStatement(in rsql.Statement) (squirrel.Sqlizer, error) {
|
||
|
switch s := in.(type) {
|
||
|
case *rsql.ExpressionStatement:
|
||
|
return db.decodeExpression(s.Expression)
|
||
|
}
|
||
|
return nil, nil
|
||
|
}
|
||
|
func (db *decoder) decodeExpression(in rsql.Expression) (squirrel.Sqlizer, error) {
|
||
|
switch e := in.(type) {
|
||
|
case *rsql.InfixExpression:
|
||
|
return db.decodeInfix(e)
|
||
|
}
|
||
|
return nil, nil
|
||
|
}
|
||
|
func (db *decoder) decodeInfix(in *rsql.InfixExpression) (squirrel.Sqlizer, error) {
|
||
|
|
||
|
switch in.Token.Type {
|
||
|
case rsql.TokAND:
|
||
|
a := squirrel.And{}
|
||
|
left, err := db.decodeExpression(in.Left)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
switch v := left.(type) {
|
||
|
case squirrel.And:
|
||
|
for _, el := range v {
|
||
|
if el != nil {
|
||
|
a = append(a, el)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
default:
|
||
|
if v != nil {
|
||
|
a = append(a, v)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
right, err := db.decodeExpression(in.Right)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
switch v := right.(type) {
|
||
|
case squirrel.And:
|
||
|
for _, el := range v {
|
||
|
if el != nil {
|
||
|
a = append(a, el)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
default:
|
||
|
if v != nil {
|
||
|
a = append(a, v)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return a, nil
|
||
|
case rsql.TokOR:
|
||
|
left, err := db.decodeExpression(in.Left)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
right, err := db.decodeExpression(in.Right)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return squirrel.Or{left, right}, nil
|
||
|
case rsql.TokEQ:
|
||
|
col, err := db.dbInfo.Col(in.Left.String())
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
v, err := db.decodeValue(in.Right)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return squirrel.Eq{col: v}, nil
|
||
|
case rsql.TokLIKE:
|
||
|
col, err := db.dbInfo.Col(in.Left.String())
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
v, err := db.decodeValue(in.Right)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
switch value := v.(type) {
|
||
|
case string:
|
||
|
return Like{col, strings.Replace(value, "*", "%", -1)}, nil
|
||
|
default:
|
||
|
return nil, fmt.Errorf("LIKE requires a string value")
|
||
|
}
|
||
|
|
||
|
case rsql.TokNEQ:
|
||
|
col, err := db.dbInfo.Col(in.Left.String())
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
v, err := db.decodeValue(in.Right)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return squirrel.NotEq{col: v}, nil
|
||
|
case rsql.TokGT:
|
||
|
col, err := db.dbInfo.Col(in.Left.String())
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
v, err := db.decodeValue(in.Right)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return squirrel.Gt{col: v}, nil
|
||
|
case rsql.TokGE:
|
||
|
col, err := db.dbInfo.Col(in.Left.String())
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
v, err := db.decodeValue(in.Right)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return squirrel.GtOrEq{col: v}, nil
|
||
|
case rsql.TokLT:
|
||
|
col, err := db.dbInfo.Col(in.Left.String())
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
v, err := db.decodeValue(in.Right)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return squirrel.Lt{col: v}, nil
|
||
|
case rsql.TokLE:
|
||
|
col, err := db.dbInfo.Col(in.Left.String())
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
v, err := db.decodeValue(in.Right)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return squirrel.LtOrEq{col: v}, nil
|
||
|
default:
|
||
|
return nil, nil
|
||
|
}
|
||
|
}
|
||
|
func (db *decoder) decodeValue(in rsql.Expression) (interface{}, error) {
|
||
|
switch v := in.(type) {
|
||
|
case *rsql.Array:
|
||
|
var values []interface{}
|
||
|
for _, el := range v.Elements {
|
||
|
v, err := db.decodeValue(el)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
values = append(values, v)
|
||
|
}
|
||
|
|
||
|
return values, nil
|
||
|
case *rsql.InfixExpression:
|
||
|
return db.decodeInfix(v)
|
||
|
case *rsql.Identifier:
|
||
|
return v.Value, nil
|
||
|
case *rsql.Integer:
|
||
|
return v.Value, nil
|
||
|
case *rsql.Float:
|
||
|
return v.Value, nil
|
||
|
case *rsql.String:
|
||
|
return v.Value, nil
|
||
|
case *rsql.Bool:
|
||
|
return v.Value, nil
|
||
|
case *rsql.Null:
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
type Like struct {
|
||
|
column string
|
||
|
value string
|
||
|
}
|
||
|
|
||
|
func (l Like) ToSql() (sql string, args []interface{}, err error) {
|
||
|
sql = fmt.Sprintf("%s LIKE(?)", l.column)
|
||
|
args = append(args, l.value)
|
||
|
return
|
||
|
}
|