```
feat(llm): 引入多LLM连接器支持及配置管理 - 添加OpenAI和Ollama的Spring AI Starter依赖 - 实现LLM连接器接口及工厂模式,支持动态加载不同LLM配置 - 新增LLMController提供统一REST接口用于文本生成和流式输出- 配置文件中增加ai相关配置项,支持多模型参数设置 - 更新pom.xml引入Spring Milestones和Snapshots仓库- 升级spring-ai版本至1.0.0-M6 - 移除旧有的DeepSeek和Ollama直接依赖,改为通过连接器调用 - 注释掉部分未使用的测试代码与配置说明 ```
This commit is contained in:
parent
7e43fd59a4
commit
46d676a8c1
20
pom.xml
20
pom.xml
|
@ -28,6 +28,22 @@
|
||||||
<enabled>false</enabled>
|
<enabled>false</enabled>
|
||||||
</snapshots>
|
</snapshots>
|
||||||
</repository>
|
</repository>
|
||||||
|
<repository>
|
||||||
|
<id>spring-milestones</id>
|
||||||
|
<name>Spring Milestones</name>
|
||||||
|
<url>https://repo.spring.io/milestone</url>
|
||||||
|
<snapshots>
|
||||||
|
<enabled>false</enabled>
|
||||||
|
</snapshots>
|
||||||
|
</repository>
|
||||||
|
<repository>
|
||||||
|
<id>spring-snapshots</id>
|
||||||
|
<name>Spring Snapshots</name>
|
||||||
|
<url>https://repo.spring.io/snapshot</url>
|
||||||
|
<releases>
|
||||||
|
<enabled>false</enabled>
|
||||||
|
</releases>
|
||||||
|
</repository>
|
||||||
</repositories>
|
</repositories>
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
|
@ -35,7 +51,7 @@
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
<maven.compiler.source>17</maven.compiler.source>
|
<maven.compiler.source>17</maven.compiler.source>
|
||||||
<maven.compiler.target>17</maven.compiler.target>
|
<maven.compiler.target>17</maven.compiler.target>
|
||||||
<spring-ai.version>1.0.0</spring-ai.version>
|
<spring-ai.version>1.0.0-M6</spring-ai.version>
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
|
@ -172,6 +188,8 @@
|
||||||
</plugins>
|
</plugins>
|
||||||
</build>
|
</build>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>dev</id>
|
<id>dev</id>
|
||||||
|
|
|
@ -12,19 +12,6 @@
|
||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
|
||||||
<groupId>org.springframework.ai</groupId>
|
|
||||||
<artifactId>spring-ai-ollama</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.springframework.ai</groupId>
|
|
||||||
<artifactId>spring-ai-starter-model-deepseek</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.springframework.boot</groupId>
|
|
||||||
<artifactId>spring-boot-starter-web</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.springframework.boot</groupId>
|
<groupId>org.springframework.boot</groupId>
|
||||||
<artifactId>spring-boot-starter-test</artifactId>
|
<artifactId>spring-boot-starter-test</artifactId>
|
||||||
|
@ -143,4 +130,6 @@
|
||||||
</plugins>
|
</plugins>
|
||||||
</build>
|
</build>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
/**
|
|
||||||
* 1. 用于管理引入的Jar所需的资源启动或者初始化处理
|
|
||||||
* 2. 如果有AOP切面,可以再建一个aop包,来写切面逻辑
|
|
||||||
*/
|
|
||||||
package com.touka.config;
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ thread:
|
||||||
block-queue-size: 5000
|
block-queue-size: 5000
|
||||||
policy: CallerRunsPolicy
|
policy: CallerRunsPolicy
|
||||||
|
|
||||||
# 数据库配置;启动时配置数据库资源信息
|
# 数据库配置
|
||||||
spring:
|
spring:
|
||||||
datasource:
|
datasource:
|
||||||
username: root
|
username: root
|
||||||
|
@ -21,25 +21,45 @@ spring:
|
||||||
driver-class-name: com.mysql.cj.jdbc.Driver
|
driver-class-name: com.mysql.cj.jdbc.Driver
|
||||||
hikari:
|
hikari:
|
||||||
pool-name: Retail_HikariCP
|
pool-name: Retail_HikariCP
|
||||||
minimum-idle: 15 #最小空闲连接数量
|
minimum-idle: 15
|
||||||
idle-timeout: 180000 #空闲连接存活最大时间,默认600000(10分钟)
|
idle-timeout: 180000
|
||||||
maximum-pool-size: 25 #连接池最大连接数,默认是10
|
maximum-pool-size: 25
|
||||||
auto-commit: true #此属性控制从池返回的连接的默认自动提交行为,默认值:true
|
auto-commit: true
|
||||||
max-lifetime: 1800000 #此属性控制池中连接的最长生命周期,值0表示无限生命周期,默认1800000即30分钟
|
max-lifetime: 1800000
|
||||||
connection-timeout: 30000 #数据库连接超时时间,默认30秒,即30000
|
connection-timeout: 30000
|
||||||
connection-test-query: SELECT 1
|
connection-test-query: SELECT 1
|
||||||
type: com.zaxxer.hikari.HikariDataSource
|
type: com.zaxxer.hikari.HikariDataSource
|
||||||
|
|
||||||
ai:
|
ai:
|
||||||
deepseek:
|
deepseek:
|
||||||
api-key: your-api-key
|
api-key: your-api-key
|
||||||
base-url: https://api.deepseek.com # DeepSeek 的请求 URL, 可不填,默认值为 api.deepseek.com
|
base-url: https://api.deepseek.com
|
||||||
chat:
|
chat:
|
||||||
options:
|
options:
|
||||||
model: deepseek-reasoner # 使用深度思考模型
|
model: deepseek-reasoner
|
||||||
temperature: 0.8 # 温度值
|
temperature: 0.8
|
||||||
|
openai: # 注意:这个应该在spring: ai: 下面,而不是与spring平级
|
||||||
|
api-key: sk-EcUhlzeo44aQLBJqNncwBxw8yQuAE8rlbe8fNWLC3RXxfDGq
|
||||||
|
base-url: https://api.chatanywhere.tech # 您使用的是第三方API
|
||||||
|
chat:
|
||||||
|
options:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
temperature: 0.7
|
||||||
|
max-tokens: 2048
|
||||||
|
|
||||||
# MyBatis 配置【如需使用记得打开】
|
llm:
|
||||||
|
# 默认连接器类型
|
||||||
|
default-connector: openai
|
||||||
|
# 各连接器详细配置
|
||||||
|
connectors:
|
||||||
|
# OpenAI连接器配置
|
||||||
|
openai:
|
||||||
|
api-key: sk-EcUhlzeo44aQLBJqNncwBxw8yQuAE8rlbe8fNWLC3RXxfDGq
|
||||||
|
base-url: https://api.chatanywhere.tech
|
||||||
|
timeout-ms: 30000
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
# MyBatis 配置
|
||||||
#mybatis:
|
#mybatis:
|
||||||
# mapper-locations: classpath:/mybatis/mapper/*.xml
|
# mapper-locations: classpath:/mybatis/mapper/*.xml
|
||||||
# config-location: classpath:/mybatis/config/mybatis-config.xml
|
# config-location: classpath:/mybatis/config/mybatis-config.xml
|
||||||
|
@ -49,4 +69,3 @@ logging:
|
||||||
level:
|
level:
|
||||||
root: info
|
root: info
|
||||||
config: classpath:logback-spring.xml
|
config: classpath:logback-spring.xml
|
||||||
|
|
||||||
|
|
|
@ -1,81 +1,82 @@
|
||||||
package com.touka.test;
|
//package com.touka.test;
|
||||||
|
//
|
||||||
import jakarta.annotation.Resource;
|
//import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
//import lombok.extern.slf4j.Slf4j;
|
||||||
import org.junit.Test;
|
//import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
//import org.junit.runner.RunWith;
|
||||||
import org.springframework.ai.chat.model.ChatResponse;
|
//import org.springframework.ai.chat.model.ChatResponse;
|
||||||
|
//
|
||||||
import org.springframework.ai.chat.prompt.Prompt;
|
//import org.springframework.ai.chat.prompt.Prompt;
|
||||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
//import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||||
import org.springframework.ai.deepseek.DeepSeekChatModel;
|
//
|
||||||
|
//
|
||||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
//import org.springframework.ai.deepseek.DeepSeekChatModel;
|
||||||
|
//import org.springframework.ai.ollama.api.OllamaOptions;
|
||||||
import org.springframework.boot.test.context.SpringBootTest;
|
//
|
||||||
import org.springframework.test.context.junit4.SpringRunner;
|
//import org.springframework.boot.test.context.SpringBootTest;
|
||||||
|
//import org.springframework.test.context.junit4.SpringRunner;
|
||||||
import reactor.core.publisher.Flux;
|
//
|
||||||
|
//import reactor.core.publisher.Flux;
|
||||||
@Slf4j
|
//
|
||||||
@RunWith(SpringRunner.class)
|
//@Slf4j
|
||||||
@SpringBootTest
|
//@RunWith(SpringRunner.class)
|
||||||
public class ApiTest {
|
//@SpringBootTest
|
||||||
|
//public class ApiTest {
|
||||||
@Resource
|
//
|
||||||
private DeepSeekChatModel chatModel;
|
// @Resource
|
||||||
|
// private DeepSeekChatModel chatModel;
|
||||||
/**
|
//
|
||||||
* 测试同步生成响应
|
// /**
|
||||||
*/
|
// * 测试同步生成响应
|
||||||
@Test
|
// */
|
||||||
public void testGenerate() {
|
// @Test
|
||||||
String message = "Tell me a joke";
|
// public void testGenerate() {
|
||||||
|
// String message = "Tell me a joke";
|
||||||
// 方式1: 直接传入字符串
|
//
|
||||||
String response1 = chatModel.call(message);
|
// // 方式1: 直接传入字符串
|
||||||
System.out.println("Response1: " + response1);
|
// String response1 = chatModel.call(message);
|
||||||
|
// System.out.println("Response1: " + response1);
|
||||||
// 方式2: 使用 Prompt
|
//
|
||||||
ChatResponse response2 = chatModel.call(new Prompt(message));
|
// // 方式2: 使用 Prompt
|
||||||
System.out.println("Response2: " + response2.getResult().getOutput().getText());
|
// ChatResponse response2 = chatModel.call(new Prompt(message));
|
||||||
}
|
// System.out.println("Response2: " + response2.getResult().getOutput().getText());
|
||||||
|
// }
|
||||||
/**
|
//
|
||||||
* 测试流式生成响应
|
// /**
|
||||||
*/
|
// * 测试流式生成响应
|
||||||
@Test
|
// */
|
||||||
public void testGenerateStream() {
|
// @Test
|
||||||
String message = "Tell me a joke";
|
// public void testGenerateStream() {
|
||||||
|
// String message = "Tell me a joke";
|
||||||
// 使用 PromptTemplate 构建提示词
|
//
|
||||||
Prompt prompt = new PromptTemplate(message).create();
|
// // 使用 PromptTemplate 构建提示词
|
||||||
|
// Prompt prompt = new PromptTemplate(message).create();
|
||||||
// 流式输出
|
//
|
||||||
Flux<ChatResponse> responseFlux = chatModel.stream(prompt);
|
// // 流式输出
|
||||||
|
// Flux<ChatResponse> responseFlux = chatModel.stream(prompt);
|
||||||
// 订阅并打印流式响应
|
//
|
||||||
responseFlux.doOnNext(response -> {
|
// // 订阅并打印流式响应
|
||||||
System.out.println("Stream Response: " + response.getResult().getOutput().getText());
|
// responseFlux.doOnNext(response -> {
|
||||||
}).blockLast(); // 在测试中阻塞等待完成
|
// System.out.println("Stream Response: " + response.getResult().getOutput().getText());
|
||||||
}
|
// }).blockLast(); // 在测试中阻塞等待完成
|
||||||
|
// }
|
||||||
/**
|
//
|
||||||
* 测试使用特定模型选项
|
// /**
|
||||||
*/
|
// * 测试使用特定模型选项
|
||||||
@Test
|
// */
|
||||||
public void testGenerateWithModelOptions() {
|
// @Test
|
||||||
String message = "1+1=?";
|
// public void testGenerateWithModelOptions() {
|
||||||
String model = "deepseek-r1:1.5b";
|
// String message = "1+1=?";
|
||||||
|
// String model = "deepseek-r1:1.5b";
|
||||||
ChatResponse response = chatModel.call(new Prompt(
|
//
|
||||||
message,
|
// ChatResponse response = chatModel.call(new Prompt(
|
||||||
OllamaOptions.builder()
|
// message,
|
||||||
.model(model)
|
// OllamaOptions.builder()
|
||||||
.build()
|
// .model(model)
|
||||||
));
|
// .build()
|
||||||
|
// ));
|
||||||
System.out.println("Response with model options: " + response.getResult().getOutput().getText());
|
//
|
||||||
}
|
// System.out.println("Response with model options: " + response.getResult().getOutput().getText());
|
||||||
|
// }
|
||||||
}
|
//
|
||||||
|
//}
|
||||||
|
|
|
@ -10,6 +10,16 @@
|
||||||
<artifactId>visual-novel-server-trigger</artifactId>
|
<artifactId>visual-novel-server-trigger</artifactId>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.ai</groupId>
|
||||||
|
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.ai</groupId>
|
||||||
|
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.springframework.boot</groupId>
|
<groupId>org.springframework.boot</groupId>
|
||||||
<artifactId>spring-boot-starter-web</artifactId>
|
<artifactId>spring-boot-starter-web</artifactId>
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
package com.touka.trigger.http.controller;
|
||||||
|
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import com.touka.trigger.http.service.LLMService;
|
||||||
|
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.web.bind.annotation.GetMapping;
|
||||||
|
import org.springframework.web.bind.annotation.PostMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RequestBody;
|
||||||
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RequestParam;
|
||||||
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LLM控制器,提供REST API接口
|
||||||
|
*/
|
||||||
|
@RestController
|
||||||
|
@RequestMapping("/api/llm")
|
||||||
|
public class LLMController {
|
||||||
|
|
||||||
|
private final LLMService llmService;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
public LLMController(LLMService llmService) {
|
||||||
|
this.llmService = llmService;
|
||||||
|
}
|
||||||
|
|
||||||
|
@PostMapping("/generate")
|
||||||
|
public String generate(@RequestBody Map<String, Object> request) {
|
||||||
|
String prompt = (String) request.get("prompt");
|
||||||
|
if (prompt == null) {
|
||||||
|
throw new IllegalArgumentException("Prompt is required");
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Map<String, Object> params = (Map<String, Object>) request.getOrDefault("params", Map.of());
|
||||||
|
|
||||||
|
String connectorType = (String) request.get("connectorType");
|
||||||
|
if (connectorType != null) {
|
||||||
|
return llmService.generateText(connectorType, prompt, params);
|
||||||
|
} else {
|
||||||
|
return llmService.generateText(prompt, params);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@GetMapping("/stream")
|
||||||
|
public SseEmitter stream(
|
||||||
|
@RequestParam String prompt,
|
||||||
|
@RequestParam(required = false) String connectorType) {
|
||||||
|
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
|
||||||
|
llmService.streamText(prompt, Map.of(), emitter);
|
||||||
|
return emitter;
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
@GetMapping("/connectors")
|
||||||
|
public Map<String, Object> getConnectors() {
|
||||||
|
return (Map<String, Object>) (Map<?, ?>) llmService.getAllAvailableConnectors();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
//package com.touka.trigger.http.controller;
|
||||||
|
//
|
||||||
|
//import jakarta.annotation.Resource;
|
||||||
|
//import org.springframework.ai.chat.messages.UserMessage;
|
||||||
|
//import org.springframework.ai.chat.prompt.Prompt;
|
||||||
|
//import org.springframework.ai.deepseek.DeepSeekChatModel;
|
||||||
|
//import org.springframework.web.bind.annotation.CrossOrigin;
|
||||||
|
//import org.springframework.web.bind.annotation.GetMapping;
|
||||||
|
//import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
|
//import org.springframework.web.bind.annotation.RequestParam;
|
||||||
|
//import org.springframework.web.bind.annotation.RestController;
|
||||||
|
//import reactor.core.publisher.Flux;
|
||||||
|
//
|
||||||
|
//import java.util.Map;
|
||||||
|
//
|
||||||
|
//@RestController()
|
||||||
|
//@CrossOrigin("*")
|
||||||
|
//@RequestMapping("/api/ollama/")
|
||||||
|
//public class OllamaController {
|
||||||
|
// @Resource
|
||||||
|
// private DeepSeekChatModel chatModel;
|
||||||
|
//
|
||||||
|
// @GetMapping("/generate")
|
||||||
|
// public Map generate(@RequestParam(value = "message") String message) {
|
||||||
|
// return Map.of("generation", chatModel.call(message));
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * 流式对话
|
||||||
|
// * @param message
|
||||||
|
// * @return
|
||||||
|
// */
|
||||||
|
// @GetMapping(value = "/generateStream", produces = "text/html;charset=utf-8")
|
||||||
|
// public Flux<String> generateStream(@RequestParam(value = "message") String message) {
|
||||||
|
// // 构建提示词
|
||||||
|
// Prompt prompt = new Prompt(new UserMessage(message));
|
||||||
|
// // 流式输出
|
||||||
|
// return chatModel.stream(prompt)
|
||||||
|
// .mapNotNull(chatResponse -> chatResponse.getResult().getOutput().getText());
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
//}
|
|
@ -0,0 +1,148 @@
|
||||||
|
package com.touka.trigger.http.service;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import com.touka.types.config.ILLMConnector;
|
||||||
|
import com.touka.types.config.LLMConnectorFactory;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LLM服务,提供简化的接口来调用不同的LLM连接器
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
public class LLMService {
|
||||||
|
|
||||||
|
private final LLMConnectorFactory connectorFactory;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
public LLMService(LLMConnectorFactory connectorFactory) {
|
||||||
|
this.connectorFactory = connectorFactory;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用默认连接器生成文本
|
||||||
|
* @param prompt 提示词
|
||||||
|
* @return 生成的文本
|
||||||
|
*/
|
||||||
|
public String generateText(String prompt) {
|
||||||
|
return generateText(prompt, Collections.emptyMap());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用默认连接器生成文本
|
||||||
|
* @param prompt 提示词
|
||||||
|
* @param params 额外参数
|
||||||
|
* @return 生成的文本
|
||||||
|
*/
|
||||||
|
public String generateText(String prompt, Map<String, Object> params) {
|
||||||
|
ILLMConnector connector = connectorFactory.getDefaultConnector();
|
||||||
|
if (connector == null) {
|
||||||
|
throw new IllegalStateException("No available LLM connector");
|
||||||
|
}
|
||||||
|
return connector.generateText(prompt, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用指定连接器生成文本
|
||||||
|
* @param connectorType 连接器类型
|
||||||
|
* @param prompt 提示词
|
||||||
|
* @param params 额外参数
|
||||||
|
* @return 生成的文本
|
||||||
|
*/
|
||||||
|
public String generateText(String connectorType, String prompt, Map<String, Object> params) {
|
||||||
|
ILLMConnector connector = connectorFactory.getConnector(connectorType);
|
||||||
|
if (connector == null || !connector.isAvailable()) {
|
||||||
|
throw new IllegalArgumentException("Connector not found or not available: " + connectorType);
|
||||||
|
}
|
||||||
|
return connector.generateText(prompt, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 流式生成文本
|
||||||
|
* @param prompt 提示词
|
||||||
|
* @param emitter SSE发射器
|
||||||
|
*/
|
||||||
|
public void streamText(String prompt, SseEmitter emitter) {
|
||||||
|
streamText(prompt, Collections.emptyMap(), emitter);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 流式生成文本
|
||||||
|
* @param prompt 提示词
|
||||||
|
* @param params 额外参数
|
||||||
|
* @param emitter SSE发射器
|
||||||
|
*/
|
||||||
|
public void streamText(String prompt, Map<String, Object> params, SseEmitter emitter) {
|
||||||
|
ILLMConnector connector = connectorFactory.getDefaultConnector();
|
||||||
|
if (connector == null) {
|
||||||
|
emitter.completeWithError(new IllegalStateException("No available LLM connector"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用Flux处理流式响应
|
||||||
|
Flux<String> textFlux = connector.streamText(prompt, params);
|
||||||
|
|
||||||
|
textFlux.subscribe(
|
||||||
|
content -> {
|
||||||
|
try {
|
||||||
|
emitter.send(SseEmitter.event().data(content));
|
||||||
|
} catch (Exception e) {
|
||||||
|
emitter.completeWithError(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
error -> emitter.completeWithError(error),
|
||||||
|
() -> emitter.complete()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用指定连接器流式生成文本
|
||||||
|
* @param connectorType 连接器类型
|
||||||
|
* @param prompt 提示词
|
||||||
|
* @param params 额外参数
|
||||||
|
* @param emitter SSE发射器
|
||||||
|
*/
|
||||||
|
public void streamText(String connectorType, String prompt, Map<String, Object> params, SseEmitter emitter) {
|
||||||
|
ILLMConnector connector = connectorFactory.getConnector(connectorType);
|
||||||
|
if (connector == null || !connector.isAvailable()) {
|
||||||
|
emitter.completeWithError(new IllegalArgumentException("Connector not found or not available: " + connectorType));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用Flux处理流式响应
|
||||||
|
Flux<String> textFlux = connector.streamText(prompt, params);
|
||||||
|
|
||||||
|
textFlux.subscribe(
|
||||||
|
content -> {
|
||||||
|
try {
|
||||||
|
emitter.send(SseEmitter.event().data(content));
|
||||||
|
} catch (Exception e) {
|
||||||
|
emitter.completeWithError(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
error -> emitter.completeWithError(error),
|
||||||
|
() -> emitter.complete()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取所有可用的连接器类型
|
||||||
|
* @return 连接器类型列表
|
||||||
|
*/
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public Map<String, Object> getAllAvailableConnectors() {
|
||||||
|
Map<String, Object> availableConnectors = new HashMap<>();
|
||||||
|
connectorFactory.getAllConnectors().forEach((type, connector) -> {
|
||||||
|
if (connector.isAvailable()) {
|
||||||
|
availableConnectors.put(type, connector);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return availableConnectors;
|
||||||
|
}
|
||||||
|
}
|
|
@ -10,6 +10,22 @@
|
||||||
<artifactId>visual-novel-server-types</artifactId>
|
<artifactId>visual-novel-server-types</artifactId>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.ai</groupId>
|
||||||
|
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.ai</groupId>
|
||||||
|
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<!-- 添加OpenAI SDK依赖 -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.theokanning.openai-gpt3-java</groupId>
|
||||||
|
<artifactId>client</artifactId>
|
||||||
|
<version>0.18.0</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.springframework.boot</groupId>
|
<groupId>org.springframework.boot</groupId>
|
||||||
<artifactId>spring-boot-starter-web</artifactId>
|
<artifactId>spring-boot-starter-web</artifactId>
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
package com.touka.types.config;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LLM连接器的基础实现,提供一些通用功能
|
||||||
|
*/
|
||||||
|
public abstract class BaseLLMConnector implements ILLMConnector {
|
||||||
|
|
||||||
|
protected String apiKey;
|
||||||
|
protected String baseUrl;
|
||||||
|
protected int timeoutMs;
|
||||||
|
protected boolean enabled;
|
||||||
|
|
||||||
|
public BaseLLMConnector(String apiKey, String baseUrl, int timeoutMs) {
|
||||||
|
this.apiKey = apiKey;
|
||||||
|
this.baseUrl = baseUrl;
|
||||||
|
this.timeoutMs = timeoutMs;
|
||||||
|
this.enabled = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isAvailable() {
|
||||||
|
return enabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setEnabled(boolean enabled) {
|
||||||
|
this.enabled = enabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getApiKey() {
|
||||||
|
return apiKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setApiKey(String apiKey) {
|
||||||
|
this.apiKey = apiKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getBaseUrl() {
|
||||||
|
return baseUrl;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setBaseUrl(String baseUrl) {
|
||||||
|
this.baseUrl = baseUrl;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getTimeoutMs() {
|
||||||
|
return timeoutMs;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTimeoutMs(int timeoutMs) {
|
||||||
|
this.timeoutMs = timeoutMs;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
package com.touka.types.config;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.springframework.ai.chat.model.ChatResponse;
|
||||||
|
import org.springframework.ai.chat.prompt.Prompt;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 大语言模型连接器接口,定义了与各类LLM API交互的统一方法
|
||||||
|
*/
|
||||||
|
public interface ILLMConnector {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用Prompt对象生成文本响应
|
||||||
|
* @param prompt 提示对象
|
||||||
|
* @return 聊天响应
|
||||||
|
*/
|
||||||
|
ChatResponse generateResponse(Prompt prompt);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用文本提示生成文本
|
||||||
|
* @param textPrompt 文本提示
|
||||||
|
* @param params 额外参数
|
||||||
|
* @return 生成的文本内容
|
||||||
|
*/
|
||||||
|
String generateText(String textPrompt, Map<String, Object> params);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 流式生成文本响应
|
||||||
|
* @param prompt 提示对象
|
||||||
|
* @return 响应流
|
||||||
|
*/
|
||||||
|
Flux<ChatResponse> streamResponse(Prompt prompt);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 流式生成文本
|
||||||
|
* @param textPrompt 文本提示
|
||||||
|
* @param params 额外参数
|
||||||
|
* @return 文本流
|
||||||
|
*/
|
||||||
|
Flux<String> streamText(String textPrompt, Map<String, Object> params);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取连接器类型
|
||||||
|
* @return 连接器类型标识
|
||||||
|
*/
|
||||||
|
String getConnectorType();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查连接器是否可用
|
||||||
|
* @return 连接器是否可用
|
||||||
|
*/
|
||||||
|
boolean isAvailable();
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
package com.touka.types.config;
|
||||||
|
|
||||||
|
import org.springframework.ai.chat.client.ChatClient;
|
||||||
|
import org.springframework.ai.openai.OpenAiChatModel;
|
||||||
|
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||||
|
import org.springframework.ai.openai.api.OpenAiApi;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.context.annotation.Bean;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
|
||||||
|
// 在配置类中添加
|
||||||
|
@Configuration
|
||||||
|
public class LLMConfig {
|
||||||
|
@Bean
|
||||||
|
public ChatClient chatClient(ChatClient.Builder builder) {
|
||||||
|
return builder.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
public ChatClient.Builder chatClientBuilder(OpenAiChatModel openAiChatModel) {
|
||||||
|
// 根据配置创建Builder
|
||||||
|
return ChatClient.builder(openAiChatModel);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
public OpenAiChatModel openAiChatModel(OpenAiApi openAiApi) {
|
||||||
|
OpenAiChatOptions options = OpenAiChatOptions.builder()
|
||||||
|
.model("gpt-3.5-turbo")
|
||||||
|
.temperature(0.7)
|
||||||
|
.maxCompletionTokens(2048)
|
||||||
|
.build();
|
||||||
|
return new OpenAiChatModel(openAiApi, options);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
public OpenAiApi openAiApi(@Value("${spring.ai.openai.api-key}") String apiKey, @Value("${spring.ai.openai.base-url}") String baseUrl) {
|
||||||
|
return OpenAiApi.builder().
|
||||||
|
baseUrl(baseUrl).
|
||||||
|
apiKey(apiKey).
|
||||||
|
build();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,34 @@
|
||||||
|
package com.touka.types.config;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LLM连接器的配置类,用于从配置文件中读取连接器配置
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Component
|
||||||
|
@ConfigurationProperties(prefix = "llm")
|
||||||
|
public class LLMConnectorConfig {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 默认连接器类型
|
||||||
|
*/
|
||||||
|
private String defaultConnector;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 各连接器的详细配置
|
||||||
|
*/
|
||||||
|
private Map<String, ConnectorDetailConfig> connectors;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class ConnectorDetailConfig {
|
||||||
|
private String apiKey;
|
||||||
|
private String baseUrl;
|
||||||
|
private int timeoutMs = 30000;
|
||||||
|
private boolean enabled = true;
|
||||||
|
private Map<String, Object> additionalParams;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,124 @@
|
||||||
|
package com.touka.types.config;
|
||||||
|
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
|
||||||
|
import com.touka.types.config.impl.OpenAIConnector;
|
||||||
|
import org.springframework.ai.chat.client.ChatClient;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LLM连接器工厂,用于创建和管理不同类型的连接器实例
|
||||||
|
*/
|
||||||
|
@Component
|
||||||
|
public class LLMConnectorFactory {
|
||||||
|
|
||||||
|
@Value("${spring.ai.openai.api-key:}")
|
||||||
|
private String defaultApiKey;
|
||||||
|
|
||||||
|
@Value("${spring.ai.openai.base-url:https://api.openai.com/v1}")
|
||||||
|
private String defaultBaseUrl;
|
||||||
|
|
||||||
|
private final LLMConnectorConfig config;
|
||||||
|
private final ChatClient chatClient;
|
||||||
|
private final Map<String, ILLMConnector> connectors = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
public LLMConnectorFactory(LLMConnectorConfig config, ChatClient chatClient) {
|
||||||
|
this.config = config;
|
||||||
|
this.chatClient = chatClient;
|
||||||
|
initializeConnectors();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initializeConnectors() {
|
||||||
|
// 注册默认连接器
|
||||||
|
registerDefaultConnectors();
|
||||||
|
|
||||||
|
// 根据配置初始化连接器
|
||||||
|
if (config.getConnectors() != null) {
|
||||||
|
for (Map.Entry<String, LLMConnectorConfig.ConnectorDetailConfig> entry : config.getConnectors().entrySet()) {
|
||||||
|
String connectorType = entry.getKey();
|
||||||
|
LLMConnectorConfig.ConnectorDetailConfig detailConfig = entry.getValue();
|
||||||
|
|
||||||
|
if (detailConfig.isEnabled()) {
|
||||||
|
ILLMConnector connector = createConnector(connectorType, detailConfig);
|
||||||
|
if (connector != null) {
|
||||||
|
connectors.put(connectorType, connector);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void registerDefaultConnectors() {
|
||||||
|
// 注册默认连接器实现,传入所有必要的参数
|
||||||
|
connectors.put("openai", new OpenAIConnector(chatClient, defaultApiKey, defaultBaseUrl));
|
||||||
|
}
|
||||||
|
|
||||||
|
private ILLMConnector createConnector(String connectorType, LLMConnectorConfig.ConnectorDetailConfig detailConfig) {
|
||||||
|
switch (connectorType.toLowerCase()) {
|
||||||
|
case "openai":
|
||||||
|
return new OpenAIConnector(
|
||||||
|
detailConfig.getApiKey(),
|
||||||
|
detailConfig.getBaseUrl(),
|
||||||
|
detailConfig.getTimeoutMs(),
|
||||||
|
chatClient);
|
||||||
|
default:
|
||||||
|
// 可以扩展支持更多连接器类型
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取指定类型的连接器
|
||||||
|
* @param connectorType 连接器类型
|
||||||
|
* @return 连接器实例,如果不存在则返回null
|
||||||
|
*/
|
||||||
|
public ILLMConnector getConnector(String connectorType) {
|
||||||
|
return connectors.get(connectorType);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取默认连接器
|
||||||
|
* @return 默认连接器实例
|
||||||
|
*/
|
||||||
|
public ILLMConnector getDefaultConnector() {
|
||||||
|
String defaultType = config.getDefaultConnector();
|
||||||
|
if (defaultType != null) {
|
||||||
|
ILLMConnector connector = connectors.get(defaultType);
|
||||||
|
if (connector != null && connector.isAvailable()) {
|
||||||
|
return connector;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有配置默认连接器或默认连接器不可用,返回第一个可用的连接器
|
||||||
|
return connectors.values().stream()
|
||||||
|
.filter(ILLMConnector::isAvailable)
|
||||||
|
.findFirst()
|
||||||
|
.orElse(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 注册新的连接器
|
||||||
|
* @param connector 连接器实例
|
||||||
|
*/
|
||||||
|
public void registerConnector(ILLMConnector connector) {
|
||||||
|
if (connector != null) {
|
||||||
|
connectors.put(connector.getConnectorType(), connector);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取所有可用的连接器类型
|
||||||
|
* @return 连接器类型列表
|
||||||
|
*/
|
||||||
|
public Map<String, ILLMConnector> getAllConnectors() {
|
||||||
|
return new ConcurrentHashMap<>(connectors);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,159 @@
|
||||||
|
package com.touka.types.config.impl;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import com.touka.types.config.BaseLLMConnector;
|
||||||
|
import org.springframework.ai.chat.client.ChatClient;
|
||||||
|
import org.springframework.ai.chat.model.ChatResponse;
|
||||||
|
import org.springframework.ai.chat.model.Generation;
|
||||||
|
import org.springframework.ai.chat.prompt.Prompt;
|
||||||
|
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
|
||||||
|
import org.springframework.ai.chat.messages.Message;
|
||||||
|
import org.springframework.ai.chat.messages.UserMessage;
|
||||||
|
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OpenAI API连接器实现,基于Spring AI的ChatModel和ChatClient
|
||||||
|
*/
|
||||||
|
@Component
|
||||||
|
public class OpenAIConnector extends BaseLLMConnector {
|
||||||
|
|
||||||
|
private static final String CONNECTOR_TYPE = "openai";
|
||||||
|
|
||||||
|
private final ChatClient chatClient;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
public OpenAIConnector(ChatClient chatClient,
|
||||||
|
@Value("${spring.ai.openai.api-key:}") String apiKey,
|
||||||
|
@Value("${spring.ai.openai.base-url:https://api.openai.com/v1}") String baseUrl) {
|
||||||
|
// 直接在构造函数中使用注入的值
|
||||||
|
super(apiKey, baseUrl, 30000);
|
||||||
|
this.chatClient = chatClient;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OpenAIConnector(String apiKey, String baseUrl, int timeoutMs, ChatClient chatClient) {
|
||||||
|
super(apiKey, baseUrl, timeoutMs);
|
||||||
|
this.chatClient = chatClient;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ChatResponse generateResponse(Prompt prompt) {
|
||||||
|
return chatClient.prompt(prompt).call().chatResponse();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String generateText(String textPrompt, Map<String, Object> params) {
|
||||||
|
// 构建消息列表
|
||||||
|
List<Message> messages = new ArrayList<>();
|
||||||
|
|
||||||
|
// 添加系统消息(如果有)
|
||||||
|
if (params != null && params.containsKey("systemPrompt")) {
|
||||||
|
String systemPrompt = (String) params.get("systemPrompt");
|
||||||
|
messages.add(new SystemPromptTemplate(systemPrompt).createMessage(params));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加用户消息
|
||||||
|
messages.add(new UserMessage(textPrompt));
|
||||||
|
|
||||||
|
// 创建提示
|
||||||
|
Prompt prompt = new Prompt(messages, createOptions(params));
|
||||||
|
|
||||||
|
// 调用ChatClient生成响应
|
||||||
|
ChatResponse response = chatClient.prompt(prompt).call().chatResponse();
|
||||||
|
|
||||||
|
// 提取生成的文本
|
||||||
|
if (response.getResults() != null && !response.getResults().isEmpty()) {
|
||||||
|
Generation generation = response.getResults().get(0);
|
||||||
|
return generation.getOutput().getText();
|
||||||
|
}
|
||||||
|
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Flux<ChatResponse> streamResponse(Prompt prompt) {
|
||||||
|
return chatClient.prompt(prompt).stream().chatResponse();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Flux<String> streamText(String textPrompt, Map<String, Object> params) {
|
||||||
|
// 构建消息列表
|
||||||
|
List<Message> messages = new ArrayList<>();
|
||||||
|
|
||||||
|
// 添加系统消息(如果有)
|
||||||
|
if (params != null && params.containsKey("systemPrompt")) {
|
||||||
|
String systemPrompt = (String) params.get("systemPrompt");
|
||||||
|
messages.add(new SystemPromptTemplate(systemPrompt).createMessage(params));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加用户消息
|
||||||
|
messages.add(new UserMessage(textPrompt));
|
||||||
|
|
||||||
|
// 创建提示
|
||||||
|
Prompt prompt = new Prompt(messages, createOptions(params));
|
||||||
|
|
||||||
|
// 流式调用并处理响应
|
||||||
|
return chatClient.prompt(prompt).stream().chatResponse()
|
||||||
|
.mapNotNull(response -> {
|
||||||
|
if (response.getResults() != null && !response.getResults().isEmpty()) {
|
||||||
|
Generation generation = response.getResults().get(0);
|
||||||
|
return generation.getOutput().getText();
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
})
|
||||||
|
.filter(content -> !content.isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getConnectorType() {
|
||||||
|
return CONNECTOR_TYPE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据参数创建OpenAI选项
|
||||||
|
*/
|
||||||
|
private OpenAiChatOptions createOptions(Map<String, Object> params) {
|
||||||
|
OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder();
|
||||||
|
|
||||||
|
if (params != null) {
|
||||||
|
// 设置模型
|
||||||
|
if (params.containsKey("model")) {
|
||||||
|
builder.model((String) params.get("model"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置温度
|
||||||
|
if (params.containsKey("temperature")) {
|
||||||
|
builder.temperature(Double.parseDouble(params.get("temperature").toString()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置最大生成长度
|
||||||
|
if (params.containsKey("maxTokens")) {
|
||||||
|
builder.maxTokens(Integer.parseInt(params.get("maxTokens").toString()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置topP
|
||||||
|
if (params.containsKey("topP")) {
|
||||||
|
builder.topP(Double.parseDouble(params.get("topP").toString()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置停止词
|
||||||
|
if (params.containsKey("stop")) {
|
||||||
|
if (params.get("stop") instanceof String) {
|
||||||
|
builder.stop(List.of((String) params.get("stop")));
|
||||||
|
} else if (params.get("stop") instanceof List) {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
List<String> stopList = (List<String>) params.get("stop");
|
||||||
|
builder.stop(stopList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.build();
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue