mirror of
https://github.com/awfufu/go-hurobot.git
synced 2026-03-01 05:29:43 +08:00
refactor: centralize LLM supplier config in DB and auto-switch model
This commit is contained in:
53
cmds/llm.go
53
cmds/llm.go
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
func cmd_llm(c *qbot.Client, msg *qbot.Message, args *ArgsList) {
|
||||
if args.Size < 2 {
|
||||
c.SendMsg(msg, "Usage:\nllm prompt [新提示词]\nllm max-history [能看见的历史消息数]\nllm enable/disable\nllm status")
|
||||
c.SendMsg(msg, "Usage:\nllm prompt [新提示词]\nllm max-history [能看见的历史消息数]\nllm enable/disable\nllm status\nllm model [模型]\nllm supplier [API供应商]")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -36,12 +36,12 @@ func cmd_llm(c *qbot.Client, msg *qbot.Message, args *ArgsList) {
|
||||
Supplier string
|
||||
Model string
|
||||
}{
|
||||
Prompt: "你是一个群聊机器人,请你陪伴群友们聊天,注意请不要使用Markdown语法。",
|
||||
Prompt: "",
|
||||
MaxHistory: 200,
|
||||
Enabled: true,
|
||||
Debug: false,
|
||||
Supplier: "siliconflow",
|
||||
Model: "deepseek-ai/DeepSeek-V2.5",
|
||||
Model: "deepseek-ai/DeepSeek-V3",
|
||||
}
|
||||
qbot.PsqlDB.Table("group_llm_configs").Create(map[string]any{
|
||||
"group_id": msg.GroupID,
|
||||
@@ -49,8 +49,8 @@ func cmd_llm(c *qbot.Client, msg *qbot.Message, args *ArgsList) {
|
||||
"max_history": llmConfig.MaxHistory,
|
||||
"enabled": llmConfig.Enabled,
|
||||
"debug": llmConfig.Debug,
|
||||
"supplier": "siliconflow",
|
||||
"model": "deepseek-ai/DeepSeek-V2.5",
|
||||
"supplier": llmConfig.Supplier,
|
||||
"model": llmConfig.Model,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -182,6 +182,49 @@ func cmd_llm(c *qbot.Client, msg *qbot.Message, args *ArgsList) {
|
||||
}
|
||||
}
|
||||
|
||||
case "supplier":
|
||||
if args.Size == 2 {
|
||||
c.SendMsg(msg, fmt.Sprintf("supplier: %s", llmConfig.Supplier))
|
||||
} else {
|
||||
newSupplier := args.Contents[2]
|
||||
|
||||
var exists int64
|
||||
qbot.PsqlDB.Table("suppliers").
|
||||
Where("name = ?", newSupplier).
|
||||
Count(&exists)
|
||||
if exists == 0 {
|
||||
c.SendMsg(msg, fmt.Sprintf("unknown supplier: %s", newSupplier))
|
||||
return
|
||||
}
|
||||
|
||||
var sup struct {
|
||||
DefaultModel string `psql:"default_model"`
|
||||
}
|
||||
qbot.PsqlDB.Table("suppliers").
|
||||
Select("default_model").
|
||||
Where("name = ?", newSupplier).
|
||||
Scan(&sup)
|
||||
|
||||
// Update supplier
|
||||
err := qbot.PsqlDB.Table("group_llm_configs").
|
||||
Where("group_id = ?", msg.GroupID).
|
||||
Update("supplier", newSupplier).Error
|
||||
if err != nil {
|
||||
c.SendMsg(msg, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Auto-switch model to supplier default if provided
|
||||
if strings.TrimSpace(sup.DefaultModel) != "" {
|
||||
_ = qbot.PsqlDB.Table("group_llm_configs").
|
||||
Where("group_id = ?", msg.GroupID).
|
||||
Update("model", sup.DefaultModel).Error
|
||||
c.SendMsg(msg, fmt.Sprintf("supplier updated to %s, model -> %s", newSupplier, sup.DefaultModel))
|
||||
} else {
|
||||
c.SendMsg(msg, fmt.Sprintf("supplier updated to %s", newSupplier))
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
c.SendMsg(msg, fmt.Sprintf("Unrecognized parameter >>%s<<", args.Contents[1]))
|
||||
}
|
||||
|
||||
@@ -7,6 +7,14 @@ CREATE TABLE users (
|
||||
PRIMARY KEY ("user_id")
|
||||
);
|
||||
|
||||
CREATE TABLE suppliers (
|
||||
"name" TEXT NOT NULL,
|
||||
"base_url" TEXT NOT NULL,
|
||||
"api_key" TEXT,
|
||||
"default_model" TEXT,
|
||||
PRIMARY KEY ("name")
|
||||
);
|
||||
|
||||
CREATE TABLE messages (
|
||||
"msg_id" BIGINT NOT NULL,
|
||||
"user_id" BIGINT NOT NULL,
|
||||
@@ -29,7 +37,8 @@ CREATE TABLE group_llm_configs (
|
||||
"debug" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
"supplier" TEXT,
|
||||
"model" TEXT,
|
||||
PRIMARY KEY ("group_id")
|
||||
PRIMARY KEY ("group_id"),
|
||||
FOREIGN KEY ("supplier") REFERENCES suppliers(name)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_messages_covering ON messages("group_id", "is_cmd", "time" DESC, "user_id", "content", "msg_id");
|
||||
@@ -54,3 +63,6 @@ CREATE TABLE group_rcon_configs (
|
||||
"enabled" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
PRIMARY KEY ("group_id")
|
||||
);
|
||||
|
||||
INSERT INTO suppliers ("name", "base_url", "api_key", "default_model") VALUES
|
||||
('siliconflow', 'https://api.siliconflow.cn/v1', '', 'deepseek-ai/DeepSeek-V3');
|
||||
|
||||
34
llm/llm.go
34
llm/llm.go
@@ -17,17 +17,33 @@ import (
|
||||
func SendLLMRequest(supplier string, messages []openai.ChatCompletionMessageParamUnion, model string, temperature float64) (*openai.ChatCompletion, error) {
|
||||
var client *openai.Client
|
||||
|
||||
switch supplier {
|
||||
case "siliconflow":
|
||||
clientVal := openai.NewClient(
|
||||
option.WithBaseURL("https://api.siliconflow.cn/v1"),
|
||||
option.WithAPIKey(config.ApiKey),
|
||||
)
|
||||
client = &clientVal
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid supplier: %s", supplier)
|
||||
var supplierConf struct {
|
||||
BaseURL string `psql:"base_url"`
|
||||
APIKey string `psql:"api_key"`
|
||||
}
|
||||
|
||||
err := qbot.PsqlDB.Table("suppliers").
|
||||
Select("base_url, api_key").
|
||||
Where("name = ?", supplier).
|
||||
First(&supplierConf).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("supplier not found: %s", supplier)
|
||||
}
|
||||
|
||||
apiKey := supplierConf.APIKey
|
||||
if apiKey == "" {
|
||||
apiKey = config.ApiKey
|
||||
}
|
||||
if supplierConf.BaseURL == "" {
|
||||
return nil, fmt.Errorf("supplier %s base_url is empty", supplier)
|
||||
}
|
||||
|
||||
clientVal := openai.NewClient(
|
||||
option.WithBaseURL(supplierConf.BaseURL),
|
||||
option.WithAPIKey(apiKey),
|
||||
)
|
||||
client = &clientVal
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := client.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
|
||||
|
||||
Reference in New Issue
Block a user