package memory import ( "fmt" "github.com/hxclaw/hxclaw/cmd/hxclaw/internal" ) type RecallResult struct { Type string `json:"type"` Message string `json:"message"` Sessions []*Session `json:"sessions,omitempty"` Chats []*Chat `json:"chats,omitempty"` } func RecallHistory() (*RecallResult, error) { db := GetDB() if db == nil { return nil, fmt.Errorf("数据库未初始化") } sessions, err := db.GetAllSessions() if err != nil { return nil, fmt.Errorf("查询会话失败: %v", err) } if len(sessions) == 0 { return &RecallResult{ Type: "history", Message: "暂无历史会话记录", }, nil } msg := "以下是所有历史会话摘要:\n\n" for i, s := range sessions { summary := s.Summary if summary == "" { summary = "(暂无摘要)" } msg += fmt.Sprintf("**会话 %d** (ID: %d)\n%s\n\n", i+1, s.ID, summary) } return &RecallResult{ Type: "history", Message: msg, Sessions: sessions, }, nil } func RecallTopic(query string, topK int) (*RecallResult, error) { vs := GetVectorService() if vs == nil { return nil, fmt.Errorf("向量服务未初始化") } cfg := internal.GetProjectConfig().Memory if topK <= 0 { topK = cfg.Vector.MaxSearchResults if topK <= 0 { topK = 5 } } results, err := SearchSimilar(query, topK) if err != nil { return nil, fmt.Errorf("向量检索失败: %v", err) } if len(results) == 0 { return &RecallResult{ Type: "topic", Message: fmt.Sprintf("未找到与「%s」相关的内容", query), }, nil } type sessionChat struct { sessionID int64 chat *Chat similarity float64 } sessionChats := make(map[int64]*sessionChat) for _, r := range results { if existing, ok := sessionChats[r.Chat.SessionID]; ok { if r.Similarity > existing.similarity { sessionChats[r.Chat.SessionID] = &sessionChat{ sessionID: r.Chat.SessionID, chat: r.Chat, similarity: r.Similarity, } } } else { sessionChats[r.Chat.SessionID] = &sessionChat{ sessionID: r.Chat.SessionID, chat: r.Chat, similarity: r.Similarity, } } } msg := fmt.Sprintf("找到与「%s」相关的内容:\n\n", query) var chats []*Chat for _, sc := range sessionChats { msg += fmt.Sprintf("**会话 %d** (相似度: %.2f)\n%s\n\n", sc.sessionID, sc.similarity, sc.chat.Summary) chats = append(chats, sc.chat) } return &RecallResult{ Type: "topic", Message: msg, Chats: chats, }, nil } func RecallSession(sessionID int64) (*RecallResult, error) { db := GetDB() if db == nil { return nil, fmt.Errorf("数据库未初始化") } session, err := db.GetSessionByID(sessionID) if err != nil { return nil, fmt.Errorf("查询会话失败: %v", err) } if session == nil { return &RecallResult{ Type: "session", Message: fmt.Sprintf("未找到会话 ID: %d", sessionID), }, nil } msg := fmt.Sprintf("**会话 %d 的摘要**\n\n%s", sessionID, session.Summary) return &RecallResult{ Type: "session", Message: msg, Sessions: []*Session{session}, }, nil } func RecallWithinSession(sessionID int64, query string) (*RecallResult, error) { db := GetDB() if db == nil { return nil, fmt.Errorf("数据库未初始化") } chats, err := db.GetChatsBySessionID(sessionID) if err != nil { return nil, fmt.Errorf("查询聊天记录失败: %v", err) } if len(chats) == 0 { return &RecallResult{ Type: "within_session", Message: fmt.Sprintf("会话 %d 中暂无聊天记录", sessionID), }, nil } vs := GetVectorService() if vs == nil { return nil, fmt.Errorf("向量服务未初始化") } queryEmb, err := vs.Generate(query) if err != nil { return nil, fmt.Errorf("生成查询向量失败: %v", err) } type scoredChat struct { chat *Chat similarity float64 } var scored []scoredChat for _, c := range chats { if len(c.SummaryEmbedding) == 0 { continue } sim := vs.CosineSimilarity(queryEmb, c.SummaryEmbedding) scored = append(scored, scoredChat{chat: c, similarity: sim}) } for i := 0; i < len(scored); i++ { for j := i + 1; j < len(scored); j++ { if scored[j].similarity > scored[i].similarity { scored[i], scored[j] = scored[j], scored[i] } } } topK := 3 if topK > len(scored) { topK = len(scored) } msg := fmt.Sprintf("会话 %d 中与「%s」相关内容:\n\n", sessionID, query) var matchedChats []*Chat for i := 0; i < topK; i++ { msg += fmt.Sprintf("**相似度: %.2f**\n%s\n\n", scored[i].similarity, scored[i].chat.Summary) matchedChats = append(matchedChats, scored[i].chat) } return &RecallResult{ Type: "within_session", Message: msg, Chats: matchedChats, }, nil } func FormatRecallResults(result *RecallResult) string { return result.Message }