diff --git a/.gitignore b/.gitignore index 0b188bc..592ebfd 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,14 @@ target/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# Cache directory for conversation sessions +.cache/ + +# Added by cargo +Cargo.lock +/target + +# Environment +.env +.env.* diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..de76095 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "user-content-agent" +version = "0.0.1" +edition = "2024" + +[dependencies] +reqwest = { version = "0.12", features = ["json", "rustls-tls", "stream"], default-features = false } +tokio = { version = "1", features = ["full", "signal"] } +anyhow = "1.0" +dotenvy = "0.15" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +futures = "0.3" +async-stream = "0.3" +chrono = { version = "0.4", features = ["serde"] } diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..1bfe1dd --- /dev/null +++ b/src/cache.rs @@ -0,0 +1,145 @@ +use anyhow::{Result, Context}; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::path::{Path, PathBuf}; + +/// 会话数据 +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Session { + pub messages: Vec<(String, String)>, // (role, content) + pub created_at: String, + pub updated_at: String, +} + +impl Session { + /// 创建新会话 + pub fn new() -> Self { + let now = chrono::Local::now().to_rfc3339(); + Self { + messages: Vec::new(), + created_at: now.clone(), + updated_at: now, + } + } + + /// 添加消息 + pub fn add_message(&mut self, role: String, content: String) { + self.messages.push((role, content)); + self.updated_at = chrono::Local::now().to_rfc3339(); + } + + /// 获取消息列表 + pub fn get_messages(&self) -> Vec<(String, String)> { + self.messages.clone() + } +} + +/// 缓存管理器 +pub struct CacheManager { + cache_dir: PathBuf, +} + +impl CacheManager { + /// 创建缓存管理器 + pub fn new() -> Result { + let cache_dir = PathBuf::from(".cache/sessions"); + + // 创建缓存目录 + if !cache_dir.exists() { + fs::create_dir_all(&cache_dir) + .context("创建缓存目录失败")?; + } + + Ok(Self { cache_dir }) + } + + /// 获取最新的会话文件路径 + pub fn get_latest_session_path(&self) -> Option { + if !self.cache_dir.exists() { + return None; + } + + let mut entries: Vec<_> = fs::read_dir(&self.cache_dir) + .ok()? + .filter_map(|entry| entry.ok()) + .filter(|entry| { + entry.path().extension() + .map(|ext| ext == "json") + .unwrap_or(false) + }) + .collect(); + + // 按修改时间排序(最新的在前) + entries.sort_by_key(|entry| { + entry.metadata() + .and_then(|m| m.modified()) + .unwrap_or(std::time::SystemTime::UNIX_EPOCH) + }); + entries.reverse(); + + entries.first().map(|entry| entry.path()) + } + + /// 创建新的会话文件路径 + pub fn create_session_path(&self) -> PathBuf { + let timestamp = chrono::Local::now().format("%Y%m%d_%H%M%S"); + self.cache_dir.join(format!("session_{}.json", timestamp)) + } + + /// 保存会话 + pub fn save_session(&self, session: &Session, path: &Path) -> Result<()> { + let json = serde_json::to_string_pretty(session) + .context("序列化会话失败")?; + + fs::write(path, json) + .context("保存会话文件失败")?; + + Ok(()) + } + + /// 加载会话 + pub fn load_session(&self, path: &Path) -> Result { + let json = fs::read_to_string(path) + .context("读取会话文件失败")?; + + let session: Session = serde_json::from_str(&json) + .context("解析会话文件失败")?; + + Ok(session) + } + + /// 获取所有会话文件 + pub fn list_sessions(&self) -> Result> { + if !self.cache_dir.exists() { + return Ok(Vec::new()); + } + + let mut sessions = Vec::new(); + + for entry in fs::read_dir(&self.cache_dir)? { + let entry = entry?; + let path = entry.path(); + + if path.extension().map(|ext| ext == "json").unwrap_or(false) { + if let Ok(metadata) = entry.metadata() { + if let Ok(modified) = metadata.modified() { + let datetime = chrono::DateTime::::from(modified); + sessions.push((datetime.format("%Y-%m-%d %H:%M:%S").to_string(), path)); + } + } + } + } + + // 按时间排序(最新的在前) + sessions.sort_by(|a, b| b.0.cmp(&a.0)); + + Ok(sessions) + } +} + +// 为Session实现Default trait +impl Default for Session { + fn default() -> Self { + Self::new() + } +} diff --git a/src/chat.rs b/src/chat.rs new file mode 100644 index 0000000..b353340 --- /dev/null +++ b/src/chat.rs @@ -0,0 +1,183 @@ +use anyhow::{Result, Context}; +use futures::Stream; +use serde::{Deserialize, Serialize}; +use std::pin::Pin; + +/// 流式输出块 +#[derive(Debug, Clone)] +pub enum StreamChunk { + Reasoning(String), + Content(String), + Done, +} + +/// AI对话客户端,支持OpenAI兼容的API +pub struct ChatClient { + api_key: String, + api_base: String, + default_model: String, + client: reqwest::Client, +} + +/// 消息结构 +#[derive(Serialize, Debug, Clone)] +struct Message { + role: String, + content: String, +} + +/// 请求结构 +#[derive(Serialize, Debug)] +struct ChatRequest { + model: String, + messages: Vec, + stream: bool, +} + +/// 流式响应结构 +#[derive(Deserialize, Debug)] +struct StreamResponse { + choices: Vec, +} + +#[derive(Deserialize, Debug)] +struct StreamChoice { + delta: StreamDelta, + finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +struct StreamDelta { + content: Option, + #[serde(rename = "reasoning_content")] + reasoning_content: Option, +} + +impl ChatClient { + /// 创建新的对话客户端 + /// 从环境变量读取配置: + /// - OPENAI_API_KEY: API密钥(必需) + /// - OPENAI_API_BASE: API地址(可选,默认:https://api.openai.com/v1) + /// - OPENAI_MODEL: 默认模型(可选,默认:gpt-3.5-turbo) + pub fn new() -> Result { + let api_key = std::env::var("OPENAI_API_KEY") + .context("未设置环境变量 OPENAI_API_KEY")?; + + let api_base = std::env::var("OPENAI_API_BASE") + .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()); + + let default_model = std::env::var("OPENAI_MODEL") + .unwrap_or_else(|_| "gpt-3.5-turbo".to_string()); + + let client = reqwest::Client::new(); + + Ok(Self { + api_key, + api_base, + default_model, + client, + }) + } + + /// 流式发送消息(带对话历史) + /// 使用默认模型 + pub async fn chat_stream(&self, messages: Vec<(String, String)>) -> Result> + Send + '_>>> { + self.chat_stream_with_model(&self.default_model, messages).await + } + + /// 流式发送消息(带对话历史) + /// 指定使用的模型 + /// messages: Vec<(role, content)>,role可以是"user"或"assistant" + pub async fn chat_stream_with_model(&self, model: &str, messages: Vec<(String, String)>) -> Result> + Send + '_>>> { + let request = ChatRequest { + model: model.to_string(), + messages: messages.into_iter().map(|(role, content)| Message { role, content }).collect(), + stream: true, + }; + + let url = format!("{}/chat/completions", self.api_base); + + let response = self.client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&request) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response.text().await?; + anyhow::bail!("API请求失败: {} - {}", status, error_text); + } + + let stream = response.bytes_stream(); + + Ok(Box::pin(async_stream::stream! { + use futures::StreamExt; + + let mut buffer = String::new(); + + for await chunk_result in stream { + let chunk = chunk_result?; + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + // 处理buffer中的完整行 + while let Some(pos) = buffer.find("\n") { + let line = buffer[..pos].trim().to_string(); + buffer.drain(..pos + 1); + + if line.starts_with("data: ") { + let data = &line[6..]; + + if data == "[DONE]" { + yield Ok(StreamChunk::Done); + continue; + } + + match serde_json::from_str::(data) { + Ok(response) => { + if let Some(choice) = response.choices.first() { + if let Some(ref finish_reason) = choice.finish_reason { + if finish_reason == "stop" { + yield Ok(StreamChunk::Done); + } + } else { + // 将思维过程拆分成单个字符逐个输出 + if let Some(ref reasoning) = choice.delta.reasoning_content { + for ch in reasoning.chars() { + yield Ok(StreamChunk::Reasoning(ch.to_string())); + } + } + // 将回复内容拆分成单个字符逐个输出 + if let Some(ref content) = choice.delta.content { + for ch in content.chars() { + yield Ok(StreamChunk::Content(ch.to_string())); + } + } + } + } + } + Err(e) => { + // 忽略解析错误,继续处理下一行 + eprintln!("解析警告: {} - 数据: {}", e, data); + } + } + } + } + } + })) + } + + /// 流式发送单条消息(无对话历史,兼容旧接口) + /// 使用默认模型 + pub async fn chat_stream_simple(&self, message: &str) -> Result> + Send + '_>>> { + self.chat_stream(vec![("user".to_string(), message.to_string())]).await + } +} + +impl Default for ChatClient { + fn default() -> Self { + Self::new().expect("创建ChatClient失败") + } +} diff --git a/src/interactive.rs b/src/interactive.rs new file mode 100644 index 0000000..d48d551 --- /dev/null +++ b/src/interactive.rs @@ -0,0 +1,238 @@ +use crate::cache::{CacheManager, Session}; +use crate::chat::ChatClient; +use anyhow::Result; +use futures::StreamExt; +use std::io::{self, Write}; +use std::sync::{Arc, Mutex}; +use tokio::signal; + +/// 交互式对话模式 +pub async fn run_interactive() -> Result<()> { + println!("=== 交互式AI对话 ==="); + println!("输入 'exit' 或 'quit' 退出对话"); + println!("按 Ctrl+C 也可以安全退出并保存会话\n"); + + // 创建缓存管理器 + let cache_manager = CacheManager::new()?; + + // 创建AI对话客户端 + let client = ChatClient::new()?; + + // 维护对话历史(使用Arc>以便在信号处理函数中访问) + let session = Arc::new(Mutex::new(Session::new())); + let session_path = cache_manager.create_session_path(); + let mut message_count: u32 = 0; + + // 检查是否有上次会话 + if let Some(latest_session_path) = cache_manager.get_latest_session_path() { + println!("发现上次的对话会话!"); + println!("是否要恢复上次的对话? (y/n)"); + + print!("> "); + io::stdout().flush()?; + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + + if input.trim().eq_ignore_ascii_case("y") || input.trim().eq_ignore_ascii_case("yes") { + match cache_manager.load_session(&latest_session_path) { + Ok(loaded_session) => { + let mut session_guard = session.lock().unwrap(); + *session_guard = loaded_session; + let msg_count = session_guard.messages.len(); + println!("✓ 已恢复 {} 条对话消息\n", msg_count); + drop(session_guard); // 释放锁 + + // 显示最近的几条消息作为提醒 + let session_guard = session.lock().unwrap(); + let recent_messages = session_guard.messages.len().saturating_sub(4); + if recent_messages < session_guard.messages.len() { + println!("最近的对话: "); + for (role, content) in &session_guard.messages[recent_messages..] { + let prefix = if role == "user" { "你" } else { "AI" }; + let display_content = if content.chars().count() > 50 { + let truncated: String = content.chars().take(50).collect(); + format!("{}...", truncated) + } else { + content.replace('\n', " ") + }; + println!(" {}: {}", prefix, display_content); + } + println!(); + } + drop(session_guard); // 释放锁 + } + Err(e) => { + eprintln!("恢复会话失败: {}", e); + eprintln!("将开始新的对话会话\n"); + } + } + } else { + println!("将开始新的对话会话\n"); + } + } + + // 设置Ctrl+C信号处理 + let session_clone = Arc::clone(&session); + let session_path_clone = session_path.clone(); + let cache_manager_clone = CacheManager::new()?; + + tokio::spawn(async move { + match signal::ctrl_c().await { + Ok(()) => { + println!("\n\n检测到 Ctrl+C,正在保存会话并退出..."); + let session_guard = session_clone.lock().unwrap(); + if !session_guard.messages.is_empty() { + match cache_manager_clone.save_session(&session_guard, &session_path_clone) { + Ok(_) => { + let user_messages = session_guard.messages.iter().filter(|(role, _)| role == "user").count(); + let ai_messages = session_guard.messages.iter().filter(|(role, _)| role == "assistant").count(); + println!("✓ 会话已保存(共 {} 轮对话)", user_messages); + println!("文件: {:?}", session_path_clone); + } + Err(e) => eprintln!("保存会话失败: {}", e), + } + } + println!("再见!"); + std::process::exit(0); + } + Err(err) => { + eprintln!("设置信号处理器失败: {}", err); + } + } + }); + + loop { + // 显示提示符 + print!("\n你: "); + io::stdout().flush()?; + + // 读取用户输入 + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + + // 去除首尾空白 + let input = input.trim(); + + // 检查是否退出 + if input.eq_ignore_ascii_case("exit") || input.eq_ignore_ascii_case("quit") { + println!("\n对话结束,再见!"); + break; + } + + // 检查输入是否为空 + if input.is_empty() { + continue; + } + + // 添加用户消息到会话 + { + let mut session_guard = session.lock().unwrap(); + session_guard.add_message("user".to_string(), input.to_string()); + } + message_count += 1; + + // 每5条消息自动保存一次 + if message_count % 5 == 0 { + let session_guard = session.lock().unwrap(); + if let Err(e) = cache_manager.save_session(&session_guard, &session_path) { + eprintln!("警告: 自动保存失败: {}", e); + } + } + + // 发送消息并获取流式回复 + let messages = { + let session_guard = session.lock().unwrap(); + session_guard.get_messages() + }; + + match client.chat_stream(messages).await { + Ok(mut stream) => { + let mut is_first_content = true; + let mut is_reasoning = true; + let mut assistant_response = String::new(); + + print!("\x1b[90mAI思维: "); + io::stdout().flush()?; + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(crate::chat::StreamChunk::Reasoning(ch)) => { + if is_reasoning { + print!("{}", ch); + io::stdout().flush()?; + } + } + Ok(crate::chat::StreamChunk::Content(ch)) => { + if is_first_content { + // 结束思维过程,切换到回复模式 + println!("\x1b[0m"); + println!(); + print!("AI: "); + io::stdout().flush()?; + is_first_content = false; + is_reasoning = false; + } + print!("{}", ch); + io::stdout().flush()?; + assistant_response.push_str(&ch); + } + Ok(crate::chat::StreamChunk::Done) => { + break; + } + Err(e) => { + eprintln!("\n\x1b[31m错误: {}\x1b[0m", e); + break; + } + } + } + + // 确保输出完成 + if is_reasoning { + // 如果没有收到内容,说明只有思维过程 + println!("\x1b[0m"); + } else { + println!("\x1b[0m"); + } + + // 添加AI回复到会话 + if !assistant_response.is_empty() { + let mut session_guard = session.lock().unwrap(); + session_guard.add_message("assistant".to_string(), assistant_response); + } + + // AI回复完成后立即保存会话 + { + let session_guard = session.lock().unwrap(); + if let Err(e) = cache_manager.save_session(&session_guard, &session_path) { + eprintln!("\n警告: 保存会话失败: {}", e); + } + } + } + Err(e) => { + eprintln!("\n\x1b[31m发送消息失败: {}\x1b[0m", e); + // 出错时移除最后添加的用户消息 + let mut session_guard = session.lock().unwrap(); + session_guard.messages.pop(); + message_count = message_count.saturating_sub(1); + } + } + } + + // 退出时保存会话 + let session_guard = session.lock().unwrap(); + if !session_guard.messages.is_empty() { + println!("\n正在保存对话会话..."); + match cache_manager.save_session(&session_guard, &session_path) { + Ok(_) => println!("✓ 会话已保存到 {:?}", session_path), + Err(e) => eprintln!("保存会话失败: {}", e), + } + + // 显示会话统计 + let user_messages = session_guard.messages.iter().filter(|(role, _)| role == "user").count(); + let ai_messages = session_guard.messages.iter().filter(|(role, _)| role == "assistant").count(); + println!("本次对话共 {} 轮(你: {}, AI: {})", user_messages, user_messages, ai_messages); + } + + Ok(()) +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..ec27416 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,34 @@ +mod cache; +mod chat; +mod interactive; +mod test; + +use anyhow::Result; +use std::env; + +#[tokio::main] +async fn main() -> Result<()> { + // 加载环境变量 + dotenvy::dotenv().ok(); + + // 获取命令行参数 + let args: Vec = env::args().collect(); + + match args.get(1).map(|s| s.as_str()) { + Some("test") => { + // 运行测试模式 + test::run_test().await?; + } + Some("interactive") | None => { + // 运行交互式对话模式(默认) + interactive::run_interactive().await?; + } + _ => { + println!("用法: {} [test|interactive]", args[0]); + println!(" test - 运行测试模式(单次对话)"); + println!(" interactive - 运行交互式对话模式(默认)"); + } + } + + Ok(()) +} diff --git a/src/test.rs b/src/test.rs new file mode 100644 index 0000000..ba4e31c --- /dev/null +++ b/src/test.rs @@ -0,0 +1,50 @@ +use crate::chat::ChatClient; +use anyhow::Result; +use futures::StreamExt; + +/// 测试AI对话功能 +pub async fn run_test() -> Result<()> { + println!("=== 测试AI对话功能 ===\n"); + + // 创建AI对话客户端 + let client = ChatClient::new()?; + + // 创建消息列表(只包含一条用户消息) + let messages = vec![("user".to_string(), "你好!请介绍一下你自己。".to_string())]; + + // 发送消息并获取流式回复 + let mut stream = client.chat_stream(messages).await?; + + let mut is_first_content = true; + let mut is_reasoning = true; + + print!("\x1b[90m思维过程: "); + + while let Some(chunk_result) = stream.next().await { + match chunk_result? { + crate::chat::StreamChunk::Reasoning(text) => { + if is_reasoning { + print!("{}", text); + } + } + crate::chat::StreamChunk::Content(text) => { + if is_first_content { + println!("\x1b[0m"); + println!(); + print!("AI回复: "); + is_first_content = false; + is_reasoning = false; + } + print!("{}", text); + } + crate::chat::StreamChunk::Done => { + break; + } + } + } + + println!(); + println!("\n=== 测试完成 ==="); + + Ok(()) +}