package main import ( "bufio" "bytes" "encoding/json" "fmt" "io" "net/http" "os" "strings" "sync" "time" "hub.gaomia.site/titor/YunShu/pkg/mdprint" ) var ( llmOnce sync.Once llmHost = "https://ark.cn-beijing.volces.com/api/v3/chat/completions" llmModel = "doubao-seed-2-0-pro-260215" llmKey = "" ) func loadLLMConfig() { llmOnce.Do(func() { cfg, err := LoadConfig() if err == nil { if cfg.LLM.Host != "" { llmHost = cfg.LLM.Host } if cfg.LLM.Model != "" { llmModel = cfg.LLM.Model } if cfg.LLM.Key != "" { llmKey = cfg.LLM.Key } } if v := os.Getenv("LLM_ENDPOINT"); v != "" { llmHost = v } if v := os.Getenv("LLM_MODEL"); v != "" { llmModel = v } if v := os.Getenv("LLM_API_KEY"); v != "" { llmKey = v } if v := os.Getenv("OPENAI_API_KEY"); v != "" && llmKey == "" { llmKey = v } }) } // GetLLMKey 获取 API Key,优先使用已加载的密钥 func GetLLMKey() (string, error) { loadLLMConfig() if llmKey == "" { return "", fmt.Errorf("未配置 API Key。请运行 'weather-cia onboard' 初始化,或设置 LLM_API_KEY 环境变量") } return llmKey, nil } // CallLLM 调用大模型 API(兼容 OpenAI Chat Completion 格式) func CallLLM(messages []Message, toolDefs []ToolDef) (*OpenAIResponse, error) { loadLLMConfig() apiKey, err := GetLLMKey() if err != nil { return nil, err } start := time.Now() reqBody := map[string]interface{}{ "model": llmModel, "messages": messages, } // 注册工具定义 if len(toolDefs) > 0 { tools := make([]OpenAITool, 0, len(toolDefs)) for _, td := range toolDefs { tools = append(tools, OpenAITool{ Type: "function", Function: OpenAIToolFunc{ Name: td.Name, Description: td.Description, Parameters: td.Parameters, }, }) } reqBody["tools"] = tools reqBody["tool_choice"] = "auto" } body, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("序列化请求失败: %w", err) } req, err := http.NewRequest("POST", llmHost, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+apiKey) client := &http.Client{Timeout: 120 * time.Second} resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("请求 LLM 失败: %w", err) } defer resp.Body.Close() respData, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("读取响应失败: %w", err) } if resp.StatusCode != 200 { var errResp OpenAIErrorResponse if json.Unmarshal(respData, &errResp) == nil && errResp.Error.Message != "" { return nil, fmt.Errorf("LLM API 错误 [%s]: %s", errResp.Error.Type, errResp.Error.Message) } return nil, fmt.Errorf("LLM API 返回 HTTP %d: %s", resp.StatusCode, string(respData)) } var result OpenAIResponse if err := json.Unmarshal(respData, &result); err != nil { return nil, fmt.Errorf("解析响应失败: %w", err) } if len(result.Choices) == 0 { return nil, fmt.Errorf("LLM 返回空结果") } if result.Usage.TotalTokens > 0 { infoLog("LLM 调用完成", "tokens", result.Usage.TotalTokens, "duration", time.Since(start).Round(time.Millisecond*100).String(), ) } return &result, nil } // ============================================================ // 流式输出 (SSE) // ============================================================ type accumulatedToolCall struct { ID string Type string Name string Args string } type sseChunk struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []sseChoice `json:"choices"` Usage *OpenAIUsage `json:"usage,omitempty"` } type sseChoice struct { Index int `json:"index"` Delta sseDelta `json:"delta"` FinishReason *string `json:"finish_reason"` } type sseDelta struct { Role string `json:"role,omitempty"` Content string `json:"content,omitempty"` ToolCalls []sseToolCallDelta `json:"tool_calls,omitempty"` } type sseToolCallDelta struct { Index int `json:"index"` ID string `json:"id,omitempty"` Type string `json:"type,omitempty"` Function *sseToolCallFunctionDelta `json:"function,omitempty"` } type sseToolCallFunctionDelta struct { Name string `json:"name,omitempty"` Arguments string `json:"arguments,omitempty"` } // CallLLMStream 流式调用 LLM,按 \n\n 段落边界缓冲后通过 mdprint 渲染到 stdout func CallLLMStream(messages []Message, toolDefs []ToolDef) (*OpenAIResponse, error) { loadLLMConfig() apiKey, err := GetLLMKey() if err != nil { return nil, err } start := time.Now() reqBody := map[string]any{ "model": llmModel, "messages": messages, "stream": true, } if len(toolDefs) > 0 { tools := make([]OpenAITool, 0, len(toolDefs)) for _, td := range toolDefs { tools = append(tools, OpenAITool{ Type: "function", Function: OpenAIToolFunc{ Name: td.Name, Description: td.Description, Parameters: td.Parameters, }, }) } reqBody["tools"] = tools reqBody["tool_choice"] = "auto" } body, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("序列化请求失败: %w", err) } req, err := http.NewRequest("POST", llmHost, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+apiKey) client := &http.Client{Timeout: 120 * time.Second} resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("请求 LLM 失败: %w", err) } defer resp.Body.Close() if resp.StatusCode != 200 { respData, _ := io.ReadAll(resp.Body) var errResp OpenAIErrorResponse if json.Unmarshal(respData, &errResp) == nil && errResp.Error.Message != "" { return nil, fmt.Errorf("LLM API 错误 [%s]: %s", errResp.Error.Type, errResp.Error.Message) } return nil, fmt.Errorf("LLM API 返回 HTTP %d: %s", resp.StatusCode, string(respData)) } reader := bufio.NewReader(resp.Body) var fullContent strings.Builder var blockBuf strings.Builder toolCallAccums := make(map[int]*accumulatedToolCall) var responseID, responseModel string var responseCreated int64 var usage *OpenAIUsage for { line, err := reader.ReadString('\n') if err != nil { if err == io.EOF { break } return nil, fmt.Errorf("读取流响应失败: %w", err) } line = strings.TrimSpace(line) if line == "" { continue } if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { break } var chunk sseChunk if err := json.Unmarshal([]byte(data), &chunk); err != nil { continue } if responseID == "" && chunk.ID != "" { responseID = chunk.ID } if responseModel == "" && chunk.Model != "" { responseModel = chunk.Model } if responseCreated == 0 && chunk.Created != 0 { responseCreated = chunk.Created } if chunk.Usage != nil { usage = chunk.Usage } for _, choice := range chunk.Choices { delta := choice.Delta if delta.Content != "" { fullContent.WriteString(delta.Content) blockBuf.WriteString(delta.Content) tryFlushBlocks(&blockBuf) } for _, tc := range delta.ToolCalls { acc, ok := toolCallAccums[tc.Index] if !ok { acc = &accumulatedToolCall{} toolCallAccums[tc.Index] = acc } if tc.ID != "" { acc.ID = tc.ID } if tc.Type != "" { acc.Type = tc.Type } if tc.Function != nil { if tc.Function.Name != "" { acc.Name = tc.Function.Name } acc.Args += tc.Function.Arguments } } } } // 流结束,刷残段 if blockBuf.Len() > 0 { mdprint.Print(blockBuf.String()) } // 重建响应 var choice OpenAIChoice if len(toolCallAccums) > 0 { var tcs []ToolCall for i := 0; i < len(toolCallAccums); i++ { acc := toolCallAccums[i] if acc == nil { continue } tcs = append(tcs, ToolCall{ ID: acc.ID, Type: acc.Type, Function: ToolCallFunction{ Name: acc.Name, Arguments: acc.Args, }, }) } choice.Message.ToolCalls = tcs } else { content := fullContent.String() if content != "" { choice.Message.Content = &content } } result := &OpenAIResponse{ ID: responseID, Object: "chat.completion", Created: responseCreated, Model: responseModel, Choices: []OpenAIChoice{choice}, } if usage != nil && usage.TotalTokens > 0 { result.Usage = *usage infoLog("LLM 调用完成", "tokens", usage.TotalTokens, "duration", time.Since(start).Round(time.Millisecond*100).String(), ) } return result, nil } // tryFlushBlocks 检测 blockBuf 中是否有完整的 Markdown block(以 \n\n 为界) // 有则通过 mdprint 渲染到 stdout,剩余残段留在 buf 中继续缓冲 func tryFlushBlocks(buf *strings.Builder) { content := buf.String() idx := strings.LastIndex(content, "\n\n") if idx < 0 { return } complete := strings.TrimRight(content[:idx], "\n\r\t ") if complete == "" { return } mdprint.Print(complete) remainder := content[idx+2:] buf.Reset() if remainder != "" { buf.WriteString(remainder) } }