Compare commits
3 Commits
Author | SHA1 | Date |
---|---|---|
|
9022a082c5 | |
|
46d676a8c1 | |
|
7e43fd59a4 |
26
pom.xml
26
pom.xml
|
@ -28,6 +28,22 @@
|
|||
<enabled>false</enabled>
|
||||
</snapshots>
|
||||
</repository>
|
||||
<repository>
|
||||
<id>spring-milestones</id>
|
||||
<name>Spring Milestones</name>
|
||||
<url>https://repo.spring.io/milestone</url>
|
||||
<snapshots>
|
||||
<enabled>false</enabled>
|
||||
</snapshots>
|
||||
</repository>
|
||||
<repository>
|
||||
<id>spring-snapshots</id>
|
||||
<name>Spring Snapshots</name>
|
||||
<url>https://repo.spring.io/snapshot</url>
|
||||
<releases>
|
||||
<enabled>false</enabled>
|
||||
</releases>
|
||||
</repository>
|
||||
</repositories>
|
||||
|
||||
<properties>
|
||||
|
@ -35,6 +51,7 @@
|
|||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>17</maven.compiler.source>
|
||||
<maven.compiler.target>17</maven.compiler.target>
|
||||
<spring-ai.version>1.0.0-M6</spring-ai.version>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
|
||||
|
@ -53,6 +70,13 @@
|
|||
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-bom</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
<type>pom</type>
|
||||
<scope>import</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mybatis.spring.boot</groupId>
|
||||
<artifactId>mybatis-spring-boot-starter</artifactId>
|
||||
|
@ -164,6 +188,8 @@
|
|||
</plugins>
|
||||
</build>
|
||||
|
||||
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>dev</id>
|
||||
|
|
|
@ -12,10 +12,6 @@
|
|||
<packaging>jar</packaging>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
|
@ -134,4 +130,6 @@
|
|||
</plugins>
|
||||
</build>
|
||||
|
||||
|
||||
|
||||
</project>
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
/**
|
||||
* 1. 用于管理引入的Jar所需的资源启动或者初始化处理
|
||||
* 2. 如果有AOP切面,可以再建一个aop包,来写切面逻辑
|
||||
*/
|
||||
package com.touka.config;
|
||||
|
|
@ -12,7 +12,7 @@ thread:
|
|||
block-queue-size: 5000
|
||||
policy: CallerRunsPolicy
|
||||
|
||||
# 数据库配置;启动时配置数据库资源信息
|
||||
# 数据库配置
|
||||
spring:
|
||||
datasource:
|
||||
username: root
|
||||
|
@ -21,16 +21,45 @@ spring:
|
|||
driver-class-name: com.mysql.cj.jdbc.Driver
|
||||
hikari:
|
||||
pool-name: Retail_HikariCP
|
||||
minimum-idle: 15 #最小空闲连接数量
|
||||
idle-timeout: 180000 #空闲连接存活最大时间,默认600000(10分钟)
|
||||
maximum-pool-size: 25 #连接池最大连接数,默认是10
|
||||
auto-commit: true #此属性控制从池返回的连接的默认自动提交行为,默认值:true
|
||||
max-lifetime: 1800000 #此属性控制池中连接的最长生命周期,值0表示无限生命周期,默认1800000即30分钟
|
||||
connection-timeout: 30000 #数据库连接超时时间,默认30秒,即30000
|
||||
minimum-idle: 15
|
||||
idle-timeout: 180000
|
||||
maximum-pool-size: 25
|
||||
auto-commit: true
|
||||
max-lifetime: 1800000
|
||||
connection-timeout: 30000
|
||||
connection-test-query: SELECT 1
|
||||
type: com.zaxxer.hikari.HikariDataSource
|
||||
|
||||
# MyBatis 配置【如需使用记得打开】
|
||||
ai:
|
||||
deepseek:
|
||||
api-key: your-api-key
|
||||
base-url: https://api.deepseek.com
|
||||
chat:
|
||||
options:
|
||||
model: deepseek-reasoner
|
||||
temperature: 0.8
|
||||
openai: # 注意:这个应该在spring: ai: 下面,而不是与spring平级
|
||||
api-key: sk-EcUhlzeo44aQLBJqNncwBxw8yQuAE8rlbe8fNWLC3RXxfDGq
|
||||
base-url: https://api.chatanywhere.tech # 您使用的是第三方API
|
||||
chat:
|
||||
options:
|
||||
model: gpt-3.5-turbo
|
||||
temperature: 0.7
|
||||
max-tokens: 2048
|
||||
|
||||
llm:
|
||||
# 默认连接器类型
|
||||
default-connector: openai
|
||||
# 各连接器详细配置
|
||||
connectors:
|
||||
# OpenAI连接器配置
|
||||
openai:
|
||||
api-key: sk-EcUhlzeo44aQLBJqNncwBxw8yQuAE8rlbe8fNWLC3RXxfDGq
|
||||
base-url: https://api.chatanywhere.tech
|
||||
timeout-ms: 30000
|
||||
enabled: true
|
||||
|
||||
# MyBatis 配置
|
||||
#mybatis:
|
||||
# mapper-locations: classpath:/mybatis/mapper/*.xml
|
||||
# config-location: classpath:/mybatis/config/mybatis-config.xml
|
||||
|
@ -39,4 +68,4 @@ spring:
|
|||
logging:
|
||||
level:
|
||||
root: info
|
||||
config: classpath:logback-spring.xml
|
||||
config: classpath:logback-spring.xml
|
||||
|
|
|
@ -1,19 +1,82 @@
|
|||
package com.touka.test;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.test.context.junit4.SpringRunner;
|
||||
|
||||
@Slf4j
|
||||
@RunWith(SpringRunner.class)
|
||||
@SpringBootTest
|
||||
public class ApiTest {
|
||||
|
||||
@Test
|
||||
public void test() {
|
||||
log.info("测试完成");
|
||||
}
|
||||
|
||||
}
|
||||
//package com.touka.test;
|
||||
//
|
||||
//import jakarta.annotation.Resource;
|
||||
//import lombok.extern.slf4j.Slf4j;
|
||||
//import org.junit.Test;
|
||||
//import org.junit.runner.RunWith;
|
||||
//import org.springframework.ai.chat.model.ChatResponse;
|
||||
//
|
||||
//import org.springframework.ai.chat.prompt.Prompt;
|
||||
//import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
//
|
||||
//
|
||||
//import org.springframework.ai.deepseek.DeepSeekChatModel;
|
||||
//import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
//
|
||||
//import org.springframework.boot.test.context.SpringBootTest;
|
||||
//import org.springframework.test.context.junit4.SpringRunner;
|
||||
//
|
||||
//import reactor.core.publisher.Flux;
|
||||
//
|
||||
//@Slf4j
|
||||
//@RunWith(SpringRunner.class)
|
||||
//@SpringBootTest
|
||||
//public class ApiTest {
|
||||
//
|
||||
// @Resource
|
||||
// private DeepSeekChatModel chatModel;
|
||||
//
|
||||
// /**
|
||||
// * 测试同步生成响应
|
||||
// */
|
||||
// @Test
|
||||
// public void testGenerate() {
|
||||
// String message = "Tell me a joke";
|
||||
//
|
||||
// // 方式1: 直接传入字符串
|
||||
// String response1 = chatModel.call(message);
|
||||
// System.out.println("Response1: " + response1);
|
||||
//
|
||||
// // 方式2: 使用 Prompt
|
||||
// ChatResponse response2 = chatModel.call(new Prompt(message));
|
||||
// System.out.println("Response2: " + response2.getResult().getOutput().getText());
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * 测试流式生成响应
|
||||
// */
|
||||
// @Test
|
||||
// public void testGenerateStream() {
|
||||
// String message = "Tell me a joke";
|
||||
//
|
||||
// // 使用 PromptTemplate 构建提示词
|
||||
// Prompt prompt = new PromptTemplate(message).create();
|
||||
//
|
||||
// // 流式输出
|
||||
// Flux<ChatResponse> responseFlux = chatModel.stream(prompt);
|
||||
//
|
||||
// // 订阅并打印流式响应
|
||||
// responseFlux.doOnNext(response -> {
|
||||
// System.out.println("Stream Response: " + response.getResult().getOutput().getText());
|
||||
// }).blockLast(); // 在测试中阻塞等待完成
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * 测试使用特定模型选项
|
||||
// */
|
||||
// @Test
|
||||
// public void testGenerateWithModelOptions() {
|
||||
// String message = "1+1=?";
|
||||
// String model = "deepseek-r1:1.5b";
|
||||
//
|
||||
// ChatResponse response = chatModel.call(new Prompt(
|
||||
// message,
|
||||
// OllamaOptions.builder()
|
||||
// .model(model)
|
||||
// .build()
|
||||
// ));
|
||||
//
|
||||
// System.out.println("Response with model options: " + response.getResult().getOutput().getText());
|
||||
// }
|
||||
//
|
||||
//}
|
||||
|
|
|
@ -10,6 +10,16 @@
|
|||
<artifactId>visual-novel-server-trigger</artifactId>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
package com.touka.trigger.http.controller;
|
||||
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import com.touka.trigger.http.service.LLMService;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
|
||||
/**
|
||||
* LLM控制器,提供REST API接口
|
||||
*/
|
||||
@RestController
|
||||
@RequestMapping("/api/llm")
|
||||
public class LLMController {
|
||||
|
||||
private final LLMService llmService;
|
||||
|
||||
@Autowired
|
||||
public LLMController(LLMService llmService) {
|
||||
this.llmService = llmService;
|
||||
}
|
||||
|
||||
@PostMapping("/generate")
|
||||
public String generate(@RequestBody Map<String, Object> request) {
|
||||
String prompt = (String) request.get("prompt");
|
||||
if (prompt == null) {
|
||||
throw new IllegalArgumentException("Prompt is required");
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> params = (Map<String, Object>) request.getOrDefault("params", Map.of());
|
||||
|
||||
String connectorType = (String) request.get("connectorType");
|
||||
if (connectorType != null) {
|
||||
return llmService.generateText(connectorType, prompt, params);
|
||||
} else {
|
||||
return llmService.generateText(prompt, params);
|
||||
}
|
||||
}
|
||||
|
||||
@GetMapping("/stream")
|
||||
public SseEmitter stream(
|
||||
@RequestParam String prompt,
|
||||
@RequestParam(required = false) String connectorType) {
|
||||
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
|
||||
llmService.streamText(prompt, Map.of(), emitter);
|
||||
return emitter;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@GetMapping("/connectors")
|
||||
public Map<String, Object> getConnectors() {
|
||||
return (Map<String, Object>) (Map<?, ?>) llmService.getAllAvailableConnectors();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
//package com.touka.trigger.http.controller;
|
||||
//
|
||||
//import jakarta.annotation.Resource;
|
||||
//import org.springframework.ai.chat.messages.UserMessage;
|
||||
//import org.springframework.ai.chat.prompt.Prompt;
|
||||
//import org.springframework.ai.deepseek.DeepSeekChatModel;
|
||||
//import org.springframework.web.bind.annotation.CrossOrigin;
|
||||
//import org.springframework.web.bind.annotation.GetMapping;
|
||||
//import org.springframework.web.bind.annotation.RequestMapping;
|
||||
//import org.springframework.web.bind.annotation.RequestParam;
|
||||
//import org.springframework.web.bind.annotation.RestController;
|
||||
//import reactor.core.publisher.Flux;
|
||||
//
|
||||
//import java.util.Map;
|
||||
//
|
||||
//@RestController()
|
||||
//@CrossOrigin("*")
|
||||
//@RequestMapping("/api/ollama/")
|
||||
//public class OllamaController {
|
||||
// @Resource
|
||||
// private DeepSeekChatModel chatModel;
|
||||
//
|
||||
// @GetMapping("/generate")
|
||||
// public Map generate(@RequestParam(value = "message") String message) {
|
||||
// return Map.of("generation", chatModel.call(message));
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * 流式对话
|
||||
// * @param message
|
||||
// * @return
|
||||
// */
|
||||
// @GetMapping(value = "/generateStream", produces = "text/html;charset=utf-8")
|
||||
// public Flux<String> generateStream(@RequestParam(value = "message") String message) {
|
||||
// // 构建提示词
|
||||
// Prompt prompt = new Prompt(new UserMessage(message));
|
||||
// // 流式输出
|
||||
// return chatModel.stream(prompt)
|
||||
// .mapNotNull(chatResponse -> chatResponse.getResult().getOutput().getText());
|
||||
// }
|
||||
//
|
||||
//}
|
|
@ -1,4 +0,0 @@
|
|||
/**
|
||||
* HTTP 接口服务
|
||||
*/
|
||||
package com.touka.trigger.http;
|
|
@ -0,0 +1,148 @@
|
|||
package com.touka.trigger.http.service;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import com.touka.types.config.ILLMConnector;
|
||||
import com.touka.types.config.LLMConnectorFactory;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
|
||||
/**
|
||||
* LLM服务,提供简化的接口来调用不同的LLM连接器
|
||||
*/
|
||||
@Service
|
||||
public class LLMService {
|
||||
|
||||
private final LLMConnectorFactory connectorFactory;
|
||||
|
||||
@Autowired
|
||||
public LLMService(LLMConnectorFactory connectorFactory) {
|
||||
this.connectorFactory = connectorFactory;
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用默认连接器生成文本
|
||||
* @param prompt 提示词
|
||||
* @return 生成的文本
|
||||
*/
|
||||
public String generateText(String prompt) {
|
||||
return generateText(prompt, Collections.emptyMap());
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用默认连接器生成文本
|
||||
* @param prompt 提示词
|
||||
* @param params 额外参数
|
||||
* @return 生成的文本
|
||||
*/
|
||||
public String generateText(String prompt, Map<String, Object> params) {
|
||||
ILLMConnector connector = connectorFactory.getDefaultConnector();
|
||||
if (connector == null) {
|
||||
throw new IllegalStateException("No available LLM connector");
|
||||
}
|
||||
return connector.generateText(prompt, params);
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用指定连接器生成文本
|
||||
* @param connectorType 连接器类型
|
||||
* @param prompt 提示词
|
||||
* @param params 额外参数
|
||||
* @return 生成的文本
|
||||
*/
|
||||
public String generateText(String connectorType, String prompt, Map<String, Object> params) {
|
||||
ILLMConnector connector = connectorFactory.getConnector(connectorType);
|
||||
if (connector == null || !connector.isAvailable()) {
|
||||
throw new IllegalArgumentException("Connector not found or not available: " + connectorType);
|
||||
}
|
||||
return connector.generateText(prompt, params);
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式生成文本
|
||||
* @param prompt 提示词
|
||||
* @param emitter SSE发射器
|
||||
*/
|
||||
public void streamText(String prompt, SseEmitter emitter) {
|
||||
streamText(prompt, Collections.emptyMap(), emitter);
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式生成文本
|
||||
* @param prompt 提示词
|
||||
* @param params 额外参数
|
||||
* @param emitter SSE发射器
|
||||
*/
|
||||
public void streamText(String prompt, Map<String, Object> params, SseEmitter emitter) {
|
||||
ILLMConnector connector = connectorFactory.getDefaultConnector();
|
||||
if (connector == null) {
|
||||
emitter.completeWithError(new IllegalStateException("No available LLM connector"));
|
||||
return;
|
||||
}
|
||||
|
||||
// 使用Flux处理流式响应
|
||||
Flux<String> textFlux = connector.streamText(prompt, params);
|
||||
|
||||
textFlux.subscribe(
|
||||
content -> {
|
||||
try {
|
||||
emitter.send(SseEmitter.event().data(content));
|
||||
} catch (Exception e) {
|
||||
emitter.completeWithError(e);
|
||||
}
|
||||
},
|
||||
error -> emitter.completeWithError(error),
|
||||
() -> emitter.complete()
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用指定连接器流式生成文本
|
||||
* @param connectorType 连接器类型
|
||||
* @param prompt 提示词
|
||||
* @param params 额外参数
|
||||
* @param emitter SSE发射器
|
||||
*/
|
||||
public void streamText(String connectorType, String prompt, Map<String, Object> params, SseEmitter emitter) {
|
||||
ILLMConnector connector = connectorFactory.getConnector(connectorType);
|
||||
if (connector == null || !connector.isAvailable()) {
|
||||
emitter.completeWithError(new IllegalArgumentException("Connector not found or not available: " + connectorType));
|
||||
return;
|
||||
}
|
||||
|
||||
// 使用Flux处理流式响应
|
||||
Flux<String> textFlux = connector.streamText(prompt, params);
|
||||
|
||||
textFlux.subscribe(
|
||||
content -> {
|
||||
try {
|
||||
emitter.send(SseEmitter.event().data(content));
|
||||
} catch (Exception e) {
|
||||
emitter.completeWithError(e);
|
||||
}
|
||||
},
|
||||
error -> emitter.completeWithError(error),
|
||||
() -> emitter.complete()
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有可用的连接器类型
|
||||
* @return 连接器类型列表
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public Map<String, Object> getAllAvailableConnectors() {
|
||||
Map<String, Object> availableConnectors = new HashMap<>();
|
||||
connectorFactory.getAllConnectors().forEach((type, connector) -> {
|
||||
if (connector.isAvailable()) {
|
||||
availableConnectors.put(type, connector);
|
||||
}
|
||||
});
|
||||
return availableConnectors;
|
||||
}
|
||||
}
|
|
@ -1,4 +0,0 @@
|
|||
/**
|
||||
* 任务服务,可以选择使用 Spring 默认提供的 Schedule https://bugstack.cn/md/road-map/quartz.html
|
||||
*/
|
||||
package com.touka.trigger.job;
|
|
@ -1,5 +0,0 @@
|
|||
/**
|
||||
* 监听服务;在单体服务中,解耦流程。类似MQ的使用,如Spring的Event,Guava的事件总线都可以。如果使用了 Redis 那么也可以有发布/订阅使用。
|
||||
* Guava:https://bugstack.cn/md/road-map/guava.html
|
||||
*/
|
||||
package com.touka.trigger.listener;
|
|
@ -10,6 +10,22 @@
|
|||
<artifactId>visual-novel-server-types</artifactId>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
|
||||
</dependency>
|
||||
<!-- 添加OpenAI SDK依赖 -->
|
||||
<dependency>
|
||||
<groupId>com.theokanning.openai-gpt3-java</groupId>
|
||||
<artifactId>client</artifactId>
|
||||
<version>0.18.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
package com.touka.types.config;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* LLM连接器的基础实现,提供一些通用功能
|
||||
*/
|
||||
public abstract class BaseLLMConnector implements ILLMConnector {
|
||||
|
||||
protected String apiKey;
|
||||
protected String baseUrl;
|
||||
protected int timeoutMs;
|
||||
protected boolean enabled;
|
||||
|
||||
public BaseLLMConnector(String apiKey, String baseUrl, int timeoutMs) {
|
||||
this.apiKey = apiKey;
|
||||
this.baseUrl = baseUrl;
|
||||
this.timeoutMs = timeoutMs;
|
||||
this.enabled = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isAvailable() {
|
||||
return enabled;
|
||||
}
|
||||
|
||||
public void setEnabled(boolean enabled) {
|
||||
this.enabled = enabled;
|
||||
}
|
||||
|
||||
public String getApiKey() {
|
||||
return apiKey;
|
||||
}
|
||||
|
||||
public void setApiKey(String apiKey) {
|
||||
this.apiKey = apiKey;
|
||||
}
|
||||
|
||||
public String getBaseUrl() {
|
||||
return baseUrl;
|
||||
}
|
||||
|
||||
public void setBaseUrl(String baseUrl) {
|
||||
this.baseUrl = baseUrl;
|
||||
}
|
||||
|
||||
public int getTimeoutMs() {
|
||||
return timeoutMs;
|
||||
}
|
||||
|
||||
public void setTimeoutMs(int timeoutMs) {
|
||||
this.timeoutMs = timeoutMs;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
package com.touka.types.config;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* 大语言模型连接器接口,定义了与各类LLM API交互的统一方法
|
||||
*/
|
||||
public interface ILLMConnector {
|
||||
|
||||
/**
|
||||
* 使用Prompt对象生成文本响应
|
||||
* @param prompt 提示对象
|
||||
* @return 聊天响应
|
||||
*/
|
||||
ChatResponse generateResponse(Prompt prompt);
|
||||
|
||||
/**
|
||||
* 使用文本提示生成文本
|
||||
* @param textPrompt 文本提示
|
||||
* @param params 额外参数
|
||||
* @return 生成的文本内容
|
||||
*/
|
||||
String generateText(String textPrompt, Map<String, Object> params);
|
||||
|
||||
/**
|
||||
* 流式生成文本响应
|
||||
* @param prompt 提示对象
|
||||
* @return 响应流
|
||||
*/
|
||||
Flux<ChatResponse> streamResponse(Prompt prompt);
|
||||
|
||||
/**
|
||||
* 流式生成文本
|
||||
* @param textPrompt 文本提示
|
||||
* @param params 额外参数
|
||||
* @return 文本流
|
||||
*/
|
||||
Flux<String> streamText(String textPrompt, Map<String, Object> params);
|
||||
|
||||
/**
|
||||
* 获取连接器类型
|
||||
* @return 连接器类型标识
|
||||
*/
|
||||
String getConnectorType();
|
||||
|
||||
/**
|
||||
* 检查连接器是否可用
|
||||
* @return 连接器是否可用
|
||||
*/
|
||||
boolean isAvailable();
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
package com.touka.types.config;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
// 在配置类中添加
|
||||
@Configuration
|
||||
public class LLMConfig {
|
||||
@Bean
|
||||
public ChatClient chatClient(ChatClient.Builder builder) {
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public ChatClient.Builder chatClientBuilder(OpenAiChatModel openAiChatModel) {
|
||||
// 根据配置创建Builder
|
||||
return ChatClient.builder(openAiChatModel);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiChatModel(OpenAiApi openAiApi) {
|
||||
OpenAiChatOptions options = OpenAiChatOptions.builder()
|
||||
.model("gpt-3.5-turbo")
|
||||
.temperature(0.7)
|
||||
.maxCompletionTokens(2048)
|
||||
.build();
|
||||
return new OpenAiChatModel(openAiApi, options);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiApi openAiApi(@Value("${spring.ai.openai.api-key}") String apiKey, @Value("${spring.ai.openai.base-url}") String baseUrl) {
|
||||
return OpenAiApi.builder().
|
||||
baseUrl(baseUrl).
|
||||
apiKey(apiKey).
|
||||
build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
package com.touka.types.config;
|
||||
|
||||
import java.util.Map;
|
||||
import lombok.Data;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* LLM连接器的配置类,用于从配置文件中读取连接器配置
|
||||
*/
|
||||
@Data
|
||||
@Component
|
||||
@ConfigurationProperties(prefix = "llm")
|
||||
public class LLMConnectorConfig {
|
||||
|
||||
/**
|
||||
* 默认连接器类型
|
||||
*/
|
||||
private String defaultConnector;
|
||||
|
||||
/**
|
||||
* 各连接器的详细配置
|
||||
*/
|
||||
private Map<String, ConnectorDetailConfig> connectors;
|
||||
|
||||
@Data
|
||||
public static class ConnectorDetailConfig {
|
||||
private String apiKey;
|
||||
private String baseUrl;
|
||||
private int timeoutMs = 30000;
|
||||
private boolean enabled = true;
|
||||
private Map<String, Object> additionalParams;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,124 @@
|
|||
package com.touka.types.config;
|
||||
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
|
||||
import com.touka.types.config.impl.OpenAIConnector;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
|
||||
/**
|
||||
* LLM连接器工厂,用于创建和管理不同类型的连接器实例
|
||||
*/
|
||||
@Component
|
||||
public class LLMConnectorFactory {
|
||||
|
||||
@Value("${spring.ai.openai.api-key:}")
|
||||
private String defaultApiKey;
|
||||
|
||||
@Value("${spring.ai.openai.base-url:https://api.openai.com/v1}")
|
||||
private String defaultBaseUrl;
|
||||
|
||||
private final LLMConnectorConfig config;
|
||||
private final ChatClient chatClient;
|
||||
private final Map<String, ILLMConnector> connectors = new ConcurrentHashMap<>();
|
||||
|
||||
@Autowired
|
||||
public LLMConnectorFactory(LLMConnectorConfig config, ChatClient chatClient) {
|
||||
this.config = config;
|
||||
this.chatClient = chatClient;
|
||||
initializeConnectors();
|
||||
}
|
||||
|
||||
private void initializeConnectors() {
|
||||
// 注册默认连接器
|
||||
registerDefaultConnectors();
|
||||
|
||||
// 根据配置初始化连接器
|
||||
if (config.getConnectors() != null) {
|
||||
for (Map.Entry<String, LLMConnectorConfig.ConnectorDetailConfig> entry : config.getConnectors().entrySet()) {
|
||||
String connectorType = entry.getKey();
|
||||
LLMConnectorConfig.ConnectorDetailConfig detailConfig = entry.getValue();
|
||||
|
||||
if (detailConfig.isEnabled()) {
|
||||
ILLMConnector connector = createConnector(connectorType, detailConfig);
|
||||
if (connector != null) {
|
||||
connectors.put(connectorType, connector);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void registerDefaultConnectors() {
|
||||
// 注册默认连接器实现,传入所有必要的参数
|
||||
connectors.put("openai", new OpenAIConnector(chatClient, defaultApiKey, defaultBaseUrl));
|
||||
}
|
||||
|
||||
private ILLMConnector createConnector(String connectorType, LLMConnectorConfig.ConnectorDetailConfig detailConfig) {
|
||||
switch (connectorType.toLowerCase()) {
|
||||
case "openai":
|
||||
return new OpenAIConnector(
|
||||
detailConfig.getApiKey(),
|
||||
detailConfig.getBaseUrl(),
|
||||
detailConfig.getTimeoutMs(),
|
||||
chatClient);
|
||||
default:
|
||||
// 可以扩展支持更多连接器类型
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取指定类型的连接器
|
||||
* @param connectorType 连接器类型
|
||||
* @return 连接器实例,如果不存在则返回null
|
||||
*/
|
||||
public ILLMConnector getConnector(String connectorType) {
|
||||
return connectors.get(connectorType);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取默认连接器
|
||||
* @return 默认连接器实例
|
||||
*/
|
||||
public ILLMConnector getDefaultConnector() {
|
||||
String defaultType = config.getDefaultConnector();
|
||||
if (defaultType != null) {
|
||||
ILLMConnector connector = connectors.get(defaultType);
|
||||
if (connector != null && connector.isAvailable()) {
|
||||
return connector;
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有配置默认连接器或默认连接器不可用,返回第一个可用的连接器
|
||||
return connectors.values().stream()
|
||||
.filter(ILLMConnector::isAvailable)
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 注册新的连接器
|
||||
* @param connector 连接器实例
|
||||
*/
|
||||
public void registerConnector(ILLMConnector connector) {
|
||||
if (connector != null) {
|
||||
connectors.put(connector.getConnectorType(), connector);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有可用的连接器类型
|
||||
* @return 连接器类型列表
|
||||
*/
|
||||
public Map<String, ILLMConnector> getAllConnectors() {
|
||||
return new ConcurrentHashMap<>(connectors);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,159 @@
|
|||
package com.touka.types.config.impl;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import com.touka.types.config.BaseLLMConnector;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* OpenAI API连接器实现,基于Spring AI的ChatModel和ChatClient
|
||||
*/
|
||||
@Component
|
||||
public class OpenAIConnector extends BaseLLMConnector {
|
||||
|
||||
private static final String CONNECTOR_TYPE = "openai";
|
||||
|
||||
private final ChatClient chatClient;
|
||||
|
||||
@Autowired
|
||||
public OpenAIConnector(ChatClient chatClient,
|
||||
@Value("${spring.ai.openai.api-key:}") String apiKey,
|
||||
@Value("${spring.ai.openai.base-url:https://api.openai.com/v1}") String baseUrl) {
|
||||
// 直接在构造函数中使用注入的值
|
||||
super(apiKey, baseUrl, 30000);
|
||||
this.chatClient = chatClient;
|
||||
}
|
||||
|
||||
public OpenAIConnector(String apiKey, String baseUrl, int timeoutMs, ChatClient chatClient) {
|
||||
super(apiKey, baseUrl, timeoutMs);
|
||||
this.chatClient = chatClient;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse generateResponse(Prompt prompt) {
|
||||
return chatClient.prompt(prompt).call().chatResponse();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String generateText(String textPrompt, Map<String, Object> params) {
|
||||
// 构建消息列表
|
||||
List<Message> messages = new ArrayList<>();
|
||||
|
||||
// 添加系统消息(如果有)
|
||||
if (params != null && params.containsKey("systemPrompt")) {
|
||||
String systemPrompt = (String) params.get("systemPrompt");
|
||||
messages.add(new SystemPromptTemplate(systemPrompt).createMessage(params));
|
||||
}
|
||||
|
||||
// 添加用户消息
|
||||
messages.add(new UserMessage(textPrompt));
|
||||
|
||||
// 创建提示
|
||||
Prompt prompt = new Prompt(messages, createOptions(params));
|
||||
|
||||
// 调用ChatClient生成响应
|
||||
ChatResponse response = chatClient.prompt(prompt).call().chatResponse();
|
||||
|
||||
// 提取生成的文本
|
||||
if (response.getResults() != null && !response.getResults().isEmpty()) {
|
||||
Generation generation = response.getResults().get(0);
|
||||
return generation.getOutput().getText();
|
||||
}
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> streamResponse(Prompt prompt) {
|
||||
return chatClient.prompt(prompt).stream().chatResponse();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<String> streamText(String textPrompt, Map<String, Object> params) {
|
||||
// 构建消息列表
|
||||
List<Message> messages = new ArrayList<>();
|
||||
|
||||
// 添加系统消息(如果有)
|
||||
if (params != null && params.containsKey("systemPrompt")) {
|
||||
String systemPrompt = (String) params.get("systemPrompt");
|
||||
messages.add(new SystemPromptTemplate(systemPrompt).createMessage(params));
|
||||
}
|
||||
|
||||
// 添加用户消息
|
||||
messages.add(new UserMessage(textPrompt));
|
||||
|
||||
// 创建提示
|
||||
Prompt prompt = new Prompt(messages, createOptions(params));
|
||||
|
||||
// 流式调用并处理响应
|
||||
return chatClient.prompt(prompt).stream().chatResponse()
|
||||
.mapNotNull(response -> {
|
||||
if (response.getResults() != null && !response.getResults().isEmpty()) {
|
||||
Generation generation = response.getResults().get(0);
|
||||
return generation.getOutput().getText();
|
||||
}
|
||||
return "";
|
||||
})
|
||||
.filter(content -> !content.isEmpty());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getConnectorType() {
|
||||
return CONNECTOR_TYPE;
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据参数创建OpenAI选项
|
||||
*/
|
||||
private OpenAiChatOptions createOptions(Map<String, Object> params) {
|
||||
OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder();
|
||||
|
||||
if (params != null) {
|
||||
// 设置模型
|
||||
if (params.containsKey("model")) {
|
||||
builder.model((String) params.get("model"));
|
||||
}
|
||||
|
||||
// 设置温度
|
||||
if (params.containsKey("temperature")) {
|
||||
builder.temperature(Double.parseDouble(params.get("temperature").toString()));
|
||||
}
|
||||
|
||||
// 设置最大生成长度
|
||||
if (params.containsKey("maxTokens")) {
|
||||
builder.maxTokens(Integer.parseInt(params.get("maxTokens").toString()));
|
||||
}
|
||||
|
||||
// 设置topP
|
||||
if (params.containsKey("topP")) {
|
||||
builder.topP(Double.parseDouble(params.get("topP").toString()));
|
||||
}
|
||||
|
||||
// 设置停止词
|
||||
if (params.containsKey("stop")) {
|
||||
if (params.get("stop") instanceof String) {
|
||||
builder.stop(List.of((String) params.get("stop")));
|
||||
} else if (params.get("stop") instanceof List) {
|
||||
@SuppressWarnings("unchecked")
|
||||
List<String> stopList = (List<String>) params.get("stop");
|
||||
builder.stop(stopList);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue