fix: 补充提交 memory 模块和 tts/userdir 文件
This commit is contained in:
443
cmd/hxclaw/internal/memory/save.go
Normal file
443
cmd/hxclaw/internal/memory/save.go
Normal file
@@ -0,0 +1,443 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
|
||||
)
|
||||
|
||||
var lastContext string
|
||||
|
||||
func GetContextPrompt(userInput string) string {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
memoryCfg := internal.GetProjectConfig().Memory
|
||||
if !memoryCfg.Enabled {
|
||||
return ""
|
||||
}
|
||||
|
||||
session := currentSession
|
||||
if session == nil {
|
||||
var err error
|
||||
session, err = db.GetLatestSession()
|
||||
if err != nil || session == nil {
|
||||
return ""
|
||||
}
|
||||
currentSession = session
|
||||
}
|
||||
|
||||
var context string
|
||||
|
||||
// 读取 picoclaw 的长期记忆
|
||||
picoMemory := readPicoClawMemory()
|
||||
if picoMemory != "" {
|
||||
context += "=== 长期记忆 ===\n" + picoMemory + "\n============\n"
|
||||
}
|
||||
|
||||
// 添加会话摘要
|
||||
if session.Summary != "" {
|
||||
context += "=== 当前会话摘要 ===\n" + session.Summary + "\n============\n"
|
||||
}
|
||||
|
||||
if shouldRecall(userInput) {
|
||||
recallResult := detectAndRecall(userInput)
|
||||
if recallResult != "" {
|
||||
context += recallResult
|
||||
}
|
||||
}
|
||||
|
||||
lastContext = context
|
||||
return context
|
||||
}
|
||||
|
||||
func readPicoClawMemory() string {
|
||||
memoryPath := filepath.Join(getPicoclawWorkspace(), "memory", "MEMORY.md")
|
||||
data, err := os.ReadFile(memoryPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(string(data))
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
var sb strings.Builder
|
||||
inSection := false
|
||||
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
inSection = true
|
||||
continue
|
||||
}
|
||||
if inSection {
|
||||
sb.WriteString(trimmed)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
|
||||
func getPicoclawWorkspace() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, ".picoclaw", "workspace")
|
||||
}
|
||||
|
||||
func shouldRecall(input string) bool {
|
||||
memoryCfg := internal.GetProjectConfig().Memory
|
||||
|
||||
for _, cmd := range []string{"/recall", "/memory"} {
|
||||
if strings.HasPrefix(input, cmd) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
recallCfg := memoryCfg.Recall
|
||||
for _, keyword := range recallCfg.Keywords {
|
||||
if strings.Contains(input, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if recallCfg.AutoRecall {
|
||||
return shouldAutoRecall(input)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func shouldAutoRecall(input string) bool {
|
||||
vs := GetVectorService()
|
||||
if vs == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
memoryCfg := internal.GetProjectConfig().Memory
|
||||
threshold := memoryCfg.Recall.SimilarityThreshold
|
||||
if threshold <= 0 {
|
||||
threshold = 0.7
|
||||
}
|
||||
|
||||
session := currentSession
|
||||
if session == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if session.Summary == "" || len(session.SummaryEmbedding) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
inputEmb, err := vs.Generate(input)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
similarity := vs.CosineSimilarity(inputEmb, session.SummaryEmbedding)
|
||||
return similarity > threshold
|
||||
}
|
||||
|
||||
func detectAndRecall(input string) string {
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
if strings.HasPrefix(input, "/recall") {
|
||||
input = strings.TrimPrefix(input, "/recall")
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
result, _ := RecallHistory()
|
||||
return FormatRecallResults(result)
|
||||
}
|
||||
}
|
||||
|
||||
for _, keyword := range []string{"之前", "聊过", "记得", "曾经", "谈论过", "提过"} {
|
||||
if strings.Contains(input, keyword) {
|
||||
return recallByKeyword(input, keyword)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func recallByKeyword(input, keyword string) string {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if keyword == "之前" || keyword == "聊过" || keyword == "曾经" {
|
||||
if strings.Contains(input, "什么") || input == "之前" || strings.Contains(input, "都聊") {
|
||||
result, _ := RecallHistory()
|
||||
return FormatRecallResults(result)
|
||||
}
|
||||
|
||||
topic := extractTopic(input)
|
||||
if topic != "" {
|
||||
result, _ := RecallTopic(topic, 5)
|
||||
return FormatRecallResults(result)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(input, "那次") || strings.Contains(input, "那次") {
|
||||
result, _ := RecallHistory()
|
||||
return FormatRecallResults(result)
|
||||
}
|
||||
|
||||
topic := extractTopic(input)
|
||||
if topic != "" {
|
||||
result, _ := RecallTopic(topic, 5)
|
||||
return FormatRecallResults(result)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractTopic(input string) string {
|
||||
topic := input
|
||||
removePrefixes := []string{"之前", "聊过", "谈论过", "提过", "记得", "曾经", "关于", "我问过"}
|
||||
for _, prefix := range removePrefixes {
|
||||
topic = strings.Replace(topic, prefix, "", -1)
|
||||
}
|
||||
topic = strings.Trim(topic, " ,,。.??!!::")
|
||||
return topic
|
||||
}
|
||||
|
||||
func ShouldSkipSummaryUpdate(input string) bool {
|
||||
return shouldRecall(input)
|
||||
}
|
||||
|
||||
func SaveChat(userInput, aiReply string, useSessionSummary bool) (int, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return 0, fmt.Errorf("数据库未初始化")
|
||||
}
|
||||
|
||||
memoryCfg := internal.GetProjectConfig().Memory
|
||||
if !memoryCfg.Enabled {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// 获取或创建 Session
|
||||
session := currentSession
|
||||
if session == nil {
|
||||
var err error
|
||||
session, err = db.GetLatestSession()
|
||||
if err != nil || session == nil {
|
||||
session, err = createSession()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("创建会话失败: %v", err)
|
||||
}
|
||||
}
|
||||
currentSession = session
|
||||
}
|
||||
|
||||
// 保存原始 session summary(用于恢复)
|
||||
originalSummary := session.Summary
|
||||
|
||||
// 创建聊天记录
|
||||
chat, err := db.CreateChat(session.ID, userInput)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("创建聊天记录失败: %v", err)
|
||||
}
|
||||
|
||||
// 添加 AI 回复
|
||||
chat.AddAIReply(aiReply)
|
||||
chat.Summary = GenerateSummary(userInput, aiReply)
|
||||
chat.UpdatedAt = time.Now().Unix()
|
||||
|
||||
// 生成并保存向量
|
||||
vs := GetVectorService()
|
||||
if vs != nil && memoryCfg.Vector.APIKey != "" {
|
||||
go generateEmbedding(chat)
|
||||
}
|
||||
|
||||
if err := db.UpdateChat(chat); err != nil {
|
||||
return 0, fmt.Errorf("更新聊天记录失败: %v", err)
|
||||
}
|
||||
|
||||
// 更新 Session
|
||||
session.AddChatID(chat.ID)
|
||||
|
||||
// 只有普通对话才更新 session summary,recall 查询保持原 summary
|
||||
if useSessionSummary {
|
||||
session.Summary = GenerateSummary("", session.Summary+"\n"+aiReply)
|
||||
}
|
||||
session.UpdatedAt = time.Now().Unix()
|
||||
|
||||
// Session 也要生成向量
|
||||
if vs != nil && memoryCfg.Vector.APIKey != "" && useSessionSummary {
|
||||
go generateSessionEmbedding(session)
|
||||
}
|
||||
|
||||
if err := db.UpdateSession(session); err != nil {
|
||||
return 0, fmt.Errorf("更新会话失败: %v", err)
|
||||
}
|
||||
|
||||
// 恢复原始 session summary(避免 recall 结果污染)
|
||||
session.Summary = originalSummary
|
||||
currentSession = session
|
||||
|
||||
return len(session.ChatIDs), nil
|
||||
}
|
||||
|
||||
func generateEmbedding(chat *Chat) {
|
||||
vs := GetVectorService()
|
||||
if vs == nil {
|
||||
return
|
||||
}
|
||||
|
||||
text := chat.UserInput
|
||||
if len(chat.AIReplies) > 0 {
|
||||
text = text + "\n" + chat.AIReplies[len(chat.AIReplies)-1]
|
||||
}
|
||||
|
||||
embedding, err := vs.Generate(text)
|
||||
if err != nil {
|
||||
fmt.Printf("[memory] 向量生成失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
chat.SummaryEmbedding = embedding
|
||||
chat.UpdatedAt = time.Now().Unix()
|
||||
|
||||
if err := GetDB().UpdateChat(chat); err != nil {
|
||||
fmt.Printf("[memory] 向量保存失败: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func generateSessionEmbedding(session *Session) {
|
||||
vs := GetVectorService()
|
||||
if vs == nil {
|
||||
return
|
||||
}
|
||||
|
||||
embedding, err := vs.Generate(session.Summary)
|
||||
if err != nil {
|
||||
fmt.Printf("[memory] Session 向量生成失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
session.SummaryEmbedding = embedding
|
||||
session.UpdatedAt = time.Now().Unix()
|
||||
|
||||
if err := GetDB().UpdateSession(session); err != nil {
|
||||
fmt.Printf("[memory] Session 向量保存失败: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createSession() (*Session, error) {
|
||||
return CreateNewSession()
|
||||
}
|
||||
|
||||
type SearchResult struct {
|
||||
Chat *Chat
|
||||
Similarity float64
|
||||
}
|
||||
|
||||
func SearchSimilar(query string, topK int) ([]SearchResult, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("数据库未初始化")
|
||||
}
|
||||
|
||||
vs := GetVectorService()
|
||||
if vs == nil {
|
||||
return nil, fmt.Errorf("向量服务未初始化")
|
||||
}
|
||||
|
||||
// 获取配置的最大检索数
|
||||
maxResults := internal.GetProjectConfig().Memory.Vector.MaxSearchResults
|
||||
if maxResults <= 0 {
|
||||
maxResults = 10
|
||||
}
|
||||
if topK > maxResults {
|
||||
topK = maxResults
|
||||
}
|
||||
|
||||
// 生成查询向量
|
||||
queryEmbedding, err := vs.Generate(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询向量生成失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取所有有向量的 chat
|
||||
chats, err := db.GetAllChats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取聊天记录失败: %w", err)
|
||||
}
|
||||
|
||||
// 计算相似度并排序
|
||||
type scoredChat struct {
|
||||
chat *Chat
|
||||
similarity float64
|
||||
}
|
||||
|
||||
var scored []scoredChat
|
||||
for _, chat := range chats {
|
||||
if len(chat.SummaryEmbedding) == 0 {
|
||||
continue
|
||||
}
|
||||
sim := vs.CosineSimilarity(queryEmbedding, chat.SummaryEmbedding)
|
||||
scored = append(scored, scoredChat{chat: chat, similarity: sim})
|
||||
}
|
||||
|
||||
// 按相似度排序
|
||||
for i := 0; i < len(scored); i++ {
|
||||
for j := i + 1; j < len(scored); j++ {
|
||||
if scored[j].similarity > scored[i].similarity {
|
||||
scored[i], scored[j] = scored[j], scored[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 取前 K 个
|
||||
if topK > len(scored) {
|
||||
topK = len(scored)
|
||||
}
|
||||
|
||||
results := make([]SearchResult, topK)
|
||||
for i := 0; i < topK; i++ {
|
||||
results[i] = SearchResult{
|
||||
Chat: scored[i].chat,
|
||||
Similarity: scored[i].similarity,
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (db *DB) GetAllChats() ([]*Chat, error) {
|
||||
rows, err := db.db.Query(
|
||||
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE summary_embedding IS NOT NULL AND length(summary_embedding) > 0 ORDER BY id DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var chats []*Chat
|
||||
for rows.Next() {
|
||||
var c Chat
|
||||
var aiRepliesJSON string
|
||||
var summaryEmbedding []byte
|
||||
if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &summaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
c.SummaryEmbedding = summaryEmbedding
|
||||
c.SetAIRepliesFromJSON(aiRepliesJSON)
|
||||
chats = append(chats, &c)
|
||||
}
|
||||
|
||||
return chats, nil
|
||||
}
|
||||
Reference in New Issue
Block a user