- 流式输出: SSE 逐 token 接收, \\n\n\ 段落缓冲后 mdprint 彩色渲染 - 日志系统: charmbracelet/log v2 双写(stderr + log.yml), yunshu log 命令 - 会议室架构: dialog(main) + weather/profile/note(sub) 多 Agent 编排 - 泛型工具注册: NewTool[T] 反射推导 JSON Schema, 类型安全 - 安全加固: safeMemoryPath 三段校验(EvalSymlinks+Rel), maxToolCalls=2 - 性能优化: sync.Once 延迟加载, note 一步完成, obs/summary 合并 - Prompt 适配: 流式输出原则(先调工具不说话), 单 Agent 查询跳过 obs+summary - 文档: AGENTS.md + architecture.md + changelog.md 全部同步至 v2.3.0
282 lines
6.6 KiB
Go
282 lines
6.6 KiB
Go
package main
|
||
|
||
import (
|
||
"crypto/sha256"
|
||
"encoding/json"
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
func sessionPath() string {
|
||
return filepath.Join(ConfigDir(), "session", "session.json")
|
||
}
|
||
|
||
func ClearSession() {
|
||
os.Remove(sessionPath())
|
||
infoLog("会话已清空")
|
||
}
|
||
|
||
const maxSessionMessages = 40
|
||
|
||
func LoadSession() []Message {
|
||
data, err := os.ReadFile(sessionPath())
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
|
||
var messages []Message
|
||
if err := json.Unmarshal(data, &messages); err != nil {
|
||
warnLog("解析 session.json 失败", "err", err)
|
||
return nil
|
||
}
|
||
if len(messages) > maxSessionMessages {
|
||
messages = messages[len(messages)-maxSessionMessages:]
|
||
}
|
||
return messages
|
||
|
||
}
|
||
|
||
func AppendToSession(msg Message) {
|
||
|
||
messages := LoadSession()
|
||
messages = append(messages, msg)
|
||
|
||
data, err := json.MarshalIndent(messages, "", " ")
|
||
if err != nil {
|
||
warnLog("序列化 session 失败", "err", err)
|
||
return
|
||
}
|
||
os.WriteFile(sessionPath(), data, 0644)
|
||
}
|
||
|
||
// ============================================================
|
||
// Cache 辅助
|
||
// ============================================================
|
||
|
||
func cacheDir() string {
|
||
return filepath.Join(ConfigDir(), "cache")
|
||
}
|
||
|
||
func cacheFilePath(agentName string) string {
|
||
return filepath.Join(cacheDir(), agentName+".json")
|
||
}
|
||
|
||
type cacheEntry struct {
|
||
CreatedAt time.Time `json:"created_at"`
|
||
TTL int `json:"ttl"`
|
||
Data interface{} `json:"data"`
|
||
Raw map[string]interface{} `json:"raw"`
|
||
}
|
||
|
||
func buildCacheKey(keys []string, args map[string]interface{}) string {
|
||
parts := make([]string, 0)
|
||
for _, k := range keys {
|
||
if v, ok := args[k]; ok {
|
||
parts = append(parts, fmt.Sprintf("%s=%v", k, v))
|
||
}
|
||
}
|
||
if len(parts) == 0 {
|
||
return ""
|
||
}
|
||
h := sha256.Sum256([]byte(strings.Join(parts, "&")))
|
||
return fmt.Sprintf("%x", h[:6])
|
||
}
|
||
|
||
func readCache(agentName, key string) *cacheEntry {
|
||
if key == "" {
|
||
return nil
|
||
}
|
||
data, err := os.ReadFile(cacheFilePath(agentName))
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
var store map[string]cacheEntry
|
||
if err := json.Unmarshal(data, &store); err != nil {
|
||
warnLog("解析缓存失败", "agent", agentName, "err", err)
|
||
return nil
|
||
}
|
||
entry, ok := store[key]
|
||
if !ok {
|
||
return nil
|
||
}
|
||
if time.Since(entry.CreatedAt) > time.Duration(entry.TTL)*time.Second {
|
||
delete(store, key)
|
||
return nil
|
||
}
|
||
return &entry
|
||
}
|
||
|
||
func writeCache(agentName, key string, data interface{}, raw map[string]interface{}, ttl int) {
|
||
if key == "" {
|
||
return
|
||
}
|
||
store := make(map[string]cacheEntry)
|
||
existing, err := os.ReadFile(cacheFilePath(agentName))
|
||
if err == nil {
|
||
if err := json.Unmarshal(existing, &store); err != nil {
|
||
warnLog("读取旧缓存解析失败", "agent", agentName, "err", err)
|
||
}
|
||
}
|
||
store[key] = cacheEntry{
|
||
CreatedAt: time.Now(),
|
||
TTL: ttl,
|
||
Data: data,
|
||
Raw: raw,
|
||
}
|
||
dir := cacheDir()
|
||
os.MkdirAll(dir, 0755)
|
||
out, err := json.MarshalIndent(store, "", " ")
|
||
if err != nil {
|
||
warnLog("序列化缓存失败", "agent", agentName, "err", err)
|
||
return
|
||
}
|
||
os.WriteFile(cacheFilePath(agentName), out, 0644)
|
||
}
|
||
|
||
// ============================================================
|
||
// 子 Agent 返回解析
|
||
// ============================================================
|
||
|
||
func parseSubResult(raw string) (text string, resultData interface{}) {
|
||
const resultMarker = "---RESULT---\n"
|
||
const textMarker = "\n---TEXT---"
|
||
|
||
if !strings.Contains(raw, resultMarker) {
|
||
return raw, nil
|
||
}
|
||
|
||
parts := strings.SplitN(raw, resultMarker, 2)
|
||
remaining := parts[1]
|
||
|
||
resultEnd := strings.Index(remaining, textMarker)
|
||
if resultEnd == -1 {
|
||
return raw, nil
|
||
}
|
||
|
||
jsonStr := strings.TrimSpace(remaining[:resultEnd])
|
||
json.Unmarshal([]byte(jsonStr), &resultData)
|
||
text = strings.TrimSpace(remaining[resultEnd+len(textMarker):])
|
||
return
|
||
}
|
||
|
||
// ============================================================
|
||
// RunSubAgent — 隔离的子 Agent 执行(不读写 session)
|
||
// ============================================================
|
||
|
||
func RunSubAgent(def *AgentDef, userInput string) (string, error) {
|
||
infoLog("子 Agent 开始", "agent", def.Name)
|
||
messages := []Message{
|
||
{Role: RoleSystem, Content: def.SystemPrompt},
|
||
{Role: RoleUser, Content: userInput},
|
||
}
|
||
|
||
toolDefs := GetToolDefs(def.Tools)
|
||
maxToolCalls := 2
|
||
toolCallCount := 0
|
||
|
||
for {
|
||
resp, err := CallLLM(messages, toolDefs)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
choice := resp.Choices[0]
|
||
|
||
if len(choice.Message.ToolCalls) > 0 {
|
||
toolCallCount++
|
||
if toolCallCount > maxToolCalls {
|
||
warnLog("子 Agent 执行轮次超限", "agent", def.Name, "rounds", toolCallCount)
|
||
return "---TEXT---\n(子 Agent 执行轮次超限,已终止)", nil
|
||
}
|
||
assistantMsg := Message{
|
||
Role: RoleAssistant,
|
||
ToolCalls: choice.Message.ToolCalls,
|
||
}
|
||
messages = append(messages, assistantMsg)
|
||
|
||
for _, tc := range choice.Message.ToolCalls {
|
||
result, err := ExecuteTool(tc)
|
||
if err != nil {
|
||
result = fmt.Sprintf("工具执行错误: %v", err)
|
||
}
|
||
toolMsg := Message{
|
||
Role: RoleTool,
|
||
Content: result,
|
||
ToolCallID: tc.ID,
|
||
}
|
||
messages = append(messages, toolMsg)
|
||
}
|
||
} else {
|
||
content := ""
|
||
if choice.Message.Content != nil {
|
||
content = *choice.Message.Content
|
||
}
|
||
return content, nil
|
||
}
|
||
}
|
||
}
|
||
|
||
func RunAgent(def *AgentDef, userInput string) error {
|
||
messages := LoadSession()
|
||
|
||
fullMessages := []Message{
|
||
{Role: RoleSystem, Content: def.SystemPrompt},
|
||
}
|
||
fullMessages = append(fullMessages, messages...)
|
||
fullMessages = append(fullMessages, Message{Role: RoleUser, Content: userInput})
|
||
|
||
AppendToSession(Message{Role: RoleUser, Content: userInput})
|
||
|
||
toolDefs := GetToolDefs(def.Tools)
|
||
|
||
for {
|
||
fmt.Println()
|
||
resp, err := CallLLMStream(fullMessages, toolDefs)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
choice := resp.Choices[0]
|
||
|
||
if len(choice.Message.ToolCalls) > 0 {
|
||
assistantMsg := Message{
|
||
Role: RoleAssistant,
|
||
ToolCalls: choice.Message.ToolCalls,
|
||
}
|
||
fullMessages = append(fullMessages, assistantMsg)
|
||
AppendToSession(assistantMsg)
|
||
|
||
for _, tc := range choice.Message.ToolCalls {
|
||
result, err := ExecuteTool(tc)
|
||
if err != nil {
|
||
result = fmt.Sprintf("工具执行错误: %v", err)
|
||
}
|
||
toolMsg := Message{
|
||
Role: RoleTool,
|
||
Content: result,
|
||
ToolCallID: tc.ID,
|
||
}
|
||
fullMessages = append(fullMessages, toolMsg)
|
||
AppendToSession(toolMsg)
|
||
}
|
||
} else {
|
||
content := ""
|
||
if choice.Message.Content != nil {
|
||
content = *choice.Message.Content
|
||
}
|
||
|
||
assistantMsg := Message{
|
||
Role: RoleAssistant,
|
||
Content: content,
|
||
}
|
||
fullMessages = append(fullMessages, assistantMsg)
|
||
AppendToSession(assistantMsg)
|
||
|
||
return nil
|
||
}
|
||
}
|
||
}
|