feat: basic chat client with ai
connector part almostly finished.
This commit is contained in:
11
.gitignore
vendored
11
.gitignore
vendored
@@ -16,3 +16,14 @@ target/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# Cache directory for conversation sessions
|
||||
.cache/
|
||||
|
||||
# Added by cargo
|
||||
Cargo.lock
|
||||
/target
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.env.*
|
||||
|
||||
15
Cargo.toml
Normal file
15
Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[package]
|
||||
name = "user-content-agent"
|
||||
version = "0.0.1"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
reqwest = { version = "0.12", features = ["json", "rustls-tls", "stream"], default-features = false }
|
||||
tokio = { version = "1", features = ["full", "signal"] }
|
||||
anyhow = "1.0"
|
||||
dotenvy = "0.15"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
futures = "0.3"
|
||||
async-stream = "0.3"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
145
src/cache.rs
Normal file
145
src/cache.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
use anyhow::{Result, Context};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// 会话数据
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Session {
|
||||
pub messages: Vec<(String, String)>, // (role, content)
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// 创建新会话
|
||||
pub fn new() -> Self {
|
||||
let now = chrono::Local::now().to_rfc3339();
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
created_at: now.clone(),
|
||||
updated_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// 添加消息
|
||||
pub fn add_message(&mut self, role: String, content: String) {
|
||||
self.messages.push((role, content));
|
||||
self.updated_at = chrono::Local::now().to_rfc3339();
|
||||
}
|
||||
|
||||
/// 获取消息列表
|
||||
pub fn get_messages(&self) -> Vec<(String, String)> {
|
||||
self.messages.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// 缓存管理器
|
||||
pub struct CacheManager {
|
||||
cache_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl CacheManager {
|
||||
/// 创建缓存管理器
|
||||
pub fn new() -> Result<Self> {
|
||||
let cache_dir = PathBuf::from(".cache/sessions");
|
||||
|
||||
// 创建缓存目录
|
||||
if !cache_dir.exists() {
|
||||
fs::create_dir_all(&cache_dir)
|
||||
.context("创建缓存目录失败")?;
|
||||
}
|
||||
|
||||
Ok(Self { cache_dir })
|
||||
}
|
||||
|
||||
/// 获取最新的会话文件路径
|
||||
pub fn get_latest_session_path(&self) -> Option<PathBuf> {
|
||||
if !self.cache_dir.exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut entries: Vec<_> = fs::read_dir(&self.cache_dir)
|
||||
.ok()?
|
||||
.filter_map(|entry| entry.ok())
|
||||
.filter(|entry| {
|
||||
entry.path().extension()
|
||||
.map(|ext| ext == "json")
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 按修改时间排序(最新的在前)
|
||||
entries.sort_by_key(|entry| {
|
||||
entry.metadata()
|
||||
.and_then(|m| m.modified())
|
||||
.unwrap_or(std::time::SystemTime::UNIX_EPOCH)
|
||||
});
|
||||
entries.reverse();
|
||||
|
||||
entries.first().map(|entry| entry.path())
|
||||
}
|
||||
|
||||
/// 创建新的会话文件路径
|
||||
pub fn create_session_path(&self) -> PathBuf {
|
||||
let timestamp = chrono::Local::now().format("%Y%m%d_%H%M%S");
|
||||
self.cache_dir.join(format!("session_{}.json", timestamp))
|
||||
}
|
||||
|
||||
/// 保存会话
|
||||
pub fn save_session(&self, session: &Session, path: &Path) -> Result<()> {
|
||||
let json = serde_json::to_string_pretty(session)
|
||||
.context("序列化会话失败")?;
|
||||
|
||||
fs::write(path, json)
|
||||
.context("保存会话文件失败")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 加载会话
|
||||
pub fn load_session(&self, path: &Path) -> Result<Session> {
|
||||
let json = fs::read_to_string(path)
|
||||
.context("读取会话文件失败")?;
|
||||
|
||||
let session: Session = serde_json::from_str(&json)
|
||||
.context("解析会话文件失败")?;
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
/// 获取所有会话文件
|
||||
pub fn list_sessions(&self) -> Result<Vec<(String, PathBuf)>> {
|
||||
if !self.cache_dir.exists() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut sessions = Vec::new();
|
||||
|
||||
for entry in fs::read_dir(&self.cache_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.extension().map(|ext| ext == "json").unwrap_or(false) {
|
||||
if let Ok(metadata) = entry.metadata() {
|
||||
if let Ok(modified) = metadata.modified() {
|
||||
let datetime = chrono::DateTime::<chrono::Local>::from(modified);
|
||||
sessions.push((datetime.format("%Y-%m-%d %H:%M:%S").to_string(), path));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 按时间排序(最新的在前)
|
||||
sessions.sort_by(|a, b| b.0.cmp(&a.0));
|
||||
|
||||
Ok(sessions)
|
||||
}
|
||||
}
|
||||
|
||||
// 为Session实现Default trait
|
||||
impl Default for Session {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
183
src/chat.rs
Normal file
183
src/chat.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use anyhow::{Result, Context};
|
||||
use futures::Stream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::pin::Pin;
|
||||
|
||||
/// 流式输出块
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StreamChunk {
|
||||
Reasoning(String),
|
||||
Content(String),
|
||||
Done,
|
||||
}
|
||||
|
||||
/// AI对话客户端,支持OpenAI兼容的API
|
||||
pub struct ChatClient {
|
||||
api_key: String,
|
||||
api_base: String,
|
||||
default_model: String,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
/// 消息结构
|
||||
#[derive(Serialize, Debug, Clone)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
/// 请求结构
|
||||
#[derive(Serialize, Debug)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
/// 流式响应结构
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct StreamResponse {
|
||||
choices: Vec<StreamChoice>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct StreamChoice {
|
||||
delta: StreamDelta,
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct StreamDelta {
|
||||
content: Option<String>,
|
||||
#[serde(rename = "reasoning_content")]
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
impl ChatClient {
|
||||
/// 创建新的对话客户端
|
||||
/// 从环境变量读取配置:
|
||||
/// - OPENAI_API_KEY: API密钥(必需)
|
||||
/// - OPENAI_API_BASE: API地址(可选,默认:https://api.openai.com/v1)
|
||||
/// - OPENAI_MODEL: 默认模型(可选,默认:gpt-3.5-turbo)
|
||||
pub fn new() -> Result<Self> {
|
||||
let api_key = std::env::var("OPENAI_API_KEY")
|
||||
.context("未设置环境变量 OPENAI_API_KEY")?;
|
||||
|
||||
let api_base = std::env::var("OPENAI_API_BASE")
|
||||
.unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
|
||||
|
||||
let default_model = std::env::var("OPENAI_MODEL")
|
||||
.unwrap_or_else(|_| "gpt-3.5-turbo".to_string());
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
Ok(Self {
|
||||
api_key,
|
||||
api_base,
|
||||
default_model,
|
||||
client,
|
||||
})
|
||||
}
|
||||
|
||||
/// 流式发送消息(带对话历史)
|
||||
/// 使用默认模型
|
||||
pub async fn chat_stream(&self, messages: Vec<(String, String)>) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>>> {
|
||||
self.chat_stream_with_model(&self.default_model, messages).await
|
||||
}
|
||||
|
||||
/// 流式发送消息(带对话历史)
|
||||
/// 指定使用的模型
|
||||
/// messages: Vec<(role, content)>,role可以是"user"或"assistant"
|
||||
pub async fn chat_stream_with_model(&self, model: &str, messages: Vec<(String, String)>) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>>> {
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages: messages.into_iter().map(|(role, content)| Message { role, content }).collect(),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let url = format!("{}/chat/completions", self.api_base);
|
||||
|
||||
let response = self.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response.text().await?;
|
||||
anyhow::bail!("API请求失败: {} - {}", status, error_text);
|
||||
}
|
||||
|
||||
let stream = response.bytes_stream();
|
||||
|
||||
Ok(Box::pin(async_stream::stream! {
|
||||
use futures::StreamExt;
|
||||
|
||||
let mut buffer = String::new();
|
||||
|
||||
for await chunk_result in stream {
|
||||
let chunk = chunk_result?;
|
||||
buffer.push_str(&String::from_utf8_lossy(&chunk));
|
||||
|
||||
// 处理buffer中的完整行
|
||||
while let Some(pos) = buffer.find("\n") {
|
||||
let line = buffer[..pos].trim().to_string();
|
||||
buffer.drain(..pos + 1);
|
||||
|
||||
if line.starts_with("data: ") {
|
||||
let data = &line[6..];
|
||||
|
||||
if data == "[DONE]" {
|
||||
yield Ok(StreamChunk::Done);
|
||||
continue;
|
||||
}
|
||||
|
||||
match serde_json::from_str::<StreamResponse>(data) {
|
||||
Ok(response) => {
|
||||
if let Some(choice) = response.choices.first() {
|
||||
if let Some(ref finish_reason) = choice.finish_reason {
|
||||
if finish_reason == "stop" {
|
||||
yield Ok(StreamChunk::Done);
|
||||
}
|
||||
} else {
|
||||
// 将思维过程拆分成单个字符逐个输出
|
||||
if let Some(ref reasoning) = choice.delta.reasoning_content {
|
||||
for ch in reasoning.chars() {
|
||||
yield Ok(StreamChunk::Reasoning(ch.to_string()));
|
||||
}
|
||||
}
|
||||
// 将回复内容拆分成单个字符逐个输出
|
||||
if let Some(ref content) = choice.delta.content {
|
||||
for ch in content.chars() {
|
||||
yield Ok(StreamChunk::Content(ch.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// 忽略解析错误,继续处理下一行
|
||||
eprintln!("解析警告: {} - 数据: {}", e, data);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// 流式发送单条消息(无对话历史,兼容旧接口)
|
||||
/// 使用默认模型
|
||||
pub async fn chat_stream_simple(&self, message: &str) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>>> {
|
||||
self.chat_stream(vec![("user".to_string(), message.to_string())]).await
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ChatClient {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("创建ChatClient失败")
|
||||
}
|
||||
}
|
||||
238
src/interactive.rs
Normal file
238
src/interactive.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use crate::cache::{CacheManager, Session};
|
||||
use crate::chat::ChatClient;
|
||||
use anyhow::Result;
|
||||
use futures::StreamExt;
|
||||
use std::io::{self, Write};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::signal;
|
||||
|
||||
/// 交互式对话模式
|
||||
pub async fn run_interactive() -> Result<()> {
|
||||
println!("=== 交互式AI对话 ===");
|
||||
println!("输入 'exit' 或 'quit' 退出对话");
|
||||
println!("按 Ctrl+C 也可以安全退出并保存会话\n");
|
||||
|
||||
// 创建缓存管理器
|
||||
let cache_manager = CacheManager::new()?;
|
||||
|
||||
// 创建AI对话客户端
|
||||
let client = ChatClient::new()?;
|
||||
|
||||
// 维护对话历史(使用Arc<Mutex<>>以便在信号处理函数中访问)
|
||||
let session = Arc::new(Mutex::new(Session::new()));
|
||||
let session_path = cache_manager.create_session_path();
|
||||
let mut message_count: u32 = 0;
|
||||
|
||||
// 检查是否有上次会话
|
||||
if let Some(latest_session_path) = cache_manager.get_latest_session_path() {
|
||||
println!("发现上次的对话会话!");
|
||||
println!("是否要恢复上次的对话? (y/n)");
|
||||
|
||||
print!("> ");
|
||||
io::stdout().flush()?;
|
||||
|
||||
let mut input = String::new();
|
||||
io::stdin().read_line(&mut input)?;
|
||||
|
||||
if input.trim().eq_ignore_ascii_case("y") || input.trim().eq_ignore_ascii_case("yes") {
|
||||
match cache_manager.load_session(&latest_session_path) {
|
||||
Ok(loaded_session) => {
|
||||
let mut session_guard = session.lock().unwrap();
|
||||
*session_guard = loaded_session;
|
||||
let msg_count = session_guard.messages.len();
|
||||
println!("✓ 已恢复 {} 条对话消息\n", msg_count);
|
||||
drop(session_guard); // 释放锁
|
||||
|
||||
// 显示最近的几条消息作为提醒
|
||||
let session_guard = session.lock().unwrap();
|
||||
let recent_messages = session_guard.messages.len().saturating_sub(4);
|
||||
if recent_messages < session_guard.messages.len() {
|
||||
println!("最近的对话: ");
|
||||
for (role, content) in &session_guard.messages[recent_messages..] {
|
||||
let prefix = if role == "user" { "你" } else { "AI" };
|
||||
let display_content = if content.chars().count() > 50 {
|
||||
let truncated: String = content.chars().take(50).collect();
|
||||
format!("{}...", truncated)
|
||||
} else {
|
||||
content.replace('\n', " ")
|
||||
};
|
||||
println!(" {}: {}", prefix, display_content);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
drop(session_guard); // 释放锁
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("恢复会话失败: {}", e);
|
||||
eprintln!("将开始新的对话会话\n");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!("将开始新的对话会话\n");
|
||||
}
|
||||
}
|
||||
|
||||
// 设置Ctrl+C信号处理
|
||||
let session_clone = Arc::clone(&session);
|
||||
let session_path_clone = session_path.clone();
|
||||
let cache_manager_clone = CacheManager::new()?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
match signal::ctrl_c().await {
|
||||
Ok(()) => {
|
||||
println!("\n\n检测到 Ctrl+C,正在保存会话并退出...");
|
||||
let session_guard = session_clone.lock().unwrap();
|
||||
if !session_guard.messages.is_empty() {
|
||||
match cache_manager_clone.save_session(&session_guard, &session_path_clone) {
|
||||
Ok(_) => {
|
||||
let user_messages = session_guard.messages.iter().filter(|(role, _)| role == "user").count();
|
||||
let ai_messages = session_guard.messages.iter().filter(|(role, _)| role == "assistant").count();
|
||||
println!("✓ 会话已保存(共 {} 轮对话)", user_messages);
|
||||
println!("文件: {:?}", session_path_clone);
|
||||
}
|
||||
Err(e) => eprintln!("保存会话失败: {}", e),
|
||||
}
|
||||
}
|
||||
println!("再见!");
|
||||
std::process::exit(0);
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("设置信号处理器失败: {}", err);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
// 显示提示符
|
||||
print!("\n你: ");
|
||||
io::stdout().flush()?;
|
||||
|
||||
// 读取用户输入
|
||||
let mut input = String::new();
|
||||
io::stdin().read_line(&mut input)?;
|
||||
|
||||
// 去除首尾空白
|
||||
let input = input.trim();
|
||||
|
||||
// 检查是否退出
|
||||
if input.eq_ignore_ascii_case("exit") || input.eq_ignore_ascii_case("quit") {
|
||||
println!("\n对话结束,再见!");
|
||||
break;
|
||||
}
|
||||
|
||||
// 检查输入是否为空
|
||||
if input.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// 添加用户消息到会话
|
||||
{
|
||||
let mut session_guard = session.lock().unwrap();
|
||||
session_guard.add_message("user".to_string(), input.to_string());
|
||||
}
|
||||
message_count += 1;
|
||||
|
||||
// 每5条消息自动保存一次
|
||||
if message_count % 5 == 0 {
|
||||
let session_guard = session.lock().unwrap();
|
||||
if let Err(e) = cache_manager.save_session(&session_guard, &session_path) {
|
||||
eprintln!("警告: 自动保存失败: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 发送消息并获取流式回复
|
||||
let messages = {
|
||||
let session_guard = session.lock().unwrap();
|
||||
session_guard.get_messages()
|
||||
};
|
||||
|
||||
match client.chat_stream(messages).await {
|
||||
Ok(mut stream) => {
|
||||
let mut is_first_content = true;
|
||||
let mut is_reasoning = true;
|
||||
let mut assistant_response = String::new();
|
||||
|
||||
print!("\x1b[90mAI思维: ");
|
||||
io::stdout().flush()?;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(crate::chat::StreamChunk::Reasoning(ch)) => {
|
||||
if is_reasoning {
|
||||
print!("{}", ch);
|
||||
io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
Ok(crate::chat::StreamChunk::Content(ch)) => {
|
||||
if is_first_content {
|
||||
// 结束思维过程,切换到回复模式
|
||||
println!("\x1b[0m");
|
||||
println!();
|
||||
print!("AI: ");
|
||||
io::stdout().flush()?;
|
||||
is_first_content = false;
|
||||
is_reasoning = false;
|
||||
}
|
||||
print!("{}", ch);
|
||||
io::stdout().flush()?;
|
||||
assistant_response.push_str(&ch);
|
||||
}
|
||||
Ok(crate::chat::StreamChunk::Done) => {
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("\n\x1b[31m错误: {}\x1b[0m", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 确保输出完成
|
||||
if is_reasoning {
|
||||
// 如果没有收到内容,说明只有思维过程
|
||||
println!("\x1b[0m");
|
||||
} else {
|
||||
println!("\x1b[0m");
|
||||
}
|
||||
|
||||
// 添加AI回复到会话
|
||||
if !assistant_response.is_empty() {
|
||||
let mut session_guard = session.lock().unwrap();
|
||||
session_guard.add_message("assistant".to_string(), assistant_response);
|
||||
}
|
||||
|
||||
// AI回复完成后立即保存会话
|
||||
{
|
||||
let session_guard = session.lock().unwrap();
|
||||
if let Err(e) = cache_manager.save_session(&session_guard, &session_path) {
|
||||
eprintln!("\n警告: 保存会话失败: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("\n\x1b[31m发送消息失败: {}\x1b[0m", e);
|
||||
// 出错时移除最后添加的用户消息
|
||||
let mut session_guard = session.lock().unwrap();
|
||||
session_guard.messages.pop();
|
||||
message_count = message_count.saturating_sub(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 退出时保存会话
|
||||
let session_guard = session.lock().unwrap();
|
||||
if !session_guard.messages.is_empty() {
|
||||
println!("\n正在保存对话会话...");
|
||||
match cache_manager.save_session(&session_guard, &session_path) {
|
||||
Ok(_) => println!("✓ 会话已保存到 {:?}", session_path),
|
||||
Err(e) => eprintln!("保存会话失败: {}", e),
|
||||
}
|
||||
|
||||
// 显示会话统计
|
||||
let user_messages = session_guard.messages.iter().filter(|(role, _)| role == "user").count();
|
||||
let ai_messages = session_guard.messages.iter().filter(|(role, _)| role == "assistant").count();
|
||||
println!("本次对话共 {} 轮(你: {}, AI: {})", user_messages, user_messages, ai_messages);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
34
src/main.rs
Normal file
34
src/main.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
mod cache;
|
||||
mod chat;
|
||||
mod interactive;
|
||||
mod test;
|
||||
|
||||
use anyhow::Result;
|
||||
use std::env;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// 加载环境变量
|
||||
dotenvy::dotenv().ok();
|
||||
|
||||
// 获取命令行参数
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
match args.get(1).map(|s| s.as_str()) {
|
||||
Some("test") => {
|
||||
// 运行测试模式
|
||||
test::run_test().await?;
|
||||
}
|
||||
Some("interactive") | None => {
|
||||
// 运行交互式对话模式(默认)
|
||||
interactive::run_interactive().await?;
|
||||
}
|
||||
_ => {
|
||||
println!("用法: {} [test|interactive]", args[0]);
|
||||
println!(" test - 运行测试模式(单次对话)");
|
||||
println!(" interactive - 运行交互式对话模式(默认)");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
50
src/test.rs
Normal file
50
src/test.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use crate::chat::ChatClient;
|
||||
use anyhow::Result;
|
||||
use futures::StreamExt;
|
||||
|
||||
/// 测试AI对话功能
|
||||
pub async fn run_test() -> Result<()> {
|
||||
println!("=== 测试AI对话功能 ===\n");
|
||||
|
||||
// 创建AI对话客户端
|
||||
let client = ChatClient::new()?;
|
||||
|
||||
// 创建消息列表(只包含一条用户消息)
|
||||
let messages = vec![("user".to_string(), "你好!请介绍一下你自己。".to_string())];
|
||||
|
||||
// 发送消息并获取流式回复
|
||||
let mut stream = client.chat_stream(messages).await?;
|
||||
|
||||
let mut is_first_content = true;
|
||||
let mut is_reasoning = true;
|
||||
|
||||
print!("\x1b[90m思维过程: ");
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result? {
|
||||
crate::chat::StreamChunk::Reasoning(text) => {
|
||||
if is_reasoning {
|
||||
print!("{}", text);
|
||||
}
|
||||
}
|
||||
crate::chat::StreamChunk::Content(text) => {
|
||||
if is_first_content {
|
||||
println!("\x1b[0m");
|
||||
println!();
|
||||
print!("AI回复: ");
|
||||
is_first_content = false;
|
||||
is_reasoning = false;
|
||||
}
|
||||
print!("{}", text);
|
||||
}
|
||||
crate::chat::StreamChunk::Done => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!();
|
||||
println!("\n=== 测试完成 ===");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user