package memory import ( "fmt" "os" "path/filepath" "strings" "time" "github.com/hxclaw/hxclaw/cmd/hxclaw/internal" ) var lastContext string var ErrNeedNewSession = fmt.Errorf("需要创建新会话") func GetContextPrompt(userInput string) string { db := GetDB() if db == nil { return "" } memoryCfg := internal.GetProjectConfig().Memory if !memoryCfg.Enabled { return "" } session := currentSession if session == nil { return "" // 没有 session,不提供上下文 } var context string // 读取 picoclaw 的长期记忆 picoMemory := readPicoClawMemory() if picoMemory != "" { context += "=== 长期记忆 ===\n" + picoMemory + "\n============\n" } // 添加会话摘要 if session.Summary != "" { context += "=== 当前会话摘要 ===\n" + session.Summary + "\n============\n" } if shouldRecall(userInput) { recallResult := detectAndRecall(userInput) if recallResult != "" { context += recallResult } } lastContext = context return context } func readPicoClawMemory() string { memoryPath := filepath.Join(getPicoclawWorkspace(), "memory", "MEMORY.md") data, err := os.ReadFile(memoryPath) if err != nil { return "" } content := strings.TrimSpace(string(data)) lines := strings.Split(content, "\n") var sb strings.Builder inSection := false for _, line := range lines { trimmed := strings.TrimSpace(line) if trimmed == "" { continue } if strings.HasPrefix(trimmed, "#") { inSection = true continue } if inSection { sb.WriteString(trimmed) sb.WriteString("\n") } } return strings.TrimSpace(sb.String()) } func getPicoclawWorkspace() string { home, err := os.UserHomeDir() if err != nil { return "" } return filepath.Join(home, ".picoclaw", "workspace") } func shouldRecall(input string) bool { memoryCfg := internal.GetProjectConfig().Memory for _, cmd := range []string{"/recall", "/memory"} { if strings.HasPrefix(input, cmd) { return true } } recallCfg := memoryCfg.Recall for _, keyword := range recallCfg.Keywords { if strings.Contains(input, keyword) { return true } } if recallCfg.AutoRecall { return shouldAutoRecall(input) } return false } func shouldAutoRecall(input string) bool { vs := GetVectorService() if vs == nil { return false } memoryCfg := internal.GetProjectConfig().Memory threshold := memoryCfg.Recall.SimilarityThreshold if threshold <= 0 { threshold = 0.7 } session := currentSession if session == nil { return false } if session.Summary == "" || len(session.SummaryEmbedding) == 0 { return false } inputEmb, err := vs.Generate(input) if err != nil { return false } similarity := vs.CosineSimilarity(inputEmb, session.SummaryEmbedding) return similarity > threshold } func detectAndRecall(input string) string { input = strings.TrimSpace(input) if strings.HasPrefix(input, "/recall") { input = strings.TrimPrefix(input, "/recall") input = strings.TrimSpace(input) if input == "" { result, _ := RecallHistory() return FormatRecallResults(result) } } for _, keyword := range []string{"之前", "聊过", "记得", "曾经", "谈论过", "提过"} { if strings.Contains(input, keyword) { return recallByKeyword(input, keyword) } } return "" } func recallByKeyword(input, keyword string) string { db := GetDB() if db == nil { return "" } if keyword == "之前" || keyword == "聊过" || keyword == "曾经" { if strings.Contains(input, "什么") || input == "之前" || strings.Contains(input, "都聊") { result, _ := RecallHistory() return FormatRecallResults(result) } topic := extractTopic(input) if topic != "" { result, _ := RecallTopic(topic, 5) return FormatRecallResults(result) } } if strings.Contains(input, "那次") || strings.Contains(input, "那次") { result, _ := RecallHistory() return FormatRecallResults(result) } topic := extractTopic(input) if topic != "" { result, _ := RecallTopic(topic, 5) return FormatRecallResults(result) } return "" } func extractTopic(input string) string { topic := input removePrefixes := []string{"之前", "聊过", "谈论过", "提过", "记得", "曾经", "关于", "我问过"} for _, prefix := range removePrefixes { topic = strings.Replace(topic, prefix, "", -1) } topic = strings.Trim(topic, " ,,。.??!!::") return topic } func ShouldSkipSummaryUpdate(input string) bool { return shouldRecall(input) } func SaveChat(userInput, aiReply string, useSessionSummary bool) (int, error) { db := GetDB() if db == nil { return 0, fmt.Errorf("数据库未初始化") } memoryCfg := internal.GetProjectConfig().Memory if !memoryCfg.Enabled { return 0, nil } // 获取或创建 Session session := currentSession if session == nil { // 如果没有 session,返回错误让用户创建 return 0, ErrNeedNewSession } // 保存原始 session summary(用于恢复) originalSummary := session.Summary // 创建聊天记录 chat, err := db.CreateChat(session.ID, userInput) if err != nil { return 0, fmt.Errorf("创建聊天记录失败: %v", err) } // 添加 AI 回复 chat.AddAIReply(aiReply) chat.Summary = GenerateSummary(userInput, aiReply) chat.UpdatedAt = time.Now().Unix() // 生成并保存向量 vs := GetVectorService() if vs != nil && memoryCfg.Vector.APIKey != "" { generateEmbedding(chat) } if err := db.UpdateChat(chat); err != nil { return 0, fmt.Errorf("更新聊天记录失败: %v", err) } // 更新 Session session.AddChatID(chat.ID) // 只有普通对话才更新 session summary,recall 查询保持原 summary if useSessionSummary { session.Summary = GenerateSummary("", session.Summary+"\n"+aiReply) } session.UpdatedAt = time.Now().Unix() // Session 也要生成向量 if vs != nil && memoryCfg.Vector.APIKey != "" && useSessionSummary { generateSessionEmbedding(session) } if err := db.UpdateSession(session); err != nil { return 0, fmt.Errorf("更新会话失败: %v", err) } // 恢复原始 session summary(避免 recall 结果污染) session.Summary = originalSummary currentSession = session return len(session.ChatIDs), nil } func generateEmbedding(chat *Chat) { vs := GetVectorService() if vs == nil { return } text := chat.UserInput if len(chat.AIReplies) > 0 { text = text + "\n" + chat.AIReplies[len(chat.AIReplies)-1] } embedding, err := vs.Generate(text) if err != nil { fmt.Printf("[memory] 向量生成失败: %v\n", err) return } chat.SummaryEmbedding = embedding chat.UpdatedAt = time.Now().Unix() if err := GetDB().UpdateChat(chat); err != nil { fmt.Printf("[memory] 向量保存失败: %v\n", err) } } func generateSessionEmbedding(session *Session) { vs := GetVectorService() if vs == nil { return } embedding, err := vs.Generate(session.Summary) if err != nil { fmt.Printf("[memory] Session 向量生成失败: %v\n", err) return } session.SummaryEmbedding = embedding session.UpdatedAt = time.Now().Unix() if err := GetDB().UpdateSession(session); err != nil { fmt.Printf("[memory] Session 向量保存失败: %v\n", err) } } func createSession() (*Session, error) { return CreateNewSession() } type SearchResult struct { Chat *Chat Similarity float64 } func SearchSimilar(query string, topK int) ([]SearchResult, error) { db := GetDB() if db == nil { return nil, fmt.Errorf("数据库未初始化") } vs := GetVectorService() if vs == nil { return nil, fmt.Errorf("向量服务未初始化") } // 获取配置的最大检索数 maxResults := internal.GetProjectConfig().Memory.Vector.MaxSearchResults if maxResults <= 0 { maxResults = 10 } if topK > maxResults { topK = maxResults } // 生成查询向量 queryEmbedding, err := vs.Generate(query) if err != nil { return nil, fmt.Errorf("查询向量生成失败: %w", err) } // 获取所有有向量的 chat chats, err := db.GetAllChats() if err != nil { return nil, fmt.Errorf("获取聊天记录失败: %w", err) } // 计算相似度并排序 type scoredChat struct { chat *Chat similarity float64 } var scored []scoredChat for _, chat := range chats { if len(chat.SummaryEmbedding) == 0 { continue } sim := vs.CosineSimilarity(queryEmbedding, chat.SummaryEmbedding) scored = append(scored, scoredChat{chat: chat, similarity: sim}) } // 按相似度排序 for i := 0; i < len(scored); i++ { for j := i + 1; j < len(scored); j++ { if scored[j].similarity > scored[i].similarity { scored[i], scored[j] = scored[j], scored[i] } } } // 取前 K 个 if topK > len(scored) { topK = len(scored) } results := make([]SearchResult, topK) for i := 0; i < topK; i++ { results[i] = SearchResult{ Chat: scored[i].chat, Similarity: scored[i].similarity, } } return results, nil } func (db *DB) GetAllChats() ([]*Chat, error) { rows, err := db.db.Query( `SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE summary_embedding IS NOT NULL AND length(summary_embedding) > 0 ORDER BY id DESC`, ) if err != nil { return nil, err } defer rows.Close() var chats []*Chat for rows.Next() { var c Chat var aiRepliesJSON string var summaryEmbedding []byte if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &summaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil { continue } c.SummaryEmbedding = summaryEmbedding c.SetAIRepliesFromJSON(aiRepliesJSON) chats = append(chats, &c) } return chats, nil }