fix: 补充提交 memory 模块和 tts/userdir 文件
This commit is contained in:
371
cmd/hxclaw/internal/memory/db.go
Normal file
371
cmd/hxclaw/internal/memory/db.go
Normal file
@@ -0,0 +1,371 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
|
||||
|
||||
_ "github.com/tursodatabase/libsql-client-go/libsql"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
var (
|
||||
db *DB
|
||||
cfg *DBConfig
|
||||
)
|
||||
|
||||
type DBConfig struct {
|
||||
DBPath string
|
||||
}
|
||||
|
||||
type DBSOption func(*DBConfig)
|
||||
|
||||
func WithDBPath(path string) DBSOption {
|
||||
return func(c *DBConfig) {
|
||||
c.DBPath = path
|
||||
}
|
||||
}
|
||||
|
||||
func getDefaultDBPath() string {
|
||||
return internal.GetDBFile()
|
||||
}
|
||||
|
||||
// GetDefaultDBPath 获取默认数据库路径(导出)
|
||||
func GetDefaultDBPath() string {
|
||||
return getDefaultDBPath()
|
||||
}
|
||||
|
||||
func Init(opts ...DBSOption) error {
|
||||
memoryCfg := internal.GetProjectConfig().Memory
|
||||
|
||||
cfg = &DBConfig{
|
||||
DBPath: memoryCfg.DBPath,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
path := cfg.DBPath
|
||||
if path == "" {
|
||||
path = getDefaultDBPath()
|
||||
}
|
||||
|
||||
if dir := filepath.Dir(path); dir != "" {
|
||||
os.MkdirAll(dir, 0755)
|
||||
}
|
||||
|
||||
dbInst, err := sql.Open("libsql", "file:"+path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库失败: %w", err)
|
||||
}
|
||||
|
||||
if err := dbInst.Ping(); err != nil {
|
||||
return fmt.Errorf("ping 数据库失败: %w", err)
|
||||
}
|
||||
|
||||
db = &DB{db: dbInst}
|
||||
return db.createTables()
|
||||
}
|
||||
|
||||
func GetDB() *DB {
|
||||
return db
|
||||
}
|
||||
|
||||
func (d *DB) createTables() error {
|
||||
queries := []string{
|
||||
`CREATE TABLE IF NOT EXISTS sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
uuid TEXT UNIQUE NOT NULL,
|
||||
summary TEXT,
|
||||
summary_embedding BLOB,
|
||||
chat_ids TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS chats (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id INTEGER NOT NULL,
|
||||
user_input TEXT NOT NULL,
|
||||
ai_replies TEXT,
|
||||
summary TEXT,
|
||||
summary_embedding BLOB,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
FOREIGN KEY(session_id) REFERENCES sessions(id)
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_chats_session_id ON chats(session_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_sessions_uuid ON sessions(uuid)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
if _, err := d.db.ExecContext(context.Background(), q); err != nil {
|
||||
return fmt.Errorf("创建表失败: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) CreateSession(uuid string) (*Session, error) {
|
||||
s := NewSession(uuid)
|
||||
|
||||
result, err := d.db.ExecContext(
|
||||
context.Background(),
|
||||
`INSERT INTO sessions (uuid, summary, summary_embedding, chat_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
s.UUID, s.Summary, s.SummaryEmbedding, "[]", s.CreatedAt, s.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建会话失败: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取ID失败: %w", err)
|
||||
}
|
||||
|
||||
s.ID = id
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetSessionByUUID(uuid string) (*Session, error) {
|
||||
s := &Session{}
|
||||
var chatIDsJSON string
|
||||
err := d.db.QueryRowContext(
|
||||
context.Background(),
|
||||
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions WHERE uuid = ?`,
|
||||
uuid,
|
||||
).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询会话失败: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
|
||||
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetSessionByID(id int64) (*Session, error) {
|
||||
s := &Session{}
|
||||
var chatIDsJSON string
|
||||
err := d.db.QueryRowContext(
|
||||
context.Background(),
|
||||
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions WHERE id = ?`,
|
||||
id,
|
||||
).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询会话失败: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
|
||||
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateSession(s *Session) error {
|
||||
chatIDsJSON, err := s.GetChatIDsJSON()
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化chat_ids失败: %w", err)
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(
|
||||
context.Background(),
|
||||
`UPDATE sessions SET summary = ?, summary_embedding = ?, chat_ids = ?, updated_at = ? WHERE id = ?`,
|
||||
s.Summary, s.SummaryEmbedding, chatIDsJSON, s.UpdatedAt, s.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新会话失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) CreateChat(sessionID int64, userInput string) (*Chat, error) {
|
||||
c := NewChat(sessionID, userInput)
|
||||
|
||||
result, err := d.db.ExecContext(
|
||||
context.Background(),
|
||||
`INSERT INTO chats (session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
c.SessionID, c.UserInput, "[]", c.Summary, c.SummaryEmbedding, c.CreatedAt, c.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建聊天记录失败: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取ID失败: %w", err)
|
||||
}
|
||||
|
||||
c.ID = id
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetChatByID(id int64) (*Chat, error) {
|
||||
c := &Chat{}
|
||||
var aiRepliesJSON string
|
||||
err := d.db.QueryRowContext(
|
||||
context.Background(),
|
||||
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE id = ?`,
|
||||
id,
|
||||
).Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &c.SummaryEmbedding, &c.CreatedAt, &c.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询聊天记录失败: %w", err)
|
||||
}
|
||||
|
||||
if err := c.SetAIRepliesFromJSON(aiRepliesJSON); err != nil {
|
||||
return nil, fmt.Errorf("解析ai_replies失败: %w", err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateChat(c *Chat) error {
|
||||
aiRepliesJSON, err := c.GetAIRepliesJSON()
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化ai_replies失败: %w", err)
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(
|
||||
context.Background(),
|
||||
`UPDATE chats SET ai_replies = ?, summary = ?, summary_embedding = ?, updated_at = ? WHERE id = ?`,
|
||||
aiRepliesJSON, c.Summary, c.SummaryEmbedding, c.UpdatedAt, c.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新聊天记录失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) GetLatestSession() (*Session, error) {
|
||||
s := &Session{}
|
||||
var chatIDsJSON string
|
||||
err := d.db.QueryRowContext(
|
||||
context.Background(),
|
||||
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions ORDER BY id DESC LIMIT 1`,
|
||||
).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询最新会话失败: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
|
||||
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetAllSessions() ([]*Session, error) {
|
||||
rows, err := d.db.QueryContext(
|
||||
context.Background(),
|
||||
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions ORDER BY id DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询会话列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []*Session
|
||||
for rows.Next() {
|
||||
s := &Session{}
|
||||
var chatIDsJSON string
|
||||
if err := rows.Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描会话失败: %w", err)
|
||||
}
|
||||
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
|
||||
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
|
||||
}
|
||||
sessions = append(sessions, s)
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetChatsBySessionID(sessionID int64) ([]*Chat, error) {
|
||||
rows, err := d.db.QueryContext(
|
||||
context.Background(),
|
||||
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE session_id = ? ORDER BY id`,
|
||||
sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询聊天记录失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var chats []*Chat
|
||||
for rows.Next() {
|
||||
c := &Chat{}
|
||||
var aiRepliesJSON string
|
||||
if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &c.SummaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描聊天记录失败: %w", err)
|
||||
}
|
||||
if err := c.SetAIRepliesFromJSON(aiRepliesJSON); err != nil {
|
||||
return nil, fmt.Errorf("解析ai_replies失败: %w", err)
|
||||
}
|
||||
chats = append(chats, c)
|
||||
}
|
||||
return chats, nil
|
||||
}
|
||||
|
||||
func (d *DB) Close() error {
|
||||
if db != nil && db.db != nil {
|
||||
return db.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateNewSession() (*Session, error) {
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("数据库未初始化")
|
||||
}
|
||||
return db.CreateSession(generateUUID())
|
||||
}
|
||||
|
||||
func GetLatestSession() (*Session, error) {
|
||||
if db == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return db.GetLatestSession()
|
||||
}
|
||||
|
||||
func GetAllSessions() ([]*Session, error) {
|
||||
if db == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return db.GetAllSessions()
|
||||
}
|
||||
|
||||
func ListSessions() ([]*Session, error) {
|
||||
return GetAllSessions()
|
||||
}
|
||||
|
||||
func GetCurrentSession() *Session {
|
||||
return currentSession
|
||||
}
|
||||
|
||||
var currentSession *Session
|
||||
|
||||
func generateUUID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
Reference in New Issue
Block a user