package config import ( "os" "path/filepath" "testing" ) func TestConfigLoader_Load(t *testing.T) { // 创建临时配置文件 tmpDir := t.TempDir() configFile := filepath.Join(tmpDir, "config.yaml") configContent := ` default_provider: "siliconflow" default_model: "gpt-3.5-turbo" timeout: 30 providers: siliconflow: api_host: "https://api.siliconflow.cn/v1" api_key: "${TEST_API_KEY}" model: "siliconflow-base" enabled: true prompts: simple: "请用简单易懂的语言翻译以下内容。" ` if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil { t.Fatalf("创建配置文件失败: %v", err) } // 设置环境变量 os.Setenv("TEST_API_KEY", "test-key-123") defer os.Unsetenv("TEST_API_KEY") // 加载配置 loader := &YAMLConfigLoader{} config, err := loader.Load(configFile) if err != nil { t.Fatalf("加载配置失败: %v", err) } // 验证配置 if config.DefaultProvider != "siliconflow" { t.Errorf("Expected default provider 'siliconflow', got '%s'", config.DefaultProvider) } if config.DefaultModel != "gpt-3.5-turbo" { t.Errorf("Expected default model 'gpt-3.5-turbo', got '%s'", config.DefaultModel) } if config.Timeout != 30 { t.Errorf("Expected timeout 30, got %d", config.Timeout) } // 验证厂商配置 providerConfig, err := config.GetProviderConfig("siliconflow") if err != nil { t.Fatalf("获取厂商配置失败: %v", err) } if providerConfig.APIKey != "test-key-123" { t.Errorf("Expected API key 'test-key-123', got '%s'", providerConfig.APIKey) } if providerConfig.APIHost != "https://api.siliconflow.cn/v1" { t.Errorf("Expected API host 'https://api.siliconflow.cn/v1', got '%s'", providerConfig.APIHost) } // 验证Prompt prompt, exists := config.GetPrompt("simple") if !exists { t.Error("Prompt 'simple' should exist") } if prompt != "请用简单易懂的语言翻译以下内容。" { t.Errorf("Unexpected prompt content: %s", prompt) } } func TestConfig_Validate(t *testing.T) { config := &Config{ DefaultProvider: "siliconflow", DefaultModel: "gpt-3.5-turbo", Timeout: 30, Providers: map[string]ProviderConfig{ "siliconflow": { APIHost: "https://api.siliconflow.cn/v1", APIKey: "test-key", Model: "siliconflow-base", Enabled: true, }, }, Prompts: map[string]string{ "simple": "test prompt", }, } if err := config.Validate(); err != nil { t.Errorf("Config validation failed: %v", err) } } func TestConfig_GetProviderConfig(t *testing.T) { config := &Config{ Providers: map[string]ProviderConfig{ "siliconflow": { APIHost: "https://api.siliconflow.cn/v1", APIKey: "test-key", Model: "siliconflow-base", Enabled: true, }, "volcano": { APIHost: "https://api.volcengine.com/v1", APIKey: "test-key", Model: "volcano-chat", Enabled: false, }, }, } // 测试获取启用的厂商 _, err := config.GetProviderConfig("siliconflow") if err != nil { t.Errorf("Should get enabled provider: %v", err) } // 测试获取禁用的厂商 _, err = config.GetProviderConfig("volcano") if err == nil { t.Error("Should return error for disabled provider") } // 测试获取不存在的厂商 _, err = config.GetProviderConfig("nonexistent") if err == nil { t.Error("Should return error for non-existent provider") } } func TestConfig_GetEnabledProviders(t *testing.T) { config := &Config{ Providers: map[string]ProviderConfig{ "siliconflow": {Enabled: true}, "volcano": {Enabled: false}, "qwen": {Enabled: true}, }, } enabled := config.GetEnabledProviders() if len(enabled) != 2 { t.Errorf("Expected 2 enabled providers, got %d", len(enabled)) } } func TestConfig_String(t *testing.T) { config := &Config{ DefaultProvider: "siliconflow", DefaultModel: "gpt-3.5-turbo", Timeout: 30, Providers: map[string]ProviderConfig{ "siliconflow": { Enabled: true, Model: "siliconflow-base", }, }, Prompts: map[string]string{ "simple": "test prompt", }, } str := config.String() if str == "" { t.Error("String representation should not be empty") } }