fix: 补充提交 memory 模块和 tts/userdir 文件
This commit is contained in:
214
cmd/hxclaw/internal/memory/skill.go
Normal file
214
cmd/hxclaw/internal/memory/skill.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
|
||||
)
|
||||
|
||||
type RecallResult struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
Sessions []*Session `json:"sessions,omitempty"`
|
||||
Chats []*Chat `json:"chats,omitempty"`
|
||||
}
|
||||
|
||||
func RecallHistory() (*RecallResult, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("数据库未初始化")
|
||||
}
|
||||
|
||||
sessions, err := db.GetAllSessions()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询会话失败: %v", err)
|
||||
}
|
||||
|
||||
if len(sessions) == 0 {
|
||||
return &RecallResult{
|
||||
Type: "history",
|
||||
Message: "暂无历史会话记录",
|
||||
}, nil
|
||||
}
|
||||
|
||||
msg := "以下是所有历史会话摘要:\n\n"
|
||||
for i, s := range sessions {
|
||||
summary := s.Summary
|
||||
if summary == "" {
|
||||
summary = "(暂无摘要)"
|
||||
}
|
||||
msg += fmt.Sprintf("**会话 %d** (ID: %d)\n%s\n\n", i+1, s.ID, summary)
|
||||
}
|
||||
|
||||
return &RecallResult{
|
||||
Type: "history",
|
||||
Message: msg,
|
||||
Sessions: sessions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RecallTopic(query string, topK int) (*RecallResult, error) {
|
||||
vs := GetVectorService()
|
||||
if vs == nil {
|
||||
return nil, fmt.Errorf("向量服务未初始化")
|
||||
}
|
||||
|
||||
cfg := internal.GetProjectConfig().Memory
|
||||
if topK <= 0 {
|
||||
topK = cfg.Vector.MaxSearchResults
|
||||
if topK <= 0 {
|
||||
topK = 5
|
||||
}
|
||||
}
|
||||
|
||||
results, err := SearchSimilar(query, topK)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("向量检索失败: %v", err)
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return &RecallResult{
|
||||
Type: "topic",
|
||||
Message: fmt.Sprintf("未找到与「%s」相关的内容", query),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type sessionChat struct {
|
||||
sessionID int64
|
||||
chat *Chat
|
||||
similarity float64
|
||||
}
|
||||
|
||||
sessionChats := make(map[int64]*sessionChat)
|
||||
for _, r := range results {
|
||||
if existing, ok := sessionChats[r.Chat.SessionID]; ok {
|
||||
if r.Similarity > existing.similarity {
|
||||
sessionChats[r.Chat.SessionID] = &sessionChat{
|
||||
sessionID: r.Chat.SessionID,
|
||||
chat: r.Chat,
|
||||
similarity: r.Similarity,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
sessionChats[r.Chat.SessionID] = &sessionChat{
|
||||
sessionID: r.Chat.SessionID,
|
||||
chat: r.Chat,
|
||||
similarity: r.Similarity,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("找到与「%s」相关的内容:\n\n", query)
|
||||
var chats []*Chat
|
||||
for _, sc := range sessionChats {
|
||||
msg += fmt.Sprintf("**会话 %d** (相似度: %.2f)\n%s\n\n",
|
||||
sc.sessionID, sc.similarity, sc.chat.Summary)
|
||||
chats = append(chats, sc.chat)
|
||||
}
|
||||
|
||||
return &RecallResult{
|
||||
Type: "topic",
|
||||
Message: msg,
|
||||
Chats: chats,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RecallSession(sessionID int64) (*RecallResult, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("数据库未初始化")
|
||||
}
|
||||
|
||||
session, err := db.GetSessionByID(sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询会话失败: %v", err)
|
||||
}
|
||||
if session == nil {
|
||||
return &RecallResult{
|
||||
Type: "session",
|
||||
Message: fmt.Sprintf("未找到会话 ID: %d", sessionID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("**会话 %d 的摘要**\n\n%s", sessionID, session.Summary)
|
||||
|
||||
return &RecallResult{
|
||||
Type: "session",
|
||||
Message: msg,
|
||||
Sessions: []*Session{session},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RecallWithinSession(sessionID int64, query string) (*RecallResult, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("数据库未初始化")
|
||||
}
|
||||
|
||||
chats, err := db.GetChatsBySessionID(sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询聊天记录失败: %v", err)
|
||||
}
|
||||
|
||||
if len(chats) == 0 {
|
||||
return &RecallResult{
|
||||
Type: "within_session",
|
||||
Message: fmt.Sprintf("会话 %d 中暂无聊天记录", sessionID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
vs := GetVectorService()
|
||||
if vs == nil {
|
||||
return nil, fmt.Errorf("向量服务未初始化")
|
||||
}
|
||||
|
||||
queryEmb, err := vs.Generate(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成查询向量失败: %v", err)
|
||||
}
|
||||
|
||||
type scoredChat struct {
|
||||
chat *Chat
|
||||
similarity float64
|
||||
}
|
||||
|
||||
var scored []scoredChat
|
||||
for _, c := range chats {
|
||||
if len(c.SummaryEmbedding) == 0 {
|
||||
continue
|
||||
}
|
||||
sim := vs.CosineSimilarity(queryEmb, c.SummaryEmbedding)
|
||||
scored = append(scored, scoredChat{chat: c, similarity: sim})
|
||||
}
|
||||
|
||||
for i := 0; i < len(scored); i++ {
|
||||
for j := i + 1; j < len(scored); j++ {
|
||||
if scored[j].similarity > scored[i].similarity {
|
||||
scored[i], scored[j] = scored[j], scored[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
topK := 3
|
||||
if topK > len(scored) {
|
||||
topK = len(scored)
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("会话 %d 中与「%s」相关内容:\n\n", sessionID, query)
|
||||
var matchedChats []*Chat
|
||||
for i := 0; i < topK; i++ {
|
||||
msg += fmt.Sprintf("**相似度: %.2f**\n%s\n\n",
|
||||
scored[i].similarity, scored[i].chat.Summary)
|
||||
matchedChats = append(matchedChats, scored[i].chat)
|
||||
}
|
||||
|
||||
return &RecallResult{
|
||||
Type: "within_session",
|
||||
Message: msg,
|
||||
Chats: matchedChats,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func FormatRecallResults(result *RecallResult) string {
|
||||
return result.Message
|
||||
}
|
||||
Reference in New Issue
Block a user