feat(llm): 引入多LLM连接器支持及配置管理

- 添加OpenAI和Ollama的Spring AI Starter依赖
- 实现LLM连接器接口及工厂模式,支持动态加载不同LLM配置
- 新增LLMController提供统一REST接口用于文本生成和流式输出- 配置文件中增加ai相关配置项,支持多模型参数设置
- 更新pom.xml引入Spring Milestones和Snapshots仓库- 升级spring-ai版本至1.0.0-M6
- 移除旧有的DeepSeek和Ollama直接依赖,改为通过连接器调用
- 注释掉部分未使用的测试代码与配置说明
```
This commit is contained in:
lijunming 2025-09-25 13:48:41 +08:00
parent 7e43fd59a4
commit 46d676a8c1
16 changed files with 882 additions and 113 deletions

20
pom.xml
View File

@ -28,6 +28,22 @@
<enabled>false</enabled> <enabled>false</enabled>
</snapshots> </snapshots>
</repository> </repository>
<repository>
<id>spring-milestones</id>
<name>Spring Milestones</name>
<url>https://repo.spring.io/milestone</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
<repository>
<id>spring-snapshots</id>
<name>Spring Snapshots</name>
<url>https://repo.spring.io/snapshot</url>
<releases>
<enabled>false</enabled>
</releases>
</repository>
</repositories> </repositories>
<properties> <properties>
@ -35,7 +51,7 @@
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>17</maven.compiler.source> <maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target> <maven.compiler.target>17</maven.compiler.target>
<spring-ai.version>1.0.0</spring-ai.version> <spring-ai.version>1.0.0-M6</spring-ai.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties> </properties>
@ -172,6 +188,8 @@
</plugins> </plugins>
</build> </build>
<profiles> <profiles>
<profile> <profile>
<id>dev</id> <id>dev</id>

View File

@ -12,19 +12,6 @@
<packaging>jar</packaging> <packaging>jar</packaging>
<dependencies> <dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-deepseek</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId> <artifactId>spring-boot-starter-test</artifactId>
@ -143,4 +130,6 @@
</plugins> </plugins>
</build> </build>
</project> </project>

View File

@ -1,6 +0,0 @@
/**
* 1. 用于管理引入的Jar所需的资源启动或者初始化处理
* 2. 如果有AOP切面可以再建一个aop包来写切面逻辑
*/
package com.touka.config;

View File

@ -12,7 +12,7 @@ thread:
block-queue-size: 5000 block-queue-size: 5000
policy: CallerRunsPolicy policy: CallerRunsPolicy
# 数据库配置;启动时配置数据库资源信息 # 数据库配置
spring: spring:
datasource: datasource:
username: root username: root
@ -21,25 +21,45 @@ spring:
driver-class-name: com.mysql.cj.jdbc.Driver driver-class-name: com.mysql.cj.jdbc.Driver
hikari: hikari:
pool-name: Retail_HikariCP pool-name: Retail_HikariCP
minimum-idle: 15 #最小空闲连接数量 minimum-idle: 15
idle-timeout: 180000 #空闲连接存活最大时间默认60000010分钟 idle-timeout: 180000
maximum-pool-size: 25 #连接池最大连接数默认是10 maximum-pool-size: 25
auto-commit: true #此属性控制从池返回的连接的默认自动提交行为,默认值true auto-commit: true
max-lifetime: 1800000 #此属性控制池中连接的最长生命周期值0表示无限生命周期默认1800000即30分钟 max-lifetime: 1800000
connection-timeout: 30000 #数据库连接超时时间,默认30秒即30000 connection-timeout: 30000
connection-test-query: SELECT 1 connection-test-query: SELECT 1
type: com.zaxxer.hikari.HikariDataSource type: com.zaxxer.hikari.HikariDataSource
ai: ai:
deepseek: deepseek:
api-key: your-api-key api-key: your-api-key
base-url: https://api.deepseek.com # DeepSeek 的请求 URL, 可不填,默认值为 api.deepseek.com base-url: https://api.deepseek.com
chat: chat:
options: options:
model: deepseek-reasoner # 使用深度思考模型 model: deepseek-reasoner
temperature: 0.8 # 温度值 temperature: 0.8
openai: # 注意这个应该在spring: ai: 下面而不是与spring平级
api-key: sk-EcUhlzeo44aQLBJqNncwBxw8yQuAE8rlbe8fNWLC3RXxfDGq
base-url: https://api.chatanywhere.tech # 您使用的是第三方API
chat:
options:
model: gpt-3.5-turbo
temperature: 0.7
max-tokens: 2048
# MyBatis 配置【如需使用记得打开】 llm:
# 默认连接器类型
default-connector: openai
# 各连接器详细配置
connectors:
# OpenAI连接器配置
openai:
api-key: sk-EcUhlzeo44aQLBJqNncwBxw8yQuAE8rlbe8fNWLC3RXxfDGq
base-url: https://api.chatanywhere.tech
timeout-ms: 30000
enabled: true
# MyBatis 配置
#mybatis: #mybatis:
# mapper-locations: classpath:/mybatis/mapper/*.xml # mapper-locations: classpath:/mybatis/mapper/*.xml
# config-location: classpath:/mybatis/config/mybatis-config.xml # config-location: classpath:/mybatis/config/mybatis-config.xml
@ -49,4 +69,3 @@ logging:
level: level:
root: info root: info
config: classpath:logback-spring.xml config: classpath:logback-spring.xml

View File

@ -1,81 +1,82 @@
package com.touka.test; //package com.touka.test;
//
import jakarta.annotation.Resource; //import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; //import lombok.extern.slf4j.Slf4j;
import org.junit.Test; //import org.junit.Test;
import org.junit.runner.RunWith; //import org.junit.runner.RunWith;
import org.springframework.ai.chat.model.ChatResponse; //import org.springframework.ai.chat.model.ChatResponse;
//
import org.springframework.ai.chat.prompt.Prompt; //import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate; //import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.deepseek.DeepSeekChatModel; //
//
import org.springframework.ai.ollama.api.OllamaOptions; //import org.springframework.ai.deepseek.DeepSeekChatModel;
//import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.boot.test.context.SpringBootTest; //
import org.springframework.test.context.junit4.SpringRunner; //import org.springframework.boot.test.context.SpringBootTest;
//import org.springframework.test.context.junit4.SpringRunner;
import reactor.core.publisher.Flux; //
//import reactor.core.publisher.Flux;
@Slf4j //
@RunWith(SpringRunner.class) //@Slf4j
@SpringBootTest //@RunWith(SpringRunner.class)
public class ApiTest { //@SpringBootTest
//public class ApiTest {
@Resource //
private DeepSeekChatModel chatModel; // @Resource
// private DeepSeekChatModel chatModel;
/** //
* 测试同步生成响应 // /**
*/ // * 测试同步生成响应
@Test // */
public void testGenerate() { // @Test
String message = "Tell me a joke"; // public void testGenerate() {
// String message = "Tell me a joke";
// 方式1: 直接传入字符串 //
String response1 = chatModel.call(message); // // 方式1: 直接传入字符串
System.out.println("Response1: " + response1); // String response1 = chatModel.call(message);
// System.out.println("Response1: " + response1);
// 方式2: 使用 Prompt //
ChatResponse response2 = chatModel.call(new Prompt(message)); // // 方式2: 使用 Prompt
System.out.println("Response2: " + response2.getResult().getOutput().getText()); // ChatResponse response2 = chatModel.call(new Prompt(message));
} // System.out.println("Response2: " + response2.getResult().getOutput().getText());
// }
/** //
* 测试流式生成响应 // /**
*/ // * 测试流式生成响应
@Test // */
public void testGenerateStream() { // @Test
String message = "Tell me a joke"; // public void testGenerateStream() {
// String message = "Tell me a joke";
// 使用 PromptTemplate 构建提示词 //
Prompt prompt = new PromptTemplate(message).create(); // // 使用 PromptTemplate 构建提示词
// Prompt prompt = new PromptTemplate(message).create();
// 流式输出 //
Flux<ChatResponse> responseFlux = chatModel.stream(prompt); // // 流式输出
// Flux<ChatResponse> responseFlux = chatModel.stream(prompt);
// 订阅并打印流式响应 //
responseFlux.doOnNext(response -> { // // 订阅并打印流式响应
System.out.println("Stream Response: " + response.getResult().getOutput().getText()); // responseFlux.doOnNext(response -> {
}).blockLast(); // 在测试中阻塞等待完成 // System.out.println("Stream Response: " + response.getResult().getOutput().getText());
} // }).blockLast(); // 在测试中阻塞等待完成
// }
/** //
* 测试使用特定模型选项 // /**
*/ // * 测试使用特定模型选项
@Test // */
public void testGenerateWithModelOptions() { // @Test
String message = "1+1=?"; // public void testGenerateWithModelOptions() {
String model = "deepseek-r1:1.5b"; // String message = "1+1=?";
// String model = "deepseek-r1:1.5b";
ChatResponse response = chatModel.call(new Prompt( //
message, // ChatResponse response = chatModel.call(new Prompt(
OllamaOptions.builder() // message,
.model(model) // OllamaOptions.builder()
.build() // .model(model)
)); // .build()
// ));
System.out.println("Response with model options: " + response.getResult().getOutput().getText()); //
} // System.out.println("Response with model options: " + response.getResult().getOutput().getText());
// }
} //
//}

View File

@ -10,6 +10,16 @@
<artifactId>visual-novel-server-trigger</artifactId> <artifactId>visual-novel-server-trigger</artifactId>
<dependencies> <dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId> <artifactId>spring-boot-starter-web</artifactId>

View File

@ -0,0 +1,64 @@
package com.touka.trigger.http.controller;
import java.util.Map;
import com.touka.trigger.http.service.LLMService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* LLM控制器提供REST API接口
*/
@RestController
@RequestMapping("/api/llm")
public class LLMController {
private final LLMService llmService;
@Autowired
public LLMController(LLMService llmService) {
this.llmService = llmService;
}
@PostMapping("/generate")
public String generate(@RequestBody Map<String, Object> request) {
String prompt = (String) request.get("prompt");
if (prompt == null) {
throw new IllegalArgumentException("Prompt is required");
}
@SuppressWarnings("unchecked")
Map<String, Object> params = (Map<String, Object>) request.getOrDefault("params", Map.of());
String connectorType = (String) request.get("connectorType");
if (connectorType != null) {
return llmService.generateText(connectorType, prompt, params);
} else {
return llmService.generateText(prompt, params);
}
}
@GetMapping("/stream")
public SseEmitter stream(
@RequestParam String prompt,
@RequestParam(required = false) String connectorType) {
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
llmService.streamText(prompt, Map.of(), emitter);
return emitter;
}
@SuppressWarnings("unchecked")
@GetMapping("/connectors")
public Map<String, Object> getConnectors() {
return (Map<String, Object>) (Map<?, ?>) llmService.getAllAvailableConnectors();
}
}

View File

@ -0,0 +1,42 @@
//package com.touka.trigger.http.controller;
//
//import jakarta.annotation.Resource;
//import org.springframework.ai.chat.messages.UserMessage;
//import org.springframework.ai.chat.prompt.Prompt;
//import org.springframework.ai.deepseek.DeepSeekChatModel;
//import org.springframework.web.bind.annotation.CrossOrigin;
//import org.springframework.web.bind.annotation.GetMapping;
//import org.springframework.web.bind.annotation.RequestMapping;
//import org.springframework.web.bind.annotation.RequestParam;
//import org.springframework.web.bind.annotation.RestController;
//import reactor.core.publisher.Flux;
//
//import java.util.Map;
//
//@RestController()
//@CrossOrigin("*")
//@RequestMapping("/api/ollama/")
//public class OllamaController {
// @Resource
// private DeepSeekChatModel chatModel;
//
// @GetMapping("/generate")
// public Map generate(@RequestParam(value = "message") String message) {
// return Map.of("generation", chatModel.call(message));
// }
//
// /**
// * 流式对话
// * @param message
// * @return
// */
// @GetMapping(value = "/generateStream", produces = "text/html;charset=utf-8")
// public Flux<String> generateStream(@RequestParam(value = "message") String message) {
// // 构建提示词
// Prompt prompt = new Prompt(new UserMessage(message));
// // 流式输出
// return chatModel.stream(prompt)
// .mapNotNull(chatResponse -> chatResponse.getResult().getOutput().getText());
// }
//
//}

View File

@ -0,0 +1,148 @@
package com.touka.trigger.http.service;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import com.touka.types.config.ILLMConnector;
import com.touka.types.config.LLMConnectorFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import reactor.core.publisher.Flux;
/**
* LLM服务提供简化的接口来调用不同的LLM连接器
*/
@Service
public class LLMService {
private final LLMConnectorFactory connectorFactory;
@Autowired
public LLMService(LLMConnectorFactory connectorFactory) {
this.connectorFactory = connectorFactory;
}
/**
* 使用默认连接器生成文本
* @param prompt 提示词
* @return 生成的文本
*/
public String generateText(String prompt) {
return generateText(prompt, Collections.emptyMap());
}
/**
* 使用默认连接器生成文本
* @param prompt 提示词
* @param params 额外参数
* @return 生成的文本
*/
public String generateText(String prompt, Map<String, Object> params) {
ILLMConnector connector = connectorFactory.getDefaultConnector();
if (connector == null) {
throw new IllegalStateException("No available LLM connector");
}
return connector.generateText(prompt, params);
}
/**
* 使用指定连接器生成文本
* @param connectorType 连接器类型
* @param prompt 提示词
* @param params 额外参数
* @return 生成的文本
*/
public String generateText(String connectorType, String prompt, Map<String, Object> params) {
ILLMConnector connector = connectorFactory.getConnector(connectorType);
if (connector == null || !connector.isAvailable()) {
throw new IllegalArgumentException("Connector not found or not available: " + connectorType);
}
return connector.generateText(prompt, params);
}
/**
* 流式生成文本
* @param prompt 提示词
* @param emitter SSE发射器
*/
public void streamText(String prompt, SseEmitter emitter) {
streamText(prompt, Collections.emptyMap(), emitter);
}
/**
* 流式生成文本
* @param prompt 提示词
* @param params 额外参数
* @param emitter SSE发射器
*/
public void streamText(String prompt, Map<String, Object> params, SseEmitter emitter) {
ILLMConnector connector = connectorFactory.getDefaultConnector();
if (connector == null) {
emitter.completeWithError(new IllegalStateException("No available LLM connector"));
return;
}
// 使用Flux处理流式响应
Flux<String> textFlux = connector.streamText(prompt, params);
textFlux.subscribe(
content -> {
try {
emitter.send(SseEmitter.event().data(content));
} catch (Exception e) {
emitter.completeWithError(e);
}
},
error -> emitter.completeWithError(error),
() -> emitter.complete()
);
}
/**
* 使用指定连接器流式生成文本
* @param connectorType 连接器类型
* @param prompt 提示词
* @param params 额外参数
* @param emitter SSE发射器
*/
public void streamText(String connectorType, String prompt, Map<String, Object> params, SseEmitter emitter) {
ILLMConnector connector = connectorFactory.getConnector(connectorType);
if (connector == null || !connector.isAvailable()) {
emitter.completeWithError(new IllegalArgumentException("Connector not found or not available: " + connectorType));
return;
}
// 使用Flux处理流式响应
Flux<String> textFlux = connector.streamText(prompt, params);
textFlux.subscribe(
content -> {
try {
emitter.send(SseEmitter.event().data(content));
} catch (Exception e) {
emitter.completeWithError(e);
}
},
error -> emitter.completeWithError(error),
() -> emitter.complete()
);
}
/**
* 获取所有可用的连接器类型
* @return 连接器类型列表
*/
@SuppressWarnings("unchecked")
public Map<String, Object> getAllAvailableConnectors() {
Map<String, Object> availableConnectors = new HashMap<>();
connectorFactory.getAllConnectors().forEach((type, connector) -> {
if (connector.isAvailable()) {
availableConnectors.put(type, connector);
}
});
return availableConnectors;
}
}

View File

@ -10,6 +10,22 @@
<artifactId>visual-novel-server-types</artifactId> <artifactId>visual-novel-server-types</artifactId>
<dependencies> <dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>
<!-- 添加OpenAI SDK依赖 -->
<dependency>
<groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>client</artifactId>
<version>0.18.0</version>
</dependency>
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId> <artifactId>spring-boot-starter-web</artifactId>

View File

@ -0,0 +1,54 @@
package com.touka.types.config;
/**
* LLM连接器的基础实现提供一些通用功能
*/
public abstract class BaseLLMConnector implements ILLMConnector {
protected String apiKey;
protected String baseUrl;
protected int timeoutMs;
protected boolean enabled;
public BaseLLMConnector(String apiKey, String baseUrl, int timeoutMs) {
this.apiKey = apiKey;
this.baseUrl = baseUrl;
this.timeoutMs = timeoutMs;
this.enabled = true;
}
@Override
public boolean isAvailable() {
return enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
public String getApiKey() {
return apiKey;
}
public void setApiKey(String apiKey) {
this.apiKey = apiKey;
}
public String getBaseUrl() {
return baseUrl;
}
public void setBaseUrl(String baseUrl) {
this.baseUrl = baseUrl;
}
public int getTimeoutMs() {
return timeoutMs;
}
public void setTimeoutMs(int timeoutMs) {
this.timeoutMs = timeoutMs;
}
}

View File

@ -0,0 +1,55 @@
package com.touka.types.config;
import java.util.Map;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;
/**
* 大语言模型连接器接口定义了与各类LLM API交互的统一方法
*/
public interface ILLMConnector {
/**
* 使用Prompt对象生成文本响应
* @param prompt 提示对象
* @return 聊天响应
*/
ChatResponse generateResponse(Prompt prompt);
/**
* 使用文本提示生成文本
* @param textPrompt 文本提示
* @param params 额外参数
* @return 生成的文本内容
*/
String generateText(String textPrompt, Map<String, Object> params);
/**
* 流式生成文本响应
* @param prompt 提示对象
* @return 响应流
*/
Flux<ChatResponse> streamResponse(Prompt prompt);
/**
* 流式生成文本
* @param textPrompt 文本提示
* @param params 额外参数
* @return 文本流
*/
Flux<String> streamText(String textPrompt, Map<String, Object> params);
/**
* 获取连接器类型
* @return 连接器类型标识
*/
String getConnectorType();
/**
* 检查连接器是否可用
* @return 连接器是否可用
*/
boolean isAvailable();
}

View File

@ -0,0 +1,42 @@
package com.touka.types.config;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
// 在配置类中添加
@Configuration
public class LLMConfig {
@Bean
public ChatClient chatClient(ChatClient.Builder builder) {
return builder.build();
}
@Bean
public ChatClient.Builder chatClientBuilder(OpenAiChatModel openAiChatModel) {
// 根据配置创建Builder
return ChatClient.builder(openAiChatModel);
}
@Bean
public OpenAiChatModel openAiChatModel(OpenAiApi openAiApi) {
OpenAiChatOptions options = OpenAiChatOptions.builder()
.model("gpt-3.5-turbo")
.temperature(0.7)
.maxCompletionTokens(2048)
.build();
return new OpenAiChatModel(openAiApi, options);
}
@Bean
public OpenAiApi openAiApi(@Value("${spring.ai.openai.api-key}") String apiKey, @Value("${spring.ai.openai.base-url}") String baseUrl) {
return OpenAiApi.builder().
baseUrl(baseUrl).
apiKey(apiKey).
build();
}
}

View File

@ -0,0 +1,34 @@
package com.touka.types.config;
import java.util.Map;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
/**
* LLM连接器的配置类用于从配置文件中读取连接器配置
*/
@Data
@Component
@ConfigurationProperties(prefix = "llm")
public class LLMConnectorConfig {
/**
* 默认连接器类型
*/
private String defaultConnector;
/**
* 各连接器的详细配置
*/
private Map<String, ConnectorDetailConfig> connectors;
@Data
public static class ConnectorDetailConfig {
private String apiKey;
private String baseUrl;
private int timeoutMs = 30000;
private boolean enabled = true;
private Map<String, Object> additionalParams;
}
}

View File

@ -0,0 +1,124 @@
package com.touka.types.config;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import com.touka.types.config.impl.OpenAIConnector;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
/**
* LLM连接器工厂用于创建和管理不同类型的连接器实例
*/
@Component
public class LLMConnectorFactory {
@Value("${spring.ai.openai.api-key:}")
private String defaultApiKey;
@Value("${spring.ai.openai.base-url:https://api.openai.com/v1}")
private String defaultBaseUrl;
private final LLMConnectorConfig config;
private final ChatClient chatClient;
private final Map<String, ILLMConnector> connectors = new ConcurrentHashMap<>();
@Autowired
public LLMConnectorFactory(LLMConnectorConfig config, ChatClient chatClient) {
this.config = config;
this.chatClient = chatClient;
initializeConnectors();
}
private void initializeConnectors() {
// 注册默认连接器
registerDefaultConnectors();
// 根据配置初始化连接器
if (config.getConnectors() != null) {
for (Map.Entry<String, LLMConnectorConfig.ConnectorDetailConfig> entry : config.getConnectors().entrySet()) {
String connectorType = entry.getKey();
LLMConnectorConfig.ConnectorDetailConfig detailConfig = entry.getValue();
if (detailConfig.isEnabled()) {
ILLMConnector connector = createConnector(connectorType, detailConfig);
if (connector != null) {
connectors.put(connectorType, connector);
}
}
}
}
}
private void registerDefaultConnectors() {
// 注册默认连接器实现传入所有必要的参数
connectors.put("openai", new OpenAIConnector(chatClient, defaultApiKey, defaultBaseUrl));
}
private ILLMConnector createConnector(String connectorType, LLMConnectorConfig.ConnectorDetailConfig detailConfig) {
switch (connectorType.toLowerCase()) {
case "openai":
return new OpenAIConnector(
detailConfig.getApiKey(),
detailConfig.getBaseUrl(),
detailConfig.getTimeoutMs(),
chatClient);
default:
// 可以扩展支持更多连接器类型
return null;
}
}
/**
* 获取指定类型的连接器
* @param connectorType 连接器类型
* @return 连接器实例如果不存在则返回null
*/
public ILLMConnector getConnector(String connectorType) {
return connectors.get(connectorType);
}
/**
* 获取默认连接器
* @return 默认连接器实例
*/
public ILLMConnector getDefaultConnector() {
String defaultType = config.getDefaultConnector();
if (defaultType != null) {
ILLMConnector connector = connectors.get(defaultType);
if (connector != null && connector.isAvailable()) {
return connector;
}
}
// 如果没有配置默认连接器或默认连接器不可用返回第一个可用的连接器
return connectors.values().stream()
.filter(ILLMConnector::isAvailable)
.findFirst()
.orElse(null);
}
/**
* 注册新的连接器
* @param connector 连接器实例
*/
public void registerConnector(ILLMConnector connector) {
if (connector != null) {
connectors.put(connector.getConnectorType(), connector);
}
}
/**
* 获取所有可用的连接器类型
* @return 连接器类型列表
*/
public Map<String, ILLMConnector> getAllConnectors() {
return new ConcurrentHashMap<>(connectors);
}
}

View File

@ -0,0 +1,159 @@
package com.touka.types.config.impl;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import com.touka.types.config.BaseLLMConnector;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;
/**
* OpenAI API连接器实现基于Spring AI的ChatModel和ChatClient
*/
@Component
public class OpenAIConnector extends BaseLLMConnector {
private static final String CONNECTOR_TYPE = "openai";
private final ChatClient chatClient;
@Autowired
public OpenAIConnector(ChatClient chatClient,
@Value("${spring.ai.openai.api-key:}") String apiKey,
@Value("${spring.ai.openai.base-url:https://api.openai.com/v1}") String baseUrl) {
// 直接在构造函数中使用注入的值
super(apiKey, baseUrl, 30000);
this.chatClient = chatClient;
}
public OpenAIConnector(String apiKey, String baseUrl, int timeoutMs, ChatClient chatClient) {
super(apiKey, baseUrl, timeoutMs);
this.chatClient = chatClient;
}
@Override
public ChatResponse generateResponse(Prompt prompt) {
return chatClient.prompt(prompt).call().chatResponse();
}
@Override
public String generateText(String textPrompt, Map<String, Object> params) {
// 构建消息列表
List<Message> messages = new ArrayList<>();
// 添加系统消息如果有
if (params != null && params.containsKey("systemPrompt")) {
String systemPrompt = (String) params.get("systemPrompt");
messages.add(new SystemPromptTemplate(systemPrompt).createMessage(params));
}
// 添加用户消息
messages.add(new UserMessage(textPrompt));
// 创建提示
Prompt prompt = new Prompt(messages, createOptions(params));
// 调用ChatClient生成响应
ChatResponse response = chatClient.prompt(prompt).call().chatResponse();
// 提取生成的文本
if (response.getResults() != null && !response.getResults().isEmpty()) {
Generation generation = response.getResults().get(0);
return generation.getOutput().getText();
}
return "";
}
@Override
public Flux<ChatResponse> streamResponse(Prompt prompt) {
return chatClient.prompt(prompt).stream().chatResponse();
}
@Override
public Flux<String> streamText(String textPrompt, Map<String, Object> params) {
// 构建消息列表
List<Message> messages = new ArrayList<>();
// 添加系统消息如果有
if (params != null && params.containsKey("systemPrompt")) {
String systemPrompt = (String) params.get("systemPrompt");
messages.add(new SystemPromptTemplate(systemPrompt).createMessage(params));
}
// 添加用户消息
messages.add(new UserMessage(textPrompt));
// 创建提示
Prompt prompt = new Prompt(messages, createOptions(params));
// 流式调用并处理响应
return chatClient.prompt(prompt).stream().chatResponse()
.mapNotNull(response -> {
if (response.getResults() != null && !response.getResults().isEmpty()) {
Generation generation = response.getResults().get(0);
return generation.getOutput().getText();
}
return "";
})
.filter(content -> !content.isEmpty());
}
@Override
public String getConnectorType() {
return CONNECTOR_TYPE;
}
/**
* 根据参数创建OpenAI选项
*/
private OpenAiChatOptions createOptions(Map<String, Object> params) {
OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder();
if (params != null) {
// 设置模型
if (params.containsKey("model")) {
builder.model((String) params.get("model"));
}
// 设置温度
if (params.containsKey("temperature")) {
builder.temperature(Double.parseDouble(params.get("temperature").toString()));
}
// 设置最大生成长度
if (params.containsKey("maxTokens")) {
builder.maxTokens(Integer.parseInt(params.get("maxTokens").toString()));
}
// 设置topP
if (params.containsKey("topP")) {
builder.topP(Double.parseDouble(params.get("topP").toString()));
}
// 设置停止词
if (params.containsKey("stop")) {
if (params.get("stop") instanceof String) {
builder.stop(List.of((String) params.get("stop")));
} else if (params.get("stop") instanceof List) {
@SuppressWarnings("unchecked")
List<String> stopList = (List<String>) params.get("stop");
builder.stop(stopList);
}
}
}
return builder.build();
}
}