fix: 将 libsql-client-go 作为普通目录提交,而非 gitlink
This commit is contained in:
22
go_modules/libsql-client-go/libsql/internal/hrana/batch.go
Normal file
22
go_modules/libsql-client-go/libsql/internal/hrana/batch.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package hrana
|
||||
|
||||
type Batch struct {
|
||||
Steps []BatchStep `json:"steps"`
|
||||
ReplicationIndex *uint64 `json:"replication_index,omitempty"`
|
||||
}
|
||||
|
||||
type BatchStep struct {
|
||||
Stmt Stmt `json:"stmt"`
|
||||
Condition *BatchCondition `json:"condition,omitempty"`
|
||||
}
|
||||
|
||||
type BatchCondition struct {
|
||||
Type string `json:"type"`
|
||||
Step *int32 `json:"step,omitempty"`
|
||||
Cond *BatchCondition `json:"cond,omitempty"`
|
||||
Conds []BatchCondition `json:"conds,omitempty"`
|
||||
}
|
||||
|
||||
func (b *Batch) Add(stmt Stmt, condition *BatchCondition) {
|
||||
b.Steps = append(b.Steps, BatchStep{Stmt: stmt, Condition: condition})
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package hrana
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type BatchResult struct {
|
||||
StepResults []*StmtResult `json:"step_results"`
|
||||
StepErrors []*Error `json:"step_errors"`
|
||||
ReplicationIndex *uint64 `json:"replication_index"`
|
||||
}
|
||||
|
||||
func (b *BatchResult) UnmarshalJSON(data []byte) error {
|
||||
type Alias BatchResult
|
||||
aux := &struct {
|
||||
ReplicationIndex interface{} `json:"replication_index,omitempty"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(b),
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if aux.ReplicationIndex == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := aux.ReplicationIndex.(type) {
|
||||
case float64:
|
||||
repIndex := uint64(v)
|
||||
b.ReplicationIndex = &repIndex
|
||||
case string:
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
repIndex, err := strconv.ParseUint(v, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.ReplicationIndex = &repIndex
|
||||
default:
|
||||
return fmt.Errorf("invalid type for replication index: %T", v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package hrana
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBatchResult_UnmarshalJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
jsonData []byte
|
||||
expected *uint64
|
||||
}{
|
||||
{
|
||||
jsonData: []byte(`{"replication_index":1}`),
|
||||
expected: uint64Ptr(1),
|
||||
},
|
||||
{
|
||||
jsonData: []byte(`{"replication_index":"1"}`),
|
||||
expected: uint64Ptr(1),
|
||||
},
|
||||
{
|
||||
jsonData: []byte(`{"replication_index":""}`),
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
jsonData: []byte(`{}`),
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
jsonData: []byte(`{"replication_index":"0"}`),
|
||||
expected: uint64Ptr(0),
|
||||
},
|
||||
{
|
||||
jsonData: []byte(`{"replication_index":0}`),
|
||||
expected: uint64Ptr(0),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
|
||||
batchResult := &BatchResult{}
|
||||
err := json.Unmarshal(tc.jsonData, batchResult)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(batchResult.ReplicationIndex, tc.expected) {
|
||||
t.Errorf("ReplicationIndex field is not correctly unmarshaled got = %v, want = %v", batchResult.ReplicationIndex, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func uint64Ptr(n uint64) *uint64 {
|
||||
return &n
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package hrana
|
||||
|
||||
type PipelineRequest struct {
|
||||
Baton string `json:"baton,omitempty"`
|
||||
Requests []StreamRequest `json:"requests"`
|
||||
}
|
||||
|
||||
func (pr *PipelineRequest) Add(request StreamRequest) {
|
||||
pr.Requests = append(pr.Requests, request)
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package hrana
|
||||
|
||||
type PipelineResponse struct {
|
||||
Baton string `json:"baton,omitempty"`
|
||||
BaseUrl string `json:"base_url,omitempty"`
|
||||
Results []StreamResult `json:"results"`
|
||||
}
|
||||
58
go_modules/libsql-client-go/libsql/internal/hrana/stmt.go
Normal file
58
go_modules/libsql-client-go/libsql/internal/hrana/stmt.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package hrana
|
||||
|
||||
import (
|
||||
"github.com/tursodatabase/libsql-client-go/libsql/internal/http/shared"
|
||||
)
|
||||
|
||||
type Stmt struct {
|
||||
Sql *string `json:"sql,omitempty"`
|
||||
SqlId *int32 `json:"sql_id,omitempty"`
|
||||
Args []Value `json:"args,omitempty"`
|
||||
NamedArgs []NamedArg `json:"named_args,omitempty"`
|
||||
WantRows bool `json:"want_rows"`
|
||||
ReplicationIndex *uint64 `json:"replication_index,omitempty"`
|
||||
}
|
||||
|
||||
type NamedArg struct {
|
||||
Name string `json:"name"`
|
||||
Value Value `json:"value"`
|
||||
}
|
||||
|
||||
func (s *Stmt) AddArgs(params shared.Params) error {
|
||||
if len(params.Named()) > 0 {
|
||||
return s.AddNamedArgs(params.Named())
|
||||
} else {
|
||||
return s.AddPositionalArgs(params.Positional())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Stmt) AddPositionalArgs(args []any) error {
|
||||
argValues := make([]Value, len(args))
|
||||
for idx := range args {
|
||||
var err error
|
||||
if argValues[idx], err = ToValue(args[idx]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.Args = argValues
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Stmt) AddNamedArgs(args map[string]any) error {
|
||||
argValues := make([]NamedArg, len(args))
|
||||
idx := 0
|
||||
for key, value := range args {
|
||||
var err error
|
||||
var v Value
|
||||
if v, err = ToValue(value); err != nil {
|
||||
return err
|
||||
}
|
||||
argValues[idx] = NamedArg{
|
||||
Name: key,
|
||||
Value: v,
|
||||
}
|
||||
idx++
|
||||
}
|
||||
s.NamedArgs = argValues
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package hrana
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Column struct {
|
||||
Name *string `json:"name"`
|
||||
Type *string `json:"decltype"`
|
||||
}
|
||||
|
||||
type StmtResult struct {
|
||||
Cols []Column `json:"cols"`
|
||||
Rows [][]Value `json:"rows"`
|
||||
AffectedRowCount int32 `json:"affected_row_count"`
|
||||
LastInsertRowId *string `json:"last_insert_rowid"`
|
||||
ReplicationIndex *uint64 `json:"replication_index"`
|
||||
}
|
||||
|
||||
func (r *StmtResult) GetLastInsertRowId() int64 {
|
||||
if r.LastInsertRowId != nil {
|
||||
if integer, err := strconv.ParseInt(*r.LastInsertRowId, 10, 64); err == nil {
|
||||
return integer
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (r *StmtResult) UnmarshalJSON(data []byte) error {
|
||||
type Alias StmtResult
|
||||
aux := &struct {
|
||||
ReplicationIndex interface{} `json:"replication_index,omitempty"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(r),
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if aux.ReplicationIndex == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := aux.ReplicationIndex.(type) {
|
||||
case float64:
|
||||
repIndex := uint64(v)
|
||||
r.ReplicationIndex = &repIndex
|
||||
case string:
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
repIndex, err := strconv.ParseUint(v, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.ReplicationIndex = &repIndex
|
||||
default:
|
||||
return fmt.Errorf("invalid type for replication index: %T", v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package hrana
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStmtResult_UnmarshalJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
jsonData []byte
|
||||
expected *uint64
|
||||
}{
|
||||
{
|
||||
jsonData: []byte(`{"replication_index":1}`),
|
||||
expected: uint64Ptr(1),
|
||||
},
|
||||
{
|
||||
jsonData: []byte(`{"replication_index":"1"}`),
|
||||
expected: uint64Ptr(1),
|
||||
},
|
||||
{
|
||||
jsonData: []byte(`{"replication_index":""}`),
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
jsonData: []byte(`{}`),
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
jsonData: []byte(`{"replication_index":"0"}`),
|
||||
expected: uint64Ptr(0),
|
||||
},
|
||||
{
|
||||
jsonData: []byte(`{"replication_index":0}`),
|
||||
expected: uint64Ptr(0),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
|
||||
stmtResult := &StmtResult{}
|
||||
err := json.Unmarshal(tc.jsonData, stmtResult)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(stmtResult.ReplicationIndex, tc.expected) {
|
||||
t.Errorf("ReplicationIndex field is not correctly unmarshaled got = %v, want = %v", stmtResult.ReplicationIndex, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package hrana
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStmtWithPositionalArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []any
|
||||
want []Value
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "int args",
|
||||
args: []any{1, 2},
|
||||
want: []Value{{Type: "integer", Value: "1"}, {Type: "integer", Value: "2"}},
|
||||
},
|
||||
{
|
||||
name: "string args",
|
||||
args: []any{"a", "b"},
|
||||
want: []Value{{Type: "text", Value: "a"}, {Type: "text", Value: "b"}},
|
||||
},
|
||||
{
|
||||
name: "invalid arg",
|
||||
args: []any{make(chan int)},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
stmt := Stmt{}
|
||||
err := stmt.AddPositionalArgs(tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("StmtWithPositionalArgs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(stmt.Args, tt.want) {
|
||||
t.Errorf("got = %v, want %v", stmt.Args, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStmtWithNamedArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args map[string]any
|
||||
want []NamedArg
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "int args",
|
||||
args: map[string]any{"arg1": 1, "arg2": int64(2)},
|
||||
want: []NamedArg{
|
||||
{Name: "arg1", Value: Value{Type: "integer", Value: "1"}},
|
||||
{Name: "arg2", Value: Value{Type: "integer", Value: "2"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string args",
|
||||
args: map[string]any{"arg1": "a", "arg2": "b"},
|
||||
want: []NamedArg{
|
||||
{Name: "arg1", Value: Value{Type: "text", Value: "a"}},
|
||||
{Name: "arg2", Value: Value{Type: "text", Value: "b"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid arg",
|
||||
args: map[string]any{"arg1": make(chan int)},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
stmt := Stmt{}
|
||||
err := stmt.AddNamedArgs(tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("StmtWithNamedArgs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
got := make(map[NamedArg]struct{})
|
||||
want := make(map[NamedArg]struct{})
|
||||
for _, arg := range stmt.NamedArgs {
|
||||
got[arg] = struct{}{}
|
||||
}
|
||||
for _, arg := range tt.want {
|
||||
want[arg] = struct{}{}
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got = %v, want %v", stmt.NamedArgs, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package hrana
|
||||
|
||||
import (
|
||||
"github.com/tursodatabase/libsql-client-go/libsql/internal/http/shared"
|
||||
)
|
||||
|
||||
type StreamRequest struct {
|
||||
Type string `json:"type"`
|
||||
Stmt *Stmt `json:"stmt,omitempty"`
|
||||
Batch *Batch `json:"batch,omitempty"`
|
||||
Sql *string `json:"sql,omitempty"`
|
||||
SqlId *int32 `json:"sql_id,omitempty"`
|
||||
}
|
||||
|
||||
func CloseStream() StreamRequest {
|
||||
return StreamRequest{Type: "close"}
|
||||
}
|
||||
|
||||
func ExecuteStream(sql string, params *shared.Params, wantRows bool) (*StreamRequest, error) {
|
||||
stmt := &Stmt{
|
||||
Sql: &sql,
|
||||
WantRows: wantRows,
|
||||
}
|
||||
if params != nil {
|
||||
if err := stmt.AddArgs(*params); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &StreamRequest{Type: "execute", Stmt: stmt}, nil
|
||||
}
|
||||
|
||||
func ExecuteStoredStream(sqlId int32, params shared.Params, wantRows bool) (*StreamRequest, error) {
|
||||
stmt := &Stmt{
|
||||
SqlId: &sqlId,
|
||||
WantRows: wantRows,
|
||||
}
|
||||
if err := stmt.AddArgs(params); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StreamRequest{Type: "execute", Stmt: stmt}, nil
|
||||
}
|
||||
|
||||
func BatchStream(sqls []string, params []shared.Params, wantRows bool, transactional bool) (*StreamRequest, error) {
|
||||
size := len(sqls)
|
||||
if transactional {
|
||||
size += 1
|
||||
}
|
||||
batch := &Batch{Steps: make([]BatchStep, 0, size)}
|
||||
addArgs := len(params) > 0
|
||||
for idx, sql := range sqls {
|
||||
s := sql
|
||||
stmt := &Stmt{
|
||||
Sql: &s,
|
||||
WantRows: wantRows,
|
||||
}
|
||||
if addArgs {
|
||||
if err := stmt.AddArgs(params[idx]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var condition *BatchCondition
|
||||
if transactional {
|
||||
if idx > 0 {
|
||||
prev_idx := int32(idx - 1)
|
||||
condition = &BatchCondition{
|
||||
Type: "ok",
|
||||
Step: &prev_idx,
|
||||
}
|
||||
}
|
||||
}
|
||||
batch.Add(*stmt, condition)
|
||||
}
|
||||
if transactional {
|
||||
rollback := "ROLLBACK"
|
||||
last_idx := int32(len(sqls) - 1)
|
||||
batch.Add(Stmt{Sql: &rollback, WantRows: false},
|
||||
&BatchCondition{Type: "not", Cond: &BatchCondition{Type: "ok", Step: &last_idx}})
|
||||
}
|
||||
return &StreamRequest{Type: "batch", Batch: batch}, nil
|
||||
}
|
||||
|
||||
func StoreSqlStream(sql string, sqlId int32) StreamRequest {
|
||||
return StreamRequest{Type: "store_sql", Sql: &sql, SqlId: &sqlId}
|
||||
}
|
||||
|
||||
func CloseStoredSqlStream(sqlId int32) StreamRequest {
|
||||
return StreamRequest{Type: "close_sql", SqlId: &sqlId}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package hrana
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type StreamResult struct {
|
||||
Type string `json:"type"`
|
||||
Response *StreamResponse `json:"response,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type StreamResponse struct {
|
||||
Type string `json:"type"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
func (r *StreamResponse) ExecuteResult() (*StmtResult, error) {
|
||||
if r.Type != "execute" {
|
||||
return nil, fmt.Errorf("invalid response type: %s", r.Type)
|
||||
}
|
||||
|
||||
var res StmtResult
|
||||
if err := json.Unmarshal(r.Result, &res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func (r *StreamResponse) BatchResult() (*BatchResult, error) {
|
||||
if r.Type != "batch" {
|
||||
return nil, fmt.Errorf("invalid response type: %s", r.Type)
|
||||
}
|
||||
|
||||
var res BatchResult
|
||||
if err := json.Unmarshal(r.Result, &res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, e := range res.StepErrors {
|
||||
if e != nil {
|
||||
return nil, errors.New(e.Message)
|
||||
}
|
||||
}
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Message string `json:"message"`
|
||||
Code *string `json:"code,omitempty"`
|
||||
}
|
||||
89
go_modules/libsql-client-go/libsql/internal/hrana/value.go
Normal file
89
go_modules/libsql-client-go/libsql/internal/hrana/value.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package hrana
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Value struct {
|
||||
Type string `json:"type"`
|
||||
Value any `json:"value,omitempty"`
|
||||
Base64 *string `json:"base64,omitempty"`
|
||||
}
|
||||
|
||||
func (v Value) ToValue(columnType *string) any {
|
||||
if v.Type == "blob" {
|
||||
if v.Base64 == nil {
|
||||
return nil
|
||||
}
|
||||
bytes, err := base64.StdEncoding.WithPadding(base64.NoPadding).DecodeString(*v.Base64)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return bytes
|
||||
} else if v.Type == "integer" {
|
||||
integer, err := strconv.ParseInt(v.Value.(string), 10, 64)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return integer
|
||||
} else if columnType != nil {
|
||||
if (strings.ToLower(*columnType) == "timestamp" || strings.ToLower(*columnType) == "datetime") && v.Type == "text" {
|
||||
for _, format := range []string{
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02T15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999",
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02T15:04:05",
|
||||
"2006-01-02 15:04",
|
||||
"2006-01-02T15:04",
|
||||
"2006-01-02",
|
||||
} {
|
||||
if t, err := time.ParseInLocation(format, v.Value.(string), time.UTC); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return v.Value
|
||||
}
|
||||
|
||||
func ToValue(v any) (Value, error) {
|
||||
var res Value
|
||||
if v == nil {
|
||||
res.Type = "null"
|
||||
} else if integer, ok := v.(int64); ok {
|
||||
res.Type = "integer"
|
||||
res.Value = strconv.FormatInt(integer, 10)
|
||||
} else if integer, ok := v.(int); ok {
|
||||
res.Type = "integer"
|
||||
res.Value = strconv.FormatInt(int64(integer), 10)
|
||||
} else if text, ok := v.(string); ok {
|
||||
res.Type = "text"
|
||||
res.Value = text
|
||||
} else if blob, ok := v.([]byte); ok {
|
||||
res.Type = "blob"
|
||||
b64 := base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString(blob)
|
||||
res.Base64 = &b64
|
||||
} else if float, ok := v.(float64); ok {
|
||||
res.Type = "float"
|
||||
res.Value = float
|
||||
} else if t, ok := v.(time.Time); ok {
|
||||
res.Type = "text"
|
||||
res.Value = t.Format("2006-01-02 15:04:05.999999999-07:00")
|
||||
} else if t, ok := v.(bool); ok {
|
||||
res.Type = "integer"
|
||||
res.Value = "0"
|
||||
if t {
|
||||
res.Value = "1"
|
||||
}
|
||||
} else {
|
||||
return res, fmt.Errorf("unsupported value type: %s", v)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
304
go_modules/libsql-client-go/libsql/internal/hrana/value_test.go
Normal file
304
go_modules/libsql-client-go/libsql/internal/hrana/value_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package hrana
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func toPtr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestValueToValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
columnType string
|
||||
value Value
|
||||
want any
|
||||
}{
|
||||
{
|
||||
name: "null",
|
||||
value: Value{
|
||||
Type: "null",
|
||||
Value: nil,
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
value: Value{
|
||||
Type: "integer",
|
||||
Value: strconv.FormatInt(int64(42), 10),
|
||||
},
|
||||
want: int64(42),
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
value: Value{
|
||||
Type: "text",
|
||||
Value: "foo",
|
||||
},
|
||||
want: "foo",
|
||||
},
|
||||
{
|
||||
name: "bytes",
|
||||
value: Value{
|
||||
Type: "blob",
|
||||
Base64: toPtr("YmFy"),
|
||||
},
|
||||
want: []byte("bar"),
|
||||
},
|
||||
{
|
||||
name: "bytes",
|
||||
value: Value{
|
||||
Type: "blob",
|
||||
Base64: toPtr(""),
|
||||
},
|
||||
want: []byte{},
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
value: Value{
|
||||
Type: "float",
|
||||
Value: 3.14,
|
||||
},
|
||||
want: 3.14,
|
||||
},
|
||||
{
|
||||
name: "timestamp",
|
||||
columnType: "TIMESTAMP",
|
||||
value: Value{
|
||||
Type: "text",
|
||||
Value: "0001-01-01 01:00:00+00:00",
|
||||
},
|
||||
want: time.Time{}.Add(time.Hour),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var columnType *string = nil
|
||||
if tt.columnType != "" {
|
||||
columnType = &tt.columnType
|
||||
}
|
||||
got := tt.value.ToValue(columnType)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ToValue() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value any
|
||||
want Value
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "null",
|
||||
value: nil,
|
||||
want: Value{
|
||||
Type: "null",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
value: 42,
|
||||
want: Value{
|
||||
Type: "integer",
|
||||
Value: strconv.FormatInt(int64(42), 10),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
value: "foo",
|
||||
want: Value{
|
||||
Type: "text",
|
||||
Value: "foo",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bytes",
|
||||
value: []byte{},
|
||||
want: Value{
|
||||
Type: "blob",
|
||||
Base64: toPtr(""),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
value: 3.14,
|
||||
want: Value{
|
||||
Type: "float",
|
||||
Value: 3.14,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "boolean",
|
||||
value: true,
|
||||
want: Value{
|
||||
Type: "integer",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "timestamp",
|
||||
value: time.Time{}.Add(time.Hour),
|
||||
want: Value{
|
||||
Type: "text",
|
||||
Value: "0001-01-01 01:00:00+00:00",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unsupported",
|
||||
value: make(chan int),
|
||||
want: Value{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ToValue(tt.value)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ToValue() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ToValue() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value Value
|
||||
marshaled string
|
||||
}{
|
||||
{
|
||||
name: "null",
|
||||
value: Value{
|
||||
Type: "null",
|
||||
},
|
||||
marshaled: `{"type":"null"}`,
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
value: Value{
|
||||
Type: "integer",
|
||||
Value: strconv.FormatInt(int64(42), 10),
|
||||
},
|
||||
marshaled: `{"type":"integer","value":"42"}`,
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
value: Value{
|
||||
Type: "text",
|
||||
Value: "foo",
|
||||
},
|
||||
marshaled: `{"type":"text","value":"foo"}`,
|
||||
},
|
||||
{
|
||||
name: "bytes",
|
||||
value: Value{
|
||||
Type: "blob",
|
||||
Base64: toPtr("YmFy"),
|
||||
},
|
||||
marshaled: `{"type":"blob","base64":"YmFy"}`,
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
value: Value{
|
||||
Type: "float",
|
||||
Value: 3.14,
|
||||
},
|
||||
marshaled: `{"type":"float","value":3.14}`,
|
||||
},
|
||||
{
|
||||
name: "timestamp",
|
||||
value: Value{
|
||||
Type: "text",
|
||||
Value: time.Time{}.Add(time.Hour),
|
||||
},
|
||||
marshaled: `{"type":"text","value":"0001-01-01T01:00:00Z"}`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := json.Marshal(tt.value)
|
||||
if err != nil {
|
||||
t.Errorf("json.Marshal() error = %v", err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(string(got), tt.marshaled) {
|
||||
t.Errorf("json.Marshal() = %v, want %v", string(got), tt.marshaled)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value Value
|
||||
marshaled string
|
||||
}{
|
||||
{
|
||||
name: "null",
|
||||
value: Value{
|
||||
Type: "null",
|
||||
},
|
||||
marshaled: `{"type":"null"}`,
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
value: Value{
|
||||
Type: "integer",
|
||||
Value: strconv.FormatInt(int64(42), 10),
|
||||
},
|
||||
marshaled: `{"type":"integer","value":"42"}`,
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
value: Value{
|
||||
Type: "text",
|
||||
Value: "foo",
|
||||
},
|
||||
marshaled: `{"type":"text","value":"foo"}`,
|
||||
},
|
||||
{
|
||||
name: "bytes",
|
||||
value: Value{
|
||||
Type: "blob",
|
||||
Base64: toPtr("YmFy"),
|
||||
},
|
||||
marshaled: `{"type":"blob","base64":"YmFy"}`,
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
value: Value{
|
||||
Type: "float",
|
||||
Value: 3.14,
|
||||
},
|
||||
marshaled: `{"type":"float","value":3.14}`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got Value
|
||||
err := json.Unmarshal([]byte(tt.marshaled), &got)
|
||||
if err != nil {
|
||||
t.Errorf("json.Marshal() error = %v", err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.value) {
|
||||
t.Errorf("json.Unmarshal() = %v, want %v", got, tt.value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
11
go_modules/libsql-client-go/libsql/internal/http/driver.go
Normal file
11
go_modules/libsql-client-go/libsql/internal/http/driver.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
||||
"github.com/tursodatabase/libsql-client-go/libsql/internal/http/hranaV2"
|
||||
)
|
||||
|
||||
func Connect(url, jwt, host string, schemaDb bool, remoteEncryptionKey string) driver.Conn {
|
||||
return hranaV2.Connect(url, jwt, host, schemaDb, remoteEncryptionKey)
|
||||
}
|
||||
@@ -0,0 +1,612 @@
|
||||
package hranaV2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
net_url "net/url"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
"github.com/tursodatabase/libsql-client-go/sqliteparserutils"
|
||||
|
||||
"github.com/tursodatabase/libsql-client-go/libsql/internal/hrana"
|
||||
"github.com/tursodatabase/libsql-client-go/libsql/internal/http/shared"
|
||||
)
|
||||
|
||||
var commitHash string
|
||||
|
||||
func init() {
|
||||
if info, ok := debug.ReadBuildInfo(); ok {
|
||||
for _, module := range info.Deps {
|
||||
if module.Path == "github.com/tursodatabase/libsql-client-go" {
|
||||
parts := strings.Split(module.Version, "-")
|
||||
if len(parts) == 3 {
|
||||
commitHash = parts[2][:6]
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
commitHash = "unknown"
|
||||
}
|
||||
|
||||
func Connect(url, jwt, host string, schemaDb bool, encryptionKey string) driver.Conn {
|
||||
return &hranaV2Conn{url, jwt, host, schemaDb, encryptionKey, "", false, 0}
|
||||
}
|
||||
|
||||
type hranaV2Stmt struct {
|
||||
conn *hranaV2Conn
|
||||
numInput int
|
||||
sql string
|
||||
}
|
||||
|
||||
func (s *hranaV2Stmt) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *hranaV2Stmt) NumInput() int {
|
||||
return s.numInput
|
||||
}
|
||||
|
||||
func convertToNamed(args []driver.Value) []driver.NamedValue {
|
||||
if len(args) == 0 {
|
||||
return nil
|
||||
}
|
||||
var result []driver.NamedValue
|
||||
for idx := range args {
|
||||
result = append(result, driver.NamedValue{Ordinal: idx, Value: args[idx]})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *hranaV2Stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
return s.ExecContext(context.Background(), convertToNamed(args))
|
||||
}
|
||||
|
||||
func (s *hranaV2Stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return s.QueryContext(context.Background(), convertToNamed(args))
|
||||
}
|
||||
|
||||
func (s *hranaV2Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
return s.conn.ExecContext(ctx, s.sql, args)
|
||||
}
|
||||
|
||||
func (s *hranaV2Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
return s.conn.QueryContext(ctx, s.sql, args)
|
||||
}
|
||||
|
||||
type hranaV2Conn struct {
|
||||
url string
|
||||
jwt string
|
||||
host string
|
||||
schemaDb bool
|
||||
remoteEncryptionKey string
|
||||
baton string
|
||||
streamClosed bool
|
||||
replicationIndex uint64
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) Ping() error {
|
||||
return h.PingContext(context.Background())
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) PingContext(ctx context.Context) error {
|
||||
_, err := h.executeStmt(ctx, "SELECT 1", nil, false)
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) Prepare(query string) (driver.Stmt, error) {
|
||||
return h.PrepareContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
stmts, paramInfos, err := shared.ParseStatement(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(stmts) != 1 {
|
||||
return nil, fmt.Errorf("only one statement is supported got %d", len(stmts))
|
||||
}
|
||||
numInput := -1
|
||||
if len(paramInfos[0].NamedParameters) == 0 {
|
||||
numInput = paramInfos[0].PositionalParametersCount
|
||||
}
|
||||
return &hranaV2Stmt{h, numInput, query}, nil
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) Close() error {
|
||||
if h.baton != "" {
|
||||
go func(baton, url, jwt, host, encryptionKey string) {
|
||||
msg := hrana.PipelineRequest{Baton: baton}
|
||||
msg.Add(hrana.CloseStream())
|
||||
_, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host, encryptionKey)
|
||||
}(h.baton, h.url, h.jwt, h.host, h.remoteEncryptionKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) Begin() (driver.Tx, error) {
|
||||
return h.BeginTx(context.Background(), driver.TxOptions{})
|
||||
}
|
||||
|
||||
type hranaV2Tx struct {
|
||||
conn *hranaV2Conn
|
||||
}
|
||||
|
||||
func (h hranaV2Tx) Commit() error {
|
||||
_, err := h.conn.ExecContext(context.Background(), "COMMIT", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (h hranaV2Tx) Rollback() error {
|
||||
_, err := h.conn.ExecContext(context.Background(), "ROLLBACK", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
if opts.ReadOnly {
|
||||
return nil, fmt.Errorf("read only transactions are not supported")
|
||||
}
|
||||
if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) {
|
||||
return nil, fmt.Errorf("isolation level %d is not supported", opts.Isolation)
|
||||
}
|
||||
_, err := h.ExecContext(ctx, "BEGIN", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &hranaV2Tx{h}, nil
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, streamClose bool) (*hrana.PipelineResponse, error) {
|
||||
if h.streamClosed {
|
||||
// If the stream is closed, we can't send any more requests using this connection.
|
||||
return nil, fmt.Errorf("stream is closed: %w", driver.ErrBadConn)
|
||||
}
|
||||
if h.baton != "" {
|
||||
msg.Baton = h.baton
|
||||
}
|
||||
if h.replicationIndex > 0 {
|
||||
addReplicationIndex(msg, h.replicationIndex)
|
||||
}
|
||||
result, streamClosed, err := sendPipelineRequest(ctx, msg, h.url, h.jwt, h.host, h.remoteEncryptionKey)
|
||||
if streamClosed {
|
||||
h.streamClosed = true
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.baton = result.Baton
|
||||
if result.Baton == "" && !streamClose {
|
||||
// We need to remember that the stream is closed so we don't try to send any more requests using this connection.
|
||||
h.streamClosed = true
|
||||
}
|
||||
if result.BaseUrl != "" {
|
||||
h.url = result.BaseUrl
|
||||
}
|
||||
if idx := getReplicationIndex(&result); idx > h.replicationIndex {
|
||||
h.replicationIndex = idx
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func addReplicationIndex(msg *hrana.PipelineRequest, replicationIndex uint64) {
|
||||
for i := range msg.Requests {
|
||||
if msg.Requests[i].Stmt != nil && msg.Requests[i].Stmt.ReplicationIndex == nil {
|
||||
msg.Requests[i].Stmt.ReplicationIndex = &replicationIndex
|
||||
} else if msg.Requests[i].Batch != nil && msg.Requests[i].Batch.ReplicationIndex == nil {
|
||||
msg.Requests[i].Batch.ReplicationIndex = &replicationIndex
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getReplicationIndex(response *hrana.PipelineResponse) uint64 {
|
||||
if response == nil || len(response.Results) == 0 {
|
||||
return 0
|
||||
}
|
||||
var replicationIndex uint64
|
||||
for _, result := range response.Results {
|
||||
if result.Response == nil {
|
||||
continue
|
||||
}
|
||||
if result.Response.Type == "execute" {
|
||||
if res, err := result.Response.ExecuteResult(); err == nil && res.ReplicationIndex != nil {
|
||||
if *res.ReplicationIndex > replicationIndex {
|
||||
replicationIndex = *res.ReplicationIndex
|
||||
}
|
||||
}
|
||||
} else if result.Response.Type == "batch" {
|
||||
if res, err := result.Response.BatchResult(); err == nil && res.ReplicationIndex != nil {
|
||||
if *res.ReplicationIndex > replicationIndex {
|
||||
replicationIndex = *res.ReplicationIndex
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return replicationIndex
|
||||
}
|
||||
|
||||
func sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, url string, jwt string, host string, remoteEncryptionKey string) (result hrana.PipelineResponse, streamClosed bool, err error) {
|
||||
reqBody, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return hrana.PipelineResponse{}, false, err
|
||||
}
|
||||
pipelineURL, err := net_url.JoinPath(url, "/v2/pipeline")
|
||||
if err != nil {
|
||||
return hrana.PipelineResponse{}, false, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", pipelineURL, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return hrana.PipelineResponse{}, false, err
|
||||
}
|
||||
if len(jwt) > 0 {
|
||||
req.Header.Set("Authorization", "Bearer "+jwt)
|
||||
}
|
||||
req.Header.Set("x-libsql-client-version", "libsql-remote-go-"+commitHash)
|
||||
if remoteEncryptionKey != "" {
|
||||
req.Header.Set("x-turso-encryption-key", remoteEncryptionKey)
|
||||
}
|
||||
|
||||
req.Host = host
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return hrana.PipelineResponse{}, false, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return hrana.PipelineResponse{}, false, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// We need to remember that the stream is closed so we don't try to send any more requests using this connection.
|
||||
var serverError struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &serverError); err == nil {
|
||||
return hrana.PipelineResponse{}, true, fmt.Errorf("error code %d: %s", resp.StatusCode, serverError.Error)
|
||||
}
|
||||
var errResponse hrana.Error
|
||||
if err := json.Unmarshal(body, &errResponse); err == nil {
|
||||
if errResponse.Code != nil {
|
||||
if *errResponse.Code == "STREAM_EXPIRED" {
|
||||
return hrana.PipelineResponse{}, true, fmt.Errorf("error code %s: %s\n%w", *errResponse.Code, errResponse.Message, driver.ErrBadConn)
|
||||
} else {
|
||||
return hrana.PipelineResponse{}, true, fmt.Errorf("error code %s: %s", *errResponse.Code, errResponse.Message)
|
||||
}
|
||||
}
|
||||
return hrana.PipelineResponse{}, true, errors.New(errResponse.Message)
|
||||
}
|
||||
return hrana.PipelineResponse{}, true, fmt.Errorf("error code %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return hrana.PipelineResponse{}, false, err
|
||||
}
|
||||
return result, false, nil
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) executeMsg(ctx context.Context, msg *hrana.PipelineRequest) (*hrana.PipelineResponse, error) {
|
||||
result, err := h.sendPipelineRequest(ctx, msg, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, r := range result.Results {
|
||||
if r.Error != nil {
|
||||
return nil, errors.New(r.Error.Message)
|
||||
}
|
||||
if r.Response == nil {
|
||||
return nil, errors.New("no response received")
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type chunker struct {
|
||||
chunk []string
|
||||
iterator *sqliteparserutils.StatementIterator
|
||||
limit int
|
||||
}
|
||||
|
||||
func newChunker(iterator *sqliteparserutils.StatementIterator, limit int) *chunker {
|
||||
return &chunker{iterator: iterator, chunk: make([]string, 0, limit), limit: limit}
|
||||
}
|
||||
|
||||
func isTransactionStatement(stmt string) bool {
|
||||
patterns := [][]byte{[]byte("begin"), []byte("commit"), []byte("end"), []byte("rollback")}
|
||||
for _, p := range patterns {
|
||||
if len(stmt) >= len(p) && bytes.Equal(bytes.ToLower([]byte(stmt[0:len(p)])), p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *chunker) Next() (chunk []string, isEOF bool) {
|
||||
c.chunk = c.chunk[:0]
|
||||
var stmt string
|
||||
for !isEOF && len(c.chunk) < c.limit {
|
||||
stmt, _, isEOF = c.iterator.Next()
|
||||
// We need to skip transaction statements. Chunks run in a transaction by default.
|
||||
if stmt != "" && !isTransactionStatement(stmt) {
|
||||
c.chunk = append(c.chunk, stmt)
|
||||
}
|
||||
}
|
||||
return c.chunk, isEOF
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) executeSingleStmt(ctx context.Context, stmt string, wantRows bool) (*hrana.PipelineResponse, error) {
|
||||
msg := &hrana.PipelineRequest{}
|
||||
executeStream, err := hrana.ExecuteStream(stmt, nil, wantRows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL: %s\n%w", stmt, err)
|
||||
}
|
||||
msg.Add(*executeStream)
|
||||
res, err := h.executeMsg(ctx, msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL: %s\n%w", stmt, err)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) executeInChunks(ctx context.Context, query string, wantRows bool) (*hrana.PipelineResponse, error) {
|
||||
const chunkSize = 4096
|
||||
iterator := sqliteparserutils.CreateStatementIterator(query)
|
||||
chunker := newChunker(iterator, chunkSize)
|
||||
|
||||
chunk, isEOF := chunker.Next()
|
||||
if isEOF && len(chunk) == 1 {
|
||||
return h.executeSingleStmt(ctx, chunk[0], wantRows)
|
||||
}
|
||||
|
||||
_, err := h.executeSingleStmt(ctx, "BEGIN", false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
batch := &hrana.Batch{Steps: make([]hrana.BatchStep, chunkSize)}
|
||||
msg := &hrana.PipelineRequest{}
|
||||
msg.Add(hrana.StreamRequest{Type: "batch", Batch: batch})
|
||||
for idx := range batch.Steps {
|
||||
batch.Steps[idx].Stmt.WantRows = wantRows
|
||||
}
|
||||
|
||||
result := &hrana.PipelineResponse{}
|
||||
for {
|
||||
for idx := range chunk {
|
||||
batch.Steps[idx].Stmt.Sql = &chunk[idx]
|
||||
}
|
||||
if len(chunk) < chunkSize {
|
||||
// We can trim batch.Steps because this is the last chunk anyway.
|
||||
// isEOF has to be true at this point.
|
||||
batch.Steps = batch.Steps[:len(chunk)]
|
||||
}
|
||||
res, err := h.executeMsg(ctx, msg)
|
||||
if err != nil {
|
||||
h.closeStream()
|
||||
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
|
||||
}
|
||||
result.Baton = res.Baton
|
||||
result.BaseUrl = res.BaseUrl
|
||||
result.Results = append(result.Results, res.Results...)
|
||||
if isEOF {
|
||||
break
|
||||
}
|
||||
chunk, isEOF = chunker.Next()
|
||||
}
|
||||
_, err = h.executeSingleStmt(ctx, "COMMIT", false)
|
||||
if err != nil {
|
||||
h.closeStream()
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) executeStmt(ctx context.Context, query string, args []driver.NamedValue, wantRows bool) (*hrana.PipelineResponse, error) {
|
||||
const querySizeLimitForChunking = 20 * 1024 * 1024
|
||||
if len(args) == 0 && len(query) > querySizeLimitForChunking && !h.schemaDb {
|
||||
return h.executeInChunks(ctx, query, wantRows)
|
||||
}
|
||||
stmts, params, err := shared.ParseStatementAndArgs(query, args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
|
||||
}
|
||||
msg := &hrana.PipelineRequest{}
|
||||
if len(stmts) == 1 {
|
||||
var p *shared.Params
|
||||
if len(params) > 0 {
|
||||
p = ¶ms[0]
|
||||
}
|
||||
executeStream, err := hrana.ExecuteStream(stmts[0], p, wantRows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
|
||||
}
|
||||
msg.Add(*executeStream)
|
||||
} else {
|
||||
batchStream, err := hrana.BatchStream(stmts, params, wantRows, !h.schemaDb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
|
||||
}
|
||||
msg.Add(*batchStream)
|
||||
}
|
||||
|
||||
resp, err := h.executeMsg(ctx, msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
result, err := h.executeStmt(ctx, query, args, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch result.Results[0].Response.Type {
|
||||
case "execute":
|
||||
res, err := result.Results[0].Response.ExecuteResult()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return shared.NewResult(res.GetLastInsertRowId(), int64(res.AffectedRowCount)), nil
|
||||
case "batch":
|
||||
res, err := result.Results[0].Response.BatchResult()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lastInsertRowId := int64(0)
|
||||
affectedRowCount := int64(0)
|
||||
upperBound := len(res.StepResults)
|
||||
if !h.schemaDb {
|
||||
upperBound -= 1
|
||||
}
|
||||
for idx := 0; idx < upperBound; idx++ {
|
||||
r := res.StepResults[idx]
|
||||
rowId := r.GetLastInsertRowId()
|
||||
if rowId > 0 {
|
||||
lastInsertRowId = rowId
|
||||
}
|
||||
affectedRowCount += int64(r.AffectedRowCount)
|
||||
}
|
||||
return shared.NewResult(lastInsertRowId, affectedRowCount), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("failed to execute SQL: %s\n%s", query, "unknown response type")
|
||||
}
|
||||
}
|
||||
|
||||
type StmtResultRowsProvider struct {
|
||||
r *hrana.StmtResult
|
||||
}
|
||||
|
||||
func (p *StmtResultRowsProvider) SetsCount() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (p *StmtResultRowsProvider) RowsCount(setIdx int) int {
|
||||
if setIdx != 0 {
|
||||
return 0
|
||||
}
|
||||
return len(p.r.Rows)
|
||||
}
|
||||
|
||||
func (p *StmtResultRowsProvider) Columns(setIdx int) []string {
|
||||
if setIdx != 0 {
|
||||
return nil
|
||||
}
|
||||
res := make([]string, len(p.r.Cols))
|
||||
for i, c := range p.r.Cols {
|
||||
if c.Name != nil {
|
||||
res[i] = *c.Name
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (p *StmtResultRowsProvider) FieldValue(setIdx, rowIdx, colIdx int) driver.Value {
|
||||
if setIdx != 0 {
|
||||
return nil
|
||||
}
|
||||
return p.r.Rows[rowIdx][colIdx].ToValue(p.r.Cols[colIdx].Type)
|
||||
}
|
||||
|
||||
func (p *StmtResultRowsProvider) Error(setIdx int) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *StmtResultRowsProvider) HasResult(setIdx int) bool {
|
||||
return setIdx == 0
|
||||
}
|
||||
|
||||
type BatchResultRowsProvider struct {
|
||||
r *hrana.BatchResult
|
||||
}
|
||||
|
||||
func (p *BatchResultRowsProvider) SetsCount() int {
|
||||
return len(p.r.StepResults)
|
||||
}
|
||||
|
||||
func (p *BatchResultRowsProvider) RowsCount(setIdx int) int {
|
||||
if setIdx >= len(p.r.StepResults) || p.r.StepResults[setIdx] == nil {
|
||||
return 0
|
||||
}
|
||||
return len(p.r.StepResults[setIdx].Rows)
|
||||
}
|
||||
|
||||
func (p *BatchResultRowsProvider) Columns(setIdx int) []string {
|
||||
if setIdx >= len(p.r.StepResults) || p.r.StepResults[setIdx] == nil {
|
||||
return nil
|
||||
}
|
||||
res := make([]string, len(p.r.StepResults[setIdx].Cols))
|
||||
for i, c := range p.r.StepResults[setIdx].Cols {
|
||||
if c.Name != nil {
|
||||
res[i] = *c.Name
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (p *BatchResultRowsProvider) FieldValue(setIdx, rowIdx, colIdx int) driver.Value {
|
||||
if setIdx >= len(p.r.StepResults) || p.r.StepResults[setIdx] == nil {
|
||||
return nil
|
||||
}
|
||||
return p.r.StepResults[setIdx].Rows[rowIdx][colIdx].ToValue(p.r.StepResults[setIdx].Cols[colIdx].Type)
|
||||
}
|
||||
|
||||
func (p *BatchResultRowsProvider) Error(setIdx int) string {
|
||||
if setIdx >= len(p.r.StepErrors) || p.r.StepErrors[setIdx] == nil {
|
||||
return ""
|
||||
}
|
||||
return p.r.StepErrors[setIdx].Message
|
||||
}
|
||||
|
||||
func (p *BatchResultRowsProvider) HasResult(setIdx int) bool {
|
||||
return setIdx < len(p.r.StepResults) && p.r.StepResults[setIdx] != nil
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
result, err := h.executeStmt(ctx, query, args, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch result.Results[0].Response.Type {
|
||||
case "execute":
|
||||
res, err := result.Results[0].Response.ExecuteResult()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return shared.NewRows(&StmtResultRowsProvider{res}), nil
|
||||
case "batch":
|
||||
res, err := result.Results[0].Response.BatchResult()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !h.schemaDb {
|
||||
res.StepResults = res.StepResults[:len(res.StepResults)-1]
|
||||
res.StepErrors = res.StepErrors[:len(res.StepErrors)-1]
|
||||
}
|
||||
return shared.NewRows(&BatchResultRowsProvider{res}), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("failed to execute SQL: %s\n%s", query, "unknown response type")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) closeStream() {
|
||||
if h.baton != "" {
|
||||
go func(baton, url, jwt, host, encryptionKey string) {
|
||||
msg := hrana.PipelineRequest{Baton: baton}
|
||||
msg.Add(hrana.CloseStream())
|
||||
_, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host, encryptionKey)
|
||||
}(h.baton, h.url, h.jwt, h.host, h.remoteEncryptionKey)
|
||||
h.baton = ""
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hranaV2Conn) ResetSession(ctx context.Context) error {
|
||||
h.closeStream()
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package shared
|
||||
|
||||
type result struct {
|
||||
id int64
|
||||
changes int64
|
||||
}
|
||||
|
||||
func NewResult(id, changes int64) *result {
|
||||
return &result{id: id, changes: changes}
|
||||
}
|
||||
|
||||
func (r *result) LastInsertId() (int64, error) {
|
||||
return r.id, nil
|
||||
}
|
||||
|
||||
func (r *result) RowsAffected() (int64, error) {
|
||||
return r.changes, nil
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
type rowsProvider interface {
|
||||
SetsCount() int
|
||||
RowsCount(setIdx int) int
|
||||
Columns(setIdx int) []string
|
||||
FieldValue(setIdx, rowIdx int, columnIdx int) driver.Value
|
||||
Error(setIdx int) string
|
||||
HasResult(setIdx int) bool
|
||||
}
|
||||
|
||||
func NewRows(result rowsProvider) driver.Rows {
|
||||
return &rows{result: result}
|
||||
}
|
||||
|
||||
type rows struct {
|
||||
result rowsProvider
|
||||
currentResultSetIndex int
|
||||
currentRowIdx int
|
||||
}
|
||||
|
||||
func (r *rows) Columns() []string {
|
||||
return r.result.Columns(r.currentResultSetIndex)
|
||||
}
|
||||
|
||||
func (r *rows) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rows) Next(dest []driver.Value) error {
|
||||
if r.currentRowIdx == r.result.RowsCount(r.currentResultSetIndex) {
|
||||
return io.EOF
|
||||
}
|
||||
count := len(r.result.Columns(r.currentResultSetIndex))
|
||||
for idx := 0; idx < count; idx++ {
|
||||
dest[idx] = r.result.FieldValue(r.currentResultSetIndex, r.currentRowIdx, idx)
|
||||
}
|
||||
r.currentRowIdx++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rows) HasNextResultSet() bool {
|
||||
return r.currentResultSetIndex < r.result.SetsCount()-1
|
||||
}
|
||||
|
||||
func (r *rows) NextResultSet() error {
|
||||
if !r.HasNextResultSet() {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
r.currentResultSetIndex++
|
||||
r.currentRowIdx = 0
|
||||
|
||||
errStr := r.result.Error(r.currentResultSetIndex)
|
||||
if errStr != "" {
|
||||
return fmt.Errorf("failed to execute statement\n%s", errStr)
|
||||
}
|
||||
if !r.result.HasResult(r.currentResultSetIndex) {
|
||||
return fmt.Errorf("no results for statement")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,257 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
|
||||
"github.com/antlr4-go/antlr/v4"
|
||||
"github.com/tursodatabase/libsql-client-go/sqliteparser"
|
||||
"github.com/tursodatabase/libsql-client-go/sqliteparserutils"
|
||||
)
|
||||
|
||||
type ParamsInfo struct {
|
||||
NamedParameters []string
|
||||
PositionalParametersCount int
|
||||
}
|
||||
|
||||
func ParseStatement(sql string) ([]string, []ParamsInfo, error) {
|
||||
stmts, _ := sqliteparserutils.SplitStatement(sql)
|
||||
|
||||
stmtsParams := make([]ParamsInfo, len(stmts))
|
||||
for idx, stmt := range stmts {
|
||||
nameParams, positionalParamsCount, err := extractParameters(stmt)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
stmtsParams[idx] = ParamsInfo{nameParams, positionalParamsCount}
|
||||
}
|
||||
return stmts, stmtsParams, nil
|
||||
}
|
||||
|
||||
func ParseStatementAndArgs(sql string, args []driver.NamedValue) ([]string, []Params, error) {
|
||||
stmts, _ := sqliteparserutils.SplitStatement(sql)
|
||||
|
||||
if len(args) == 0 {
|
||||
return stmts, nil, nil
|
||||
}
|
||||
parameters, err := ConvertArgs(args)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
stmtsParams := make([]Params, len(stmts))
|
||||
totalParametersAlreadyUsed := 0
|
||||
for idx, stmt := range stmts {
|
||||
stmtParams, err := generateStatementParameters(stmt, parameters, totalParametersAlreadyUsed)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("fail to generate statement parameter. statement: %s. error: %v", stmt, err)
|
||||
}
|
||||
stmtsParams[idx] = stmtParams
|
||||
totalParametersAlreadyUsed += stmtParams.Len()
|
||||
}
|
||||
return stmts, stmtsParams, nil
|
||||
}
|
||||
|
||||
type paramsType int
|
||||
|
||||
const (
|
||||
namedParameters paramsType = iota
|
||||
positionalParameters
|
||||
)
|
||||
|
||||
type Params struct {
|
||||
positional []any
|
||||
named map[string]any
|
||||
}
|
||||
|
||||
func (p *Params) MarshalJSON() ([]byte, error) {
|
||||
if len(p.named) > 0 {
|
||||
return json.Marshal(p.named)
|
||||
}
|
||||
if len(p.positional) > 0 {
|
||||
return json.Marshal(p.positional)
|
||||
}
|
||||
return json.Marshal(make([]any, 0))
|
||||
}
|
||||
|
||||
func (p *Params) Named() map[string]any {
|
||||
return p.named
|
||||
}
|
||||
|
||||
func (p *Params) Positional() []any {
|
||||
return p.positional
|
||||
}
|
||||
|
||||
func (p *Params) Len() int {
|
||||
if p.named != nil {
|
||||
return len(p.named)
|
||||
}
|
||||
|
||||
return len(p.positional)
|
||||
}
|
||||
|
||||
func (p *Params) Type() paramsType {
|
||||
if p.named != nil {
|
||||
return namedParameters
|
||||
}
|
||||
|
||||
return positionalParameters
|
||||
}
|
||||
|
||||
func NewParams(t paramsType) Params {
|
||||
p := Params{}
|
||||
switch t {
|
||||
case namedParameters:
|
||||
p.named = make(map[string]any)
|
||||
case positionalParameters:
|
||||
p.positional = make([]any, 0)
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func getParamType(arg *driver.NamedValue) paramsType {
|
||||
if arg.Name == "" {
|
||||
return positionalParameters
|
||||
}
|
||||
return namedParameters
|
||||
}
|
||||
|
||||
func ConvertArgs(args []driver.NamedValue) (Params, error) {
|
||||
if len(args) == 0 {
|
||||
return NewParams(positionalParameters), nil
|
||||
}
|
||||
|
||||
var sortedArgs []*driver.NamedValue
|
||||
for idx := range args {
|
||||
sortedArgs = append(sortedArgs, &args[idx])
|
||||
}
|
||||
sort.Slice(sortedArgs, func(i, j int) bool {
|
||||
return sortedArgs[i].Ordinal < sortedArgs[j].Ordinal
|
||||
})
|
||||
|
||||
parametersType := getParamType(sortedArgs[0])
|
||||
parameters := NewParams(parametersType)
|
||||
for _, arg := range sortedArgs {
|
||||
if parametersType != getParamType(arg) {
|
||||
return Params{}, fmt.Errorf("driver does not accept positional and named parameters at the same time")
|
||||
}
|
||||
|
||||
switch parametersType {
|
||||
case positionalParameters:
|
||||
parameters.positional = append(parameters.positional, arg.Value)
|
||||
case namedParameters:
|
||||
parameters.named[arg.Name] = arg.Value
|
||||
}
|
||||
}
|
||||
return parameters, nil
|
||||
}
|
||||
|
||||
func isExplain(stmt string) bool {
|
||||
statementStream := antlr.NewInputStream(stmt)
|
||||
|
||||
lexer := sqliteparser.NewSQLiteLexer(statementStream)
|
||||
tokenStream := antlr.NewCommonTokenStream(lexer, 0)
|
||||
firstToken := tokenStream.LT(1)
|
||||
return firstToken.GetTokenType() == sqliteparser.SQLiteParserEXPLAIN_
|
||||
}
|
||||
|
||||
func generateStatementParameters(stmt string, queryParams Params, positionalParametersOffset int) (Params, error) {
|
||||
nameParams, positionalParamsCount, err := extractParameters(stmt)
|
||||
if err != nil {
|
||||
return Params{}, err
|
||||
}
|
||||
|
||||
stmtParams := NewParams(queryParams.Type())
|
||||
|
||||
switch queryParams.Type() {
|
||||
case positionalParameters:
|
||||
if positionalParametersOffset+positionalParamsCount > len(queryParams.positional) {
|
||||
if isExplain(stmt) {
|
||||
return Params{}, nil
|
||||
}
|
||||
// Positional parameters with indexes most of the time will have fewer args than parameters.
|
||||
stmtParams.positional = queryParams.positional[positionalParametersOffset:len(queryParams.positional)]
|
||||
} else {
|
||||
stmtParams.positional = queryParams.positional[positionalParametersOffset : positionalParametersOffset+positionalParamsCount]
|
||||
}
|
||||
case namedParameters:
|
||||
stmtParametersNeeded := make(map[string]bool)
|
||||
for _, stmtParametersName := range nameParams {
|
||||
stmtParametersNeeded[stmtParametersName] = true
|
||||
}
|
||||
for queryParamsName, queryParamsValue := range queryParams.named {
|
||||
if stmtParametersNeeded[queryParamsName] {
|
||||
stmtParams.named[queryParamsName] = queryParamsValue
|
||||
delete(stmtParametersNeeded, queryParamsName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stmtParams, nil
|
||||
}
|
||||
|
||||
func extractParameters(stmt string) (nameParams []string, positionalParamsCount int, err error) {
|
||||
statementStream := antlr.NewInputStream(stmt)
|
||||
sqliteparser.NewSQLiteLexer(statementStream)
|
||||
lexer := sqliteparser.NewSQLiteLexer(statementStream)
|
||||
|
||||
allTokens := lexer.GetAllTokens()
|
||||
|
||||
nameParamsSet := make(map[string]bool)
|
||||
positionalParamsWithIndexesSet := make(map[string]bool)
|
||||
|
||||
// ^: asserts the start of the string.
|
||||
// \?: matches a literal question mark character.
|
||||
// (\d+)? captures one more digits (0-9) in a group, but group is optional due to the ? quantifier.
|
||||
// $: asserts the end of thr string so as to avoid this scenario: ?123ABC.
|
||||
re := regexp.MustCompile(`^\?(\d+)?$`)
|
||||
|
||||
for _, token := range allTokens {
|
||||
tokenType := token.GetTokenType()
|
||||
if tokenType == sqliteparser.SQLiteLexerBIND_PARAMETER {
|
||||
parameter := token.GetText()
|
||||
|
||||
match := re.FindStringSubmatch(parameter)
|
||||
if match == nil {
|
||||
paramWithoutPrefix, err := removeParamPrefix(parameter)
|
||||
if err != nil {
|
||||
return []string{}, 0, err
|
||||
}
|
||||
nameParamsSet[paramWithoutPrefix] = true
|
||||
continue
|
||||
}
|
||||
|
||||
posS := string(match[1])
|
||||
if posS == "" {
|
||||
// When an empty string, it means the parameter is a
|
||||
// positional parameter without an index (e.g, ?).
|
||||
positionalParamsCount++
|
||||
} else {
|
||||
// Positional parameter with indexes (e.g., ?<number>)
|
||||
// must be deduped.
|
||||
positionalParamsWithIndexesSet[posS] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nameParams = make([]string, 0, len(nameParamsSet))
|
||||
for k := range nameParamsSet {
|
||||
nameParams = append(nameParams, k)
|
||||
}
|
||||
|
||||
// Only count unique number of positional parameters.
|
||||
positionalParamsCount += len(positionalParamsWithIndexesSet)
|
||||
|
||||
return nameParams, positionalParamsCount, nil
|
||||
}
|
||||
|
||||
func removeParamPrefix(paramName string) (string, error) {
|
||||
if paramName[0] == ':' || paramName[0] == '@' || paramName[0] == '$' {
|
||||
return paramName[1:], nil
|
||||
}
|
||||
return "", fmt.Errorf("all named parameters must start with ':', or '@' or '$'")
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractParameters(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
nameParams []string
|
||||
positionalParamsCount int
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "OnlyColonNameParams",
|
||||
value: "select :column from :table",
|
||||
nameParams: []string{"column", "table"},
|
||||
},
|
||||
{
|
||||
name: "OnlyAtNameParams",
|
||||
value: "select @column from @table",
|
||||
nameParams: []string{"column", "table"},
|
||||
},
|
||||
{
|
||||
name: "OnlyDollarSignNameParams",
|
||||
value: "select $column from $table",
|
||||
nameParams: []string{"column", "table"},
|
||||
},
|
||||
{
|
||||
name: "RepeatedNamedParameter",
|
||||
value: "select :number, :number",
|
||||
nameParams: []string{"number"},
|
||||
},
|
||||
{
|
||||
name: "OnlyPositionalParams",
|
||||
value: "select ? from ?",
|
||||
nameParams: []string{},
|
||||
positionalParamsCount: 2,
|
||||
},
|
||||
{
|
||||
name: "OnlyPositionalParamsWithoutIndexes",
|
||||
value: "select ? from ?",
|
||||
nameParams: []string{},
|
||||
positionalParamsCount: 2,
|
||||
},
|
||||
{
|
||||
name: "OnlyPositionalParamsWithIndexes (dedup)",
|
||||
value: "select ?1 from ?1",
|
||||
nameParams: []string{},
|
||||
positionalParamsCount: 1,
|
||||
},
|
||||
{
|
||||
name: "PositionalParamsWithIndexes",
|
||||
value: "select ? from ?1",
|
||||
nameParams: []string{},
|
||||
positionalParamsCount: 2,
|
||||
},
|
||||
{
|
||||
name: "MixedParams",
|
||||
value: "select :column1, @column2, $column3, ? from ?",
|
||||
nameParams: []string{"column1", "column2", "column3"},
|
||||
positionalParamsCount: 2,
|
||||
},
|
||||
{
|
||||
name: "NoParams",
|
||||
value: "select myColumn from myTable",
|
||||
nameParams: []string{},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotNameParams, gotUniquePositionalParamsCount, gotErr := extractParameters(tt.value)
|
||||
sort.Strings(gotNameParams)
|
||||
sort.Strings(tt.nameParams)
|
||||
if !reflect.DeepEqual(gotNameParams, tt.nameParams) {
|
||||
t.Errorf("got nameParams %#v, want %#v", gotNameParams, tt.nameParams)
|
||||
}
|
||||
if !reflect.DeepEqual(gotUniquePositionalParamsCount, tt.positionalParamsCount) {
|
||||
t.Errorf("got positionalParams %#v, want %#v", gotUniquePositionalParamsCount, tt.positionalParamsCount)
|
||||
}
|
||||
if !reflect.DeepEqual(gotUniquePositionalParamsCount, tt.positionalParamsCount) {
|
||||
t.Errorf("got err %v, want %v", gotErr, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
194
go_modules/libsql-client-go/libsql/internal/ws/driver.go
Normal file
194
go_modules/libsql-client-go/libsql/internal/ws/driver.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type result struct {
|
||||
id int64
|
||||
changes int64
|
||||
}
|
||||
|
||||
func (r *result) LastInsertId() (int64, error) {
|
||||
return r.id, nil
|
||||
}
|
||||
|
||||
func (r *result) RowsAffected() (int64, error) {
|
||||
return r.changes, nil
|
||||
}
|
||||
|
||||
type rows struct {
|
||||
res *execResponse
|
||||
currentRowIdx int
|
||||
}
|
||||
|
||||
func (r *rows) Columns() []string {
|
||||
return r.res.columns()
|
||||
}
|
||||
|
||||
func (r *rows) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rows) Next(dest []driver.Value) error {
|
||||
if r.currentRowIdx == r.res.rowsCount() {
|
||||
return io.EOF
|
||||
}
|
||||
count := r.res.rowLen(r.currentRowIdx)
|
||||
for idx := 0; idx < count; idx++ {
|
||||
v, err := r.res.value(r.currentRowIdx, idx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dest[idx] = v
|
||||
}
|
||||
r.currentRowIdx++
|
||||
return nil
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
ws *websocketConn
|
||||
}
|
||||
|
||||
func Connect(url string, jwt string) (*conn, error) {
|
||||
c, err := connect(url, jwt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &conn{c}, nil
|
||||
}
|
||||
|
||||
type stmt struct {
|
||||
c *conn
|
||||
query string
|
||||
}
|
||||
|
||||
func (s stmt) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stmt) NumInput() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func convertToNamed(args []driver.Value) []driver.NamedValue {
|
||||
if len(args) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := []driver.NamedValue{}
|
||||
for idx := range args {
|
||||
result = append(result, driver.NamedValue{Ordinal: idx, Value: args[idx]})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
return s.ExecContext(context.Background(), convertToNamed(args))
|
||||
}
|
||||
|
||||
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return s.QueryContext(context.Background(), convertToNamed(args))
|
||||
}
|
||||
|
||||
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
return s.c.ExecContext(ctx, s.query, args)
|
||||
}
|
||||
|
||||
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
return s.c.QueryContext(ctx, s.query, args)
|
||||
}
|
||||
|
||||
func (c *conn) Ping() error {
|
||||
return c.PingContext(context.Background())
|
||||
}
|
||||
|
||||
func (c *conn) PingContext(ctx context.Context) error {
|
||||
_, err := c.ws.exec(ctx, "SELECT 1", params{}, false)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
||||
return c.PrepareContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
|
||||
return stmt{c, query}, nil
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
return c.ws.Close()
|
||||
}
|
||||
|
||||
type tx struct {
|
||||
c *conn
|
||||
}
|
||||
|
||||
func (t tx) Commit() error {
|
||||
_, err := t.c.ExecContext(context.Background(), "COMMIT", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t tx) Rollback() error {
|
||||
_, err := t.c.ExecContext(context.Background(), "ROLLBACK", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) Begin() (driver.Tx, error) {
|
||||
return c.BeginTx(context.Background(), driver.TxOptions{})
|
||||
}
|
||||
|
||||
func (c *conn) BeginTx(ctx context.Context, _ driver.TxOptions) (driver.Tx, error) {
|
||||
_, err := c.ExecContext(ctx, "BEGIN", nil)
|
||||
if err != nil {
|
||||
return tx{nil}, err
|
||||
}
|
||||
return tx{c}, nil
|
||||
}
|
||||
|
||||
func convertArgs(args []driver.NamedValue) params {
|
||||
if len(args) == 0 {
|
||||
return params{}
|
||||
}
|
||||
positionalArgs := [](*driver.NamedValue){}
|
||||
namedArgs := []namedParam{}
|
||||
for idx := range args {
|
||||
if len(args[idx].Name) > 0 {
|
||||
namedArgs = append(namedArgs, namedParam{args[idx].Name, args[idx].Value})
|
||||
} else {
|
||||
positionalArgs = append(positionalArgs, &args[idx])
|
||||
}
|
||||
}
|
||||
sort.Slice(positionalArgs, func(i, j int) bool {
|
||||
return positionalArgs[i].Ordinal < positionalArgs[j].Ordinal
|
||||
})
|
||||
posArgs := [](any){}
|
||||
for idx := range positionalArgs {
|
||||
posArgs = append(posArgs, positionalArgs[idx].Value)
|
||||
}
|
||||
return params{PositinalArgs: posArgs, NamedArgs: namedArgs}
|
||||
}
|
||||
|
||||
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
res, err := c.ws.exec(ctx, query, convertArgs(args), false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result{res.lastInsertId(), res.affectedRowCount()}, nil
|
||||
}
|
||||
|
||||
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
res, err := c.ws.exec(ctx, query, convertArgs(args), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rows{res, 0}, nil
|
||||
}
|
||||
322
go_modules/libsql-client-go/libsql/internal/ws/websockets.go
Normal file
322
go_modules/libsql-client-go/libsql/internal/ws/websockets.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
)
|
||||
|
||||
// defaultWSTimeout specifies the timeout used for initial http connection
|
||||
var defaultWSTimeout = 120 * time.Second
|
||||
|
||||
func errorMsg(errorResp interface{}) string {
|
||||
return errorResp.(map[string]interface{})["error"].(map[string]interface{})["message"].(string)
|
||||
}
|
||||
|
||||
func isErrorResp(resp interface{}) bool {
|
||||
return resp.(map[string]interface{})["type"] == "response_error"
|
||||
}
|
||||
|
||||
type websocketConn struct {
|
||||
conn *websocket.Conn
|
||||
idPool *idPool
|
||||
}
|
||||
|
||||
type namedParam struct {
|
||||
Name string
|
||||
Value any
|
||||
}
|
||||
|
||||
type params struct {
|
||||
PositinalArgs []any
|
||||
NamedArgs []namedParam
|
||||
}
|
||||
|
||||
func convertValue(v any) (map[string]interface{}, error) {
|
||||
res := map[string]interface{}{}
|
||||
if v == nil {
|
||||
res["type"] = "null"
|
||||
} else if integer, ok := v.(int64); ok {
|
||||
res["type"] = "integer"
|
||||
res["value"] = strconv.FormatInt(integer, 10)
|
||||
} else if text, ok := v.(string); ok {
|
||||
res["type"] = "text"
|
||||
res["value"] = text
|
||||
} else if blob, ok := v.([]byte); ok {
|
||||
res["type"] = "blob"
|
||||
res["base64"] = base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString(blob)
|
||||
} else if float, ok := v.(float64); ok {
|
||||
res["type"] = "float"
|
||||
res["value"] = float
|
||||
} else if boolean, ok := v.(bool); ok {
|
||||
res["type"] = "integer"
|
||||
if boolean {
|
||||
res["value"] = "1"
|
||||
} else {
|
||||
res["value"] = "0"
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported value type: %s", v)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
type execResponse struct {
|
||||
resp map[string]interface{}
|
||||
}
|
||||
|
||||
func (r *execResponse) affectedRowCount() int64 {
|
||||
return int64(r.resp["affected_row_count"].(float64))
|
||||
}
|
||||
|
||||
func (r *execResponse) lastInsertId() int64 {
|
||||
id, ok := r.resp["last_insert_rowid"].(string)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
value, _ := strconv.ParseInt(id, 10, 64)
|
||||
return value
|
||||
}
|
||||
|
||||
func (r *execResponse) columns() []string {
|
||||
res := []string{}
|
||||
cols := r.resp["cols"].([]interface{})
|
||||
for idx := range cols {
|
||||
var v string = ""
|
||||
if cols[idx].(map[string]interface{})["name"] != nil {
|
||||
v = cols[idx].(map[string]interface{})["name"].(string)
|
||||
}
|
||||
res = append(res, v)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (r *execResponse) rowsCount() int {
|
||||
return len(r.resp["rows"].([]interface{}))
|
||||
}
|
||||
|
||||
func (r *execResponse) rowLen(rowIdx int) int {
|
||||
return len(r.resp["rows"].([]interface{})[rowIdx].([]interface{}))
|
||||
}
|
||||
|
||||
func (r *execResponse) value(rowIdx int, colIdx int) (any, error) {
|
||||
val := r.resp["rows"].([]interface{})[rowIdx].([]interface{})[colIdx].(map[string]interface{})
|
||||
switch val["type"] {
|
||||
case "null":
|
||||
return nil, nil
|
||||
case "integer":
|
||||
v, err := strconv.ParseInt(val["value"].(string), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v, nil
|
||||
case "text":
|
||||
return val["value"].(string), nil
|
||||
case "blob":
|
||||
base64Encoded := val["base64"].(string)
|
||||
v, err := base64.StdEncoding.WithPadding(base64.NoPadding).DecodeString(base64Encoded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v, nil
|
||||
case "float":
|
||||
return val["value"].(float64), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unrecognized value type: %s", val["type"])
|
||||
}
|
||||
|
||||
func (ws *websocketConn) exec(ctx context.Context, sql string, sqlParams params, wantRows bool) (*execResponse, error) {
|
||||
requestId := ws.idPool.Get()
|
||||
defer ws.idPool.Put(requestId)
|
||||
stmt := map[string]interface{}{
|
||||
"sql": sql,
|
||||
"want_rows": wantRows,
|
||||
}
|
||||
if len(sqlParams.PositinalArgs) > 0 {
|
||||
args := []map[string]interface{}{}
|
||||
for idx := range sqlParams.PositinalArgs {
|
||||
v, err := convertValue(sqlParams.PositinalArgs[idx])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, v)
|
||||
}
|
||||
stmt["args"] = args
|
||||
}
|
||||
if len(sqlParams.NamedArgs) > 0 {
|
||||
args := []map[string]interface{}{}
|
||||
for idx := range sqlParams.NamedArgs {
|
||||
v, err := convertValue(sqlParams.NamedArgs[idx].Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arg := map[string]interface{}{
|
||||
"name": sqlParams.NamedArgs[idx].Name,
|
||||
"value": v,
|
||||
}
|
||||
args = append(args, arg)
|
||||
}
|
||||
stmt["named_args"] = args
|
||||
}
|
||||
err := wsjson.Write(ctx, ws.conn, map[string]interface{}{
|
||||
"type": "request",
|
||||
"request_id": requestId,
|
||||
"request": map[string]interface{}{
|
||||
"type": "execute",
|
||||
"stream_id": 0,
|
||||
"stmt": stmt,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", driver.ErrBadConn, err.Error())
|
||||
}
|
||||
|
||||
var resp interface{}
|
||||
if err = wsjson.Read(ctx, ws.conn, &resp); err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", driver.ErrBadConn, err.Error())
|
||||
}
|
||||
|
||||
if isErrorResp(resp) {
|
||||
err = fmt.Errorf("unable to execute %s: %s", sql, errorMsg(resp))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &execResponse{resp.(map[string]interface{})["response"].(map[string]interface{})["result"].(map[string]interface{})}, nil
|
||||
}
|
||||
|
||||
func (ws *websocketConn) Close() error {
|
||||
return ws.conn.Close(websocket.StatusNormalClosure, "All's good")
|
||||
}
|
||||
|
||||
func connect(url string, jwt string) (*websocketConn, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultWSTimeout)
|
||||
defer cancel()
|
||||
c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
|
||||
Subprotocols: []string{"hrana1"},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.SetReadLimit(1024 * 1024 * 16) // 16MB
|
||||
|
||||
err = wsjson.Write(ctx, c, map[string]interface{}{
|
||||
"type": "hello",
|
||||
"jwt": jwt,
|
||||
})
|
||||
if err != nil {
|
||||
c.Close(websocket.StatusInternalError, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = wsjson.Write(ctx, c, map[string]interface{}{
|
||||
"type": "request",
|
||||
"request_id": 0,
|
||||
"request": map[string]interface{}{
|
||||
"type": "open_stream",
|
||||
"stream_id": 0,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
c.Close(websocket.StatusInternalError, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var helloResp interface{}
|
||||
err = wsjson.Read(ctx, c, &helloResp)
|
||||
if err != nil {
|
||||
c.Close(websocket.StatusInternalError, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
if helloResp.(map[string]interface{})["type"] == "hello_error" {
|
||||
err = fmt.Errorf("handshake error: %s", errorMsg(helloResp))
|
||||
c.Close(websocket.StatusProtocolError, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var openStreamResp interface{}
|
||||
err = wsjson.Read(ctx, c, &openStreamResp)
|
||||
if err != nil {
|
||||
c.Close(websocket.StatusInternalError, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isErrorResp(openStreamResp) {
|
||||
err = fmt.Errorf("unable to open stream: %s", errorMsg(helloResp))
|
||||
c.Close(websocket.StatusProtocolError, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
return &websocketConn{c, newIDPool()}, nil
|
||||
}
|
||||
|
||||
// Below is modified IDPool from "vitess.io/vitess/go/pools"
|
||||
|
||||
// idPool is used to ensure that the set of IDs in use concurrently never
|
||||
// contains any duplicates. The IDs start at 1 and increase without bound, but
|
||||
// will never be larger than the peak number of concurrent uses.
|
||||
//
|
||||
// idPool's Get() and Put() methods can be used concurrently.
|
||||
type idPool struct {
|
||||
sync.Mutex
|
||||
|
||||
// used holds the set of values that have been returned to us with Put().
|
||||
used map[uint32]bool
|
||||
// maxUsed remembers the largest value we've given out.
|
||||
maxUsed uint32
|
||||
}
|
||||
|
||||
// NewIDPool creates and initializes an idPool.
|
||||
func newIDPool() *idPool {
|
||||
return &idPool{
|
||||
used: make(map[uint32]bool),
|
||||
maxUsed: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns an ID that is unique among currently active users of this pool.
|
||||
func (pool *idPool) Get() (id uint32) {
|
||||
pool.Lock()
|
||||
defer pool.Unlock()
|
||||
|
||||
// Pick a value that's been returned, if any.
|
||||
for key := range pool.used {
|
||||
delete(pool.used, key)
|
||||
return key
|
||||
}
|
||||
|
||||
// No recycled IDs are available, so increase the pool size.
|
||||
pool.maxUsed++
|
||||
return pool.maxUsed
|
||||
}
|
||||
|
||||
// Put recycles an ID back into the pool for others to use. Putting back a value
|
||||
// or 0, or a value that is not currently "checked out", will result in a panic
|
||||
// because that should never happen except in the case of a programming error.
|
||||
func (pool *idPool) Put(id uint32) {
|
||||
pool.Lock()
|
||||
defer pool.Unlock()
|
||||
|
||||
if id < 1 || id > pool.maxUsed {
|
||||
panic(fmt.Errorf("idPool.Put(%v): invalid value, must be in the range [1,%v]", id, pool.maxUsed))
|
||||
}
|
||||
|
||||
if pool.used[id] {
|
||||
panic(fmt.Errorf("idPool.Put(%v): can't put value that was already recycled", id))
|
||||
}
|
||||
|
||||
// If we're recycling maxUsed, just shrink the pool.
|
||||
if id == pool.maxUsed {
|
||||
pool.maxUsed = id - 1
|
||||
return
|
||||
}
|
||||
|
||||
// Add it to the set of recycled IDs.
|
||||
pool.used[id] = true
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConvertValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value any
|
||||
want map[string]any
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
value: nil,
|
||||
want: map[string]any{
|
||||
"type": "null",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "integer",
|
||||
value: int64(42),
|
||||
want: map[string]any{
|
||||
"type": "integer",
|
||||
"value": "42",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "text",
|
||||
value: "turso for win",
|
||||
want: map[string]any{
|
||||
"type": "text",
|
||||
"value": "turso for win",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "blob",
|
||||
value: []byte("hello world"),
|
||||
want: map[string]any{
|
||||
"type": "blob",
|
||||
// `hello world` encoded is `aGVsbG8gd29ybGQ=` but we want without padding
|
||||
"base64": "aGVsbG8gd29ybGQ",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
value: 3.14,
|
||||
want: map[string]any{
|
||||
"type": "float",
|
||||
"value": 3.14,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "boolean_true",
|
||||
value: true,
|
||||
want: map[string]any{
|
||||
"type": "integer",
|
||||
"value": "1",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "boolean_false",
|
||||
value: false,
|
||||
want: map[string]any{
|
||||
"type": "integer",
|
||||
"value": "0",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "unsupported",
|
||||
value: struct{}{},
|
||||
want: nil,
|
||||
err: fmt.Errorf("unsupported value type: %s", struct{}{}),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := convertValue(tt.value)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("got %v, want %v", got, tt.want)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.err) {
|
||||
t.Errorf("got error %v, want %v", err, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_execResponse_lastInsertId(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value map[string]interface{}
|
||||
want int64
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
value: map[string]interface{}{"last_insert_rowid": "42"},
|
||||
want: 42,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
value: map[string]interface{}{},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
value: map[string]interface{}{"last_insert_rowid": "invalid"},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "invalid_type",
|
||||
value: map[string]interface{}{"last_insert_rowid": 42.0},
|
||||
want: 0,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &execResponse{
|
||||
resp: tt.value,
|
||||
}
|
||||
if got := r.lastInsertId(); got != tt.want {
|
||||
t.Errorf("lastInsertId() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user