feat: basic chat client with ai
connector part almostly finished.
This commit is contained in:
11
.gitignore
vendored
11
.gitignore
vendored
@@ -16,3 +16,14 @@ target/
|
|||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
|
# Cache directory for conversation sessions
|
||||||
|
.cache/
|
||||||
|
|
||||||
|
# Added by cargo
|
||||||
|
Cargo.lock
|
||||||
|
/target
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
|||||||
15
Cargo.toml
Normal file
15
Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
[package]
|
||||||
|
name = "user-content-agent"
|
||||||
|
version = "0.0.1"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
reqwest = { version = "0.12", features = ["json", "rustls-tls", "stream"], default-features = false }
|
||||||
|
tokio = { version = "1", features = ["full", "signal"] }
|
||||||
|
anyhow = "1.0"
|
||||||
|
dotenvy = "0.15"
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
serde_json = "1.0"
|
||||||
|
futures = "0.3"
|
||||||
|
async-stream = "0.3"
|
||||||
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
145
src/cache.rs
Normal file
145
src/cache.rs
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
use anyhow::{Result, Context};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fs;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
/// 会话数据
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
|
pub struct Session {
|
||||||
|
pub messages: Vec<(String, String)>, // (role, content)
|
||||||
|
pub created_at: String,
|
||||||
|
pub updated_at: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Session {
|
||||||
|
/// 创建新会话
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let now = chrono::Local::now().to_rfc3339();
|
||||||
|
Self {
|
||||||
|
messages: Vec::new(),
|
||||||
|
created_at: now.clone(),
|
||||||
|
updated_at: now,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加消息
|
||||||
|
pub fn add_message(&mut self, role: String, content: String) {
|
||||||
|
self.messages.push((role, content));
|
||||||
|
self.updated_at = chrono::Local::now().to_rfc3339();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取消息列表
|
||||||
|
pub fn get_messages(&self) -> Vec<(String, String)> {
|
||||||
|
self.messages.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 缓存管理器
|
||||||
|
pub struct CacheManager {
|
||||||
|
cache_dir: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CacheManager {
|
||||||
|
/// 创建缓存管理器
|
||||||
|
pub fn new() -> Result<Self> {
|
||||||
|
let cache_dir = PathBuf::from(".cache/sessions");
|
||||||
|
|
||||||
|
// 创建缓存目录
|
||||||
|
if !cache_dir.exists() {
|
||||||
|
fs::create_dir_all(&cache_dir)
|
||||||
|
.context("创建缓存目录失败")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self { cache_dir })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取最新的会话文件路径
|
||||||
|
pub fn get_latest_session_path(&self) -> Option<PathBuf> {
|
||||||
|
if !self.cache_dir.exists() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut entries: Vec<_> = fs::read_dir(&self.cache_dir)
|
||||||
|
.ok()?
|
||||||
|
.filter_map(|entry| entry.ok())
|
||||||
|
.filter(|entry| {
|
||||||
|
entry.path().extension()
|
||||||
|
.map(|ext| ext == "json")
|
||||||
|
.unwrap_or(false)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// 按修改时间排序(最新的在前)
|
||||||
|
entries.sort_by_key(|entry| {
|
||||||
|
entry.metadata()
|
||||||
|
.and_then(|m| m.modified())
|
||||||
|
.unwrap_or(std::time::SystemTime::UNIX_EPOCH)
|
||||||
|
});
|
||||||
|
entries.reverse();
|
||||||
|
|
||||||
|
entries.first().map(|entry| entry.path())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 创建新的会话文件路径
|
||||||
|
pub fn create_session_path(&self) -> PathBuf {
|
||||||
|
let timestamp = chrono::Local::now().format("%Y%m%d_%H%M%S");
|
||||||
|
self.cache_dir.join(format!("session_{}.json", timestamp))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 保存会话
|
||||||
|
pub fn save_session(&self, session: &Session, path: &Path) -> Result<()> {
|
||||||
|
let json = serde_json::to_string_pretty(session)
|
||||||
|
.context("序列化会话失败")?;
|
||||||
|
|
||||||
|
fs::write(path, json)
|
||||||
|
.context("保存会话文件失败")?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 加载会话
|
||||||
|
pub fn load_session(&self, path: &Path) -> Result<Session> {
|
||||||
|
let json = fs::read_to_string(path)
|
||||||
|
.context("读取会话文件失败")?;
|
||||||
|
|
||||||
|
let session: Session = serde_json::from_str(&json)
|
||||||
|
.context("解析会话文件失败")?;
|
||||||
|
|
||||||
|
Ok(session)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取所有会话文件
|
||||||
|
pub fn list_sessions(&self) -> Result<Vec<(String, PathBuf)>> {
|
||||||
|
if !self.cache_dir.exists() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sessions = Vec::new();
|
||||||
|
|
||||||
|
for entry in fs::read_dir(&self.cache_dir)? {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
|
||||||
|
if path.extension().map(|ext| ext == "json").unwrap_or(false) {
|
||||||
|
if let Ok(metadata) = entry.metadata() {
|
||||||
|
if let Ok(modified) = metadata.modified() {
|
||||||
|
let datetime = chrono::DateTime::<chrono::Local>::from(modified);
|
||||||
|
sessions.push((datetime.format("%Y-%m-%d %H:%M:%S").to_string(), path));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按时间排序(最新的在前)
|
||||||
|
sessions.sort_by(|a, b| b.0.cmp(&a.0));
|
||||||
|
|
||||||
|
Ok(sessions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 为Session实现Default trait
|
||||||
|
impl Default for Session {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
183
src/chat.rs
Normal file
183
src/chat.rs
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
use anyhow::{Result, Context};
|
||||||
|
use futures::Stream;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
/// 流式输出块
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum StreamChunk {
|
||||||
|
Reasoning(String),
|
||||||
|
Content(String),
|
||||||
|
Done,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// AI对话客户端,支持OpenAI兼容的API
|
||||||
|
pub struct ChatClient {
|
||||||
|
api_key: String,
|
||||||
|
api_base: String,
|
||||||
|
default_model: String,
|
||||||
|
client: reqwest::Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 消息结构
|
||||||
|
#[derive(Serialize, Debug, Clone)]
|
||||||
|
struct Message {
|
||||||
|
role: String,
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 请求结构
|
||||||
|
#[derive(Serialize, Debug)]
|
||||||
|
struct ChatRequest {
|
||||||
|
model: String,
|
||||||
|
messages: Vec<Message>,
|
||||||
|
stream: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 流式响应结构
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct StreamResponse {
|
||||||
|
choices: Vec<StreamChoice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct StreamChoice {
|
||||||
|
delta: StreamDelta,
|
||||||
|
finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct StreamDelta {
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(rename = "reasoning_content")]
|
||||||
|
reasoning_content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatClient {
|
||||||
|
/// 创建新的对话客户端
|
||||||
|
/// 从环境变量读取配置:
|
||||||
|
/// - OPENAI_API_KEY: API密钥(必需)
|
||||||
|
/// - OPENAI_API_BASE: API地址(可选,默认:https://api.openai.com/v1)
|
||||||
|
/// - OPENAI_MODEL: 默认模型(可选,默认:gpt-3.5-turbo)
|
||||||
|
pub fn new() -> Result<Self> {
|
||||||
|
let api_key = std::env::var("OPENAI_API_KEY")
|
||||||
|
.context("未设置环境变量 OPENAI_API_KEY")?;
|
||||||
|
|
||||||
|
let api_base = std::env::var("OPENAI_API_BASE")
|
||||||
|
.unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
|
||||||
|
|
||||||
|
let default_model = std::env::var("OPENAI_MODEL")
|
||||||
|
.unwrap_or_else(|_| "gpt-3.5-turbo".to_string());
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
api_key,
|
||||||
|
api_base,
|
||||||
|
default_model,
|
||||||
|
client,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 流式发送消息(带对话历史)
|
||||||
|
/// 使用默认模型
|
||||||
|
pub async fn chat_stream(&self, messages: Vec<(String, String)>) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>>> {
|
||||||
|
self.chat_stream_with_model(&self.default_model, messages).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 流式发送消息(带对话历史)
|
||||||
|
/// 指定使用的模型
|
||||||
|
/// messages: Vec<(role, content)>,role可以是"user"或"assistant"
|
||||||
|
pub async fn chat_stream_with_model(&self, model: &str, messages: Vec<(String, String)>) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>>> {
|
||||||
|
let request = ChatRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
messages: messages.into_iter().map(|(role, content)| Message { role, content }).collect(),
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let url = format!("{}/chat/completions", self.api_base);
|
||||||
|
|
||||||
|
let response = self.client
|
||||||
|
.post(&url)
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let error_text = response.text().await?;
|
||||||
|
anyhow::bail!("API请求失败: {} - {}", status, error_text);
|
||||||
|
}
|
||||||
|
|
||||||
|
let stream = response.bytes_stream();
|
||||||
|
|
||||||
|
Ok(Box::pin(async_stream::stream! {
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
|
let mut buffer = String::new();
|
||||||
|
|
||||||
|
for await chunk_result in stream {
|
||||||
|
let chunk = chunk_result?;
|
||||||
|
buffer.push_str(&String::from_utf8_lossy(&chunk));
|
||||||
|
|
||||||
|
// 处理buffer中的完整行
|
||||||
|
while let Some(pos) = buffer.find("\n") {
|
||||||
|
let line = buffer[..pos].trim().to_string();
|
||||||
|
buffer.drain(..pos + 1);
|
||||||
|
|
||||||
|
if line.starts_with("data: ") {
|
||||||
|
let data = &line[6..];
|
||||||
|
|
||||||
|
if data == "[DONE]" {
|
||||||
|
yield Ok(StreamChunk::Done);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match serde_json::from_str::<StreamResponse>(data) {
|
||||||
|
Ok(response) => {
|
||||||
|
if let Some(choice) = response.choices.first() {
|
||||||
|
if let Some(ref finish_reason) = choice.finish_reason {
|
||||||
|
if finish_reason == "stop" {
|
||||||
|
yield Ok(StreamChunk::Done);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 将思维过程拆分成单个字符逐个输出
|
||||||
|
if let Some(ref reasoning) = choice.delta.reasoning_content {
|
||||||
|
for ch in reasoning.chars() {
|
||||||
|
yield Ok(StreamChunk::Reasoning(ch.to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 将回复内容拆分成单个字符逐个输出
|
||||||
|
if let Some(ref content) = choice.delta.content {
|
||||||
|
for ch in content.chars() {
|
||||||
|
yield Ok(StreamChunk::Content(ch.to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// 忽略解析错误,继续处理下一行
|
||||||
|
eprintln!("解析警告: {} - 数据: {}", e, data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 流式发送单条消息(无对话历史,兼容旧接口)
|
||||||
|
/// 使用默认模型
|
||||||
|
pub async fn chat_stream_simple(&self, message: &str) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>>> {
|
||||||
|
self.chat_stream(vec![("user".to_string(), message.to_string())]).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ChatClient {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new().expect("创建ChatClient失败")
|
||||||
|
}
|
||||||
|
}
|
||||||
238
src/interactive.rs
Normal file
238
src/interactive.rs
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
use crate::cache::{CacheManager, Session};
|
||||||
|
use crate::chat::ChatClient;
|
||||||
|
use anyhow::Result;
|
||||||
|
use futures::StreamExt;
|
||||||
|
use std::io::{self, Write};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use tokio::signal;
|
||||||
|
|
||||||
|
/// 交互式对话模式
|
||||||
|
pub async fn run_interactive() -> Result<()> {
|
||||||
|
println!("=== 交互式AI对话 ===");
|
||||||
|
println!("输入 'exit' 或 'quit' 退出对话");
|
||||||
|
println!("按 Ctrl+C 也可以安全退出并保存会话\n");
|
||||||
|
|
||||||
|
// 创建缓存管理器
|
||||||
|
let cache_manager = CacheManager::new()?;
|
||||||
|
|
||||||
|
// 创建AI对话客户端
|
||||||
|
let client = ChatClient::new()?;
|
||||||
|
|
||||||
|
// 维护对话历史(使用Arc<Mutex<>>以便在信号处理函数中访问)
|
||||||
|
let session = Arc::new(Mutex::new(Session::new()));
|
||||||
|
let session_path = cache_manager.create_session_path();
|
||||||
|
let mut message_count: u32 = 0;
|
||||||
|
|
||||||
|
// 检查是否有上次会话
|
||||||
|
if let Some(latest_session_path) = cache_manager.get_latest_session_path() {
|
||||||
|
println!("发现上次的对话会话!");
|
||||||
|
println!("是否要恢复上次的对话? (y/n)");
|
||||||
|
|
||||||
|
print!("> ");
|
||||||
|
io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut input = String::new();
|
||||||
|
io::stdin().read_line(&mut input)?;
|
||||||
|
|
||||||
|
if input.trim().eq_ignore_ascii_case("y") || input.trim().eq_ignore_ascii_case("yes") {
|
||||||
|
match cache_manager.load_session(&latest_session_path) {
|
||||||
|
Ok(loaded_session) => {
|
||||||
|
let mut session_guard = session.lock().unwrap();
|
||||||
|
*session_guard = loaded_session;
|
||||||
|
let msg_count = session_guard.messages.len();
|
||||||
|
println!("✓ 已恢复 {} 条对话消息\n", msg_count);
|
||||||
|
drop(session_guard); // 释放锁
|
||||||
|
|
||||||
|
// 显示最近的几条消息作为提醒
|
||||||
|
let session_guard = session.lock().unwrap();
|
||||||
|
let recent_messages = session_guard.messages.len().saturating_sub(4);
|
||||||
|
if recent_messages < session_guard.messages.len() {
|
||||||
|
println!("最近的对话: ");
|
||||||
|
for (role, content) in &session_guard.messages[recent_messages..] {
|
||||||
|
let prefix = if role == "user" { "你" } else { "AI" };
|
||||||
|
let display_content = if content.chars().count() > 50 {
|
||||||
|
let truncated: String = content.chars().take(50).collect();
|
||||||
|
format!("{}...", truncated)
|
||||||
|
} else {
|
||||||
|
content.replace('\n', " ")
|
||||||
|
};
|
||||||
|
println!(" {}: {}", prefix, display_content);
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
drop(session_guard); // 释放锁
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("恢复会话失败: {}", e);
|
||||||
|
eprintln!("将开始新的对话会话\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
println!("将开始新的对话会话\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置Ctrl+C信号处理
|
||||||
|
let session_clone = Arc::clone(&session);
|
||||||
|
let session_path_clone = session_path.clone();
|
||||||
|
let cache_manager_clone = CacheManager::new()?;
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
match signal::ctrl_c().await {
|
||||||
|
Ok(()) => {
|
||||||
|
println!("\n\n检测到 Ctrl+C,正在保存会话并退出...");
|
||||||
|
let session_guard = session_clone.lock().unwrap();
|
||||||
|
if !session_guard.messages.is_empty() {
|
||||||
|
match cache_manager_clone.save_session(&session_guard, &session_path_clone) {
|
||||||
|
Ok(_) => {
|
||||||
|
let user_messages = session_guard.messages.iter().filter(|(role, _)| role == "user").count();
|
||||||
|
let ai_messages = session_guard.messages.iter().filter(|(role, _)| role == "assistant").count();
|
||||||
|
println!("✓ 会话已保存(共 {} 轮对话)", user_messages);
|
||||||
|
println!("文件: {:?}", session_path_clone);
|
||||||
|
}
|
||||||
|
Err(e) => eprintln!("保存会话失败: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("再见!");
|
||||||
|
std::process::exit(0);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("设置信号处理器失败: {}", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// 显示提示符
|
||||||
|
print!("\n你: ");
|
||||||
|
io::stdout().flush()?;
|
||||||
|
|
||||||
|
// 读取用户输入
|
||||||
|
let mut input = String::new();
|
||||||
|
io::stdin().read_line(&mut input)?;
|
||||||
|
|
||||||
|
// 去除首尾空白
|
||||||
|
let input = input.trim();
|
||||||
|
|
||||||
|
// 检查是否退出
|
||||||
|
if input.eq_ignore_ascii_case("exit") || input.eq_ignore_ascii_case("quit") {
|
||||||
|
println!("\n对话结束,再见!");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查输入是否为空
|
||||||
|
if input.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加用户消息到会话
|
||||||
|
{
|
||||||
|
let mut session_guard = session.lock().unwrap();
|
||||||
|
session_guard.add_message("user".to_string(), input.to_string());
|
||||||
|
}
|
||||||
|
message_count += 1;
|
||||||
|
|
||||||
|
// 每5条消息自动保存一次
|
||||||
|
if message_count % 5 == 0 {
|
||||||
|
let session_guard = session.lock().unwrap();
|
||||||
|
if let Err(e) = cache_manager.save_session(&session_guard, &session_path) {
|
||||||
|
eprintln!("警告: 自动保存失败: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送消息并获取流式回复
|
||||||
|
let messages = {
|
||||||
|
let session_guard = session.lock().unwrap();
|
||||||
|
session_guard.get_messages()
|
||||||
|
};
|
||||||
|
|
||||||
|
match client.chat_stream(messages).await {
|
||||||
|
Ok(mut stream) => {
|
||||||
|
let mut is_first_content = true;
|
||||||
|
let mut is_reasoning = true;
|
||||||
|
let mut assistant_response = String::new();
|
||||||
|
|
||||||
|
print!("\x1b[90mAI思维: ");
|
||||||
|
io::stdout().flush()?;
|
||||||
|
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(crate::chat::StreamChunk::Reasoning(ch)) => {
|
||||||
|
if is_reasoning {
|
||||||
|
print!("{}", ch);
|
||||||
|
io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(crate::chat::StreamChunk::Content(ch)) => {
|
||||||
|
if is_first_content {
|
||||||
|
// 结束思维过程,切换到回复模式
|
||||||
|
println!("\x1b[0m");
|
||||||
|
println!();
|
||||||
|
print!("AI: ");
|
||||||
|
io::stdout().flush()?;
|
||||||
|
is_first_content = false;
|
||||||
|
is_reasoning = false;
|
||||||
|
}
|
||||||
|
print!("{}", ch);
|
||||||
|
io::stdout().flush()?;
|
||||||
|
assistant_response.push_str(&ch);
|
||||||
|
}
|
||||||
|
Ok(crate::chat::StreamChunk::Done) => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("\n\x1b[31m错误: {}\x1b[0m", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保输出完成
|
||||||
|
if is_reasoning {
|
||||||
|
// 如果没有收到内容,说明只有思维过程
|
||||||
|
println!("\x1b[0m");
|
||||||
|
} else {
|
||||||
|
println!("\x1b[0m");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加AI回复到会话
|
||||||
|
if !assistant_response.is_empty() {
|
||||||
|
let mut session_guard = session.lock().unwrap();
|
||||||
|
session_guard.add_message("assistant".to_string(), assistant_response);
|
||||||
|
}
|
||||||
|
|
||||||
|
// AI回复完成后立即保存会话
|
||||||
|
{
|
||||||
|
let session_guard = session.lock().unwrap();
|
||||||
|
if let Err(e) = cache_manager.save_session(&session_guard, &session_path) {
|
||||||
|
eprintln!("\n警告: 保存会话失败: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("\n\x1b[31m发送消息失败: {}\x1b[0m", e);
|
||||||
|
// 出错时移除最后添加的用户消息
|
||||||
|
let mut session_guard = session.lock().unwrap();
|
||||||
|
session_guard.messages.pop();
|
||||||
|
message_count = message_count.saturating_sub(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 退出时保存会话
|
||||||
|
let session_guard = session.lock().unwrap();
|
||||||
|
if !session_guard.messages.is_empty() {
|
||||||
|
println!("\n正在保存对话会话...");
|
||||||
|
match cache_manager.save_session(&session_guard, &session_path) {
|
||||||
|
Ok(_) => println!("✓ 会话已保存到 {:?}", session_path),
|
||||||
|
Err(e) => eprintln!("保存会话失败: {}", e),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 显示会话统计
|
||||||
|
let user_messages = session_guard.messages.iter().filter(|(role, _)| role == "user").count();
|
||||||
|
let ai_messages = session_guard.messages.iter().filter(|(role, _)| role == "assistant").count();
|
||||||
|
println!("本次对话共 {} 轮(你: {}, AI: {})", user_messages, user_messages, ai_messages);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
34
src/main.rs
Normal file
34
src/main.rs
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
mod cache;
|
||||||
|
mod chat;
|
||||||
|
mod interactive;
|
||||||
|
mod test;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
// 加载环境变量
|
||||||
|
dotenvy::dotenv().ok();
|
||||||
|
|
||||||
|
// 获取命令行参数
|
||||||
|
let args: Vec<String> = env::args().collect();
|
||||||
|
|
||||||
|
match args.get(1).map(|s| s.as_str()) {
|
||||||
|
Some("test") => {
|
||||||
|
// 运行测试模式
|
||||||
|
test::run_test().await?;
|
||||||
|
}
|
||||||
|
Some("interactive") | None => {
|
||||||
|
// 运行交互式对话模式(默认)
|
||||||
|
interactive::run_interactive().await?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
println!("用法: {} [test|interactive]", args[0]);
|
||||||
|
println!(" test - 运行测试模式(单次对话)");
|
||||||
|
println!(" interactive - 运行交互式对话模式(默认)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
50
src/test.rs
Normal file
50
src/test.rs
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
use crate::chat::ChatClient;
|
||||||
|
use anyhow::Result;
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
|
/// 测试AI对话功能
|
||||||
|
pub async fn run_test() -> Result<()> {
|
||||||
|
println!("=== 测试AI对话功能 ===\n");
|
||||||
|
|
||||||
|
// 创建AI对话客户端
|
||||||
|
let client = ChatClient::new()?;
|
||||||
|
|
||||||
|
// 创建消息列表(只包含一条用户消息)
|
||||||
|
let messages = vec![("user".to_string(), "你好!请介绍一下你自己。".to_string())];
|
||||||
|
|
||||||
|
// 发送消息并获取流式回复
|
||||||
|
let mut stream = client.chat_stream(messages).await?;
|
||||||
|
|
||||||
|
let mut is_first_content = true;
|
||||||
|
let mut is_reasoning = true;
|
||||||
|
|
||||||
|
print!("\x1b[90m思维过程: ");
|
||||||
|
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
match chunk_result? {
|
||||||
|
crate::chat::StreamChunk::Reasoning(text) => {
|
||||||
|
if is_reasoning {
|
||||||
|
print!("{}", text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
crate::chat::StreamChunk::Content(text) => {
|
||||||
|
if is_first_content {
|
||||||
|
println!("\x1b[0m");
|
||||||
|
println!();
|
||||||
|
print!("AI回复: ");
|
||||||
|
is_first_content = false;
|
||||||
|
is_reasoning = false;
|
||||||
|
}
|
||||||
|
print!("{}", text);
|
||||||
|
}
|
||||||
|
crate::chat::StreamChunk::Done => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!();
|
||||||
|
println!("\n=== 测试完成 ===");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user