fix: 将 libsql-client-go 作为普通目录提交,而非 gitlink

This commit is contained in:
2026-04-27 07:04:46 +08:00
parent 2359d6c9fa
commit f6332fbaaf
52 changed files with 42333 additions and 1 deletions

View File

@@ -0,0 +1,22 @@
package hrana
type Batch struct {
Steps []BatchStep `json:"steps"`
ReplicationIndex *uint64 `json:"replication_index,omitempty"`
}
type BatchStep struct {
Stmt Stmt `json:"stmt"`
Condition *BatchCondition `json:"condition,omitempty"`
}
type BatchCondition struct {
Type string `json:"type"`
Step *int32 `json:"step,omitempty"`
Cond *BatchCondition `json:"cond,omitempty"`
Conds []BatchCondition `json:"conds,omitempty"`
}
func (b *Batch) Add(stmt Stmt, condition *BatchCondition) {
b.Steps = append(b.Steps, BatchStep{Stmt: stmt, Condition: condition})
}

View File

@@ -0,0 +1,49 @@
package hrana
import (
"encoding/json"
"fmt"
"strconv"
)
type BatchResult struct {
StepResults []*StmtResult `json:"step_results"`
StepErrors []*Error `json:"step_errors"`
ReplicationIndex *uint64 `json:"replication_index"`
}
func (b *BatchResult) UnmarshalJSON(data []byte) error {
type Alias BatchResult
aux := &struct {
ReplicationIndex interface{} `json:"replication_index,omitempty"`
*Alias
}{
Alias: (*Alias)(b),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if aux.ReplicationIndex == nil {
return nil
}
switch v := aux.ReplicationIndex.(type) {
case float64:
repIndex := uint64(v)
b.ReplicationIndex = &repIndex
case string:
if v == "" {
return nil
}
repIndex, err := strconv.ParseUint(v, 10, 64)
if err != nil {
return err
}
b.ReplicationIndex = &repIndex
default:
return fmt.Errorf("invalid type for replication index: %T", v)
}
return nil
}

View File

@@ -0,0 +1,58 @@
package hrana
import (
"encoding/json"
"fmt"
"reflect"
"testing"
)
func TestBatchResult_UnmarshalJSON(t *testing.T) {
testCases := []struct {
name string
jsonData []byte
expected *uint64
}{
{
jsonData: []byte(`{"replication_index":1}`),
expected: uint64Ptr(1),
},
{
jsonData: []byte(`{"replication_index":"1"}`),
expected: uint64Ptr(1),
},
{
jsonData: []byte(`{"replication_index":""}`),
expected: nil,
},
{
jsonData: []byte(`{}`),
expected: nil,
},
{
jsonData: []byte(`{"replication_index":"0"}`),
expected: uint64Ptr(0),
},
{
jsonData: []byte(`{"replication_index":0}`),
expected: uint64Ptr(0),
},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
batchResult := &BatchResult{}
err := json.Unmarshal(tc.jsonData, batchResult)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !reflect.DeepEqual(batchResult.ReplicationIndex, tc.expected) {
t.Errorf("ReplicationIndex field is not correctly unmarshaled got = %v, want = %v", batchResult.ReplicationIndex, tc.expected)
}
})
}
}
func uint64Ptr(n uint64) *uint64 {
return &n
}

View File

@@ -0,0 +1,10 @@
package hrana
type PipelineRequest struct {
Baton string `json:"baton,omitempty"`
Requests []StreamRequest `json:"requests"`
}
func (pr *PipelineRequest) Add(request StreamRequest) {
pr.Requests = append(pr.Requests, request)
}

View File

@@ -0,0 +1,7 @@
package hrana
type PipelineResponse struct {
Baton string `json:"baton,omitempty"`
BaseUrl string `json:"base_url,omitempty"`
Results []StreamResult `json:"results"`
}

View File

@@ -0,0 +1,58 @@
package hrana
import (
"github.com/tursodatabase/libsql-client-go/libsql/internal/http/shared"
)
type Stmt struct {
Sql *string `json:"sql,omitempty"`
SqlId *int32 `json:"sql_id,omitempty"`
Args []Value `json:"args,omitempty"`
NamedArgs []NamedArg `json:"named_args,omitempty"`
WantRows bool `json:"want_rows"`
ReplicationIndex *uint64 `json:"replication_index,omitempty"`
}
type NamedArg struct {
Name string `json:"name"`
Value Value `json:"value"`
}
func (s *Stmt) AddArgs(params shared.Params) error {
if len(params.Named()) > 0 {
return s.AddNamedArgs(params.Named())
} else {
return s.AddPositionalArgs(params.Positional())
}
}
func (s *Stmt) AddPositionalArgs(args []any) error {
argValues := make([]Value, len(args))
for idx := range args {
var err error
if argValues[idx], err = ToValue(args[idx]); err != nil {
return err
}
}
s.Args = argValues
return nil
}
func (s *Stmt) AddNamedArgs(args map[string]any) error {
argValues := make([]NamedArg, len(args))
idx := 0
for key, value := range args {
var err error
var v Value
if v, err = ToValue(value); err != nil {
return err
}
argValues[idx] = NamedArg{
Name: key,
Value: v,
}
idx++
}
s.NamedArgs = argValues
return nil
}

View File

@@ -0,0 +1,65 @@
package hrana
import (
"encoding/json"
"fmt"
"strconv"
)
type Column struct {
Name *string `json:"name"`
Type *string `json:"decltype"`
}
type StmtResult struct {
Cols []Column `json:"cols"`
Rows [][]Value `json:"rows"`
AffectedRowCount int32 `json:"affected_row_count"`
LastInsertRowId *string `json:"last_insert_rowid"`
ReplicationIndex *uint64 `json:"replication_index"`
}
func (r *StmtResult) GetLastInsertRowId() int64 {
if r.LastInsertRowId != nil {
if integer, err := strconv.ParseInt(*r.LastInsertRowId, 10, 64); err == nil {
return integer
}
}
return 0
}
func (r *StmtResult) UnmarshalJSON(data []byte) error {
type Alias StmtResult
aux := &struct {
ReplicationIndex interface{} `json:"replication_index,omitempty"`
*Alias
}{
Alias: (*Alias)(r),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if aux.ReplicationIndex == nil {
return nil
}
switch v := aux.ReplicationIndex.(type) {
case float64:
repIndex := uint64(v)
r.ReplicationIndex = &repIndex
case string:
if v == "" {
return nil
}
repIndex, err := strconv.ParseUint(v, 10, 64)
if err != nil {
return err
}
r.ReplicationIndex = &repIndex
default:
return fmt.Errorf("invalid type for replication index: %T", v)
}
return nil
}

View File

@@ -0,0 +1,54 @@
package hrana
import (
"encoding/json"
"fmt"
"reflect"
"testing"
)
func TestStmtResult_UnmarshalJSON(t *testing.T) {
testCases := []struct {
name string
jsonData []byte
expected *uint64
}{
{
jsonData: []byte(`{"replication_index":1}`),
expected: uint64Ptr(1),
},
{
jsonData: []byte(`{"replication_index":"1"}`),
expected: uint64Ptr(1),
},
{
jsonData: []byte(`{"replication_index":""}`),
expected: nil,
},
{
jsonData: []byte(`{}`),
expected: nil,
},
{
jsonData: []byte(`{"replication_index":"0"}`),
expected: uint64Ptr(0),
},
{
jsonData: []byte(`{"replication_index":0}`),
expected: uint64Ptr(0),
},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
stmtResult := &StmtResult{}
err := json.Unmarshal(tc.jsonData, stmtResult)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !reflect.DeepEqual(stmtResult.ReplicationIndex, tc.expected) {
t.Errorf("ReplicationIndex field is not correctly unmarshaled got = %v, want = %v", stmtResult.ReplicationIndex, tc.expected)
}
})
}
}

View File

@@ -0,0 +1,98 @@
package hrana
import (
"reflect"
"testing"
)
func TestStmtWithPositionalArgs(t *testing.T) {
tests := []struct {
name string
args []any
want []Value
wantErr bool
}{
{
name: "int args",
args: []any{1, 2},
want: []Value{{Type: "integer", Value: "1"}, {Type: "integer", Value: "2"}},
},
{
name: "string args",
args: []any{"a", "b"},
want: []Value{{Type: "text", Value: "a"}, {Type: "text", Value: "b"}},
},
{
name: "invalid arg",
args: []any{make(chan int)},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stmt := Stmt{}
err := stmt.AddPositionalArgs(tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("StmtWithPositionalArgs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(stmt.Args, tt.want) {
t.Errorf("got = %v, want %v", stmt.Args, tt.want)
}
})
}
}
func TestStmtWithNamedArgs(t *testing.T) {
tests := []struct {
name string
args map[string]any
want []NamedArg
wantErr bool
}{
{
name: "int args",
args: map[string]any{"arg1": 1, "arg2": int64(2)},
want: []NamedArg{
{Name: "arg1", Value: Value{Type: "integer", Value: "1"}},
{Name: "arg2", Value: Value{Type: "integer", Value: "2"}},
},
},
{
name: "string args",
args: map[string]any{"arg1": "a", "arg2": "b"},
want: []NamedArg{
{Name: "arg1", Value: Value{Type: "text", Value: "a"}},
{Name: "arg2", Value: Value{Type: "text", Value: "b"}},
},
},
{
name: "invalid arg",
args: map[string]any{"arg1": make(chan int)},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stmt := Stmt{}
err := stmt.AddNamedArgs(tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("StmtWithNamedArgs() error = %v, wantErr %v", err, tt.wantErr)
return
}
got := make(map[NamedArg]struct{})
want := make(map[NamedArg]struct{})
for _, arg := range stmt.NamedArgs {
got[arg] = struct{}{}
}
for _, arg := range tt.want {
want[arg] = struct{}{}
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got = %v, want %v", stmt.NamedArgs, tt.want)
}
})
}
}

View File

@@ -0,0 +1,88 @@
package hrana
import (
"github.com/tursodatabase/libsql-client-go/libsql/internal/http/shared"
)
type StreamRequest struct {
Type string `json:"type"`
Stmt *Stmt `json:"stmt,omitempty"`
Batch *Batch `json:"batch,omitempty"`
Sql *string `json:"sql,omitempty"`
SqlId *int32 `json:"sql_id,omitempty"`
}
func CloseStream() StreamRequest {
return StreamRequest{Type: "close"}
}
func ExecuteStream(sql string, params *shared.Params, wantRows bool) (*StreamRequest, error) {
stmt := &Stmt{
Sql: &sql,
WantRows: wantRows,
}
if params != nil {
if err := stmt.AddArgs(*params); err != nil {
return nil, err
}
}
return &StreamRequest{Type: "execute", Stmt: stmt}, nil
}
func ExecuteStoredStream(sqlId int32, params shared.Params, wantRows bool) (*StreamRequest, error) {
stmt := &Stmt{
SqlId: &sqlId,
WantRows: wantRows,
}
if err := stmt.AddArgs(params); err != nil {
return nil, err
}
return &StreamRequest{Type: "execute", Stmt: stmt}, nil
}
func BatchStream(sqls []string, params []shared.Params, wantRows bool, transactional bool) (*StreamRequest, error) {
size := len(sqls)
if transactional {
size += 1
}
batch := &Batch{Steps: make([]BatchStep, 0, size)}
addArgs := len(params) > 0
for idx, sql := range sqls {
s := sql
stmt := &Stmt{
Sql: &s,
WantRows: wantRows,
}
if addArgs {
if err := stmt.AddArgs(params[idx]); err != nil {
return nil, err
}
}
var condition *BatchCondition
if transactional {
if idx > 0 {
prev_idx := int32(idx - 1)
condition = &BatchCondition{
Type: "ok",
Step: &prev_idx,
}
}
}
batch.Add(*stmt, condition)
}
if transactional {
rollback := "ROLLBACK"
last_idx := int32(len(sqls) - 1)
batch.Add(Stmt{Sql: &rollback, WantRows: false},
&BatchCondition{Type: "not", Cond: &BatchCondition{Type: "ok", Step: &last_idx}})
}
return &StreamRequest{Type: "batch", Batch: batch}, nil
}
func StoreSqlStream(sql string, sqlId int32) StreamRequest {
return StreamRequest{Type: "store_sql", Sql: &sql, SqlId: &sqlId}
}
func CloseStoredSqlStream(sqlId int32) StreamRequest {
return StreamRequest{Type: "close_sql", SqlId: &sqlId}
}

View File

@@ -0,0 +1,52 @@
package hrana
import (
"encoding/json"
"errors"
"fmt"
)
type StreamResult struct {
Type string `json:"type"`
Response *StreamResponse `json:"response,omitempty"`
Error *Error `json:"error,omitempty"`
}
type StreamResponse struct {
Type string `json:"type"`
Result json.RawMessage `json:"result,omitempty"`
}
func (r *StreamResponse) ExecuteResult() (*StmtResult, error) {
if r.Type != "execute" {
return nil, fmt.Errorf("invalid response type: %s", r.Type)
}
var res StmtResult
if err := json.Unmarshal(r.Result, &res); err != nil {
return nil, err
}
return &res, nil
}
func (r *StreamResponse) BatchResult() (*BatchResult, error) {
if r.Type != "batch" {
return nil, fmt.Errorf("invalid response type: %s", r.Type)
}
var res BatchResult
if err := json.Unmarshal(r.Result, &res); err != nil {
return nil, err
}
for _, e := range res.StepErrors {
if e != nil {
return nil, errors.New(e.Message)
}
}
return &res, nil
}
type Error struct {
Message string `json:"message"`
Code *string `json:"code,omitempty"`
}

View File

@@ -0,0 +1,89 @@
package hrana
import (
"encoding/base64"
"fmt"
"strconv"
"strings"
"time"
)
type Value struct {
Type string `json:"type"`
Value any `json:"value,omitempty"`
Base64 *string `json:"base64,omitempty"`
}
func (v Value) ToValue(columnType *string) any {
if v.Type == "blob" {
if v.Base64 == nil {
return nil
}
bytes, err := base64.StdEncoding.WithPadding(base64.NoPadding).DecodeString(*v.Base64)
if err != nil {
return nil
}
return bytes
} else if v.Type == "integer" {
integer, err := strconv.ParseInt(v.Value.(string), 10, 64)
if err != nil {
return nil
}
return integer
} else if columnType != nil {
if (strings.ToLower(*columnType) == "timestamp" || strings.ToLower(*columnType) == "datetime") && v.Type == "text" {
for _, format := range []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02T15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999",
"2006-01-02 15:04:05",
"2006-01-02T15:04:05",
"2006-01-02 15:04",
"2006-01-02T15:04",
"2006-01-02",
} {
if t, err := time.ParseInLocation(format, v.Value.(string), time.UTC); err == nil {
return t
}
}
}
}
return v.Value
}
func ToValue(v any) (Value, error) {
var res Value
if v == nil {
res.Type = "null"
} else if integer, ok := v.(int64); ok {
res.Type = "integer"
res.Value = strconv.FormatInt(integer, 10)
} else if integer, ok := v.(int); ok {
res.Type = "integer"
res.Value = strconv.FormatInt(int64(integer), 10)
} else if text, ok := v.(string); ok {
res.Type = "text"
res.Value = text
} else if blob, ok := v.([]byte); ok {
res.Type = "blob"
b64 := base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString(blob)
res.Base64 = &b64
} else if float, ok := v.(float64); ok {
res.Type = "float"
res.Value = float
} else if t, ok := v.(time.Time); ok {
res.Type = "text"
res.Value = t.Format("2006-01-02 15:04:05.999999999-07:00")
} else if t, ok := v.(bool); ok {
res.Type = "integer"
res.Value = "0"
if t {
res.Value = "1"
}
} else {
return res, fmt.Errorf("unsupported value type: %s", v)
}
return res, nil
}

View File

@@ -0,0 +1,304 @@
package hrana
import (
"encoding/json"
"reflect"
"strconv"
"testing"
"time"
)
func toPtr[T any](v T) *T {
return &v
}
func TestValueToValue(t *testing.T) {
tests := []struct {
name string
columnType string
value Value
want any
}{
{
name: "null",
value: Value{
Type: "null",
Value: nil,
},
want: nil,
},
{
name: "int",
value: Value{
Type: "integer",
Value: strconv.FormatInt(int64(42), 10),
},
want: int64(42),
},
{
name: "string",
value: Value{
Type: "text",
Value: "foo",
},
want: "foo",
},
{
name: "bytes",
value: Value{
Type: "blob",
Base64: toPtr("YmFy"),
},
want: []byte("bar"),
},
{
name: "bytes",
value: Value{
Type: "blob",
Base64: toPtr(""),
},
want: []byte{},
},
{
name: "float",
value: Value{
Type: "float",
Value: 3.14,
},
want: 3.14,
},
{
name: "timestamp",
columnType: "TIMESTAMP",
value: Value{
Type: "text",
Value: "0001-01-01 01:00:00+00:00",
},
want: time.Time{}.Add(time.Hour),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var columnType *string = nil
if tt.columnType != "" {
columnType = &tt.columnType
}
got := tt.value.ToValue(columnType)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ToValue() = %v, want %v", got, tt.want)
}
})
}
}
func TestToValue(t *testing.T) {
tests := []struct {
name string
value any
want Value
wantErr bool
}{
{
name: "null",
value: nil,
want: Value{
Type: "null",
},
},
{
name: "int",
value: 42,
want: Value{
Type: "integer",
Value: strconv.FormatInt(int64(42), 10),
},
},
{
name: "string",
value: "foo",
want: Value{
Type: "text",
Value: "foo",
},
},
{
name: "bytes",
value: []byte{},
want: Value{
Type: "blob",
Base64: toPtr(""),
},
},
{
name: "float",
value: 3.14,
want: Value{
Type: "float",
Value: 3.14,
},
},
{
name: "boolean",
value: true,
want: Value{
Type: "integer",
Value: "1",
},
},
{
name: "timestamp",
value: time.Time{}.Add(time.Hour),
want: Value{
Type: "text",
Value: "0001-01-01 01:00:00+00:00",
},
},
{
name: "unsupported",
value: make(chan int),
want: Value{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ToValue(tt.value)
if (err != nil) != tt.wantErr {
t.Errorf("ToValue() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ToValue() = %v, want %v", got, tt.want)
}
})
}
}
func TestMarshal(t *testing.T) {
tests := []struct {
name string
value Value
marshaled string
}{
{
name: "null",
value: Value{
Type: "null",
},
marshaled: `{"type":"null"}`,
},
{
name: "int",
value: Value{
Type: "integer",
Value: strconv.FormatInt(int64(42), 10),
},
marshaled: `{"type":"integer","value":"42"}`,
},
{
name: "string",
value: Value{
Type: "text",
Value: "foo",
},
marshaled: `{"type":"text","value":"foo"}`,
},
{
name: "bytes",
value: Value{
Type: "blob",
Base64: toPtr("YmFy"),
},
marshaled: `{"type":"blob","base64":"YmFy"}`,
},
{
name: "float",
value: Value{
Type: "float",
Value: 3.14,
},
marshaled: `{"type":"float","value":3.14}`,
},
{
name: "timestamp",
value: Value{
Type: "text",
Value: time.Time{}.Add(time.Hour),
},
marshaled: `{"type":"text","value":"0001-01-01T01:00:00Z"}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := json.Marshal(tt.value)
if err != nil {
t.Errorf("json.Marshal() error = %v", err)
return
}
if !reflect.DeepEqual(string(got), tt.marshaled) {
t.Errorf("json.Marshal() = %v, want %v", string(got), tt.marshaled)
}
})
}
}
func TestUnmarshal(t *testing.T) {
tests := []struct {
name string
value Value
marshaled string
}{
{
name: "null",
value: Value{
Type: "null",
},
marshaled: `{"type":"null"}`,
},
{
name: "int",
value: Value{
Type: "integer",
Value: strconv.FormatInt(int64(42), 10),
},
marshaled: `{"type":"integer","value":"42"}`,
},
{
name: "string",
value: Value{
Type: "text",
Value: "foo",
},
marshaled: `{"type":"text","value":"foo"}`,
},
{
name: "bytes",
value: Value{
Type: "blob",
Base64: toPtr("YmFy"),
},
marshaled: `{"type":"blob","base64":"YmFy"}`,
},
{
name: "float",
value: Value{
Type: "float",
Value: 3.14,
},
marshaled: `{"type":"float","value":3.14}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got Value
err := json.Unmarshal([]byte(tt.marshaled), &got)
if err != nil {
t.Errorf("json.Marshal() error = %v", err)
return
}
if !reflect.DeepEqual(got, tt.value) {
t.Errorf("json.Unmarshal() = %v, want %v", got, tt.value)
}
})
}
}

View File

@@ -0,0 +1,11 @@
package http
import (
"database/sql/driver"
"github.com/tursodatabase/libsql-client-go/libsql/internal/http/hranaV2"
)
func Connect(url, jwt, host string, schemaDb bool, remoteEncryptionKey string) driver.Conn {
return hranaV2.Connect(url, jwt, host, schemaDb, remoteEncryptionKey)
}

View File

@@ -0,0 +1,612 @@
package hranaV2
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
net_url "net/url"
"runtime/debug"
"strings"
"github.com/tursodatabase/libsql-client-go/sqliteparserutils"
"github.com/tursodatabase/libsql-client-go/libsql/internal/hrana"
"github.com/tursodatabase/libsql-client-go/libsql/internal/http/shared"
)
var commitHash string
func init() {
if info, ok := debug.ReadBuildInfo(); ok {
for _, module := range info.Deps {
if module.Path == "github.com/tursodatabase/libsql-client-go" {
parts := strings.Split(module.Version, "-")
if len(parts) == 3 {
commitHash = parts[2][:6]
return
}
}
}
}
commitHash = "unknown"
}
func Connect(url, jwt, host string, schemaDb bool, encryptionKey string) driver.Conn {
return &hranaV2Conn{url, jwt, host, schemaDb, encryptionKey, "", false, 0}
}
type hranaV2Stmt struct {
conn *hranaV2Conn
numInput int
sql string
}
func (s *hranaV2Stmt) Close() error {
return nil
}
func (s *hranaV2Stmt) NumInput() int {
return s.numInput
}
func convertToNamed(args []driver.Value) []driver.NamedValue {
if len(args) == 0 {
return nil
}
var result []driver.NamedValue
for idx := range args {
result = append(result, driver.NamedValue{Ordinal: idx, Value: args[idx]})
}
return result
}
func (s *hranaV2Stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.ExecContext(context.Background(), convertToNamed(args))
}
func (s *hranaV2Stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.QueryContext(context.Background(), convertToNamed(args))
}
func (s *hranaV2Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
return s.conn.ExecContext(ctx, s.sql, args)
}
func (s *hranaV2Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
return s.conn.QueryContext(ctx, s.sql, args)
}
type hranaV2Conn struct {
url string
jwt string
host string
schemaDb bool
remoteEncryptionKey string
baton string
streamClosed bool
replicationIndex uint64
}
func (h *hranaV2Conn) Ping() error {
return h.PingContext(context.Background())
}
func (h *hranaV2Conn) PingContext(ctx context.Context) error {
_, err := h.executeStmt(ctx, "SELECT 1", nil, false)
return err
}
func (h *hranaV2Conn) Prepare(query string) (driver.Stmt, error) {
return h.PrepareContext(context.Background(), query)
}
func (h *hranaV2Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
stmts, paramInfos, err := shared.ParseStatement(query)
if err != nil {
return nil, err
}
if len(stmts) != 1 {
return nil, fmt.Errorf("only one statement is supported got %d", len(stmts))
}
numInput := -1
if len(paramInfos[0].NamedParameters) == 0 {
numInput = paramInfos[0].PositionalParametersCount
}
return &hranaV2Stmt{h, numInput, query}, nil
}
func (h *hranaV2Conn) Close() error {
if h.baton != "" {
go func(baton, url, jwt, host, encryptionKey string) {
msg := hrana.PipelineRequest{Baton: baton}
msg.Add(hrana.CloseStream())
_, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host, encryptionKey)
}(h.baton, h.url, h.jwt, h.host, h.remoteEncryptionKey)
}
return nil
}
func (h *hranaV2Conn) Begin() (driver.Tx, error) {
return h.BeginTx(context.Background(), driver.TxOptions{})
}
type hranaV2Tx struct {
conn *hranaV2Conn
}
func (h hranaV2Tx) Commit() error {
_, err := h.conn.ExecContext(context.Background(), "COMMIT", nil)
return err
}
func (h hranaV2Tx) Rollback() error {
_, err := h.conn.ExecContext(context.Background(), "ROLLBACK", nil)
return err
}
func (h *hranaV2Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if opts.ReadOnly {
return nil, fmt.Errorf("read only transactions are not supported")
}
if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) {
return nil, fmt.Errorf("isolation level %d is not supported", opts.Isolation)
}
_, err := h.ExecContext(ctx, "BEGIN", nil)
if err != nil {
return nil, err
}
return &hranaV2Tx{h}, nil
}
func (h *hranaV2Conn) sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, streamClose bool) (*hrana.PipelineResponse, error) {
if h.streamClosed {
// If the stream is closed, we can't send any more requests using this connection.
return nil, fmt.Errorf("stream is closed: %w", driver.ErrBadConn)
}
if h.baton != "" {
msg.Baton = h.baton
}
if h.replicationIndex > 0 {
addReplicationIndex(msg, h.replicationIndex)
}
result, streamClosed, err := sendPipelineRequest(ctx, msg, h.url, h.jwt, h.host, h.remoteEncryptionKey)
if streamClosed {
h.streamClosed = true
}
if err != nil {
return nil, err
}
h.baton = result.Baton
if result.Baton == "" && !streamClose {
// We need to remember that the stream is closed so we don't try to send any more requests using this connection.
h.streamClosed = true
}
if result.BaseUrl != "" {
h.url = result.BaseUrl
}
if idx := getReplicationIndex(&result); idx > h.replicationIndex {
h.replicationIndex = idx
}
return &result, nil
}
func addReplicationIndex(msg *hrana.PipelineRequest, replicationIndex uint64) {
for i := range msg.Requests {
if msg.Requests[i].Stmt != nil && msg.Requests[i].Stmt.ReplicationIndex == nil {
msg.Requests[i].Stmt.ReplicationIndex = &replicationIndex
} else if msg.Requests[i].Batch != nil && msg.Requests[i].Batch.ReplicationIndex == nil {
msg.Requests[i].Batch.ReplicationIndex = &replicationIndex
}
}
}
func getReplicationIndex(response *hrana.PipelineResponse) uint64 {
if response == nil || len(response.Results) == 0 {
return 0
}
var replicationIndex uint64
for _, result := range response.Results {
if result.Response == nil {
continue
}
if result.Response.Type == "execute" {
if res, err := result.Response.ExecuteResult(); err == nil && res.ReplicationIndex != nil {
if *res.ReplicationIndex > replicationIndex {
replicationIndex = *res.ReplicationIndex
}
}
} else if result.Response.Type == "batch" {
if res, err := result.Response.BatchResult(); err == nil && res.ReplicationIndex != nil {
if *res.ReplicationIndex > replicationIndex {
replicationIndex = *res.ReplicationIndex
}
}
}
}
return replicationIndex
}
func sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, url string, jwt string, host string, remoteEncryptionKey string) (result hrana.PipelineResponse, streamClosed bool, err error) {
reqBody, err := json.Marshal(msg)
if err != nil {
return hrana.PipelineResponse{}, false, err
}
pipelineURL, err := net_url.JoinPath(url, "/v2/pipeline")
if err != nil {
return hrana.PipelineResponse{}, false, err
}
req, err := http.NewRequestWithContext(ctx, "POST", pipelineURL, bytes.NewReader(reqBody))
if err != nil {
return hrana.PipelineResponse{}, false, err
}
if len(jwt) > 0 {
req.Header.Set("Authorization", "Bearer "+jwt)
}
req.Header.Set("x-libsql-client-version", "libsql-remote-go-"+commitHash)
if remoteEncryptionKey != "" {
req.Header.Set("x-turso-encryption-key", remoteEncryptionKey)
}
req.Host = host
resp, err := http.DefaultClient.Do(req)
if err != nil {
return hrana.PipelineResponse{}, false, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return hrana.PipelineResponse{}, false, err
}
if resp.StatusCode != http.StatusOK {
// We need to remember that the stream is closed so we don't try to send any more requests using this connection.
var serverError struct {
Error string `json:"error"`
}
if err := json.Unmarshal(body, &serverError); err == nil {
return hrana.PipelineResponse{}, true, fmt.Errorf("error code %d: %s", resp.StatusCode, serverError.Error)
}
var errResponse hrana.Error
if err := json.Unmarshal(body, &errResponse); err == nil {
if errResponse.Code != nil {
if *errResponse.Code == "STREAM_EXPIRED" {
return hrana.PipelineResponse{}, true, fmt.Errorf("error code %s: %s\n%w", *errResponse.Code, errResponse.Message, driver.ErrBadConn)
} else {
return hrana.PipelineResponse{}, true, fmt.Errorf("error code %s: %s", *errResponse.Code, errResponse.Message)
}
}
return hrana.PipelineResponse{}, true, errors.New(errResponse.Message)
}
return hrana.PipelineResponse{}, true, fmt.Errorf("error code %d: %s", resp.StatusCode, string(body))
}
if err = json.Unmarshal(body, &result); err != nil {
return hrana.PipelineResponse{}, false, err
}
return result, false, nil
}
func (h *hranaV2Conn) executeMsg(ctx context.Context, msg *hrana.PipelineRequest) (*hrana.PipelineResponse, error) {
result, err := h.sendPipelineRequest(ctx, msg, false)
if err != nil {
return nil, err
}
for _, r := range result.Results {
if r.Error != nil {
return nil, errors.New(r.Error.Message)
}
if r.Response == nil {
return nil, errors.New("no response received")
}
}
return result, nil
}
type chunker struct {
chunk []string
iterator *sqliteparserutils.StatementIterator
limit int
}
func newChunker(iterator *sqliteparserutils.StatementIterator, limit int) *chunker {
return &chunker{iterator: iterator, chunk: make([]string, 0, limit), limit: limit}
}
func isTransactionStatement(stmt string) bool {
patterns := [][]byte{[]byte("begin"), []byte("commit"), []byte("end"), []byte("rollback")}
for _, p := range patterns {
if len(stmt) >= len(p) && bytes.Equal(bytes.ToLower([]byte(stmt[0:len(p)])), p) {
return true
}
}
return false
}
func (c *chunker) Next() (chunk []string, isEOF bool) {
c.chunk = c.chunk[:0]
var stmt string
for !isEOF && len(c.chunk) < c.limit {
stmt, _, isEOF = c.iterator.Next()
// We need to skip transaction statements. Chunks run in a transaction by default.
if stmt != "" && !isTransactionStatement(stmt) {
c.chunk = append(c.chunk, stmt)
}
}
return c.chunk, isEOF
}
func (h *hranaV2Conn) executeSingleStmt(ctx context.Context, stmt string, wantRows bool) (*hrana.PipelineResponse, error) {
msg := &hrana.PipelineRequest{}
executeStream, err := hrana.ExecuteStream(stmt, nil, wantRows)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL: %s\n%w", stmt, err)
}
msg.Add(*executeStream)
res, err := h.executeMsg(ctx, msg)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL: %s\n%w", stmt, err)
}
return res, nil
}
func (h *hranaV2Conn) executeInChunks(ctx context.Context, query string, wantRows bool) (*hrana.PipelineResponse, error) {
const chunkSize = 4096
iterator := sqliteparserutils.CreateStatementIterator(query)
chunker := newChunker(iterator, chunkSize)
chunk, isEOF := chunker.Next()
if isEOF && len(chunk) == 1 {
return h.executeSingleStmt(ctx, chunk[0], wantRows)
}
_, err := h.executeSingleStmt(ctx, "BEGIN", false)
if err != nil {
return nil, err
}
batch := &hrana.Batch{Steps: make([]hrana.BatchStep, chunkSize)}
msg := &hrana.PipelineRequest{}
msg.Add(hrana.StreamRequest{Type: "batch", Batch: batch})
for idx := range batch.Steps {
batch.Steps[idx].Stmt.WantRows = wantRows
}
result := &hrana.PipelineResponse{}
for {
for idx := range chunk {
batch.Steps[idx].Stmt.Sql = &chunk[idx]
}
if len(chunk) < chunkSize {
// We can trim batch.Steps because this is the last chunk anyway.
// isEOF has to be true at this point.
batch.Steps = batch.Steps[:len(chunk)]
}
res, err := h.executeMsg(ctx, msg)
if err != nil {
h.closeStream()
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
}
result.Baton = res.Baton
result.BaseUrl = res.BaseUrl
result.Results = append(result.Results, res.Results...)
if isEOF {
break
}
chunk, isEOF = chunker.Next()
}
_, err = h.executeSingleStmt(ctx, "COMMIT", false)
if err != nil {
h.closeStream()
return nil, err
}
return result, nil
}
func (h *hranaV2Conn) executeStmt(ctx context.Context, query string, args []driver.NamedValue, wantRows bool) (*hrana.PipelineResponse, error) {
const querySizeLimitForChunking = 20 * 1024 * 1024
if len(args) == 0 && len(query) > querySizeLimitForChunking && !h.schemaDb {
return h.executeInChunks(ctx, query, wantRows)
}
stmts, params, err := shared.ParseStatementAndArgs(query, args)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
}
msg := &hrana.PipelineRequest{}
if len(stmts) == 1 {
var p *shared.Params
if len(params) > 0 {
p = &params[0]
}
executeStream, err := hrana.ExecuteStream(stmts[0], p, wantRows)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
}
msg.Add(*executeStream)
} else {
batchStream, err := hrana.BatchStream(stmts, params, wantRows, !h.schemaDb)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
}
msg.Add(*batchStream)
}
resp, err := h.executeMsg(ctx, msg)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
}
return resp, nil
}
func (h *hranaV2Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
result, err := h.executeStmt(ctx, query, args, false)
if err != nil {
return nil, err
}
switch result.Results[0].Response.Type {
case "execute":
res, err := result.Results[0].Response.ExecuteResult()
if err != nil {
return nil, err
}
return shared.NewResult(res.GetLastInsertRowId(), int64(res.AffectedRowCount)), nil
case "batch":
res, err := result.Results[0].Response.BatchResult()
if err != nil {
return nil, err
}
lastInsertRowId := int64(0)
affectedRowCount := int64(0)
upperBound := len(res.StepResults)
if !h.schemaDb {
upperBound -= 1
}
for idx := 0; idx < upperBound; idx++ {
r := res.StepResults[idx]
rowId := r.GetLastInsertRowId()
if rowId > 0 {
lastInsertRowId = rowId
}
affectedRowCount += int64(r.AffectedRowCount)
}
return shared.NewResult(lastInsertRowId, affectedRowCount), nil
default:
return nil, fmt.Errorf("failed to execute SQL: %s\n%s", query, "unknown response type")
}
}
type StmtResultRowsProvider struct {
r *hrana.StmtResult
}
func (p *StmtResultRowsProvider) SetsCount() int {
return 1
}
func (p *StmtResultRowsProvider) RowsCount(setIdx int) int {
if setIdx != 0 {
return 0
}
return len(p.r.Rows)
}
func (p *StmtResultRowsProvider) Columns(setIdx int) []string {
if setIdx != 0 {
return nil
}
res := make([]string, len(p.r.Cols))
for i, c := range p.r.Cols {
if c.Name != nil {
res[i] = *c.Name
}
}
return res
}
func (p *StmtResultRowsProvider) FieldValue(setIdx, rowIdx, colIdx int) driver.Value {
if setIdx != 0 {
return nil
}
return p.r.Rows[rowIdx][colIdx].ToValue(p.r.Cols[colIdx].Type)
}
func (p *StmtResultRowsProvider) Error(setIdx int) string {
return ""
}
func (p *StmtResultRowsProvider) HasResult(setIdx int) bool {
return setIdx == 0
}
type BatchResultRowsProvider struct {
r *hrana.BatchResult
}
func (p *BatchResultRowsProvider) SetsCount() int {
return len(p.r.StepResults)
}
func (p *BatchResultRowsProvider) RowsCount(setIdx int) int {
if setIdx >= len(p.r.StepResults) || p.r.StepResults[setIdx] == nil {
return 0
}
return len(p.r.StepResults[setIdx].Rows)
}
func (p *BatchResultRowsProvider) Columns(setIdx int) []string {
if setIdx >= len(p.r.StepResults) || p.r.StepResults[setIdx] == nil {
return nil
}
res := make([]string, len(p.r.StepResults[setIdx].Cols))
for i, c := range p.r.StepResults[setIdx].Cols {
if c.Name != nil {
res[i] = *c.Name
}
}
return res
}
func (p *BatchResultRowsProvider) FieldValue(setIdx, rowIdx, colIdx int) driver.Value {
if setIdx >= len(p.r.StepResults) || p.r.StepResults[setIdx] == nil {
return nil
}
return p.r.StepResults[setIdx].Rows[rowIdx][colIdx].ToValue(p.r.StepResults[setIdx].Cols[colIdx].Type)
}
func (p *BatchResultRowsProvider) Error(setIdx int) string {
if setIdx >= len(p.r.StepErrors) || p.r.StepErrors[setIdx] == nil {
return ""
}
return p.r.StepErrors[setIdx].Message
}
func (p *BatchResultRowsProvider) HasResult(setIdx int) bool {
return setIdx < len(p.r.StepResults) && p.r.StepResults[setIdx] != nil
}
func (h *hranaV2Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
result, err := h.executeStmt(ctx, query, args, true)
if err != nil {
return nil, err
}
switch result.Results[0].Response.Type {
case "execute":
res, err := result.Results[0].Response.ExecuteResult()
if err != nil {
return nil, err
}
return shared.NewRows(&StmtResultRowsProvider{res}), nil
case "batch":
res, err := result.Results[0].Response.BatchResult()
if err != nil {
return nil, err
}
if !h.schemaDb {
res.StepResults = res.StepResults[:len(res.StepResults)-1]
res.StepErrors = res.StepErrors[:len(res.StepErrors)-1]
}
return shared.NewRows(&BatchResultRowsProvider{res}), nil
default:
return nil, fmt.Errorf("failed to execute SQL: %s\n%s", query, "unknown response type")
}
}
func (h *hranaV2Conn) closeStream() {
if h.baton != "" {
go func(baton, url, jwt, host, encryptionKey string) {
msg := hrana.PipelineRequest{Baton: baton}
msg.Add(hrana.CloseStream())
_, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host, encryptionKey)
}(h.baton, h.url, h.jwt, h.host, h.remoteEncryptionKey)
h.baton = ""
}
}
func (h *hranaV2Conn) ResetSession(ctx context.Context) error {
h.closeStream()
return nil
}

View File

@@ -0,0 +1,18 @@
package shared
type result struct {
id int64
changes int64
}
func NewResult(id, changes int64) *result {
return &result{id: id, changes: changes}
}
func (r *result) LastInsertId() (int64, error) {
return r.id, nil
}
func (r *result) RowsAffected() (int64, error) {
return r.changes, nil
}

View File

@@ -0,0 +1,69 @@
package shared
import (
"database/sql/driver"
"fmt"
"io"
)
type rowsProvider interface {
SetsCount() int
RowsCount(setIdx int) int
Columns(setIdx int) []string
FieldValue(setIdx, rowIdx int, columnIdx int) driver.Value
Error(setIdx int) string
HasResult(setIdx int) bool
}
func NewRows(result rowsProvider) driver.Rows {
return &rows{result: result}
}
type rows struct {
result rowsProvider
currentResultSetIndex int
currentRowIdx int
}
func (r *rows) Columns() []string {
return r.result.Columns(r.currentResultSetIndex)
}
func (r *rows) Close() error {
return nil
}
func (r *rows) Next(dest []driver.Value) error {
if r.currentRowIdx == r.result.RowsCount(r.currentResultSetIndex) {
return io.EOF
}
count := len(r.result.Columns(r.currentResultSetIndex))
for idx := 0; idx < count; idx++ {
dest[idx] = r.result.FieldValue(r.currentResultSetIndex, r.currentRowIdx, idx)
}
r.currentRowIdx++
return nil
}
func (r *rows) HasNextResultSet() bool {
return r.currentResultSetIndex < r.result.SetsCount()-1
}
func (r *rows) NextResultSet() error {
if !r.HasNextResultSet() {
return io.EOF
}
r.currentResultSetIndex++
r.currentRowIdx = 0
errStr := r.result.Error(r.currentResultSetIndex)
if errStr != "" {
return fmt.Errorf("failed to execute statement\n%s", errStr)
}
if !r.result.HasResult(r.currentResultSetIndex) {
return fmt.Errorf("no results for statement")
}
return nil
}

View File

@@ -0,0 +1,257 @@
package shared
import (
"database/sql/driver"
"encoding/json"
"fmt"
"regexp"
"sort"
"github.com/antlr4-go/antlr/v4"
"github.com/tursodatabase/libsql-client-go/sqliteparser"
"github.com/tursodatabase/libsql-client-go/sqliteparserutils"
)
type ParamsInfo struct {
NamedParameters []string
PositionalParametersCount int
}
func ParseStatement(sql string) ([]string, []ParamsInfo, error) {
stmts, _ := sqliteparserutils.SplitStatement(sql)
stmtsParams := make([]ParamsInfo, len(stmts))
for idx, stmt := range stmts {
nameParams, positionalParamsCount, err := extractParameters(stmt)
if err != nil {
return nil, nil, err
}
stmtsParams[idx] = ParamsInfo{nameParams, positionalParamsCount}
}
return stmts, stmtsParams, nil
}
func ParseStatementAndArgs(sql string, args []driver.NamedValue) ([]string, []Params, error) {
stmts, _ := sqliteparserutils.SplitStatement(sql)
if len(args) == 0 {
return stmts, nil, nil
}
parameters, err := ConvertArgs(args)
if err != nil {
return nil, nil, err
}
stmtsParams := make([]Params, len(stmts))
totalParametersAlreadyUsed := 0
for idx, stmt := range stmts {
stmtParams, err := generateStatementParameters(stmt, parameters, totalParametersAlreadyUsed)
if err != nil {
return nil, nil, fmt.Errorf("fail to generate statement parameter. statement: %s. error: %v", stmt, err)
}
stmtsParams[idx] = stmtParams
totalParametersAlreadyUsed += stmtParams.Len()
}
return stmts, stmtsParams, nil
}
type paramsType int
const (
namedParameters paramsType = iota
positionalParameters
)
type Params struct {
positional []any
named map[string]any
}
func (p *Params) MarshalJSON() ([]byte, error) {
if len(p.named) > 0 {
return json.Marshal(p.named)
}
if len(p.positional) > 0 {
return json.Marshal(p.positional)
}
return json.Marshal(make([]any, 0))
}
func (p *Params) Named() map[string]any {
return p.named
}
func (p *Params) Positional() []any {
return p.positional
}
func (p *Params) Len() int {
if p.named != nil {
return len(p.named)
}
return len(p.positional)
}
func (p *Params) Type() paramsType {
if p.named != nil {
return namedParameters
}
return positionalParameters
}
func NewParams(t paramsType) Params {
p := Params{}
switch t {
case namedParameters:
p.named = make(map[string]any)
case positionalParameters:
p.positional = make([]any, 0)
}
return p
}
func getParamType(arg *driver.NamedValue) paramsType {
if arg.Name == "" {
return positionalParameters
}
return namedParameters
}
func ConvertArgs(args []driver.NamedValue) (Params, error) {
if len(args) == 0 {
return NewParams(positionalParameters), nil
}
var sortedArgs []*driver.NamedValue
for idx := range args {
sortedArgs = append(sortedArgs, &args[idx])
}
sort.Slice(sortedArgs, func(i, j int) bool {
return sortedArgs[i].Ordinal < sortedArgs[j].Ordinal
})
parametersType := getParamType(sortedArgs[0])
parameters := NewParams(parametersType)
for _, arg := range sortedArgs {
if parametersType != getParamType(arg) {
return Params{}, fmt.Errorf("driver does not accept positional and named parameters at the same time")
}
switch parametersType {
case positionalParameters:
parameters.positional = append(parameters.positional, arg.Value)
case namedParameters:
parameters.named[arg.Name] = arg.Value
}
}
return parameters, nil
}
func isExplain(stmt string) bool {
statementStream := antlr.NewInputStream(stmt)
lexer := sqliteparser.NewSQLiteLexer(statementStream)
tokenStream := antlr.NewCommonTokenStream(lexer, 0)
firstToken := tokenStream.LT(1)
return firstToken.GetTokenType() == sqliteparser.SQLiteParserEXPLAIN_
}
func generateStatementParameters(stmt string, queryParams Params, positionalParametersOffset int) (Params, error) {
nameParams, positionalParamsCount, err := extractParameters(stmt)
if err != nil {
return Params{}, err
}
stmtParams := NewParams(queryParams.Type())
switch queryParams.Type() {
case positionalParameters:
if positionalParametersOffset+positionalParamsCount > len(queryParams.positional) {
if isExplain(stmt) {
return Params{}, nil
}
// Positional parameters with indexes most of the time will have fewer args than parameters.
stmtParams.positional = queryParams.positional[positionalParametersOffset:len(queryParams.positional)]
} else {
stmtParams.positional = queryParams.positional[positionalParametersOffset : positionalParametersOffset+positionalParamsCount]
}
case namedParameters:
stmtParametersNeeded := make(map[string]bool)
for _, stmtParametersName := range nameParams {
stmtParametersNeeded[stmtParametersName] = true
}
for queryParamsName, queryParamsValue := range queryParams.named {
if stmtParametersNeeded[queryParamsName] {
stmtParams.named[queryParamsName] = queryParamsValue
delete(stmtParametersNeeded, queryParamsName)
}
}
}
return stmtParams, nil
}
func extractParameters(stmt string) (nameParams []string, positionalParamsCount int, err error) {
statementStream := antlr.NewInputStream(stmt)
sqliteparser.NewSQLiteLexer(statementStream)
lexer := sqliteparser.NewSQLiteLexer(statementStream)
allTokens := lexer.GetAllTokens()
nameParamsSet := make(map[string]bool)
positionalParamsWithIndexesSet := make(map[string]bool)
// ^: asserts the start of the string.
// \?: matches a literal question mark character.
// (\d+)? captures one more digits (0-9) in a group, but group is optional due to the ? quantifier.
// $: asserts the end of thr string so as to avoid this scenario: ?123ABC.
re := regexp.MustCompile(`^\?(\d+)?$`)
for _, token := range allTokens {
tokenType := token.GetTokenType()
if tokenType == sqliteparser.SQLiteLexerBIND_PARAMETER {
parameter := token.GetText()
match := re.FindStringSubmatch(parameter)
if match == nil {
paramWithoutPrefix, err := removeParamPrefix(parameter)
if err != nil {
return []string{}, 0, err
}
nameParamsSet[paramWithoutPrefix] = true
continue
}
posS := string(match[1])
if posS == "" {
// When an empty string, it means the parameter is a
// positional parameter without an index (e.g, ?).
positionalParamsCount++
} else {
// Positional parameter with indexes (e.g., ?<number>)
// must be deduped.
positionalParamsWithIndexesSet[posS] = true
}
}
}
nameParams = make([]string, 0, len(nameParamsSet))
for k := range nameParamsSet {
nameParams = append(nameParams, k)
}
// Only count unique number of positional parameters.
positionalParamsCount += len(positionalParamsWithIndexesSet)
return nameParams, positionalParamsCount, nil
}
func removeParamPrefix(paramName string) (string, error) {
if paramName[0] == ':' || paramName[0] == '@' || paramName[0] == '$' {
return paramName[1:], nil
}
return "", fmt.Errorf("all named parameters must start with ':', or '@' or '$'")
}

View File

@@ -0,0 +1,89 @@
package shared
import (
"reflect"
"sort"
"testing"
)
func TestExtractParameters(t *testing.T) {
tests := []struct {
name string
value string
nameParams []string
positionalParamsCount int
err error
}{
{
name: "OnlyColonNameParams",
value: "select :column from :table",
nameParams: []string{"column", "table"},
},
{
name: "OnlyAtNameParams",
value: "select @column from @table",
nameParams: []string{"column", "table"},
},
{
name: "OnlyDollarSignNameParams",
value: "select $column from $table",
nameParams: []string{"column", "table"},
},
{
name: "RepeatedNamedParameter",
value: "select :number, :number",
nameParams: []string{"number"},
},
{
name: "OnlyPositionalParams",
value: "select ? from ?",
nameParams: []string{},
positionalParamsCount: 2,
},
{
name: "OnlyPositionalParamsWithoutIndexes",
value: "select ? from ?",
nameParams: []string{},
positionalParamsCount: 2,
},
{
name: "OnlyPositionalParamsWithIndexes (dedup)",
value: "select ?1 from ?1",
nameParams: []string{},
positionalParamsCount: 1,
},
{
name: "PositionalParamsWithIndexes",
value: "select ? from ?1",
nameParams: []string{},
positionalParamsCount: 2,
},
{
name: "MixedParams",
value: "select :column1, @column2, $column3, ? from ?",
nameParams: []string{"column1", "column2", "column3"},
positionalParamsCount: 2,
},
{
name: "NoParams",
value: "select myColumn from myTable",
nameParams: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotNameParams, gotUniquePositionalParamsCount, gotErr := extractParameters(tt.value)
sort.Strings(gotNameParams)
sort.Strings(tt.nameParams)
if !reflect.DeepEqual(gotNameParams, tt.nameParams) {
t.Errorf("got nameParams %#v, want %#v", gotNameParams, tt.nameParams)
}
if !reflect.DeepEqual(gotUniquePositionalParamsCount, tt.positionalParamsCount) {
t.Errorf("got positionalParams %#v, want %#v", gotUniquePositionalParamsCount, tt.positionalParamsCount)
}
if !reflect.DeepEqual(gotUniquePositionalParamsCount, tt.positionalParamsCount) {
t.Errorf("got err %v, want %v", gotErr, tt.err)
}
})
}
}

View File

@@ -0,0 +1,194 @@
package ws
import (
"context"
"database/sql/driver"
"io"
"sort"
)
type result struct {
id int64
changes int64
}
func (r *result) LastInsertId() (int64, error) {
return r.id, nil
}
func (r *result) RowsAffected() (int64, error) {
return r.changes, nil
}
type rows struct {
res *execResponse
currentRowIdx int
}
func (r *rows) Columns() []string {
return r.res.columns()
}
func (r *rows) Close() error {
return nil
}
func (r *rows) Next(dest []driver.Value) error {
if r.currentRowIdx == r.res.rowsCount() {
return io.EOF
}
count := r.res.rowLen(r.currentRowIdx)
for idx := 0; idx < count; idx++ {
v, err := r.res.value(r.currentRowIdx, idx)
if err != nil {
return err
}
dest[idx] = v
}
r.currentRowIdx++
return nil
}
type conn struct {
ws *websocketConn
}
func Connect(url string, jwt string) (*conn, error) {
c, err := connect(url, jwt)
if err != nil {
return nil, err
}
return &conn{c}, nil
}
type stmt struct {
c *conn
query string
}
func (s stmt) Close() error {
return nil
}
func (s stmt) NumInput() int {
return -1
}
func convertToNamed(args []driver.Value) []driver.NamedValue {
if len(args) == 0 {
return nil
}
result := []driver.NamedValue{}
for idx := range args {
result = append(result, driver.NamedValue{Ordinal: idx, Value: args[idx]})
}
return result
}
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.ExecContext(context.Background(), convertToNamed(args))
}
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.QueryContext(context.Background(), convertToNamed(args))
}
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
return s.c.ExecContext(ctx, s.query, args)
}
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
return s.c.QueryContext(ctx, s.query, args)
}
func (c *conn) Ping() error {
return c.PingContext(context.Background())
}
func (c *conn) PingContext(ctx context.Context) error {
_, err := c.ws.exec(ctx, "SELECT 1", params{}, false)
return err
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
return stmt{c, query}, nil
}
func (c *conn) Close() error {
return c.ws.Close()
}
type tx struct {
c *conn
}
func (t tx) Commit() error {
_, err := t.c.ExecContext(context.Background(), "COMMIT", nil)
if err != nil {
return err
}
return nil
}
func (t tx) Rollback() error {
_, err := t.c.ExecContext(context.Background(), "ROLLBACK", nil)
if err != nil {
return err
}
return nil
}
func (c *conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c *conn) BeginTx(ctx context.Context, _ driver.TxOptions) (driver.Tx, error) {
_, err := c.ExecContext(ctx, "BEGIN", nil)
if err != nil {
return tx{nil}, err
}
return tx{c}, nil
}
func convertArgs(args []driver.NamedValue) params {
if len(args) == 0 {
return params{}
}
positionalArgs := [](*driver.NamedValue){}
namedArgs := []namedParam{}
for idx := range args {
if len(args[idx].Name) > 0 {
namedArgs = append(namedArgs, namedParam{args[idx].Name, args[idx].Value})
} else {
positionalArgs = append(positionalArgs, &args[idx])
}
}
sort.Slice(positionalArgs, func(i, j int) bool {
return positionalArgs[i].Ordinal < positionalArgs[j].Ordinal
})
posArgs := [](any){}
for idx := range positionalArgs {
posArgs = append(posArgs, positionalArgs[idx].Value)
}
return params{PositinalArgs: posArgs, NamedArgs: namedArgs}
}
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
res, err := c.ws.exec(ctx, query, convertArgs(args), false)
if err != nil {
return nil, err
}
return &result{res.lastInsertId(), res.affectedRowCount()}, nil
}
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
res, err := c.ws.exec(ctx, query, convertArgs(args), true)
if err != nil {
return nil, err
}
return &rows{res, 0}, nil
}

View File

@@ -0,0 +1,322 @@
package ws
import (
"context"
"database/sql/driver"
"encoding/base64"
"fmt"
"strconv"
"sync"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
// defaultWSTimeout specifies the timeout used for initial http connection
var defaultWSTimeout = 120 * time.Second
func errorMsg(errorResp interface{}) string {
return errorResp.(map[string]interface{})["error"].(map[string]interface{})["message"].(string)
}
func isErrorResp(resp interface{}) bool {
return resp.(map[string]interface{})["type"] == "response_error"
}
type websocketConn struct {
conn *websocket.Conn
idPool *idPool
}
type namedParam struct {
Name string
Value any
}
type params struct {
PositinalArgs []any
NamedArgs []namedParam
}
func convertValue(v any) (map[string]interface{}, error) {
res := map[string]interface{}{}
if v == nil {
res["type"] = "null"
} else if integer, ok := v.(int64); ok {
res["type"] = "integer"
res["value"] = strconv.FormatInt(integer, 10)
} else if text, ok := v.(string); ok {
res["type"] = "text"
res["value"] = text
} else if blob, ok := v.([]byte); ok {
res["type"] = "blob"
res["base64"] = base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString(blob)
} else if float, ok := v.(float64); ok {
res["type"] = "float"
res["value"] = float
} else if boolean, ok := v.(bool); ok {
res["type"] = "integer"
if boolean {
res["value"] = "1"
} else {
res["value"] = "0"
}
} else {
return nil, fmt.Errorf("unsupported value type: %s", v)
}
return res, nil
}
type execResponse struct {
resp map[string]interface{}
}
func (r *execResponse) affectedRowCount() int64 {
return int64(r.resp["affected_row_count"].(float64))
}
func (r *execResponse) lastInsertId() int64 {
id, ok := r.resp["last_insert_rowid"].(string)
if !ok {
return 0
}
value, _ := strconv.ParseInt(id, 10, 64)
return value
}
func (r *execResponse) columns() []string {
res := []string{}
cols := r.resp["cols"].([]interface{})
for idx := range cols {
var v string = ""
if cols[idx].(map[string]interface{})["name"] != nil {
v = cols[idx].(map[string]interface{})["name"].(string)
}
res = append(res, v)
}
return res
}
func (r *execResponse) rowsCount() int {
return len(r.resp["rows"].([]interface{}))
}
func (r *execResponse) rowLen(rowIdx int) int {
return len(r.resp["rows"].([]interface{})[rowIdx].([]interface{}))
}
func (r *execResponse) value(rowIdx int, colIdx int) (any, error) {
val := r.resp["rows"].([]interface{})[rowIdx].([]interface{})[colIdx].(map[string]interface{})
switch val["type"] {
case "null":
return nil, nil
case "integer":
v, err := strconv.ParseInt(val["value"].(string), 10, 64)
if err != nil {
return nil, err
}
return v, nil
case "text":
return val["value"].(string), nil
case "blob":
base64Encoded := val["base64"].(string)
v, err := base64.StdEncoding.WithPadding(base64.NoPadding).DecodeString(base64Encoded)
if err != nil {
return nil, err
}
return v, nil
case "float":
return val["value"].(float64), nil
}
return nil, fmt.Errorf("unrecognized value type: %s", val["type"])
}
func (ws *websocketConn) exec(ctx context.Context, sql string, sqlParams params, wantRows bool) (*execResponse, error) {
requestId := ws.idPool.Get()
defer ws.idPool.Put(requestId)
stmt := map[string]interface{}{
"sql": sql,
"want_rows": wantRows,
}
if len(sqlParams.PositinalArgs) > 0 {
args := []map[string]interface{}{}
for idx := range sqlParams.PositinalArgs {
v, err := convertValue(sqlParams.PositinalArgs[idx])
if err != nil {
return nil, err
}
args = append(args, v)
}
stmt["args"] = args
}
if len(sqlParams.NamedArgs) > 0 {
args := []map[string]interface{}{}
for idx := range sqlParams.NamedArgs {
v, err := convertValue(sqlParams.NamedArgs[idx].Value)
if err != nil {
return nil, err
}
arg := map[string]interface{}{
"name": sqlParams.NamedArgs[idx].Name,
"value": v,
}
args = append(args, arg)
}
stmt["named_args"] = args
}
err := wsjson.Write(ctx, ws.conn, map[string]interface{}{
"type": "request",
"request_id": requestId,
"request": map[string]interface{}{
"type": "execute",
"stream_id": 0,
"stmt": stmt,
},
})
if err != nil {
return nil, fmt.Errorf("%w: %s", driver.ErrBadConn, err.Error())
}
var resp interface{}
if err = wsjson.Read(ctx, ws.conn, &resp); err != nil {
return nil, fmt.Errorf("%w: %s", driver.ErrBadConn, err.Error())
}
if isErrorResp(resp) {
err = fmt.Errorf("unable to execute %s: %s", sql, errorMsg(resp))
return nil, err
}
return &execResponse{resp.(map[string]interface{})["response"].(map[string]interface{})["result"].(map[string]interface{})}, nil
}
func (ws *websocketConn) Close() error {
return ws.conn.Close(websocket.StatusNormalClosure, "All's good")
}
func connect(url string, jwt string) (*websocketConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultWSTimeout)
defer cancel()
c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
Subprotocols: []string{"hrana1"},
})
if err != nil {
return nil, err
}
c.SetReadLimit(1024 * 1024 * 16) // 16MB
err = wsjson.Write(ctx, c, map[string]interface{}{
"type": "hello",
"jwt": jwt,
})
if err != nil {
c.Close(websocket.StatusInternalError, err.Error())
return nil, err
}
err = wsjson.Write(ctx, c, map[string]interface{}{
"type": "request",
"request_id": 0,
"request": map[string]interface{}{
"type": "open_stream",
"stream_id": 0,
},
})
if err != nil {
c.Close(websocket.StatusInternalError, err.Error())
return nil, err
}
var helloResp interface{}
err = wsjson.Read(ctx, c, &helloResp)
if err != nil {
c.Close(websocket.StatusInternalError, err.Error())
return nil, err
}
if helloResp.(map[string]interface{})["type"] == "hello_error" {
err = fmt.Errorf("handshake error: %s", errorMsg(helloResp))
c.Close(websocket.StatusProtocolError, err.Error())
return nil, err
}
var openStreamResp interface{}
err = wsjson.Read(ctx, c, &openStreamResp)
if err != nil {
c.Close(websocket.StatusInternalError, err.Error())
return nil, err
}
if isErrorResp(openStreamResp) {
err = fmt.Errorf("unable to open stream: %s", errorMsg(helloResp))
c.Close(websocket.StatusProtocolError, err.Error())
return nil, err
}
return &websocketConn{c, newIDPool()}, nil
}
// Below is modified IDPool from "vitess.io/vitess/go/pools"
// idPool is used to ensure that the set of IDs in use concurrently never
// contains any duplicates. The IDs start at 1 and increase without bound, but
// will never be larger than the peak number of concurrent uses.
//
// idPool's Get() and Put() methods can be used concurrently.
type idPool struct {
sync.Mutex
// used holds the set of values that have been returned to us with Put().
used map[uint32]bool
// maxUsed remembers the largest value we've given out.
maxUsed uint32
}
// NewIDPool creates and initializes an idPool.
func newIDPool() *idPool {
return &idPool{
used: make(map[uint32]bool),
maxUsed: 0,
}
}
// Get returns an ID that is unique among currently active users of this pool.
func (pool *idPool) Get() (id uint32) {
pool.Lock()
defer pool.Unlock()
// Pick a value that's been returned, if any.
for key := range pool.used {
delete(pool.used, key)
return key
}
// No recycled IDs are available, so increase the pool size.
pool.maxUsed++
return pool.maxUsed
}
// Put recycles an ID back into the pool for others to use. Putting back a value
// or 0, or a value that is not currently "checked out", will result in a panic
// because that should never happen except in the case of a programming error.
func (pool *idPool) Put(id uint32) {
pool.Lock()
defer pool.Unlock()
if id < 1 || id > pool.maxUsed {
panic(fmt.Errorf("idPool.Put(%v): invalid value, must be in the range [1,%v]", id, pool.maxUsed))
}
if pool.used[id] {
panic(fmt.Errorf("idPool.Put(%v): can't put value that was already recycled", id))
}
// If we're recycling maxUsed, just shrink the pool.
if id == pool.maxUsed {
pool.maxUsed = id - 1
return
}
// Add it to the set of recycled IDs.
pool.used[id] = true
}

View File

@@ -0,0 +1,136 @@
package ws
import (
"fmt"
"reflect"
"testing"
)
func TestConvertValue(t *testing.T) {
tests := []struct {
name string
value any
want map[string]any
err error
}{
{
name: "nil",
value: nil,
want: map[string]any{
"type": "null",
},
err: nil,
},
{
name: "integer",
value: int64(42),
want: map[string]any{
"type": "integer",
"value": "42",
},
err: nil,
},
{
name: "text",
value: "turso for win",
want: map[string]any{
"type": "text",
"value": "turso for win",
},
err: nil,
},
{
name: "blob",
value: []byte("hello world"),
want: map[string]any{
"type": "blob",
// `hello world` encoded is `aGVsbG8gd29ybGQ=` but we want without padding
"base64": "aGVsbG8gd29ybGQ",
},
err: nil,
},
{
name: "float",
value: 3.14,
want: map[string]any{
"type": "float",
"value": 3.14,
},
err: nil,
},
{
name: "boolean_true",
value: true,
want: map[string]any{
"type": "integer",
"value": "1",
},
err: nil,
},
{
name: "boolean_false",
value: false,
want: map[string]any{
"type": "integer",
"value": "0",
},
err: nil,
},
{
name: "unsupported",
value: struct{}{},
want: nil,
err: fmt.Errorf("unsupported value type: %s", struct{}{}),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := convertValue(tt.value)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("got %v, want %v", got, tt.want)
}
if !reflect.DeepEqual(err, tt.err) {
t.Errorf("got error %v, want %v", err, tt.err)
}
})
}
}
func Test_execResponse_lastInsertId(t *testing.T) {
tests := []struct {
name string
value map[string]interface{}
want int64
}{
{
name: "valid",
value: map[string]interface{}{"last_insert_rowid": "42"},
want: 42,
},
{
name: "empty",
value: map[string]interface{}{},
want: 0,
},
{
name: "invalid",
value: map[string]interface{}{"last_insert_rowid": "invalid"},
want: 0,
},
{
name: "invalid_type",
value: map[string]interface{}{"last_insert_rowid": 42.0},
want: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &execResponse{
resp: tt.value,
}
if got := r.lastInsertId(); got != tt.want {
t.Errorf("lastInsertId() = %v, want %v", got, tt.want)
}
})
}
}