Files
haibao-tts-cli/src/main.rs

448 lines
13 KiB
Rust
Raw Normal View History

mod cli;
mod config;
mod api;
mod ui;
mod tone;
use anyhow::{Context, Result};
use clap::Parser;
use cli::{Cli, Commands, ConfigAction};
use config::ConfigManager;
use rodio;
use std::fs;
use std::io::Read;
use std::process;
/// 程序退出码定义
///
/// 遵循 agents.md 中定义的退出码规范
#[derive(Debug)]
enum ExitCode {
Success = 0,
ArgumentError = 1,
ConfigError = 2,
ApiError = 3,
FileError = 4,
}
impl From<ExitCode> for i32 {
fn from(code: ExitCode) -> Self {
code as i32
}
}
/// 主函数
///
/// 使用 tokio 运行时处理异步 API 调用
#[tokio::main]
async fn main() {
// 解析命令行参数
let cli = Cli::parse();
// 执行程序逻辑,如果出错则处理错误并返回对应退出码
let exit_code = match run(cli).await {
Ok(_) => ExitCode::Success,
Err(e) => {
// 根据错误类型返回对应的退出码
eprintln!("错误: {:#}", e);
// 简化错误处理,根据错误信息判断类型
let error_msg = e.to_string();
if error_msg.contains("API") || error_msg.contains("请求") {
ExitCode::ApiError
} else if error_msg.contains("配置") {
ExitCode::ConfigError
} else if error_msg.contains("文件") || error_msg.contains("读取") || error_msg.contains("写入") {
ExitCode::FileError
} else {
ExitCode::ArgumentError
}
}
};
process::exit(exit_code.into());
}
/// 程序主逻辑
async fn run(cli: Cli) -> Result<()> {
match cli.command {
// 处理子命令
Some(Commands::Voices) => {
list_voices();
Ok(())
}
Some(Commands::ShowConfig) => {
show_config()
}
Some(Commands::Config { action }) => {
handle_config_command(action)
}
Some(Commands::Onboard) => {
// 引导式配置初始化
onboard().await
}
// 没有子命令时,执行语音合成
None => {
// 检查参数组合
if cli.play && cli.output.is_some() {
return Err(anyhow::anyhow!("--play 和 --output 不能同时使用"));
}
// 检查是否有输入text 或 file
if cli.text.is_none() && cli.file.is_none() {
return Err(anyhow::anyhow!(
"必须提供 --text 或 --file 参数\n使用 --help 查看帮助信息"
));
}
// 执行语音合成
let audio_data = synthesize(
cli.text,
cli.file,
&cli.voice,
&cli.format,
cli.style.as_deref(),
cli.stream,
)
.await?;
// 根据参数决定处理方式
if cli.play {
// 播放音频(流式数据需要封装成 WAV 格式)
ui::show_playback_start();
if cli.stream {
// 流式返回的是 PCM16 原始数据,需要添加 WAV 头
let wav_data = pcm16_to_wav(&audio_data);
play_audio(&wav_data)?;
} else {
play_audio(&audio_data)?;
}
ui::show_playback_complete();
} else if let Some(output_path) = cli.output {
// 保存到文件
fs::write(&output_path, &audio_data)
.with_context(|| format!("无法写入文件: {:?}", output_path))?;
ui::show_save_complete(&output_path.to_string_lossy());
} else {
// 输出到 stdout二进制流
let stdout = std::io::stdout();
let mut handle = stdout.lock();
use std::io::Write;
handle.write_all(&audio_data)
.context("无法写入标准输出")?;
handle.flush()
.context("无法刷新标准输出")?;
}
Ok(())
}
}
}
/// 列出所有可用的音色
///
/// 显示详细的音色信息,包括 Voice ID、语言、性别
fn list_voices() {
ui::show_voices();
}
/// 合法的音色列表mimo-v2.5-tts 支持)
const VALID_VOICES: &[&str] = &[
"mimo_default",
"冰糖",
"茉莉",
"苏打",
"白桦",
"Mia",
"Chloe",
"Milo",
"Dean",
];
/// 验证音色是否合法
///
/// 如果音色不在合法列表中,输出警告并使用默认音色 mimo_default
fn validate_voice(voice: &str) -> String {
if VALID_VOICES.contains(&voice) {
voice.to_string()
} else {
eprintln!("警告:无效音色 '{}',使用默认音色 'mimo_default'", voice);
"mimo_default".to_string()
}
}
/// 显示当前配置
fn show_config() -> Result<()> {
let config_manager = ConfigManager::new()
.context("无法加载配置")?;
let config = config_manager.get_config();
ui::show_config(
&config.api_key,
&config.default_voice,
&config_manager.get_config_path().to_string_lossy(),
);
Ok(())
}
/// 处理配置相关子命令
fn handle_config_command(action: ConfigAction) -> Result<()> {
match action {
ConfigAction::Set { api_key, voice, .. } => {
let mut config_manager = ConfigManager::new()
.context("无法加载配置")?;
if let Some(key) = api_key {
config_manager.set_api_key(key);
ui::show_success("API Key 已更新");
}
if let Some(v) = voice {
config_manager.set_default_voice(v);
ui::show_success("默认音色已更新");
}
config_manager.save()
.context("无法保存配置")?;
ui::show_info("📁 配置已保存到:", &config_manager.get_config_path().to_string_lossy());
}
ConfigAction::Show => {
show_config()?;
}
ConfigAction::Init => {
// 交互式初始化
ui::show_info("初始化配置...", "");
let config_manager = ConfigManager::new()
.context("无法创建配置")?;
ui::show_info("请使用以下命令设置 API Key:", "");
println!(" mimo-tts config set --api-key <YOUR_API_KEY>");
ui::show_info("配置文件将保存在:", &config_manager.get_config_path().to_string_lossy());
}
}
Ok(())
}
/// 引导式配置初始化
///
/// 交互式引导用户完成配置设置
async fn onboard() -> Result<()> {
let config_manager = ConfigManager::new()
.context("无法创建配置管理器")?;
let current_config = config_manager.get_config();
// 使用 UI 模块显示交互式表单
let result = ui::show_onboard_form(
&current_config.api_key,
&current_config.default_voice,
);
let (api_key, default_voice) = result
.map_err(|e| anyhow::anyhow!("表单输入错误: {}", e))?;
// 保存配置
let mut config_manager = ConfigManager::new()
.context("无法创建配置管理器")?;
if !api_key.is_empty() {
config_manager.set_api_key(api_key);
}
if !default_voice.is_empty() {
config_manager.set_default_voice(default_voice);
}
config_manager.save()
.context("无法保存配置")?;
ui::show_info("📁 配置已保存到:", &config_manager.get_config_path().to_string_lossy());
Ok(())
}
/// 执行语音合成
///
/// # 参数
/// - text: 直接提供的文本(可选)
/// - file: 文本文件路径(可选)
/// - voice: 音色名称
/// - format: 音频格式
/// - style: 风格描述(可选,会放在 user 消息中)
/// - stream: 是否使用流式输出
///
/// # 返回
/// 返回合成的音频数据WAV 或 PCM16 格式)
async fn synthesize(
text: Option<String>,
file: Option<std::path::PathBuf>,
voice: &str,
format: &str,
style: Option<&str>,
stream: bool,
) -> Result<Vec<u8>> {
// 获取要合成的文本
let content = if let Some(t) = text {
tone::apply_tone(&t)
} else if let Some(f) = file {
// 从文件读取文本
let mut file = fs::File::open(&f)
.with_context(|| format!("无法打开文件: {:?}", f))?;
let mut content = String::new();
file.read_to_string(&mut content)
.with_context(|| format!("无法读取文件: {:?}", f))?;
tone::apply_tone(&content)
} else {
return Err(anyhow::anyhow!("没有提供文本内容"));
};
// 验证音色是否合法,不合法则使用默认值
let validated_voice = validate_voice(voice);
// 加载配置
let config_manager = ConfigManager::new()
.context("无法加载配置")?;
let config = config_manager.get_config();
// 检查 API Key 是否设置
if config.api_key.is_empty() {
return Err(anyhow::anyhow!(
"API Key 未设置\n请使用: mimo-tts config set --api-key <YOUR_API_KEY>"
));
}
// 创建 TTS 客户端
let client = api::TtsClient::builder()
.base_url(config.base_url.clone())
.api_key(config.api_key.clone())
.build()
.context("无法创建 TTS 客户端")?;
// 流式输出时自动使用 pcm16 格式
let actual_format = if stream { "pcm16" } else { format };
// 构建请求(如果指定了风格,添加到 user 消息)
let mut builder = api::TtsRequest::builder()
.audio(api::AudioConfig {
format: actual_format.to_string(),
voice: validated_voice,
});
// 添加消息:如果指定了风格,先添加 user 消息描述风格
if let Some(s) = style {
builder = builder.add_message(api::Message {
role: "user".to_string(),
content: s.to_string(),
});
}
// 添加 assistant 消息(实际要合成的文本)
builder = builder.add_message(api::Message {
role: "assistant".to_string(),
content: content.clone(),
});
let request = builder.build();
// 调用 API 合成语音
let audio_data = if stream {
// 流式请求已在 api.rs 中处理
client
.synthesize_with_request(&request)
.await
.context("流式语音合成失败")?
} else {
client
.synthesize_with_request(&request)
.await
.context("语音合成失败")?
};
Ok(audio_data)
}
/// 单元测试模块
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exit_code() {
assert_eq!(i32::from(ExitCode::Success), 0);
assert_eq!(i32::from(ExitCode::ArgumentError), 1);
assert_eq!(i32::from(ExitCode::ConfigError), 2);
assert_eq!(i32::from(ExitCode::ApiError), 3);
assert_eq!(i32::from(ExitCode::FileError), 4);
}
}
/// 播放音频数据
///
/// 使用 rodio 直接从内存播放 WAV 音频
/// # 参数
/// - data: WAV 格式的音频数据
fn play_audio(data: &[u8]) -> Result<()> {
// 创建 rodio 音频输出流
let (_stream, stream_handle) = rodio::OutputStream::try_default()
.context("无法创建音频输出流")?;
// 从内存数据创建音频源
let cursor = std::io::Cursor::new(data.to_vec());
let source = rodio::Decoder::new(cursor)
.context("无法解码音频数据")?;
// 创建播放器并播放(单次播放,不循环)
let sink = rodio::Sink::try_new(&stream_handle)
.context("无法创建音频播放器")?;
sink.append(source);
// 等待播放完成
sink.sleep_until_end();
Ok(())
}
/// 将 PCM16 原始数据转换为 WAV 格式
///
/// # 参数
/// - pcm_data: PCM16 原始音频数据16bit, 单声道, 24000Hz
///
/// # 返回
/// 完整的 WAV 格式数据(包含 44 字节头部)
fn pcm16_to_wav(pcm_data: &[u8]) -> Vec<u8> {
let sample_rate: u32 = 24000; // Mimo-TTS PCM16 输出通常是 24kHz
let bits_per_sample: u16 = 16;
let channels: u16 = 1;
let byte_rate = sample_rate * channels as u32 * bits_per_sample as u32 / 8;
let block_align = channels * bits_per_sample / 8;
let data_size = pcm_data.len() as u32;
let file_size = 36 + data_size;
let mut wav = Vec::with_capacity(44 + pcm_data.len());
// RIFF 头
wav.extend_from_slice(b"RIFF");
wav.extend_from_slice(&file_size.to_le_bytes());
wav.extend_from_slice(b"WAVE");
// fmt 子块
wav.extend_from_slice(b"fmt ");
wav.extend_from_slice(&16u32.to_le_bytes()); // PCM 格式大小
wav.extend_from_slice(&1u16.to_le_bytes()); // PCM 格式
wav.extend_from_slice(&channels.to_le_bytes());
wav.extend_from_slice(&sample_rate.to_le_bytes());
wav.extend_from_slice(&byte_rate.to_le_bytes());
wav.extend_from_slice(&block_align.to_le_bytes());
wav.extend_from_slice(&bits_per_sample.to_le_bytes());
// data 子块
wav.extend_from_slice(b"data");
wav.extend_from_slice(&data_size.to_le_bytes());
wav.extend_from_slice(pcm_data);
wav
}