Files
yoyo/internal/provider/siliconflow.go

186 lines
4.6 KiB
Go
Raw Permalink Normal View History

package provider
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
// SiliconFlowProvider 硅基流动厂商实现
type SiliconFlowProvider struct {
config ProviderConfig
client *http.Client
}
// NewSiliconFlowProvider 创建硅基流动厂商实例
func NewSiliconFlowProvider(config ProviderConfig) (Provider, error) {
return &SiliconFlowProvider{
config: config,
client: &http.Client{
Timeout: 30 * time.Second,
},
}, nil
}
// Name 返回厂商名称
func (p *SiliconFlowProvider) Name() string {
return "siliconflow"
}
// Validate 验证配置
func (p *SiliconFlowProvider) Validate() error {
if p.config.APIKey == "" {
return fmt.Errorf("siliconflow: API key 不能为空")
}
if p.config.APIHost == "" {
p.config.APIHost = "https://api.siliconflow.cn/v1"
}
return nil
}
// Translate 调用硅基流动API
func (p *SiliconFlowProvider) Translate(ctx context.Context, req *TranslateRequest) (*TranslateResponse, error) {
// 构建请求体
requestBody := map[string]interface{}{
"model": p.config.Model,
"messages": []map[string]string{
{
"role": "user",
"content": req.Text,
},
},
"stream": false,
}
// 如果有Prompt添加到系统消息
if req.Prompt != "" {
messages := requestBody["messages"].([]map[string]string)
requestBody["messages"] = append([]map[string]string{
{
"role": "system",
"content": req.Prompt,
},
}, messages...)
}
// 序列化请求体
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, NewTranslateErrorWithDetails("SERIALIZATION_ERROR", "请求序列化失败", err.Error())
}
// 创建HTTP请求
url := fmt.Sprintf("%s/chat/completions", p.config.APIHost)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, NewTranslateErrorWithDetails("REQUEST_ERROR", "创建请求失败", err.Error())
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.config.APIKey))
// 发送请求
resp, err := p.client.Do(httpReq)
if err != nil {
return nil, NewTranslateErrorWithDetails("NETWORK_ERROR", "请求失败", err.Error())
}
defer resp.Body.Close()
// 读取响应
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, NewTranslateErrorWithDetails("RESPONSE_ERROR", "读取响应失败", err.Error())
}
// 检查HTTP状态码
if resp.StatusCode != http.StatusOK {
return nil, NewTranslateErrorWithDetails(
"HTTP_ERROR",
fmt.Sprintf("HTTP错误: %d", resp.StatusCode),
string(body),
)
}
// 解析响应
var apiResp SiliconFlowResponse
if err := json.Unmarshal(body, &apiResp); err != nil {
return nil, NewTranslateErrorWithDetails("PARSE_ERROR", "解析响应失败", err.Error())
}
// 检查API错误
if apiResp.Error != nil {
return nil, NewTranslateErrorWithDetails(
"API_ERROR",
apiResp.Error.Message,
apiResp.Error.Code,
)
}
// 构建响应
if len(apiResp.Choices) == 0 {
return nil, NewTranslateError("NO_RESPONSE", "API返回空响应")
}
translatedText := apiResp.Choices[0].Message.Content
return &TranslateResponse{
Text: translatedText,
FromLang: req.FromLang,
ToLang: req.ToLang,
Model: apiResp.Model,
Usage: &Usage{
PromptTokens: apiResp.Usage.PromptTokens,
CompletionTokens: apiResp.Usage.CompletionTokens,
TotalTokens: apiResp.Usage.TotalTokens,
},
RawResponse: body,
}, nil
}
// SiliconFlowResponse 硅基流动API响应
type SiliconFlowResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []SiliconFlowChoice `json:"choices"`
Usage SiliconFlowUsage `json:"usage"`
Error *SiliconFlowError `json:"error,omitempty"`
}
// SiliconFlowChoice 选择项
type SiliconFlowChoice struct {
Index int `json:"index"`
Message SiliconFlowMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
// SiliconFlowMessage 消息
type SiliconFlowMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// SiliconFlowUsage 用量
type SiliconFlowUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// SiliconFlowError 错误
type SiliconFlowError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// init 注册硅基流动厂商
func init() {
RegisterProvider("siliconflow", NewSiliconFlowProvider)
}