Files
YunShu/tool.go
titor c4a0e3ef53 feat: v2.3.0 流式输出 + 日志系统 + 会议室架构全面升级
- 流式输出: SSE 逐 token 接收, \\n\n\ 段落缓冲后 mdprint 彩色渲染
- 日志系统: charmbracelet/log v2 双写(stderr + log.yml), yunshu log 命令
- 会议室架构: dialog(main) + weather/profile/note(sub) 多 Agent 编排
- 泛型工具注册: NewTool[T] 反射推导 JSON Schema, 类型安全
- 安全加固: safeMemoryPath 三段校验(EvalSymlinks+Rel), maxToolCalls=2
- 性能优化: sync.Once 延迟加载, note 一步完成, obs/summary 合并
- Prompt 适配: 流式输出原则(先调工具不说话), 单 Agent 查询跳过 obs+summary
- 文档: AGENTS.md + architecture.md + changelog.md 全部同步至 v2.3.0
2026-05-16 17:21:29 +08:00

441 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"gopkg.in/yaml.v3"
)
var registeredTools = make(map[string]*ToolDef)
func RegisterTool(td *ToolDef) {
registeredTools[td.Name] = td
}
func ListRegisteredTools() []*ToolDef {
list := make([]*ToolDef, 0, len(registeredTools))
for _, td := range registeredTools {
list = append(list, td)
}
return list
}
func GetToolDefs(names []string) []ToolDef {
defs := make([]ToolDef, 0, len(names))
for _, name := range names {
if td, ok := registeredTools[name]; ok {
defs = append(defs, *td)
}
}
return defs
}
func safeMemoryPath(path string) (string, error) {
cleanPath := filepath.Clean(path)
fullPath := filepath.Join(ConfigDir(), cleanPath)
realPath, err := filepath.EvalSymlinks(fullPath)
if err != nil {
if !os.IsNotExist(err) {
return "", fmt.Errorf("路径解析失败: %w", err)
}
realPath = fullPath
}
rel, err := filepath.Rel(ConfigDir(), realPath)
if err != nil || strings.HasPrefix(rel, "..") {
return "", fmt.Errorf("路径越界: %s", path)
}
return fullPath, nil
}
func ExecuteTool(tc ToolCall) (string, error) {
td, ok := registeredTools[tc.Function.Name]
if !ok {
return "", fmt.Errorf("未知工具: %s", tc.Function.Name)
}
var args map[string]any
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
return "", fmt.Errorf("解析工具参数失败: %w", err)
}
return td.Execute(args)
}
// ============================================================
// 工具输入结构体
// ============================================================
type HTTPGetInput struct {
URL string `json:"url" description:"请求的完整 URL 地址"`
Headers map[string]string `json:"headers,omitempty" description:"请求头键值对,如 {\"User-Agent\": \"...\"}"`
}
type SkillInput struct {
Name string `json:"name" description:"Skill 名称,如 msn-weather-api"`
}
type ReadFileInput struct {
Path string `json:"path" description:"文件路径,相对于项目根目录"`
}
type TaskInput struct {
Agent string `json:"agent" description:"子 Agent 名称,如 weather"`
Args map[string]any `json:"args" description:"子 Agent 参数对象"`
}
type MemoryReadInput struct {
Path string `json:"path" description:"文件路径。config/user.md(画像)、config/soul.md(AI灵魂)、session/dialog.yml(对话摘要)、notes/*.md(备忘录) 等。留空返回可用文件列表"`
}
type MemoryWriteInput struct {
Path string `json:"path" description:"文件路径。.md 按 ## 标题合并value 传字符串);.yml 按 key 合并value 传对象)"`
Value interface{} `json:"value" description:"内容。格式取决于文件类型"`
}
type GeocodeInput struct {
City string `json:"city" description:"城市名称,支持中文(如 北京)或英文(如 Beijing"`
}
// ============================================================
// 工具注册
// ============================================================
// mdMerge 按 ## 标题合并 Markdown 文件。
// incoming 中的标题覆盖 existing 中同名的,其他段保留。
func mdMerge(existing, incoming string) string {
if existing == "" {
return incoming
}
if incoming == "" {
return existing
}
type section struct {
heading string
content string
}
parse := func(text string) []section {
var secs []section
var cur section
for _, line := range strings.Split(text, "\n") {
if h, ok := strings.CutPrefix(line, "## "); ok && h != "" {
if cur.heading != "" || cur.content != "" {
secs = append(secs, cur)
}
cur = section{heading: h}
} else {
if cur.content != "" {
cur.content += "\n"
}
cur.content += line
}
}
if cur.heading != "" || cur.content != "" {
secs = append(secs, cur)
}
return secs
}
existingSecs := parse(existing)
incomingSecs := parse(incoming)
headingIdx := make(map[string]int)
for i, s := range existingSecs {
headingIdx[s.heading] = i
}
seen := make(map[string]bool)
var merged []section
for _, s := range existingSecs {
found := false
for _, in := range incomingSecs {
if in.heading == s.heading {
merged = append(merged, in)
seen[in.heading] = true
found = true
break
}
}
if !found {
merged = append(merged, s)
}
}
for _, in := range incomingSecs {
if !seen[in.heading] {
merged = append(merged, in)
}
}
var out strings.Builder
for i, s := range merged {
if i > 0 {
out.WriteString("\n")
}
if s.heading != "" {
out.WriteString("## ")
out.WriteString(s.heading)
out.WriteString("\n")
}
if strings.TrimSpace(s.content) != "" {
out.WriteString(strings.TrimRight(s.content, "\n"))
out.WriteString("\n")
}
}
return out.String()
}
func init() {
RegisterTool(NewTool[HTTPGetInput]("http-get",
"发送 HTTP GET 请求获取数据",
func(args HTTPGetInput) (string, error) {
req, err := http.NewRequest("GET", args.URL, nil)
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
for k, v := range args.Headers {
req.Header.Set(k, v)
}
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != 200 {
return fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(body)), nil
}
return string(body), nil
},
))
RegisterTool(NewTool[SkillInput]("skill",
"加载指定名称的 Skill 知识内容到当前上下文,获取专业知识",
func(args SkillInput) (string, error) {
return LoadSkill(args.Name)
},
))
RegisterTool(NewTool[ReadFileInput]("read-file",
"读取本地文件内容",
func(args ReadFileInput) (string, error) {
data, err := os.ReadFile(args.Path)
if err != nil {
return "", fmt.Errorf("读取文件失败: %w", err)
}
return string(data), nil
},
))
RegisterTool(NewTool[TaskInput]("task",
"调度子 Agent 执行领域任务。sub-agent 加载后自动查缓存,有缓存直接返回,无缓存调 LLM + 工具链获取新数据",
func(args TaskInput) (string, error) {
infoLog("task("+args.Agent+") 开始")
registry := ScanAgents()
sub := registry.GetSub(args.Agent)
if sub == nil {
errorLog("未找到子 Agent", "agent", args.Agent)
return "", fmt.Errorf("未找到子 Agent: %s", args.Agent)
}
var cacheKey string
var cacheData interface{}
if sub.Cache != nil && len(sub.Cache.Keys) > 0 {
cacheKey = buildCacheKey(sub.Cache.Keys, args.Args)
if entry := readCache(args.Agent, cacheKey); entry != nil {
cacheData = entry.Data
}
}
subInput := map[string]any{
"args": args.Args,
"cache_data": cacheData,
}
subInputBytes, _ := json.Marshal(subInput)
result, err := RunSubAgent(sub, string(subInputBytes))
if err != nil {
errorLog("task("+args.Agent+") 失败", "err", err)
return "", fmt.Errorf("子 Agent %s 执行失败: %w", args.Agent, err)
}
text, resultData := parseSubResult(result)
if cacheKey != "" && resultData != nil && sub.Cache != nil {
writeCache(args.Agent, cacheKey, resultData, args.Args, sub.Cache.TTL)
}
infoLog("task("+args.Agent+") 完成")
return text, nil
},
))
RegisterTool(NewTool[MemoryReadInput]("memory.read",
"读取记忆文件。路径支持 config/, session/, log.yml, notes/ 等。返回文件原始内容",
func(args MemoryReadInput) (string, error) {
fullPath, err := safeMemoryPath(args.Path)
if err != nil {
return "", err
}
info, err := os.Stat(fullPath)
if err != nil {
if os.IsNotExist(err) {
return "null", nil
}
return "", fmt.Errorf("读取失败: %w", err)
}
if info.IsDir() {
entries, err := os.ReadDir(fullPath)
if err != nil {
return "", fmt.Errorf("读取目录失败: %w", err)
}
names := make([]string, 0)
for _, e := range entries {
if !e.IsDir() {
names = append(names, e.Name())
}
}
out, _ := yaml.Marshal(names)
return string(out), nil
}
data, err := os.ReadFile(fullPath)
if err != nil {
return "", fmt.Errorf("读取失败: %w", err)
}
return string(data), nil
},
))
RegisterTool(NewTool[MemoryWriteInput]("memory.write",
"写入记忆文件。.md 按 ## 标题合并value 传字符串),.yml 按 key 合并value 传对象)。目录自动创建",
func(args MemoryWriteInput) (string, error) {
fullPath, err := safeMemoryPath(args.Path)
if err != nil {
return "", err
}
os.MkdirAll(filepath.Dir(fullPath), 0755)
ext := filepath.Ext(args.Path)
switch ext {
case ".yaml", ".yml":
existing := make(map[string]any)
if data, err := os.ReadFile(fullPath); err == nil {
if err := yaml.Unmarshal(data, &existing); err != nil {
warnLog("解析 yml 失败", "path", args.Path, "err", err)
}
}
if m, ok := args.Value.(map[string]any); ok {
for k, v := range m {
existing[k] = v
}
}
out, err := yaml.Marshal(existing)
if err != nil {
return "", fmt.Errorf("序列化 yml 失败: %w", err)
}
if err := os.WriteFile(fullPath, out, 0644); err != nil {
return "", fmt.Errorf("写入失败: %w", err)
}
default:
str, ok := args.Value.(string)
if !ok {
return "", fmt.Errorf(".md 文件 value 必须是字符串")
}
existing := ""
if data, err := os.ReadFile(fullPath); err == nil {
existing = string(data)
} else if !os.IsNotExist(err) {
warnLog("读取 md 失败", "path", args.Path, "err", err)
}
merged := mdMerge(existing, str)
if err := os.WriteFile(fullPath, []byte(merged), 0644); err != nil {
return "", fmt.Errorf("写入失败: %w", err)
}
}
return "ok", nil
},
))
RegisterTool(NewTool[GeocodeInput]("geocode",
"查询城市或地点的经纬度坐标,返回 lat/lon/name/country。支持中文城市名如 北京、上海、成都)和英文名",
func(args GeocodeInput) (string, error) {
url := fmt.Sprintf("https://wttr.in/%s?format=j1", args.City)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("请求 wttr.in 失败: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("wttr.in 返回 HTTP %d: %s", resp.StatusCode, string(body))
}
var result struct {
NearestArea []struct {
AreaName []struct{ Value string `json:"value"` } `json:"areaName"`
Country []struct{ Value string `json:"value"` } `json:"country"`
Latitude string `json:"latitude"`
Longitude string `json:"longitude"`
} `json:"nearest_area"`
}
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("解析 wttr.in 响应失败: %w", err)
}
if len(result.NearestArea) == 0 {
return "", fmt.Errorf("未找到城市: %s", args.City)
}
area := result.NearestArea[0]
name := ""
if len(area.AreaName) > 0 {
name = area.AreaName[0].Value
}
country := ""
if len(area.Country) > 0 {
country = area.Country[0].Value
}
out, _ := json.Marshal(map[string]any{
"lat": area.Latitude,
"lon": area.Longitude,
"name": name,
"country": country,
})
return string(out), nil
},
))
}