195 lines
4.1 KiB
Go
195 lines
4.1 KiB
Go
package ws
|
|
|
|
import (
|
|
"context"
|
|
"database/sql/driver"
|
|
"io"
|
|
"sort"
|
|
)
|
|
|
|
type result struct {
|
|
id int64
|
|
changes int64
|
|
}
|
|
|
|
func (r *result) LastInsertId() (int64, error) {
|
|
return r.id, nil
|
|
}
|
|
|
|
func (r *result) RowsAffected() (int64, error) {
|
|
return r.changes, nil
|
|
}
|
|
|
|
type rows struct {
|
|
res *execResponse
|
|
currentRowIdx int
|
|
}
|
|
|
|
func (r *rows) Columns() []string {
|
|
return r.res.columns()
|
|
}
|
|
|
|
func (r *rows) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (r *rows) Next(dest []driver.Value) error {
|
|
if r.currentRowIdx == r.res.rowsCount() {
|
|
return io.EOF
|
|
}
|
|
count := r.res.rowLen(r.currentRowIdx)
|
|
for idx := 0; idx < count; idx++ {
|
|
v, err := r.res.value(r.currentRowIdx, idx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dest[idx] = v
|
|
}
|
|
r.currentRowIdx++
|
|
return nil
|
|
}
|
|
|
|
type conn struct {
|
|
ws *websocketConn
|
|
}
|
|
|
|
func Connect(url string, jwt string) (*conn, error) {
|
|
c, err := connect(url, jwt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &conn{c}, nil
|
|
}
|
|
|
|
type stmt struct {
|
|
c *conn
|
|
query string
|
|
}
|
|
|
|
func (s stmt) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (s stmt) NumInput() int {
|
|
return -1
|
|
}
|
|
|
|
func convertToNamed(args []driver.Value) []driver.NamedValue {
|
|
if len(args) == 0 {
|
|
return nil
|
|
}
|
|
result := []driver.NamedValue{}
|
|
for idx := range args {
|
|
result = append(result, driver.NamedValue{Ordinal: idx, Value: args[idx]})
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
|
|
return s.ExecContext(context.Background(), convertToNamed(args))
|
|
}
|
|
|
|
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
|
|
return s.QueryContext(context.Background(), convertToNamed(args))
|
|
}
|
|
|
|
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
|
return s.c.ExecContext(ctx, s.query, args)
|
|
}
|
|
|
|
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
|
return s.c.QueryContext(ctx, s.query, args)
|
|
}
|
|
|
|
func (c *conn) Ping() error {
|
|
return c.PingContext(context.Background())
|
|
}
|
|
|
|
func (c *conn) PingContext(ctx context.Context) error {
|
|
_, err := c.ws.exec(ctx, "SELECT 1", params{}, false)
|
|
return err
|
|
}
|
|
|
|
func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
|
return c.PrepareContext(context.Background(), query)
|
|
}
|
|
|
|
func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
|
|
return stmt{c, query}, nil
|
|
}
|
|
|
|
func (c *conn) Close() error {
|
|
return c.ws.Close()
|
|
}
|
|
|
|
type tx struct {
|
|
c *conn
|
|
}
|
|
|
|
func (t tx) Commit() error {
|
|
_, err := t.c.ExecContext(context.Background(), "COMMIT", nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t tx) Rollback() error {
|
|
_, err := t.c.ExecContext(context.Background(), "ROLLBACK", nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *conn) Begin() (driver.Tx, error) {
|
|
return c.BeginTx(context.Background(), driver.TxOptions{})
|
|
}
|
|
|
|
func (c *conn) BeginTx(ctx context.Context, _ driver.TxOptions) (driver.Tx, error) {
|
|
_, err := c.ExecContext(ctx, "BEGIN", nil)
|
|
if err != nil {
|
|
return tx{nil}, err
|
|
}
|
|
return tx{c}, nil
|
|
}
|
|
|
|
func convertArgs(args []driver.NamedValue) params {
|
|
if len(args) == 0 {
|
|
return params{}
|
|
}
|
|
positionalArgs := [](*driver.NamedValue){}
|
|
namedArgs := []namedParam{}
|
|
for idx := range args {
|
|
if len(args[idx].Name) > 0 {
|
|
namedArgs = append(namedArgs, namedParam{args[idx].Name, args[idx].Value})
|
|
} else {
|
|
positionalArgs = append(positionalArgs, &args[idx])
|
|
}
|
|
}
|
|
sort.Slice(positionalArgs, func(i, j int) bool {
|
|
return positionalArgs[i].Ordinal < positionalArgs[j].Ordinal
|
|
})
|
|
posArgs := [](any){}
|
|
for idx := range positionalArgs {
|
|
posArgs = append(posArgs, positionalArgs[idx].Value)
|
|
}
|
|
return params{PositinalArgs: posArgs, NamedArgs: namedArgs}
|
|
}
|
|
|
|
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
|
res, err := c.ws.exec(ctx, query, convertArgs(args), false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &result{res.lastInsertId(), res.affectedRowCount()}, nil
|
|
}
|
|
|
|
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
|
res, err := c.ws.exec(ctx, query, convertArgs(args), true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &rows{res, 0}, nil
|
|
}
|