diff --git a/pom.xml b/pom.xml index 511e2ce..cc73b4c 100644 --- a/pom.xml +++ b/pom.xml @@ -28,6 +28,22 @@ false + + spring-milestones + Spring Milestones + https://repo.spring.io/milestone + + false + + + + spring-snapshots + Spring Snapshots + https://repo.spring.io/snapshot + + false + + @@ -35,7 +51,7 @@ UTF-8 17 17 - 1.0.0 + 1.0.0-M6 UTF-8 @@ -172,6 +188,8 @@ + + dev diff --git a/visual-novel-server-app/pom.xml b/visual-novel-server-app/pom.xml index 8e1c22d..333a30b 100644 --- a/visual-novel-server-app/pom.xml +++ b/visual-novel-server-app/pom.xml @@ -12,19 +12,6 @@ jar - - org.springframework.ai - spring-ai-ollama - - - - org.springframework.ai - spring-ai-starter-model-deepseek - - - org.springframework.boot - spring-boot-starter-web - org.springframework.boot spring-boot-starter-test @@ -143,4 +130,6 @@ + + diff --git a/visual-novel-server-app/src/main/java/com/touka/config/package-info.java b/visual-novel-server-app/src/main/java/com/touka/config/package-info.java deleted file mode 100644 index e58f90e..0000000 --- a/visual-novel-server-app/src/main/java/com/touka/config/package-info.java +++ /dev/null @@ -1,6 +0,0 @@ -/** - * 1. 用于管理引入的Jar所需的资源启动或者初始化处理 - * 2. 如果有AOP切面,可以再建一个aop包,来写切面逻辑 - */ -package com.touka.config; - diff --git a/visual-novel-server-app/src/main/resources/application-dev.yml b/visual-novel-server-app/src/main/resources/application-dev.yml index c0254c0..671bac1 100644 --- a/visual-novel-server-app/src/main/resources/application-dev.yml +++ b/visual-novel-server-app/src/main/resources/application-dev.yml @@ -12,7 +12,7 @@ thread: block-queue-size: 5000 policy: CallerRunsPolicy -# 数据库配置;启动时配置数据库资源信息 +# 数据库配置 spring: datasource: username: root @@ -21,25 +21,45 @@ spring: driver-class-name: com.mysql.cj.jdbc.Driver hikari: pool-name: Retail_HikariCP - minimum-idle: 15 #最小空闲连接数量 - idle-timeout: 180000 #空闲连接存活最大时间,默认600000(10分钟) - maximum-pool-size: 25 #连接池最大连接数,默认是10 - auto-commit: true #此属性控制从池返回的连接的默认自动提交行为,默认值:true - max-lifetime: 1800000 #此属性控制池中连接的最长生命周期,值0表示无限生命周期,默认1800000即30分钟 - connection-timeout: 30000 #数据库连接超时时间,默认30秒,即30000 + minimum-idle: 15 + idle-timeout: 180000 + maximum-pool-size: 25 + auto-commit: true + max-lifetime: 1800000 + connection-timeout: 30000 connection-test-query: SELECT 1 type: com.zaxxer.hikari.HikariDataSource ai: deepseek: api-key: your-api-key - base-url: https://api.deepseek.com # DeepSeek 的请求 URL, 可不填,默认值为 api.deepseek.com + base-url: https://api.deepseek.com chat: options: - model: deepseek-reasoner # 使用深度思考模型 - temperature: 0.8 # 温度值 + model: deepseek-reasoner + temperature: 0.8 + openai: # 注意:这个应该在spring: ai: 下面,而不是与spring平级 + api-key: sk-EcUhlzeo44aQLBJqNncwBxw8yQuAE8rlbe8fNWLC3RXxfDGq + base-url: https://api.chatanywhere.tech # 您使用的是第三方API + chat: + options: + model: gpt-3.5-turbo + temperature: 0.7 + max-tokens: 2048 -# MyBatis 配置【如需使用记得打开】 +llm: + # 默认连接器类型 + default-connector: openai + # 各连接器详细配置 + connectors: + # OpenAI连接器配置 + openai: + api-key: sk-EcUhlzeo44aQLBJqNncwBxw8yQuAE8rlbe8fNWLC3RXxfDGq + base-url: https://api.chatanywhere.tech + timeout-ms: 30000 + enabled: true + +# MyBatis 配置 #mybatis: # mapper-locations: classpath:/mybatis/mapper/*.xml # config-location: classpath:/mybatis/config/mybatis-config.xml @@ -49,4 +69,3 @@ logging: level: root: info config: classpath:logback-spring.xml - diff --git a/visual-novel-server-app/src/test/java/com/touka/test/ApiTest.java b/visual-novel-server-app/src/test/java/com/touka/test/ApiTest.java index a11a145..3267ca2 100644 --- a/visual-novel-server-app/src/test/java/com/touka/test/ApiTest.java +++ b/visual-novel-server-app/src/test/java/com/touka/test/ApiTest.java @@ -1,81 +1,82 @@ -package com.touka.test; - -import jakarta.annotation.Resource; -import lombok.extern.slf4j.Slf4j; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.springframework.ai.chat.model.ChatResponse; - -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.PromptTemplate; -import org.springframework.ai.deepseek.DeepSeekChatModel; - -import org.springframework.ai.ollama.api.OllamaOptions; - -import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.test.context.junit4.SpringRunner; - -import reactor.core.publisher.Flux; - -@Slf4j -@RunWith(SpringRunner.class) -@SpringBootTest -public class ApiTest { - - @Resource - private DeepSeekChatModel chatModel; - - /** - * 测试同步生成响应 - */ - @Test - public void testGenerate() { - String message = "Tell me a joke"; - - // 方式1: 直接传入字符串 - String response1 = chatModel.call(message); - System.out.println("Response1: " + response1); - - // 方式2: 使用 Prompt - ChatResponse response2 = chatModel.call(new Prompt(message)); - System.out.println("Response2: " + response2.getResult().getOutput().getText()); - } - - /** - * 测试流式生成响应 - */ - @Test - public void testGenerateStream() { - String message = "Tell me a joke"; - - // 使用 PromptTemplate 构建提示词 - Prompt prompt = new PromptTemplate(message).create(); - - // 流式输出 - Flux responseFlux = chatModel.stream(prompt); - - // 订阅并打印流式响应 - responseFlux.doOnNext(response -> { - System.out.println("Stream Response: " + response.getResult().getOutput().getText()); - }).blockLast(); // 在测试中阻塞等待完成 - } - - /** - * 测试使用特定模型选项 - */ - @Test - public void testGenerateWithModelOptions() { - String message = "1+1=?"; - String model = "deepseek-r1:1.5b"; - - ChatResponse response = chatModel.call(new Prompt( - message, - OllamaOptions.builder() - .model(model) - .build() - )); - - System.out.println("Response with model options: " + response.getResult().getOutput().getText()); - } - -} +//package com.touka.test; +// +//import jakarta.annotation.Resource; +//import lombok.extern.slf4j.Slf4j; +//import org.junit.Test; +//import org.junit.runner.RunWith; +//import org.springframework.ai.chat.model.ChatResponse; +// +//import org.springframework.ai.chat.prompt.Prompt; +//import org.springframework.ai.chat.prompt.PromptTemplate; +// +// +//import org.springframework.ai.deepseek.DeepSeekChatModel; +//import org.springframework.ai.ollama.api.OllamaOptions; +// +//import org.springframework.boot.test.context.SpringBootTest; +//import org.springframework.test.context.junit4.SpringRunner; +// +//import reactor.core.publisher.Flux; +// +//@Slf4j +//@RunWith(SpringRunner.class) +//@SpringBootTest +//public class ApiTest { +// +// @Resource +// private DeepSeekChatModel chatModel; +// +// /** +// * 测试同步生成响应 +// */ +// @Test +// public void testGenerate() { +// String message = "Tell me a joke"; +// +// // 方式1: 直接传入字符串 +// String response1 = chatModel.call(message); +// System.out.println("Response1: " + response1); +// +// // 方式2: 使用 Prompt +// ChatResponse response2 = chatModel.call(new Prompt(message)); +// System.out.println("Response2: " + response2.getResult().getOutput().getText()); +// } +// +// /** +// * 测试流式生成响应 +// */ +// @Test +// public void testGenerateStream() { +// String message = "Tell me a joke"; +// +// // 使用 PromptTemplate 构建提示词 +// Prompt prompt = new PromptTemplate(message).create(); +// +// // 流式输出 +// Flux responseFlux = chatModel.stream(prompt); +// +// // 订阅并打印流式响应 +// responseFlux.doOnNext(response -> { +// System.out.println("Stream Response: " + response.getResult().getOutput().getText()); +// }).blockLast(); // 在测试中阻塞等待完成 +// } +// +// /** +// * 测试使用特定模型选项 +// */ +// @Test +// public void testGenerateWithModelOptions() { +// String message = "1+1=?"; +// String model = "deepseek-r1:1.5b"; +// +// ChatResponse response = chatModel.call(new Prompt( +// message, +// OllamaOptions.builder() +// .model(model) +// .build() +// )); +// +// System.out.println("Response with model options: " + response.getResult().getOutput().getText()); +// } +// +//} diff --git a/visual-novel-server-trigger/pom.xml b/visual-novel-server-trigger/pom.xml index 01b257f..377c29d 100644 --- a/visual-novel-server-trigger/pom.xml +++ b/visual-novel-server-trigger/pom.xml @@ -10,6 +10,16 @@ visual-novel-server-trigger + + org.springframework.ai + spring-ai-openai-spring-boot-starter + + + + org.springframework.ai + spring-ai-ollama-spring-boot-starter + + org.springframework.boot spring-boot-starter-web diff --git a/visual-novel-server-trigger/src/main/java/com/touka/trigger/http/controller/LLMController.java b/visual-novel-server-trigger/src/main/java/com/touka/trigger/http/controller/LLMController.java new file mode 100644 index 0000000..43c0108 --- /dev/null +++ b/visual-novel-server-trigger/src/main/java/com/touka/trigger/http/controller/LLMController.java @@ -0,0 +1,64 @@ +package com.touka.trigger.http.controller; + + +import java.util.Map; + +import com.touka.trigger.http.service.LLMService; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + + +/** + * LLM控制器,提供REST API接口 + */ +@RestController +@RequestMapping("/api/llm") +public class LLMController { + + private final LLMService llmService; + + @Autowired + public LLMController(LLMService llmService) { + this.llmService = llmService; + } + + @PostMapping("/generate") + public String generate(@RequestBody Map request) { + String prompt = (String) request.get("prompt"); + if (prompt == null) { + throw new IllegalArgumentException("Prompt is required"); + } + + @SuppressWarnings("unchecked") + Map params = (Map) request.getOrDefault("params", Map.of()); + + String connectorType = (String) request.get("connectorType"); + if (connectorType != null) { + return llmService.generateText(connectorType, prompt, params); + } else { + return llmService.generateText(prompt, params); + } + } + + @GetMapping("/stream") + public SseEmitter stream( + @RequestParam String prompt, + @RequestParam(required = false) String connectorType) { + SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); + llmService.streamText(prompt, Map.of(), emitter); + return emitter; + } + + @SuppressWarnings("unchecked") + @GetMapping("/connectors") + public Map getConnectors() { + return (Map) (Map) llmService.getAllAvailableConnectors(); + } +} diff --git a/visual-novel-server-trigger/src/main/java/com/touka/trigger/http/controller/OllamaController.java b/visual-novel-server-trigger/src/main/java/com/touka/trigger/http/controller/OllamaController.java new file mode 100644 index 0000000..0bce8d0 --- /dev/null +++ b/visual-novel-server-trigger/src/main/java/com/touka/trigger/http/controller/OllamaController.java @@ -0,0 +1,42 @@ +//package com.touka.trigger.http.controller; +// +//import jakarta.annotation.Resource; +//import org.springframework.ai.chat.messages.UserMessage; +//import org.springframework.ai.chat.prompt.Prompt; +//import org.springframework.ai.deepseek.DeepSeekChatModel; +//import org.springframework.web.bind.annotation.CrossOrigin; +//import org.springframework.web.bind.annotation.GetMapping; +//import org.springframework.web.bind.annotation.RequestMapping; +//import org.springframework.web.bind.annotation.RequestParam; +//import org.springframework.web.bind.annotation.RestController; +//import reactor.core.publisher.Flux; +// +//import java.util.Map; +// +//@RestController() +//@CrossOrigin("*") +//@RequestMapping("/api/ollama/") +//public class OllamaController { +// @Resource +// private DeepSeekChatModel chatModel; +// +// @GetMapping("/generate") +// public Map generate(@RequestParam(value = "message") String message) { +// return Map.of("generation", chatModel.call(message)); +// } +// +// /** +// * 流式对话 +// * @param message +// * @return +// */ +// @GetMapping(value = "/generateStream", produces = "text/html;charset=utf-8") +// public Flux generateStream(@RequestParam(value = "message") String message) { +// // 构建提示词 +// Prompt prompt = new Prompt(new UserMessage(message)); +// // 流式输出 +// return chatModel.stream(prompt) +// .mapNotNull(chatResponse -> chatResponse.getResult().getOutput().getText()); +// } +// +//} diff --git a/visual-novel-server-trigger/src/main/java/com/touka/trigger/http/service/LLMService.java b/visual-novel-server-trigger/src/main/java/com/touka/trigger/http/service/LLMService.java new file mode 100644 index 0000000..a90c527 --- /dev/null +++ b/visual-novel-server-trigger/src/main/java/com/touka/trigger/http/service/LLMService.java @@ -0,0 +1,148 @@ +package com.touka.trigger.http.service; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import com.touka.types.config.ILLMConnector; +import com.touka.types.config.LLMConnectorFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import reactor.core.publisher.Flux; + + +/** + * LLM服务,提供简化的接口来调用不同的LLM连接器 + */ +@Service +public class LLMService { + + private final LLMConnectorFactory connectorFactory; + + @Autowired + public LLMService(LLMConnectorFactory connectorFactory) { + this.connectorFactory = connectorFactory; + } + + /** + * 使用默认连接器生成文本 + * @param prompt 提示词 + * @return 生成的文本 + */ + public String generateText(String prompt) { + return generateText(prompt, Collections.emptyMap()); + } + + /** + * 使用默认连接器生成文本 + * @param prompt 提示词 + * @param params 额外参数 + * @return 生成的文本 + */ + public String generateText(String prompt, Map params) { + ILLMConnector connector = connectorFactory.getDefaultConnector(); + if (connector == null) { + throw new IllegalStateException("No available LLM connector"); + } + return connector.generateText(prompt, params); + } + + /** + * 使用指定连接器生成文本 + * @param connectorType 连接器类型 + * @param prompt 提示词 + * @param params 额外参数 + * @return 生成的文本 + */ + public String generateText(String connectorType, String prompt, Map params) { + ILLMConnector connector = connectorFactory.getConnector(connectorType); + if (connector == null || !connector.isAvailable()) { + throw new IllegalArgumentException("Connector not found or not available: " + connectorType); + } + return connector.generateText(prompt, params); + } + + /** + * 流式生成文本 + * @param prompt 提示词 + * @param emitter SSE发射器 + */ + public void streamText(String prompt, SseEmitter emitter) { + streamText(prompt, Collections.emptyMap(), emitter); + } + + /** + * 流式生成文本 + * @param prompt 提示词 + * @param params 额外参数 + * @param emitter SSE发射器 + */ + public void streamText(String prompt, Map params, SseEmitter emitter) { + ILLMConnector connector = connectorFactory.getDefaultConnector(); + if (connector == null) { + emitter.completeWithError(new IllegalStateException("No available LLM connector")); + return; + } + + // 使用Flux处理流式响应 + Flux textFlux = connector.streamText(prompt, params); + + textFlux.subscribe( + content -> { + try { + emitter.send(SseEmitter.event().data(content)); + } catch (Exception e) { + emitter.completeWithError(e); + } + }, + error -> emitter.completeWithError(error), + () -> emitter.complete() + ); + } + + /** + * 使用指定连接器流式生成文本 + * @param connectorType 连接器类型 + * @param prompt 提示词 + * @param params 额外参数 + * @param emitter SSE发射器 + */ + public void streamText(String connectorType, String prompt, Map params, SseEmitter emitter) { + ILLMConnector connector = connectorFactory.getConnector(connectorType); + if (connector == null || !connector.isAvailable()) { + emitter.completeWithError(new IllegalArgumentException("Connector not found or not available: " + connectorType)); + return; + } + + // 使用Flux处理流式响应 + Flux textFlux = connector.streamText(prompt, params); + + textFlux.subscribe( + content -> { + try { + emitter.send(SseEmitter.event().data(content)); + } catch (Exception e) { + emitter.completeWithError(e); + } + }, + error -> emitter.completeWithError(error), + () -> emitter.complete() + ); + } + + /** + * 获取所有可用的连接器类型 + * @return 连接器类型列表 + */ + @SuppressWarnings("unchecked") + public Map getAllAvailableConnectors() { + Map availableConnectors = new HashMap<>(); + connectorFactory.getAllConnectors().forEach((type, connector) -> { + if (connector.isAvailable()) { + availableConnectors.put(type, connector); + } + }); + return availableConnectors; + } +} diff --git a/visual-novel-server-types/pom.xml b/visual-novel-server-types/pom.xml index 31b9f5a..3730bd5 100644 --- a/visual-novel-server-types/pom.xml +++ b/visual-novel-server-types/pom.xml @@ -10,6 +10,22 @@ visual-novel-server-types + + org.springframework.ai + spring-ai-ollama-spring-boot-starter + + + + org.springframework.ai + spring-ai-openai-spring-boot-starter + + + + com.theokanning.openai-gpt3-java + client + 0.18.0 + + org.springframework.boot spring-boot-starter-web diff --git a/visual-novel-server-types/src/main/java/com/touka/types/config/BaseLLMConnector.java b/visual-novel-server-types/src/main/java/com/touka/types/config/BaseLLMConnector.java new file mode 100644 index 0000000..5186258 --- /dev/null +++ b/visual-novel-server-types/src/main/java/com/touka/types/config/BaseLLMConnector.java @@ -0,0 +1,54 @@ +package com.touka.types.config; + + + +/** + * LLM连接器的基础实现,提供一些通用功能 + */ +public abstract class BaseLLMConnector implements ILLMConnector { + + protected String apiKey; + protected String baseUrl; + protected int timeoutMs; + protected boolean enabled; + + public BaseLLMConnector(String apiKey, String baseUrl, int timeoutMs) { + this.apiKey = apiKey; + this.baseUrl = baseUrl; + this.timeoutMs = timeoutMs; + this.enabled = true; + } + + @Override + public boolean isAvailable() { + return enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + } + + public int getTimeoutMs() { + return timeoutMs; + } + + public void setTimeoutMs(int timeoutMs) { + this.timeoutMs = timeoutMs; + } +} diff --git a/visual-novel-server-types/src/main/java/com/touka/types/config/ILLMConnector.java b/visual-novel-server-types/src/main/java/com/touka/types/config/ILLMConnector.java new file mode 100644 index 0000000..e36b561 --- /dev/null +++ b/visual-novel-server-types/src/main/java/com/touka/types/config/ILLMConnector.java @@ -0,0 +1,55 @@ +package com.touka.types.config; + +import java.util.Map; + +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import reactor.core.publisher.Flux; + +/** + * 大语言模型连接器接口,定义了与各类LLM API交互的统一方法 + */ +public interface ILLMConnector { + + /** + * 使用Prompt对象生成文本响应 + * @param prompt 提示对象 + * @return 聊天响应 + */ + ChatResponse generateResponse(Prompt prompt); + + /** + * 使用文本提示生成文本 + * @param textPrompt 文本提示 + * @param params 额外参数 + * @return 生成的文本内容 + */ + String generateText(String textPrompt, Map params); + + /** + * 流式生成文本响应 + * @param prompt 提示对象 + * @return 响应流 + */ + Flux streamResponse(Prompt prompt); + + /** + * 流式生成文本 + * @param textPrompt 文本提示 + * @param params 额外参数 + * @return 文本流 + */ + Flux streamText(String textPrompt, Map params); + + /** + * 获取连接器类型 + * @return 连接器类型标识 + */ + String getConnectorType(); + + /** + * 检查连接器是否可用 + * @return 连接器是否可用 + */ + boolean isAvailable(); +} \ No newline at end of file diff --git a/visual-novel-server-types/src/main/java/com/touka/types/config/LLMConfig.java b/visual-novel-server-types/src/main/java/com/touka/types/config/LLMConfig.java new file mode 100644 index 0000000..e702afc --- /dev/null +++ b/visual-novel-server-types/src/main/java/com/touka/types/config/LLMConfig.java @@ -0,0 +1,42 @@ +package com.touka.types.config; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +// 在配置类中添加 +@Configuration +public class LLMConfig { + @Bean + public ChatClient chatClient(ChatClient.Builder builder) { + return builder.build(); + } + + @Bean + public ChatClient.Builder chatClientBuilder(OpenAiChatModel openAiChatModel) { + // 根据配置创建Builder + return ChatClient.builder(openAiChatModel); + } + + @Bean + public OpenAiChatModel openAiChatModel(OpenAiApi openAiApi) { + OpenAiChatOptions options = OpenAiChatOptions.builder() + .model("gpt-3.5-turbo") + .temperature(0.7) + .maxCompletionTokens(2048) + .build(); + return new OpenAiChatModel(openAiApi, options); + } + + @Bean + public OpenAiApi openAiApi(@Value("${spring.ai.openai.api-key}") String apiKey, @Value("${spring.ai.openai.base-url}") String baseUrl) { + return OpenAiApi.builder(). + baseUrl(baseUrl). + apiKey(apiKey). + build(); + } +} diff --git a/visual-novel-server-types/src/main/java/com/touka/types/config/LLMConnectorConfig.java b/visual-novel-server-types/src/main/java/com/touka/types/config/LLMConnectorConfig.java new file mode 100644 index 0000000..25a222f --- /dev/null +++ b/visual-novel-server-types/src/main/java/com/touka/types/config/LLMConnectorConfig.java @@ -0,0 +1,34 @@ +package com.touka.types.config; + +import java.util.Map; +import lombok.Data; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.stereotype.Component; + +/** + * LLM连接器的配置类,用于从配置文件中读取连接器配置 + */ +@Data +@Component +@ConfigurationProperties(prefix = "llm") +public class LLMConnectorConfig { + + /** + * 默认连接器类型 + */ + private String defaultConnector; + + /** + * 各连接器的详细配置 + */ + private Map connectors; + + @Data + public static class ConnectorDetailConfig { + private String apiKey; + private String baseUrl; + private int timeoutMs = 30000; + private boolean enabled = true; + private Map additionalParams; + } +} \ No newline at end of file diff --git a/visual-novel-server-types/src/main/java/com/touka/types/config/LLMConnectorFactory.java b/visual-novel-server-types/src/main/java/com/touka/types/config/LLMConnectorFactory.java new file mode 100644 index 0000000..11142c8 --- /dev/null +++ b/visual-novel-server-types/src/main/java/com/touka/types/config/LLMConnectorFactory.java @@ -0,0 +1,124 @@ +package com.touka.types.config; + + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + + +import com.touka.types.config.impl.OpenAIConnector; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.beans.factory.annotation.Autowired; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + + +/** + * LLM连接器工厂,用于创建和管理不同类型的连接器实例 + */ +@Component +public class LLMConnectorFactory { + + @Value("${spring.ai.openai.api-key:}") + private String defaultApiKey; + + @Value("${spring.ai.openai.base-url:https://api.openai.com/v1}") + private String defaultBaseUrl; + + private final LLMConnectorConfig config; + private final ChatClient chatClient; + private final Map connectors = new ConcurrentHashMap<>(); + + @Autowired + public LLMConnectorFactory(LLMConnectorConfig config, ChatClient chatClient) { + this.config = config; + this.chatClient = chatClient; + initializeConnectors(); + } + + private void initializeConnectors() { + // 注册默认连接器 + registerDefaultConnectors(); + + // 根据配置初始化连接器 + if (config.getConnectors() != null) { + for (Map.Entry entry : config.getConnectors().entrySet()) { + String connectorType = entry.getKey(); + LLMConnectorConfig.ConnectorDetailConfig detailConfig = entry.getValue(); + + if (detailConfig.isEnabled()) { + ILLMConnector connector = createConnector(connectorType, detailConfig); + if (connector != null) { + connectors.put(connectorType, connector); + } + } + } + } + } + + private void registerDefaultConnectors() { + // 注册默认连接器实现,传入所有必要的参数 + connectors.put("openai", new OpenAIConnector(chatClient, defaultApiKey, defaultBaseUrl)); + } + + private ILLMConnector createConnector(String connectorType, LLMConnectorConfig.ConnectorDetailConfig detailConfig) { + switch (connectorType.toLowerCase()) { + case "openai": + return new OpenAIConnector( + detailConfig.getApiKey(), + detailConfig.getBaseUrl(), + detailConfig.getTimeoutMs(), + chatClient); + default: + // 可以扩展支持更多连接器类型 + return null; + } + } + + /** + * 获取指定类型的连接器 + * @param connectorType 连接器类型 + * @return 连接器实例,如果不存在则返回null + */ + public ILLMConnector getConnector(String connectorType) { + return connectors.get(connectorType); + } + + /** + * 获取默认连接器 + * @return 默认连接器实例 + */ + public ILLMConnector getDefaultConnector() { + String defaultType = config.getDefaultConnector(); + if (defaultType != null) { + ILLMConnector connector = connectors.get(defaultType); + if (connector != null && connector.isAvailable()) { + return connector; + } + } + + // 如果没有配置默认连接器或默认连接器不可用,返回第一个可用的连接器 + return connectors.values().stream() + .filter(ILLMConnector::isAvailable) + .findFirst() + .orElse(null); + } + + /** + * 注册新的连接器 + * @param connector 连接器实例 + */ + public void registerConnector(ILLMConnector connector) { + if (connector != null) { + connectors.put(connector.getConnectorType(), connector); + } + } + + /** + * 获取所有可用的连接器类型 + * @return 连接器类型列表 + */ + public Map getAllConnectors() { + return new ConcurrentHashMap<>(connectors); + } +} diff --git a/visual-novel-server-types/src/main/java/com/touka/types/config/impl/OpenAIConnector.java b/visual-novel-server-types/src/main/java/com/touka/types/config/impl/OpenAIConnector.java new file mode 100644 index 0000000..60aad75 --- /dev/null +++ b/visual-novel-server-types/src/main/java/com/touka/types/config/impl/OpenAIConnector.java @@ -0,0 +1,159 @@ +package com.touka.types.config.impl; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import com.touka.types.config.BaseLLMConnector; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; +import reactor.core.publisher.Flux; + +/** + * OpenAI API连接器实现,基于Spring AI的ChatModel和ChatClient + */ +@Component +public class OpenAIConnector extends BaseLLMConnector { + + private static final String CONNECTOR_TYPE = "openai"; + + private final ChatClient chatClient; + + @Autowired + public OpenAIConnector(ChatClient chatClient, + @Value("${spring.ai.openai.api-key:}") String apiKey, + @Value("${spring.ai.openai.base-url:https://api.openai.com/v1}") String baseUrl) { + // 直接在构造函数中使用注入的值 + super(apiKey, baseUrl, 30000); + this.chatClient = chatClient; + } + + public OpenAIConnector(String apiKey, String baseUrl, int timeoutMs, ChatClient chatClient) { + super(apiKey, baseUrl, timeoutMs); + this.chatClient = chatClient; + } + + @Override + public ChatResponse generateResponse(Prompt prompt) { + return chatClient.prompt(prompt).call().chatResponse(); + } + + @Override + public String generateText(String textPrompt, Map params) { + // 构建消息列表 + List messages = new ArrayList<>(); + + // 添加系统消息(如果有) + if (params != null && params.containsKey("systemPrompt")) { + String systemPrompt = (String) params.get("systemPrompt"); + messages.add(new SystemPromptTemplate(systemPrompt).createMessage(params)); + } + + // 添加用户消息 + messages.add(new UserMessage(textPrompt)); + + // 创建提示 + Prompt prompt = new Prompt(messages, createOptions(params)); + + // 调用ChatClient生成响应 + ChatResponse response = chatClient.prompt(prompt).call().chatResponse(); + + // 提取生成的文本 + if (response.getResults() != null && !response.getResults().isEmpty()) { + Generation generation = response.getResults().get(0); + return generation.getOutput().getText(); + } + + return ""; + } + + @Override + public Flux streamResponse(Prompt prompt) { + return chatClient.prompt(prompt).stream().chatResponse(); + } + + @Override + public Flux streamText(String textPrompt, Map params) { + // 构建消息列表 + List messages = new ArrayList<>(); + + // 添加系统消息(如果有) + if (params != null && params.containsKey("systemPrompt")) { + String systemPrompt = (String) params.get("systemPrompt"); + messages.add(new SystemPromptTemplate(systemPrompt).createMessage(params)); + } + + // 添加用户消息 + messages.add(new UserMessage(textPrompt)); + + // 创建提示 + Prompt prompt = new Prompt(messages, createOptions(params)); + + // 流式调用并处理响应 + return chatClient.prompt(prompt).stream().chatResponse() + .mapNotNull(response -> { + if (response.getResults() != null && !response.getResults().isEmpty()) { + Generation generation = response.getResults().get(0); + return generation.getOutput().getText(); + } + return ""; + }) + .filter(content -> !content.isEmpty()); + } + + @Override + public String getConnectorType() { + return CONNECTOR_TYPE; + } + + /** + * 根据参数创建OpenAI选项 + */ + private OpenAiChatOptions createOptions(Map params) { + OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder(); + + if (params != null) { + // 设置模型 + if (params.containsKey("model")) { + builder.model((String) params.get("model")); + } + + // 设置温度 + if (params.containsKey("temperature")) { + builder.temperature(Double.parseDouble(params.get("temperature").toString())); + } + + // 设置最大生成长度 + if (params.containsKey("maxTokens")) { + builder.maxTokens(Integer.parseInt(params.get("maxTokens").toString())); + } + + // 设置topP + if (params.containsKey("topP")) { + builder.topP(Double.parseDouble(params.get("topP").toString())); + } + + // 设置停止词 + if (params.containsKey("stop")) { + if (params.get("stop") instanceof String) { + builder.stop(List.of((String) params.get("stop"))); + } else if (params.get("stop") instanceof List) { + @SuppressWarnings("unchecked") + List stopList = (List) params.get("stop"); + builder.stop(stopList); + } + } + } + + return builder.build(); + } +} \ No newline at end of file