- 流式输出: SSE 逐 token 接收, \\n\n\ 段落缓冲后 mdprint 彩色渲染 - 日志系统: charmbracelet/log v2 双写(stderr + log.yml), yunshu log 命令 - 会议室架构: dialog(main) + weather/profile/note(sub) 多 Agent 编排 - 泛型工具注册: NewTool[T] 反射推导 JSON Schema, 类型安全 - 安全加固: safeMemoryPath 三段校验(EvalSymlinks+Rel), maxToolCalls=2 - 性能优化: sync.Once 延迟加载, note 一步完成, obs/summary 合并 - Prompt 适配: 流式输出原则(先调工具不说话), 单 Agent 查询跳过 obs+summary - 文档: AGENTS.md + architecture.md + changelog.md 全部同步至 v2.3.0
441 lines
12 KiB
Go
441 lines
12 KiB
Go
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
|
||
"gopkg.in/yaml.v3"
|
||
)
|
||
|
||
var registeredTools = make(map[string]*ToolDef)
|
||
|
||
func RegisterTool(td *ToolDef) {
|
||
registeredTools[td.Name] = td
|
||
}
|
||
|
||
func ListRegisteredTools() []*ToolDef {
|
||
list := make([]*ToolDef, 0, len(registeredTools))
|
||
for _, td := range registeredTools {
|
||
list = append(list, td)
|
||
}
|
||
return list
|
||
}
|
||
|
||
func GetToolDefs(names []string) []ToolDef {
|
||
defs := make([]ToolDef, 0, len(names))
|
||
for _, name := range names {
|
||
if td, ok := registeredTools[name]; ok {
|
||
defs = append(defs, *td)
|
||
}
|
||
}
|
||
return defs
|
||
}
|
||
|
||
func safeMemoryPath(path string) (string, error) {
|
||
cleanPath := filepath.Clean(path)
|
||
fullPath := filepath.Join(ConfigDir(), cleanPath)
|
||
|
||
realPath, err := filepath.EvalSymlinks(fullPath)
|
||
if err != nil {
|
||
if !os.IsNotExist(err) {
|
||
return "", fmt.Errorf("路径解析失败: %w", err)
|
||
}
|
||
realPath = fullPath
|
||
}
|
||
|
||
rel, err := filepath.Rel(ConfigDir(), realPath)
|
||
if err != nil || strings.HasPrefix(rel, "..") {
|
||
return "", fmt.Errorf("路径越界: %s", path)
|
||
}
|
||
|
||
return fullPath, nil
|
||
}
|
||
|
||
func ExecuteTool(tc ToolCall) (string, error) {
|
||
td, ok := registeredTools[tc.Function.Name]
|
||
if !ok {
|
||
return "", fmt.Errorf("未知工具: %s", tc.Function.Name)
|
||
}
|
||
|
||
var args map[string]any
|
||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
|
||
return "", fmt.Errorf("解析工具参数失败: %w", err)
|
||
}
|
||
|
||
return td.Execute(args)
|
||
}
|
||
|
||
// ============================================================
|
||
// 工具输入结构体
|
||
// ============================================================
|
||
|
||
type HTTPGetInput struct {
|
||
URL string `json:"url" description:"请求的完整 URL 地址"`
|
||
Headers map[string]string `json:"headers,omitempty" description:"请求头键值对,如 {\"User-Agent\": \"...\"}"`
|
||
}
|
||
|
||
type SkillInput struct {
|
||
Name string `json:"name" description:"Skill 名称,如 msn-weather-api"`
|
||
}
|
||
|
||
type ReadFileInput struct {
|
||
Path string `json:"path" description:"文件路径,相对于项目根目录"`
|
||
}
|
||
|
||
type TaskInput struct {
|
||
Agent string `json:"agent" description:"子 Agent 名称,如 weather"`
|
||
Args map[string]any `json:"args" description:"子 Agent 参数对象"`
|
||
}
|
||
|
||
type MemoryReadInput struct {
|
||
Path string `json:"path" description:"文件路径。config/user.md(画像)、config/soul.md(AI灵魂)、session/dialog.yml(对话摘要)、notes/*.md(备忘录) 等。留空返回可用文件列表"`
|
||
}
|
||
|
||
type MemoryWriteInput struct {
|
||
Path string `json:"path" description:"文件路径。.md 按 ## 标题合并(value 传字符串);.yml 按 key 合并(value 传对象)"`
|
||
Value interface{} `json:"value" description:"内容。格式取决于文件类型"`
|
||
}
|
||
|
||
type GeocodeInput struct {
|
||
City string `json:"city" description:"城市名称,支持中文(如 北京)或英文(如 Beijing)"`
|
||
}
|
||
|
||
// ============================================================
|
||
// 工具注册
|
||
// ============================================================
|
||
|
||
// mdMerge 按 ## 标题合并 Markdown 文件。
|
||
// incoming 中的标题覆盖 existing 中同名的,其他段保留。
|
||
func mdMerge(existing, incoming string) string {
|
||
if existing == "" {
|
||
return incoming
|
||
}
|
||
if incoming == "" {
|
||
return existing
|
||
}
|
||
|
||
type section struct {
|
||
heading string
|
||
content string
|
||
}
|
||
|
||
parse := func(text string) []section {
|
||
var secs []section
|
||
var cur section
|
||
for _, line := range strings.Split(text, "\n") {
|
||
if h, ok := strings.CutPrefix(line, "## "); ok && h != "" {
|
||
if cur.heading != "" || cur.content != "" {
|
||
secs = append(secs, cur)
|
||
}
|
||
cur = section{heading: h}
|
||
} else {
|
||
if cur.content != "" {
|
||
cur.content += "\n"
|
||
}
|
||
cur.content += line
|
||
}
|
||
}
|
||
if cur.heading != "" || cur.content != "" {
|
||
secs = append(secs, cur)
|
||
}
|
||
return secs
|
||
}
|
||
|
||
existingSecs := parse(existing)
|
||
incomingSecs := parse(incoming)
|
||
|
||
headingIdx := make(map[string]int)
|
||
for i, s := range existingSecs {
|
||
headingIdx[s.heading] = i
|
||
}
|
||
|
||
seen := make(map[string]bool)
|
||
var merged []section
|
||
|
||
for _, s := range existingSecs {
|
||
found := false
|
||
for _, in := range incomingSecs {
|
||
if in.heading == s.heading {
|
||
merged = append(merged, in)
|
||
seen[in.heading] = true
|
||
found = true
|
||
break
|
||
}
|
||
}
|
||
if !found {
|
||
merged = append(merged, s)
|
||
}
|
||
}
|
||
|
||
for _, in := range incomingSecs {
|
||
if !seen[in.heading] {
|
||
merged = append(merged, in)
|
||
}
|
||
}
|
||
|
||
var out strings.Builder
|
||
for i, s := range merged {
|
||
if i > 0 {
|
||
out.WriteString("\n")
|
||
}
|
||
if s.heading != "" {
|
||
out.WriteString("## ")
|
||
out.WriteString(s.heading)
|
||
out.WriteString("\n")
|
||
}
|
||
if strings.TrimSpace(s.content) != "" {
|
||
out.WriteString(strings.TrimRight(s.content, "\n"))
|
||
out.WriteString("\n")
|
||
}
|
||
}
|
||
return out.String()
|
||
}
|
||
|
||
func init() {
|
||
RegisterTool(NewTool[HTTPGetInput]("http-get",
|
||
"发送 HTTP GET 请求获取数据",
|
||
func(args HTTPGetInput) (string, error) {
|
||
req, err := http.NewRequest("GET", args.URL, nil)
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||
}
|
||
for k, v := range args.Headers {
|
||
req.Header.Set(k, v)
|
||
}
|
||
|
||
client := &http.Client{Timeout: 15 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return "", fmt.Errorf("请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取响应失败: %w", err)
|
||
}
|
||
|
||
if resp.StatusCode != 200 {
|
||
return fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(body)), nil
|
||
}
|
||
return string(body), nil
|
||
},
|
||
))
|
||
|
||
RegisterTool(NewTool[SkillInput]("skill",
|
||
"加载指定名称的 Skill 知识内容到当前上下文,获取专业知识",
|
||
func(args SkillInput) (string, error) {
|
||
return LoadSkill(args.Name)
|
||
},
|
||
))
|
||
|
||
RegisterTool(NewTool[ReadFileInput]("read-file",
|
||
"读取本地文件内容",
|
||
func(args ReadFileInput) (string, error) {
|
||
data, err := os.ReadFile(args.Path)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取文件失败: %w", err)
|
||
}
|
||
return string(data), nil
|
||
},
|
||
))
|
||
|
||
RegisterTool(NewTool[TaskInput]("task",
|
||
"调度子 Agent 执行领域任务。sub-agent 加载后自动查缓存,有缓存直接返回,无缓存调 LLM + 工具链获取新数据",
|
||
func(args TaskInput) (string, error) {
|
||
infoLog("task("+args.Agent+") 开始")
|
||
registry := ScanAgents()
|
||
sub := registry.GetSub(args.Agent)
|
||
if sub == nil {
|
||
errorLog("未找到子 Agent", "agent", args.Agent)
|
||
return "", fmt.Errorf("未找到子 Agent: %s", args.Agent)
|
||
}
|
||
|
||
var cacheKey string
|
||
var cacheData interface{}
|
||
if sub.Cache != nil && len(sub.Cache.Keys) > 0 {
|
||
cacheKey = buildCacheKey(sub.Cache.Keys, args.Args)
|
||
if entry := readCache(args.Agent, cacheKey); entry != nil {
|
||
cacheData = entry.Data
|
||
}
|
||
}
|
||
|
||
subInput := map[string]any{
|
||
"args": args.Args,
|
||
"cache_data": cacheData,
|
||
}
|
||
subInputBytes, _ := json.Marshal(subInput)
|
||
|
||
result, err := RunSubAgent(sub, string(subInputBytes))
|
||
if err != nil {
|
||
errorLog("task("+args.Agent+") 失败", "err", err)
|
||
return "", fmt.Errorf("子 Agent %s 执行失败: %w", args.Agent, err)
|
||
}
|
||
|
||
text, resultData := parseSubResult(result)
|
||
|
||
if cacheKey != "" && resultData != nil && sub.Cache != nil {
|
||
writeCache(args.Agent, cacheKey, resultData, args.Args, sub.Cache.TTL)
|
||
}
|
||
|
||
infoLog("task("+args.Agent+") 完成")
|
||
return text, nil
|
||
},
|
||
))
|
||
|
||
RegisterTool(NewTool[MemoryReadInput]("memory.read",
|
||
"读取记忆文件。路径支持 config/, session/, log.yml, notes/ 等。返回文件原始内容",
|
||
func(args MemoryReadInput) (string, error) {
|
||
fullPath, err := safeMemoryPath(args.Path)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
info, err := os.Stat(fullPath)
|
||
if err != nil {
|
||
if os.IsNotExist(err) {
|
||
return "null", nil
|
||
}
|
||
return "", fmt.Errorf("读取失败: %w", err)
|
||
}
|
||
if info.IsDir() {
|
||
entries, err := os.ReadDir(fullPath)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取目录失败: %w", err)
|
||
}
|
||
names := make([]string, 0)
|
||
for _, e := range entries {
|
||
if !e.IsDir() {
|
||
names = append(names, e.Name())
|
||
}
|
||
}
|
||
out, _ := yaml.Marshal(names)
|
||
return string(out), nil
|
||
}
|
||
|
||
data, err := os.ReadFile(fullPath)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取失败: %w", err)
|
||
}
|
||
return string(data), nil
|
||
},
|
||
))
|
||
|
||
RegisterTool(NewTool[MemoryWriteInput]("memory.write",
|
||
"写入记忆文件。.md 按 ## 标题合并(value 传字符串),.yml 按 key 合并(value 传对象)。目录自动创建",
|
||
func(args MemoryWriteInput) (string, error) {
|
||
fullPath, err := safeMemoryPath(args.Path)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
os.MkdirAll(filepath.Dir(fullPath), 0755)
|
||
|
||
ext := filepath.Ext(args.Path)
|
||
switch ext {
|
||
case ".yaml", ".yml":
|
||
existing := make(map[string]any)
|
||
if data, err := os.ReadFile(fullPath); err == nil {
|
||
if err := yaml.Unmarshal(data, &existing); err != nil {
|
||
warnLog("解析 yml 失败", "path", args.Path, "err", err)
|
||
}
|
||
}
|
||
if m, ok := args.Value.(map[string]any); ok {
|
||
for k, v := range m {
|
||
existing[k] = v
|
||
}
|
||
}
|
||
out, err := yaml.Marshal(existing)
|
||
if err != nil {
|
||
return "", fmt.Errorf("序列化 yml 失败: %w", err)
|
||
}
|
||
if err := os.WriteFile(fullPath, out, 0644); err != nil {
|
||
return "", fmt.Errorf("写入失败: %w", err)
|
||
}
|
||
default:
|
||
str, ok := args.Value.(string)
|
||
if !ok {
|
||
return "", fmt.Errorf(".md 文件 value 必须是字符串")
|
||
}
|
||
existing := ""
|
||
if data, err := os.ReadFile(fullPath); err == nil {
|
||
existing = string(data)
|
||
} else if !os.IsNotExist(err) {
|
||
warnLog("读取 md 失败", "path", args.Path, "err", err)
|
||
}
|
||
merged := mdMerge(existing, str)
|
||
if err := os.WriteFile(fullPath, []byte(merged), 0644); err != nil {
|
||
return "", fmt.Errorf("写入失败: %w", err)
|
||
}
|
||
}
|
||
return "ok", nil
|
||
},
|
||
))
|
||
|
||
RegisterTool(NewTool[GeocodeInput]("geocode",
|
||
"查询城市或地点的经纬度坐标,返回 lat/lon/name/country。支持中文城市名(如 北京、上海、成都)和英文名",
|
||
func(args GeocodeInput) (string, error) {
|
||
url := fmt.Sprintf("https://wttr.in/%s?format=j1", args.City)
|
||
req, err := http.NewRequest("GET", url, nil)
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||
}
|
||
|
||
client := &http.Client{Timeout: 15 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return "", fmt.Errorf("请求 wttr.in 失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取响应失败: %w", err)
|
||
}
|
||
|
||
if resp.StatusCode != 200 {
|
||
return "", fmt.Errorf("wttr.in 返回 HTTP %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
var result struct {
|
||
NearestArea []struct {
|
||
AreaName []struct{ Value string `json:"value"` } `json:"areaName"`
|
||
Country []struct{ Value string `json:"value"` } `json:"country"`
|
||
Latitude string `json:"latitude"`
|
||
Longitude string `json:"longitude"`
|
||
} `json:"nearest_area"`
|
||
}
|
||
|
||
if err := json.Unmarshal(body, &result); err != nil {
|
||
return "", fmt.Errorf("解析 wttr.in 响应失败: %w", err)
|
||
}
|
||
|
||
if len(result.NearestArea) == 0 {
|
||
return "", fmt.Errorf("未找到城市: %s", args.City)
|
||
}
|
||
|
||
area := result.NearestArea[0]
|
||
name := ""
|
||
if len(area.AreaName) > 0 {
|
||
name = area.AreaName[0].Value
|
||
}
|
||
country := ""
|
||
if len(area.Country) > 0 {
|
||
country = area.Country[0].Value
|
||
}
|
||
|
||
out, _ := json.Marshal(map[string]any{
|
||
"lat": area.Latitude,
|
||
"lon": area.Longitude,
|
||
"name": name,
|
||
"country": country,
|
||
})
|
||
return string(out), nil
|
||
},
|
||
))
|
||
}
|