fix: 将 libsql-client-go 作为普通目录提交,而非 gitlink
This commit is contained in:
194
go_modules/libsql-client-go/libsql/internal/ws/driver.go
Normal file
194
go_modules/libsql-client-go/libsql/internal/ws/driver.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type result struct {
|
||||
id int64
|
||||
changes int64
|
||||
}
|
||||
|
||||
func (r *result) LastInsertId() (int64, error) {
|
||||
return r.id, nil
|
||||
}
|
||||
|
||||
func (r *result) RowsAffected() (int64, error) {
|
||||
return r.changes, nil
|
||||
}
|
||||
|
||||
type rows struct {
|
||||
res *execResponse
|
||||
currentRowIdx int
|
||||
}
|
||||
|
||||
func (r *rows) Columns() []string {
|
||||
return r.res.columns()
|
||||
}
|
||||
|
||||
func (r *rows) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rows) Next(dest []driver.Value) error {
|
||||
if r.currentRowIdx == r.res.rowsCount() {
|
||||
return io.EOF
|
||||
}
|
||||
count := r.res.rowLen(r.currentRowIdx)
|
||||
for idx := 0; idx < count; idx++ {
|
||||
v, err := r.res.value(r.currentRowIdx, idx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dest[idx] = v
|
||||
}
|
||||
r.currentRowIdx++
|
||||
return nil
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
ws *websocketConn
|
||||
}
|
||||
|
||||
func Connect(url string, jwt string) (*conn, error) {
|
||||
c, err := connect(url, jwt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &conn{c}, nil
|
||||
}
|
||||
|
||||
type stmt struct {
|
||||
c *conn
|
||||
query string
|
||||
}
|
||||
|
||||
func (s stmt) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stmt) NumInput() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func convertToNamed(args []driver.Value) []driver.NamedValue {
|
||||
if len(args) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := []driver.NamedValue{}
|
||||
for idx := range args {
|
||||
result = append(result, driver.NamedValue{Ordinal: idx, Value: args[idx]})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
return s.ExecContext(context.Background(), convertToNamed(args))
|
||||
}
|
||||
|
||||
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return s.QueryContext(context.Background(), convertToNamed(args))
|
||||
}
|
||||
|
||||
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
return s.c.ExecContext(ctx, s.query, args)
|
||||
}
|
||||
|
||||
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
return s.c.QueryContext(ctx, s.query, args)
|
||||
}
|
||||
|
||||
func (c *conn) Ping() error {
|
||||
return c.PingContext(context.Background())
|
||||
}
|
||||
|
||||
func (c *conn) PingContext(ctx context.Context) error {
|
||||
_, err := c.ws.exec(ctx, "SELECT 1", params{}, false)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
||||
return c.PrepareContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
|
||||
return stmt{c, query}, nil
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
return c.ws.Close()
|
||||
}
|
||||
|
||||
type tx struct {
|
||||
c *conn
|
||||
}
|
||||
|
||||
func (t tx) Commit() error {
|
||||
_, err := t.c.ExecContext(context.Background(), "COMMIT", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t tx) Rollback() error {
|
||||
_, err := t.c.ExecContext(context.Background(), "ROLLBACK", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) Begin() (driver.Tx, error) {
|
||||
return c.BeginTx(context.Background(), driver.TxOptions{})
|
||||
}
|
||||
|
||||
func (c *conn) BeginTx(ctx context.Context, _ driver.TxOptions) (driver.Tx, error) {
|
||||
_, err := c.ExecContext(ctx, "BEGIN", nil)
|
||||
if err != nil {
|
||||
return tx{nil}, err
|
||||
}
|
||||
return tx{c}, nil
|
||||
}
|
||||
|
||||
func convertArgs(args []driver.NamedValue) params {
|
||||
if len(args) == 0 {
|
||||
return params{}
|
||||
}
|
||||
positionalArgs := [](*driver.NamedValue){}
|
||||
namedArgs := []namedParam{}
|
||||
for idx := range args {
|
||||
if len(args[idx].Name) > 0 {
|
||||
namedArgs = append(namedArgs, namedParam{args[idx].Name, args[idx].Value})
|
||||
} else {
|
||||
positionalArgs = append(positionalArgs, &args[idx])
|
||||
}
|
||||
}
|
||||
sort.Slice(positionalArgs, func(i, j int) bool {
|
||||
return positionalArgs[i].Ordinal < positionalArgs[j].Ordinal
|
||||
})
|
||||
posArgs := [](any){}
|
||||
for idx := range positionalArgs {
|
||||
posArgs = append(posArgs, positionalArgs[idx].Value)
|
||||
}
|
||||
return params{PositinalArgs: posArgs, NamedArgs: namedArgs}
|
||||
}
|
||||
|
||||
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
res, err := c.ws.exec(ctx, query, convertArgs(args), false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result{res.lastInsertId(), res.affectedRowCount()}, nil
|
||||
}
|
||||
|
||||
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
res, err := c.ws.exec(ctx, query, convertArgs(args), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rows{res, 0}, nil
|
||||
}
|
||||
322
go_modules/libsql-client-go/libsql/internal/ws/websockets.go
Normal file
322
go_modules/libsql-client-go/libsql/internal/ws/websockets.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
)
|
||||
|
||||
// defaultWSTimeout specifies the timeout used for initial http connection
|
||||
var defaultWSTimeout = 120 * time.Second
|
||||
|
||||
func errorMsg(errorResp interface{}) string {
|
||||
return errorResp.(map[string]interface{})["error"].(map[string]interface{})["message"].(string)
|
||||
}
|
||||
|
||||
func isErrorResp(resp interface{}) bool {
|
||||
return resp.(map[string]interface{})["type"] == "response_error"
|
||||
}
|
||||
|
||||
type websocketConn struct {
|
||||
conn *websocket.Conn
|
||||
idPool *idPool
|
||||
}
|
||||
|
||||
type namedParam struct {
|
||||
Name string
|
||||
Value any
|
||||
}
|
||||
|
||||
type params struct {
|
||||
PositinalArgs []any
|
||||
NamedArgs []namedParam
|
||||
}
|
||||
|
||||
func convertValue(v any) (map[string]interface{}, error) {
|
||||
res := map[string]interface{}{}
|
||||
if v == nil {
|
||||
res["type"] = "null"
|
||||
} else if integer, ok := v.(int64); ok {
|
||||
res["type"] = "integer"
|
||||
res["value"] = strconv.FormatInt(integer, 10)
|
||||
} else if text, ok := v.(string); ok {
|
||||
res["type"] = "text"
|
||||
res["value"] = text
|
||||
} else if blob, ok := v.([]byte); ok {
|
||||
res["type"] = "blob"
|
||||
res["base64"] = base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString(blob)
|
||||
} else if float, ok := v.(float64); ok {
|
||||
res["type"] = "float"
|
||||
res["value"] = float
|
||||
} else if boolean, ok := v.(bool); ok {
|
||||
res["type"] = "integer"
|
||||
if boolean {
|
||||
res["value"] = "1"
|
||||
} else {
|
||||
res["value"] = "0"
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported value type: %s", v)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
type execResponse struct {
|
||||
resp map[string]interface{}
|
||||
}
|
||||
|
||||
func (r *execResponse) affectedRowCount() int64 {
|
||||
return int64(r.resp["affected_row_count"].(float64))
|
||||
}
|
||||
|
||||
func (r *execResponse) lastInsertId() int64 {
|
||||
id, ok := r.resp["last_insert_rowid"].(string)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
value, _ := strconv.ParseInt(id, 10, 64)
|
||||
return value
|
||||
}
|
||||
|
||||
func (r *execResponse) columns() []string {
|
||||
res := []string{}
|
||||
cols := r.resp["cols"].([]interface{})
|
||||
for idx := range cols {
|
||||
var v string = ""
|
||||
if cols[idx].(map[string]interface{})["name"] != nil {
|
||||
v = cols[idx].(map[string]interface{})["name"].(string)
|
||||
}
|
||||
res = append(res, v)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (r *execResponse) rowsCount() int {
|
||||
return len(r.resp["rows"].([]interface{}))
|
||||
}
|
||||
|
||||
func (r *execResponse) rowLen(rowIdx int) int {
|
||||
return len(r.resp["rows"].([]interface{})[rowIdx].([]interface{}))
|
||||
}
|
||||
|
||||
func (r *execResponse) value(rowIdx int, colIdx int) (any, error) {
|
||||
val := r.resp["rows"].([]interface{})[rowIdx].([]interface{})[colIdx].(map[string]interface{})
|
||||
switch val["type"] {
|
||||
case "null":
|
||||
return nil, nil
|
||||
case "integer":
|
||||
v, err := strconv.ParseInt(val["value"].(string), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v, nil
|
||||
case "text":
|
||||
return val["value"].(string), nil
|
||||
case "blob":
|
||||
base64Encoded := val["base64"].(string)
|
||||
v, err := base64.StdEncoding.WithPadding(base64.NoPadding).DecodeString(base64Encoded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v, nil
|
||||
case "float":
|
||||
return val["value"].(float64), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unrecognized value type: %s", val["type"])
|
||||
}
|
||||
|
||||
func (ws *websocketConn) exec(ctx context.Context, sql string, sqlParams params, wantRows bool) (*execResponse, error) {
|
||||
requestId := ws.idPool.Get()
|
||||
defer ws.idPool.Put(requestId)
|
||||
stmt := map[string]interface{}{
|
||||
"sql": sql,
|
||||
"want_rows": wantRows,
|
||||
}
|
||||
if len(sqlParams.PositinalArgs) > 0 {
|
||||
args := []map[string]interface{}{}
|
||||
for idx := range sqlParams.PositinalArgs {
|
||||
v, err := convertValue(sqlParams.PositinalArgs[idx])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, v)
|
||||
}
|
||||
stmt["args"] = args
|
||||
}
|
||||
if len(sqlParams.NamedArgs) > 0 {
|
||||
args := []map[string]interface{}{}
|
||||
for idx := range sqlParams.NamedArgs {
|
||||
v, err := convertValue(sqlParams.NamedArgs[idx].Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arg := map[string]interface{}{
|
||||
"name": sqlParams.NamedArgs[idx].Name,
|
||||
"value": v,
|
||||
}
|
||||
args = append(args, arg)
|
||||
}
|
||||
stmt["named_args"] = args
|
||||
}
|
||||
err := wsjson.Write(ctx, ws.conn, map[string]interface{}{
|
||||
"type": "request",
|
||||
"request_id": requestId,
|
||||
"request": map[string]interface{}{
|
||||
"type": "execute",
|
||||
"stream_id": 0,
|
||||
"stmt": stmt,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", driver.ErrBadConn, err.Error())
|
||||
}
|
||||
|
||||
var resp interface{}
|
||||
if err = wsjson.Read(ctx, ws.conn, &resp); err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", driver.ErrBadConn, err.Error())
|
||||
}
|
||||
|
||||
if isErrorResp(resp) {
|
||||
err = fmt.Errorf("unable to execute %s: %s", sql, errorMsg(resp))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &execResponse{resp.(map[string]interface{})["response"].(map[string]interface{})["result"].(map[string]interface{})}, nil
|
||||
}
|
||||
|
||||
func (ws *websocketConn) Close() error {
|
||||
return ws.conn.Close(websocket.StatusNormalClosure, "All's good")
|
||||
}
|
||||
|
||||
func connect(url string, jwt string) (*websocketConn, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultWSTimeout)
|
||||
defer cancel()
|
||||
c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
|
||||
Subprotocols: []string{"hrana1"},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.SetReadLimit(1024 * 1024 * 16) // 16MB
|
||||
|
||||
err = wsjson.Write(ctx, c, map[string]interface{}{
|
||||
"type": "hello",
|
||||
"jwt": jwt,
|
||||
})
|
||||
if err != nil {
|
||||
c.Close(websocket.StatusInternalError, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = wsjson.Write(ctx, c, map[string]interface{}{
|
||||
"type": "request",
|
||||
"request_id": 0,
|
||||
"request": map[string]interface{}{
|
||||
"type": "open_stream",
|
||||
"stream_id": 0,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
c.Close(websocket.StatusInternalError, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var helloResp interface{}
|
||||
err = wsjson.Read(ctx, c, &helloResp)
|
||||
if err != nil {
|
||||
c.Close(websocket.StatusInternalError, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
if helloResp.(map[string]interface{})["type"] == "hello_error" {
|
||||
err = fmt.Errorf("handshake error: %s", errorMsg(helloResp))
|
||||
c.Close(websocket.StatusProtocolError, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var openStreamResp interface{}
|
||||
err = wsjson.Read(ctx, c, &openStreamResp)
|
||||
if err != nil {
|
||||
c.Close(websocket.StatusInternalError, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isErrorResp(openStreamResp) {
|
||||
err = fmt.Errorf("unable to open stream: %s", errorMsg(helloResp))
|
||||
c.Close(websocket.StatusProtocolError, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
return &websocketConn{c, newIDPool()}, nil
|
||||
}
|
||||
|
||||
// Below is modified IDPool from "vitess.io/vitess/go/pools"
|
||||
|
||||
// idPool is used to ensure that the set of IDs in use concurrently never
|
||||
// contains any duplicates. The IDs start at 1 and increase without bound, but
|
||||
// will never be larger than the peak number of concurrent uses.
|
||||
//
|
||||
// idPool's Get() and Put() methods can be used concurrently.
|
||||
type idPool struct {
|
||||
sync.Mutex
|
||||
|
||||
// used holds the set of values that have been returned to us with Put().
|
||||
used map[uint32]bool
|
||||
// maxUsed remembers the largest value we've given out.
|
||||
maxUsed uint32
|
||||
}
|
||||
|
||||
// NewIDPool creates and initializes an idPool.
|
||||
func newIDPool() *idPool {
|
||||
return &idPool{
|
||||
used: make(map[uint32]bool),
|
||||
maxUsed: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns an ID that is unique among currently active users of this pool.
|
||||
func (pool *idPool) Get() (id uint32) {
|
||||
pool.Lock()
|
||||
defer pool.Unlock()
|
||||
|
||||
// Pick a value that's been returned, if any.
|
||||
for key := range pool.used {
|
||||
delete(pool.used, key)
|
||||
return key
|
||||
}
|
||||
|
||||
// No recycled IDs are available, so increase the pool size.
|
||||
pool.maxUsed++
|
||||
return pool.maxUsed
|
||||
}
|
||||
|
||||
// Put recycles an ID back into the pool for others to use. Putting back a value
|
||||
// or 0, or a value that is not currently "checked out", will result in a panic
|
||||
// because that should never happen except in the case of a programming error.
|
||||
func (pool *idPool) Put(id uint32) {
|
||||
pool.Lock()
|
||||
defer pool.Unlock()
|
||||
|
||||
if id < 1 || id > pool.maxUsed {
|
||||
panic(fmt.Errorf("idPool.Put(%v): invalid value, must be in the range [1,%v]", id, pool.maxUsed))
|
||||
}
|
||||
|
||||
if pool.used[id] {
|
||||
panic(fmt.Errorf("idPool.Put(%v): can't put value that was already recycled", id))
|
||||
}
|
||||
|
||||
// If we're recycling maxUsed, just shrink the pool.
|
||||
if id == pool.maxUsed {
|
||||
pool.maxUsed = id - 1
|
||||
return
|
||||
}
|
||||
|
||||
// Add it to the set of recycled IDs.
|
||||
pool.used[id] = true
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConvertValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value any
|
||||
want map[string]any
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
value: nil,
|
||||
want: map[string]any{
|
||||
"type": "null",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "integer",
|
||||
value: int64(42),
|
||||
want: map[string]any{
|
||||
"type": "integer",
|
||||
"value": "42",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "text",
|
||||
value: "turso for win",
|
||||
want: map[string]any{
|
||||
"type": "text",
|
||||
"value": "turso for win",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "blob",
|
||||
value: []byte("hello world"),
|
||||
want: map[string]any{
|
||||
"type": "blob",
|
||||
// `hello world` encoded is `aGVsbG8gd29ybGQ=` but we want without padding
|
||||
"base64": "aGVsbG8gd29ybGQ",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
value: 3.14,
|
||||
want: map[string]any{
|
||||
"type": "float",
|
||||
"value": 3.14,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "boolean_true",
|
||||
value: true,
|
||||
want: map[string]any{
|
||||
"type": "integer",
|
||||
"value": "1",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "boolean_false",
|
||||
value: false,
|
||||
want: map[string]any{
|
||||
"type": "integer",
|
||||
"value": "0",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "unsupported",
|
||||
value: struct{}{},
|
||||
want: nil,
|
||||
err: fmt.Errorf("unsupported value type: %s", struct{}{}),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := convertValue(tt.value)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("got %v, want %v", got, tt.want)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.err) {
|
||||
t.Errorf("got error %v, want %v", err, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_execResponse_lastInsertId(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value map[string]interface{}
|
||||
want int64
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
value: map[string]interface{}{"last_insert_rowid": "42"},
|
||||
want: 42,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
value: map[string]interface{}{},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
value: map[string]interface{}{"last_insert_rowid": "invalid"},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "invalid_type",
|
||||
value: map[string]interface{}{"last_insert_rowid": 42.0},
|
||||
want: 0,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &execResponse{
|
||||
resp: tt.value,
|
||||
}
|
||||
if got := r.lastInsertId(); got != tt.want {
|
||||
t.Errorf("lastInsertId() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user