记一次使用Java调用本地BERT模型,做文本内容实体提取,运行环境中不需要Python

BERT文本分割-中文-通用领域

BERT文本分割-中文-通用领域

NLP
StructBERT

使用modelscope和gradio加载BERT文本分割-中文-通用领域的文本分割模型并前端推理。

背景

为什么使用Java加载模型?

在生产环境中没有任何必要使用Python在单独起一个服务提供服务接口,那会增加每次服务调用的时间,造成用户不好的体验。

同时为了减少部署的工作量,与其他业务功能都使用Java提供统一的服务接口,会减少很多的工作量,维护成本也相对减少。

环境说明

Java版本:11

操作系统:Windows

利用Python将模型本地化

下载模型

前往https://huggingface.co/中查找自己需要的开源模型,复制模型标识,比如如下图所示:

img

将模型标识替换掉下方代码中的mode_id内容,target_dir是输出目录,这里指定一个目录即可,目录不存在的话会自动创建目录。

import os, sys

model_id = "uer/roberta-base-finetuned-cluener2020-chinese"
target_dir = r"E:\Work\BERT\models\roberta-base-finetuned-cluener2020-chinese"
os.makedirs(target_dir, exist_ok=True)

# 常见需要的文件(可能有些模型文件名不同)
files = ["config.json", "vocab.txt", "tokenizer.json", "special_tokens_map.json", "pytorch_model.bin"]

api = HfApi()
for fname in files:
    try:
        print("Downloading", fname)
        path = hf_hub_download(repo_id=model_id, filename=fname, cache_dir=target_dir, local_dir=target_dir)
        print("Saved:", path)
    except Exception as e:
        print("Failed to download", fname, ":", e)
print("done")

验证是否下载成功

打开输出目录,查看文件是否有下载完成。至少需要包含pytorch_model.binconfig.jsonvocab.txt等以下文件。

img

生成tokenizer.json文件

有些模型是没有tokenizer.json文件的,就像我们现在所用的这个模型。但我们后续使用Java去加载这个模型时是需要用到tokenizer.json文件。下面是使用Python去根据下载的模型生成tokenizer.json文件代码:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese")
tokenizer.save_pretrained("E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese", legacy_format=False)

Python加载模型测试

通过以下代码加载刚刚下载到本地的模型。指定目录即可,保证目录中模型存在会自动读取模型文件的。

from transformers import BertTokenizerFast, BertForTokenClassification
import torch

model_dir = "E:\\Work\BERT\\models\\roberta-base-finetuned-cluener2020-chinese"
tokenizer = BertTokenizerFast.from_pretrained(model_dir)
model = BertForTokenClassification.from_pretrained(model_dir)

text = "程序员范宁在北京大学的燕园看了中国男篮的一场比赛。"
tokens = list(text)
inputs = tokenizer(tokens, return_tensors="pt", is_split_into_words=True)

with torch.no_grad():
    outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=2)

id2label = model.config.id2label
print([id2label[i.item()] for i in predictions[0]])

输出内容

上面的示例代码输出结果如下所示:

['O', 'B-position', 'I-position', 'I-position', 'B-name', 'I-name', 'O', 'B-organization', 'I-organization', 'I-organization', 'I-address', 'O', 'I-address', 'I-address', 'O', 'O', 'O', 'O', 'O', 'I-organization', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

这个输出是一个典型的 命名实体识别(NER)任务的标签序列,使用的是 BIO 标注格式(有时也叫 IOB 格式)。

B-XXX:表示一个实体的开始(Begin),XXX 是实体类型(如 name、position、organization、address 等)。

I-XXX:表示该词属于 XXX 类型实体的中间或结尾部分(Inside),且前面已经有同类型的 B 或 I。

O:表示“Outside”,即不属于任何命名实体。

导出pt模型

常见的模型文件:

格式主要框架是否含结构是否跨语言是否可读典型文件名
.ptPyTorch✅(TorchScript) ❌(state_dict)✅(TorchScript) ❌(pickle)model.pt, traced_model.pt
.binPyTorch (HF)pytorch_model.bin
.h5TensorFlow/Keras❌(限 TF)⚠️(需工具)tf_model.h5
.msgpackFlax/JAX✅(数据)⚠️(二进制)flax_model.msgpack

如果你需要将模型转换为.pt标准格式模型用于Java服务,下面是用来将.bin模型导出为.pt模型的Python代码:

from transformers import BertTokenizerFast, BertForTokenClassification
import torch

model_dir = "E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese"
tokenizer = BertTokenizerFast.from_pretrained(model_dir)
model = BertForTokenClassification.from_pretrained(model_dir)
model.eval()

# 示例输入(必须和实际输入格式一致)
text = "程序员范宁在北京大学的燕园看了中国男篮的一场比赛。"
tokens = list(text)
inputs = tokenizer(tokens, return_tensors="pt", is_split_into_words=True)

# 导出为 TorchScript
traced_model = torch.jit.trace(
    model,
    (inputs["input_ids"], inputs["attention_mask"]),
    strict=False
)
traced_model.save("roberta-cluener-traced.pt")

Java加载模型

引入工程依赖

以下为pom.xml文件的核心片段内容:

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>0.34.0</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>

    <dependencies>
        <!-- 系统依赖-->
        <dependency>
            <groupId>cn.tworice</groupId>
            <artifactId>tworice-system</artifactId>
        </dependency>
             
        <!-- 单元测试 -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
        </dependency>

        <!-- DJL API -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
        </dependency>

        <!-- DJL PyTorch engine -->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
        </dependency>

        <!-- PyTorch native CPU binding -->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-cpu</artifactId>
            <version>2.7.1</version> <!-- 与 BOM 配合的 native 版本(可从 BOM 确认) -->
            <classifier>win-x86_64</classifier>
        </dependency>

        <!-- HuggingFace tokenizers helper(用于加载 tokenizer) -->
        <dependency>
            <groupId>ai.djl.huggingface</groupId>
            <artifactId>tokenizers</artifactId>
        </dependency>
    </dependencies>

加载模型

初始化标签映射

在实例初始化块中对标签映射关系进行初始化,这些标签对应关系一般在模型文件夹下的config.json文件中。config.json文件示例如下图所示:

img

将该内容转换成Java中的Map存储,可以编写一个自动化内容,也可以手动转一下,我这里手动转了一下,核心代码:

private final Map<Integer, String> ID2LABEL = new HashMap<>();

{
    // 标签映射
    ID2LABEL.put(0, "O");
    ID2LABEL.put(1, "B-address");
    ID2LABEL.put(2, "I-address");
    ID2LABEL.put(3, "B-book");
    ID2LABEL.put(4, "I-book");
    ID2LABEL.put(5, "B-company");
    // 这里其他类似内容省略.......
}

加载模型

先加载模型配置文件,这里就用到了上文中生成的tokenizer.json文件,将tokenizer.json文件所在目录替换掉代码中的目录,之后替换掉代码中的pt文件绝对路径。

@PostConstruct
public void init() throws IOException, MalformedModelException {
    tokenizer = HuggingFaceTokenizer.builder()
            .optTokenizerPath(Paths.get("E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese"))
            .optAddSpecialTokens(true)
            .build();

    model = Model.newInstance("ner");
    model.load(Paths.get("E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese\\roberta-cluener-traced.pt"));
}

完整代码

下面是Java加载模型服务提供类的完整代码:

package cn.tworice.djl;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.huggingface.tokenizers.Encoding;
import org.springframework.stereotype.Service;

import javax.annotation.PostConstruct;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.*;

@Service
public class NerService {

    private Model model;
    private HuggingFaceTokenizer tokenizer;
    private final Map<Integer, String> ID2LABEL = new HashMap<>();

    {
        // CLUEner2020 标签映射(请根据你的模型 config 确认)
        ID2LABEL.put(0, "O");
        ID2LABEL.put(1, "B-address");
        ID2LABEL.put(2, "I-address");
        ID2LABEL.put(3, "B-book");
        ID2LABEL.put(4, "I-book");
        ID2LABEL.put(5, "B-company");
        // 这里其他类似内容省略.......
    }

    @PostConstruct
    public void init() throws IOException, MalformedModelException {
        tokenizer = HuggingFaceTokenizer.builder()
                .optTokenizerPath(Paths.get("E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese"))
                .optAddSpecialTokens(true)
                .build();

        model = Model.newInstance("ner");
        model.load(Paths.get("E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese\\roberta-cluener-traced.pt"));
    }

    /**
     * 预测并打印每个 token 及其 label,同时返回实体列表
     */
    public List<NerEntity> predict(String text) throws TranslateException {
        try (NDManager manager = NDManager.newBaseManager()) {
            Encoding encoding = tokenizer.encode(text);
            long[] inputIds = encoding.getIds();
            long[] attentionMask = encoding.getAttentionMask();
            String[] tokens = encoding.getTokens(); // 实际分词结果

            // === 打印输入 ===
            System.out.println(">>> 输入文本: " + text);
            System.out.println(">>> 分词结果 (含 [CLS]/[SEP]): " + Arrays.toString(tokens));

            // 转为 NDArray
            NDArray inputIdsArr = manager.create(new Shape(1, inputIds.length), DataType.INT64);
            inputIdsArr.set(inputIds);
            NDArray attentionMaskArr = manager.create(new Shape(1, attentionMask.length), DataType.INT64);
            attentionMaskArr.set(attentionMask);

            // 推理
            try (Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator())) {
                NDList inputs = new NDList(inputIdsArr, attentionMaskArr);
                NDList outputs = predictor.predict(inputs);
                NDArray logits = outputs.singletonOrThrow(); // [1, seq_len, num_labels]
                NDArray predictions = logits.argMax(2);      // [1, seq_len]
                long[] predIds = predictions.toLongArray();  // length = seq_len

                // === 构建 token -> label 映射(跳过 [CLS] 和 [SEP])===
                System.out.println("\n>>> Token 与 Label 对应关系:");
                List<String> tokenLabels = new ArrayList<>();
                // tokens[0] = [CLS], tokens[tokens.length-1] = [SEP]
                for (int i = 1; i < tokens.length - 1; i++) {
                    String token = tokens[i];
                    String label = ID2LABEL.getOrDefault((int) predIds[i], "O");
                    tokenLabels.add(label);
                    System.out.printf("  %-12s -> %s%n", token, label);
                }
                System.out.println(); // 空行分隔

                // === 提取实体(使用 tokenLabels 转为 long[])===
                long[] labelIds = tokenLabels.stream()
                        .mapToLong(label -> {
                            for (Map.Entry<Integer, String> entry : ID2LABEL.entrySet()) {
                                if (entry.getValue().equals(label)) return entry.getKey();
                            }
                            return 0L;
                        })
                        .toArray();

                // 注意:这里 tokens[1:-1] 对应原始分词,但中文通常按字,可直接用于 decode
                String[] contentTokens = Arrays.copyOfRange(tokens, 1, tokens.length - 1);
                List<NerEntity> entities = decodeEntities(contentTokens, labelIds);

                return entities;
            }
        }
    }

    private List<NerEntity> decodeEntities(String[] tokens, long[] labels) {
        List<NerEntity> entities = new ArrayList<>();
        StringBuilder currentEntity = new StringBuilder();
        String currentType = null;
        int start = -1;

        for (int i = 0; i < tokens.length && i < labels.length; i++) {
            String token = tokens[i];
            String label = ID2LABEL.getOrDefault((int) labels[i], "O");

            if (label.startsWith("B-")) {
                if (currentType != null) {
                    entities.add(new NerEntity(currentEntity.toString(), currentType, start, i));
                }
                currentEntity = new StringBuilder(token);
                currentType = label.substring(2);
                start = i;
            } else if (label.startsWith("I-") && currentType != null && label.substring(2).equals(currentType)) {
                currentEntity.append(token);
            } else {
                if (currentType != null) {
                    entities.add(new NerEntity(currentEntity.toString(), currentType, start, i));
                    currentType = null;
                    currentEntity.setLength(0);
                }
            }
        }

        if (currentType != null) {
            entities.add(new NerEntity(currentEntity.toString(), currentType, start, tokens.length));
        }

        return entities;
    }

    public static class NerEntity {
        public String entity;
        public String type;
        public int start;
        public int end;

        public NerEntity(String entity, String type, int start, int end) {
            this.entity = entity;
            this.type = type;
            this.start = start;
            this.end = end;
        }

        @Override
        public String toString() {
            return String.format("{'entity': '%s', 'type': '%s', 'start': %d, 'end': %d}", entity, type, start, end);
        }
    }
}

测试使用

利用单元测试,传入一段文字查看输出结果。

@SpringBootTest
public class BertTest {

    @Autowired
    private NerService nerService;

    @Test
    void testNerPrediction() throws Exception {
        String text = "2025年10月1日,程序员范宁在中国北京看了中国男篮的一场比赛。";
        List<NerService.NerEntity> entities = nerService.predict(text);

        System.out.println("输入文本: " + text);
        System.out.println("识别实体:");
        for (var entity : entities) {
            System.out.println(entity);
        }
    }
}

结果输出:

>>> 输入文本: 程序员范宁在中国北京看了中国男篮的一场比赛。
>>> 分词结果 ([CLS]/[SEP]): [[CLS],,,,,,,,,,,,,,,,,,,,,,, [SEP]]
[W1022 18:57:51.000000000 LegacyTypeDispatch.h:79] Warning: AutoNonVariableTypeMode is deprecated and will be removed in 1.10 release. For kernel implementations please use AutoDispatchBelowADInplaceOrView instead, If you are looking for a user facing API to enable running your inference-only workload, please use c10::InferenceMode. Using AutoDispatchBelowADInplaceOrView in user code is under risk of producing silent wrong result in some edge cases. See Note [AutoDispatchBelowAutograd] for more details. (function operator ())

>>> TokenLabel 对应关系:-> B-position
  序            -> I-position
  员            -> I-position
  范            -> B-name
  宁            -> I-name
  在            -> O-> B-address
  国            -> I-address
  北            -> I-address
  京            -> I-address
  看            -> O-> O-> O-> O-> O-> I-organization
  的            -> O-> O-> O-> O-> O-> O

输入文本: 程序员范宁在中国北京看了中国男篮的一场比赛。
识别实体:
{'entity': '程序员', 'type': 'position', 'start': 0, 'end': 3}
{'entity': '范宁', 'type': 'name', 'start': 3, 'end': 5}
{'entity': '中国北京', 'type': 'address', 'start': 6, 'end': 10}

您可能感兴趣的与本文相关的镜像

BERT文本分割-中文-通用领域

BERT文本分割-中文-通用领域

NLP
StructBERT

使用modelscope和gradio加载BERT文本分割-中文-通用领域的文本分割模型并前端推理。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

二饭

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值