diff --git a/cmd/hxclaw/internal/memory/db.go b/cmd/hxclaw/internal/memory/db.go new file mode 100644 index 0000000..c9979c2 --- /dev/null +++ b/cmd/hxclaw/internal/memory/db.go @@ -0,0 +1,371 @@ +package memory + +import ( + "context" + "database/sql" + "fmt" + "os" + "path/filepath" + + "github.com/google/uuid" + "github.com/hxclaw/hxclaw/cmd/hxclaw/internal" + + _ "github.com/tursodatabase/libsql-client-go/libsql" +) + +type DB struct { + db *sql.DB +} + +var ( + db *DB + cfg *DBConfig +) + +type DBConfig struct { + DBPath string +} + +type DBSOption func(*DBConfig) + +func WithDBPath(path string) DBSOption { + return func(c *DBConfig) { + c.DBPath = path + } +} + +func getDefaultDBPath() string { + return internal.GetDBFile() +} + +// GetDefaultDBPath 获取默认数据库路径(导出) +func GetDefaultDBPath() string { + return getDefaultDBPath() +} + +func Init(opts ...DBSOption) error { + memoryCfg := internal.GetProjectConfig().Memory + + cfg = &DBConfig{ + DBPath: memoryCfg.DBPath, + } + + for _, opt := range opts { + opt(cfg) + } + + path := cfg.DBPath + if path == "" { + path = getDefaultDBPath() + } + + if dir := filepath.Dir(path); dir != "" { + os.MkdirAll(dir, 0755) + } + + dbInst, err := sql.Open("libsql", "file:"+path) + if err != nil { + return fmt.Errorf("打开数据库失败: %w", err) + } + + if err := dbInst.Ping(); err != nil { + return fmt.Errorf("ping 数据库失败: %w", err) + } + + db = &DB{db: dbInst} + return db.createTables() +} + +func GetDB() *DB { + return db +} + +func (d *DB) createTables() error { + queries := []string{ + `CREATE TABLE IF NOT EXISTS sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + uuid TEXT UNIQUE NOT NULL, + summary TEXT, + summary_embedding BLOB, + chat_ids TEXT, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + )`, + `CREATE TABLE IF NOT EXISTS chats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER NOT NULL, + user_input TEXT NOT NULL, + ai_replies TEXT, + summary TEXT, + summary_embedding BLOB, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) + )`, + `CREATE INDEX IF NOT EXISTS idx_chats_session_id ON chats(session_id)`, + `CREATE INDEX IF NOT EXISTS idx_sessions_uuid ON sessions(uuid)`, + } + + for _, q := range queries { + if _, err := d.db.ExecContext(context.Background(), q); err != nil { + return fmt.Errorf("创建表失败: %w", err) + } + } + return nil +} + +func (d *DB) CreateSession(uuid string) (*Session, error) { + s := NewSession(uuid) + + result, err := d.db.ExecContext( + context.Background(), + `INSERT INTO sessions (uuid, summary, summary_embedding, chat_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, + s.UUID, s.Summary, s.SummaryEmbedding, "[]", s.CreatedAt, s.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("创建会话失败: %w", err) + } + + id, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("获取ID失败: %w", err) + } + + s.ID = id + return s, nil +} + +func (d *DB) GetSessionByUUID(uuid string) (*Session, error) { + s := &Session{} + var chatIDsJSON string + err := d.db.QueryRowContext( + context.Background(), + `SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions WHERE uuid = ?`, + uuid, + ).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("查询会话失败: %w", err) + } + + if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil { + return nil, fmt.Errorf("解析chat_ids失败: %w", err) + } + + return s, nil +} + +func (d *DB) GetSessionByID(id int64) (*Session, error) { + s := &Session{} + var chatIDsJSON string + err := d.db.QueryRowContext( + context.Background(), + `SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions WHERE id = ?`, + id, + ).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("查询会话失败: %w", err) + } + + if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil { + return nil, fmt.Errorf("解析chat_ids失败: %w", err) + } + + return s, nil +} + +func (d *DB) UpdateSession(s *Session) error { + chatIDsJSON, err := s.GetChatIDsJSON() + if err != nil { + return fmt.Errorf("序列化chat_ids失败: %w", err) + } + + _, err = d.db.ExecContext( + context.Background(), + `UPDATE sessions SET summary = ?, summary_embedding = ?, chat_ids = ?, updated_at = ? WHERE id = ?`, + s.Summary, s.SummaryEmbedding, chatIDsJSON, s.UpdatedAt, s.ID, + ) + if err != nil { + return fmt.Errorf("更新会话失败: %w", err) + } + return nil +} + +func (d *DB) CreateChat(sessionID int64, userInput string) (*Chat, error) { + c := NewChat(sessionID, userInput) + + result, err := d.db.ExecContext( + context.Background(), + `INSERT INTO chats (session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)`, + c.SessionID, c.UserInput, "[]", c.Summary, c.SummaryEmbedding, c.CreatedAt, c.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("创建聊天记录失败: %w", err) + } + + id, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("获取ID失败: %w", err) + } + + c.ID = id + return c, nil +} + +func (d *DB) GetChatByID(id int64) (*Chat, error) { + c := &Chat{} + var aiRepliesJSON string + err := d.db.QueryRowContext( + context.Background(), + `SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE id = ?`, + id, + ).Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &c.SummaryEmbedding, &c.CreatedAt, &c.UpdatedAt) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("查询聊天记录失败: %w", err) + } + + if err := c.SetAIRepliesFromJSON(aiRepliesJSON); err != nil { + return nil, fmt.Errorf("解析ai_replies失败: %w", err) + } + + return c, nil +} + +func (d *DB) UpdateChat(c *Chat) error { + aiRepliesJSON, err := c.GetAIRepliesJSON() + if err != nil { + return fmt.Errorf("序列化ai_replies失败: %w", err) + } + + _, err = d.db.ExecContext( + context.Background(), + `UPDATE chats SET ai_replies = ?, summary = ?, summary_embedding = ?, updated_at = ? WHERE id = ?`, + aiRepliesJSON, c.Summary, c.SummaryEmbedding, c.UpdatedAt, c.ID, + ) + if err != nil { + return fmt.Errorf("更新聊天记录失败: %w", err) + } + return nil +} + +func (d *DB) GetLatestSession() (*Session, error) { + s := &Session{} + var chatIDsJSON string + err := d.db.QueryRowContext( + context.Background(), + `SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions ORDER BY id DESC LIMIT 1`, + ).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("查询最新会话失败: %w", err) + } + + if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil { + return nil, fmt.Errorf("解析chat_ids失败: %w", err) + } + + return s, nil +} + +func (d *DB) GetAllSessions() ([]*Session, error) { + rows, err := d.db.QueryContext( + context.Background(), + `SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions ORDER BY id DESC`, + ) + if err != nil { + return nil, fmt.Errorf("查询会话列表失败: %w", err) + } + defer rows.Close() + + var sessions []*Session + for rows.Next() { + s := &Session{} + var chatIDsJSON string + if err := rows.Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt); err != nil { + return nil, fmt.Errorf("扫描会话失败: %w", err) + } + if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil { + return nil, fmt.Errorf("解析chat_ids失败: %w", err) + } + sessions = append(sessions, s) + } + return sessions, nil +} + +func (d *DB) GetChatsBySessionID(sessionID int64) ([]*Chat, error) { + rows, err := d.db.QueryContext( + context.Background(), + `SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE session_id = ? ORDER BY id`, + sessionID, + ) + if err != nil { + return nil, fmt.Errorf("查询聊天记录失败: %w", err) + } + defer rows.Close() + + var chats []*Chat + for rows.Next() { + c := &Chat{} + var aiRepliesJSON string + if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &c.SummaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil { + return nil, fmt.Errorf("扫描聊天记录失败: %w", err) + } + if err := c.SetAIRepliesFromJSON(aiRepliesJSON); err != nil { + return nil, fmt.Errorf("解析ai_replies失败: %w", err) + } + chats = append(chats, c) + } + return chats, nil +} + +func (d *DB) Close() error { + if db != nil && db.db != nil { + return db.db.Close() + } + return nil +} + +func CreateNewSession() (*Session, error) { + if db == nil { + return nil, fmt.Errorf("数据库未初始化") + } + return db.CreateSession(generateUUID()) +} + +func GetLatestSession() (*Session, error) { + if db == nil { + return nil, nil + } + return db.GetLatestSession() +} + +func GetAllSessions() ([]*Session, error) { + if db == nil { + return nil, nil + } + return db.GetAllSessions() +} + +func ListSessions() ([]*Session, error) { + return GetAllSessions() +} + +func GetCurrentSession() *Session { + return currentSession +} + +var currentSession *Session + +func generateUUID() string { + return uuid.New().String() +} \ No newline at end of file diff --git a/cmd/hxclaw/internal/memory/export.go b/cmd/hxclaw/internal/memory/export.go new file mode 100644 index 0000000..6bbf4e2 --- /dev/null +++ b/cmd/hxclaw/internal/memory/export.go @@ -0,0 +1,110 @@ +package memory + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/hxclaw/hxclaw/cmd/hxclaw/internal" +) + +type ExportDocument struct { + Version int `json:"version"` + ExportedAt string `json:"exported_at"` + Sessions []*SessionData `json:"sessions"` +} + +type SessionData struct { + ID int64 `json:"id"` + UUID string `json:"uuid"` + Summary string `json:"summary"` + ChatIDs []int64 `json:"chat_ids"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + Chats []*Chat `json:"chats"` +} + +func ExportIfNeeded() { + memoryCfg := internal.GetProjectConfig().Memory + if !memoryCfg.Enabled || !memoryCfg.AutoExport { + return + } + + db := GetDB() + if db == nil { + return + } + + exportPath := filepath.Join(internal.GetConfigDir(), "export-data.json") + if err := ExportToFile(exportPath); err != nil { + fmt.Printf("警告:导出数据失败: %v\n", err) + return + } + fmt.Printf("数据已导出: %s\n", exportPath) +} + +func ExportToFile(filename string) error { + db := GetDB() + if db == nil { + return fmt.Errorf("数据库未初始化") + } + + sessions, err := db.GetAllSessions() + if err != nil { + return fmt.Errorf("查询会话失败: %w", err) + } + + var doc ExportDocument + doc.Version = 1 + doc.ExportedAt = time.Now().Format(time.RFC3339) + + if _, err := os.Stat(filename); err == nil { + data, readErr := os.ReadFile(filename) + if readErr == nil { + json.Unmarshal(data, &doc) + } + } + + sessionMap := make(map[string]*SessionData) + for i := range doc.Sessions { + sessionMap[doc.Sessions[i].UUID] = doc.Sessions[i] + } + + for _, s := range sessions { + if existing, ok := sessionMap[s.UUID]; ok { + existing.Summary = s.Summary + existing.ChatIDs = s.ChatIDs + existing.UpdatedAt = s.UpdatedAt + } else { + sessionData := &SessionData{ + ID: s.ID, + UUID: s.UUID, + Summary: s.Summary, + ChatIDs: s.ChatIDs, + CreatedAt: s.CreatedAt, + UpdatedAt: s.UpdatedAt, + Chats: []*Chat{}, + } + doc.Sessions = append(doc.Sessions, sessionData) + sessionMap[s.UUID] = sessionData + } + + chats, err := db.GetChatsBySessionID(s.ID) + if err != nil { + continue + } + sessionMap[s.UUID].Chats = append(sessionMap[s.UUID].Chats, chats...) + } + + file, err := os.Create(filename) + if err != nil { + return fmt.Errorf("创建文件失败: %w", err) + } + defer file.Close() + + encoder := json.NewEncoder(file) + encoder.SetIndent("", " ") + return encoder.Encode(doc) +} \ No newline at end of file diff --git a/cmd/hxclaw/internal/memory/model.go b/cmd/hxclaw/internal/memory/model.go new file mode 100644 index 0000000..97ce7c1 --- /dev/null +++ b/cmd/hxclaw/internal/memory/model.go @@ -0,0 +1,98 @@ +package memory + +import ( + "encoding/json" + "time" +) + +type Session struct { + ID int64 `json:"id"` + UUID string `json:"uuid"` + Summary string `json:"summary"` + SummaryEmbedding []byte `json:"-"` + ChatIDs []int64 `json:"chat_ids"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +type Chat struct { + ID int64 `json:"id"` + SessionID int64 `json:"session_id"` + UserInput string `json:"user_input"` + AIReplies []string `json:"ai_replies"` + Summary string `json:"summary"` + SummaryEmbedding []byte `json:"-"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +func NewSession(uuid string) *Session { + now := time.Now().Unix() + return &Session{ + UUID: uuid, + ChatIDs: []int64{}, + CreatedAt: now, + UpdatedAt: now, + } +} + +func NewChat(sessionID int64, userInput string) *Chat { + now := time.Now().Unix() + return &Chat{ + SessionID: sessionID, + UserInput: userInput, + AIReplies: []string{}, + CreatedAt: now, + UpdatedAt: now, + } +} + +func (s *Session) AddChatID(chatID int64) { + s.ChatIDs = append(s.ChatIDs, chatID) + s.UpdatedAt = time.Now().Unix() +} + +func (s *Session) GetChatIDsJSON() (string, error) { + data, err := json.Marshal(s.ChatIDs) + if err != nil { + return "", err + } + return string(data), nil +} + +func (s *Session) SetChatIDsFromJSON(data string) error { + if data == "" { + s.ChatIDs = []int64{} + return nil + } + return json.Unmarshal([]byte(data), &s.ChatIDs) +} + +func (c *Chat) AddAIReply(reply string) { + c.AIReplies = append(c.AIReplies, reply) + c.UpdatedAt = time.Now().Unix() +} + +func (c *Chat) GetAIRepliesJSON() (string, error) { + data, err := json.Marshal(c.AIReplies) + if err != nil { + return "", err + } + return string(data), nil +} + +func (c *Chat) SetAIRepliesFromJSON(data string) error { + if data == "" { + c.AIReplies = []string{} + return nil + } + return json.Unmarshal([]byte(data), &c.AIReplies) +} + +func GenerateSummary(userInput, aiReply string) string { + fullText := userInput + "\n" + aiReply + if len(fullText) > 200 { + fullText = fullText[:200] + } + return fullText +} \ No newline at end of file diff --git a/cmd/hxclaw/internal/memory/save.go b/cmd/hxclaw/internal/memory/save.go new file mode 100644 index 0000000..cb8fd02 --- /dev/null +++ b/cmd/hxclaw/internal/memory/save.go @@ -0,0 +1,443 @@ +package memory + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/hxclaw/hxclaw/cmd/hxclaw/internal" +) + +var lastContext string + +func GetContextPrompt(userInput string) string { + db := GetDB() + if db == nil { + return "" + } + + memoryCfg := internal.GetProjectConfig().Memory + if !memoryCfg.Enabled { + return "" + } + + session := currentSession + if session == nil { + var err error + session, err = db.GetLatestSession() + if err != nil || session == nil { + return "" + } + currentSession = session + } + + var context string + + // 读取 picoclaw 的长期记忆 + picoMemory := readPicoClawMemory() + if picoMemory != "" { + context += "=== 长期记忆 ===\n" + picoMemory + "\n============\n" + } + + // 添加会话摘要 + if session.Summary != "" { + context += "=== 当前会话摘要 ===\n" + session.Summary + "\n============\n" + } + + if shouldRecall(userInput) { + recallResult := detectAndRecall(userInput) + if recallResult != "" { + context += recallResult + } + } + + lastContext = context + return context +} + +func readPicoClawMemory() string { + memoryPath := filepath.Join(getPicoclawWorkspace(), "memory", "MEMORY.md") + data, err := os.ReadFile(memoryPath) + if err != nil { + return "" + } + + content := strings.TrimSpace(string(data)) + + lines := strings.Split(content, "\n") + var sb strings.Builder + inSection := false + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" { + continue + } + if strings.HasPrefix(trimmed, "#") { + inSection = true + continue + } + if inSection { + sb.WriteString(trimmed) + sb.WriteString("\n") + } + } + + return strings.TrimSpace(sb.String()) +} + +func getPicoclawWorkspace() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return filepath.Join(home, ".picoclaw", "workspace") +} + +func shouldRecall(input string) bool { + memoryCfg := internal.GetProjectConfig().Memory + + for _, cmd := range []string{"/recall", "/memory"} { + if strings.HasPrefix(input, cmd) { + return true + } + } + + recallCfg := memoryCfg.Recall + for _, keyword := range recallCfg.Keywords { + if strings.Contains(input, keyword) { + return true + } + } + + if recallCfg.AutoRecall { + return shouldAutoRecall(input) + } + + return false +} + +func shouldAutoRecall(input string) bool { + vs := GetVectorService() + if vs == nil { + return false + } + + memoryCfg := internal.GetProjectConfig().Memory + threshold := memoryCfg.Recall.SimilarityThreshold + if threshold <= 0 { + threshold = 0.7 + } + + session := currentSession + if session == nil { + return false + } + + if session.Summary == "" || len(session.SummaryEmbedding) == 0 { + return false + } + + inputEmb, err := vs.Generate(input) + if err != nil { + return false + } + + similarity := vs.CosineSimilarity(inputEmb, session.SummaryEmbedding) + return similarity > threshold +} + +func detectAndRecall(input string) string { + input = strings.TrimSpace(input) + + if strings.HasPrefix(input, "/recall") { + input = strings.TrimPrefix(input, "/recall") + input = strings.TrimSpace(input) + if input == "" { + result, _ := RecallHistory() + return FormatRecallResults(result) + } + } + + for _, keyword := range []string{"之前", "聊过", "记得", "曾经", "谈论过", "提过"} { + if strings.Contains(input, keyword) { + return recallByKeyword(input, keyword) + } + } + + return "" +} + +func recallByKeyword(input, keyword string) string { + db := GetDB() + if db == nil { + return "" + } + + if keyword == "之前" || keyword == "聊过" || keyword == "曾经" { + if strings.Contains(input, "什么") || input == "之前" || strings.Contains(input, "都聊") { + result, _ := RecallHistory() + return FormatRecallResults(result) + } + + topic := extractTopic(input) + if topic != "" { + result, _ := RecallTopic(topic, 5) + return FormatRecallResults(result) + } + } + + if strings.Contains(input, "那次") || strings.Contains(input, "那次") { + result, _ := RecallHistory() + return FormatRecallResults(result) + } + + topic := extractTopic(input) + if topic != "" { + result, _ := RecallTopic(topic, 5) + return FormatRecallResults(result) + } + + return "" +} + +func extractTopic(input string) string { + topic := input + removePrefixes := []string{"之前", "聊过", "谈论过", "提过", "记得", "曾经", "关于", "我问过"} + for _, prefix := range removePrefixes { + topic = strings.Replace(topic, prefix, "", -1) + } + topic = strings.Trim(topic, " ,,。.??!!::") + return topic +} + +func ShouldSkipSummaryUpdate(input string) bool { + return shouldRecall(input) +} + +func SaveChat(userInput, aiReply string, useSessionSummary bool) (int, error) { + db := GetDB() + if db == nil { + return 0, fmt.Errorf("数据库未初始化") + } + + memoryCfg := internal.GetProjectConfig().Memory + if !memoryCfg.Enabled { + return 0, nil + } + + // 获取或创建 Session + session := currentSession + if session == nil { + var err error + session, err = db.GetLatestSession() + if err != nil || session == nil { + session, err = createSession() + if err != nil { + return 0, fmt.Errorf("创建会话失败: %v", err) + } + } + currentSession = session + } + + // 保存原始 session summary(用于恢复) + originalSummary := session.Summary + + // 创建聊天记录 + chat, err := db.CreateChat(session.ID, userInput) + if err != nil { + return 0, fmt.Errorf("创建聊天记录失败: %v", err) + } + + // 添加 AI 回复 + chat.AddAIReply(aiReply) + chat.Summary = GenerateSummary(userInput, aiReply) + chat.UpdatedAt = time.Now().Unix() + + // 生成并保存向量 + vs := GetVectorService() + if vs != nil && memoryCfg.Vector.APIKey != "" { + go generateEmbedding(chat) + } + + if err := db.UpdateChat(chat); err != nil { + return 0, fmt.Errorf("更新聊天记录失败: %v", err) + } + + // 更新 Session + session.AddChatID(chat.ID) + + // 只有普通对话才更新 session summary,recall 查询保持原 summary + if useSessionSummary { + session.Summary = GenerateSummary("", session.Summary+"\n"+aiReply) + } + session.UpdatedAt = time.Now().Unix() + + // Session 也要生成向量 + if vs != nil && memoryCfg.Vector.APIKey != "" && useSessionSummary { + go generateSessionEmbedding(session) + } + + if err := db.UpdateSession(session); err != nil { + return 0, fmt.Errorf("更新会话失败: %v", err) + } + + // 恢复原始 session summary(避免 recall 结果污染) + session.Summary = originalSummary + currentSession = session + + return len(session.ChatIDs), nil +} + +func generateEmbedding(chat *Chat) { + vs := GetVectorService() + if vs == nil { + return + } + + text := chat.UserInput + if len(chat.AIReplies) > 0 { + text = text + "\n" + chat.AIReplies[len(chat.AIReplies)-1] + } + + embedding, err := vs.Generate(text) + if err != nil { + fmt.Printf("[memory] 向量生成失败: %v\n", err) + return + } + + chat.SummaryEmbedding = embedding + chat.UpdatedAt = time.Now().Unix() + + if err := GetDB().UpdateChat(chat); err != nil { + fmt.Printf("[memory] 向量保存失败: %v\n", err) + } +} + +func generateSessionEmbedding(session *Session) { + vs := GetVectorService() + if vs == nil { + return + } + + embedding, err := vs.Generate(session.Summary) + if err != nil { + fmt.Printf("[memory] Session 向量生成失败: %v\n", err) + return + } + + session.SummaryEmbedding = embedding + session.UpdatedAt = time.Now().Unix() + + if err := GetDB().UpdateSession(session); err != nil { + fmt.Printf("[memory] Session 向量保存失败: %v\n", err) + } +} + +func createSession() (*Session, error) { + return CreateNewSession() +} + +type SearchResult struct { + Chat *Chat + Similarity float64 +} + +func SearchSimilar(query string, topK int) ([]SearchResult, error) { + db := GetDB() + if db == nil { + return nil, fmt.Errorf("数据库未初始化") + } + + vs := GetVectorService() + if vs == nil { + return nil, fmt.Errorf("向量服务未初始化") + } + + // 获取配置的最大检索数 + maxResults := internal.GetProjectConfig().Memory.Vector.MaxSearchResults + if maxResults <= 0 { + maxResults = 10 + } + if topK > maxResults { + topK = maxResults + } + + // 生成查询向量 + queryEmbedding, err := vs.Generate(query) + if err != nil { + return nil, fmt.Errorf("查询向量生成失败: %w", err) + } + + // 获取所有有向量的 chat + chats, err := db.GetAllChats() + if err != nil { + return nil, fmt.Errorf("获取聊天记录失败: %w", err) + } + + // 计算相似度并排序 + type scoredChat struct { + chat *Chat + similarity float64 + } + + var scored []scoredChat + for _, chat := range chats { + if len(chat.SummaryEmbedding) == 0 { + continue + } + sim := vs.CosineSimilarity(queryEmbedding, chat.SummaryEmbedding) + scored = append(scored, scoredChat{chat: chat, similarity: sim}) + } + + // 按相似度排序 + for i := 0; i < len(scored); i++ { + for j := i + 1; j < len(scored); j++ { + if scored[j].similarity > scored[i].similarity { + scored[i], scored[j] = scored[j], scored[i] + } + } + } + + // 取前 K 个 + if topK > len(scored) { + topK = len(scored) + } + + results := make([]SearchResult, topK) + for i := 0; i < topK; i++ { + results[i] = SearchResult{ + Chat: scored[i].chat, + Similarity: scored[i].similarity, + } + } + + return results, nil +} + +func (db *DB) GetAllChats() ([]*Chat, error) { + rows, err := db.db.Query( + `SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE summary_embedding IS NOT NULL AND length(summary_embedding) > 0 ORDER BY id DESC`, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var chats []*Chat + for rows.Next() { + var c Chat + var aiRepliesJSON string + var summaryEmbedding []byte + if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &summaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil { + continue + } + c.SummaryEmbedding = summaryEmbedding + c.SetAIRepliesFromJSON(aiRepliesJSON) + chats = append(chats, &c) + } + + return chats, nil +} \ No newline at end of file diff --git a/cmd/hxclaw/internal/memory/skill.go b/cmd/hxclaw/internal/memory/skill.go new file mode 100644 index 0000000..77c1803 --- /dev/null +++ b/cmd/hxclaw/internal/memory/skill.go @@ -0,0 +1,214 @@ +package memory + +import ( + "fmt" + + "github.com/hxclaw/hxclaw/cmd/hxclaw/internal" +) + +type RecallResult struct { + Type string `json:"type"` + Message string `json:"message"` + Sessions []*Session `json:"sessions,omitempty"` + Chats []*Chat `json:"chats,omitempty"` +} + +func RecallHistory() (*RecallResult, error) { + db := GetDB() + if db == nil { + return nil, fmt.Errorf("数据库未初始化") + } + + sessions, err := db.GetAllSessions() + if err != nil { + return nil, fmt.Errorf("查询会话失败: %v", err) + } + + if len(sessions) == 0 { + return &RecallResult{ + Type: "history", + Message: "暂无历史会话记录", + }, nil + } + + msg := "以下是所有历史会话摘要:\n\n" + for i, s := range sessions { + summary := s.Summary + if summary == "" { + summary = "(暂无摘要)" + } + msg += fmt.Sprintf("**会话 %d** (ID: %d)\n%s\n\n", i+1, s.ID, summary) + } + + return &RecallResult{ + Type: "history", + Message: msg, + Sessions: sessions, + }, nil +} + +func RecallTopic(query string, topK int) (*RecallResult, error) { + vs := GetVectorService() + if vs == nil { + return nil, fmt.Errorf("向量服务未初始化") + } + + cfg := internal.GetProjectConfig().Memory + if topK <= 0 { + topK = cfg.Vector.MaxSearchResults + if topK <= 0 { + topK = 5 + } + } + + results, err := SearchSimilar(query, topK) + if err != nil { + return nil, fmt.Errorf("向量检索失败: %v", err) + } + + if len(results) == 0 { + return &RecallResult{ + Type: "topic", + Message: fmt.Sprintf("未找到与「%s」相关的内容", query), + }, nil + } + + type sessionChat struct { + sessionID int64 + chat *Chat + similarity float64 + } + + sessionChats := make(map[int64]*sessionChat) + for _, r := range results { + if existing, ok := sessionChats[r.Chat.SessionID]; ok { + if r.Similarity > existing.similarity { + sessionChats[r.Chat.SessionID] = &sessionChat{ + sessionID: r.Chat.SessionID, + chat: r.Chat, + similarity: r.Similarity, + } + } + } else { + sessionChats[r.Chat.SessionID] = &sessionChat{ + sessionID: r.Chat.SessionID, + chat: r.Chat, + similarity: r.Similarity, + } + } + } + + msg := fmt.Sprintf("找到与「%s」相关的内容:\n\n", query) + var chats []*Chat + for _, sc := range sessionChats { + msg += fmt.Sprintf("**会话 %d** (相似度: %.2f)\n%s\n\n", + sc.sessionID, sc.similarity, sc.chat.Summary) + chats = append(chats, sc.chat) + } + + return &RecallResult{ + Type: "topic", + Message: msg, + Chats: chats, + }, nil +} + +func RecallSession(sessionID int64) (*RecallResult, error) { + db := GetDB() + if db == nil { + return nil, fmt.Errorf("数据库未初始化") + } + + session, err := db.GetSessionByID(sessionID) + if err != nil { + return nil, fmt.Errorf("查询会话失败: %v", err) + } + if session == nil { + return &RecallResult{ + Type: "session", + Message: fmt.Sprintf("未找到会话 ID: %d", sessionID), + }, nil + } + + msg := fmt.Sprintf("**会话 %d 的摘要**\n\n%s", sessionID, session.Summary) + + return &RecallResult{ + Type: "session", + Message: msg, + Sessions: []*Session{session}, + }, nil +} + +func RecallWithinSession(sessionID int64, query string) (*RecallResult, error) { + db := GetDB() + if db == nil { + return nil, fmt.Errorf("数据库未初始化") + } + + chats, err := db.GetChatsBySessionID(sessionID) + if err != nil { + return nil, fmt.Errorf("查询聊天记录失败: %v", err) + } + + if len(chats) == 0 { + return &RecallResult{ + Type: "within_session", + Message: fmt.Sprintf("会话 %d 中暂无聊天记录", sessionID), + }, nil + } + + vs := GetVectorService() + if vs == nil { + return nil, fmt.Errorf("向量服务未初始化") + } + + queryEmb, err := vs.Generate(query) + if err != nil { + return nil, fmt.Errorf("生成查询向量失败: %v", err) + } + + type scoredChat struct { + chat *Chat + similarity float64 + } + + var scored []scoredChat + for _, c := range chats { + if len(c.SummaryEmbedding) == 0 { + continue + } + sim := vs.CosineSimilarity(queryEmb, c.SummaryEmbedding) + scored = append(scored, scoredChat{chat: c, similarity: sim}) + } + + for i := 0; i < len(scored); i++ { + for j := i + 1; j < len(scored); j++ { + if scored[j].similarity > scored[i].similarity { + scored[i], scored[j] = scored[j], scored[i] + } + } + } + + topK := 3 + if topK > len(scored) { + topK = len(scored) + } + + msg := fmt.Sprintf("会话 %d 中与「%s」相关内容:\n\n", sessionID, query) + var matchedChats []*Chat + for i := 0; i < topK; i++ { + msg += fmt.Sprintf("**相似度: %.2f**\n%s\n\n", + scored[i].similarity, scored[i].chat.Summary) + matchedChats = append(matchedChats, scored[i].chat) + } + + return &RecallResult{ + Type: "within_session", + Message: msg, + Chats: matchedChats, + }, nil +} + +func FormatRecallResults(result *RecallResult) string { + return result.Message +} \ No newline at end of file diff --git a/cmd/hxclaw/internal/memory/vector.go b/cmd/hxclaw/internal/memory/vector.go new file mode 100644 index 0000000..83d7c45 --- /dev/null +++ b/cmd/hxclaw/internal/memory/vector.go @@ -0,0 +1,234 @@ +package memory + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "math" + "net/http" + "sync" +) + +type VectorService struct { + APIKey string + BaseURL string + Model string + Dimension int +} + +var ( + vectorSvc *VectorService + vectorMu sync.RWMutex +) + +type VectorOption func(*VectorService) + +func WithAPIKey(key string) VectorOption { + return func(v *VectorService) { + v.APIKey = key + } +} + +func WithBaseURL(url string) VectorOption { + return func(v *VectorService) { + v.BaseURL = url + } +} + +func WithModel(model string) VectorOption { + return func(v *VectorService) { + v.Model = model + } +} + +func WithDimension(dim int) VectorOption { + return func(v *VectorService) { + v.Dimension = dim + } +} + +func InitVector(opts ...VectorOption) error { + vectorSvc = &VectorService{ + APIKey: "", + BaseURL: "https://api.siliconflow.cn/v1", + Model: "BAAI/bge-m3", + Dimension: 1024, + } + for _, opt := range opts { + opt(vectorSvc) + } + return nil +} + +func GetVectorService() *VectorService { + vectorMu.RLock() + defer vectorMu.RUnlock() + return vectorSvc +} + +func (v *VectorService) Generate(text string) ([]byte, error) { + if text == "" { + return make([]byte, v.Dimension*4), nil + } + + reqBody := map[string]interface{}{ + "input": text, + "model": v.Model, + } + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("序列化请求失败: %w", err) + } + + req, err := http.NewRequestWithContext( + context.Background(), + "POST", + v.BaseURL+"/embeddings", + bytes.NewReader(bodyBytes), + ) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+v.APIKey) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("请求失败: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API 返回错误: %d %s", resp.StatusCode, string(respBody)) + } + + var result struct { + Data []struct { + Embedding []float32 `json:"embedding"` + } `json:"data"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("解析响应失败: %w", err) + } + + if len(result.Data) == 0 || len(result.Data[0].Embedding) == 0 { + return nil, fmt.Errorf("未获取到向量") + } + + embedding := result.Data[0].Embedding + buf := new(bytes.Buffer) + if err := binary.Write(buf, binary.LittleEndian, embedding); err != nil { + return nil, fmt.Errorf("编码向量失败: %w", err) + } + + return buf.Bytes(), nil +} + +func (v *VectorService) GenerateBatch(texts []string) ([][]byte, error) { + if len(texts) == 0 { + return [][]byte{}, nil + } + + reqBody := map[string]interface{}{ + "input": texts, + "model": v.Model, + } + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("序列化请求失败: %w", err) + } + + req, err := http.NewRequestWithContext( + context.Background(), + "POST", + v.BaseURL+"/embeddings", + bytes.NewReader(bodyBytes), + ) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+v.APIKey) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("请求失败: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API 返回错误: %d %s", resp.StatusCode, string(respBody)) + } + + var result struct { + Data []struct { + Embedding []float32 `json:"embedding"` + } `json:"data"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("解析响应失败: %w", err) + } + + embeddings := make([][]byte, len(result.Data)) + for i, data := range result.Data { + embedding := data.Embedding + buf := new(bytes.Buffer) + if err := binary.Write(buf, binary.LittleEndian, embedding); err != nil { + continue + } + embeddings[i] = buf.Bytes() + } + + return embeddings, nil +} + +func (v *VectorService) CosineSimilarity(a, b []byte) float64 { + if len(a) == 0 || len(b) == 0 { + return 0 + } + + vecLen := len(a) / 4 + if len(b)/4 < vecLen { + vecLen = len(b) / 4 + } + + var dotProduct float32 + var normA float32 + var normB float32 + + for i := 0; i < vecLen; i++ { + offset := i * 4 + fa := binary.LittleEndian.Uint32(a[offset : offset+4]) + fb := binary.LittleEndian.Uint32(b[offset : offset+4]) + f32a := math.Float32frombits(fa) + f32b := math.Float32frombits(fb) + dotProduct += f32a * f32b + normA += f32a * f32a + normB += f32b * f32b + } + + if normA == 0 || normB == 0 { + return 0 + } + + return float64(dotProduct / (normA * normB)) +} \ No newline at end of file diff --git a/cmd/hxclaw/internal/tts.go b/cmd/hxclaw/internal/tts.go new file mode 100644 index 0000000..6e849c1 --- /dev/null +++ b/cmd/hxclaw/internal/tts.go @@ -0,0 +1,145 @@ +package internal + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "net" + "os/exec" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +var ( + ttsEnabled bool + ttsEnabledMu sync.RWMutex +) + +type TTSRequest struct { + Text string `json:"text"` + Voice *string `json:"voice,omitempty"` + Format *string `json:"format,omitempty"` + Style *string `json:"style,omitempty"` +} + +type TTSResponse struct { + Status string `json:"status"` + Message string `json:"message"` +} + +func SetTTSEnabled(enabled bool) { + ttsEnabledMu.Lock() + defer ttsEnabledMu.Unlock() + ttsEnabled = enabled +} + +func IsTTSEnabled() bool { + ttsEnabledMu.RLock() + defer ttsEnabledMu.RUnlock() + return ttsEnabled +} + +func ToggleTTS() bool { + ttsEnabledMu.Lock() + defer ttsEnabledMu.Unlock() + ttsEnabled = !ttsEnabled + return ttsEnabled +} + +func GetTTSPrompt(basePrompt string) string { + if IsTTSEnabled() { + return basePrompt + "🔊 " + } + return basePrompt +} + +func SpeakText(text string) { + cfg := GetProjectConfig() + if !cfg.TTS.Enabled && !IsTTSEnabled() { + return + } + + port := cfg.TTS.Port + if port <= 0 { + port = 9876 + } + + addr := fmt.Sprintf("127.0.0.1:%d", port) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err := speakTextWithContext(ctx, text, addr) + if err != nil { + logger.WarnCF("tts", "网络语音暂时异常", map[string]any{ + "error": err.Error(), + }) + } +} + +func speakTextWithContext(ctx context.Context, text, addr string) error { + conn, err := dialWithContext(ctx, addr) + if err != nil { + return fmt.Errorf("连接失败: %w", err) + } + defer conn.Close() + + reader := bufio.NewReader(conn) + + voiceStr := "mimo_default" + formatStr := "wav" + + request := TTSRequest{ + Text: text, + Voice: &voiceStr, + Format: &formatStr, + } + + requestData, err := json.Marshal(request) + if err != nil { + return fmt.Errorf("序列化请求失败: %w", err) + } + + _, err = conn.Write(append(requestData, '\n')) + if err != nil { + return fmt.Errorf("发送请求失败: %w", err) + } + + responseLine, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("读取响应失败: %w", err) + } + + var response TTSResponse + if err := json.Unmarshal([]byte(responseLine), &response); err != nil { + return fmt.Errorf("解析响应失败: %w", err) + } + + if response.Status != "ok" { + return fmt.Errorf("服务错误: %s", response.Message) + } + + audioFile := response.Message + if audioFile == "" { + return fmt.Errorf("未收到音频文件路径") + } + + if err := playAudio(audioFile); err != nil { + return fmt.Errorf("播放失败: %w", err) + } + + return nil +} + +func dialWithContext(ctx context.Context, addr string) (net.Conn, error) { + d := &net.Dialer{} + return d.DialContext(ctx, "tcp", addr) +} + +func playAudio(filePath string) error { + cmd := exec.Command("afplay", filePath) + return cmd.Run() +} \ No newline at end of file diff --git a/cmd/hxclaw/internal/userdir.go b/cmd/hxclaw/internal/userdir.go new file mode 100644 index 0000000..27e7ffc --- /dev/null +++ b/cmd/hxclaw/internal/userdir.go @@ -0,0 +1,36 @@ +package internal + +import ( + "os" + "path/filepath" +) + +const ( + AppName = "hxclaw" + ConfigDirName = ".config" +) + +func GetUserHome() string { + home := os.Getenv("HOME") + if home == "" { + home = os.Getenv("USERPROFILE") + } + return home +} + +func GetConfigDir() string { + return filepath.Join(GetUserHome(), ConfigDirName, AppName) +} + +func GetConfigFile() string { + return filepath.Join(GetConfigDir(), "config.yml") +} + +func GetDBFile() string { + return filepath.Join(GetConfigDir(), "hxclaw.db") +} + +func EnsureConfigDir() error { + dir := GetConfigDir() + return os.MkdirAll(dir, 0755) +} \ No newline at end of file