186 lines
4.6 KiB
Go
186 lines
4.6 KiB
Go
|
|
package provider
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"bytes"
|
|||
|
|
"context"
|
|||
|
|
"encoding/json"
|
|||
|
|
"fmt"
|
|||
|
|
"io"
|
|||
|
|
"net/http"
|
|||
|
|
"time"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// SiliconFlowProvider 硅基流动厂商实现
|
|||
|
|
type SiliconFlowProvider struct {
|
|||
|
|
config ProviderConfig
|
|||
|
|
client *http.Client
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewSiliconFlowProvider 创建硅基流动厂商实例
|
|||
|
|
func NewSiliconFlowProvider(config ProviderConfig) (Provider, error) {
|
|||
|
|
return &SiliconFlowProvider{
|
|||
|
|
config: config,
|
|||
|
|
client: &http.Client{
|
|||
|
|
Timeout: 30 * time.Second,
|
|||
|
|
},
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Name 返回厂商名称
|
|||
|
|
func (p *SiliconFlowProvider) Name() string {
|
|||
|
|
return "siliconflow"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Validate 验证配置
|
|||
|
|
func (p *SiliconFlowProvider) Validate() error {
|
|||
|
|
if p.config.APIKey == "" {
|
|||
|
|
return fmt.Errorf("siliconflow: API key 不能为空")
|
|||
|
|
}
|
|||
|
|
if p.config.APIHost == "" {
|
|||
|
|
p.config.APIHost = "https://api.siliconflow.cn/v1"
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Translate 调用硅基流动API
|
|||
|
|
func (p *SiliconFlowProvider) Translate(ctx context.Context, req *TranslateRequest) (*TranslateResponse, error) {
|
|||
|
|
// 构建请求体
|
|||
|
|
requestBody := map[string]interface{}{
|
|||
|
|
"model": p.config.Model,
|
|||
|
|
"messages": []map[string]string{
|
|||
|
|
{
|
|||
|
|
"role": "user",
|
|||
|
|
"content": req.Text,
|
|||
|
|
},
|
|||
|
|
},
|
|||
|
|
"stream": false,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 如果有Prompt,添加到系统消息
|
|||
|
|
if req.Prompt != "" {
|
|||
|
|
messages := requestBody["messages"].([]map[string]string)
|
|||
|
|
requestBody["messages"] = append([]map[string]string{
|
|||
|
|
{
|
|||
|
|
"role": "system",
|
|||
|
|
"content": req.Prompt,
|
|||
|
|
},
|
|||
|
|
}, messages...)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 序列化请求体
|
|||
|
|
jsonData, err := json.Marshal(requestBody)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, NewTranslateErrorWithDetails("SERIALIZATION_ERROR", "请求序列化失败", err.Error())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 创建HTTP请求
|
|||
|
|
url := fmt.Sprintf("%s/chat/completions", p.config.APIHost)
|
|||
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, NewTranslateErrorWithDetails("REQUEST_ERROR", "创建请求失败", err.Error())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 设置请求头
|
|||
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|||
|
|
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.config.APIKey))
|
|||
|
|
|
|||
|
|
// 发送请求
|
|||
|
|
resp, err := p.client.Do(httpReq)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, NewTranslateErrorWithDetails("NETWORK_ERROR", "请求失败", err.Error())
|
|||
|
|
}
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
|
|||
|
|
// 读取响应
|
|||
|
|
body, err := io.ReadAll(resp.Body)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, NewTranslateErrorWithDetails("RESPONSE_ERROR", "读取响应失败", err.Error())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查HTTP状态码
|
|||
|
|
if resp.StatusCode != http.StatusOK {
|
|||
|
|
return nil, NewTranslateErrorWithDetails(
|
|||
|
|
"HTTP_ERROR",
|
|||
|
|
fmt.Sprintf("HTTP错误: %d", resp.StatusCode),
|
|||
|
|
string(body),
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 解析响应
|
|||
|
|
var apiResp SiliconFlowResponse
|
|||
|
|
if err := json.Unmarshal(body, &apiResp); err != nil {
|
|||
|
|
return nil, NewTranslateErrorWithDetails("PARSE_ERROR", "解析响应失败", err.Error())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查API错误
|
|||
|
|
if apiResp.Error != nil {
|
|||
|
|
return nil, NewTranslateErrorWithDetails(
|
|||
|
|
"API_ERROR",
|
|||
|
|
apiResp.Error.Message,
|
|||
|
|
apiResp.Error.Code,
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 构建响应
|
|||
|
|
if len(apiResp.Choices) == 0 {
|
|||
|
|
return nil, NewTranslateError("NO_RESPONSE", "API返回空响应")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
translatedText := apiResp.Choices[0].Message.Content
|
|||
|
|
|
|||
|
|
return &TranslateResponse{
|
|||
|
|
Text: translatedText,
|
|||
|
|
FromLang: req.FromLang,
|
|||
|
|
ToLang: req.ToLang,
|
|||
|
|
Model: apiResp.Model,
|
|||
|
|
Usage: &Usage{
|
|||
|
|
PromptTokens: apiResp.Usage.PromptTokens,
|
|||
|
|
CompletionTokens: apiResp.Usage.CompletionTokens,
|
|||
|
|
TotalTokens: apiResp.Usage.TotalTokens,
|
|||
|
|
},
|
|||
|
|
RawResponse: body,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SiliconFlowResponse 硅基流动API响应
|
|||
|
|
type SiliconFlowResponse struct {
|
|||
|
|
ID string `json:"id"`
|
|||
|
|
Object string `json:"object"`
|
|||
|
|
Created int64 `json:"created"`
|
|||
|
|
Model string `json:"model"`
|
|||
|
|
Choices []SiliconFlowChoice `json:"choices"`
|
|||
|
|
Usage SiliconFlowUsage `json:"usage"`
|
|||
|
|
Error *SiliconFlowError `json:"error,omitempty"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SiliconFlowChoice 选择项
|
|||
|
|
type SiliconFlowChoice struct {
|
|||
|
|
Index int `json:"index"`
|
|||
|
|
Message SiliconFlowMessage `json:"message"`
|
|||
|
|
FinishReason string `json:"finish_reason"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SiliconFlowMessage 消息
|
|||
|
|
type SiliconFlowMessage struct {
|
|||
|
|
Role string `json:"role"`
|
|||
|
|
Content string `json:"content"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SiliconFlowUsage 用量
|
|||
|
|
type SiliconFlowUsage struct {
|
|||
|
|
PromptTokens int `json:"prompt_tokens"`
|
|||
|
|
CompletionTokens int `json:"completion_tokens"`
|
|||
|
|
TotalTokens int `json:"total_tokens"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SiliconFlowError 错误
|
|||
|
|
type SiliconFlowError struct {
|
|||
|
|
Code string `json:"code"`
|
|||
|
|
Message string `json:"message"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// init 注册硅基流动厂商
|
|||
|
|
func init() {
|
|||
|
|
RegisterProvider("siliconflow", NewSiliconFlowProvider)
|
|||
|
|
}
|