fix: 将 libsql-client-go 作为普通目录提交,而非 gitlink

This commit is contained in:
2026-04-27 07:04:46 +08:00
parent 2359d6c9fa
commit f6332fbaaf
52 changed files with 42333 additions and 1 deletions

View File

@@ -0,0 +1,194 @@
package ws
import (
"context"
"database/sql/driver"
"io"
"sort"
)
type result struct {
id int64
changes int64
}
func (r *result) LastInsertId() (int64, error) {
return r.id, nil
}
func (r *result) RowsAffected() (int64, error) {
return r.changes, nil
}
type rows struct {
res *execResponse
currentRowIdx int
}
func (r *rows) Columns() []string {
return r.res.columns()
}
func (r *rows) Close() error {
return nil
}
func (r *rows) Next(dest []driver.Value) error {
if r.currentRowIdx == r.res.rowsCount() {
return io.EOF
}
count := r.res.rowLen(r.currentRowIdx)
for idx := 0; idx < count; idx++ {
v, err := r.res.value(r.currentRowIdx, idx)
if err != nil {
return err
}
dest[idx] = v
}
r.currentRowIdx++
return nil
}
type conn struct {
ws *websocketConn
}
func Connect(url string, jwt string) (*conn, error) {
c, err := connect(url, jwt)
if err != nil {
return nil, err
}
return &conn{c}, nil
}
type stmt struct {
c *conn
query string
}
func (s stmt) Close() error {
return nil
}
func (s stmt) NumInput() int {
return -1
}
func convertToNamed(args []driver.Value) []driver.NamedValue {
if len(args) == 0 {
return nil
}
result := []driver.NamedValue{}
for idx := range args {
result = append(result, driver.NamedValue{Ordinal: idx, Value: args[idx]})
}
return result
}
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.ExecContext(context.Background(), convertToNamed(args))
}
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.QueryContext(context.Background(), convertToNamed(args))
}
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
return s.c.ExecContext(ctx, s.query, args)
}
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
return s.c.QueryContext(ctx, s.query, args)
}
func (c *conn) Ping() error {
return c.PingContext(context.Background())
}
func (c *conn) PingContext(ctx context.Context) error {
_, err := c.ws.exec(ctx, "SELECT 1", params{}, false)
return err
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
return stmt{c, query}, nil
}
func (c *conn) Close() error {
return c.ws.Close()
}
type tx struct {
c *conn
}
func (t tx) Commit() error {
_, err := t.c.ExecContext(context.Background(), "COMMIT", nil)
if err != nil {
return err
}
return nil
}
func (t tx) Rollback() error {
_, err := t.c.ExecContext(context.Background(), "ROLLBACK", nil)
if err != nil {
return err
}
return nil
}
func (c *conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c *conn) BeginTx(ctx context.Context, _ driver.TxOptions) (driver.Tx, error) {
_, err := c.ExecContext(ctx, "BEGIN", nil)
if err != nil {
return tx{nil}, err
}
return tx{c}, nil
}
func convertArgs(args []driver.NamedValue) params {
if len(args) == 0 {
return params{}
}
positionalArgs := [](*driver.NamedValue){}
namedArgs := []namedParam{}
for idx := range args {
if len(args[idx].Name) > 0 {
namedArgs = append(namedArgs, namedParam{args[idx].Name, args[idx].Value})
} else {
positionalArgs = append(positionalArgs, &args[idx])
}
}
sort.Slice(positionalArgs, func(i, j int) bool {
return positionalArgs[i].Ordinal < positionalArgs[j].Ordinal
})
posArgs := [](any){}
for idx := range positionalArgs {
posArgs = append(posArgs, positionalArgs[idx].Value)
}
return params{PositinalArgs: posArgs, NamedArgs: namedArgs}
}
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
res, err := c.ws.exec(ctx, query, convertArgs(args), false)
if err != nil {
return nil, err
}
return &result{res.lastInsertId(), res.affectedRowCount()}, nil
}
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
res, err := c.ws.exec(ctx, query, convertArgs(args), true)
if err != nil {
return nil, err
}
return &rows{res, 0}, nil
}

View File

@@ -0,0 +1,322 @@
package ws
import (
"context"
"database/sql/driver"
"encoding/base64"
"fmt"
"strconv"
"sync"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
// defaultWSTimeout specifies the timeout used for initial http connection
var defaultWSTimeout = 120 * time.Second
func errorMsg(errorResp interface{}) string {
return errorResp.(map[string]interface{})["error"].(map[string]interface{})["message"].(string)
}
func isErrorResp(resp interface{}) bool {
return resp.(map[string]interface{})["type"] == "response_error"
}
type websocketConn struct {
conn *websocket.Conn
idPool *idPool
}
type namedParam struct {
Name string
Value any
}
type params struct {
PositinalArgs []any
NamedArgs []namedParam
}
func convertValue(v any) (map[string]interface{}, error) {
res := map[string]interface{}{}
if v == nil {
res["type"] = "null"
} else if integer, ok := v.(int64); ok {
res["type"] = "integer"
res["value"] = strconv.FormatInt(integer, 10)
} else if text, ok := v.(string); ok {
res["type"] = "text"
res["value"] = text
} else if blob, ok := v.([]byte); ok {
res["type"] = "blob"
res["base64"] = base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString(blob)
} else if float, ok := v.(float64); ok {
res["type"] = "float"
res["value"] = float
} else if boolean, ok := v.(bool); ok {
res["type"] = "integer"
if boolean {
res["value"] = "1"
} else {
res["value"] = "0"
}
} else {
return nil, fmt.Errorf("unsupported value type: %s", v)
}
return res, nil
}
type execResponse struct {
resp map[string]interface{}
}
func (r *execResponse) affectedRowCount() int64 {
return int64(r.resp["affected_row_count"].(float64))
}
func (r *execResponse) lastInsertId() int64 {
id, ok := r.resp["last_insert_rowid"].(string)
if !ok {
return 0
}
value, _ := strconv.ParseInt(id, 10, 64)
return value
}
func (r *execResponse) columns() []string {
res := []string{}
cols := r.resp["cols"].([]interface{})
for idx := range cols {
var v string = ""
if cols[idx].(map[string]interface{})["name"] != nil {
v = cols[idx].(map[string]interface{})["name"].(string)
}
res = append(res, v)
}
return res
}
func (r *execResponse) rowsCount() int {
return len(r.resp["rows"].([]interface{}))
}
func (r *execResponse) rowLen(rowIdx int) int {
return len(r.resp["rows"].([]interface{})[rowIdx].([]interface{}))
}
func (r *execResponse) value(rowIdx int, colIdx int) (any, error) {
val := r.resp["rows"].([]interface{})[rowIdx].([]interface{})[colIdx].(map[string]interface{})
switch val["type"] {
case "null":
return nil, nil
case "integer":
v, err := strconv.ParseInt(val["value"].(string), 10, 64)
if err != nil {
return nil, err
}
return v, nil
case "text":
return val["value"].(string), nil
case "blob":
base64Encoded := val["base64"].(string)
v, err := base64.StdEncoding.WithPadding(base64.NoPadding).DecodeString(base64Encoded)
if err != nil {
return nil, err
}
return v, nil
case "float":
return val["value"].(float64), nil
}
return nil, fmt.Errorf("unrecognized value type: %s", val["type"])
}
func (ws *websocketConn) exec(ctx context.Context, sql string, sqlParams params, wantRows bool) (*execResponse, error) {
requestId := ws.idPool.Get()
defer ws.idPool.Put(requestId)
stmt := map[string]interface{}{
"sql": sql,
"want_rows": wantRows,
}
if len(sqlParams.PositinalArgs) > 0 {
args := []map[string]interface{}{}
for idx := range sqlParams.PositinalArgs {
v, err := convertValue(sqlParams.PositinalArgs[idx])
if err != nil {
return nil, err
}
args = append(args, v)
}
stmt["args"] = args
}
if len(sqlParams.NamedArgs) > 0 {
args := []map[string]interface{}{}
for idx := range sqlParams.NamedArgs {
v, err := convertValue(sqlParams.NamedArgs[idx].Value)
if err != nil {
return nil, err
}
arg := map[string]interface{}{
"name": sqlParams.NamedArgs[idx].Name,
"value": v,
}
args = append(args, arg)
}
stmt["named_args"] = args
}
err := wsjson.Write(ctx, ws.conn, map[string]interface{}{
"type": "request",
"request_id": requestId,
"request": map[string]interface{}{
"type": "execute",
"stream_id": 0,
"stmt": stmt,
},
})
if err != nil {
return nil, fmt.Errorf("%w: %s", driver.ErrBadConn, err.Error())
}
var resp interface{}
if err = wsjson.Read(ctx, ws.conn, &resp); err != nil {
return nil, fmt.Errorf("%w: %s", driver.ErrBadConn, err.Error())
}
if isErrorResp(resp) {
err = fmt.Errorf("unable to execute %s: %s", sql, errorMsg(resp))
return nil, err
}
return &execResponse{resp.(map[string]interface{})["response"].(map[string]interface{})["result"].(map[string]interface{})}, nil
}
func (ws *websocketConn) Close() error {
return ws.conn.Close(websocket.StatusNormalClosure, "All's good")
}
func connect(url string, jwt string) (*websocketConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultWSTimeout)
defer cancel()
c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
Subprotocols: []string{"hrana1"},
})
if err != nil {
return nil, err
}
c.SetReadLimit(1024 * 1024 * 16) // 16MB
err = wsjson.Write(ctx, c, map[string]interface{}{
"type": "hello",
"jwt": jwt,
})
if err != nil {
c.Close(websocket.StatusInternalError, err.Error())
return nil, err
}
err = wsjson.Write(ctx, c, map[string]interface{}{
"type": "request",
"request_id": 0,
"request": map[string]interface{}{
"type": "open_stream",
"stream_id": 0,
},
})
if err != nil {
c.Close(websocket.StatusInternalError, err.Error())
return nil, err
}
var helloResp interface{}
err = wsjson.Read(ctx, c, &helloResp)
if err != nil {
c.Close(websocket.StatusInternalError, err.Error())
return nil, err
}
if helloResp.(map[string]interface{})["type"] == "hello_error" {
err = fmt.Errorf("handshake error: %s", errorMsg(helloResp))
c.Close(websocket.StatusProtocolError, err.Error())
return nil, err
}
var openStreamResp interface{}
err = wsjson.Read(ctx, c, &openStreamResp)
if err != nil {
c.Close(websocket.StatusInternalError, err.Error())
return nil, err
}
if isErrorResp(openStreamResp) {
err = fmt.Errorf("unable to open stream: %s", errorMsg(helloResp))
c.Close(websocket.StatusProtocolError, err.Error())
return nil, err
}
return &websocketConn{c, newIDPool()}, nil
}
// Below is modified IDPool from "vitess.io/vitess/go/pools"
// idPool is used to ensure that the set of IDs in use concurrently never
// contains any duplicates. The IDs start at 1 and increase without bound, but
// will never be larger than the peak number of concurrent uses.
//
// idPool's Get() and Put() methods can be used concurrently.
type idPool struct {
sync.Mutex
// used holds the set of values that have been returned to us with Put().
used map[uint32]bool
// maxUsed remembers the largest value we've given out.
maxUsed uint32
}
// NewIDPool creates and initializes an idPool.
func newIDPool() *idPool {
return &idPool{
used: make(map[uint32]bool),
maxUsed: 0,
}
}
// Get returns an ID that is unique among currently active users of this pool.
func (pool *idPool) Get() (id uint32) {
pool.Lock()
defer pool.Unlock()
// Pick a value that's been returned, if any.
for key := range pool.used {
delete(pool.used, key)
return key
}
// No recycled IDs are available, so increase the pool size.
pool.maxUsed++
return pool.maxUsed
}
// Put recycles an ID back into the pool for others to use. Putting back a value
// or 0, or a value that is not currently "checked out", will result in a panic
// because that should never happen except in the case of a programming error.
func (pool *idPool) Put(id uint32) {
pool.Lock()
defer pool.Unlock()
if id < 1 || id > pool.maxUsed {
panic(fmt.Errorf("idPool.Put(%v): invalid value, must be in the range [1,%v]", id, pool.maxUsed))
}
if pool.used[id] {
panic(fmt.Errorf("idPool.Put(%v): can't put value that was already recycled", id))
}
// If we're recycling maxUsed, just shrink the pool.
if id == pool.maxUsed {
pool.maxUsed = id - 1
return
}
// Add it to the set of recycled IDs.
pool.used[id] = true
}

View File

@@ -0,0 +1,136 @@
package ws
import (
"fmt"
"reflect"
"testing"
)
func TestConvertValue(t *testing.T) {
tests := []struct {
name string
value any
want map[string]any
err error
}{
{
name: "nil",
value: nil,
want: map[string]any{
"type": "null",
},
err: nil,
},
{
name: "integer",
value: int64(42),
want: map[string]any{
"type": "integer",
"value": "42",
},
err: nil,
},
{
name: "text",
value: "turso for win",
want: map[string]any{
"type": "text",
"value": "turso for win",
},
err: nil,
},
{
name: "blob",
value: []byte("hello world"),
want: map[string]any{
"type": "blob",
// `hello world` encoded is `aGVsbG8gd29ybGQ=` but we want without padding
"base64": "aGVsbG8gd29ybGQ",
},
err: nil,
},
{
name: "float",
value: 3.14,
want: map[string]any{
"type": "float",
"value": 3.14,
},
err: nil,
},
{
name: "boolean_true",
value: true,
want: map[string]any{
"type": "integer",
"value": "1",
},
err: nil,
},
{
name: "boolean_false",
value: false,
want: map[string]any{
"type": "integer",
"value": "0",
},
err: nil,
},
{
name: "unsupported",
value: struct{}{},
want: nil,
err: fmt.Errorf("unsupported value type: %s", struct{}{}),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := convertValue(tt.value)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("got %v, want %v", got, tt.want)
}
if !reflect.DeepEqual(err, tt.err) {
t.Errorf("got error %v, want %v", err, tt.err)
}
})
}
}
func Test_execResponse_lastInsertId(t *testing.T) {
tests := []struct {
name string
value map[string]interface{}
want int64
}{
{
name: "valid",
value: map[string]interface{}{"last_insert_rowid": "42"},
want: 42,
},
{
name: "empty",
value: map[string]interface{}{},
want: 0,
},
{
name: "invalid",
value: map[string]interface{}{"last_insert_rowid": "invalid"},
want: 0,
},
{
name: "invalid_type",
value: map[string]interface{}{"last_insert_rowid": 42.0},
want: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &execResponse{
resp: tt.value,
}
if got := r.lastInsertId(); got != tt.want {
t.Errorf("lastInsertId() = %v, want %v", got, tt.want)
}
})
}
}