496 lines
11 KiB
Go
496 lines
11 KiB
Go
package memory
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
|
||
)
|
||
|
||
var lastContext string
|
||
|
||
var ErrNeedNewSession = fmt.Errorf("需要创建新会话")
|
||
|
||
func generateChatSummary(userInput, aiReply string) string {
|
||
if summaryAgent == nil {
|
||
return GenerateSummary(userInput, aiReply)
|
||
}
|
||
|
||
memoryCfg := internal.GetProjectConfig().Memory
|
||
timeout := time.Duration(memoryCfg.SummaryTimeout) * time.Second
|
||
if timeout <= 0 {
|
||
timeout = 30 * time.Second
|
||
}
|
||
|
||
prompt := fmt.Sprintf(`用极简文言文概括对话,一句话,只包含最关键信息(如人名、地点、事件、核心结论):
|
||
问:%s
|
||
答:%s`, userInput, aiReply)
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
|
||
resp, err := summaryAgent.ProcessDirect(ctx, prompt, "summary:chat")
|
||
if err != nil {
|
||
return GenerateSummary(userInput, aiReply)
|
||
}
|
||
return strings.TrimSpace(resp)
|
||
}
|
||
|
||
func generateSessionSummary(oldSummary, userInput, aiReply string) string {
|
||
if summaryAgent == nil {
|
||
return GenerateSummary(oldSummary, oldSummary+"\n"+aiReply)
|
||
}
|
||
|
||
memoryCfg := internal.GetProjectConfig().Memory
|
||
timeout := time.Duration(memoryCfg.SummaryTimeout) * time.Second
|
||
if timeout <= 0 {
|
||
timeout = 30 * time.Second
|
||
}
|
||
|
||
prompt := fmt.Sprintf(`精简整合以下历史,提取关键信息,去除冗余描述,保留核心要点:
|
||
历史:%s
|
||
新对话:问%s 答%s`, oldSummary, userInput, aiReply)
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
|
||
resp, err := summaryAgent.ProcessDirect(ctx, prompt, "summary:session")
|
||
if err != nil {
|
||
return GenerateSummary(oldSummary, oldSummary+"\n"+aiReply)
|
||
}
|
||
return strings.TrimSpace(resp)
|
||
}
|
||
|
||
func GetContextPrompt(userInput string) string {
|
||
db := GetDB()
|
||
if db == nil {
|
||
return ""
|
||
}
|
||
|
||
memoryCfg := internal.GetProjectConfig().Memory
|
||
if !memoryCfg.Enabled {
|
||
return ""
|
||
}
|
||
|
||
session := currentSession
|
||
if session == nil {
|
||
return "" // 没有 session,不提供上下文
|
||
}
|
||
|
||
var context string
|
||
|
||
// 读取 picoclaw 的长期记忆
|
||
picoMemory := readPicoClawMemory()
|
||
if picoMemory != "" {
|
||
context += "=== 长期记忆 ===\n" + picoMemory + "\n============\n"
|
||
}
|
||
|
||
// 添加会话摘要
|
||
if session.Summary != "" {
|
||
context += "=== 当前会话摘要 ===\n" + session.Summary + "\n============\n"
|
||
}
|
||
|
||
if shouldRecall(userInput) {
|
||
recallResult := detectAndRecall(userInput)
|
||
if recallResult != "" {
|
||
context += recallResult
|
||
}
|
||
}
|
||
|
||
lastContext = context
|
||
return context
|
||
}
|
||
|
||
func readPicoClawMemory() string {
|
||
memoryPath := filepath.Join(getPicoclawWorkspace(), "memory", "MEMORY.md")
|
||
data, err := os.ReadFile(memoryPath)
|
||
if err != nil {
|
||
return ""
|
||
}
|
||
|
||
content := strings.TrimSpace(string(data))
|
||
|
||
lines := strings.Split(content, "\n")
|
||
var sb strings.Builder
|
||
inSection := false
|
||
|
||
for _, line := range lines {
|
||
trimmed := strings.TrimSpace(line)
|
||
if trimmed == "" {
|
||
continue
|
||
}
|
||
if strings.HasPrefix(trimmed, "#") {
|
||
inSection = true
|
||
continue
|
||
}
|
||
if inSection {
|
||
sb.WriteString(trimmed)
|
||
sb.WriteString("\n")
|
||
}
|
||
}
|
||
|
||
return strings.TrimSpace(sb.String())
|
||
}
|
||
|
||
func getPicoclawWorkspace() string {
|
||
home, err := os.UserHomeDir()
|
||
if err != nil {
|
||
return ""
|
||
}
|
||
return filepath.Join(home, ".picoclaw", "workspace")
|
||
}
|
||
|
||
func shouldRecall(input string) bool {
|
||
memoryCfg := internal.GetProjectConfig().Memory
|
||
|
||
for _, cmd := range []string{"/recall", "/memory"} {
|
||
if strings.HasPrefix(input, cmd) {
|
||
return true
|
||
}
|
||
}
|
||
|
||
recallCfg := memoryCfg.Recall
|
||
for _, keyword := range recallCfg.Keywords {
|
||
if strings.Contains(input, keyword) {
|
||
return true
|
||
}
|
||
}
|
||
|
||
if recallCfg.AutoRecall {
|
||
return shouldAutoRecall(input)
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
func shouldAutoRecall(input string) bool {
|
||
vs := GetVectorService()
|
||
if vs == nil {
|
||
return false
|
||
}
|
||
|
||
memoryCfg := internal.GetProjectConfig().Memory
|
||
threshold := memoryCfg.Recall.SimilarityThreshold
|
||
if threshold <= 0 {
|
||
threshold = 0.7
|
||
}
|
||
|
||
session := currentSession
|
||
if session == nil {
|
||
return false
|
||
}
|
||
|
||
if session.Summary == "" || len(session.SummaryEmbedding) == 0 {
|
||
return false
|
||
}
|
||
|
||
inputEmb, err := vs.Generate(input)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
|
||
similarity := vs.CosineSimilarity(inputEmb, session.SummaryEmbedding)
|
||
return similarity > threshold
|
||
}
|
||
|
||
func detectAndRecall(input string) string {
|
||
input = strings.TrimSpace(input)
|
||
|
||
if strings.HasPrefix(input, "/recall") {
|
||
input = strings.TrimPrefix(input, "/recall")
|
||
input = strings.TrimSpace(input)
|
||
if input == "" {
|
||
result, _ := RecallHistory()
|
||
return FormatRecallResults(result)
|
||
}
|
||
}
|
||
|
||
for _, keyword := range []string{"之前", "聊过", "记得", "曾经", "谈论过", "提过"} {
|
||
if strings.Contains(input, keyword) {
|
||
return recallByKeyword(input, keyword)
|
||
}
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
func recallByKeyword(input, keyword string) string {
|
||
db := GetDB()
|
||
if db == nil {
|
||
return ""
|
||
}
|
||
|
||
if keyword == "之前" || keyword == "聊过" || keyword == "曾经" {
|
||
if strings.Contains(input, "什么") || input == "之前" || strings.Contains(input, "都聊") {
|
||
result, _ := RecallHistory()
|
||
return FormatRecallResults(result)
|
||
}
|
||
|
||
topic := extractTopic(input)
|
||
if topic != "" {
|
||
result, _ := RecallTopic(topic, 5)
|
||
return FormatRecallResults(result)
|
||
}
|
||
}
|
||
|
||
if strings.Contains(input, "那次") || strings.Contains(input, "那次") {
|
||
result, _ := RecallHistory()
|
||
return FormatRecallResults(result)
|
||
}
|
||
|
||
topic := extractTopic(input)
|
||
if topic != "" {
|
||
result, _ := RecallTopic(topic, 5)
|
||
return FormatRecallResults(result)
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
func extractTopic(input string) string {
|
||
topic := input
|
||
removePrefixes := []string{"之前", "聊过", "谈论过", "提过", "记得", "曾经", "关于", "我问过"}
|
||
for _, prefix := range removePrefixes {
|
||
topic = strings.Replace(topic, prefix, "", -1)
|
||
}
|
||
topic = strings.Trim(topic, " ,,。.??!!::")
|
||
return topic
|
||
}
|
||
|
||
func ShouldSkipSummaryUpdate(input string) bool {
|
||
return shouldRecall(input)
|
||
}
|
||
|
||
func SaveChat(userInput, aiReply string, useSessionSummary bool) (int, error) {
|
||
db := GetDB()
|
||
if db == nil {
|
||
return 0, fmt.Errorf("数据库未初始化")
|
||
}
|
||
|
||
memoryCfg := internal.GetProjectConfig().Memory
|
||
if !memoryCfg.Enabled {
|
||
return 0, nil
|
||
}
|
||
|
||
// 获取或创建 Session
|
||
session := currentSession
|
||
if session == nil && memoryCfg.AutoSession {
|
||
// 自动创建 session
|
||
newSession, err := CreateNewSession()
|
||
if err != nil {
|
||
return 0, fmt.Errorf("自动创建会话失败: %v", err)
|
||
}
|
||
session = newSession
|
||
}
|
||
|
||
if session == nil {
|
||
return 0, ErrNeedNewSession
|
||
}
|
||
|
||
// 保存原始 session summary(用于恢复)
|
||
originalSummary := session.Summary
|
||
|
||
// 创建聊天记录
|
||
chat, err := db.CreateChat(session.ID, userInput)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("创建聊天记录失败: %v", err)
|
||
}
|
||
|
||
// 添加 AI 回复
|
||
chat.AddAIReply(aiReply)
|
||
chat.Summary = generateChatSummary(userInput, aiReply)
|
||
chat.UpdatedAt = time.Now().Unix()
|
||
|
||
// 生成并保存向量
|
||
vs := GetVectorService()
|
||
if vs != nil && memoryCfg.Vector.APIKey != "" {
|
||
generateEmbedding(chat)
|
||
}
|
||
|
||
if err := db.UpdateChat(chat); err != nil {
|
||
return 0, fmt.Errorf("更新聊天记录失败: %v", err)
|
||
}
|
||
|
||
// 更新 Session
|
||
session.AddChatID(chat.ID)
|
||
|
||
// 只有普通对话才更新 session summary,recall 查询保持原 summary
|
||
if useSessionSummary {
|
||
session.Summary = generateSessionSummary(session.Summary, userInput, aiReply)
|
||
}
|
||
session.UpdatedAt = time.Now().Unix()
|
||
|
||
// Session 也要生成向量
|
||
if vs != nil && memoryCfg.Vector.APIKey != "" && useSessionSummary {
|
||
generateSessionEmbedding(session)
|
||
}
|
||
|
||
if err := db.UpdateSession(session); err != nil {
|
||
return 0, fmt.Errorf("更新会话失败: %v", err)
|
||
}
|
||
|
||
// 恢复原始 session summary(避免 recall 结果污染)
|
||
session.Summary = originalSummary
|
||
|
||
// 首次自动创建后赋值 currentSession
|
||
if currentSession == nil {
|
||
currentSession = session
|
||
}
|
||
|
||
return len(session.ChatIDs), nil
|
||
}
|
||
|
||
func generateEmbedding(chat *Chat) {
|
||
vs := GetVectorService()
|
||
if vs == nil {
|
||
return
|
||
}
|
||
|
||
text := chat.UserInput
|
||
if len(chat.AIReplies) > 0 {
|
||
text = text + "\n" + chat.AIReplies[len(chat.AIReplies)-1]
|
||
}
|
||
|
||
embedding, err := vs.Generate(text)
|
||
if err != nil {
|
||
fmt.Printf("[memory] 向量生成失败: %v\n", err)
|
||
return
|
||
}
|
||
|
||
chat.SummaryEmbedding = embedding
|
||
chat.UpdatedAt = time.Now().Unix()
|
||
|
||
if err := GetDB().UpdateChat(chat); err != nil {
|
||
fmt.Printf("[memory] 向量保存失败: %v\n", err)
|
||
}
|
||
}
|
||
|
||
func generateSessionEmbedding(session *Session) {
|
||
vs := GetVectorService()
|
||
if vs == nil {
|
||
return
|
||
}
|
||
|
||
embedding, err := vs.Generate(session.Summary)
|
||
if err != nil {
|
||
fmt.Printf("[memory] Session 向量生成失败: %v\n", err)
|
||
return
|
||
}
|
||
|
||
session.SummaryEmbedding = embedding
|
||
session.UpdatedAt = time.Now().Unix()
|
||
|
||
if err := GetDB().UpdateSession(session); err != nil {
|
||
fmt.Printf("[memory] Session 向量保存失败: %v\n", err)
|
||
}
|
||
}
|
||
|
||
func createSession() (*Session, error) {
|
||
return CreateNewSession()
|
||
}
|
||
|
||
type SearchResult struct {
|
||
Chat *Chat
|
||
Similarity float64
|
||
}
|
||
|
||
func SearchSimilar(query string, topK int) ([]SearchResult, error) {
|
||
db := GetDB()
|
||
if db == nil {
|
||
return nil, fmt.Errorf("数据库未初始化")
|
||
}
|
||
|
||
vs := GetVectorService()
|
||
if vs == nil {
|
||
return nil, fmt.Errorf("向量服务未初始化")
|
||
}
|
||
|
||
// 获取配置的最大检索数
|
||
maxResults := internal.GetProjectConfig().Memory.Vector.MaxSearchResults
|
||
if maxResults <= 0 {
|
||
maxResults = 10
|
||
}
|
||
if topK > maxResults {
|
||
topK = maxResults
|
||
}
|
||
|
||
// 生成查询向量
|
||
queryEmbedding, err := vs.Generate(query)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询向量生成失败: %w", err)
|
||
}
|
||
|
||
// 获取所有有向量的 chat
|
||
chats, err := db.GetAllChats()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取聊天记录失败: %w", err)
|
||
}
|
||
|
||
// 计算相似度并排序
|
||
type scoredChat struct {
|
||
chat *Chat
|
||
similarity float64
|
||
}
|
||
|
||
var scored []scoredChat
|
||
for _, chat := range chats {
|
||
if len(chat.SummaryEmbedding) == 0 {
|
||
continue
|
||
}
|
||
sim := vs.CosineSimilarity(queryEmbedding, chat.SummaryEmbedding)
|
||
scored = append(scored, scoredChat{chat: chat, similarity: sim})
|
||
}
|
||
|
||
// 按相似度排序
|
||
for i := 0; i < len(scored); i++ {
|
||
for j := i + 1; j < len(scored); j++ {
|
||
if scored[j].similarity > scored[i].similarity {
|
||
scored[i], scored[j] = scored[j], scored[i]
|
||
}
|
||
}
|
||
}
|
||
|
||
// 取前 K 个
|
||
if topK > len(scored) {
|
||
topK = len(scored)
|
||
}
|
||
|
||
results := make([]SearchResult, topK)
|
||
for i := 0; i < topK; i++ {
|
||
results[i] = SearchResult{
|
||
Chat: scored[i].chat,
|
||
Similarity: scored[i].similarity,
|
||
}
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
func (db *DB) GetAllChats() ([]*Chat, error) {
|
||
rows, err := db.db.Query(
|
||
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE summary_embedding IS NOT NULL AND length(summary_embedding) > 0 ORDER BY id DESC`,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var chats []*Chat
|
||
for rows.Next() {
|
||
var c Chat
|
||
var aiRepliesJSON string
|
||
var summaryEmbedding []byte
|
||
if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &summaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil {
|
||
continue
|
||
}
|
||
c.SummaryEmbedding = summaryEmbedding
|
||
c.SetAIRepliesFromJSON(aiRepliesJSON)
|
||
chats = append(chats, &c)
|
||
}
|
||
|
||
return chats, nil
|
||
} |