package memory import ( "context" "database/sql" "fmt" "os" "path/filepath" "github.com/google/uuid" "github.com/hxclaw/hxclaw/cmd/hxclaw/internal" "github.com/sipeed/picoclaw/pkg/agent" _ "github.com/tursodatabase/libsql-client-go/libsql" ) type DB struct { db *sql.DB } var ( db *DB cfg *DBConfig summaryAgent *agent.AgentLoop ) type DBConfig struct { DBPath string } type DBSOption func(*DBConfig) func WithDBPath(path string) DBSOption { return func(c *DBConfig) { c.DBPath = path } } func getDefaultDBPath() string { return internal.GetDBFile() } // GetDefaultDBPath 获取默认数据库路径(导出) func GetDefaultDBPath() string { return getDefaultDBPath() } func Init(agentLoop *agent.AgentLoop, opts ...DBSOption) error { summaryAgent = agentLoop memoryCfg := internal.GetProjectConfig().Memory cfg = &DBConfig{ DBPath: memoryCfg.DBPath, } for _, opt := range opts { opt(cfg) } path := cfg.DBPath if path == "" { path = getDefaultDBPath() } if dir := filepath.Dir(path); dir != "" { os.MkdirAll(dir, 0755) } dbInst, err := sql.Open("libsql", "file:"+path) if err != nil { return fmt.Errorf("打开数据库失败: %w", err) } if err := dbInst.Ping(); err != nil { return fmt.Errorf("ping 数据库失败: %w", err) } db = &DB{db: dbInst} return db.createTables() } func GetDB() *DB { return db } func (d *DB) createTables() error { queries := []string{ `CREATE TABLE IF NOT EXISTS sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, uuid TEXT UNIQUE NOT NULL, summary TEXT, summary_embedding BLOB, chat_ids TEXT, created_at INTEGER NOT NULL, updated_at INTEGER NOT NULL )`, `CREATE TABLE IF NOT EXISTS chats ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id INTEGER NOT NULL, user_input TEXT NOT NULL, ai_replies TEXT, summary TEXT, summary_embedding BLOB, created_at INTEGER NOT NULL, updated_at INTEGER NOT NULL, FOREIGN KEY(session_id) REFERENCES sessions(id) )`, `CREATE INDEX IF NOT EXISTS idx_chats_session_id ON chats(session_id)`, `CREATE INDEX IF NOT EXISTS idx_sessions_uuid ON sessions(uuid)`, } for _, q := range queries { if _, err := d.db.ExecContext(context.Background(), q); err != nil { return fmt.Errorf("创建表失败: %w", err) } } return nil } func (d *DB) CreateSession(uuid string) (*Session, error) { s := NewSession(uuid) result, err := d.db.ExecContext( context.Background(), `INSERT INTO sessions (uuid, summary, summary_embedding, chat_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, s.UUID, s.Summary, s.SummaryEmbedding, "[]", s.CreatedAt, s.UpdatedAt, ) if err != nil { return nil, fmt.Errorf("创建会话失败: %w", err) } id, err := result.LastInsertId() if err != nil { return nil, fmt.Errorf("获取ID失败: %w", err) } s.ID = id return s, nil } func (d *DB) GetSessionByUUID(uuid string) (*Session, error) { s := &Session{} var chatIDsJSON string err := d.db.QueryRowContext( context.Background(), `SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions WHERE uuid = ?`, uuid, ).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("查询会话失败: %w", err) } if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil { return nil, fmt.Errorf("解析chat_ids失败: %w", err) } return s, nil } func (d *DB) GetSessionByID(id int64) (*Session, error) { s := &Session{} var chatIDsJSON string err := d.db.QueryRowContext( context.Background(), `SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions WHERE id = ?`, id, ).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("查询会话失败: %w", err) } if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil { return nil, fmt.Errorf("解析chat_ids失败: %w", err) } return s, nil } func (d *DB) UpdateSession(s *Session) error { chatIDsJSON, err := s.GetChatIDsJSON() if err != nil { return fmt.Errorf("序列化chat_ids失败: %w", err) } _, err = d.db.ExecContext( context.Background(), `UPDATE sessions SET summary = ?, summary_embedding = ?, chat_ids = ?, updated_at = ? WHERE id = ?`, s.Summary, s.SummaryEmbedding, chatIDsJSON, s.UpdatedAt, s.ID, ) if err != nil { return fmt.Errorf("更新会话失败: %w", err) } return nil } func (d *DB) CreateChat(sessionID int64, userInput string) (*Chat, error) { c := NewChat(sessionID, userInput) result, err := d.db.ExecContext( context.Background(), `INSERT INTO chats (session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)`, c.SessionID, c.UserInput, "[]", c.Summary, c.SummaryEmbedding, c.CreatedAt, c.UpdatedAt, ) if err != nil { return nil, fmt.Errorf("创建聊天记录失败: %w", err) } id, err := result.LastInsertId() if err != nil { return nil, fmt.Errorf("获取ID失败: %w", err) } c.ID = id return c, nil } func (d *DB) GetChatByID(id int64) (*Chat, error) { c := &Chat{} var aiRepliesJSON string err := d.db.QueryRowContext( context.Background(), `SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE id = ?`, id, ).Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &c.SummaryEmbedding, &c.CreatedAt, &c.UpdatedAt) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("查询聊天记录失败: %w", err) } if err := c.SetAIRepliesFromJSON(aiRepliesJSON); err != nil { return nil, fmt.Errorf("解析ai_replies失败: %w", err) } return c, nil } func (d *DB) UpdateChat(c *Chat) error { aiRepliesJSON, err := c.GetAIRepliesJSON() if err != nil { return fmt.Errorf("序列化ai_replies失败: %w", err) } _, err = d.db.ExecContext( context.Background(), `UPDATE chats SET ai_replies = ?, summary = ?, summary_embedding = ?, updated_at = ? WHERE id = ?`, aiRepliesJSON, c.Summary, c.SummaryEmbedding, c.UpdatedAt, c.ID, ) if err != nil { return fmt.Errorf("更新聊天记录失败: %w", err) } return nil } func (d *DB) GetLatestSession() (*Session, error) { s := &Session{} var chatIDsJSON string err := d.db.QueryRowContext( context.Background(), `SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions ORDER BY id DESC LIMIT 1`, ).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("查询最新会话失败: %w", err) } if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil { return nil, fmt.Errorf("解析chat_ids失败: %w", err) } return s, nil } func (d *DB) GetAllSessions() ([]*Session, error) { rows, err := d.db.QueryContext( context.Background(), `SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions ORDER BY id DESC`, ) if err != nil { return nil, fmt.Errorf("查询会话列表失败: %w", err) } defer rows.Close() var sessions []*Session for rows.Next() { s := &Session{} var chatIDsJSON string if err := rows.Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt); err != nil { return nil, fmt.Errorf("扫描会话失败: %w", err) } if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil { return nil, fmt.Errorf("解析chat_ids失败: %w", err) } sessions = append(sessions, s) } return sessions, nil } func (d *DB) GetChatsBySessionID(sessionID int64) ([]*Chat, error) { rows, err := d.db.QueryContext( context.Background(), `SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE session_id = ? ORDER BY id`, sessionID, ) if err != nil { return nil, fmt.Errorf("查询聊天记录失败: %w", err) } defer rows.Close() var chats []*Chat for rows.Next() { c := &Chat{} var aiRepliesJSON string if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &c.SummaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil { return nil, fmt.Errorf("扫描聊天记录失败: %w", err) } if err := c.SetAIRepliesFromJSON(aiRepliesJSON); err != nil { return nil, fmt.Errorf("解析ai_replies失败: %w", err) } chats = append(chats, c) } return chats, nil } func (d *DB) Close() error { if db != nil && db.db != nil { return db.db.Close() } return nil } func CreateNewSession() (*Session, error) { if db == nil { return nil, fmt.Errorf("数据库未初始化") } return db.CreateSession(generateUUID()) } func GetLatestSession() (*Session, error) { if db == nil { return nil, nil } return db.GetLatestSession() } func GetAllSessions() ([]*Session, error) { if db == nil { return nil, nil } return db.GetAllSessions() } func ListSessions() ([]*Session, error) { return GetAllSessions() } func GetCurrentSession() *Session { return currentSession } var currentSession *Session func generateUUID() string { return uuid.New().String() }