fix: 将 libsql-client-go 作为普通目录提交,而非 gitlink
This commit is contained in:
562
go_modules/libsql-client-go/tests/http/driver_test.go
Normal file
562
go_modules/libsql-client-go/tests/http/driver_test.go
Normal file
@@ -0,0 +1,562 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tursodatabase/libsql-client-go/libsql"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
_ "github.com/tursodatabase/libsql-client-go/libsql"
|
||||
)
|
||||
|
||||
type T struct {
|
||||
*testing.T
|
||||
}
|
||||
|
||||
func (t T) FatalWithMsg(msg string) {
|
||||
t.Log(string(debug.Stack()))
|
||||
t.Fatal(msg)
|
||||
}
|
||||
|
||||
func (t T) FatalOnError(err error) {
|
||||
if err != nil {
|
||||
t.Log(string(debug.Stack()))
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
type Database struct {
|
||||
*sql.DB
|
||||
t T
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func getDb(t T) Database {
|
||||
dbURL := os.Getenv("LIBSQL_TEST_HTTP_DB_URL")
|
||||
authToken := os.Getenv("LIBSQL_TEST_HTTP_AUTH_TOKEN")
|
||||
var connector driver.Connector
|
||||
var err error
|
||||
if authToken == "" {
|
||||
connector, err = libsql.NewConnector(dbURL)
|
||||
} else {
|
||||
connector, err = libsql.NewConnector(dbURL, libsql.WithAuthToken(authToken))
|
||||
}
|
||||
t.FatalOnError(err)
|
||||
db := sql.OpenDB(connector)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
cancel()
|
||||
})
|
||||
|
||||
return Database{db, t, ctx}
|
||||
}
|
||||
|
||||
func (db Database) exec(sql string, args ...any) sql.Result {
|
||||
res, err := db.ExecContext(db.ctx, sql, args...)
|
||||
db.t.FatalOnError(err)
|
||||
return res
|
||||
}
|
||||
|
||||
func (db Database) query(sql string, args ...any) *sql.Rows {
|
||||
rows, err := db.QueryContext(db.ctx, sql, args...)
|
||||
db.t.FatalOnError(err)
|
||||
return rows
|
||||
}
|
||||
|
||||
type Table struct {
|
||||
name string
|
||||
db Database
|
||||
}
|
||||
|
||||
func (db Database) createTable() Table {
|
||||
name := "test_" + fmt.Sprint(rand.Int()) + "_" + time.Now().Format("20060102150405")
|
||||
db.exec("CREATE TABLE " + name + " (a int, b int)")
|
||||
db.t.Cleanup(func() {
|
||||
db.exec("DROP TABLE " + name)
|
||||
})
|
||||
return Table{name, db}
|
||||
}
|
||||
|
||||
func (db Database) assertTable(name string) {
|
||||
rows, err := db.QueryContext(db.ctx, "select 1 from "+name)
|
||||
db.t.FatalOnError(err)
|
||||
defer rows.Close()
|
||||
}
|
||||
|
||||
func (t Table) insertRows(start, count int) {
|
||||
t.insertRowsInternal(start, count, func(i int) sql.Result {
|
||||
return t.db.exec("INSERT INTO " + t.name + " (a, b) VALUES (" + fmt.Sprint(i) + ", " + fmt.Sprint(i) + ")")
|
||||
})
|
||||
}
|
||||
|
||||
func (t Table) insertRowsWithArgs(start, count int) {
|
||||
t.insertRowsInternal(start, count, func(i int) sql.Result {
|
||||
return t.db.exec("INSERT INTO "+t.name+" (a, b) VALUES (?, ?)", i, i)
|
||||
})
|
||||
}
|
||||
|
||||
func (t Table) insertRowsWithColonNamedArgs(start, count int) {
|
||||
t.insertRowsInternal(start, count, func(i int) sql.Result {
|
||||
return t.db.exec("INSERT INTO "+t.name+" (a, b) VALUES (:Arg1, :Arg2)", sql.Named("Arg1", i), sql.Named("Arg2", i))
|
||||
})
|
||||
}
|
||||
|
||||
func (t Table) insertRowsWithAtNamedArgs(start, count int) {
|
||||
t.insertRowsInternal(start, count, func(i int) sql.Result {
|
||||
return t.db.exec("INSERT INTO "+t.name+" (a, b) VALUES (@Arg1, @Arg2)", sql.Named("Arg1", i), sql.Named("Arg2", i))
|
||||
})
|
||||
}
|
||||
|
||||
func (t Table) insertRowsWithDollarNamedArgs(start, count int) {
|
||||
t.insertRowsInternal(start, count, func(i int) sql.Result {
|
||||
return t.db.exec("INSERT INTO "+t.name+" (a, b) VALUES ($Arg1, $Arg2)", sql.Named("Arg1", i), sql.Named("Arg2", i))
|
||||
})
|
||||
}
|
||||
|
||||
func (t Table) insertRowsInternal(start, count int, execFn func(i int) sql.Result) {
|
||||
for i := 0; i < count; i++ {
|
||||
res := execFn(i + start)
|
||||
affected, err := res.RowsAffected()
|
||||
t.db.t.FatalOnError(err)
|
||||
if affected != 1 {
|
||||
t.db.t.FatalWithMsg("expected 1 row affected")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t Table) assertRowsCount(count int) {
|
||||
t.assertCount(count, func() *sql.Rows {
|
||||
return t.db.query("SELECT COUNT(*) FROM " + t.name)
|
||||
})
|
||||
}
|
||||
|
||||
func (t Table) assertRowDoesNotExist(id int) {
|
||||
t.assertCount(0, func() *sql.Rows {
|
||||
return t.db.query("SELECT COUNT(*) FROM "+t.name+" WHERE a = ?", id)
|
||||
})
|
||||
}
|
||||
|
||||
func (t Table) assertRowExists(id int) {
|
||||
t.assertCount(1, func() *sql.Rows {
|
||||
return t.db.query("SELECT COUNT(*) FROM "+t.name+" WHERE a = ?", id)
|
||||
})
|
||||
}
|
||||
|
||||
func (t Table) assertCount(expectedCount int, queryFn func() *sql.Rows) {
|
||||
rows := queryFn()
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
t.db.t.FatalWithMsg("expected at least one row")
|
||||
}
|
||||
var rowCount int
|
||||
t.db.t.FatalOnError(rows.Scan(&rowCount))
|
||||
if rowCount != expectedCount {
|
||||
t.db.t.FatalWithMsg(fmt.Sprintf("expected %d rows, got %d", expectedCount, rowCount))
|
||||
}
|
||||
}
|
||||
|
||||
func (t Table) beginTx() Tx {
|
||||
tx, err := t.db.BeginTx(t.db.ctx, nil)
|
||||
t.db.t.FatalOnError(err)
|
||||
return Tx{tx, t, nil}
|
||||
}
|
||||
|
||||
func (t Table) beginTxWithContext(ctx context.Context) Tx {
|
||||
tx, err := t.db.BeginTx(ctx, nil)
|
||||
t.db.t.FatalOnError(err)
|
||||
return Tx{tx, t, &ctx}
|
||||
}
|
||||
|
||||
func (t Table) prepareInsertStmt() PreparedStmt {
|
||||
stmt, err := t.db.Prepare("INSERT INTO " + t.name + " (a, b) VALUES (?, ?)")
|
||||
t.db.t.FatalOnError(err)
|
||||
return PreparedStmt{stmt, t}
|
||||
}
|
||||
|
||||
type PreparedStmt struct {
|
||||
*sql.Stmt
|
||||
t Table
|
||||
}
|
||||
|
||||
func (s PreparedStmt) exec(args ...any) sql.Result {
|
||||
res, err := s.ExecContext(s.t.db.ctx, args...)
|
||||
s.t.db.t.FatalOnError(err)
|
||||
return res
|
||||
}
|
||||
|
||||
type Tx struct {
|
||||
*sql.Tx
|
||||
t Table
|
||||
ctx *context.Context
|
||||
}
|
||||
|
||||
func (t Tx) context() context.Context {
|
||||
if t.ctx != nil {
|
||||
return *t.ctx
|
||||
}
|
||||
return t.t.db.ctx
|
||||
}
|
||||
|
||||
func (t Tx) exec(sql string, args ...any) sql.Result {
|
||||
res, err := t.ExecContext(t.context(), sql, args...)
|
||||
t.t.db.t.FatalOnError(err)
|
||||
return res
|
||||
}
|
||||
|
||||
func (t Tx) query(sql string, args ...any) *sql.Rows {
|
||||
rows, err := t.QueryContext(t.context(), sql, args...)
|
||||
t.t.db.t.FatalOnError(err)
|
||||
return rows
|
||||
}
|
||||
|
||||
func (t Tx) insertRows(start, count int) {
|
||||
t.t.insertRowsInternal(start, count, func(i int) sql.Result {
|
||||
return t.exec("INSERT INTO " + t.t.name + " (a, b) VALUES (" + fmt.Sprint(i) + ", '" + fmt.Sprint(i) + "')")
|
||||
})
|
||||
}
|
||||
|
||||
func (t Tx) insertRowsWithArgs(start, count int) {
|
||||
t.t.insertRowsInternal(start, count, func(i int) sql.Result {
|
||||
return t.exec("INSERT INTO "+t.t.name+" (a, b) VALUES (?, ?)", i, fmt.Sprint(i))
|
||||
})
|
||||
}
|
||||
|
||||
func (t Tx) assertRowsCount(count int) {
|
||||
t.t.assertCount(count, func() *sql.Rows {
|
||||
return t.query("SELECT COUNT(*) FROM " + t.t.name)
|
||||
})
|
||||
}
|
||||
|
||||
func (t Tx) assertRowDoesNotExist(id int) {
|
||||
t.t.assertCount(0, func() *sql.Rows {
|
||||
return t.query("SELECT COUNT(*) FROM "+t.t.name+" WHERE a = ?", id)
|
||||
})
|
||||
}
|
||||
|
||||
func (t Tx) assertRowExists(id int) {
|
||||
t.t.assertCount(1, func() *sql.Rows {
|
||||
return t.query("SELECT COUNT(*) FROM "+t.t.name+" WHERE a = ?", id)
|
||||
})
|
||||
}
|
||||
|
||||
func (t Tx) prepareInsertStmt() PreparedStmt {
|
||||
stmt, err := t.Prepare("INSERT INTO " + t.t.name + " (a, b) VALUES (?, ?)")
|
||||
t.t.db.t.FatalOnError(err)
|
||||
return PreparedStmt{stmt, t.t}
|
||||
}
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
|
||||
// This ping should succeed because the database is up and running
|
||||
db.t.FatalOnError(db.Ping())
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
|
||||
// This ping should return an error because the database is already closed
|
||||
err := db.Ping()
|
||||
if err == nil {
|
||||
db.t.Fatal("db.Ping succeeded when it should have failed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExecAndQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
table := db.createTable()
|
||||
table.insertRows(0, 10)
|
||||
table.insertRowsWithArgs(10, 10)
|
||||
table.assertRowsCount(20)
|
||||
table.assertRowDoesNotExist(20)
|
||||
table.assertRowExists(0)
|
||||
table.assertRowExists(19)
|
||||
}
|
||||
|
||||
func TestPreparedStatements(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
table := db.createTable()
|
||||
stmt := table.prepareInsertStmt()
|
||||
stmt.exec(1, "1")
|
||||
db.t.FatalOnError(stmt.Close())
|
||||
table.assertRowsCount(1)
|
||||
table.assertRowExists(1)
|
||||
}
|
||||
|
||||
func TestTransaction(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
table := db.createTable()
|
||||
tx := table.beginTx()
|
||||
tx.insertRows(0, 10)
|
||||
tx.insertRowsWithArgs(10, 10)
|
||||
tx.assertRowsCount(20)
|
||||
tx.assertRowDoesNotExist(20)
|
||||
tx.assertRowExists(0)
|
||||
tx.assertRowExists(19)
|
||||
db.t.FatalOnError(tx.Commit())
|
||||
table.assertRowsCount(20)
|
||||
table.assertRowDoesNotExist(20)
|
||||
table.assertRowExists(0)
|
||||
table.assertRowExists(19)
|
||||
}
|
||||
|
||||
func TestMultiLineStatement(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
db.exec("CREATE TABLE IF NOT EXISTS my_table (my_data TEXT); INSERT INTO my_table (my_data) VALUES ('hello');")
|
||||
t.Cleanup(func() {
|
||||
db.exec("DROP TABLE my_table")
|
||||
})
|
||||
table := Table{"my_table", db}
|
||||
db.assertTable("my_table")
|
||||
table.assertRowsCount(1)
|
||||
}
|
||||
|
||||
func TestPreparedStatementInTransaction(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
table := db.createTable()
|
||||
tx := table.beginTx()
|
||||
stmt := tx.prepareInsertStmt()
|
||||
stmt.exec(1, "1")
|
||||
db.t.FatalOnError(stmt.Close())
|
||||
tx.assertRowsCount(1)
|
||||
tx.assertRowExists(1)
|
||||
db.t.FatalOnError(tx.Commit())
|
||||
table.assertRowsCount(1)
|
||||
table.assertRowExists(1)
|
||||
}
|
||||
|
||||
func TestPreparedStatementInTransactionRollback(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
table := db.createTable()
|
||||
tx := table.beginTx()
|
||||
stmt := tx.prepareInsertStmt()
|
||||
stmt.exec(1, "1")
|
||||
db.t.FatalOnError(stmt.Close())
|
||||
tx.assertRowsCount(1)
|
||||
tx.assertRowExists(1)
|
||||
db.t.FatalOnError(tx.Rollback())
|
||||
table.assertRowsCount(0)
|
||||
table.assertRowDoesNotExist(1)
|
||||
}
|
||||
|
||||
func TestTransactionRollback(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
table := db.createTable()
|
||||
tx := table.beginTx()
|
||||
tx.insertRows(0, 10)
|
||||
tx.insertRowsWithArgs(10, 10)
|
||||
tx.assertRowsCount(20)
|
||||
tx.assertRowDoesNotExist(20)
|
||||
tx.assertRowExists(0)
|
||||
tx.assertRowExists(19)
|
||||
db.t.FatalOnError(tx.Rollback())
|
||||
table.assertRowsCount(0)
|
||||
}
|
||||
|
||||
func TestCancelContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err := db.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS test (id INTEGER PRIMARY KEY, name TEXT)")
|
||||
if err == nil {
|
||||
db.t.FatalWithMsg("should have failed")
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
db.t.FatalWithMsg("should have failed with context.Canceled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelContextWithTransaction(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
table := db.createTable()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
tx := table.beginTxWithContext(ctx)
|
||||
tx.insertRows(0, 10)
|
||||
tx.insertRowsWithArgs(10, 10)
|
||||
tx.assertRowsCount(20)
|
||||
tx.assertRowDoesNotExist(20)
|
||||
tx.assertRowExists(0)
|
||||
tx.assertRowExists(19)
|
||||
// let's cancel the context before the commit
|
||||
cancel()
|
||||
err := tx.Commit()
|
||||
if err == nil {
|
||||
db.t.FatalWithMsg("should have failed")
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
db.t.FatalWithMsg("should have failed with context.Canceled")
|
||||
}
|
||||
// rolling back the transaction should not result in any error
|
||||
db.t.FatalOnError(tx.Rollback())
|
||||
}
|
||||
|
||||
func TestDataTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
var (
|
||||
text string
|
||||
nullText sql.NullString
|
||||
integer sql.NullInt64
|
||||
nullInteger sql.NullInt64
|
||||
boolean bool
|
||||
float8 float64
|
||||
nullFloat sql.NullFloat64
|
||||
bytea []byte
|
||||
)
|
||||
db.t.FatalOnError(db.QueryRowContext(db.ctx, "SELECT 'foobar' as text, NULL as text, NULL as integer, 42 as integer, 1 as boolean, X'000102' as bytea, 3.14 as float8, NULL as float8;").Scan(&text, &nullText, &nullInteger, &integer, &boolean, &bytea, &float8, &nullFloat))
|
||||
switch {
|
||||
case text != "foobar":
|
||||
t.Error("value mismatch - text")
|
||||
case nullText.Valid:
|
||||
t.Error("null text is valid")
|
||||
case nullInteger.Valid:
|
||||
t.Error("null integer is valid")
|
||||
case !integer.Valid:
|
||||
t.Error("integer is not valid")
|
||||
case integer.Int64 != 42:
|
||||
t.Error("value mismatch - integer")
|
||||
case !boolean:
|
||||
t.Error("value mismatch - boolean")
|
||||
case float8 != 3.14:
|
||||
t.Error("value mismatch - float8")
|
||||
case !bytes.Equal(bytea, []byte{0, 1, 2}):
|
||||
t.Error("value mismatch - bytea")
|
||||
case nullFloat.Valid:
|
||||
t.Error("null float is valid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentOnSingleConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
t1 := db.createTable()
|
||||
t2 := db.createTable()
|
||||
t3 := db.createTable()
|
||||
t1.insertRowsInternal(1, 10, func(i int) sql.Result {
|
||||
return t1.db.exec("INSERT INTO "+t1.name+" VALUES(?, ?)", i, i)
|
||||
})
|
||||
t2.insertRowsInternal(1, 10, func(i int) sql.Result {
|
||||
return t2.db.exec("INSERT INTO "+t2.name+" VALUES(?, ?)", i, -1*i)
|
||||
})
|
||||
t3.insertRowsInternal(1, 10, func(i int) sql.Result {
|
||||
return t3.db.exec("INSERT INTO "+t3.name+" VALUES(?, ?)", i, 0)
|
||||
})
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
conn, err := db.Conn(context.Background())
|
||||
db.t.FatalOnError(err)
|
||||
defer conn.Close()
|
||||
worker := func(t Table, check func(int) error) func() error {
|
||||
return func() error {
|
||||
for i := 1; i < 100; i++ {
|
||||
// Each iteration is wrapped into a function to make sure that `defer rows.Close()`
|
||||
// is called after each iteration not at the end of the outer function
|
||||
err := func() error {
|
||||
rows, err := conn.QueryContext(ctx, "SELECT b FROM "+t.name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(debug.Stack()))
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var v int
|
||||
err := rows.Scan(&v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(debug.Stack()))
|
||||
}
|
||||
if err := check(v); err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(debug.Stack()))
|
||||
}
|
||||
}
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(debug.Stack()))
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
g.Go(worker(t1, func(v int) error {
|
||||
if v <= 0 {
|
||||
return fmt.Errorf("got non-positive value from table1: %d", v)
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
g.Go(worker(t2, func(v int) error {
|
||||
if v >= 0 {
|
||||
return fmt.Errorf("got non-negative value from table2: %d", v)
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
g.Go(worker(t3, func(v int) error {
|
||||
if v != 0 {
|
||||
return fmt.Errorf("got non-zero value from table3: %d", v)
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
db.t.FatalOnError(g.Wait())
|
||||
}
|
||||
|
||||
func TestParameterSyntaxSupport(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
t1 := db.createTable()
|
||||
t2 := db.createTable()
|
||||
t3 := db.createTable()
|
||||
t4 := db.createTable()
|
||||
t5 := db.createTable()
|
||||
|
||||
t1.insertRowsInternal(1, 10, func(i int) sql.Result {
|
||||
return t1.db.exec("INSERT INTO "+t1.name+" VALUES(?1, ?1)", i)
|
||||
})
|
||||
t2.insertRowsInternal(1, 10, func(i int) sql.Result {
|
||||
return t2.db.exec("INSERT INTO "+t2.name+" VALUES(?, ?)", i, -1*i)
|
||||
})
|
||||
t3.insertRowsInternal(1, 10, func(i int) sql.Result {
|
||||
return t3.db.exec("INSERT INTO "+t3.name+" VALUES(:a, :b)", sql.Named("a", i), sql.Named("b", 0))
|
||||
})
|
||||
t4.insertRowsInternal(1, 10, func(i int) sql.Result {
|
||||
return t3.db.exec("INSERT INTO "+t4.name+" VALUES(@a, @a)", sql.Named("a", i), sql.Named("b", 0))
|
||||
})
|
||||
t5.insertRowsInternal(1, 10, func(i int) sql.Result {
|
||||
return t3.db.exec("INSERT INTO "+t5.name+" VALUES($a, $b)", sql.Named("a", i), sql.Named("b", 0))
|
||||
})
|
||||
}
|
||||
|
||||
func TestNamedArgs(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := getDb(T{t})
|
||||
table := db.createTable()
|
||||
table.insertRowsWithColonNamedArgs(0, 10)
|
||||
table.insertRowsWithAtNamedArgs(10, 10)
|
||||
table.insertRowsWithDollarNamedArgs(20, 10)
|
||||
table.assertRowsCount(30)
|
||||
table.assertRowDoesNotExist(30)
|
||||
table.assertRowExists(0)
|
||||
table.assertRowExists(29)
|
||||
}
|
||||
408
go_modules/libsql-client-go/tests/ws/driver_test.go
Normal file
408
go_modules/libsql-client-go/tests/ws/driver_test.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/tursodatabase/libsql-client-go/libsql"
|
||||
|
||||
_ "github.com/tursodatabase/libsql-client-go/libsql"
|
||||
)
|
||||
|
||||
// setupDB sets up a test database by connecting to libsql server and creates a `test` table
|
||||
func setupDB(ctx context.Context, t *testing.T) *sql.DB {
|
||||
dbURL := os.Getenv("LIBSQL_TEST_WS_DB_URL")
|
||||
authToken := os.Getenv("LIBSQL_TEST_WS_AUTH_TOKEN")
|
||||
var connector driver.Connector
|
||||
var err error
|
||||
if authToken == "" {
|
||||
connector, err = libsql.NewConnector(dbURL)
|
||||
} else {
|
||||
connector, err = libsql.NewConnector(dbURL, libsql.WithAuthToken(authToken))
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db := sql.OpenDB(connector)
|
||||
_, err = db.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS test (id INTEGER PRIMARY KEY, name TEXT)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func cleanupDB(ctx context.Context, t *testing.T, db *sql.DB) {
|
||||
_, err := db.ExecContext(ctx, "DROP TABLE test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db.Close()
|
||||
}
|
||||
|
||||
func assertRows(ctx context.Context, t *testing.T, db *sql.DB) {
|
||||
rows, err := db.QueryContext(ctx, "SELECT * FROM test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var expectedId int64 = 1
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var name string
|
||||
err := rows.Scan(&id, &name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if id != expectedId {
|
||||
t.Fatalf("id should be %d", expectedId)
|
||||
}
|
||||
if name != "hello world" {
|
||||
t.Fatal("name should be hello world")
|
||||
}
|
||||
expectedId++
|
||||
}
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := setupDB(ctx, t)
|
||||
|
||||
// This ping should succeed because the database is up and running
|
||||
err := db.PingContext(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cleanupDB(ctx, t, db)
|
||||
// This ping should return an error because the database is already closed
|
||||
err = db.PingContext(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("db.Ping succeeded when it should have failed")
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestQueryExec(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := setupDB(ctx, t)
|
||||
_, err := db.ExecContext(ctx, "INSERT INTO test (name) VALUES (?)", "hello world")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertRows(ctx, t, db)
|
||||
|
||||
t.Cleanup(func() {
|
||||
cleanupDB(ctx, t, db)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := setupDB(ctx, t)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, "INSERT INTO test (name) VALUES (?)", "hello world")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertRows(ctx, t, db)
|
||||
|
||||
t.Cleanup(func() {
|
||||
cleanupDB(ctx, t, db)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPreparedStatements(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := setupDB(ctx, t)
|
||||
stmt, err := db.PrepareContext(ctx, "INSERT INTO test (name) VALUES (?)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err := stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
_, err = stmt.ExecContext(ctx, "hello world")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertRows(ctx, t, db)
|
||||
|
||||
t.Cleanup(func() {
|
||||
cleanupDB(ctx, t, db)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPreparedStatementsWithTransactions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := setupDB(ctx, t)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
stmt, err := tx.PrepareContext(ctx, "INSERT INTO test (name) VALUES (?)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, "hello world")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
func() { // needed to surpress linter
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertRows(ctx, t, db)
|
||||
|
||||
t.Cleanup(func() {
|
||||
cleanupDB(ctx, t, db)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactionsRollback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := setupDB(ctx, t)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, "INSERT INTO test (name) VALUES (?)", "hello world")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var count int
|
||||
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM test").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("count should be 0")
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
cleanupDB(ctx, t, db)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPreparedStatementsWithTransactionsRollback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := setupDB(ctx, t)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
stmt, err := tx.PrepareContext(ctx, "INSERT INTO test (name) VALUES (?)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, "hello world")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
func() { // needed to surpress linter
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var count int
|
||||
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM test").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("count should be 0")
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
cleanupDB(ctx, t, db)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCancelContext(t *testing.T) {
|
||||
db := setupDB(context.Background(), t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err := db.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS test (id INTEGER PRIMARY KEY, name TEXT)")
|
||||
if err == nil {
|
||||
t.Fatal("should have failed")
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatal("should have failed with context.Canceled")
|
||||
}
|
||||
db.Close()
|
||||
}
|
||||
|
||||
func TestCancelTransactionWithContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
db := setupDB(ctx, t)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, "INSERT INTO test (name) VALUES (?)", "hello world")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// let's cancel the context before the commit
|
||||
cancel()
|
||||
err = tx.Commit()
|
||||
if err == nil {
|
||||
t.Fatal("should have failed")
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatal("should have failed with context.Canceled")
|
||||
}
|
||||
// rolling back the transaction should not result in any error
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cleanupDB(context.Background(), t, db)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDataTypes(t *testing.T) {
|
||||
db := setupDB(context.Background(), t)
|
||||
defer db.Close()
|
||||
ctx := context.Background()
|
||||
var (
|
||||
text string
|
||||
nullText sql.NullString
|
||||
integer sql.NullInt64
|
||||
nullInteger sql.NullInt64
|
||||
boolean bool
|
||||
float8 float64
|
||||
nullFloat sql.NullFloat64
|
||||
bytea []byte
|
||||
)
|
||||
err := db.QueryRowContext(ctx, "SELECT 'foobar' as text, NULL as text, NULL as integer, 42 as integer, 1 as boolean, X'000102' as bytea, 3.14 as float8, NULL as float8;").Scan(&text, &nullText, &nullInteger, &integer, &boolean, &bytea, &float8, &nullFloat)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
switch {
|
||||
case text != "foobar":
|
||||
t.Error("value mismatch - text")
|
||||
case nullText.Valid:
|
||||
t.Error("null text is valid")
|
||||
case nullInteger.Valid:
|
||||
t.Error("null integer is valid")
|
||||
case !integer.Valid:
|
||||
t.Error("integer is not valid")
|
||||
case integer.Int64 != 42:
|
||||
t.Error("value mismatch - integer")
|
||||
case !boolean:
|
||||
t.Error("value mismatch - boolean")
|
||||
case float8 != 3.14:
|
||||
t.Error("value mismatch - float8")
|
||||
case !bytes.Equal(bytea, []byte{0, 1, 2}):
|
||||
t.Error("value mismatch - bytea")
|
||||
case nullFloat.Valid:
|
||||
t.Error("null float is valid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecResult(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := setupDB(ctx, t)
|
||||
result, err := db.ExecContext(ctx, "INSERT INTO test (name) VALUES (?)", "hello world")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertRows(ctx, t, db)
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if affected != 1 {
|
||||
t.Fatal("affected should be 1")
|
||||
}
|
||||
lastInsertID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if lastInsertID != 1 {
|
||||
t.Fatal("lastInsertID should be 1")
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cleanupDB(ctx, t, db)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParameterSyntaxSupport(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
db := setupDB(ctx, t)
|
||||
|
||||
_, err := db.ExecContext(ctx, "INSERT INTO test (name) VALUES (?1)", "hello world")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertRows(ctx, t, db)
|
||||
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO test (name) VALUES (?)", "hello world")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertRows(ctx, t, db)
|
||||
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO test (name) VALUES (:a)", sql.Named("a", "hello world"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertRows(ctx, t, db)
|
||||
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO test (name) VALUES (@a)", sql.Named("a", "hello world"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertRows(ctx, t, db)
|
||||
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO test (name) VALUES ($a)", sql.Named("a", "hello world"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertRows(ctx, t, db)
|
||||
|
||||
t.Cleanup(func() {
|
||||
cleanupDB(ctx, t, db)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user