feat: basic chat client with ai

connector part almostly finished.
This commit is contained in:
2026-01-18 03:39:41 +08:00
parent f57a2141d8
commit e8e0f421e0
7 changed files with 676 additions and 0 deletions

11
.gitignore vendored
View File

@@ -16,3 +16,14 @@ target/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Cache directory for conversation sessions
.cache/
# Added by cargo
Cargo.lock
/target
# Environment
.env
.env.*

15
Cargo.toml Normal file
View File

@@ -0,0 +1,15 @@
[package]
name = "user-content-agent"
version = "0.0.1"
edition = "2024"
[dependencies]
reqwest = { version = "0.12", features = ["json", "rustls-tls", "stream"], default-features = false }
tokio = { version = "1", features = ["full", "signal"] }
anyhow = "1.0"
dotenvy = "0.15"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
futures = "0.3"
async-stream = "0.3"
chrono = { version = "0.4", features = ["serde"] }

145
src/cache.rs Normal file
View File

@@ -0,0 +1,145 @@
use anyhow::{Result, Context};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::{Path, PathBuf};
/// 会话数据
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Session {
pub messages: Vec<(String, String)>, // (role, content)
pub created_at: String,
pub updated_at: String,
}
impl Session {
/// 创建新会话
pub fn new() -> Self {
let now = chrono::Local::now().to_rfc3339();
Self {
messages: Vec::new(),
created_at: now.clone(),
updated_at: now,
}
}
/// 添加消息
pub fn add_message(&mut self, role: String, content: String) {
self.messages.push((role, content));
self.updated_at = chrono::Local::now().to_rfc3339();
}
/// 获取消息列表
pub fn get_messages(&self) -> Vec<(String, String)> {
self.messages.clone()
}
}
/// 缓存管理器
pub struct CacheManager {
cache_dir: PathBuf,
}
impl CacheManager {
/// 创建缓存管理器
pub fn new() -> Result<Self> {
let cache_dir = PathBuf::from(".cache/sessions");
// 创建缓存目录
if !cache_dir.exists() {
fs::create_dir_all(&cache_dir)
.context("创建缓存目录失败")?;
}
Ok(Self { cache_dir })
}
/// 获取最新的会话文件路径
pub fn get_latest_session_path(&self) -> Option<PathBuf> {
if !self.cache_dir.exists() {
return None;
}
let mut entries: Vec<_> = fs::read_dir(&self.cache_dir)
.ok()?
.filter_map(|entry| entry.ok())
.filter(|entry| {
entry.path().extension()
.map(|ext| ext == "json")
.unwrap_or(false)
})
.collect();
// 按修改时间排序(最新的在前)
entries.sort_by_key(|entry| {
entry.metadata()
.and_then(|m| m.modified())
.unwrap_or(std::time::SystemTime::UNIX_EPOCH)
});
entries.reverse();
entries.first().map(|entry| entry.path())
}
/// 创建新的会话文件路径
pub fn create_session_path(&self) -> PathBuf {
let timestamp = chrono::Local::now().format("%Y%m%d_%H%M%S");
self.cache_dir.join(format!("session_{}.json", timestamp))
}
/// 保存会话
pub fn save_session(&self, session: &Session, path: &Path) -> Result<()> {
let json = serde_json::to_string_pretty(session)
.context("序列化会话失败")?;
fs::write(path, json)
.context("保存会话文件失败")?;
Ok(())
}
/// 加载会话
pub fn load_session(&self, path: &Path) -> Result<Session> {
let json = fs::read_to_string(path)
.context("读取会话文件失败")?;
let session: Session = serde_json::from_str(&json)
.context("解析会话文件失败")?;
Ok(session)
}
/// 获取所有会话文件
pub fn list_sessions(&self) -> Result<Vec<(String, PathBuf)>> {
if !self.cache_dir.exists() {
return Ok(Vec::new());
}
let mut sessions = Vec::new();
for entry in fs::read_dir(&self.cache_dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().map(|ext| ext == "json").unwrap_or(false) {
if let Ok(metadata) = entry.metadata() {
if let Ok(modified) = metadata.modified() {
let datetime = chrono::DateTime::<chrono::Local>::from(modified);
sessions.push((datetime.format("%Y-%m-%d %H:%M:%S").to_string(), path));
}
}
}
}
// 按时间排序(最新的在前)
sessions.sort_by(|a, b| b.0.cmp(&a.0));
Ok(sessions)
}
}
// 为Session实现Default trait
impl Default for Session {
fn default() -> Self {
Self::new()
}
}

183
src/chat.rs Normal file
View File

@@ -0,0 +1,183 @@
use anyhow::{Result, Context};
use futures::Stream;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
/// 流式输出块
#[derive(Debug, Clone)]
pub enum StreamChunk {
Reasoning(String),
Content(String),
Done,
}
/// AI对话客户端支持OpenAI兼容的API
pub struct ChatClient {
api_key: String,
api_base: String,
default_model: String,
client: reqwest::Client,
}
/// 消息结构
#[derive(Serialize, Debug, Clone)]
struct Message {
role: String,
content: String,
}
/// 请求结构
#[derive(Serialize, Debug)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
stream: bool,
}
/// 流式响应结构
#[derive(Deserialize, Debug)]
struct StreamResponse {
choices: Vec<StreamChoice>,
}
#[derive(Deserialize, Debug)]
struct StreamChoice {
delta: StreamDelta,
finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
struct StreamDelta {
content: Option<String>,
#[serde(rename = "reasoning_content")]
reasoning_content: Option<String>,
}
impl ChatClient {
/// 创建新的对话客户端
/// 从环境变量读取配置:
/// - OPENAI_API_KEY: API密钥必需
/// - OPENAI_API_BASE: API地址可选默认https://api.openai.com/v1
/// - OPENAI_MODEL: 默认模型可选默认gpt-3.5-turbo
pub fn new() -> Result<Self> {
let api_key = std::env::var("OPENAI_API_KEY")
.context("未设置环境变量 OPENAI_API_KEY")?;
let api_base = std::env::var("OPENAI_API_BASE")
.unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
let default_model = std::env::var("OPENAI_MODEL")
.unwrap_or_else(|_| "gpt-3.5-turbo".to_string());
let client = reqwest::Client::new();
Ok(Self {
api_key,
api_base,
default_model,
client,
})
}
/// 流式发送消息(带对话历史)
/// 使用默认模型
pub async fn chat_stream(&self, messages: Vec<(String, String)>) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>>> {
self.chat_stream_with_model(&self.default_model, messages).await
}
/// 流式发送消息(带对话历史)
/// 指定使用的模型
/// messages: Vec<(role, content)>role可以是"user"或"assistant"
pub async fn chat_stream_with_model(&self, model: &str, messages: Vec<(String, String)>) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>>> {
let request = ChatRequest {
model: model.to_string(),
messages: messages.into_iter().map(|(role, content)| Message { role, content }).collect(),
stream: true,
};
let url = format!("{}/chat/completions", self.api_base);
let response = self.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await?;
anyhow::bail!("API请求失败: {} - {}", status, error_text);
}
let stream = response.bytes_stream();
Ok(Box::pin(async_stream::stream! {
use futures::StreamExt;
let mut buffer = String::new();
for await chunk_result in stream {
let chunk = chunk_result?;
buffer.push_str(&String::from_utf8_lossy(&chunk));
// 处理buffer中的完整行
while let Some(pos) = buffer.find("\n") {
let line = buffer[..pos].trim().to_string();
buffer.drain(..pos + 1);
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
yield Ok(StreamChunk::Done);
continue;
}
match serde_json::from_str::<StreamResponse>(data) {
Ok(response) => {
if let Some(choice) = response.choices.first() {
if let Some(ref finish_reason) = choice.finish_reason {
if finish_reason == "stop" {
yield Ok(StreamChunk::Done);
}
} else {
// 将思维过程拆分成单个字符逐个输出
if let Some(ref reasoning) = choice.delta.reasoning_content {
for ch in reasoning.chars() {
yield Ok(StreamChunk::Reasoning(ch.to_string()));
}
}
// 将回复内容拆分成单个字符逐个输出
if let Some(ref content) = choice.delta.content {
for ch in content.chars() {
yield Ok(StreamChunk::Content(ch.to_string()));
}
}
}
}
}
Err(e) => {
// 忽略解析错误,继续处理下一行
eprintln!("解析警告: {} - 数据: {}", e, data);
}
}
}
}
}
}))
}
/// 流式发送单条消息(无对话历史,兼容旧接口)
/// 使用默认模型
pub async fn chat_stream_simple(&self, message: &str) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>>> {
self.chat_stream(vec![("user".to_string(), message.to_string())]).await
}
}
impl Default for ChatClient {
fn default() -> Self {
Self::new().expect("创建ChatClient失败")
}
}

238
src/interactive.rs Normal file
View File

@@ -0,0 +1,238 @@
use crate::cache::{CacheManager, Session};
use crate::chat::ChatClient;
use anyhow::Result;
use futures::StreamExt;
use std::io::{self, Write};
use std::sync::{Arc, Mutex};
use tokio::signal;
/// 交互式对话模式
pub async fn run_interactive() -> Result<()> {
println!("=== 交互式AI对话 ===");
println!("输入 'exit' 或 'quit' 退出对话");
println!("按 Ctrl+C 也可以安全退出并保存会话\n");
// 创建缓存管理器
let cache_manager = CacheManager::new()?;
// 创建AI对话客户端
let client = ChatClient::new()?;
// 维护对话历史使用Arc<Mutex<>>以便在信号处理函数中访问)
let session = Arc::new(Mutex::new(Session::new()));
let session_path = cache_manager.create_session_path();
let mut message_count: u32 = 0;
// 检查是否有上次会话
if let Some(latest_session_path) = cache_manager.get_latest_session_path() {
println!("发现上次的对话会话!");
println!("是否要恢复上次的对话? (y/n)");
print!("> ");
io::stdout().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
if input.trim().eq_ignore_ascii_case("y") || input.trim().eq_ignore_ascii_case("yes") {
match cache_manager.load_session(&latest_session_path) {
Ok(loaded_session) => {
let mut session_guard = session.lock().unwrap();
*session_guard = loaded_session;
let msg_count = session_guard.messages.len();
println!("✓ 已恢复 {} 条对话消息\n", msg_count);
drop(session_guard); // 释放锁
// 显示最近的几条消息作为提醒
let session_guard = session.lock().unwrap();
let recent_messages = session_guard.messages.len().saturating_sub(4);
if recent_messages < session_guard.messages.len() {
println!("最近的对话: ");
for (role, content) in &session_guard.messages[recent_messages..] {
let prefix = if role == "user" { "" } else { "AI" };
let display_content = if content.chars().count() > 50 {
let truncated: String = content.chars().take(50).collect();
format!("{}...", truncated)
} else {
content.replace('\n', " ")
};
println!(" {}: {}", prefix, display_content);
}
println!();
}
drop(session_guard); // 释放锁
}
Err(e) => {
eprintln!("恢复会话失败: {}", e);
eprintln!("将开始新的对话会话\n");
}
}
} else {
println!("将开始新的对话会话\n");
}
}
// 设置Ctrl+C信号处理
let session_clone = Arc::clone(&session);
let session_path_clone = session_path.clone();
let cache_manager_clone = CacheManager::new()?;
tokio::spawn(async move {
match signal::ctrl_c().await {
Ok(()) => {
println!("\n\n检测到 Ctrl+C正在保存会话并退出...");
let session_guard = session_clone.lock().unwrap();
if !session_guard.messages.is_empty() {
match cache_manager_clone.save_session(&session_guard, &session_path_clone) {
Ok(_) => {
let user_messages = session_guard.messages.iter().filter(|(role, _)| role == "user").count();
let ai_messages = session_guard.messages.iter().filter(|(role, _)| role == "assistant").count();
println!("✓ 会话已保存(共 {} 轮对话)", user_messages);
println!("文件: {:?}", session_path_clone);
}
Err(e) => eprintln!("保存会话失败: {}", e),
}
}
println!("再见!");
std::process::exit(0);
}
Err(err) => {
eprintln!("设置信号处理器失败: {}", err);
}
}
});
loop {
// 显示提示符
print!("\n你: ");
io::stdout().flush()?;
// 读取用户输入
let mut input = String::new();
io::stdin().read_line(&mut input)?;
// 去除首尾空白
let input = input.trim();
// 检查是否退出
if input.eq_ignore_ascii_case("exit") || input.eq_ignore_ascii_case("quit") {
println!("\n对话结束,再见!");
break;
}
// 检查输入是否为空
if input.is_empty() {
continue;
}
// 添加用户消息到会话
{
let mut session_guard = session.lock().unwrap();
session_guard.add_message("user".to_string(), input.to_string());
}
message_count += 1;
// 每5条消息自动保存一次
if message_count % 5 == 0 {
let session_guard = session.lock().unwrap();
if let Err(e) = cache_manager.save_session(&session_guard, &session_path) {
eprintln!("警告: 自动保存失败: {}", e);
}
}
// 发送消息并获取流式回复
let messages = {
let session_guard = session.lock().unwrap();
session_guard.get_messages()
};
match client.chat_stream(messages).await {
Ok(mut stream) => {
let mut is_first_content = true;
let mut is_reasoning = true;
let mut assistant_response = String::new();
print!("\x1b[90mAI思维: ");
io::stdout().flush()?;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(crate::chat::StreamChunk::Reasoning(ch)) => {
if is_reasoning {
print!("{}", ch);
io::stdout().flush()?;
}
}
Ok(crate::chat::StreamChunk::Content(ch)) => {
if is_first_content {
// 结束思维过程,切换到回复模式
println!("\x1b[0m");
println!();
print!("AI: ");
io::stdout().flush()?;
is_first_content = false;
is_reasoning = false;
}
print!("{}", ch);
io::stdout().flush()?;
assistant_response.push_str(&ch);
}
Ok(crate::chat::StreamChunk::Done) => {
break;
}
Err(e) => {
eprintln!("\n\x1b[31m错误: {}\x1b[0m", e);
break;
}
}
}
// 确保输出完成
if is_reasoning {
// 如果没有收到内容,说明只有思维过程
println!("\x1b[0m");
} else {
println!("\x1b[0m");
}
// 添加AI回复到会话
if !assistant_response.is_empty() {
let mut session_guard = session.lock().unwrap();
session_guard.add_message("assistant".to_string(), assistant_response);
}
// AI回复完成后立即保存会话
{
let session_guard = session.lock().unwrap();
if let Err(e) = cache_manager.save_session(&session_guard, &session_path) {
eprintln!("\n警告: 保存会话失败: {}", e);
}
}
}
Err(e) => {
eprintln!("\n\x1b[31m发送消息失败: {}\x1b[0m", e);
// 出错时移除最后添加的用户消息
let mut session_guard = session.lock().unwrap();
session_guard.messages.pop();
message_count = message_count.saturating_sub(1);
}
}
}
// 退出时保存会话
let session_guard = session.lock().unwrap();
if !session_guard.messages.is_empty() {
println!("\n正在保存对话会话...");
match cache_manager.save_session(&session_guard, &session_path) {
Ok(_) => println!("✓ 会话已保存到 {:?}", session_path),
Err(e) => eprintln!("保存会话失败: {}", e),
}
// 显示会话统计
let user_messages = session_guard.messages.iter().filter(|(role, _)| role == "user").count();
let ai_messages = session_guard.messages.iter().filter(|(role, _)| role == "assistant").count();
println!("本次对话共 {} 轮(你: {}, AI: {}", user_messages, user_messages, ai_messages);
}
Ok(())
}

34
src/main.rs Normal file
View File

@@ -0,0 +1,34 @@
mod cache;
mod chat;
mod interactive;
mod test;
use anyhow::Result;
use std::env;
#[tokio::main]
async fn main() -> Result<()> {
// 加载环境变量
dotenvy::dotenv().ok();
// 获取命令行参数
let args: Vec<String> = env::args().collect();
match args.get(1).map(|s| s.as_str()) {
Some("test") => {
// 运行测试模式
test::run_test().await?;
}
Some("interactive") | None => {
// 运行交互式对话模式(默认)
interactive::run_interactive().await?;
}
_ => {
println!("用法: {} [test|interactive]", args[0]);
println!(" test - 运行测试模式(单次对话)");
println!(" interactive - 运行交互式对话模式(默认)");
}
}
Ok(())
}

50
src/test.rs Normal file
View File

@@ -0,0 +1,50 @@
use crate::chat::ChatClient;
use anyhow::Result;
use futures::StreamExt;
/// 测试AI对话功能
pub async fn run_test() -> Result<()> {
println!("=== 测试AI对话功能 ===\n");
// 创建AI对话客户端
let client = ChatClient::new()?;
// 创建消息列表(只包含一条用户消息)
let messages = vec![("user".to_string(), "你好!请介绍一下你自己。".to_string())];
// 发送消息并获取流式回复
let mut stream = client.chat_stream(messages).await?;
let mut is_first_content = true;
let mut is_reasoning = true;
print!("\x1b[90m思维过程: ");
while let Some(chunk_result) = stream.next().await {
match chunk_result? {
crate::chat::StreamChunk::Reasoning(text) => {
if is_reasoning {
print!("{}", text);
}
}
crate::chat::StreamChunk::Content(text) => {
if is_first_content {
println!("\x1b[0m");
println!();
print!("AI回复: ");
is_first_content = false;
is_reasoning = false;
}
print!("{}", text);
}
crate::chat::StreamChunk::Done => {
break;
}
}
}
println!();
println!("\n=== 测试完成 ===");
Ok(())
}