diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 2298c190..dec0262c 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -372,17 +372,22 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo log_enable(); } - if (!sparams.system_prompt.empty()) - { - ctx_server->system_prompt_set(sparams.system_prompt); - } - if (params.model_alias == "unknown") { params.model_alias = params.model; } - llama_numa_init(params.numa); + if (json_value(json_params, "vocab_only", false)) { + if (!ctx_server->load_tokenizer(params)) + { + env->ThrowNew(c_llama_error, "could not load model from given file path"); + } + else + { + env->SetLongField(obj, f_model_pointer, reinterpret_cast(ctx_server)); + } + return; + } LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER}, {"commit", LLAMA_COMMIT}}); @@ -393,6 +398,13 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo {"system_info", llama_print_system_info()}, }); + if (!sparams.system_prompt.empty()) + { + ctx_server->system_prompt_set(sparams.system_prompt); + } + + llama_numa_init(params.numa); + std::atomic state{SERVER_STATE_LOADING_MODEL}; // load the model diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index d3d4750a..558a709d 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -726,6 +726,21 @@ struct server_context llama_batch_free(batch); } + bool load_tokenizer(const gpt_params ¶ms_) + { + params = params_; + + llama_model_params model_params = llama_model_default_params(); + model_params.vocab_only = true; + llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + if (model == nullptr) + { + LOG_ERROR("unable to load model", {{"model", params.model}}); + return false; + } + return true; + } + bool load_model(const gpt_params ¶ms_) { params = params_; diff --git a/src/main/java/de/kherud/llama/ModelParameters.java b/src/main/java/de/kherud/llama/ModelParameters.java index 1cbb6973..1aa7200e 100644 --- a/src/main/java/de/kherud/llama/ModelParameters.java +++ b/src/main/java/de/kherud/llama/ModelParameters.java @@ -66,6 +66,7 @@ public final class ModelParameters extends JsonParameters { private static final String PARAM_NO_KV_OFFLOAD = "no_kv_offload"; private static final String PARAM_SYSTEM_PROMPT = "system_prompt"; private static final String PARAM_CHAT_TEMPLATE = "chat_template"; + private static final String PARAM_VOCAB_ONLY = "vocab_only"; /** * Set the RNG seed @@ -572,4 +573,13 @@ public ModelParameters setChatTemplate(String chatTemplate) { return this; } + /** + * Whether to only load the vocabulary for tokenization, no weights (default: false). + * Note, that setting this to true will disable most other options. + */ + public ModelParameters setVocabOnly() { + parameters.put(PARAM_VOCAB_ONLY, "true"); + return this; + } + } diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index a5454c59..b6a60508 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -1,18 +1,24 @@ package de.kherud.llama; -import java.io.*; -import java.nio.charset.StandardCharsets; -import java.util.*; +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Scanner; import java.util.regex.Pattern; -import de.kherud.llama.args.LogFormat; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; +import de.kherud.llama.args.LogFormat; + public class LlamaModelTest { + private static final String modelPath = "models/codellama-7b.Q2_K.gguf"; private static final String prefix = "def remove_non_ascii(s: str) -> str:\n \"\"\" "; private static final String suffix = "\n return result\n"; private static final int nPredict = 10; @@ -24,7 +30,7 @@ public static void setup() { // LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> System.out.println(level + ": " + msg)); model = new LlamaModel( new ModelParameters() - .setModelFilePath("models/codellama-7b.Q2_K.gguf") + .setModelFilePath(modelPath) // .setModelUrl("https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf") .setNGpuLayers(43) .setEmbedding(true) @@ -166,6 +172,21 @@ public void testTokenization() { Assert.assertEquals(" " + prompt, decoded); } + @Test + public void testVocabOnly() { + try (LlamaModel model = new LlamaModel( + new ModelParameters() + .setModelFilePath(modelPath) + .setVocabOnly() + )) { + String prompt = "Hello, world!"; + int[] encoded = model.encode(prompt); + String decoded = model.decode(encoded); + Assert.assertEquals(" " + prompt, decoded); + } + + } + @Test public void testLogText() { List messages = new ArrayList<>(); @@ -220,7 +241,8 @@ public void testLogStdout() { model.complete(params); System.out.println("########## Log None ##########"); - LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> {}); + LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> { + }); model.complete(params); System.out.println("##############################"); @@ -237,7 +259,8 @@ private String completeAndReadStdOut() { .setNPredict(nPredict) .setSeed(42); model.complete(params); - } finally { + } + finally { System.out.flush(); System.setOut(stdOut); printStream.close();