Files
HxClaw/cmd/hxclaw/internal/memory/save.go

438 lines
9.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package memory
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
)
var lastContext string
var ErrNeedNewSession = fmt.Errorf("需要创建新会话")
func GetContextPrompt(userInput string) string {
db := GetDB()
if db == nil {
return ""
}
memoryCfg := internal.GetProjectConfig().Memory
if !memoryCfg.Enabled {
return ""
}
session := currentSession
if session == nil {
var err error
session, err = db.GetLatestSession()
if err != nil || session == nil {
return ""
}
currentSession = session
}
var context string
// 读取 picoclaw 的长期记忆
picoMemory := readPicoClawMemory()
if picoMemory != "" {
context += "=== 长期记忆 ===\n" + picoMemory + "\n============\n"
}
// 添加会话摘要
if session.Summary != "" {
context += "=== 当前会话摘要 ===\n" + session.Summary + "\n============\n"
}
if shouldRecall(userInput) {
recallResult := detectAndRecall(userInput)
if recallResult != "" {
context += recallResult
}
}
lastContext = context
return context
}
func readPicoClawMemory() string {
memoryPath := filepath.Join(getPicoclawWorkspace(), "memory", "MEMORY.md")
data, err := os.ReadFile(memoryPath)
if err != nil {
return ""
}
content := strings.TrimSpace(string(data))
lines := strings.Split(content, "\n")
var sb strings.Builder
inSection := false
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
if strings.HasPrefix(trimmed, "#") {
inSection = true
continue
}
if inSection {
sb.WriteString(trimmed)
sb.WriteString("\n")
}
}
return strings.TrimSpace(sb.String())
}
func getPicoclawWorkspace() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".picoclaw", "workspace")
}
func shouldRecall(input string) bool {
memoryCfg := internal.GetProjectConfig().Memory
for _, cmd := range []string{"/recall", "/memory"} {
if strings.HasPrefix(input, cmd) {
return true
}
}
recallCfg := memoryCfg.Recall
for _, keyword := range recallCfg.Keywords {
if strings.Contains(input, keyword) {
return true
}
}
if recallCfg.AutoRecall {
return shouldAutoRecall(input)
}
return false
}
func shouldAutoRecall(input string) bool {
vs := GetVectorService()
if vs == nil {
return false
}
memoryCfg := internal.GetProjectConfig().Memory
threshold := memoryCfg.Recall.SimilarityThreshold
if threshold <= 0 {
threshold = 0.7
}
session := currentSession
if session == nil {
return false
}
if session.Summary == "" || len(session.SummaryEmbedding) == 0 {
return false
}
inputEmb, err := vs.Generate(input)
if err != nil {
return false
}
similarity := vs.CosineSimilarity(inputEmb, session.SummaryEmbedding)
return similarity > threshold
}
func detectAndRecall(input string) string {
input = strings.TrimSpace(input)
if strings.HasPrefix(input, "/recall") {
input = strings.TrimPrefix(input, "/recall")
input = strings.TrimSpace(input)
if input == "" {
result, _ := RecallHistory()
return FormatRecallResults(result)
}
}
for _, keyword := range []string{"之前", "聊过", "记得", "曾经", "谈论过", "提过"} {
if strings.Contains(input, keyword) {
return recallByKeyword(input, keyword)
}
}
return ""
}
func recallByKeyword(input, keyword string) string {
db := GetDB()
if db == nil {
return ""
}
if keyword == "之前" || keyword == "聊过" || keyword == "曾经" {
if strings.Contains(input, "什么") || input == "之前" || strings.Contains(input, "都聊") {
result, _ := RecallHistory()
return FormatRecallResults(result)
}
topic := extractTopic(input)
if topic != "" {
result, _ := RecallTopic(topic, 5)
return FormatRecallResults(result)
}
}
if strings.Contains(input, "那次") || strings.Contains(input, "那次") {
result, _ := RecallHistory()
return FormatRecallResults(result)
}
topic := extractTopic(input)
if topic != "" {
result, _ := RecallTopic(topic, 5)
return FormatRecallResults(result)
}
return ""
}
func extractTopic(input string) string {
topic := input
removePrefixes := []string{"之前", "聊过", "谈论过", "提过", "记得", "曾经", "关于", "我问过"}
for _, prefix := range removePrefixes {
topic = strings.Replace(topic, prefix, "", -1)
}
topic = strings.Trim(topic, " ,。.?!:")
return topic
}
func ShouldSkipSummaryUpdate(input string) bool {
return shouldRecall(input)
}
func SaveChat(userInput, aiReply string, useSessionSummary bool) (int, error) {
db := GetDB()
if db == nil {
return 0, fmt.Errorf("数据库未初始化")
}
memoryCfg := internal.GetProjectConfig().Memory
if !memoryCfg.Enabled {
return 0, nil
}
// 获取或创建 Session
session := currentSession
if session == nil {
// 如果没有 session返回错误让用户创建
return 0, ErrNeedNewSession
}
// 保存原始 session summary用于恢复
originalSummary := session.Summary
// 创建聊天记录
chat, err := db.CreateChat(session.ID, userInput)
if err != nil {
return 0, fmt.Errorf("创建聊天记录失败: %v", err)
}
// 添加 AI 回复
chat.AddAIReply(aiReply)
chat.Summary = GenerateSummary(userInput, aiReply)
chat.UpdatedAt = time.Now().Unix()
// 生成并保存向量
vs := GetVectorService()
if vs != nil && memoryCfg.Vector.APIKey != "" {
go generateEmbedding(chat)
}
if err := db.UpdateChat(chat); err != nil {
return 0, fmt.Errorf("更新聊天记录失败: %v", err)
}
// 更新 Session
session.AddChatID(chat.ID)
// 只有普通对话才更新 session summaryrecall 查询保持原 summary
if useSessionSummary {
session.Summary = GenerateSummary("", session.Summary+"\n"+aiReply)
}
session.UpdatedAt = time.Now().Unix()
// Session 也要生成向量
if vs != nil && memoryCfg.Vector.APIKey != "" && useSessionSummary {
go generateSessionEmbedding(session)
}
if err := db.UpdateSession(session); err != nil {
return 0, fmt.Errorf("更新会话失败: %v", err)
}
// 恢复原始 session summary避免 recall 结果污染)
session.Summary = originalSummary
currentSession = session
return len(session.ChatIDs), nil
}
func generateEmbedding(chat *Chat) {
vs := GetVectorService()
if vs == nil {
return
}
text := chat.UserInput
if len(chat.AIReplies) > 0 {
text = text + "\n" + chat.AIReplies[len(chat.AIReplies)-1]
}
embedding, err := vs.Generate(text)
if err != nil {
fmt.Printf("[memory] 向量生成失败: %v\n", err)
return
}
chat.SummaryEmbedding = embedding
chat.UpdatedAt = time.Now().Unix()
if err := GetDB().UpdateChat(chat); err != nil {
fmt.Printf("[memory] 向量保存失败: %v\n", err)
}
}
func generateSessionEmbedding(session *Session) {
vs := GetVectorService()
if vs == nil {
return
}
embedding, err := vs.Generate(session.Summary)
if err != nil {
fmt.Printf("[memory] Session 向量生成失败: %v\n", err)
return
}
session.SummaryEmbedding = embedding
session.UpdatedAt = time.Now().Unix()
if err := GetDB().UpdateSession(session); err != nil {
fmt.Printf("[memory] Session 向量保存失败: %v\n", err)
}
}
func createSession() (*Session, error) {
return CreateNewSession()
}
type SearchResult struct {
Chat *Chat
Similarity float64
}
func SearchSimilar(query string, topK int) ([]SearchResult, error) {
db := GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
vs := GetVectorService()
if vs == nil {
return nil, fmt.Errorf("向量服务未初始化")
}
// 获取配置的最大检索数
maxResults := internal.GetProjectConfig().Memory.Vector.MaxSearchResults
if maxResults <= 0 {
maxResults = 10
}
if topK > maxResults {
topK = maxResults
}
// 生成查询向量
queryEmbedding, err := vs.Generate(query)
if err != nil {
return nil, fmt.Errorf("查询向量生成失败: %w", err)
}
// 获取所有有向量的 chat
chats, err := db.GetAllChats()
if err != nil {
return nil, fmt.Errorf("获取聊天记录失败: %w", err)
}
// 计算相似度并排序
type scoredChat struct {
chat *Chat
similarity float64
}
var scored []scoredChat
for _, chat := range chats {
if len(chat.SummaryEmbedding) == 0 {
continue
}
sim := vs.CosineSimilarity(queryEmbedding, chat.SummaryEmbedding)
scored = append(scored, scoredChat{chat: chat, similarity: sim})
}
// 按相似度排序
for i := 0; i < len(scored); i++ {
for j := i + 1; j < len(scored); j++ {
if scored[j].similarity > scored[i].similarity {
scored[i], scored[j] = scored[j], scored[i]
}
}
}
// 取前 K 个
if topK > len(scored) {
topK = len(scored)
}
results := make([]SearchResult, topK)
for i := 0; i < topK; i++ {
results[i] = SearchResult{
Chat: scored[i].chat,
Similarity: scored[i].similarity,
}
}
return results, nil
}
func (db *DB) GetAllChats() ([]*Chat, error) {
rows, err := db.db.Query(
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE summary_embedding IS NOT NULL AND length(summary_embedding) > 0 ORDER BY id DESC`,
)
if err != nil {
return nil, err
}
defer rows.Close()
var chats []*Chat
for rows.Next() {
var c Chat
var aiRepliesJSON string
var summaryEmbedding []byte
if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &summaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil {
continue
}
c.SummaryEmbedding = summaryEmbedding
c.SetAIRepliesFromJSON(aiRepliesJSON)
chats = append(chats, &c)
}
return chats, nil
}