371 lines
9.1 KiB
Go
371 lines
9.1 KiB
Go
|
|
package memory
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"database/sql"
|
||
|
|
"fmt"
|
||
|
|
"os"
|
||
|
|
"path/filepath"
|
||
|
|
|
||
|
|
"github.com/google/uuid"
|
||
|
|
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
|
||
|
|
|
||
|
|
_ "github.com/tursodatabase/libsql-client-go/libsql"
|
||
|
|
)
|
||
|
|
|
||
|
|
type DB struct {
|
||
|
|
db *sql.DB
|
||
|
|
}
|
||
|
|
|
||
|
|
var (
|
||
|
|
db *DB
|
||
|
|
cfg *DBConfig
|
||
|
|
)
|
||
|
|
|
||
|
|
type DBConfig struct {
|
||
|
|
DBPath string
|
||
|
|
}
|
||
|
|
|
||
|
|
type DBSOption func(*DBConfig)
|
||
|
|
|
||
|
|
func WithDBPath(path string) DBSOption {
|
||
|
|
return func(c *DBConfig) {
|
||
|
|
c.DBPath = path
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func getDefaultDBPath() string {
|
||
|
|
return internal.GetDBFile()
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetDefaultDBPath 获取默认数据库路径(导出)
|
||
|
|
func GetDefaultDBPath() string {
|
||
|
|
return getDefaultDBPath()
|
||
|
|
}
|
||
|
|
|
||
|
|
func Init(opts ...DBSOption) error {
|
||
|
|
memoryCfg := internal.GetProjectConfig().Memory
|
||
|
|
|
||
|
|
cfg = &DBConfig{
|
||
|
|
DBPath: memoryCfg.DBPath,
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, opt := range opts {
|
||
|
|
opt(cfg)
|
||
|
|
}
|
||
|
|
|
||
|
|
path := cfg.DBPath
|
||
|
|
if path == "" {
|
||
|
|
path = getDefaultDBPath()
|
||
|
|
}
|
||
|
|
|
||
|
|
if dir := filepath.Dir(path); dir != "" {
|
||
|
|
os.MkdirAll(dir, 0755)
|
||
|
|
}
|
||
|
|
|
||
|
|
dbInst, err := sql.Open("libsql", "file:"+path)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("打开数据库失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := dbInst.Ping(); err != nil {
|
||
|
|
return fmt.Errorf("ping 数据库失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
db = &DB{db: dbInst}
|
||
|
|
return db.createTables()
|
||
|
|
}
|
||
|
|
|
||
|
|
func GetDB() *DB {
|
||
|
|
return db
|
||
|
|
}
|
||
|
|
|
||
|
|
func (d *DB) createTables() error {
|
||
|
|
queries := []string{
|
||
|
|
`CREATE TABLE IF NOT EXISTS sessions (
|
||
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
|
|
uuid TEXT UNIQUE NOT NULL,
|
||
|
|
summary TEXT,
|
||
|
|
summary_embedding BLOB,
|
||
|
|
chat_ids TEXT,
|
||
|
|
created_at INTEGER NOT NULL,
|
||
|
|
updated_at INTEGER NOT NULL
|
||
|
|
)`,
|
||
|
|
`CREATE TABLE IF NOT EXISTS chats (
|
||
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
|
|
session_id INTEGER NOT NULL,
|
||
|
|
user_input TEXT NOT NULL,
|
||
|
|
ai_replies TEXT,
|
||
|
|
summary TEXT,
|
||
|
|
summary_embedding BLOB,
|
||
|
|
created_at INTEGER NOT NULL,
|
||
|
|
updated_at INTEGER NOT NULL,
|
||
|
|
FOREIGN KEY(session_id) REFERENCES sessions(id)
|
||
|
|
)`,
|
||
|
|
`CREATE INDEX IF NOT EXISTS idx_chats_session_id ON chats(session_id)`,
|
||
|
|
`CREATE INDEX IF NOT EXISTS idx_sessions_uuid ON sessions(uuid)`,
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, q := range queries {
|
||
|
|
if _, err := d.db.ExecContext(context.Background(), q); err != nil {
|
||
|
|
return fmt.Errorf("创建表失败: %w", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (d *DB) CreateSession(uuid string) (*Session, error) {
|
||
|
|
s := NewSession(uuid)
|
||
|
|
|
||
|
|
result, err := d.db.ExecContext(
|
||
|
|
context.Background(),
|
||
|
|
`INSERT INTO sessions (uuid, summary, summary_embedding, chat_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)`,
|
||
|
|
s.UUID, s.Summary, s.SummaryEmbedding, "[]", s.CreatedAt, s.UpdatedAt,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("创建会话失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
id, err := result.LastInsertId()
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("获取ID失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
s.ID = id
|
||
|
|
return s, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (d *DB) GetSessionByUUID(uuid string) (*Session, error) {
|
||
|
|
s := &Session{}
|
||
|
|
var chatIDsJSON string
|
||
|
|
err := d.db.QueryRowContext(
|
||
|
|
context.Background(),
|
||
|
|
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions WHERE uuid = ?`,
|
||
|
|
uuid,
|
||
|
|
).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt)
|
||
|
|
if err == sql.ErrNoRows {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("查询会话失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
|
||
|
|
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return s, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (d *DB) GetSessionByID(id int64) (*Session, error) {
|
||
|
|
s := &Session{}
|
||
|
|
var chatIDsJSON string
|
||
|
|
err := d.db.QueryRowContext(
|
||
|
|
context.Background(),
|
||
|
|
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions WHERE id = ?`,
|
||
|
|
id,
|
||
|
|
).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt)
|
||
|
|
if err == sql.ErrNoRows {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("查询会话失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
|
||
|
|
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return s, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (d *DB) UpdateSession(s *Session) error {
|
||
|
|
chatIDsJSON, err := s.GetChatIDsJSON()
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("序列化chat_ids失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
_, err = d.db.ExecContext(
|
||
|
|
context.Background(),
|
||
|
|
`UPDATE sessions SET summary = ?, summary_embedding = ?, chat_ids = ?, updated_at = ? WHERE id = ?`,
|
||
|
|
s.Summary, s.SummaryEmbedding, chatIDsJSON, s.UpdatedAt, s.ID,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("更新会话失败: %w", err)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (d *DB) CreateChat(sessionID int64, userInput string) (*Chat, error) {
|
||
|
|
c := NewChat(sessionID, userInput)
|
||
|
|
|
||
|
|
result, err := d.db.ExecContext(
|
||
|
|
context.Background(),
|
||
|
|
`INSERT INTO chats (session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||
|
|
c.SessionID, c.UserInput, "[]", c.Summary, c.SummaryEmbedding, c.CreatedAt, c.UpdatedAt,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("创建聊天记录失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
id, err := result.LastInsertId()
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("获取ID失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
c.ID = id
|
||
|
|
return c, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (d *DB) GetChatByID(id int64) (*Chat, error) {
|
||
|
|
c := &Chat{}
|
||
|
|
var aiRepliesJSON string
|
||
|
|
err := d.db.QueryRowContext(
|
||
|
|
context.Background(),
|
||
|
|
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE id = ?`,
|
||
|
|
id,
|
||
|
|
).Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &c.SummaryEmbedding, &c.CreatedAt, &c.UpdatedAt)
|
||
|
|
if err == sql.ErrNoRows {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("查询聊天记录失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := c.SetAIRepliesFromJSON(aiRepliesJSON); err != nil {
|
||
|
|
return nil, fmt.Errorf("解析ai_replies失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return c, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (d *DB) UpdateChat(c *Chat) error {
|
||
|
|
aiRepliesJSON, err := c.GetAIRepliesJSON()
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("序列化ai_replies失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
_, err = d.db.ExecContext(
|
||
|
|
context.Background(),
|
||
|
|
`UPDATE chats SET ai_replies = ?, summary = ?, summary_embedding = ?, updated_at = ? WHERE id = ?`,
|
||
|
|
aiRepliesJSON, c.Summary, c.SummaryEmbedding, c.UpdatedAt, c.ID,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("更新聊天记录失败: %w", err)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (d *DB) GetLatestSession() (*Session, error) {
|
||
|
|
s := &Session{}
|
||
|
|
var chatIDsJSON string
|
||
|
|
err := d.db.QueryRowContext(
|
||
|
|
context.Background(),
|
||
|
|
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions ORDER BY id DESC LIMIT 1`,
|
||
|
|
).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt)
|
||
|
|
if err == sql.ErrNoRows {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("查询最新会话失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
|
||
|
|
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return s, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (d *DB) GetAllSessions() ([]*Session, error) {
|
||
|
|
rows, err := d.db.QueryContext(
|
||
|
|
context.Background(),
|
||
|
|
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions ORDER BY id DESC`,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("查询会话列表失败: %w", err)
|
||
|
|
}
|
||
|
|
defer rows.Close()
|
||
|
|
|
||
|
|
var sessions []*Session
|
||
|
|
for rows.Next() {
|
||
|
|
s := &Session{}
|
||
|
|
var chatIDsJSON string
|
||
|
|
if err := rows.Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt); err != nil {
|
||
|
|
return nil, fmt.Errorf("扫描会话失败: %w", err)
|
||
|
|
}
|
||
|
|
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
|
||
|
|
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
|
||
|
|
}
|
||
|
|
sessions = append(sessions, s)
|
||
|
|
}
|
||
|
|
return sessions, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (d *DB) GetChatsBySessionID(sessionID int64) ([]*Chat, error) {
|
||
|
|
rows, err := d.db.QueryContext(
|
||
|
|
context.Background(),
|
||
|
|
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE session_id = ? ORDER BY id`,
|
||
|
|
sessionID,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("查询聊天记录失败: %w", err)
|
||
|
|
}
|
||
|
|
defer rows.Close()
|
||
|
|
|
||
|
|
var chats []*Chat
|
||
|
|
for rows.Next() {
|
||
|
|
c := &Chat{}
|
||
|
|
var aiRepliesJSON string
|
||
|
|
if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &c.SummaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil {
|
||
|
|
return nil, fmt.Errorf("扫描聊天记录失败: %w", err)
|
||
|
|
}
|
||
|
|
if err := c.SetAIRepliesFromJSON(aiRepliesJSON); err != nil {
|
||
|
|
return nil, fmt.Errorf("解析ai_replies失败: %w", err)
|
||
|
|
}
|
||
|
|
chats = append(chats, c)
|
||
|
|
}
|
||
|
|
return chats, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (d *DB) Close() error {
|
||
|
|
if db != nil && db.db != nil {
|
||
|
|
return db.db.Close()
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func CreateNewSession() (*Session, error) {
|
||
|
|
if db == nil {
|
||
|
|
return nil, fmt.Errorf("数据库未初始化")
|
||
|
|
}
|
||
|
|
return db.CreateSession(generateUUID())
|
||
|
|
}
|
||
|
|
|
||
|
|
func GetLatestSession() (*Session, error) {
|
||
|
|
if db == nil {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
return db.GetLatestSession()
|
||
|
|
}
|
||
|
|
|
||
|
|
func GetAllSessions() ([]*Session, error) {
|
||
|
|
if db == nil {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
return db.GetAllSessions()
|
||
|
|
}
|
||
|
|
|
||
|
|
func ListSessions() ([]*Session, error) {
|
||
|
|
return GetAllSessions()
|
||
|
|
}
|
||
|
|
|
||
|
|
func GetCurrentSession() *Session {
|
||
|
|
return currentSession
|
||
|
|
}
|
||
|
|
|
||
|
|
var currentSession *Session
|
||
|
|
|
||
|
|
func generateUUID() string {
|
||
|
|
return uuid.New().String()
|
||
|
|
}
|