Files
HxClaw/cmd/hxclaw/main.go

217 lines
5.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"fmt"
"os"
"strings"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
"github.com/muesli/termenv"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
const Logo = "🦐"
func main() {
fmt.Printf("%s HxClaw - PicoClaw 增强版 CLI\n\n", Logo)
cfg, err := internal.LoadConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "错误:加载配置失败: %v\n", err)
os.Exit(1)
}
logger.ConfigureFromEnv()
provider, modelID, err := providers.CreateProvider(cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "错误:创建 Provider 失败: %v\n", err)
os.Exit(1)
}
if modelID != "" {
cfg.Agents.Defaults.ModelName = modelID
}
msgBus := bus.NewMessageBus()
defer msgBus.Close()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
defer agentLoop.Close()
startupInfo := agentLoop.GetStartupInfo()
logger.InfoCF("hxclaw", "HxClaw 已初始化",
map[string]any{
"tools_count": startupInfo["tools"].(map[string]any)["count"],
"skills_total": startupInfo["skills"].(map[string]any)["total"],
"skills_available": startupInfo["skills"].(map[string]any)["available"],
})
fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", Logo)
interactiveMode(agentLoop, "cli:default")
}
func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
prompt := fmt.Sprintf("%s You: ", Logo)
rl, err := internal.NewReadline(prompt)
if err != nil {
fmt.Printf("初始化 readline 失败: %v\n", err)
fmt.Println("回退到简单输入模式...")
simpleInteractiveMode(agentLoop, sessionKey)
return
}
defer rl.Close()
for {
line, err := rl.Readline()
if err != nil {
if err == internal.ErrInterrupt || err == internal.ErrEOF {
fmt.Println("\n再见!")
return
}
fmt.Printf("读取输入错误: %v\n", err)
continue
}
input := line
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("再见!")
return
}
runWithStreaming(agentLoop, input, sessionKey)
}
}
func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
reader := internal.NewSimpleReader()
for {
fmt.Print(fmt.Sprintf("%s You: ", Logo))
line, err := reader.ReadString()
if err != nil {
if err == internal.ErrEOF {
fmt.Println("\n再见!")
return
}
fmt.Printf("读取输入错误: %v\n", err)
continue
}
input := line
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("再见!")
return
}
runWithStreaming(agentLoop, input, sessionKey)
}
}
// runWithStreaming 尝试使用流式输出,如果 Provider 不支持则回退到普通模式
func runWithStreaming(agentLoop *agent.AgentLoop, input, sessionKey string) {
agentInstance := agentLoop.GetRegistry().GetDefaultAgent()
if agentInstance == nil {
fmt.Println("错误:无法获取 Agent 实例")
return
}
provider := agentInstance.Provider
ctx := context.Background()
// 判断是否支持流式
if sp, ok := provider.(providers.StreamingProvider); ok {
// 从 session 中获取历史消息
history := agentInstance.Sessions.GetHistory(sessionKey)
summary := agentInstance.Sessions.GetSummary(sessionKey)
// 使用 ContextBuilder 构建消息,包含历史
messages := agentInstance.ContextBuilder.BuildMessages(
history,
summary,
input,
nil, // media
"cli", // channel
sessionKey,
"", // senderID
"", // senderDisplayName
)
// 获取工具定义
toolDefs := agentInstance.Tools.ToProviderDefs()
// 启动 spinner显示 "思考中..."
spinner := internal.NewSpinner("思考中...")
spinner.Start()
fmt.Print("\n")
var result strings.Builder
var printedLen int
firstToken := true
_, err := sp.ChatStream(ctx, messages, toolDefs, agentInstance.Model, nil, func(accumulated string) {
if firstToken && len(accumulated) > 0 {
spinner.Stop()
firstToken = false
}
if len(accumulated) > printedLen {
newText := accumulated[printedLen:]
fmt.Print(newText)
os.Stdout.Sync()
result.WriteString(newText)
printedLen = len(accumulated)
}
})
if err != nil {
fmt.Printf("流式调用错误: %v\n", err)
return
}
if result.Len() > 0 {
allOutput := result.String()
rendered := internal.RenderMarkdown(allOutput)
if rendered != allOutput && rendered != "" {
// 计算流式输出的行数,清除
lines := strings.Count(allOutput, "\n") + 1
output := termenv.DefaultOutput()
output.CursorUp(1)
output.ClearLine()
output.ClearLines(lines)
fmt.Print(rendered)
fmt.Println()
fmt.Println()
} else {
fmt.Println()
fmt.Println()
}
agentInstance.Sessions.AddMessage(sessionKey, "user", input)
agentInstance.Sessions.AddMessage(sessionKey, "assistant", allOutput)
}
} else {
response, err := agentLoop.ProcessDirect(ctx, input, sessionKey)
if err != nil {
fmt.Printf("错误: %v\n", err)
return
}
rendered := internal.RenderMarkdown(response)
if rendered != "" && rendered != response {
fmt.Printf("\n%s\n\n", rendered)
} else {
fmt.Printf("\n%s %s\n\n", Logo, response)
}
}
}