Files
YunShu/tool.go

441 lines
12 KiB
Go
Raw Permalink Normal View History

2026-05-08 10:12:31 +08:00
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
2026-05-08 10:12:31 +08:00
"strings"
"time"
"gopkg.in/yaml.v3"
2026-05-08 10:12:31 +08:00
)
var registeredTools = make(map[string]*ToolDef)
func RegisterTool(td *ToolDef) {
registeredTools[td.Name] = td
}
func ListRegisteredTools() []*ToolDef {
list := make([]*ToolDef, 0, len(registeredTools))
for _, td := range registeredTools {
list = append(list, td)
}
return list
}
func GetToolDefs(names []string) []ToolDef {
defs := make([]ToolDef, 0, len(names))
for _, name := range names {
if td, ok := registeredTools[name]; ok {
defs = append(defs, *td)
}
}
return defs
}
func safeMemoryPath(path string) (string, error) {
cleanPath := filepath.Clean(path)
fullPath := filepath.Join(ConfigDir(), cleanPath)
realPath, err := filepath.EvalSymlinks(fullPath)
if err != nil {
if !os.IsNotExist(err) {
return "", fmt.Errorf("路径解析失败: %w", err)
}
realPath = fullPath
}
rel, err := filepath.Rel(ConfigDir(), realPath)
if err != nil || strings.HasPrefix(rel, "..") {
return "", fmt.Errorf("路径越界: %s", path)
}
return fullPath, nil
}
2026-05-08 10:12:31 +08:00
func ExecuteTool(tc ToolCall) (string, error) {
td, ok := registeredTools[tc.Function.Name]
if !ok {
return "", fmt.Errorf("未知工具: %s", tc.Function.Name)
}
var args map[string]any
2026-05-08 10:12:31 +08:00
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
return "", fmt.Errorf("解析工具参数失败: %w", err)
}
return td.Execute(args)
}
// ============================================================
// 工具输入结构体
// ============================================================
type HTTPGetInput struct {
URL string `json:"url" description:"请求的完整 URL 地址"`
Headers map[string]string `json:"headers,omitempty" description:"请求头键值对,如 {\"User-Agent\": \"...\"}"`
}
type SkillInput struct {
Name string `json:"name" description:"Skill 名称,如 msn-weather-api"`
}
type ReadFileInput struct {
Path string `json:"path" description:"文件路径,相对于项目根目录"`
}
type TaskInput struct {
Agent string `json:"agent" description:"子 Agent 名称,如 weather"`
Args map[string]any `json:"args" description:"子 Agent 参数对象"`
}
type MemoryReadInput struct {
Path string `json:"path" description:"文件路径。config/user.md(画像)、config/soul.md(AI灵魂)、session/dialog.yml(对话摘要)、notes/*.md(备忘录) 等。留空返回可用文件列表"`
}
type MemoryWriteInput struct {
Path string `json:"path" description:"文件路径。.md 按 ## 标题合并value 传字符串);.yml 按 key 合并value 传对象)"`
Value interface{} `json:"value" description:"内容。格式取决于文件类型"`
}
type GeocodeInput struct {
City string `json:"city" description:"城市名称,支持中文(如 北京)或英文(如 Beijing"`
}
// ============================================================
// 工具注册
// ============================================================
// mdMerge 按 ## 标题合并 Markdown 文件。
// incoming 中的标题覆盖 existing 中同名的,其他段保留。
func mdMerge(existing, incoming string) string {
if existing == "" {
return incoming
}
if incoming == "" {
return existing
}
type section struct {
heading string
content string
}
parse := func(text string) []section {
var secs []section
var cur section
for _, line := range strings.Split(text, "\n") {
if h, ok := strings.CutPrefix(line, "## "); ok && h != "" {
if cur.heading != "" || cur.content != "" {
secs = append(secs, cur)
}
cur = section{heading: h}
} else {
if cur.content != "" {
cur.content += "\n"
}
cur.content += line
2026-05-08 10:12:31 +08:00
}
}
if cur.heading != "" || cur.content != "" {
secs = append(secs, cur)
}
return secs
}
2026-05-08 10:12:31 +08:00
existingSecs := parse(existing)
incomingSecs := parse(incoming)
headingIdx := make(map[string]int)
for i, s := range existingSecs {
headingIdx[s.heading] = i
}
seen := make(map[string]bool)
var merged []section
for _, s := range existingSecs {
found := false
for _, in := range incomingSecs {
if in.heading == s.heading {
merged = append(merged, in)
seen[in.heading] = true
found = true
break
}
}
if !found {
merged = append(merged, s)
}
}
for _, in := range incomingSecs {
if !seen[in.heading] {
merged = append(merged, in)
}
}
var out strings.Builder
for i, s := range merged {
if i > 0 {
out.WriteString("\n")
}
if s.heading != "" {
out.WriteString("## ")
out.WriteString(s.heading)
out.WriteString("\n")
}
if strings.TrimSpace(s.content) != "" {
out.WriteString(strings.TrimRight(s.content, "\n"))
out.WriteString("\n")
}
}
return out.String()
}
func init() {
RegisterTool(NewTool[HTTPGetInput]("http-get",
"发送 HTTP GET 请求获取数据",
func(args HTTPGetInput) (string, error) {
req, err := http.NewRequest("GET", args.URL, nil)
2026-05-08 10:12:31 +08:00
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
for k, v := range args.Headers {
req.Header.Set(k, v)
2026-05-08 10:12:31 +08:00
}
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != 200 {
return fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(body)), nil
}
return string(body), nil
},
))
RegisterTool(NewTool[SkillInput]("skill",
"加载指定名称的 Skill 知识内容到当前上下文,获取专业知识",
func(args SkillInput) (string, error) {
return LoadSkill(args.Name)
2026-05-08 10:12:31 +08:00
},
))
RegisterTool(NewTool[ReadFileInput]("read-file",
"读取本地文件内容",
func(args ReadFileInput) (string, error) {
data, err := os.ReadFile(args.Path)
if err != nil {
return "", fmt.Errorf("读取文件失败: %w", err)
2026-05-08 10:12:31 +08:00
}
return string(data), nil
2026-05-08 10:12:31 +08:00
},
))
RegisterTool(NewTool[TaskInput]("task",
"调度子 Agent 执行领域任务。sub-agent 加载后自动查缓存,有缓存直接返回,无缓存调 LLM + 工具链获取新数据",
func(args TaskInput) (string, error) {
infoLog("task("+args.Agent+") 开始")
registry := ScanAgents()
sub := registry.GetSub(args.Agent)
if sub == nil {
errorLog("未找到子 Agent", "agent", args.Agent)
return "", fmt.Errorf("未找到子 Agent: %s", args.Agent)
}
var cacheKey string
var cacheData interface{}
if sub.Cache != nil && len(sub.Cache.Keys) > 0 {
cacheKey = buildCacheKey(sub.Cache.Keys, args.Args)
if entry := readCache(args.Agent, cacheKey); entry != nil {
cacheData = entry.Data
}
2026-05-08 10:12:31 +08:00
}
subInput := map[string]any{
"args": args.Args,
"cache_data": cacheData,
}
subInputBytes, _ := json.Marshal(subInput)
result, err := RunSubAgent(sub, string(subInputBytes))
2026-05-08 10:12:31 +08:00
if err != nil {
errorLog("task("+args.Agent+") 失败", "err", err)
return "", fmt.Errorf("子 Agent %s 执行失败: %w", args.Agent, err)
}
text, resultData := parseSubResult(result)
if cacheKey != "" && resultData != nil && sub.Cache != nil {
writeCache(args.Agent, cacheKey, resultData, args.Args, sub.Cache.TTL)
2026-05-08 10:12:31 +08:00
}
infoLog("task("+args.Agent+") 完成")
return text, nil
2026-05-08 10:12:31 +08:00
},
))
RegisterTool(NewTool[MemoryReadInput]("memory.read",
"读取记忆文件。路径支持 config/, session/, log.yml, notes/ 等。返回文件原始内容",
func(args MemoryReadInput) (string, error) {
fullPath, err := safeMemoryPath(args.Path)
if err != nil {
return "", err
}
info, err := os.Stat(fullPath)
if err != nil {
if os.IsNotExist(err) {
return "null", nil
}
return "", fmt.Errorf("读取失败: %w", err)
}
if info.IsDir() {
entries, err := os.ReadDir(fullPath)
if err != nil {
return "", fmt.Errorf("读取目录失败: %w", err)
}
names := make([]string, 0)
for _, e := range entries {
if !e.IsDir() {
names = append(names, e.Name())
}
}
out, _ := yaml.Marshal(names)
return string(out), nil
}
data, err := os.ReadFile(fullPath)
if err != nil {
return "", fmt.Errorf("读取失败: %w", err)
}
return string(data), nil
2026-05-08 10:12:31 +08:00
},
))
RegisterTool(NewTool[MemoryWriteInput]("memory.write",
"写入记忆文件。.md 按 ## 标题合并value 传字符串),.yml 按 key 合并value 传对象)。目录自动创建",
func(args MemoryWriteInput) (string, error) {
fullPath, err := safeMemoryPath(args.Path)
if err != nil {
return "", err
2026-05-08 10:12:31 +08:00
}
os.MkdirAll(filepath.Dir(fullPath), 0755)
ext := filepath.Ext(args.Path)
switch ext {
case ".yaml", ".yml":
existing := make(map[string]any)
if data, err := os.ReadFile(fullPath); err == nil {
if err := yaml.Unmarshal(data, &existing); err != nil {
warnLog("解析 yml 失败", "path", args.Path, "err", err)
}
}
if m, ok := args.Value.(map[string]any); ok {
for k, v := range m {
existing[k] = v
}
}
out, err := yaml.Marshal(existing)
if err != nil {
return "", fmt.Errorf("序列化 yml 失败: %w", err)
}
if err := os.WriteFile(fullPath, out, 0644); err != nil {
return "", fmt.Errorf("写入失败: %w", err)
}
default:
str, ok := args.Value.(string)
if !ok {
return "", fmt.Errorf(".md 文件 value 必须是字符串")
}
existing := ""
if data, err := os.ReadFile(fullPath); err == nil {
existing = string(data)
} else if !os.IsNotExist(err) {
warnLog("读取 md 失败", "path", args.Path, "err", err)
}
merged := mdMerge(existing, str)
if err := os.WriteFile(fullPath, []byte(merged), 0644); err != nil {
return "", fmt.Errorf("写入失败: %w", err)
}
}
return "ok", nil
},
))
2026-05-08 10:12:31 +08:00
RegisterTool(NewTool[GeocodeInput]("geocode",
"查询城市或地点的经纬度坐标,返回 lat/lon/name/country。支持中文城市名如 北京、上海、成都)和英文名",
func(args GeocodeInput) (string, error) {
url := fmt.Sprintf("https://wttr.in/%s?format=j1", args.City)
2026-05-08 10:12:31 +08:00
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("请求 wttr.in 失败: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("wttr.in 返回 HTTP %d: %s", resp.StatusCode, string(body))
}
var result struct {
NearestArea []struct {
AreaName []struct{ Value string `json:"value"` } `json:"areaName"`
Country []struct{ Value string `json:"value"` } `json:"country"`
Latitude string `json:"latitude"`
Longitude string `json:"longitude"`
2026-05-08 10:12:31 +08:00
} `json:"nearest_area"`
}
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("解析 wttr.in 响应失败: %w", err)
}
if len(result.NearestArea) == 0 {
return "", fmt.Errorf("未找到城市: %s", args.City)
2026-05-08 10:12:31 +08:00
}
area := result.NearestArea[0]
name := ""
if len(area.AreaName) > 0 {
name = area.AreaName[0].Value
}
country := ""
if len(area.Country) > 0 {
country = area.Country[0].Value
}
out, _ := json.Marshal(map[string]any{
2026-05-08 10:12:31 +08:00
"lat": area.Latitude,
"lon": area.Longitude,
"name": name,
"country": country,
})
return string(out), nil
},
))
2026-05-08 10:12:31 +08:00
}