feat: implement core architecture (v0.0.2)
- Implement Config class with YAML loading and environment variable support - Implement Provider interface and factory pattern - Implement SiliconFlow provider as example - Implement Translator core class with prompt management - Create CLI entry point - Add configuration template and unit tests - Update changelog and discussion records Version: 0.0.2
This commit is contained in:
99
internal/provider/factory.go
Normal file
99
internal/provider/factory.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ProviderFactory 厂商工厂
|
||||
type ProviderFactory struct {
|
||||
providers map[string]func(ProviderConfig) (Provider, error)
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewProviderFactory 创建工厂实例
|
||||
func NewProviderFactory() *ProviderFactory {
|
||||
factory := &ProviderFactory{
|
||||
providers: make(map[string]func(ProviderConfig) (Provider, error)),
|
||||
}
|
||||
|
||||
// 注册所有厂商(延迟注册,避免循环依赖)
|
||||
// 实际注册在init()函数中完成
|
||||
|
||||
return factory
|
||||
}
|
||||
|
||||
// Register 注册厂商构造函数
|
||||
func (f *ProviderFactory) Register(name string, creator func(ProviderConfig) (Provider, error)) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.providers[name] = creator
|
||||
}
|
||||
|
||||
// Create 创建厂商实例
|
||||
func (f *ProviderFactory) Create(name string, config ProviderConfig) (Provider, error) {
|
||||
f.mu.RLock()
|
||||
creator, exists := f.providers[name]
|
||||
f.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("不支持的厂商: %s", name)
|
||||
}
|
||||
|
||||
provider, err := creator(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建厂商实例失败: %w", err)
|
||||
}
|
||||
|
||||
if err := provider.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("厂商配置验证失败: %w", err)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// HasProvider 检查是否支持指定厂商
|
||||
func (f *ProviderFactory) HasProvider(name string) bool {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
_, exists := f.providers[name]
|
||||
return exists
|
||||
}
|
||||
|
||||
// GetSupportedProviders 获取所有支持的厂商
|
||||
func (f *ProviderFactory) GetSupportedProviders() []string {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
providers := make([]string, 0, len(f.providers))
|
||||
for name := range f.providers {
|
||||
providers = append(providers, name)
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
// 全局工厂实例
|
||||
var defaultFactory *ProviderFactory
|
||||
|
||||
// init 初始化默认工厂
|
||||
func init() {
|
||||
defaultFactory = NewProviderFactory()
|
||||
|
||||
// 这里可以注册厂商,但由于循环依赖,实际注册会在各厂商的init()中完成
|
||||
// 或者在main包中手动注册
|
||||
}
|
||||
|
||||
// GetDefaultFactory 获取默认工厂
|
||||
func GetDefaultFactory() *ProviderFactory {
|
||||
return defaultFactory
|
||||
}
|
||||
|
||||
// CreateProvider 使用默认工厂创建厂商
|
||||
func CreateProvider(name string, config ProviderConfig) (Provider, error) {
|
||||
return defaultFactory.Create(name, config)
|
||||
}
|
||||
|
||||
// RegisterProvider 注册厂商到默认工厂
|
||||
func RegisterProvider(name string, creator func(ProviderConfig) (Provider, error)) {
|
||||
defaultFactory.Register(name, creator)
|
||||
}
|
||||
108
internal/provider/provider.go
Normal file
108
internal/provider/provider.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Provider 厂商接口
|
||||
type Provider interface {
|
||||
// Translate 调用厂商API进行翻译
|
||||
Translate(ctx context.Context, req *TranslateRequest) (*TranslateResponse, error)
|
||||
|
||||
// Name 返回厂商名称
|
||||
Name() string
|
||||
|
||||
// Validate 验证配置是否有效
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// TranslateRequest 翻译请求
|
||||
type TranslateRequest struct {
|
||||
Text string `json:"text"`
|
||||
FromLang string `json:"from_lang"`
|
||||
ToLang string `json:"to_lang"`
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model"`
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
// TranslateResponse 翻译响应
|
||||
type TranslateResponse struct {
|
||||
Text string `json:"text"`
|
||||
FromLang string `json:"from_lang"`
|
||||
ToLang string `json:"to_lang"`
|
||||
Model string `json:"model"`
|
||||
Usage *Usage `json:"usage"`
|
||||
RawResponse []byte `json:"raw_response,omitempty"`
|
||||
}
|
||||
|
||||
// Usage 用量统计
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// ProviderConfig 厂商配置
|
||||
type ProviderConfig struct {
|
||||
APIHost string `json:"api_host"`
|
||||
APIKey string `json:"api_key"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// TranslateError 翻译错误
|
||||
type TranslateError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// Error 实现error接口
|
||||
func (e *TranslateError) Error() string {
|
||||
if e.Details != "" {
|
||||
return fmt.Sprintf("%s: %s (%s)", e.Code, e.Message, e.Details)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// NewTranslateError 创建翻译错误
|
||||
func NewTranslateError(code, message string) *TranslateError {
|
||||
return &TranslateError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTranslateErrorWithDetails 创建带详情的翻译错误
|
||||
func NewTranslateErrorWithDetails(code, message, details string) *TranslateError {
|
||||
return &TranslateError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
}
|
||||
}
|
||||
|
||||
// IsNetworkError 检查是否为网络错误
|
||||
func IsNetworkError(err error) bool {
|
||||
if translateErr, ok := err.(*TranslateError); ok {
|
||||
return translateErr.Code == "NETWORK_ERROR"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsRateLimitError 检查是否为限流错误
|
||||
func IsRateLimitError(err error) bool {
|
||||
if translateErr, ok := err.(*TranslateError); ok {
|
||||
return translateErr.Code == "RATE_LIMIT"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsAuthError 检查是否为认证错误
|
||||
func IsAuthError(err error) bool {
|
||||
if translateErr, ok := err.(*TranslateError); ok {
|
||||
return translateErr.Code == "AUTH_ERROR" || translateErr.Code == "INVALID_API_KEY"
|
||||
}
|
||||
return false
|
||||
}
|
||||
185
internal/provider/siliconflow.go
Normal file
185
internal/provider/siliconflow.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SiliconFlowProvider 硅基流动厂商实现
|
||||
type SiliconFlowProvider struct {
|
||||
config ProviderConfig
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewSiliconFlowProvider 创建硅基流动厂商实例
|
||||
func NewSiliconFlowProvider(config ProviderConfig) (Provider, error) {
|
||||
return &SiliconFlowProvider{
|
||||
config: config,
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name 返回厂商名称
|
||||
func (p *SiliconFlowProvider) Name() string {
|
||||
return "siliconflow"
|
||||
}
|
||||
|
||||
// Validate 验证配置
|
||||
func (p *SiliconFlowProvider) Validate() error {
|
||||
if p.config.APIKey == "" {
|
||||
return fmt.Errorf("siliconflow: API key 不能为空")
|
||||
}
|
||||
if p.config.APIHost == "" {
|
||||
p.config.APIHost = "https://api.siliconflow.cn/v1"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Translate 调用硅基流动API
|
||||
func (p *SiliconFlowProvider) Translate(ctx context.Context, req *TranslateRequest) (*TranslateResponse, error) {
|
||||
// 构建请求体
|
||||
requestBody := map[string]interface{}{
|
||||
"model": p.config.Model,
|
||||
"messages": []map[string]string{
|
||||
{
|
||||
"role": "user",
|
||||
"content": req.Text,
|
||||
},
|
||||
},
|
||||
"stream": false,
|
||||
}
|
||||
|
||||
// 如果有Prompt,添加到系统消息
|
||||
if req.Prompt != "" {
|
||||
messages := requestBody["messages"].([]map[string]string)
|
||||
requestBody["messages"] = append([]map[string]string{
|
||||
{
|
||||
"role": "system",
|
||||
"content": req.Prompt,
|
||||
},
|
||||
}, messages...)
|
||||
}
|
||||
|
||||
// 序列化请求体
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, NewTranslateErrorWithDetails("SERIALIZATION_ERROR", "请求序列化失败", err.Error())
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
url := fmt.Sprintf("%s/chat/completions", p.config.APIHost)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, NewTranslateErrorWithDetails("REQUEST_ERROR", "创建请求失败", err.Error())
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.config.APIKey))
|
||||
|
||||
// 发送请求
|
||||
resp, err := p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, NewTranslateErrorWithDetails("NETWORK_ERROR", "请求失败", err.Error())
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, NewTranslateErrorWithDetails("RESPONSE_ERROR", "读取响应失败", err.Error())
|
||||
}
|
||||
|
||||
// 检查HTTP状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, NewTranslateErrorWithDetails(
|
||||
"HTTP_ERROR",
|
||||
fmt.Sprintf("HTTP错误: %d", resp.StatusCode),
|
||||
string(body),
|
||||
)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var apiResp SiliconFlowResponse
|
||||
if err := json.Unmarshal(body, &apiResp); err != nil {
|
||||
return nil, NewTranslateErrorWithDetails("PARSE_ERROR", "解析响应失败", err.Error())
|
||||
}
|
||||
|
||||
// 检查API错误
|
||||
if apiResp.Error != nil {
|
||||
return nil, NewTranslateErrorWithDetails(
|
||||
"API_ERROR",
|
||||
apiResp.Error.Message,
|
||||
apiResp.Error.Code,
|
||||
)
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
if len(apiResp.Choices) == 0 {
|
||||
return nil, NewTranslateError("NO_RESPONSE", "API返回空响应")
|
||||
}
|
||||
|
||||
translatedText := apiResp.Choices[0].Message.Content
|
||||
|
||||
return &TranslateResponse{
|
||||
Text: translatedText,
|
||||
FromLang: req.FromLang,
|
||||
ToLang: req.ToLang,
|
||||
Model: apiResp.Model,
|
||||
Usage: &Usage{
|
||||
PromptTokens: apiResp.Usage.PromptTokens,
|
||||
CompletionTokens: apiResp.Usage.CompletionTokens,
|
||||
TotalTokens: apiResp.Usage.TotalTokens,
|
||||
},
|
||||
RawResponse: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SiliconFlowResponse 硅基流动API响应
|
||||
type SiliconFlowResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []SiliconFlowChoice `json:"choices"`
|
||||
Usage SiliconFlowUsage `json:"usage"`
|
||||
Error *SiliconFlowError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// SiliconFlowChoice 选择项
|
||||
type SiliconFlowChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message SiliconFlowMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// SiliconFlowMessage 消息
|
||||
type SiliconFlowMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// SiliconFlowUsage 用量
|
||||
type SiliconFlowUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// SiliconFlowError 错误
|
||||
type SiliconFlowError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// init 注册硅基流动厂商
|
||||
func init() {
|
||||
RegisterProvider("siliconflow", NewSiliconFlowProvider)
|
||||
}
|
||||
Reference in New Issue
Block a user