Files

214 lines
4.7 KiB
Go
Raw Permalink Normal View History

package memory
import (
"fmt"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
)
type RecallResult struct {
Type string `json:"type"`
Message string `json:"message"`
Sessions []*Session `json:"sessions,omitempty"`
Chats []*Chat `json:"chats,omitempty"`
}
func RecallHistory() (*RecallResult, error) {
db := GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
sessions, err := db.GetAllSessions()
if err != nil {
return nil, fmt.Errorf("查询会话失败: %v", err)
}
if len(sessions) == 0 {
return &RecallResult{
Type: "history",
Message: "暂无历史会话记录",
}, nil
}
msg := "以下是所有历史会话摘要:\n\n"
for i, s := range sessions {
summary := s.Summary
if summary == "" {
summary = "(暂无摘要)"
}
msg += fmt.Sprintf("**会话 %d** (ID: %d)\n%s\n\n", i+1, s.ID, summary)
}
return &RecallResult{
Type: "history",
Message: msg,
Sessions: sessions,
}, nil
}
func RecallTopic(query string, topK int) (*RecallResult, error) {
vs := GetVectorService()
if vs == nil {
return nil, fmt.Errorf("向量服务未初始化")
}
cfg := internal.GetProjectConfig().Memory
if topK <= 0 {
topK = cfg.Vector.MaxSearchResults
if topK <= 0 {
topK = 5
}
}
results, err := SearchSimilar(query, topK)
if err != nil {
return nil, fmt.Errorf("向量检索失败: %v", err)
}
if len(results) == 0 {
return &RecallResult{
Type: "topic",
Message: fmt.Sprintf("未找到与「%s」相关的内容", query),
}, nil
}
type sessionChat struct {
sessionID int64
chat *Chat
similarity float64
}
sessionChats := make(map[int64]*sessionChat)
for _, r := range results {
if existing, ok := sessionChats[r.Chat.SessionID]; ok {
if r.Similarity > existing.similarity {
sessionChats[r.Chat.SessionID] = &sessionChat{
sessionID: r.Chat.SessionID,
chat: r.Chat,
similarity: r.Similarity,
}
}
} else {
sessionChats[r.Chat.SessionID] = &sessionChat{
sessionID: r.Chat.SessionID,
chat: r.Chat,
similarity: r.Similarity,
}
}
}
msg := fmt.Sprintf("找到与「%s」相关的内容\n\n", query)
var chats []*Chat
for _, sc := range sessionChats {
msg += fmt.Sprintf("**会话 %d** (相似度: %.2f)\n%s\n\n",
sc.sessionID, sc.similarity, sc.chat.Summary)
chats = append(chats, sc.chat)
}
return &RecallResult{
Type: "topic",
Message: msg,
Chats: chats,
}, nil
}
func RecallSession(sessionID int64) (*RecallResult, error) {
db := GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
session, err := db.GetSessionByID(sessionID)
if err != nil {
return nil, fmt.Errorf("查询会话失败: %v", err)
}
if session == nil {
return &RecallResult{
Type: "session",
Message: fmt.Sprintf("未找到会话 ID: %d", sessionID),
}, nil
}
msg := fmt.Sprintf("**会话 %d 的摘要**\n\n%s", sessionID, session.Summary)
return &RecallResult{
Type: "session",
Message: msg,
Sessions: []*Session{session},
}, nil
}
func RecallWithinSession(sessionID int64, query string) (*RecallResult, error) {
db := GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
chats, err := db.GetChatsBySessionID(sessionID)
if err != nil {
return nil, fmt.Errorf("查询聊天记录失败: %v", err)
}
if len(chats) == 0 {
return &RecallResult{
Type: "within_session",
Message: fmt.Sprintf("会话 %d 中暂无聊天记录", sessionID),
}, nil
}
vs := GetVectorService()
if vs == nil {
return nil, fmt.Errorf("向量服务未初始化")
}
queryEmb, err := vs.Generate(query)
if err != nil {
return nil, fmt.Errorf("生成查询向量失败: %v", err)
}
type scoredChat struct {
chat *Chat
similarity float64
}
var scored []scoredChat
for _, c := range chats {
if len(c.SummaryEmbedding) == 0 {
continue
}
sim := vs.CosineSimilarity(queryEmb, c.SummaryEmbedding)
scored = append(scored, scoredChat{chat: c, similarity: sim})
}
for i := 0; i < len(scored); i++ {
for j := i + 1; j < len(scored); j++ {
if scored[j].similarity > scored[i].similarity {
scored[i], scored[j] = scored[j], scored[i]
}
}
}
topK := 3
if topK > len(scored) {
topK = len(scored)
}
msg := fmt.Sprintf("会话 %d 中与「%s」相关内容\n\n", sessionID, query)
var matchedChats []*Chat
for i := 0; i < topK; i++ {
msg += fmt.Sprintf("**相似度: %.2f**\n%s\n\n",
scored[i].similarity, scored[i].chat.Summary)
matchedChats = append(matchedChats, scored[i].chat)
}
return &RecallResult{
Type: "within_session",
Message: msg,
Chats: matchedChats,
}, nil
}
func FormatRecallResults(result *RecallResult) string {
return result.Message
}