443 lines
9.6 KiB
Go
443 lines
9.6 KiB
Go
|
|
package memory
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"fmt"
|
|||
|
|
"os"
|
|||
|
|
"path/filepath"
|
|||
|
|
"strings"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
var lastContext string
|
|||
|
|
|
|||
|
|
func GetContextPrompt(userInput string) string {
|
|||
|
|
db := GetDB()
|
|||
|
|
if db == nil {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
memoryCfg := internal.GetProjectConfig().Memory
|
|||
|
|
if !memoryCfg.Enabled {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
session := currentSession
|
|||
|
|
if session == nil {
|
|||
|
|
var err error
|
|||
|
|
session, err = db.GetLatestSession()
|
|||
|
|
if err != nil || session == nil {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
currentSession = session
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var context string
|
|||
|
|
|
|||
|
|
// 读取 picoclaw 的长期记忆
|
|||
|
|
picoMemory := readPicoClawMemory()
|
|||
|
|
if picoMemory != "" {
|
|||
|
|
context += "=== 长期记忆 ===\n" + picoMemory + "\n============\n"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 添加会话摘要
|
|||
|
|
if session.Summary != "" {
|
|||
|
|
context += "=== 当前会话摘要 ===\n" + session.Summary + "\n============\n"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if shouldRecall(userInput) {
|
|||
|
|
recallResult := detectAndRecall(userInput)
|
|||
|
|
if recallResult != "" {
|
|||
|
|
context += recallResult
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
lastContext = context
|
|||
|
|
return context
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func readPicoClawMemory() string {
|
|||
|
|
memoryPath := filepath.Join(getPicoclawWorkspace(), "memory", "MEMORY.md")
|
|||
|
|
data, err := os.ReadFile(memoryPath)
|
|||
|
|
if err != nil {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
content := strings.TrimSpace(string(data))
|
|||
|
|
|
|||
|
|
lines := strings.Split(content, "\n")
|
|||
|
|
var sb strings.Builder
|
|||
|
|
inSection := false
|
|||
|
|
|
|||
|
|
for _, line := range lines {
|
|||
|
|
trimmed := strings.TrimSpace(line)
|
|||
|
|
if trimmed == "" {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if strings.HasPrefix(trimmed, "#") {
|
|||
|
|
inSection = true
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if inSection {
|
|||
|
|
sb.WriteString(trimmed)
|
|||
|
|
sb.WriteString("\n")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return strings.TrimSpace(sb.String())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func getPicoclawWorkspace() string {
|
|||
|
|
home, err := os.UserHomeDir()
|
|||
|
|
if err != nil {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
return filepath.Join(home, ".picoclaw", "workspace")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func shouldRecall(input string) bool {
|
|||
|
|
memoryCfg := internal.GetProjectConfig().Memory
|
|||
|
|
|
|||
|
|
for _, cmd := range []string{"/recall", "/memory"} {
|
|||
|
|
if strings.HasPrefix(input, cmd) {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
recallCfg := memoryCfg.Recall
|
|||
|
|
for _, keyword := range recallCfg.Keywords {
|
|||
|
|
if strings.Contains(input, keyword) {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if recallCfg.AutoRecall {
|
|||
|
|
return shouldAutoRecall(input)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func shouldAutoRecall(input string) bool {
|
|||
|
|
vs := GetVectorService()
|
|||
|
|
if vs == nil {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
memoryCfg := internal.GetProjectConfig().Memory
|
|||
|
|
threshold := memoryCfg.Recall.SimilarityThreshold
|
|||
|
|
if threshold <= 0 {
|
|||
|
|
threshold = 0.7
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
session := currentSession
|
|||
|
|
if session == nil {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if session.Summary == "" || len(session.SummaryEmbedding) == 0 {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
inputEmb, err := vs.Generate(input)
|
|||
|
|
if err != nil {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
similarity := vs.CosineSimilarity(inputEmb, session.SummaryEmbedding)
|
|||
|
|
return similarity > threshold
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func detectAndRecall(input string) string {
|
|||
|
|
input = strings.TrimSpace(input)
|
|||
|
|
|
|||
|
|
if strings.HasPrefix(input, "/recall") {
|
|||
|
|
input = strings.TrimPrefix(input, "/recall")
|
|||
|
|
input = strings.TrimSpace(input)
|
|||
|
|
if input == "" {
|
|||
|
|
result, _ := RecallHistory()
|
|||
|
|
return FormatRecallResults(result)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, keyword := range []string{"之前", "聊过", "记得", "曾经", "谈论过", "提过"} {
|
|||
|
|
if strings.Contains(input, keyword) {
|
|||
|
|
return recallByKeyword(input, keyword)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func recallByKeyword(input, keyword string) string {
|
|||
|
|
db := GetDB()
|
|||
|
|
if db == nil {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if keyword == "之前" || keyword == "聊过" || keyword == "曾经" {
|
|||
|
|
if strings.Contains(input, "什么") || input == "之前" || strings.Contains(input, "都聊") {
|
|||
|
|
result, _ := RecallHistory()
|
|||
|
|
return FormatRecallResults(result)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
topic := extractTopic(input)
|
|||
|
|
if topic != "" {
|
|||
|
|
result, _ := RecallTopic(topic, 5)
|
|||
|
|
return FormatRecallResults(result)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if strings.Contains(input, "那次") || strings.Contains(input, "那次") {
|
|||
|
|
result, _ := RecallHistory()
|
|||
|
|
return FormatRecallResults(result)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
topic := extractTopic(input)
|
|||
|
|
if topic != "" {
|
|||
|
|
result, _ := RecallTopic(topic, 5)
|
|||
|
|
return FormatRecallResults(result)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func extractTopic(input string) string {
|
|||
|
|
topic := input
|
|||
|
|
removePrefixes := []string{"之前", "聊过", "谈论过", "提过", "记得", "曾经", "关于", "我问过"}
|
|||
|
|
for _, prefix := range removePrefixes {
|
|||
|
|
topic = strings.Replace(topic, prefix, "", -1)
|
|||
|
|
}
|
|||
|
|
topic = strings.Trim(topic, " ,,。.??!!::")
|
|||
|
|
return topic
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func ShouldSkipSummaryUpdate(input string) bool {
|
|||
|
|
return shouldRecall(input)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func SaveChat(userInput, aiReply string, useSessionSummary bool) (int, error) {
|
|||
|
|
db := GetDB()
|
|||
|
|
if db == nil {
|
|||
|
|
return 0, fmt.Errorf("数据库未初始化")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
memoryCfg := internal.GetProjectConfig().Memory
|
|||
|
|
if !memoryCfg.Enabled {
|
|||
|
|
return 0, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 获取或创建 Session
|
|||
|
|
session := currentSession
|
|||
|
|
if session == nil {
|
|||
|
|
var err error
|
|||
|
|
session, err = db.GetLatestSession()
|
|||
|
|
if err != nil || session == nil {
|
|||
|
|
session, err = createSession()
|
|||
|
|
if err != nil {
|
|||
|
|
return 0, fmt.Errorf("创建会话失败: %v", err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
currentSession = session
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 保存原始 session summary(用于恢复)
|
|||
|
|
originalSummary := session.Summary
|
|||
|
|
|
|||
|
|
// 创建聊天记录
|
|||
|
|
chat, err := db.CreateChat(session.ID, userInput)
|
|||
|
|
if err != nil {
|
|||
|
|
return 0, fmt.Errorf("创建聊天记录失败: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 添加 AI 回复
|
|||
|
|
chat.AddAIReply(aiReply)
|
|||
|
|
chat.Summary = GenerateSummary(userInput, aiReply)
|
|||
|
|
chat.UpdatedAt = time.Now().Unix()
|
|||
|
|
|
|||
|
|
// 生成并保存向量
|
|||
|
|
vs := GetVectorService()
|
|||
|
|
if vs != nil && memoryCfg.Vector.APIKey != "" {
|
|||
|
|
go generateEmbedding(chat)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := db.UpdateChat(chat); err != nil {
|
|||
|
|
return 0, fmt.Errorf("更新聊天记录失败: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 更新 Session
|
|||
|
|
session.AddChatID(chat.ID)
|
|||
|
|
|
|||
|
|
// 只有普通对话才更新 session summary,recall 查询保持原 summary
|
|||
|
|
if useSessionSummary {
|
|||
|
|
session.Summary = GenerateSummary("", session.Summary+"\n"+aiReply)
|
|||
|
|
}
|
|||
|
|
session.UpdatedAt = time.Now().Unix()
|
|||
|
|
|
|||
|
|
// Session 也要生成向量
|
|||
|
|
if vs != nil && memoryCfg.Vector.APIKey != "" && useSessionSummary {
|
|||
|
|
go generateSessionEmbedding(session)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := db.UpdateSession(session); err != nil {
|
|||
|
|
return 0, fmt.Errorf("更新会话失败: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 恢复原始 session summary(避免 recall 结果污染)
|
|||
|
|
session.Summary = originalSummary
|
|||
|
|
currentSession = session
|
|||
|
|
|
|||
|
|
return len(session.ChatIDs), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func generateEmbedding(chat *Chat) {
|
|||
|
|
vs := GetVectorService()
|
|||
|
|
if vs == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
text := chat.UserInput
|
|||
|
|
if len(chat.AIReplies) > 0 {
|
|||
|
|
text = text + "\n" + chat.AIReplies[len(chat.AIReplies)-1]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
embedding, err := vs.Generate(text)
|
|||
|
|
if err != nil {
|
|||
|
|
fmt.Printf("[memory] 向量生成失败: %v\n", err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
chat.SummaryEmbedding = embedding
|
|||
|
|
chat.UpdatedAt = time.Now().Unix()
|
|||
|
|
|
|||
|
|
if err := GetDB().UpdateChat(chat); err != nil {
|
|||
|
|
fmt.Printf("[memory] 向量保存失败: %v\n", err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func generateSessionEmbedding(session *Session) {
|
|||
|
|
vs := GetVectorService()
|
|||
|
|
if vs == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
embedding, err := vs.Generate(session.Summary)
|
|||
|
|
if err != nil {
|
|||
|
|
fmt.Printf("[memory] Session 向量生成失败: %v\n", err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
session.SummaryEmbedding = embedding
|
|||
|
|
session.UpdatedAt = time.Now().Unix()
|
|||
|
|
|
|||
|
|
if err := GetDB().UpdateSession(session); err != nil {
|
|||
|
|
fmt.Printf("[memory] Session 向量保存失败: %v\n", err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func createSession() (*Session, error) {
|
|||
|
|
return CreateNewSession()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type SearchResult struct {
|
|||
|
|
Chat *Chat
|
|||
|
|
Similarity float64
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func SearchSimilar(query string, topK int) ([]SearchResult, error) {
|
|||
|
|
db := GetDB()
|
|||
|
|
if db == nil {
|
|||
|
|
return nil, fmt.Errorf("数据库未初始化")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
vs := GetVectorService()
|
|||
|
|
if vs == nil {
|
|||
|
|
return nil, fmt.Errorf("向量服务未初始化")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 获取配置的最大检索数
|
|||
|
|
maxResults := internal.GetProjectConfig().Memory.Vector.MaxSearchResults
|
|||
|
|
if maxResults <= 0 {
|
|||
|
|
maxResults = 10
|
|||
|
|
}
|
|||
|
|
if topK > maxResults {
|
|||
|
|
topK = maxResults
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 生成查询向量
|
|||
|
|
queryEmbedding, err := vs.Generate(query)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("查询向量生成失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 获取所有有向量的 chat
|
|||
|
|
chats, err := db.GetAllChats()
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("获取聊天记录失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 计算相似度并排序
|
|||
|
|
type scoredChat struct {
|
|||
|
|
chat *Chat
|
|||
|
|
similarity float64
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var scored []scoredChat
|
|||
|
|
for _, chat := range chats {
|
|||
|
|
if len(chat.SummaryEmbedding) == 0 {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
sim := vs.CosineSimilarity(queryEmbedding, chat.SummaryEmbedding)
|
|||
|
|
scored = append(scored, scoredChat{chat: chat, similarity: sim})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 按相似度排序
|
|||
|
|
for i := 0; i < len(scored); i++ {
|
|||
|
|
for j := i + 1; j < len(scored); j++ {
|
|||
|
|
if scored[j].similarity > scored[i].similarity {
|
|||
|
|
scored[i], scored[j] = scored[j], scored[i]
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 取前 K 个
|
|||
|
|
if topK > len(scored) {
|
|||
|
|
topK = len(scored)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
results := make([]SearchResult, topK)
|
|||
|
|
for i := 0; i < topK; i++ {
|
|||
|
|
results[i] = SearchResult{
|
|||
|
|
Chat: scored[i].chat,
|
|||
|
|
Similarity: scored[i].similarity,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return results, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (db *DB) GetAllChats() ([]*Chat, error) {
|
|||
|
|
rows, err := db.db.Query(
|
|||
|
|
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE summary_embedding IS NOT NULL AND length(summary_embedding) > 0 ORDER BY id DESC`,
|
|||
|
|
)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
defer rows.Close()
|
|||
|
|
|
|||
|
|
var chats []*Chat
|
|||
|
|
for rows.Next() {
|
|||
|
|
var c Chat
|
|||
|
|
var aiRepliesJSON string
|
|||
|
|
var summaryEmbedding []byte
|
|||
|
|
if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &summaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
c.SummaryEmbedding = summaryEmbedding
|
|||
|
|
c.SetAIRepliesFromJSON(aiRepliesJSON)
|
|||
|
|
chats = append(chats, &c)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return chats, nil
|
|||
|
|
}
|