Files
HxClaw/cmd/hxclaw/internal/memory/skill.go

214 lines
4.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package memory
import (
"fmt"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
)
type RecallResult struct {
Type string `json:"type"`
Message string `json:"message"`
Sessions []*Session `json:"sessions,omitempty"`
Chats []*Chat `json:"chats,omitempty"`
}
func RecallHistory() (*RecallResult, error) {
db := GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
sessions, err := db.GetAllSessions()
if err != nil {
return nil, fmt.Errorf("查询会话失败: %v", err)
}
if len(sessions) == 0 {
return &RecallResult{
Type: "history",
Message: "暂无历史会话记录",
}, nil
}
msg := "以下是所有历史会话摘要:\n\n"
for i, s := range sessions {
summary := s.Summary
if summary == "" {
summary = "(暂无摘要)"
}
msg += fmt.Sprintf("**会话 %d** (ID: %d)\n%s\n\n", i+1, s.ID, summary)
}
return &RecallResult{
Type: "history",
Message: msg,
Sessions: sessions,
}, nil
}
func RecallTopic(query string, topK int) (*RecallResult, error) {
vs := GetVectorService()
if vs == nil {
return nil, fmt.Errorf("向量服务未初始化")
}
cfg := internal.GetProjectConfig().Memory
if topK <= 0 {
topK = cfg.Vector.MaxSearchResults
if topK <= 0 {
topK = 5
}
}
results, err := SearchSimilar(query, topK)
if err != nil {
return nil, fmt.Errorf("向量检索失败: %v", err)
}
if len(results) == 0 {
return &RecallResult{
Type: "topic",
Message: fmt.Sprintf("未找到与「%s」相关的内容", query),
}, nil
}
type sessionChat struct {
sessionID int64
chat *Chat
similarity float64
}
sessionChats := make(map[int64]*sessionChat)
for _, r := range results {
if existing, ok := sessionChats[r.Chat.SessionID]; ok {
if r.Similarity > existing.similarity {
sessionChats[r.Chat.SessionID] = &sessionChat{
sessionID: r.Chat.SessionID,
chat: r.Chat,
similarity: r.Similarity,
}
}
} else {
sessionChats[r.Chat.SessionID] = &sessionChat{
sessionID: r.Chat.SessionID,
chat: r.Chat,
similarity: r.Similarity,
}
}
}
msg := fmt.Sprintf("找到与「%s」相关的内容\n\n", query)
var chats []*Chat
for _, sc := range sessionChats {
msg += fmt.Sprintf("**会话 %d** (相似度: %.2f)\n%s\n\n",
sc.sessionID, sc.similarity, sc.chat.Summary)
chats = append(chats, sc.chat)
}
return &RecallResult{
Type: "topic",
Message: msg,
Chats: chats,
}, nil
}
func RecallSession(sessionID int64) (*RecallResult, error) {
db := GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
session, err := db.GetSessionByID(sessionID)
if err != nil {
return nil, fmt.Errorf("查询会话失败: %v", err)
}
if session == nil {
return &RecallResult{
Type: "session",
Message: fmt.Sprintf("未找到会话 ID: %d", sessionID),
}, nil
}
msg := fmt.Sprintf("**会话 %d 的摘要**\n\n%s", sessionID, session.Summary)
return &RecallResult{
Type: "session",
Message: msg,
Sessions: []*Session{session},
}, nil
}
func RecallWithinSession(sessionID int64, query string) (*RecallResult, error) {
db := GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
chats, err := db.GetChatsBySessionID(sessionID)
if err != nil {
return nil, fmt.Errorf("查询聊天记录失败: %v", err)
}
if len(chats) == 0 {
return &RecallResult{
Type: "within_session",
Message: fmt.Sprintf("会话 %d 中暂无聊天记录", sessionID),
}, nil
}
vs := GetVectorService()
if vs == nil {
return nil, fmt.Errorf("向量服务未初始化")
}
queryEmb, err := vs.Generate(query)
if err != nil {
return nil, fmt.Errorf("生成查询向量失败: %v", err)
}
type scoredChat struct {
chat *Chat
similarity float64
}
var scored []scoredChat
for _, c := range chats {
if len(c.SummaryEmbedding) == 0 {
continue
}
sim := vs.CosineSimilarity(queryEmb, c.SummaryEmbedding)
scored = append(scored, scoredChat{chat: c, similarity: sim})
}
for i := 0; i < len(scored); i++ {
for j := i + 1; j < len(scored); j++ {
if scored[j].similarity > scored[i].similarity {
scored[i], scored[j] = scored[j], scored[i]
}
}
}
topK := 3
if topK > len(scored) {
topK = len(scored)
}
msg := fmt.Sprintf("会话 %d 中与「%s」相关内容\n\n", sessionID, query)
var matchedChats []*Chat
for i := 0; i < topK; i++ {
msg += fmt.Sprintf("**相似度: %.2f**\n%s\n\n",
scored[i].similarity, scored[i].chat.Summary)
matchedChats = append(matchedChats, scored[i].chat)
}
return &RecallResult{
Type: "within_session",
Message: msg,
Chats: matchedChats,
}, nil
}
func FormatRecallResults(result *RecallResult) string {
return result.Message
}