214 lines
4.7 KiB
Go
214 lines
4.7 KiB
Go
|
|
package memory
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"fmt"
|
|||
|
|
|
|||
|
|
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
type RecallResult struct {
|
|||
|
|
Type string `json:"type"`
|
|||
|
|
Message string `json:"message"`
|
|||
|
|
Sessions []*Session `json:"sessions,omitempty"`
|
|||
|
|
Chats []*Chat `json:"chats,omitempty"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func RecallHistory() (*RecallResult, error) {
|
|||
|
|
db := GetDB()
|
|||
|
|
if db == nil {
|
|||
|
|
return nil, fmt.Errorf("数据库未初始化")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
sessions, err := db.GetAllSessions()
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("查询会话失败: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(sessions) == 0 {
|
|||
|
|
return &RecallResult{
|
|||
|
|
Type: "history",
|
|||
|
|
Message: "暂无历史会话记录",
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
msg := "以下是所有历史会话摘要:\n\n"
|
|||
|
|
for i, s := range sessions {
|
|||
|
|
summary := s.Summary
|
|||
|
|
if summary == "" {
|
|||
|
|
summary = "(暂无摘要)"
|
|||
|
|
}
|
|||
|
|
msg += fmt.Sprintf("**会话 %d** (ID: %d)\n%s\n\n", i+1, s.ID, summary)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &RecallResult{
|
|||
|
|
Type: "history",
|
|||
|
|
Message: msg,
|
|||
|
|
Sessions: sessions,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func RecallTopic(query string, topK int) (*RecallResult, error) {
|
|||
|
|
vs := GetVectorService()
|
|||
|
|
if vs == nil {
|
|||
|
|
return nil, fmt.Errorf("向量服务未初始化")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
cfg := internal.GetProjectConfig().Memory
|
|||
|
|
if topK <= 0 {
|
|||
|
|
topK = cfg.Vector.MaxSearchResults
|
|||
|
|
if topK <= 0 {
|
|||
|
|
topK = 5
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
results, err := SearchSimilar(query, topK)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("向量检索失败: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(results) == 0 {
|
|||
|
|
return &RecallResult{
|
|||
|
|
Type: "topic",
|
|||
|
|
Message: fmt.Sprintf("未找到与「%s」相关的内容", query),
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type sessionChat struct {
|
|||
|
|
sessionID int64
|
|||
|
|
chat *Chat
|
|||
|
|
similarity float64
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
sessionChats := make(map[int64]*sessionChat)
|
|||
|
|
for _, r := range results {
|
|||
|
|
if existing, ok := sessionChats[r.Chat.SessionID]; ok {
|
|||
|
|
if r.Similarity > existing.similarity {
|
|||
|
|
sessionChats[r.Chat.SessionID] = &sessionChat{
|
|||
|
|
sessionID: r.Chat.SessionID,
|
|||
|
|
chat: r.Chat,
|
|||
|
|
similarity: r.Similarity,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
sessionChats[r.Chat.SessionID] = &sessionChat{
|
|||
|
|
sessionID: r.Chat.SessionID,
|
|||
|
|
chat: r.Chat,
|
|||
|
|
similarity: r.Similarity,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
msg := fmt.Sprintf("找到与「%s」相关的内容:\n\n", query)
|
|||
|
|
var chats []*Chat
|
|||
|
|
for _, sc := range sessionChats {
|
|||
|
|
msg += fmt.Sprintf("**会话 %d** (相似度: %.2f)\n%s\n\n",
|
|||
|
|
sc.sessionID, sc.similarity, sc.chat.Summary)
|
|||
|
|
chats = append(chats, sc.chat)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &RecallResult{
|
|||
|
|
Type: "topic",
|
|||
|
|
Message: msg,
|
|||
|
|
Chats: chats,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func RecallSession(sessionID int64) (*RecallResult, error) {
|
|||
|
|
db := GetDB()
|
|||
|
|
if db == nil {
|
|||
|
|
return nil, fmt.Errorf("数据库未初始化")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
session, err := db.GetSessionByID(sessionID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("查询会话失败: %v", err)
|
|||
|
|
}
|
|||
|
|
if session == nil {
|
|||
|
|
return &RecallResult{
|
|||
|
|
Type: "session",
|
|||
|
|
Message: fmt.Sprintf("未找到会话 ID: %d", sessionID),
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
msg := fmt.Sprintf("**会话 %d 的摘要**\n\n%s", sessionID, session.Summary)
|
|||
|
|
|
|||
|
|
return &RecallResult{
|
|||
|
|
Type: "session",
|
|||
|
|
Message: msg,
|
|||
|
|
Sessions: []*Session{session},
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func RecallWithinSession(sessionID int64, query string) (*RecallResult, error) {
|
|||
|
|
db := GetDB()
|
|||
|
|
if db == nil {
|
|||
|
|
return nil, fmt.Errorf("数据库未初始化")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
chats, err := db.GetChatsBySessionID(sessionID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("查询聊天记录失败: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(chats) == 0 {
|
|||
|
|
return &RecallResult{
|
|||
|
|
Type: "within_session",
|
|||
|
|
Message: fmt.Sprintf("会话 %d 中暂无聊天记录", sessionID),
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
vs := GetVectorService()
|
|||
|
|
if vs == nil {
|
|||
|
|
return nil, fmt.Errorf("向量服务未初始化")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
queryEmb, err := vs.Generate(query)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("生成查询向量失败: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type scoredChat struct {
|
|||
|
|
chat *Chat
|
|||
|
|
similarity float64
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var scored []scoredChat
|
|||
|
|
for _, c := range chats {
|
|||
|
|
if len(c.SummaryEmbedding) == 0 {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
sim := vs.CosineSimilarity(queryEmb, c.SummaryEmbedding)
|
|||
|
|
scored = append(scored, scoredChat{chat: c, similarity: sim})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for i := 0; i < len(scored); i++ {
|
|||
|
|
for j := i + 1; j < len(scored); j++ {
|
|||
|
|
if scored[j].similarity > scored[i].similarity {
|
|||
|
|
scored[i], scored[j] = scored[j], scored[i]
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
topK := 3
|
|||
|
|
if topK > len(scored) {
|
|||
|
|
topK = len(scored)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
msg := fmt.Sprintf("会话 %d 中与「%s」相关内容:\n\n", sessionID, query)
|
|||
|
|
var matchedChats []*Chat
|
|||
|
|
for i := 0; i < topK; i++ {
|
|||
|
|
msg += fmt.Sprintf("**相似度: %.2f**\n%s\n\n",
|
|||
|
|
scored[i].similarity, scored[i].chat.Summary)
|
|||
|
|
matchedChats = append(matchedChats, scored[i].chat)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &RecallResult{
|
|||
|
|
Type: "within_session",
|
|||
|
|
Message: msg,
|
|||
|
|
Chats: matchedChats,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func FormatRecallResults(result *RecallResult) string {
|
|||
|
|
return result.Message
|
|||
|
|
}
|