Files
YunShu/toolschema.go

134 lines
2.8 KiB
Go
Raw Normal View History

package main
import (
"encoding/json"
"fmt"
"reflect"
"strings"
)
func structToSchema(t reflect.Type) Schema {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return typeToSchema(t)
}
schema := Schema{
"type": "object",
"properties": map[string]any{},
}
properties := schema["properties"].(map[string]any)
var required []any
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if !f.IsExported() {
continue
}
jsonTag := f.Tag.Get("json")
if jsonTag == "-" {
continue
}
name := strings.Split(jsonTag, ",")[0]
if name == "" {
name = strings.ToLower(f.Name[:1]) + f.Name[1:]
}
if !strings.Contains(jsonTag, "omitempty") {
required = append(required, name)
}
fieldSchema := typeToSchema(f.Type)
if desc := f.Tag.Get("description"); desc != "" {
fieldSchema["description"] = desc
}
if enum := f.Tag.Get("enum"); enum != "" {
vals := strings.Split(enum, ",")
enumVals := make([]any, len(vals))
for i, v := range vals {
enumVals[i] = strings.TrimSpace(v)
}
fieldSchema["enum"] = enumVals
}
properties[name] = fieldSchema
}
if len(required) > 0 {
schema["required"] = required
}
return schema
}
func typeToSchema(t reflect.Type) Schema {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
switch t.Kind() {
case reflect.String:
return Schema{"type": "string"}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return Schema{"type": "integer"}
case reflect.Float32, reflect.Float64:
return Schema{"type": "number"}
case reflect.Bool:
return Schema{"type": "boolean"}
case reflect.Slice, reflect.Array:
items := typeToSchema(t.Elem())
return Schema{"type": "array", "items": items}
case reflect.Interface:
return Schema{}
case reflect.Map:
m := Schema{"type": "object"}
if t.Elem().Kind() != reflect.Interface {
m["additionalProperties"] = typeToSchema(t.Elem())
}
return m
case reflect.Struct:
return structToSchema(t)
default:
return Schema{"type": "string"}
}
}
func NewTool[T any](name, description string, fn func(T) (string, error)) *ToolDef {
var zero T
t := reflect.TypeOf(zero)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
panic(fmt.Sprintf("NewTool: %T 不是结构体类型", zero))
}
schema := structToSchema(t)
return &ToolDef{
Name: name,
Description: description,
Parameters: schema,
Execute: func(args map[string]any) (string, error) {
data, err := json.Marshal(args)
if err != nil {
return "", fmt.Errorf("序列化参数失败: %w", err)
}
var typed T
if err := json.Unmarshal(data, &typed); err != nil {
return "", fmt.Errorf("参数解析失败: %w", err)
}
return fn(typed)
},
}
}