178 lines
4.1 KiB
Go
178 lines
4.1 KiB
Go
|
|
package config
|
||
|
|
|
||
|
|
import (
|
||
|
|
"os"
|
||
|
|
"path/filepath"
|
||
|
|
"testing"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestConfigLoader_Load(t *testing.T) {
|
||
|
|
// 创建临时配置文件
|
||
|
|
tmpDir := t.TempDir()
|
||
|
|
configFile := filepath.Join(tmpDir, "config.yaml")
|
||
|
|
|
||
|
|
configContent := `
|
||
|
|
default_provider: "siliconflow"
|
||
|
|
default_model: "gpt-3.5-turbo"
|
||
|
|
timeout: 30
|
||
|
|
|
||
|
|
providers:
|
||
|
|
siliconflow:
|
||
|
|
api_host: "https://api.siliconflow.cn/v1"
|
||
|
|
api_key: "${TEST_API_KEY}"
|
||
|
|
model: "siliconflow-base"
|
||
|
|
enabled: true
|
||
|
|
|
||
|
|
prompts:
|
||
|
|
simple: "请用简单易懂的语言翻译以下内容。"
|
||
|
|
`
|
||
|
|
|
||
|
|
if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil {
|
||
|
|
t.Fatalf("创建配置文件失败: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 设置环境变量
|
||
|
|
os.Setenv("TEST_API_KEY", "test-key-123")
|
||
|
|
defer os.Unsetenv("TEST_API_KEY")
|
||
|
|
|
||
|
|
// 加载配置
|
||
|
|
loader := &YAMLConfigLoader{}
|
||
|
|
config, err := loader.Load(configFile)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("加载配置失败: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 验证配置
|
||
|
|
if config.DefaultProvider != "siliconflow" {
|
||
|
|
t.Errorf("Expected default provider 'siliconflow', got '%s'", config.DefaultProvider)
|
||
|
|
}
|
||
|
|
|
||
|
|
if config.DefaultModel != "gpt-3.5-turbo" {
|
||
|
|
t.Errorf("Expected default model 'gpt-3.5-turbo', got '%s'", config.DefaultModel)
|
||
|
|
}
|
||
|
|
|
||
|
|
if config.Timeout != 30 {
|
||
|
|
t.Errorf("Expected timeout 30, got %d", config.Timeout)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 验证厂商配置
|
||
|
|
providerConfig, err := config.GetProviderConfig("siliconflow")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("获取厂商配置失败: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if providerConfig.APIKey != "test-key-123" {
|
||
|
|
t.Errorf("Expected API key 'test-key-123', got '%s'", providerConfig.APIKey)
|
||
|
|
}
|
||
|
|
|
||
|
|
if providerConfig.APIHost != "https://api.siliconflow.cn/v1" {
|
||
|
|
t.Errorf("Expected API host 'https://api.siliconflow.cn/v1', got '%s'", providerConfig.APIHost)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 验证Prompt
|
||
|
|
prompt, exists := config.GetPrompt("simple")
|
||
|
|
if !exists {
|
||
|
|
t.Error("Prompt 'simple' should exist")
|
||
|
|
}
|
||
|
|
if prompt != "请用简单易懂的语言翻译以下内容。" {
|
||
|
|
t.Errorf("Unexpected prompt content: %s", prompt)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestConfig_Validate(t *testing.T) {
|
||
|
|
config := &Config{
|
||
|
|
DefaultProvider: "siliconflow",
|
||
|
|
DefaultModel: "gpt-3.5-turbo",
|
||
|
|
Timeout: 30,
|
||
|
|
Providers: map[string]ProviderConfig{
|
||
|
|
"siliconflow": {
|
||
|
|
APIHost: "https://api.siliconflow.cn/v1",
|
||
|
|
APIKey: "test-key",
|
||
|
|
Model: "siliconflow-base",
|
||
|
|
Enabled: true,
|
||
|
|
},
|
||
|
|
},
|
||
|
|
Prompts: map[string]string{
|
||
|
|
"simple": "test prompt",
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := config.Validate(); err != nil {
|
||
|
|
t.Errorf("Config validation failed: %v", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestConfig_GetProviderConfig(t *testing.T) {
|
||
|
|
config := &Config{
|
||
|
|
Providers: map[string]ProviderConfig{
|
||
|
|
"siliconflow": {
|
||
|
|
APIHost: "https://api.siliconflow.cn/v1",
|
||
|
|
APIKey: "test-key",
|
||
|
|
Model: "siliconflow-base",
|
||
|
|
Enabled: true,
|
||
|
|
},
|
||
|
|
"volcano": {
|
||
|
|
APIHost: "https://api.volcengine.com/v1",
|
||
|
|
APIKey: "test-key",
|
||
|
|
Model: "volcano-chat",
|
||
|
|
Enabled: false,
|
||
|
|
},
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
// 测试获取启用的厂商
|
||
|
|
_, err := config.GetProviderConfig("siliconflow")
|
||
|
|
if err != nil {
|
||
|
|
t.Errorf("Should get enabled provider: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 测试获取禁用的厂商
|
||
|
|
_, err = config.GetProviderConfig("volcano")
|
||
|
|
if err == nil {
|
||
|
|
t.Error("Should return error for disabled provider")
|
||
|
|
}
|
||
|
|
|
||
|
|
// 测试获取不存在的厂商
|
||
|
|
_, err = config.GetProviderConfig("nonexistent")
|
||
|
|
if err == nil {
|
||
|
|
t.Error("Should return error for non-existent provider")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestConfig_GetEnabledProviders(t *testing.T) {
|
||
|
|
config := &Config{
|
||
|
|
Providers: map[string]ProviderConfig{
|
||
|
|
"siliconflow": {Enabled: true},
|
||
|
|
"volcano": {Enabled: false},
|
||
|
|
"qwen": {Enabled: true},
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
enabled := config.GetEnabledProviders()
|
||
|
|
if len(enabled) != 2 {
|
||
|
|
t.Errorf("Expected 2 enabled providers, got %d", len(enabled))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestConfig_String(t *testing.T) {
|
||
|
|
config := &Config{
|
||
|
|
DefaultProvider: "siliconflow",
|
||
|
|
DefaultModel: "gpt-3.5-turbo",
|
||
|
|
Timeout: 30,
|
||
|
|
Providers: map[string]ProviderConfig{
|
||
|
|
"siliconflow": {
|
||
|
|
Enabled: true,
|
||
|
|
Model: "siliconflow-base",
|
||
|
|
},
|
||
|
|
},
|
||
|
|
Prompts: map[string]string{
|
||
|
|
"simple": "test prompt",
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
str := config.String()
|
||
|
|
if str == "" {
|
||
|
|
t.Error("String representation should not be empty")
|
||
|
|
}
|
||
|
|
}
|