java自定义skill技能接入大模型

该文章已生成可运行项目,

简单demo,自己写了玩,不喜勿喷

功能

使用java开发skill技能接入大模型,让大模型在指定场景调用skill完成对应的任务,比如大模型是不能直接访问本地文件的,就可以写个简单的skill接入后可以访问本地文件,如txt,word,excel,pdf等文件。或者将指定内容输入到本地文件中。

环境

使用本地部署的大模型或者自行购买使用api key,其余的就是java基础环境

我自己使用的本地搭建的模型

直接上代码

pom依赖

<dependency>
    <groupId>com.squareup.okhttp3</groupId>
    <artifactId>okhttp</artifactId>
    <version>4.12.0</version>
</dependency>
<dependency>
    <groupId>com.google.code.gson</groupId>
    <artifactId>gson</artifactId>
    <version>2.10.1</version>
</dependency>
<dependency>
    <groupId>org.slf4j</groupId>
    <artifactId>slf4j-api</artifactId>
    <version>2.0.9</version>
</dependency>
<dependency>
    <groupId>ch.qos.logback</groupId>
    <artifactId>logback-classic</artifactId>
    <version>1.2.12</version>
</dependency>
<dependency>
    <groupId>org.reflections</groupId>
    <artifactId>reflections</artifactId>
    <version>0.10.2</version>
</dependency>
<dependency>
    <groupId>org.apache.poi</groupId>
    <artifactId>poi</artifactId>
    <version>5.2.3</version>
</dependency>
<dependency>
    <groupId>org.apache.poi</groupId>
    <artifactId>poi-ooxml</artifactId>
    <version>5.2.3</version>
</dependency>
<dependency>
    <groupId>org.apache.poi</groupId>
    <artifactId>poi-scratchpad</artifactId>
    <version>5.2.3</version>
</dependency>
<dependency>
    <groupId>org.apache.logging.log4j</groupId>
    <artifactId>log4j-core</artifactId>
    <version>2.18.0</version>
</dependency>

入口

public static void main(String[] args) {
        try {
            String baseUrl = "http://xx。xx.x.x:11434";
            String model = "deepseek-r1:32b";
            System.out.println("正在初始化AI Agent...");

            OllamaClient llmClient = new OllamaClient(baseUrl, model);
            
            if (!llmClient.testConnection()) {
                System.out.println("无法连接到本地大模型服务: {}"+ baseUrl);
                System.out.println("请确保Ollama或其他兼容OpenAI API的本地服务正在运行");
                return;
            }
            
            System.out.println("成功连接到本地大模型服务");
            
            SkillManager skillManager = new SkillManager();
            
            //skillManager.registerSkill(new CalculatorSkill());
            skillManager.registerSkill(new FileReadSkill());
            skillManager.registerSkill(new FileWriteSkill());
            skillManager.registerSkill(new HttpGetSkill());
            
            System.out.println("已注册"+skillManager.getAllSkills().size()+"个技能");
            
            AIAgent agent = new AIAgent(llmClient, skillManager);
            
            System.out.println("AI Agent初始化完成,开始交互");
            System.out.println("\n========================================");
            System.out.println("AI Agent 已启动 (输入 'exit' 退出)");
            System.out.println("========================================\n");
            
            Scanner scanner = new Scanner(System.in);
            
            while (true) {
                System.out.print("你: ");
                String userInput = scanner.nextLine().trim();
                
                if (userInput.equalsIgnoreCase("exit")) {
                    System.out.println("再见!");
                    break;
                }
                
                if (userInput.isEmpty()) {
                    continue;
                }
                
                try {
                    System.out.print("AI: ");
                    String response = agent.chat(userInput);
                    System.out.println(response);
                    System.out.println();
                } catch (Exception e) {
                    System.out.println("处理用户输入时出错:"+ e);
                    System.out.println("抱歉,处理您的请求时出现错误: " + e.getMessage());
                    System.out.println();
                }
            }
            
            scanner.close();
            
        } catch (Exception e) {
            System.out.println("程序运行出错:"+ e);
            System.err.println("程序运行出错: " + e.getMessage());
        }
    }
OllamaClient客户端
public class OllamaClient {
    private static final Logger logger = LoggerFactory.getLogger(OllamaClient.class);
    
    private final String baseUrl;
    private final String model;
    private final OkHttpClient client;
    private final Gson gson;
    
    public OllamaClient(String baseUrl, String model) {
        this.baseUrl = baseUrl.endsWith("/") ? baseUrl.substring(0, baseUrl.length() - 1) : baseUrl;
        this.model = model;
        this.client = new OkHttpClient.Builder()
                .connectTimeout(60, TimeUnit.SECONDS)
                .readTimeout(300, TimeUnit.SECONDS)
                .writeTimeout(60, TimeUnit.SECONDS)
                .build();
        this.gson = new Gson();
    }
    
    public String chat(String userMessage) throws IOException {
        return chat(userMessage, null);
    }
    
    public String chat(String userMessage, String systemPrompt) throws IOException {
        List<ChatMessage> messages = new ArrayList<>();
        
        if (systemPrompt != null && !systemPrompt.isEmpty()) {
            messages.add(new ChatMessage("system", systemPrompt));
        }
        
        messages.add(new ChatMessage("user", userMessage));
        
        return sendRequest(messages);
    }
    
    public String chat(List<ChatMessage> messages) throws IOException {
        return sendRequest(messages);
    }
    
    private String sendRequest(List<ChatMessage> messages) throws IOException {
        JsonObject requestBody = new JsonObject();
        requestBody.addProperty("model", model);
        requestBody.addProperty("stream", false);
        
        requestBody.add("messages", gson.toJsonTree(messages));
        
        String jsonBody = gson.toJson(requestBody);
        logger.debug("Sending request to Ollama: {}", jsonBody);
        
        RequestBody body = RequestBody.create(
                jsonBody,
                MediaType.parse("application/json; charset=utf-8")
        );
        
        Request httpRequest = new Request.Builder()
                .url(https://blog.csdn.net/baseUrl%20+%20"/api/chat")
                .post(body)
                .build();
        
        try (Response response = client.newCall(httpRequest).execute()) {
            if (!response.isSuccessful()) {
                String errorBody = response.body() != null ? response.body().string() : "No response body";
                logger.error("Ollama API error: {} - {}", response.code(), errorBody);
                throw new IOException("Unexpected code " + response + ", body: " + errorBody);
            }
            
            String responseBody = response.body().string();
            logger.debug("Received response from Ollama: {}", responseBody);
            
            JsonObject jsonResponse = gson.fromJson(responseBody, JsonObject.class);
            JsonObject message = jsonResponse.getAsJsonObject("message");
            if (message != null) {
                return message.get("content").getAsString();
            }

            throw new IOException("Invalid response format from Ollama");
        }
    }
    
    public boolean testConnection() {
        try {
            Request request = new Request.Builder()
                    .url(https://blog.csdn.net/baseUrl%20+%20"/api/tags")
                    .get()
                    .build();
            
            try (Response response = client.newCall(request).execute()) {
                if (response.isSuccessful()) {
                    logger.info("Successfully connected to Ollama at {}", baseUrl);
                    return true;
                } else {
                    logger.error("Connection test failed with code: {}", response.code());
                    return false;
                }
            }
        } catch (IOException e) {
            logger.error("Connection test failed", e);
            return false;
        }
    }
    
    public List<String> listModels() {
        try {
            Request request = new Request.Builder()
                    .url(https://blog.csdn.net/baseUrl%20+%20"/api/tags")
                    .get()
                    .build();
            
            try (Response response = client.newCall(request).execute()) {
                if (response.isSuccessful()) {
                    String responseBody = response.body().string();
                    JsonObject jsonResponse = gson.fromJson(responseBody, JsonObject.class);
                    return gson.fromJson(jsonResponse.getAsJsonArray("models"), List.class);
                }
            }
        } catch (Exception e) {
            logger.error("Failed to list models", e);
        }
        return new ArrayList<>();
    }
    
    private String filterThinkingProcess(String content) {
        if (content == null || content.isEmpty()) {
            return content;
        }
        
        String filtered = content.replaceAll("<think\\s*>.*?</think\\s*>", "");
        
        filtered = filtered.trim();
        
        if (!filtered.isEmpty()) {
            logger.debug("Filtered thinking process from response");
        }
        
        return filtered;
    }
public class ChatMessage {
    private String role;
    private String content;

    public ChatMessage(String role, String content) {
        this.role = role;
        this.content = content;
    }

    public String getRole() {
        return role;
    }

    public void setRole(String role) {
        this.role = role;
    }

    public String getContent() {
        return content;
    }

    public void setContent(String content) {
        this.content = content;
    }
}
package com.example.aiagent.skill;

import org.reflections.Reflections;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

public class SkillManager {
    private static final Logger logger = LoggerFactory.getLogger(SkillManager.class);
    
    private final Map<String, Skill> skills = new ConcurrentHashMap<>();
    
    public void registerSkill(Skill skill) {
        skills.put(skill.getName(), skill);
        logger.info("Registered skill: {} - {}", skill.getName(), skill.getDescription());
    }
    
    public void unregisterSkill(String skillName) {
        Skill removed = skills.remove(skillName);
        if (removed != null) {
            logger.info("Unregistered skill: {}", skillName);
        }
    }
    
    public Skill getSkill(String skillName) {
        return skills.get(skillName);
    }
    
    public List<Skill> getAllSkills() {
        return new ArrayList<>(skills.values());
    }
    
    public boolean hasSkill(String skillName) {
        return skills.containsKey(skillName);
    }
    
    public SkillResult executeSkill(String skillName, Map<String, Object> parameters) {
        Skill skill = getSkill(skillName);
        if (skill == null) {
            return SkillResult.error("Skill not found: " + skillName);
        }
        
        try {
            logger.info("Executing skill: {} with parameters: {}", skillName, parameters);
            return skill.execute(parameters);
        } catch (Exception e) {
            logger.error("Error executing skill: " + skillName, e);
            return SkillResult.error("Execution error: " + e.getMessage());
        }
    }
    
    public String getSkillsDescription() {
        StringBuilder sb = new StringBuilder();
        sb.append("Available Skills:\n");
        for (Skill skill : getAllSkills()) {
            sb.append(String.format("- %s: %s\n", skill.getName(), skill.getDescription()));
        }
        return sb.toString();
    }
    
    public void autoRegisterSkills(String packageName) {
        Reflections reflections = new Reflections(packageName);
        Set<Class<? extends Skill>> skillClasses = reflections.getSubTypesOf(Skill.class);
        
        for (Class<? extends Skill> skillClass : skillClasses) {
            try {
                Skill skill = skillClass.getDeclaredConstructor().newInstance();
                registerSkill(skill);
            } catch (Exception e) {
                logger.error("Failed to instantiate skill: " + skillClass.getName(), e);
            }
        }
    }
}
package com.example.aiagent.skill.annotation;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface SkillInfo {
    String name();
    String description();
    String[] parameters() default {};
}
FileReadSkill:读取指定文件
package com.example.aiagent.skill.impl;

import com.example.aiagent.skill.Skill;
import com.example.aiagent.skill.SkillResult;
import com.example.aiagent.skill.annotation.SkillInfo;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Map;

import org.apache.poi.xwpf.usermodel.XWPFDocument;
import org.apache.poi.xwpf.usermodel.XWPFParagraph;
import org.apache.poi.hwpf.HWPFDocument;
import org.apache.poi.hwpf.extractor.WordExtractor;
import org.apache.poi.ss.usermodel.*;
import org.apache.poi.xssf.usermodel.XSSFWorkbook;
import org.apache.poi.hssf.usermodel.HSSFWorkbook;

@SkillInfo(
    name = "file_read",
    description = "读取指定路径的文件内容",
    parameters = {"filePath: 要读取的文件路径"}
)
public class FileReadSkill implements Skill {

    @Override
    public String getName() {
        return "file_read";
    }

    @Override
    public String getDescription() {
        return "读取指定路径的文件内容";
    }

    @Override
    public SkillResult execute(Map<String, Object> parameters) {
        try {
            String filePath = (String) parameters.get("filePath");
            if (filePath == null || filePath.isEmpty()) {
                return SkillResult.error("文件路径不能为空");
            }

            filePath = normalizeFilePath(filePath);

            String content;
            String fileExtension = getFileExtension(filePath).toLowerCase();

            if (fileExtension.equals("docx")) {
                content = readDocxFile(filePath);
            } else if (fileExtension.equals("doc")) {
                content = readDocFile(filePath);
            } else if (fileExtension.equals("xlsx")) {
                content = readXlsxFile(filePath);
            } else if (fileExtension.equals("xls")) {
                content = readXlsFile(filePath);
            } else {
                content = readTextFile(filePath);
            }

            return SkillResult.success(content);
        } catch (FileNotFoundException e) {
            return SkillResult.error("文件不存在: " + e.getMessage());
        } catch (IOException e) {
            return SkillResult.error("读取文件失败: " + e.getMessage());
        } catch (Exception e) {
            return SkillResult.error("读取文件失败: " + e.getMessage());
        }
    }

    private String readTextFile(String filePath) throws IOException {
        StringBuilder content = new StringBuilder();
        try (BufferedReader reader = new BufferedReader(
                new InputStreamReader(new FileInputStream(filePath), StandardCharsets.UTF_8))) {
            String line;
            while ((line = reader.readLine()) != null) {
                content.append(line).append("\n");
            }
        } catch (Exception e) {
            try (BufferedReader reader = new BufferedReader(
                    new InputStreamReader(new FileInputStream(filePath)))) {
                String line;
                while ((line = reader.readLine()) != null) {
                    content.append(line).append("\n");
                }
            }
        }
        return content.toString();
    }

    private String readDocxFile(String filePath) throws IOException {
        StringBuilder content = new StringBuilder();
        try (FileInputStream fis = new FileInputStream(filePath);
             XWPFDocument document = new XWPFDocument(fis)) {

            for (XWPFParagraph paragraph : document.getParagraphs()) {
                content.append(paragraph.getText()).append("\n");
            }
        }
        return content.toString();
    }

    private String readDocFile(String filePath) throws IOException {
        try (FileInputStream fis = new FileInputStream(filePath);
             HWPFDocument document = new HWPFDocument(fis);
             WordExtractor extractor = new WordExtractor(document)) {
            
            String[] paragraphs = extractor.getParagraphText();
            StringBuilder content = new StringBuilder();
            for (String paragraph : paragraphs) {
                content.append(paragraph).append("\n");
            }
            return content.toString();
        }
    }

    private String readXlsxFile(String filePath) throws IOException {
        StringBuilder content = new StringBuilder();
        try (FileInputStream fis = new FileInputStream(filePath);
             Workbook workbook = new XSSFWorkbook(fis)) {
            
            for (int i = 0; i < workbook.getNumberOfSheets(); i++) {
                Sheet sheet = workbook.getSheetAt(i);
                content.append("工作表 ").append(i + 1).append(": ").append(sheet.getSheetName()).append("\n");
                
                for (Row row : sheet) {
                    for (Cell cell : row) {
                        content.append(getCellValue(cell)).append("\t");
                    }
                    content.append("\n");
                }
                content.append("\n");
            }
        }
        return content.toString();
    }

    private String readXlsFile(String filePath) throws IOException {
        StringBuilder content = new StringBuilder();
        try (FileInputStream fis = new FileInputStream(filePath);
             Workbook workbook = new HSSFWorkbook(fis)) {
            
            for (int i = 0; i < workbook.getNumberOfSheets(); i++) {
                Sheet sheet = workbook.getSheetAt(i);
                content.append("工作表 ").append(i + 1).append(": ").append(sheet.getSheetName()).append("\n");
                
                for (Row row : sheet) {
                    for (Cell cell : row) {
                        content.append(getCellValue(cell)).append("\t");
                    }
                    content.append("\n");
                }
                content.append("\n");
            }
        }
        return content.toString();
    }

    private String getCellValue(Cell cell) {
        if (cell == null) {
            return "";
        }
        
        switch (cell.getCellType()) {
            case STRING:
                return cell.getStringCellValue();
            case NUMERIC:
                if (DateUtil.isCellDateFormatted(cell)) {
                    return cell.getDateCellValue().toString();
                } else {
                    return String.valueOf(cell.getNumericCellValue());
                }
            case BOOLEAN:
                return String.valueOf(cell.getBooleanCellValue());
            case FORMULA:
                return cell.getCellFormula();
            case BLANK:
                return "";
            default:
                return "";
        }
    }

    private String getFileExtension(String filePath) {
        int lastDotIndex = filePath.lastIndexOf('.');
        if (lastDotIndex == -1 || lastDotIndex == filePath.length() - 1) {
            return "";
        }
        return filePath.substring(lastDotIndex + 1);
    }

    private String normalizeFilePath(String filePath) {
        if (filePath == null || filePath.isEmpty()) {
            return filePath;
        }

        filePath = filePath.replace("/", "\\");

        while (filePath.contains("\\\\")) {
            filePath = filePath.replace("\\\\", "\\");
        }

        return filePath;
    }
}
FileReadSkill:读取指定文件
package com.example.aiagent.skill.impl;

import com.example.aiagent.skill.Skill;
import com.example.aiagent.skill.SkillResult;
import com.example.aiagent.skill.annotation.SkillInfo;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Map;

import org.apache.poi.xwpf.usermodel.XWPFDocument;
import org.apache.poi.xwpf.usermodel.XWPFParagraph;
import org.apache.poi.hwpf.HWPFDocument;
import org.apache.poi.hwpf.extractor.WordExtractor;
import org.apache.poi.ss.usermodel.*;
import org.apache.poi.xssf.usermodel.XSSFWorkbook;
import org.apache.poi.hssf.usermodel.HSSFWorkbook;

@SkillInfo(
    name = "file_read",
    description = "读取指定路径的文件内容",
    parameters = {"filePath: 要读取的文件路径"}
)
public class FileReadSkill implements Skill {

    @Override
    public String getName() {
        return "file_read";
    }

    @Override
    public String getDescription() {
        return "读取指定路径的文件内容";
    }

    @Override
    public SkillResult execute(Map<String, Object> parameters) {
        try {
            String filePath = (String) parameters.get("filePath");
            if (filePath == null || filePath.isEmpty()) {
                return SkillResult.error("文件路径不能为空");
            }

            filePath = normalizeFilePath(filePath);

            String content;
            String fileExtension = getFileExtension(filePath).toLowerCase();

            if (fileExtension.equals("docx")) {
                content = readDocxFile(filePath);
            } else if (fileExtension.equals("doc")) {
                content = readDocFile(filePath);
            } else if (fileExtension.equals("xlsx")) {
                content = readXlsxFile(filePath);
            } else if (fileExtension.equals("xls")) {
                content = readXlsFile(filePath);
            } else {
                content = readTextFile(filePath);
            }

            return SkillResult.success(content);
        } catch (FileNotFoundException e) {
            return SkillResult.error("文件不存在: " + e.getMessage());
        } catch (IOException e) {
            return SkillResult.error("读取文件失败: " + e.getMessage());
        } catch (Exception e) {
            return SkillResult.error("读取文件失败: " + e.getMessage());
        }
    }

    private String readTextFile(String filePath) throws IOException {
        StringBuilder content = new StringBuilder();
        try (BufferedReader reader = new BufferedReader(
                new InputStreamReader(new FileInputStream(filePath), StandardCharsets.UTF_8))) {
            String line;
            while ((line = reader.readLine()) != null) {
                content.append(line).append("\n");
            }
        } catch (Exception e) {
            try (BufferedReader reader = new BufferedReader(
                    new InputStreamReader(new FileInputStream(filePath)))) {
                String line;
                while ((line = reader.readLine()) != null) {
                    content.append(line).append("\n");
                }
            }
        }
        return content.toString();
    }

    private String readDocxFile(String filePath) throws IOException {
        StringBuilder content = new StringBuilder();
        try (FileInputStream fis = new FileInputStream(filePath);
             XWPFDocument document = new XWPFDocument(fis)) {

            for (XWPFParagraph paragraph : document.getParagraphs()) {
                content.append(paragraph.getText()).append("\n");
            }
        }
        return content.toString();
    }

    private String readDocFile(String filePath) throws IOException {
        try (FileInputStream fis = new FileInputStream(filePath);
             HWPFDocument document = new HWPFDocument(fis);
             WordExtractor extractor = new WordExtractor(document)) {
            
            String[] paragraphs = extractor.getParagraphText();
            StringBuilder content = new StringBuilder();
            for (String paragraph : paragraphs) {
                content.append(paragraph).append("\n");
            }
            return content.toString();
        }
    }

    private String readXlsxFile(String filePath) throws IOException {
        StringBuilder content = new StringBuilder();
        try (FileInputStream fis = new FileInputStream(filePath);
             Workbook workbook = new XSSFWorkbook(fis)) {
            
            for (int i = 0; i < workbook.getNumberOfSheets(); i++) {
                Sheet sheet = workbook.getSheetAt(i);
                content.append("工作表 ").append(i + 1).append(": ").append(sheet.getSheetName()).append("\n");
                
                for (Row row : sheet) {
                    for (Cell cell : row) {
                        content.append(getCellValue(cell)).append("\t");
                    }
                    content.append("\n");
                }
                content.append("\n");
            }
        }
        return content.toString();
    }

    private String readXlsFile(String filePath) throws IOException {
        StringBuilder content = new StringBuilder();
        try (FileInputStream fis = new FileInputStream(filePath);
             Workbook workbook = new HSSFWorkbook(fis)) {
            
            for (int i = 0; i < workbook.getNumberOfSheets(); i++) {
                Sheet sheet = workbook.getSheetAt(i);
                content.append("工作表 ").append(i + 1).append(": ").append(sheet.getSheetName()).append("\n");
                
                for (Row row : sheet) {
                    for (Cell cell : row) {
                        content.append(getCellValue(cell)).append("\t");
                    }
                    content.append("\n");
                }
                content.append("\n");
            }
        }
        return content.toString();
    }

    private String getCellValue(Cell cell) {
        if (cell == null) {
            return "";
        }
        
        switch (cell.getCellType()) {
            case STRING:
                return cell.getStringCellValue();
            case NUMERIC:
                if (DateUtil.isCellDateFormatted(cell)) {
                    return cell.getDateCellValue().toString();
                } else {
                    return String.valueOf(cell.getNumericCellValue());
                }
            case BOOLEAN:
                return String.valueOf(cell.getBooleanCellValue());
            case FORMULA:
                return cell.getCellFormula();
            case BLANK:
                return "";
            default:
                return "";
        }
    }

    private String getFileExtension(String filePath) {
        int lastDotIndex = filePath.lastIndexOf('.');
        if (lastDotIndex == -1 || lastDotIndex == filePath.length() - 1) {
            return "";
        }
        return filePath.substring(lastDotIndex + 1);
    }

    private String normalizeFilePath(String filePath) {
        if (filePath == null || filePath.isEmpty()) {
            return filePath;
        }

        filePath = filePath.replace("/", "\\");

        while (filePath.contains("\\\\")) {
            filePath = filePath.replace("\\\\", "\\");
        }

        return filePath;
    }
}
package com.example.aiagent.skill;

import java.util.Map;

public interface Skill {
    String getName();
    String getDescription();
    SkillResult execute(Map<String, Object> parameters);
}
package com.example.aiagent.skill;

public class SkillResult {
    private boolean success;
    private String result;
    private String error;
    
    public SkillResult(boolean success, String result, String error) {
        this.success = success;
        this.result = result;
        this.error = error;
    }
    
    public static SkillResult success(String result) {
        return new SkillResult(true, result, null);
    }
    
    public static SkillResult error(String error) {
        return new SkillResult(false, null, error);
    }
    
    public boolean isSuccess() {
        return success;
    }
    
    public String getResult() {
        return result;
    }
    
    public String getError() {
        return error;
    }
}
AIAgent:这个比较重要,特别是prompt
package com.example.aiagent.agent;

import com.example.aiagent.llm.OllamaClient;
import com.example.aiagent.model.ChatMessage;
import com.example.aiagent.skill.Skill;
import com.example.aiagent.skill.SkillManager;
import com.example.aiagent.skill.SkillResult;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class AIAgent {
    private static final Logger logger = LoggerFactory.getLogger(AIAgent.class);
    
    private final OllamaClient llmClient;
    private final SkillManager skillManager;
    private final Gson gson;
    
    private static final Pattern SKILL_CALL_PATTERN = 
        Pattern.compile("\\[SKILL_CALL\\s+(\\w+)\\s*\\{([^}]*)\\}\\]");
    
    public AIAgent(OllamaClient llmClient, SkillManager skillManager) {
        this.llmClient = llmClient;
        this.skillManager = skillManager;
        this.gson = new Gson();
    }
    
    public String chat(String userMessage) throws Exception {
        return chat(userMessage, null);
    }
    
    public String chat(String userMessage, String systemPrompt) throws Exception {
        List<ChatMessage> messages = new ArrayList<>();
        
        if (systemPrompt != null && !systemPrompt.isEmpty()) {
            messages.add(new ChatMessage("system", systemPrompt));
        } else {
            messages.add(new ChatMessage("system", buildSystemPrompt()));
        }
        
        messages.add(new ChatMessage("user", userMessage));
        
        return processConversation(messages);
    }
    
    private String buildSystemPrompt() {
        StringBuilder prompt = new StringBuilder();
        prompt.append("你是一个 AI 助手,可以使用以下技能来帮助用户完成任务。\n\n");
        prompt.append(skillManager.getSkillsDescription());
        prompt.append("\n\n当需要使用技能时,请严格使用以下 JSON 格式:\n");
        prompt.append("[SKILL_CALL skill_name {\"param1\": \"value1\", \"param2\": \"value2\"}]\n\n");
        prompt.append("重要提示:\n");
        prompt.append("1. 必须使用双引号包裹参数值\n");
        prompt.append("2. 文件路径必须使用完整路径,例如:D:\\\\test.txt 或 C:\\\\Users\\\\user\\\\file.txt\n");
        prompt.append("3. 在 JSON 字符串中,反斜杠需要转义为双反斜杠\n");
        prompt.append("4. 执行完技能后,必须用自然语言向用户汇报结果\n");
        prompt.append("5. 收到技能执行结果后,绝对不要再次调用任何技能,直接用自然语言向用户汇报结果\n");
        prompt.append("6. 如果收到系统消息包含\"技能执行成功\",说明技能已经执行完成,直接汇报结果即可\n");
        prompt.append("7. 获取到文件内容后,如果用户说了需求对内容进行整理、美化或总结,就不能直接返回原始内容,需要根据需求进行处理\n");
        prompt.append("8. 如果用户没有提新的需求,就直接返回内容\n");
        prompt.append("9. 如果用户要求总结,应该提取关键信息,用简洁的语言概括\n");
        prompt.append("10. 每次对话最多调用一次技能,调用技能后不要再调用其他技能\n\n");
        prompt.append("11. 调用完技能之后的结构,根据用户需求决定是否需要整理,如果不需要请直接返回数据\n\n");
        prompt.append("示例:\n");
        prompt.append("用户:请帮我读取 D:\\\\test.docx 的内容\n");
        prompt.append("助手:[SKILL_CALL file_read {\"filePath\": \"D:\\\\test.docx\"}]\n");
        prompt.append("系统:技能 file_read 执行成功,结果:文件内容\n");
        prompt.append("助手:我已经读取了文件内容,为您整理如下:\n整理后的内容\n\n");
        prompt.append("请根据用户的需求智能地选择合适的技能,确保参数格式正确,并在执行完技能后根据用户需求对内容进行整理、美化或总结。记住:收到技能结果后不要再调用任何技能!");
        return prompt.toString();
    }
    
    private String processConversation(List<ChatMessage> messages) throws Exception {
        int maxIterations = 10;
        int iteration = 0;
        Set<String> calledSkills = new HashSet<>();
        
        while (iteration < maxIterations) {
            iteration++;
            logger.info("Iteration {}: Processing conversation", iteration);
            
            String response = llmClient.chat(messages);
            logger.info("LLM Response: {}", response);
            
            List<ChatMessage> skillResults = new ArrayList<>();
            String processedResponse = processSkillCalls(response, skillResults, calledSkills);
            
            if (skillResults.isEmpty()) {
                logger.info("No skill results, returning response");
                return processedResponse != null ? processedResponse : response;
            }
            
            logger.info("Detected {} skill(s), adding to messages", skillResults.size());
            messages.add(new ChatMessage("assistant", response));
            messages.addAll(skillResults);
            logger.info("Current message count: {}", messages.size());
        }
        
        return "达到最大迭代次数,无法完成任务。";
    }
    
    private String processSkillCalls(String response, List<ChatMessage> skillResults, Set<String> calledSkills) {
        logger.info("Processing skill calls from response: {}", response);
        Matcher matcher = SKILL_CALL_PATTERN.matcher(response);
        StringBuffer result = new StringBuffer();
        
        while (matcher.find()) {
            String skillName = matcher.group(1);
            String paramsJson = matcher.group(2);
            
            logger.info("Detected skill call: {} with params: {}", skillName, paramsJson);
            
            if (calledSkills.contains(skillName)) {
                logger.warn("Skill {} already called, skipping to prevent infinite loop", skillName);
                matcher.appendReplacement(result, "");
                continue;
            }
            
            calledSkills.add(skillName);
            
            try {
                if (!paramsJson.startsWith("{")) {
                    paramsJson = "{" + paramsJson + "}";
                }
                Map<String, Object> parameters = parseParameters(paramsJson);
                SkillResult skillResult = skillManager.executeSkill(skillName, parameters);
                
                String resultMessage;
                if (skillResult.isSuccess()) {
                    resultMessage = String.format("技能 %s 执行成功,结果:%s", 
                        skillName, skillResult.getResult());
                } else {
                    resultMessage = String.format("技能 %s 执行失败,错误:%s", 
                        skillName, skillResult.getError());
                }
                
                skillResults.add(new ChatMessage("system", resultMessage));
                logger.info("Skill result: {}", resultMessage);
                
                matcher.appendReplacement(result, "");
            } catch (Exception e) {
                logger.error("Error processing skill call", e);
                String errorMessage = String.format("技能调用失败:%s", e.getMessage());
                skillResults.add(new ChatMessage("system", errorMessage));
                matcher.appendReplacement(result, "");
            }
        }
        
        matcher.appendTail(result);
        return result.toString();
    }
    
    private Map<String, Object> parseParameters(String json) {
        try {
            logger.debug("Parsing parameters from JSON: {}", json);
            
            if (json == null || json.trim().isEmpty()) {
                logger.error("Empty or null JSON string");
                return new HashMap<>();
            }
            
            JsonObject jsonObject = JsonParser.parseString(json).getAsJsonObject();
            Map<String, Object> params = new HashMap<>();
            
            for (String key : jsonObject.keySet()) {
                String value = jsonObject.get(key).getAsString();
                value = value.replace("\\", "\\\\");
                params.put(key, value);
                logger.debug("Parsed parameter: {} = {}", key, value);
            }
            
            logger.info("Successfully parsed {} parameters: {}", params.size(), params);
            return params;
        } catch (Exception e) {
            logger.error("Error parsing parameters: {}, error: {}, class: {}", json, e.getMessage(), e.getClass().getName());
            return new HashMap<>();
        }
    }
    
    public SkillManager getSkillManager() {
        return skillManager;
    }
}
package com.example.aiagent.skill;

public class SkillResult {
    private boolean success;
    private String result;
    private String error;
    
    public SkillResult(boolean success, String result, String error) {
        this.success = success;
        this.result = result;
        this.error = error;
    }
    
    public static SkillResult success(String result) {
        return new SkillResult(true, result, null);
    }
    
    public static SkillResult error(String error) {
        return new SkillResult(false, null, error);
    }
    
    public boolean isSuccess() {
        return success;
    }
    
    public String getResult() {
        return result;
    }
    
    public String getError() {
        return error;
    }
}

结果展示

我只测试了读取文件的skill,写入的和其他的skill我就没测试,毕竟只是demo,但是感觉使用场景还是蛮多的,特别是在只能使用内网的情况下

这是txt文档

这是docx文档(并且大模型可以帮你整理里面的内容)

这是excel文件:内容较多,会自动归纳总结整理

备注

只是个小demo,不喜勿喷,文档的读取格式有多种多样,这里只写了3中常用的格式,如果需要可自行补充,还有其他常用skill。

本文章已经生成可运行项目
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值