fix: 将 libsql-client-go 作为普通目录提交,而非 gitlink

This commit is contained in:
2026-04-27 07:04:46 +08:00
parent 2359d6c9fa
commit f6332fbaaf
52 changed files with 42333 additions and 1 deletions

View File

@@ -0,0 +1,18 @@
package shared
type result struct {
id int64
changes int64
}
func NewResult(id, changes int64) *result {
return &result{id: id, changes: changes}
}
func (r *result) LastInsertId() (int64, error) {
return r.id, nil
}
func (r *result) RowsAffected() (int64, error) {
return r.changes, nil
}

View File

@@ -0,0 +1,69 @@
package shared
import (
"database/sql/driver"
"fmt"
"io"
)
type rowsProvider interface {
SetsCount() int
RowsCount(setIdx int) int
Columns(setIdx int) []string
FieldValue(setIdx, rowIdx int, columnIdx int) driver.Value
Error(setIdx int) string
HasResult(setIdx int) bool
}
func NewRows(result rowsProvider) driver.Rows {
return &rows{result: result}
}
type rows struct {
result rowsProvider
currentResultSetIndex int
currentRowIdx int
}
func (r *rows) Columns() []string {
return r.result.Columns(r.currentResultSetIndex)
}
func (r *rows) Close() error {
return nil
}
func (r *rows) Next(dest []driver.Value) error {
if r.currentRowIdx == r.result.RowsCount(r.currentResultSetIndex) {
return io.EOF
}
count := len(r.result.Columns(r.currentResultSetIndex))
for idx := 0; idx < count; idx++ {
dest[idx] = r.result.FieldValue(r.currentResultSetIndex, r.currentRowIdx, idx)
}
r.currentRowIdx++
return nil
}
func (r *rows) HasNextResultSet() bool {
return r.currentResultSetIndex < r.result.SetsCount()-1
}
func (r *rows) NextResultSet() error {
if !r.HasNextResultSet() {
return io.EOF
}
r.currentResultSetIndex++
r.currentRowIdx = 0
errStr := r.result.Error(r.currentResultSetIndex)
if errStr != "" {
return fmt.Errorf("failed to execute statement\n%s", errStr)
}
if !r.result.HasResult(r.currentResultSetIndex) {
return fmt.Errorf("no results for statement")
}
return nil
}

View File

@@ -0,0 +1,257 @@
package shared
import (
"database/sql/driver"
"encoding/json"
"fmt"
"regexp"
"sort"
"github.com/antlr4-go/antlr/v4"
"github.com/tursodatabase/libsql-client-go/sqliteparser"
"github.com/tursodatabase/libsql-client-go/sqliteparserutils"
)
type ParamsInfo struct {
NamedParameters []string
PositionalParametersCount int
}
func ParseStatement(sql string) ([]string, []ParamsInfo, error) {
stmts, _ := sqliteparserutils.SplitStatement(sql)
stmtsParams := make([]ParamsInfo, len(stmts))
for idx, stmt := range stmts {
nameParams, positionalParamsCount, err := extractParameters(stmt)
if err != nil {
return nil, nil, err
}
stmtsParams[idx] = ParamsInfo{nameParams, positionalParamsCount}
}
return stmts, stmtsParams, nil
}
func ParseStatementAndArgs(sql string, args []driver.NamedValue) ([]string, []Params, error) {
stmts, _ := sqliteparserutils.SplitStatement(sql)
if len(args) == 0 {
return stmts, nil, nil
}
parameters, err := ConvertArgs(args)
if err != nil {
return nil, nil, err
}
stmtsParams := make([]Params, len(stmts))
totalParametersAlreadyUsed := 0
for idx, stmt := range stmts {
stmtParams, err := generateStatementParameters(stmt, parameters, totalParametersAlreadyUsed)
if err != nil {
return nil, nil, fmt.Errorf("fail to generate statement parameter. statement: %s. error: %v", stmt, err)
}
stmtsParams[idx] = stmtParams
totalParametersAlreadyUsed += stmtParams.Len()
}
return stmts, stmtsParams, nil
}
type paramsType int
const (
namedParameters paramsType = iota
positionalParameters
)
type Params struct {
positional []any
named map[string]any
}
func (p *Params) MarshalJSON() ([]byte, error) {
if len(p.named) > 0 {
return json.Marshal(p.named)
}
if len(p.positional) > 0 {
return json.Marshal(p.positional)
}
return json.Marshal(make([]any, 0))
}
func (p *Params) Named() map[string]any {
return p.named
}
func (p *Params) Positional() []any {
return p.positional
}
func (p *Params) Len() int {
if p.named != nil {
return len(p.named)
}
return len(p.positional)
}
func (p *Params) Type() paramsType {
if p.named != nil {
return namedParameters
}
return positionalParameters
}
func NewParams(t paramsType) Params {
p := Params{}
switch t {
case namedParameters:
p.named = make(map[string]any)
case positionalParameters:
p.positional = make([]any, 0)
}
return p
}
func getParamType(arg *driver.NamedValue) paramsType {
if arg.Name == "" {
return positionalParameters
}
return namedParameters
}
func ConvertArgs(args []driver.NamedValue) (Params, error) {
if len(args) == 0 {
return NewParams(positionalParameters), nil
}
var sortedArgs []*driver.NamedValue
for idx := range args {
sortedArgs = append(sortedArgs, &args[idx])
}
sort.Slice(sortedArgs, func(i, j int) bool {
return sortedArgs[i].Ordinal < sortedArgs[j].Ordinal
})
parametersType := getParamType(sortedArgs[0])
parameters := NewParams(parametersType)
for _, arg := range sortedArgs {
if parametersType != getParamType(arg) {
return Params{}, fmt.Errorf("driver does not accept positional and named parameters at the same time")
}
switch parametersType {
case positionalParameters:
parameters.positional = append(parameters.positional, arg.Value)
case namedParameters:
parameters.named[arg.Name] = arg.Value
}
}
return parameters, nil
}
func isExplain(stmt string) bool {
statementStream := antlr.NewInputStream(stmt)
lexer := sqliteparser.NewSQLiteLexer(statementStream)
tokenStream := antlr.NewCommonTokenStream(lexer, 0)
firstToken := tokenStream.LT(1)
return firstToken.GetTokenType() == sqliteparser.SQLiteParserEXPLAIN_
}
func generateStatementParameters(stmt string, queryParams Params, positionalParametersOffset int) (Params, error) {
nameParams, positionalParamsCount, err := extractParameters(stmt)
if err != nil {
return Params{}, err
}
stmtParams := NewParams(queryParams.Type())
switch queryParams.Type() {
case positionalParameters:
if positionalParametersOffset+positionalParamsCount > len(queryParams.positional) {
if isExplain(stmt) {
return Params{}, nil
}
// Positional parameters with indexes most of the time will have fewer args than parameters.
stmtParams.positional = queryParams.positional[positionalParametersOffset:len(queryParams.positional)]
} else {
stmtParams.positional = queryParams.positional[positionalParametersOffset : positionalParametersOffset+positionalParamsCount]
}
case namedParameters:
stmtParametersNeeded := make(map[string]bool)
for _, stmtParametersName := range nameParams {
stmtParametersNeeded[stmtParametersName] = true
}
for queryParamsName, queryParamsValue := range queryParams.named {
if stmtParametersNeeded[queryParamsName] {
stmtParams.named[queryParamsName] = queryParamsValue
delete(stmtParametersNeeded, queryParamsName)
}
}
}
return stmtParams, nil
}
func extractParameters(stmt string) (nameParams []string, positionalParamsCount int, err error) {
statementStream := antlr.NewInputStream(stmt)
sqliteparser.NewSQLiteLexer(statementStream)
lexer := sqliteparser.NewSQLiteLexer(statementStream)
allTokens := lexer.GetAllTokens()
nameParamsSet := make(map[string]bool)
positionalParamsWithIndexesSet := make(map[string]bool)
// ^: asserts the start of the string.
// \?: matches a literal question mark character.
// (\d+)? captures one more digits (0-9) in a group, but group is optional due to the ? quantifier.
// $: asserts the end of thr string so as to avoid this scenario: ?123ABC.
re := regexp.MustCompile(`^\?(\d+)?$`)
for _, token := range allTokens {
tokenType := token.GetTokenType()
if tokenType == sqliteparser.SQLiteLexerBIND_PARAMETER {
parameter := token.GetText()
match := re.FindStringSubmatch(parameter)
if match == nil {
paramWithoutPrefix, err := removeParamPrefix(parameter)
if err != nil {
return []string{}, 0, err
}
nameParamsSet[paramWithoutPrefix] = true
continue
}
posS := string(match[1])
if posS == "" {
// When an empty string, it means the parameter is a
// positional parameter without an index (e.g, ?).
positionalParamsCount++
} else {
// Positional parameter with indexes (e.g., ?<number>)
// must be deduped.
positionalParamsWithIndexesSet[posS] = true
}
}
}
nameParams = make([]string, 0, len(nameParamsSet))
for k := range nameParamsSet {
nameParams = append(nameParams, k)
}
// Only count unique number of positional parameters.
positionalParamsCount += len(positionalParamsWithIndexesSet)
return nameParams, positionalParamsCount, nil
}
func removeParamPrefix(paramName string) (string, error) {
if paramName[0] == ':' || paramName[0] == '@' || paramName[0] == '$' {
return paramName[1:], nil
}
return "", fmt.Errorf("all named parameters must start with ':', or '@' or '$'")
}

View File

@@ -0,0 +1,89 @@
package shared
import (
"reflect"
"sort"
"testing"
)
func TestExtractParameters(t *testing.T) {
tests := []struct {
name string
value string
nameParams []string
positionalParamsCount int
err error
}{
{
name: "OnlyColonNameParams",
value: "select :column from :table",
nameParams: []string{"column", "table"},
},
{
name: "OnlyAtNameParams",
value: "select @column from @table",
nameParams: []string{"column", "table"},
},
{
name: "OnlyDollarSignNameParams",
value: "select $column from $table",
nameParams: []string{"column", "table"},
},
{
name: "RepeatedNamedParameter",
value: "select :number, :number",
nameParams: []string{"number"},
},
{
name: "OnlyPositionalParams",
value: "select ? from ?",
nameParams: []string{},
positionalParamsCount: 2,
},
{
name: "OnlyPositionalParamsWithoutIndexes",
value: "select ? from ?",
nameParams: []string{},
positionalParamsCount: 2,
},
{
name: "OnlyPositionalParamsWithIndexes (dedup)",
value: "select ?1 from ?1",
nameParams: []string{},
positionalParamsCount: 1,
},
{
name: "PositionalParamsWithIndexes",
value: "select ? from ?1",
nameParams: []string{},
positionalParamsCount: 2,
},
{
name: "MixedParams",
value: "select :column1, @column2, $column3, ? from ?",
nameParams: []string{"column1", "column2", "column3"},
positionalParamsCount: 2,
},
{
name: "NoParams",
value: "select myColumn from myTable",
nameParams: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotNameParams, gotUniquePositionalParamsCount, gotErr := extractParameters(tt.value)
sort.Strings(gotNameParams)
sort.Strings(tt.nameParams)
if !reflect.DeepEqual(gotNameParams, tt.nameParams) {
t.Errorf("got nameParams %#v, want %#v", gotNameParams, tt.nameParams)
}
if !reflect.DeepEqual(gotUniquePositionalParamsCount, tt.positionalParamsCount) {
t.Errorf("got positionalParams %#v, want %#v", gotUniquePositionalParamsCount, tt.positionalParamsCount)
}
if !reflect.DeepEqual(gotUniquePositionalParamsCount, tt.positionalParamsCount) {
t.Errorf("got err %v, want %v", gotErr, tt.err)
}
})
}
}