214 lines
4.7 KiB
Go
214 lines
4.7 KiB
Go
package memory
|
||
|
||
import (
|
||
"fmt"
|
||
|
||
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
|
||
)
|
||
|
||
type RecallResult struct {
|
||
Type string `json:"type"`
|
||
Message string `json:"message"`
|
||
Sessions []*Session `json:"sessions,omitempty"`
|
||
Chats []*Chat `json:"chats,omitempty"`
|
||
}
|
||
|
||
func RecallHistory() (*RecallResult, error) {
|
||
db := GetDB()
|
||
if db == nil {
|
||
return nil, fmt.Errorf("数据库未初始化")
|
||
}
|
||
|
||
sessions, err := db.GetAllSessions()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询会话失败: %v", err)
|
||
}
|
||
|
||
if len(sessions) == 0 {
|
||
return &RecallResult{
|
||
Type: "history",
|
||
Message: "暂无历史会话记录",
|
||
}, nil
|
||
}
|
||
|
||
msg := "以下是所有历史会话摘要:\n\n"
|
||
for i, s := range sessions {
|
||
summary := s.Summary
|
||
if summary == "" {
|
||
summary = "(暂无摘要)"
|
||
}
|
||
msg += fmt.Sprintf("**会话 %d** (ID: %d)\n%s\n\n", i+1, s.ID, summary)
|
||
}
|
||
|
||
return &RecallResult{
|
||
Type: "history",
|
||
Message: msg,
|
||
Sessions: sessions,
|
||
}, nil
|
||
}
|
||
|
||
func RecallTopic(query string, topK int) (*RecallResult, error) {
|
||
vs := GetVectorService()
|
||
if vs == nil {
|
||
return nil, fmt.Errorf("向量服务未初始化")
|
||
}
|
||
|
||
cfg := internal.GetProjectConfig().Memory
|
||
if topK <= 0 {
|
||
topK = cfg.Vector.MaxSearchResults
|
||
if topK <= 0 {
|
||
topK = 5
|
||
}
|
||
}
|
||
|
||
results, err := SearchSimilar(query, topK)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("向量检索失败: %v", err)
|
||
}
|
||
|
||
if len(results) == 0 {
|
||
return &RecallResult{
|
||
Type: "topic",
|
||
Message: fmt.Sprintf("未找到与「%s」相关的内容", query),
|
||
}, nil
|
||
}
|
||
|
||
type sessionChat struct {
|
||
sessionID int64
|
||
chat *Chat
|
||
similarity float64
|
||
}
|
||
|
||
sessionChats := make(map[int64]*sessionChat)
|
||
for _, r := range results {
|
||
if existing, ok := sessionChats[r.Chat.SessionID]; ok {
|
||
if r.Similarity > existing.similarity {
|
||
sessionChats[r.Chat.SessionID] = &sessionChat{
|
||
sessionID: r.Chat.SessionID,
|
||
chat: r.Chat,
|
||
similarity: r.Similarity,
|
||
}
|
||
}
|
||
} else {
|
||
sessionChats[r.Chat.SessionID] = &sessionChat{
|
||
sessionID: r.Chat.SessionID,
|
||
chat: r.Chat,
|
||
similarity: r.Similarity,
|
||
}
|
||
}
|
||
}
|
||
|
||
msg := fmt.Sprintf("找到与「%s」相关的内容:\n\n", query)
|
||
var chats []*Chat
|
||
for _, sc := range sessionChats {
|
||
msg += fmt.Sprintf("**会话 %d** (相似度: %.2f)\n%s\n\n",
|
||
sc.sessionID, sc.similarity, sc.chat.Summary)
|
||
chats = append(chats, sc.chat)
|
||
}
|
||
|
||
return &RecallResult{
|
||
Type: "topic",
|
||
Message: msg,
|
||
Chats: chats,
|
||
}, nil
|
||
}
|
||
|
||
func RecallSession(sessionID int64) (*RecallResult, error) {
|
||
db := GetDB()
|
||
if db == nil {
|
||
return nil, fmt.Errorf("数据库未初始化")
|
||
}
|
||
|
||
session, err := db.GetSessionByID(sessionID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询会话失败: %v", err)
|
||
}
|
||
if session == nil {
|
||
return &RecallResult{
|
||
Type: "session",
|
||
Message: fmt.Sprintf("未找到会话 ID: %d", sessionID),
|
||
}, nil
|
||
}
|
||
|
||
msg := fmt.Sprintf("**会话 %d 的摘要**\n\n%s", sessionID, session.Summary)
|
||
|
||
return &RecallResult{
|
||
Type: "session",
|
||
Message: msg,
|
||
Sessions: []*Session{session},
|
||
}, nil
|
||
}
|
||
|
||
func RecallWithinSession(sessionID int64, query string) (*RecallResult, error) {
|
||
db := GetDB()
|
||
if db == nil {
|
||
return nil, fmt.Errorf("数据库未初始化")
|
||
}
|
||
|
||
chats, err := db.GetChatsBySessionID(sessionID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询聊天记录失败: %v", err)
|
||
}
|
||
|
||
if len(chats) == 0 {
|
||
return &RecallResult{
|
||
Type: "within_session",
|
||
Message: fmt.Sprintf("会话 %d 中暂无聊天记录", sessionID),
|
||
}, nil
|
||
}
|
||
|
||
vs := GetVectorService()
|
||
if vs == nil {
|
||
return nil, fmt.Errorf("向量服务未初始化")
|
||
}
|
||
|
||
queryEmb, err := vs.Generate(query)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("生成查询向量失败: %v", err)
|
||
}
|
||
|
||
type scoredChat struct {
|
||
chat *Chat
|
||
similarity float64
|
||
}
|
||
|
||
var scored []scoredChat
|
||
for _, c := range chats {
|
||
if len(c.SummaryEmbedding) == 0 {
|
||
continue
|
||
}
|
||
sim := vs.CosineSimilarity(queryEmb, c.SummaryEmbedding)
|
||
scored = append(scored, scoredChat{chat: c, similarity: sim})
|
||
}
|
||
|
||
for i := 0; i < len(scored); i++ {
|
||
for j := i + 1; j < len(scored); j++ {
|
||
if scored[j].similarity > scored[i].similarity {
|
||
scored[i], scored[j] = scored[j], scored[i]
|
||
}
|
||
}
|
||
}
|
||
|
||
topK := 3
|
||
if topK > len(scored) {
|
||
topK = len(scored)
|
||
}
|
||
|
||
msg := fmt.Sprintf("会话 %d 中与「%s」相关内容:\n\n", sessionID, query)
|
||
var matchedChats []*Chat
|
||
for i := 0; i < topK; i++ {
|
||
msg += fmt.Sprintf("**相似度: %.2f**\n%s\n\n",
|
||
scored[i].similarity, scored[i].chat.Summary)
|
||
matchedChats = append(matchedChats, scored[i].chat)
|
||
}
|
||
|
||
return &RecallResult{
|
||
Type: "within_session",
|
||
Message: msg,
|
||
Chats: matchedChats,
|
||
}, nil
|
||
}
|
||
|
||
func FormatRecallResults(result *RecallResult) string {
|
||
return result.Message
|
||
} |