package ws import ( "context" "database/sql/driver" "encoding/base64" "fmt" "strconv" "sync" "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) // defaultWSTimeout specifies the timeout used for initial http connection var defaultWSTimeout = 120 * time.Second func errorMsg(errorResp interface{}) string { return errorResp.(map[string]interface{})["error"].(map[string]interface{})["message"].(string) } func isErrorResp(resp interface{}) bool { return resp.(map[string]interface{})["type"] == "response_error" } type websocketConn struct { conn *websocket.Conn idPool *idPool } type namedParam struct { Name string Value any } type params struct { PositinalArgs []any NamedArgs []namedParam } func convertValue(v any) (map[string]interface{}, error) { res := map[string]interface{}{} if v == nil { res["type"] = "null" } else if integer, ok := v.(int64); ok { res["type"] = "integer" res["value"] = strconv.FormatInt(integer, 10) } else if text, ok := v.(string); ok { res["type"] = "text" res["value"] = text } else if blob, ok := v.([]byte); ok { res["type"] = "blob" res["base64"] = base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString(blob) } else if float, ok := v.(float64); ok { res["type"] = "float" res["value"] = float } else if boolean, ok := v.(bool); ok { res["type"] = "integer" if boolean { res["value"] = "1" } else { res["value"] = "0" } } else { return nil, fmt.Errorf("unsupported value type: %s", v) } return res, nil } type execResponse struct { resp map[string]interface{} } func (r *execResponse) affectedRowCount() int64 { return int64(r.resp["affected_row_count"].(float64)) } func (r *execResponse) lastInsertId() int64 { id, ok := r.resp["last_insert_rowid"].(string) if !ok { return 0 } value, _ := strconv.ParseInt(id, 10, 64) return value } func (r *execResponse) columns() []string { res := []string{} cols := r.resp["cols"].([]interface{}) for idx := range cols { var v string = "" if cols[idx].(map[string]interface{})["name"] != nil { v = cols[idx].(map[string]interface{})["name"].(string) } res = append(res, v) } return res } func (r *execResponse) rowsCount() int { return len(r.resp["rows"].([]interface{})) } func (r *execResponse) rowLen(rowIdx int) int { return len(r.resp["rows"].([]interface{})[rowIdx].([]interface{})) } func (r *execResponse) value(rowIdx int, colIdx int) (any, error) { val := r.resp["rows"].([]interface{})[rowIdx].([]interface{})[colIdx].(map[string]interface{}) switch val["type"] { case "null": return nil, nil case "integer": v, err := strconv.ParseInt(val["value"].(string), 10, 64) if err != nil { return nil, err } return v, nil case "text": return val["value"].(string), nil case "blob": base64Encoded := val["base64"].(string) v, err := base64.StdEncoding.WithPadding(base64.NoPadding).DecodeString(base64Encoded) if err != nil { return nil, err } return v, nil case "float": return val["value"].(float64), nil } return nil, fmt.Errorf("unrecognized value type: %s", val["type"]) } func (ws *websocketConn) exec(ctx context.Context, sql string, sqlParams params, wantRows bool) (*execResponse, error) { requestId := ws.idPool.Get() defer ws.idPool.Put(requestId) stmt := map[string]interface{}{ "sql": sql, "want_rows": wantRows, } if len(sqlParams.PositinalArgs) > 0 { args := []map[string]interface{}{} for idx := range sqlParams.PositinalArgs { v, err := convertValue(sqlParams.PositinalArgs[idx]) if err != nil { return nil, err } args = append(args, v) } stmt["args"] = args } if len(sqlParams.NamedArgs) > 0 { args := []map[string]interface{}{} for idx := range sqlParams.NamedArgs { v, err := convertValue(sqlParams.NamedArgs[idx].Value) if err != nil { return nil, err } arg := map[string]interface{}{ "name": sqlParams.NamedArgs[idx].Name, "value": v, } args = append(args, arg) } stmt["named_args"] = args } err := wsjson.Write(ctx, ws.conn, map[string]interface{}{ "type": "request", "request_id": requestId, "request": map[string]interface{}{ "type": "execute", "stream_id": 0, "stmt": stmt, }, }) if err != nil { return nil, fmt.Errorf("%w: %s", driver.ErrBadConn, err.Error()) } var resp interface{} if err = wsjson.Read(ctx, ws.conn, &resp); err != nil { return nil, fmt.Errorf("%w: %s", driver.ErrBadConn, err.Error()) } if isErrorResp(resp) { err = fmt.Errorf("unable to execute %s: %s", sql, errorMsg(resp)) return nil, err } return &execResponse{resp.(map[string]interface{})["response"].(map[string]interface{})["result"].(map[string]interface{})}, nil } func (ws *websocketConn) Close() error { return ws.conn.Close(websocket.StatusNormalClosure, "All's good") } func connect(url string, jwt string) (*websocketConn, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultWSTimeout) defer cancel() c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{ Subprotocols: []string{"hrana1"}, }) if err != nil { return nil, err } c.SetReadLimit(1024 * 1024 * 16) // 16MB err = wsjson.Write(ctx, c, map[string]interface{}{ "type": "hello", "jwt": jwt, }) if err != nil { c.Close(websocket.StatusInternalError, err.Error()) return nil, err } err = wsjson.Write(ctx, c, map[string]interface{}{ "type": "request", "request_id": 0, "request": map[string]interface{}{ "type": "open_stream", "stream_id": 0, }, }) if err != nil { c.Close(websocket.StatusInternalError, err.Error()) return nil, err } var helloResp interface{} err = wsjson.Read(ctx, c, &helloResp) if err != nil { c.Close(websocket.StatusInternalError, err.Error()) return nil, err } if helloResp.(map[string]interface{})["type"] == "hello_error" { err = fmt.Errorf("handshake error: %s", errorMsg(helloResp)) c.Close(websocket.StatusProtocolError, err.Error()) return nil, err } var openStreamResp interface{} err = wsjson.Read(ctx, c, &openStreamResp) if err != nil { c.Close(websocket.StatusInternalError, err.Error()) return nil, err } if isErrorResp(openStreamResp) { err = fmt.Errorf("unable to open stream: %s", errorMsg(helloResp)) c.Close(websocket.StatusProtocolError, err.Error()) return nil, err } return &websocketConn{c, newIDPool()}, nil } // Below is modified IDPool from "vitess.io/vitess/go/pools" // idPool is used to ensure that the set of IDs in use concurrently never // contains any duplicates. The IDs start at 1 and increase without bound, but // will never be larger than the peak number of concurrent uses. // // idPool's Get() and Put() methods can be used concurrently. type idPool struct { sync.Mutex // used holds the set of values that have been returned to us with Put(). used map[uint32]bool // maxUsed remembers the largest value we've given out. maxUsed uint32 } // NewIDPool creates and initializes an idPool. func newIDPool() *idPool { return &idPool{ used: make(map[uint32]bool), maxUsed: 0, } } // Get returns an ID that is unique among currently active users of this pool. func (pool *idPool) Get() (id uint32) { pool.Lock() defer pool.Unlock() // Pick a value that's been returned, if any. for key := range pool.used { delete(pool.used, key) return key } // No recycled IDs are available, so increase the pool size. pool.maxUsed++ return pool.maxUsed } // Put recycles an ID back into the pool for others to use. Putting back a value // or 0, or a value that is not currently "checked out", will result in a panic // because that should never happen except in the case of a programming error. func (pool *idPool) Put(id uint32) { pool.Lock() defer pool.Unlock() if id < 1 || id > pool.maxUsed { panic(fmt.Errorf("idPool.Put(%v): invalid value, must be in the range [1,%v]", id, pool.maxUsed)) } if pool.used[id] { panic(fmt.Errorf("idPool.Put(%v): can't put value that was already recycled", id)) } // If we're recycling maxUsed, just shrink the pool. if id == pool.maxUsed { pool.maxUsed = id - 1 return } // Add it to the set of recycled IDs. pool.used[id] = true }