Files
HxClaw/cmd/hxclaw/internal/memory/db.go

371 lines
9.1 KiB
Go
Raw Normal View History

package memory
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"github.com/google/uuid"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
_ "github.com/tursodatabase/libsql-client-go/libsql"
)
type DB struct {
db *sql.DB
}
var (
db *DB
cfg *DBConfig
)
type DBConfig struct {
DBPath string
}
type DBSOption func(*DBConfig)
func WithDBPath(path string) DBSOption {
return func(c *DBConfig) {
c.DBPath = path
}
}
func getDefaultDBPath() string {
return internal.GetDBFile()
}
// GetDefaultDBPath 获取默认数据库路径(导出)
func GetDefaultDBPath() string {
return getDefaultDBPath()
}
func Init(opts ...DBSOption) error {
memoryCfg := internal.GetProjectConfig().Memory
cfg = &DBConfig{
DBPath: memoryCfg.DBPath,
}
for _, opt := range opts {
opt(cfg)
}
path := cfg.DBPath
if path == "" {
path = getDefaultDBPath()
}
if dir := filepath.Dir(path); dir != "" {
os.MkdirAll(dir, 0755)
}
dbInst, err := sql.Open("libsql", "file:"+path)
if err != nil {
return fmt.Errorf("打开数据库失败: %w", err)
}
if err := dbInst.Ping(); err != nil {
return fmt.Errorf("ping 数据库失败: %w", err)
}
db = &DB{db: dbInst}
return db.createTables()
}
func GetDB() *DB {
return db
}
func (d *DB) createTables() error {
queries := []string{
`CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
uuid TEXT UNIQUE NOT NULL,
summary TEXT,
summary_embedding BLOB,
chat_ids TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
)`,
`CREATE TABLE IF NOT EXISTS chats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL,
user_input TEXT NOT NULL,
ai_replies TEXT,
summary TEXT,
summary_embedding BLOB,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
FOREIGN KEY(session_id) REFERENCES sessions(id)
)`,
`CREATE INDEX IF NOT EXISTS idx_chats_session_id ON chats(session_id)`,
`CREATE INDEX IF NOT EXISTS idx_sessions_uuid ON sessions(uuid)`,
}
for _, q := range queries {
if _, err := d.db.ExecContext(context.Background(), q); err != nil {
return fmt.Errorf("创建表失败: %w", err)
}
}
return nil
}
func (d *DB) CreateSession(uuid string) (*Session, error) {
s := NewSession(uuid)
result, err := d.db.ExecContext(
context.Background(),
`INSERT INTO sessions (uuid, summary, summary_embedding, chat_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)`,
s.UUID, s.Summary, s.SummaryEmbedding, "[]", s.CreatedAt, s.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("创建会话失败: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("获取ID失败: %w", err)
}
s.ID = id
return s, nil
}
func (d *DB) GetSessionByUUID(uuid string) (*Session, error) {
s := &Session{}
var chatIDsJSON string
err := d.db.QueryRowContext(
context.Background(),
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions WHERE uuid = ?`,
uuid,
).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询会话失败: %w", err)
}
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
}
return s, nil
}
func (d *DB) GetSessionByID(id int64) (*Session, error) {
s := &Session{}
var chatIDsJSON string
err := d.db.QueryRowContext(
context.Background(),
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions WHERE id = ?`,
id,
).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询会话失败: %w", err)
}
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
}
return s, nil
}
func (d *DB) UpdateSession(s *Session) error {
chatIDsJSON, err := s.GetChatIDsJSON()
if err != nil {
return fmt.Errorf("序列化chat_ids失败: %w", err)
}
_, err = d.db.ExecContext(
context.Background(),
`UPDATE sessions SET summary = ?, summary_embedding = ?, chat_ids = ?, updated_at = ? WHERE id = ?`,
s.Summary, s.SummaryEmbedding, chatIDsJSON, s.UpdatedAt, s.ID,
)
if err != nil {
return fmt.Errorf("更新会话失败: %w", err)
}
return nil
}
func (d *DB) CreateChat(sessionID int64, userInput string) (*Chat, error) {
c := NewChat(sessionID, userInput)
result, err := d.db.ExecContext(
context.Background(),
`INSERT INTO chats (session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)`,
c.SessionID, c.UserInput, "[]", c.Summary, c.SummaryEmbedding, c.CreatedAt, c.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("创建聊天记录失败: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("获取ID失败: %w", err)
}
c.ID = id
return c, nil
}
func (d *DB) GetChatByID(id int64) (*Chat, error) {
c := &Chat{}
var aiRepliesJSON string
err := d.db.QueryRowContext(
context.Background(),
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE id = ?`,
id,
).Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &c.SummaryEmbedding, &c.CreatedAt, &c.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询聊天记录失败: %w", err)
}
if err := c.SetAIRepliesFromJSON(aiRepliesJSON); err != nil {
return nil, fmt.Errorf("解析ai_replies失败: %w", err)
}
return c, nil
}
func (d *DB) UpdateChat(c *Chat) error {
aiRepliesJSON, err := c.GetAIRepliesJSON()
if err != nil {
return fmt.Errorf("序列化ai_replies失败: %w", err)
}
_, err = d.db.ExecContext(
context.Background(),
`UPDATE chats SET ai_replies = ?, summary = ?, summary_embedding = ?, updated_at = ? WHERE id = ?`,
aiRepliesJSON, c.Summary, c.SummaryEmbedding, c.UpdatedAt, c.ID,
)
if err != nil {
return fmt.Errorf("更新聊天记录失败: %w", err)
}
return nil
}
func (d *DB) GetLatestSession() (*Session, error) {
s := &Session{}
var chatIDsJSON string
err := d.db.QueryRowContext(
context.Background(),
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions ORDER BY id DESC LIMIT 1`,
).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询最新会话失败: %w", err)
}
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
}
return s, nil
}
func (d *DB) GetAllSessions() ([]*Session, error) {
rows, err := d.db.QueryContext(
context.Background(),
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions ORDER BY id DESC`,
)
if err != nil {
return nil, fmt.Errorf("查询会话列表失败: %w", err)
}
defer rows.Close()
var sessions []*Session
for rows.Next() {
s := &Session{}
var chatIDsJSON string
if err := rows.Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt); err != nil {
return nil, fmt.Errorf("扫描会话失败: %w", err)
}
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
}
sessions = append(sessions, s)
}
return sessions, nil
}
func (d *DB) GetChatsBySessionID(sessionID int64) ([]*Chat, error) {
rows, err := d.db.QueryContext(
context.Background(),
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE session_id = ? ORDER BY id`,
sessionID,
)
if err != nil {
return nil, fmt.Errorf("查询聊天记录失败: %w", err)
}
defer rows.Close()
var chats []*Chat
for rows.Next() {
c := &Chat{}
var aiRepliesJSON string
if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &c.SummaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil {
return nil, fmt.Errorf("扫描聊天记录失败: %w", err)
}
if err := c.SetAIRepliesFromJSON(aiRepliesJSON); err != nil {
return nil, fmt.Errorf("解析ai_replies失败: %w", err)
}
chats = append(chats, c)
}
return chats, nil
}
func (d *DB) Close() error {
if db != nil && db.db != nil {
return db.db.Close()
}
return nil
}
func CreateNewSession() (*Session, error) {
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
return db.CreateSession(generateUUID())
}
func GetLatestSession() (*Session, error) {
if db == nil {
return nil, nil
}
return db.GetLatestSession()
}
func GetAllSessions() ([]*Session, error) {
if db == nil {
return nil, nil
}
return db.GetAllSessions()
}
func ListSessions() ([]*Session, error) {
return GetAllSessions()
}
func GetCurrentSession() *Session {
return currentSession
}
var currentSession *Session
func generateUUID() string {
return uuid.New().String()
}