Compare commits

..

No commits in common. "devlop" and "master" have entirely different histories.

19 changed files with 51 additions and 896 deletions

26
pom.xml
View File

@ -28,22 +28,6 @@
<enabled>false</enabled>
</snapshots>
</repository>
<repository>
<id>spring-milestones</id>
<name>Spring Milestones</name>
<url>https://repo.spring.io/milestone</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
<repository>
<id>spring-snapshots</id>
<name>Spring Snapshots</name>
<url>https://repo.spring.io/snapshot</url>
<releases>
<enabled>false</enabled>
</releases>
</repository>
</repositories>
<properties>
@ -51,7 +35,6 @@
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<spring-ai.version>1.0.0-M6</spring-ai.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
@ -70,13 +53,6 @@
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-bom</artifactId>
<version>${spring-ai.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>org.mybatis.spring.boot</groupId>
<artifactId>mybatis-spring-boot-starter</artifactId>
@ -188,8 +164,6 @@
</plugins>
</build>
<profiles>
<profile>
<id>dev</id>

View File

@ -12,6 +12,10 @@
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
@ -130,6 +134,4 @@
</plugins>
</build>
</project>

View File

@ -0,0 +1,6 @@
/**
* 1. 用于管理引入的Jar所需的资源启动或者初始化处理
* 2. 如果有AOP切面可以再建一个aop包来写切面逻辑
*/
package com.touka.config;

View File

@ -12,7 +12,7 @@ thread:
block-queue-size: 5000
policy: CallerRunsPolicy
# 数据库配置
# 数据库配置;启动时配置数据库资源信息
spring:
datasource:
username: root
@ -21,45 +21,16 @@ spring:
driver-class-name: com.mysql.cj.jdbc.Driver
hikari:
pool-name: Retail_HikariCP
minimum-idle: 15
idle-timeout: 180000
maximum-pool-size: 25
auto-commit: true
max-lifetime: 1800000
connection-timeout: 30000
minimum-idle: 15 #最小空闲连接数量
idle-timeout: 180000 #空闲连接存活最大时间默认60000010分钟
maximum-pool-size: 25 #连接池最大连接数默认是10
auto-commit: true #此属性控制从池返回的连接的默认自动提交行为,默认值true
max-lifetime: 1800000 #此属性控制池中连接的最长生命周期值0表示无限生命周期默认1800000即30分钟
connection-timeout: 30000 #数据库连接超时时间,默认30秒即30000
connection-test-query: SELECT 1
type: com.zaxxer.hikari.HikariDataSource
ai:
deepseek:
api-key: your-api-key
base-url: https://api.deepseek.com
chat:
options:
model: deepseek-reasoner
temperature: 0.8
openai: # 注意这个应该在spring: ai: 下面而不是与spring平级
api-key: sk-EcUhlzeo44aQLBJqNncwBxw8yQuAE8rlbe8fNWLC3RXxfDGq
base-url: https://api.chatanywhere.tech # 您使用的是第三方API
chat:
options:
model: gpt-3.5-turbo
temperature: 0.7
max-tokens: 2048
llm:
# 默认连接器类型
default-connector: openai
# 各连接器详细配置
connectors:
# OpenAI连接器配置
openai:
api-key: sk-EcUhlzeo44aQLBJqNncwBxw8yQuAE8rlbe8fNWLC3RXxfDGq
base-url: https://api.chatanywhere.tech
timeout-ms: 30000
enabled: true
# MyBatis 配置
# MyBatis 配置【如需使用记得打开】
#mybatis:
# mapper-locations: classpath:/mybatis/mapper/*.xml
# config-location: classpath:/mybatis/config/mybatis-config.xml
@ -68,4 +39,4 @@ llm:
logging:
level:
root: info
config: classpath:logback-spring.xml
config: classpath:logback-spring.xml

View File

@ -1,82 +1,19 @@
//package com.touka.test;
//
//import jakarta.annotation.Resource;
//import lombok.extern.slf4j.Slf4j;
//import org.junit.Test;
//import org.junit.runner.RunWith;
//import org.springframework.ai.chat.model.ChatResponse;
//
//import org.springframework.ai.chat.prompt.Prompt;
//import org.springframework.ai.chat.prompt.PromptTemplate;
//
//
//import org.springframework.ai.deepseek.DeepSeekChatModel;
//import org.springframework.ai.ollama.api.OllamaOptions;
//
//import org.springframework.boot.test.context.SpringBootTest;
//import org.springframework.test.context.junit4.SpringRunner;
//
//import reactor.core.publisher.Flux;
//
//@Slf4j
//@RunWith(SpringRunner.class)
//@SpringBootTest
//public class ApiTest {
//
// @Resource
// private DeepSeekChatModel chatModel;
//
// /**
// * 测试同步生成响应
// */
// @Test
// public void testGenerate() {
// String message = "Tell me a joke";
//
// // 方式1: 直接传入字符串
// String response1 = chatModel.call(message);
// System.out.println("Response1: " + response1);
//
// // 方式2: 使用 Prompt
// ChatResponse response2 = chatModel.call(new Prompt(message));
// System.out.println("Response2: " + response2.getResult().getOutput().getText());
// }
//
// /**
// * 测试流式生成响应
// */
// @Test
// public void testGenerateStream() {
// String message = "Tell me a joke";
//
// // 使用 PromptTemplate 构建提示词
// Prompt prompt = new PromptTemplate(message).create();
//
// // 流式输出
// Flux<ChatResponse> responseFlux = chatModel.stream(prompt);
//
// // 订阅并打印流式响应
// responseFlux.doOnNext(response -> {
// System.out.println("Stream Response: " + response.getResult().getOutput().getText());
// }).blockLast(); // 在测试中阻塞等待完成
// }
//
// /**
// * 测试使用特定模型选项
// */
// @Test
// public void testGenerateWithModelOptions() {
// String message = "1+1=?";
// String model = "deepseek-r1:1.5b";
//
// ChatResponse response = chatModel.call(new Prompt(
// message,
// OllamaOptions.builder()
// .model(model)
// .build()
// ));
//
// System.out.println("Response with model options: " + response.getResult().getOutput().getText());
// }
//
//}
package com.touka.test;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.junit4.SpringRunner;
@Slf4j
@RunWith(SpringRunner.class)
@SpringBootTest
public class ApiTest {
@Test
public void test() {
log.info("测试完成");
}
}

View File

@ -10,16 +10,6 @@
<artifactId>visual-novel-server-trigger</artifactId>
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>

View File

@ -1,64 +0,0 @@
package com.touka.trigger.http.controller;
import java.util.Map;
import com.touka.trigger.http.service.LLMService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* LLM控制器提供REST API接口
*/
@RestController
@RequestMapping("/api/llm")
public class LLMController {
private final LLMService llmService;
@Autowired
public LLMController(LLMService llmService) {
this.llmService = llmService;
}
@PostMapping("/generate")
public String generate(@RequestBody Map<String, Object> request) {
String prompt = (String) request.get("prompt");
if (prompt == null) {
throw new IllegalArgumentException("Prompt is required");
}
@SuppressWarnings("unchecked")
Map<String, Object> params = (Map<String, Object>) request.getOrDefault("params", Map.of());
String connectorType = (String) request.get("connectorType");
if (connectorType != null) {
return llmService.generateText(connectorType, prompt, params);
} else {
return llmService.generateText(prompt, params);
}
}
@GetMapping("/stream")
public SseEmitter stream(
@RequestParam String prompt,
@RequestParam(required = false) String connectorType) {
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
llmService.streamText(prompt, Map.of(), emitter);
return emitter;
}
@SuppressWarnings("unchecked")
@GetMapping("/connectors")
public Map<String, Object> getConnectors() {
return (Map<String, Object>) (Map<?, ?>) llmService.getAllAvailableConnectors();
}
}

View File

@ -1,42 +0,0 @@
//package com.touka.trigger.http.controller;
//
//import jakarta.annotation.Resource;
//import org.springframework.ai.chat.messages.UserMessage;
//import org.springframework.ai.chat.prompt.Prompt;
//import org.springframework.ai.deepseek.DeepSeekChatModel;
//import org.springframework.web.bind.annotation.CrossOrigin;
//import org.springframework.web.bind.annotation.GetMapping;
//import org.springframework.web.bind.annotation.RequestMapping;
//import org.springframework.web.bind.annotation.RequestParam;
//import org.springframework.web.bind.annotation.RestController;
//import reactor.core.publisher.Flux;
//
//import java.util.Map;
//
//@RestController()
//@CrossOrigin("*")
//@RequestMapping("/api/ollama/")
//public class OllamaController {
// @Resource
// private DeepSeekChatModel chatModel;
//
// @GetMapping("/generate")
// public Map generate(@RequestParam(value = "message") String message) {
// return Map.of("generation", chatModel.call(message));
// }
//
// /**
// * 流式对话
// * @param message
// * @return
// */
// @GetMapping(value = "/generateStream", produces = "text/html;charset=utf-8")
// public Flux<String> generateStream(@RequestParam(value = "message") String message) {
// // 构建提示词
// Prompt prompt = new Prompt(new UserMessage(message));
// // 流式输出
// return chatModel.stream(prompt)
// .mapNotNull(chatResponse -> chatResponse.getResult().getOutput().getText());
// }
//
//}

View File

@ -0,0 +1,4 @@
/**
* HTTP 接口服务
*/
package com.touka.trigger.http;

View File

@ -1,148 +0,0 @@
package com.touka.trigger.http.service;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import com.touka.types.config.ILLMConnector;
import com.touka.types.config.LLMConnectorFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import reactor.core.publisher.Flux;
/**
* LLM服务提供简化的接口来调用不同的LLM连接器
*/
@Service
public class LLMService {
private final LLMConnectorFactory connectorFactory;
@Autowired
public LLMService(LLMConnectorFactory connectorFactory) {
this.connectorFactory = connectorFactory;
}
/**
* 使用默认连接器生成文本
* @param prompt 提示词
* @return 生成的文本
*/
public String generateText(String prompt) {
return generateText(prompt, Collections.emptyMap());
}
/**
* 使用默认连接器生成文本
* @param prompt 提示词
* @param params 额外参数
* @return 生成的文本
*/
public String generateText(String prompt, Map<String, Object> params) {
ILLMConnector connector = connectorFactory.getDefaultConnector();
if (connector == null) {
throw new IllegalStateException("No available LLM connector");
}
return connector.generateText(prompt, params);
}
/**
* 使用指定连接器生成文本
* @param connectorType 连接器类型
* @param prompt 提示词
* @param params 额外参数
* @return 生成的文本
*/
public String generateText(String connectorType, String prompt, Map<String, Object> params) {
ILLMConnector connector = connectorFactory.getConnector(connectorType);
if (connector == null || !connector.isAvailable()) {
throw new IllegalArgumentException("Connector not found or not available: " + connectorType);
}
return connector.generateText(prompt, params);
}
/**
* 流式生成文本
* @param prompt 提示词
* @param emitter SSE发射器
*/
public void streamText(String prompt, SseEmitter emitter) {
streamText(prompt, Collections.emptyMap(), emitter);
}
/**
* 流式生成文本
* @param prompt 提示词
* @param params 额外参数
* @param emitter SSE发射器
*/
public void streamText(String prompt, Map<String, Object> params, SseEmitter emitter) {
ILLMConnector connector = connectorFactory.getDefaultConnector();
if (connector == null) {
emitter.completeWithError(new IllegalStateException("No available LLM connector"));
return;
}
// 使用Flux处理流式响应
Flux<String> textFlux = connector.streamText(prompt, params);
textFlux.subscribe(
content -> {
try {
emitter.send(SseEmitter.event().data(content));
} catch (Exception e) {
emitter.completeWithError(e);
}
},
error -> emitter.completeWithError(error),
() -> emitter.complete()
);
}
/**
* 使用指定连接器流式生成文本
* @param connectorType 连接器类型
* @param prompt 提示词
* @param params 额外参数
* @param emitter SSE发射器
*/
public void streamText(String connectorType, String prompt, Map<String, Object> params, SseEmitter emitter) {
ILLMConnector connector = connectorFactory.getConnector(connectorType);
if (connector == null || !connector.isAvailable()) {
emitter.completeWithError(new IllegalArgumentException("Connector not found or not available: " + connectorType));
return;
}
// 使用Flux处理流式响应
Flux<String> textFlux = connector.streamText(prompt, params);
textFlux.subscribe(
content -> {
try {
emitter.send(SseEmitter.event().data(content));
} catch (Exception e) {
emitter.completeWithError(e);
}
},
error -> emitter.completeWithError(error),
() -> emitter.complete()
);
}
/**
* 获取所有可用的连接器类型
* @return 连接器类型列表
*/
@SuppressWarnings("unchecked")
public Map<String, Object> getAllAvailableConnectors() {
Map<String, Object> availableConnectors = new HashMap<>();
connectorFactory.getAllConnectors().forEach((type, connector) -> {
if (connector.isAvailable()) {
availableConnectors.put(type, connector);
}
});
return availableConnectors;
}
}

View File

@ -0,0 +1,4 @@
/**
* 任务服务可以选择使用 Spring 默认提供的 Schedule https://bugstack.cn/md/road-map/quartz.html
*/
package com.touka.trigger.job;

View File

@ -0,0 +1,5 @@
/**
* 监听服务在单体服务中解耦流程类似MQ的使用如Spring的EventGuava的事件总线都可以如果使用了 Redis 那么也可以有发布/订阅使用
* Guavahttps://bugstack.cn/md/road-map/guava.html
*/
package com.touka.trigger.listener;

View File

@ -10,22 +10,6 @@
<artifactId>visual-novel-server-types</artifactId>
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>
<!-- 添加OpenAI SDK依赖 -->
<dependency>
<groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>client</artifactId>
<version>0.18.0</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>

View File

@ -1,54 +0,0 @@
package com.touka.types.config;
/**
* LLM连接器的基础实现提供一些通用功能
*/
public abstract class BaseLLMConnector implements ILLMConnector {
protected String apiKey;
protected String baseUrl;
protected int timeoutMs;
protected boolean enabled;
public BaseLLMConnector(String apiKey, String baseUrl, int timeoutMs) {
this.apiKey = apiKey;
this.baseUrl = baseUrl;
this.timeoutMs = timeoutMs;
this.enabled = true;
}
@Override
public boolean isAvailable() {
return enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
public String getApiKey() {
return apiKey;
}
public void setApiKey(String apiKey) {
this.apiKey = apiKey;
}
public String getBaseUrl() {
return baseUrl;
}
public void setBaseUrl(String baseUrl) {
this.baseUrl = baseUrl;
}
public int getTimeoutMs() {
return timeoutMs;
}
public void setTimeoutMs(int timeoutMs) {
this.timeoutMs = timeoutMs;
}
}

View File

@ -1,55 +0,0 @@
package com.touka.types.config;
import java.util.Map;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;
/**
* 大语言模型连接器接口定义了与各类LLM API交互的统一方法
*/
public interface ILLMConnector {
/**
* 使用Prompt对象生成文本响应
* @param prompt 提示对象
* @return 聊天响应
*/
ChatResponse generateResponse(Prompt prompt);
/**
* 使用文本提示生成文本
* @param textPrompt 文本提示
* @param params 额外参数
* @return 生成的文本内容
*/
String generateText(String textPrompt, Map<String, Object> params);
/**
* 流式生成文本响应
* @param prompt 提示对象
* @return 响应流
*/
Flux<ChatResponse> streamResponse(Prompt prompt);
/**
* 流式生成文本
* @param textPrompt 文本提示
* @param params 额外参数
* @return 文本流
*/
Flux<String> streamText(String textPrompt, Map<String, Object> params);
/**
* 获取连接器类型
* @return 连接器类型标识
*/
String getConnectorType();
/**
* 检查连接器是否可用
* @return 连接器是否可用
*/
boolean isAvailable();
}

View File

@ -1,42 +0,0 @@
package com.touka.types.config;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
// 在配置类中添加
@Configuration
public class LLMConfig {
@Bean
public ChatClient chatClient(ChatClient.Builder builder) {
return builder.build();
}
@Bean
public ChatClient.Builder chatClientBuilder(OpenAiChatModel openAiChatModel) {
// 根据配置创建Builder
return ChatClient.builder(openAiChatModel);
}
@Bean
public OpenAiChatModel openAiChatModel(OpenAiApi openAiApi) {
OpenAiChatOptions options = OpenAiChatOptions.builder()
.model("gpt-3.5-turbo")
.temperature(0.7)
.maxCompletionTokens(2048)
.build();
return new OpenAiChatModel(openAiApi, options);
}
@Bean
public OpenAiApi openAiApi(@Value("${spring.ai.openai.api-key}") String apiKey, @Value("${spring.ai.openai.base-url}") String baseUrl) {
return OpenAiApi.builder().
baseUrl(baseUrl).
apiKey(apiKey).
build();
}
}

View File

@ -1,34 +0,0 @@
package com.touka.types.config;
import java.util.Map;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
/**
* LLM连接器的配置类用于从配置文件中读取连接器配置
*/
@Data
@Component
@ConfigurationProperties(prefix = "llm")
public class LLMConnectorConfig {
/**
* 默认连接器类型
*/
private String defaultConnector;
/**
* 各连接器的详细配置
*/
private Map<String, ConnectorDetailConfig> connectors;
@Data
public static class ConnectorDetailConfig {
private String apiKey;
private String baseUrl;
private int timeoutMs = 30000;
private boolean enabled = true;
private Map<String, Object> additionalParams;
}
}

View File

@ -1,124 +0,0 @@
package com.touka.types.config;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import com.touka.types.config.impl.OpenAIConnector;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
/**
* LLM连接器工厂用于创建和管理不同类型的连接器实例
*/
@Component
public class LLMConnectorFactory {
@Value("${spring.ai.openai.api-key:}")
private String defaultApiKey;
@Value("${spring.ai.openai.base-url:https://api.openai.com/v1}")
private String defaultBaseUrl;
private final LLMConnectorConfig config;
private final ChatClient chatClient;
private final Map<String, ILLMConnector> connectors = new ConcurrentHashMap<>();
@Autowired
public LLMConnectorFactory(LLMConnectorConfig config, ChatClient chatClient) {
this.config = config;
this.chatClient = chatClient;
initializeConnectors();
}
private void initializeConnectors() {
// 注册默认连接器
registerDefaultConnectors();
// 根据配置初始化连接器
if (config.getConnectors() != null) {
for (Map.Entry<String, LLMConnectorConfig.ConnectorDetailConfig> entry : config.getConnectors().entrySet()) {
String connectorType = entry.getKey();
LLMConnectorConfig.ConnectorDetailConfig detailConfig = entry.getValue();
if (detailConfig.isEnabled()) {
ILLMConnector connector = createConnector(connectorType, detailConfig);
if (connector != null) {
connectors.put(connectorType, connector);
}
}
}
}
}
private void registerDefaultConnectors() {
// 注册默认连接器实现传入所有必要的参数
connectors.put("openai", new OpenAIConnector(chatClient, defaultApiKey, defaultBaseUrl));
}
private ILLMConnector createConnector(String connectorType, LLMConnectorConfig.ConnectorDetailConfig detailConfig) {
switch (connectorType.toLowerCase()) {
case "openai":
return new OpenAIConnector(
detailConfig.getApiKey(),
detailConfig.getBaseUrl(),
detailConfig.getTimeoutMs(),
chatClient);
default:
// 可以扩展支持更多连接器类型
return null;
}
}
/**
* 获取指定类型的连接器
* @param connectorType 连接器类型
* @return 连接器实例如果不存在则返回null
*/
public ILLMConnector getConnector(String connectorType) {
return connectors.get(connectorType);
}
/**
* 获取默认连接器
* @return 默认连接器实例
*/
public ILLMConnector getDefaultConnector() {
String defaultType = config.getDefaultConnector();
if (defaultType != null) {
ILLMConnector connector = connectors.get(defaultType);
if (connector != null && connector.isAvailable()) {
return connector;
}
}
// 如果没有配置默认连接器或默认连接器不可用返回第一个可用的连接器
return connectors.values().stream()
.filter(ILLMConnector::isAvailable)
.findFirst()
.orElse(null);
}
/**
* 注册新的连接器
* @param connector 连接器实例
*/
public void registerConnector(ILLMConnector connector) {
if (connector != null) {
connectors.put(connector.getConnectorType(), connector);
}
}
/**
* 获取所有可用的连接器类型
* @return 连接器类型列表
*/
public Map<String, ILLMConnector> getAllConnectors() {
return new ConcurrentHashMap<>(connectors);
}
}

View File

@ -1,159 +0,0 @@
package com.touka.types.config.impl;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import com.touka.types.config.BaseLLMConnector;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;
/**
* OpenAI API连接器实现基于Spring AI的ChatModel和ChatClient
*/
@Component
public class OpenAIConnector extends BaseLLMConnector {
private static final String CONNECTOR_TYPE = "openai";
private final ChatClient chatClient;
@Autowired
public OpenAIConnector(ChatClient chatClient,
@Value("${spring.ai.openai.api-key:}") String apiKey,
@Value("${spring.ai.openai.base-url:https://api.openai.com/v1}") String baseUrl) {
// 直接在构造函数中使用注入的值
super(apiKey, baseUrl, 30000);
this.chatClient = chatClient;
}
public OpenAIConnector(String apiKey, String baseUrl, int timeoutMs, ChatClient chatClient) {
super(apiKey, baseUrl, timeoutMs);
this.chatClient = chatClient;
}
@Override
public ChatResponse generateResponse(Prompt prompt) {
return chatClient.prompt(prompt).call().chatResponse();
}
@Override
public String generateText(String textPrompt, Map<String, Object> params) {
// 构建消息列表
List<Message> messages = new ArrayList<>();
// 添加系统消息如果有
if (params != null && params.containsKey("systemPrompt")) {
String systemPrompt = (String) params.get("systemPrompt");
messages.add(new SystemPromptTemplate(systemPrompt).createMessage(params));
}
// 添加用户消息
messages.add(new UserMessage(textPrompt));
// 创建提示
Prompt prompt = new Prompt(messages, createOptions(params));
// 调用ChatClient生成响应
ChatResponse response = chatClient.prompt(prompt).call().chatResponse();
// 提取生成的文本
if (response.getResults() != null && !response.getResults().isEmpty()) {
Generation generation = response.getResults().get(0);
return generation.getOutput().getText();
}
return "";
}
@Override
public Flux<ChatResponse> streamResponse(Prompt prompt) {
return chatClient.prompt(prompt).stream().chatResponse();
}
@Override
public Flux<String> streamText(String textPrompt, Map<String, Object> params) {
// 构建消息列表
List<Message> messages = new ArrayList<>();
// 添加系统消息如果有
if (params != null && params.containsKey("systemPrompt")) {
String systemPrompt = (String) params.get("systemPrompt");
messages.add(new SystemPromptTemplate(systemPrompt).createMessage(params));
}
// 添加用户消息
messages.add(new UserMessage(textPrompt));
// 创建提示
Prompt prompt = new Prompt(messages, createOptions(params));
// 流式调用并处理响应
return chatClient.prompt(prompt).stream().chatResponse()
.mapNotNull(response -> {
if (response.getResults() != null && !response.getResults().isEmpty()) {
Generation generation = response.getResults().get(0);
return generation.getOutput().getText();
}
return "";
})
.filter(content -> !content.isEmpty());
}
@Override
public String getConnectorType() {
return CONNECTOR_TYPE;
}
/**
* 根据参数创建OpenAI选项
*/
private OpenAiChatOptions createOptions(Map<String, Object> params) {
OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder();
if (params != null) {
// 设置模型
if (params.containsKey("model")) {
builder.model((String) params.get("model"));
}
// 设置温度
if (params.containsKey("temperature")) {
builder.temperature(Double.parseDouble(params.get("temperature").toString()));
}
// 设置最大生成长度
if (params.containsKey("maxTokens")) {
builder.maxTokens(Integer.parseInt(params.get("maxTokens").toString()));
}
// 设置topP
if (params.containsKey("topP")) {
builder.topP(Double.parseDouble(params.get("topP").toString()));
}
// 设置停止词
if (params.containsKey("stop")) {
if (params.get("stop") instanceof String) {
builder.stop(List.of((String) params.get("stop")));
} else if (params.get("stop") instanceof List) {
@SuppressWarnings("unchecked")
List<String> stopList = (List<String>) params.get("stop");
builder.stop(stopList);
}
}
}
return builder.build();
}
}