Files
HxClaw/cmd/hxclaw/internal/memory/save.go

433 lines
9.4 KiB
Go
Raw Normal View History

package memory
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
)
var lastContext string
var ErrNeedNewSession = fmt.Errorf("需要创建新会话")
func GetContextPrompt(userInput string) string {
db := GetDB()
if db == nil {
return ""
}
memoryCfg := internal.GetProjectConfig().Memory
if !memoryCfg.Enabled {
return ""
}
session := currentSession
if session == nil {
return "" // 没有 session不提供上下文
}
var context string
// 读取 picoclaw 的长期记忆
picoMemory := readPicoClawMemory()
if picoMemory != "" {
context += "=== 长期记忆 ===\n" + picoMemory + "\n============\n"
}
// 添加会话摘要
if session.Summary != "" {
context += "=== 当前会话摘要 ===\n" + session.Summary + "\n============\n"
}
if shouldRecall(userInput) {
recallResult := detectAndRecall(userInput)
if recallResult != "" {
context += recallResult
}
}
lastContext = context
return context
}
func readPicoClawMemory() string {
memoryPath := filepath.Join(getPicoclawWorkspace(), "memory", "MEMORY.md")
data, err := os.ReadFile(memoryPath)
if err != nil {
return ""
}
content := strings.TrimSpace(string(data))
lines := strings.Split(content, "\n")
var sb strings.Builder
inSection := false
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
if strings.HasPrefix(trimmed, "#") {
inSection = true
continue
}
if inSection {
sb.WriteString(trimmed)
sb.WriteString("\n")
}
}
return strings.TrimSpace(sb.String())
}
func getPicoclawWorkspace() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".picoclaw", "workspace")
}
func shouldRecall(input string) bool {
memoryCfg := internal.GetProjectConfig().Memory
for _, cmd := range []string{"/recall", "/memory"} {
if strings.HasPrefix(input, cmd) {
return true
}
}
recallCfg := memoryCfg.Recall
for _, keyword := range recallCfg.Keywords {
if strings.Contains(input, keyword) {
return true
}
}
if recallCfg.AutoRecall {
return shouldAutoRecall(input)
}
return false
}
func shouldAutoRecall(input string) bool {
vs := GetVectorService()
if vs == nil {
return false
}
memoryCfg := internal.GetProjectConfig().Memory
threshold := memoryCfg.Recall.SimilarityThreshold
if threshold <= 0 {
threshold = 0.7
}
session := currentSession
if session == nil {
return false
}
if session.Summary == "" || len(session.SummaryEmbedding) == 0 {
return false
}
inputEmb, err := vs.Generate(input)
if err != nil {
return false
}
similarity := vs.CosineSimilarity(inputEmb, session.SummaryEmbedding)
return similarity > threshold
}
func detectAndRecall(input string) string {
input = strings.TrimSpace(input)
if strings.HasPrefix(input, "/recall") {
input = strings.TrimPrefix(input, "/recall")
input = strings.TrimSpace(input)
if input == "" {
result, _ := RecallHistory()
return FormatRecallResults(result)
}
}
for _, keyword := range []string{"之前", "聊过", "记得", "曾经", "谈论过", "提过"} {
if strings.Contains(input, keyword) {
return recallByKeyword(input, keyword)
}
}
return ""
}
func recallByKeyword(input, keyword string) string {
db := GetDB()
if db == nil {
return ""
}
if keyword == "之前" || keyword == "聊过" || keyword == "曾经" {
if strings.Contains(input, "什么") || input == "之前" || strings.Contains(input, "都聊") {
result, _ := RecallHistory()
return FormatRecallResults(result)
}
topic := extractTopic(input)
if topic != "" {
result, _ := RecallTopic(topic, 5)
return FormatRecallResults(result)
}
}
if strings.Contains(input, "那次") || strings.Contains(input, "那次") {
result, _ := RecallHistory()
return FormatRecallResults(result)
}
topic := extractTopic(input)
if topic != "" {
result, _ := RecallTopic(topic, 5)
return FormatRecallResults(result)
}
return ""
}
func extractTopic(input string) string {
topic := input
removePrefixes := []string{"之前", "聊过", "谈论过", "提过", "记得", "曾经", "关于", "我问过"}
for _, prefix := range removePrefixes {
topic = strings.Replace(topic, prefix, "", -1)
}
topic = strings.Trim(topic, " ,。.?!:")
return topic
}
func ShouldSkipSummaryUpdate(input string) bool {
return shouldRecall(input)
}
func SaveChat(userInput, aiReply string, useSessionSummary bool) (int, error) {
db := GetDB()
if db == nil {
return 0, fmt.Errorf("数据库未初始化")
}
memoryCfg := internal.GetProjectConfig().Memory
if !memoryCfg.Enabled {
return 0, nil
}
// 获取或创建 Session
session := currentSession
if session == nil {
// 如果没有 session返回错误让用户创建
return 0, ErrNeedNewSession
}
// 保存原始 session summary用于恢复
originalSummary := session.Summary
// 创建聊天记录
chat, err := db.CreateChat(session.ID, userInput)
if err != nil {
return 0, fmt.Errorf("创建聊天记录失败: %v", err)
}
// 添加 AI 回复
chat.AddAIReply(aiReply)
chat.Summary = GenerateSummary(userInput, aiReply)
chat.UpdatedAt = time.Now().Unix()
// 生成并保存向量
vs := GetVectorService()
if vs != nil && memoryCfg.Vector.APIKey != "" {
generateEmbedding(chat)
}
if err := db.UpdateChat(chat); err != nil {
return 0, fmt.Errorf("更新聊天记录失败: %v", err)
}
// 更新 Session
session.AddChatID(chat.ID)
// 只有普通对话才更新 session summaryrecall 查询保持原 summary
if useSessionSummary {
session.Summary = GenerateSummary("", session.Summary+"\n"+aiReply)
}
session.UpdatedAt = time.Now().Unix()
// Session 也要生成向量
if vs != nil && memoryCfg.Vector.APIKey != "" && useSessionSummary {
generateSessionEmbedding(session)
}
if err := db.UpdateSession(session); err != nil {
return 0, fmt.Errorf("更新会话失败: %v", err)
}
// 恢复原始 session summary避免 recall 结果污染)
session.Summary = originalSummary
currentSession = session
return len(session.ChatIDs), nil
}
func generateEmbedding(chat *Chat) {
vs := GetVectorService()
if vs == nil {
return
}
text := chat.UserInput
if len(chat.AIReplies) > 0 {
text = text + "\n" + chat.AIReplies[len(chat.AIReplies)-1]
}
embedding, err := vs.Generate(text)
if err != nil {
fmt.Printf("[memory] 向量生成失败: %v\n", err)
return
}
chat.SummaryEmbedding = embedding
chat.UpdatedAt = time.Now().Unix()
if err := GetDB().UpdateChat(chat); err != nil {
fmt.Printf("[memory] 向量保存失败: %v\n", err)
}
}
func generateSessionEmbedding(session *Session) {
vs := GetVectorService()
if vs == nil {
return
}
embedding, err := vs.Generate(session.Summary)
if err != nil {
fmt.Printf("[memory] Session 向量生成失败: %v\n", err)
return
}
session.SummaryEmbedding = embedding
session.UpdatedAt = time.Now().Unix()
if err := GetDB().UpdateSession(session); err != nil {
fmt.Printf("[memory] Session 向量保存失败: %v\n", err)
}
}
func createSession() (*Session, error) {
return CreateNewSession()
}
type SearchResult struct {
Chat *Chat
Similarity float64
}
func SearchSimilar(query string, topK int) ([]SearchResult, error) {
db := GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
vs := GetVectorService()
if vs == nil {
return nil, fmt.Errorf("向量服务未初始化")
}
// 获取配置的最大检索数
maxResults := internal.GetProjectConfig().Memory.Vector.MaxSearchResults
if maxResults <= 0 {
maxResults = 10
}
if topK > maxResults {
topK = maxResults
}
// 生成查询向量
queryEmbedding, err := vs.Generate(query)
if err != nil {
return nil, fmt.Errorf("查询向量生成失败: %w", err)
}
// 获取所有有向量的 chat
chats, err := db.GetAllChats()
if err != nil {
return nil, fmt.Errorf("获取聊天记录失败: %w", err)
}
// 计算相似度并排序
type scoredChat struct {
chat *Chat
similarity float64
}
var scored []scoredChat
for _, chat := range chats {
if len(chat.SummaryEmbedding) == 0 {
continue
}
sim := vs.CosineSimilarity(queryEmbedding, chat.SummaryEmbedding)
scored = append(scored, scoredChat{chat: chat, similarity: sim})
}
// 按相似度排序
for i := 0; i < len(scored); i++ {
for j := i + 1; j < len(scored); j++ {
if scored[j].similarity > scored[i].similarity {
scored[i], scored[j] = scored[j], scored[i]
}
}
}
// 取前 K 个
if topK > len(scored) {
topK = len(scored)
}
results := make([]SearchResult, topK)
for i := 0; i < topK; i++ {
results[i] = SearchResult{
Chat: scored[i].chat,
Similarity: scored[i].similarity,
}
}
return results, nil
}
func (db *DB) GetAllChats() ([]*Chat, error) {
rows, err := db.db.Query(
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE summary_embedding IS NOT NULL AND length(summary_embedding) > 0 ORDER BY id DESC`,
)
if err != nil {
return nil, err
}
defer rows.Close()
var chats []*Chat
for rows.Next() {
var c Chat
var aiRepliesJSON string
var summaryEmbedding []byte
if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &summaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil {
continue
}
c.SummaryEmbedding = summaryEmbedding
c.SetAIRepliesFromJSON(aiRepliesJSON)
chats = append(chats, &c)
}
return chats, nil
}