diff --git a/CMakeLists.txt b/CMakeLists.txt index 8f402fa..6542a3a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,8 @@ include(FetchContent) set(BUILD_SHARED_LIBS ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(BUILD_SHARED_LIBS OFF) +set(LLAMA_BUILD_TOOLS ON) +set(LLAMA_CURL OFF) option(LLAMA_VERBOSE "llama: verbose output" OFF) @@ -15,7 +17,7 @@ option(LLAMA_VERBOSE "llama: verbose output" OFF) FetchContent_Declare( json GIT_REPOSITORY https://github.com/nlohmann/json - GIT_TAG v3.11.3 + GIT_TAG v3.12.0 ) FetchContent_MakeAvailable(json) @@ -25,7 +27,7 @@ set(LLAMA_BUILD_COMMON ON) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b4916 + GIT_TAG b5688 ) FetchContent_MakeAvailable(llama.cpp) @@ -96,7 +98,7 @@ add_library(jllama SHARED src/main/cpp/jllama.cpp src/main/cpp/server.hpp src/ma set_target_properties(jllama PROPERTIES POSITION_INDEPENDENT_CODE ON) target_include_directories(jllama PRIVATE src/main/cpp ${JNI_INCLUDE_DIRS}) -target_link_libraries(jllama PRIVATE common llama nlohmann_json) +target_link_libraries(jllama PRIVATE common mtmd llama nlohmann_json) target_compile_features(jllama PRIVATE cxx_std_11) target_compile_definitions(jllama PRIVATE diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 11c80ae..b8dd94a 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -377,7 +377,7 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo return; } - SRV_INF("loading model '%s'\n", params.model.c_str()); + SRV_INF("loading model '%s'\n", params.model.path.c_str()); common_init(); @@ -413,15 +413,12 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo const auto model_meta = ctx_server->model_meta(); - if (!params.speculative.model.empty() || !params.speculative.hf_repo.empty()) { - SRV_INF("loading draft model '%s'\n", params.speculative.model.c_str()); + if (!params.speculative.model.path.empty() || !params.speculative.model.hf_repo.empty()) { + SRV_INF("loading draft model '%s'\n", params.speculative.model.path.c_str()); auto params_dft = params; params_dft.devices = params.speculative.devices; - params_dft.hf_file = params.speculative.hf_file; - params_dft.hf_repo = params.speculative.hf_repo; params_dft.model = params.speculative.model; - params_dft.model_url = params.speculative.model_url; params_dft.n_ctx = params.speculative.n_ctx == 0 ? params.n_ctx / params.n_parallel : params.speculative.n_ctx; params_dft.n_gpu_layers = params.speculative.n_gpu_layers; params_dft.n_parallel = 1; @@ -431,12 +428,12 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo llama_model *model_dft = llama_init_dft.model.get(); if (model_dft == nullptr) { - SRV_ERR("failed to load draft model, '%s'\n", params.speculative.model.c_str()); + SRV_ERR("failed to load draft model, '%s'\n", params.speculative.model.path.c_str()); } if (!common_speculative_are_compatible(ctx_server->ctx, llama_init_dft.context.get())) { SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", - params.speculative.model.c_str(), params.model.c_str()); + params.speculative.model.path.c_str(), params.model.path.c_str()); } const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); @@ -511,7 +508,7 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv task.id = ctx_server->queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); + task.prompt_tokens = server_tokens(tokenized_prompts[i], false); task.params = server_task::params_from_json_cmpl(ctx_server->ctx, ctx_server->params_base, data); task.id_selected_slot = json_value(data, "id_slot", -1); @@ -520,7 +517,7 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv task.params.oaicompat_cmpl_id = completion_id; // oaicompat_model is already populated by params_from_json_cmpl - tasks.push_back(task); + tasks.push_back(std::move(task)); } } catch (const std::exception &e) { const auto &err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST); @@ -529,10 +526,10 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv } ctx_server->queue_results.add_waiting_tasks(tasks); - ctx_server->queue_tasks.post(tasks); - const auto task_ids = server_task::get_list_id(tasks); + ctx_server->queue_tasks.post(std::move(tasks)); + if (task_ids.size() != 1) { env->ThrowNew(c_llama_error, "multitasking currently not supported"); return 0; @@ -600,24 +597,24 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, SRV_INF("Calling embedding '%s'\n", prompt.c_str()); - const auto tokens = tokenize_mixed(ctx_server->vocab, prompt, true, true); + auto tokens = tokenize_mixed(ctx_server->vocab, prompt, true, true); std::vector tasks; server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); task.id = ctx_server->queue_tasks.get_new_id(); task.index = 0; - task.prompt_tokens = std::move(tokens); + task.prompt_tokens = server_tokens(tokens, false); // OAI-compat task.params.oaicompat = OAICOMPAT_TYPE_NONE; - tasks.push_back(task); + tasks.push_back(std::move(task)); ctx_server->queue_results.add_waiting_tasks(tasks); - ctx_server->queue_tasks.post(tasks); - std::unordered_set task_ids = server_task::get_list_id(tasks); + + ctx_server->queue_tasks.post(std::move(tasks)); const auto id_task = *task_ids.begin(); json responses = json::array(); @@ -677,7 +674,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) - if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) { + if (!ctx_server->params_base.embedding || ctx_server->params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { env->ThrowNew(c_llama_error, "This server does not support reranking. Start it with `--reranking` and without `--embedding`"); return nullptr; @@ -702,14 +699,15 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo auto task = server_task(SERVER_TASK_TYPE_RERANK); task.id = ctx_server->queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]); - tasks.push_back(task); + auto tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]); + task.prompt_tokens = server_tokens(tokens, false); + tasks.push_back(std::move(task)); } ctx_server->queue_results.add_waiting_tasks(tasks); - ctx_server->queue_tasks.post(tasks); + std::unordered_set task_ids = server_task::get_list_id(tasks); + ctx_server->queue_tasks.post(std::move(tasks)); // get the result - std::unordered_set task_ids = server_task::get_list_id(tasks); std::vector results(task_ids.size()); // Create a new HashMap instance @@ -754,14 +752,14 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams) { jlong server_handle = env->GetLongField(obj, f_model_pointer); - auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) + const auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) std::string c_params = parse_jstring(env, jparams); json data = json::parse(c_params); - json templateData = - oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja, - ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get()); + std::vector files; + json templateData = oaicompat_chat_params_parse(data, ctx_server->oai_parser_opt, files); + std::string tok_str = templateData.at("prompt"); jstring jtok_str = env->NewStringUTF(tok_str.c_str()); diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index 9686f2a..8efc5ae 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -1,6 +1,13 @@ +#include "chat.h" #include "utils.hpp" +#include "arg.h" +#include "common.h" #include "json-schema-to-grammar.h" +#include "llama.h" +#include "log.h" +#include "mtmd-helper.h" +#include "mtmd.h" #include "sampling.h" #include "speculative.h" @@ -75,6 +82,26 @@ enum error_type { ERROR_TYPE_NOT_SUPPORTED, // custom error }; +static bool server_task_type_need_embd(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + return true; + default: + return false; + } +} + +static bool server_task_type_need_logits(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + return true; + default: + return false; + } +} + struct slot_params { bool stream = true; bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt @@ -105,7 +132,7 @@ struct slot_params { oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; - common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + common_chat_syntax oaicompat_chat_syntax; json to_json() const { std::vector samplers; @@ -121,7 +148,8 @@ struct slot_params { auto grammar_triggers = json::array(); for (const auto &trigger : sampling.grammar_triggers) { - grammar_triggers.push_back(trigger.to_json()); + server_grammar_trigger ct(std::move(trigger)); + grammar_triggers.push_back(ct.to_json()); } return json{ @@ -133,6 +161,7 @@ struct slot_params { {"top_k", sampling.top_k}, {"top_p", sampling.top_p}, {"min_p", sampling.min_p}, + {"top_n_sigma", sampling.top_n_sigma}, {"xtc_probability", sampling.xtc_probability}, {"xtc_threshold", sampling.xtc_threshold}, {"typical_p", sampling.typ_p}, @@ -161,7 +190,10 @@ struct slot_params { {"grammar_lazy", sampling.grammar_lazy}, {"grammar_triggers", grammar_triggers}, {"preserved_tokens", sampling.preserved_tokens}, - {"chat_format", common_chat_format_name(oaicompat_chat_format)}, + {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, + {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, + {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, + {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, {"samplers", samplers}, {"speculative.n_max", speculative.n_max}, {"speculative.n_min", speculative.n_min}, @@ -184,7 +216,7 @@ struct server_task { // used by SERVER_TASK_TYPE_INFERENCE slot_params params; - llama_tokens prompt_tokens; + server_tokens prompt_tokens; int id_selected_slot = -1; // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE @@ -215,6 +247,7 @@ struct server_task { slot_params defaults; defaults.sampling = params_base.sampling; defaults.speculative = params_base.speculative; + defaults.n_keep = params_base.n_keep; // enabling this will output extra debug information in the HTTP responses from the server params.verbose = params_base.verbosity > 9; @@ -235,6 +268,7 @@ struct server_task { params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); @@ -340,11 +374,16 @@ struct server_task { { auto it = data.find("chat_format"); if (it != data.end()) { - params.oaicompat_chat_format = static_cast(it->get()); - SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str()); + params.oaicompat_chat_syntax.format = static_cast(it->get()); + SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format)); } else { - params.oaicompat_chat_format = defaults.oaicompat_chat_format; + params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; } + params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format; + params.oaicompat_chat_syntax.reasoning_in_content = + params.stream && (params_base.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); + params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); + params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); } { @@ -366,9 +405,9 @@ struct server_task { const auto grammar_triggers = data.find("grammar_triggers"); if (grammar_triggers != data.end()) { for (const auto &t : *grammar_triggers) { - auto ct = common_grammar_trigger::from_json(t); - if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { - const auto &word = ct.value; + server_grammar_trigger ct(t); + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto &word = ct.value.value; auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); if (ids.size() == 1) { auto token = ids[0]; @@ -381,14 +420,22 @@ struct server_task { SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); common_grammar_trigger trigger; trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; - trigger.value = (llama_token)token; - params.sampling.grammar_triggers.push_back(trigger); + trigger.value = word; + trigger.token = token; + params.sampling.grammar_triggers.push_back(std::move(trigger)); } else { SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); } } else { - params.sampling.grammar_triggers.push_back(ct); + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { + SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); + } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { + SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); + } else { + throw std::runtime_error("Unknown grammar trigger type"); + } + params.sampling.grammar_triggers.emplace_back(std::move(ct.value)); } } } @@ -485,8 +532,12 @@ struct result_timings { double predicted_per_token_ms; double predicted_per_second; + // Optional speculative metrics - only included when > 0 + int32_t draft_n = 0; + int32_t draft_n_accepted = 0; + json to_json() const { - return { + json base = { {"prompt_n", prompt_n}, {"prompt_ms", prompt_ms}, {"prompt_per_token_ms", prompt_per_token_ms}, @@ -497,6 +548,13 @@ struct result_timings { {"predicted_per_token_ms", predicted_per_token_ms}, {"predicted_per_second", predicted_per_second}, }; + + if (draft_n > 0) { + base["draft_n"] = draft_n; + base["draft_n_accepted"] = draft_n_accepted; + } + + return base; } }; @@ -617,7 +675,8 @@ struct server_task_result_cmpl_final : server_task_result { oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; - common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + common_chat_msg oaicompat_msg; + std::vector oaicompat_msg_diffs; virtual int get_index() override { return index; } @@ -706,45 +765,20 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_oaicompat_chat() { std::string finish_reason = "length"; common_chat_msg msg; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - SRV_DBG("Parsing chat message: %s\n", content.c_str()); - msg = common_chat_parse(content, oaicompat_chat_format); - finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; + if (!oaicompat_msg.empty()) { + msg = oaicompat_msg; } else { + msg.role = "assistant"; msg.content = content; } - - json message{ - {"role", "assistant"}, - }; - if (!msg.reasoning_content.empty()) { - message["reasoning_content"] = msg.reasoning_content; - } - if (msg.content.empty() && !msg.tool_calls.empty()) { - message["content"] = json(); - } else { - message["content"] = msg.content; - } - if (!msg.tool_calls.empty()) { - auto tool_calls = json::array(); - for (const auto &tc : msg.tool_calls) { - tool_calls.push_back({ - {"type", "function"}, - {"function", - { - {"name", tc.name}, - {"arguments", tc.arguments}, - }}, - {"id", tc.id}, - }); - } - message["tool_calls"] = tool_calls; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; } json choice{ {"finish_reason", finish_reason}, {"index", 0}, - {"message", message}, + {"message", msg.to_json_oaicompat()}, }; if (!stream && probs_output.size() > 0) { @@ -780,13 +814,35 @@ struct server_task_result_cmpl_final : server_task_result { std::time_t t = std::time(0); std::string finish_reason = "length"; if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = "stop"; + finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"; + } + + json deltas = json::array(); + for (const auto &diff : oaicompat_msg_diffs) { + deltas.push_back({ + {"choices", json::array({ + json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); } - json choice = json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}}; - - json ret = json{ - {"choices", json::array({choice})}, + deltas.push_back({ + {"choices", json::array({ + json{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}, + }, + })}, {"created", t}, {"id", oaicompat_cmpl_id}, {"model", oaicompat_model}, @@ -798,13 +854,18 @@ struct server_task_result_cmpl_final : server_task_result { {"prompt_tokens", n_prompt_tokens}, {"total_tokens", n_decoded + n_prompt_tokens}, }}, - }; + }); if (timings.prompt_n >= 0) { - ret.push_back({"timings", timings.to_json()}); + deltas.back().push_back({"timings", timings.to_json()}); } - return ret; + // extra fields for debugging purposes + if (verbose && !deltas.empty()) { + deltas.front()["__verbose"] = to_json_non_oaicompat(); + } + + return deltas; } }; @@ -826,6 +887,7 @@ struct server_task_result_cmpl_partial : server_task_result { oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; + std::vector oaicompat_msg_diffs; virtual int get_index() override { return index; } @@ -900,66 +962,54 @@ struct server_task_result_cmpl_partial : server_task_result { } json to_json_oaicompat_chat() { - bool first = n_decoded == 0; + bool first = n_decoded == 1; std::time_t t = std::time(0); json choices; + std::vector deltas; + auto add_delta = [&](const json &delta) { + deltas.push_back({ + {"choices", json::array({ + json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", delta}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + }; + // We have to send an initial update to conform to openai behavior if (first) { - if (content.empty()) { - choices = json::array( - {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"role", "assistant"}}}}}); - } else { - // We have to send this as two updates to conform to openai behavior - json initial_ret = json{{"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{{"role", "assistant"}}}}})}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"object", "chat.completion.chunk"}}; - - json second_ret = - json{{"choices", - json::array( - {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"content", content}}}}})}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"object", "chat.completion.chunk"}}; - - return std::vector({initial_ret, second_ret}); - } - } else { - choices = json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", - json{ - {"content", content}, - }}, - }}); + add_delta({ + {"role", "assistant"}, + {"content", nullptr}, + }); } - GGML_ASSERT(choices.size() >= 1); - - if (prob_output.probs.size() > 0) { - choices[0]["logprobs"] = json{ - {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, - }; + for (const auto &diff : oaicompat_msg_diffs) { + add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); } - json ret = json{{"choices", choices}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}}; + if (!deltas.empty()) { + GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1); - if (timings.prompt_n >= 0) { - ret.push_back({"timings", timings.to_json()}); + if (prob_output.probs.size() > 0) { + deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + + if (timings.prompt_n >= 0) { + deltas[deltas.size() - 1].push_back({"timings", timings.to_json()}); + } } - return std::vector({ret}); + return deltas; } }; @@ -1068,9 +1118,6 @@ struct server_task_result_metrics : server_task_result { int n_tasks_deferred; int64_t t_start; - int32_t kv_cache_tokens_count; - int32_t kv_cache_used_cells; - // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields uint64_t n_prompt_tokens_processed_total = 0; uint64_t t_prompt_processing_total = 0; @@ -1110,9 +1157,6 @@ struct server_task_result_metrics : server_task_result { {"n_decode_total", n_decode_total}, {"n_busy_slots_total", n_busy_slots_total}, - {"kv_cache_tokens_count", kv_cache_tokens_count}, - {"kv_cache_used_cells", kv_cache_used_cells}, - {"slots", slots_data}, }; } @@ -1171,6 +1215,9 @@ struct server_slot { llama_context *ctx = nullptr; llama_context *ctx_dft = nullptr; + // multimodal + mtmd_context *mctx = nullptr; + common_speculative *spec = nullptr; std::vector lora; @@ -1198,14 +1245,15 @@ struct server_slot { int32_t n_prompt_tokens_processed = 0; // input prompt tokens - llama_tokens prompt_tokens; + server_tokens prompt_tokens; size_t last_nl_pos = 0; std::string generated_text; llama_tokens generated_tokens; + common_chat_msg chat_msg; - llama_tokens cache_tokens; + server_tokens cache_tokens; std::vector generated_token_probs; @@ -1224,6 +1272,7 @@ struct server_slot { llama_token sampled; common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + std::vector generated_tool_call_ids; // stats size_t n_sent_text = 0; // number of sent text character @@ -1236,6 +1285,10 @@ struct server_slot { std::function callback_on_release; + // Speculative decoding stats + int32_t n_draft_total = 0; // Total draft tokens generated + int32_t n_draft_accepted = 0; // Draft tokens actually accepted + void reset() { SLT_DBG(*this, "%s", "\n"); @@ -1249,17 +1302,31 @@ struct server_slot { n_past = 0; n_sent_text = 0; task_type = SERVER_TASK_TYPE_COMPLETION; + chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; generated_tokens.clear(); generated_token_probs.clear(); + chat_msg = {}; + json_schema = json(); + generated_tool_call_ids.clear(); + + // clear speculative decoding stats + n_draft_total = 0; + n_draft_accepted = 0; } - bool is_non_causal() const { - return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; + bool need_embd() const { return server_task_type_need_embd(task_type); } + + bool need_logits() const { return server_task_type_need_logits(task_type); } + + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch + // also we cannot split if the pooling would require any past tokens + bool can_split() const { + return !need_embd() || (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); } - bool can_batch_with(server_slot &other_slot) { - return is_non_causal() == other_slot.is_non_causal() && are_lora_equal(lora, other_slot.lora); + bool can_batch_with(server_slot &other_slot) const { + return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora); } bool has_budget(const common_params &global_params) { @@ -1313,9 +1380,28 @@ struct server_slot { timings.predicted_per_token_ms = t_token_generation / n_decoded; timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + // Add speculative metrics + if (n_draft_total > 0) { + timings.draft_n = n_draft_total; + timings.draft_n_accepted = n_draft_accepted; + } + return timings; } + const common_chat_msg &update_chat_msg(std::vector &diffs) { + auto previous_msg = chat_msg; + SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); + auto new_msg = common_chat_parse(generated_text, + /* is_partial= */ stop != STOP_TYPE_EOS, params.oaicompat_chat_syntax); + if (!new_msg.empty()) { + new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id); + chat_msg = new_msg; + diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); + } + return chat_msg; + } + size_t find_stopping_strings(const std::string &text, const size_t last_token_size, bool is_full_stop) { size_t stop_pos = std::string::npos; @@ -1329,7 +1415,7 @@ struct server_slot { pos = text.find(word, from_pos); } else { // otherwise, partial stop - pos = find_partial_stop_string(word, text); + pos = string_find_partial_stop(text, word); } if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { @@ -1360,6 +1446,14 @@ struct server_slot { t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, t_token_generation, n_decoded, t_gen, n_gen_second, t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + + if (n_draft_total > 0) { + const float draft_ratio = (float)n_draft_accepted / n_draft_total; + SLT_INF(*this, + "\n" + "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", + draft_ratio, n_draft_accepted, n_draft_total); + } } json to_json() const { @@ -1369,9 +1463,8 @@ struct server_slot { {"n_ctx", n_ctx}, {"speculative", can_speculate()}, {"is_processing", is_processing()}, - {"non_causal", is_non_causal()}, {"params", params.to_json()}, - {"prompt", common_detokenize(ctx, prompt_tokens)}, + {"prompt", prompt_tokens.detokenize(ctx, true)}, {"next_token", { {"has_next_token", has_next_token}, @@ -1446,29 +1539,30 @@ struct server_queue { std::condition_variable condition_tasks; // callback functions - std::function callback_new_task; + std::function callback_new_task; std::function callback_update_slots; // Add a new task to the end of the queue - int post(server_task task, bool front = false) { + int post(server_task &&task, bool front = false) { std::unique_lock lock(mutex_tasks); GGML_ASSERT(task.id != -1); // if this is cancel task make sure to clean up pending tasks if (task.type == SERVER_TASK_TYPE_CANCEL) { cleanup_pending_task(task.id_target); } - QUE_DBG("new task, id = %d, front = %d\n", task.id, front); + const int task_id = task.id; + QUE_DBG("new task, id = %d, front = %d\n", task_id, front); if (front) { queue_tasks.push_front(std::move(task)); } else { queue_tasks.push_back(std::move(task)); } condition_tasks.notify_one(); - return task.id; + return task_id; } // multi-task version of post() - int post(std::vector &tasks, bool front = false) { + int post(std::vector &&tasks, bool front = false) { std::unique_lock lock(mutex_tasks); for (auto &task : tasks) { if (task.id == -1) { @@ -1490,7 +1584,7 @@ struct server_queue { } // Add a new task, but defer until one slot is available - void defer(server_task task) { + void defer(server_task &&task) { std::unique_lock lock(mutex_tasks); QUE_DBG("defer task, id = %d\n", task.id); queue_tasks_deferred.push_back(std::move(task)); @@ -1505,7 +1599,7 @@ struct server_queue { } // Register function to process a new task - void on_new_task(std::function callback) { callback_new_task = std::move(callback); } + void on_new_task(std::function callback) { callback_new_task = std::move(callback); } // Register the function to be called when all slots data is ready to be processed void on_update_slots(std::function callback) { callback_update_slots = std::move(callback); } @@ -1550,7 +1644,7 @@ struct server_queue { lock.unlock(); break; } - server_task task = queue_tasks.front(); + server_task task = std::move(queue_tasks.front()); queue_tasks.pop_front(); lock.unlock(); @@ -1588,6 +1682,8 @@ struct server_queue { }; struct server_response { + bool running = true; + // for keeping track of all tasks waiting for the result std::unordered_set waiting_task_ids; @@ -1643,7 +1739,13 @@ struct server_response { server_task_result_ptr recv(const std::unordered_set &id_tasks) { while (true) { std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&] { return !queue_results.empty(); }); + condition_results.wait(lock, [&] { + if (!running) { + SRV_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } + return !queue_results.empty(); + }); for (size_t i = 0; i < queue_results.size(); i++) { if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { @@ -1672,6 +1774,10 @@ struct server_response { } std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); + if (!running) { + SRV_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } if (cr_res == std::cv_status::timeout) { return nullptr; } @@ -1701,6 +1807,12 @@ struct server_response { } } } + + // terminate the waiting loop + void terminate() { + running = false; + condition_results.notify_all(); + } }; struct server_context { @@ -1713,13 +1825,16 @@ struct server_context { llama_model *model = nullptr; llama_context *ctx = nullptr; + // multimodal + mtmd_context *mctx = nullptr; + const llama_vocab *vocab = nullptr; llama_model *model_dft = nullptr; llama_context_params cparams_dft; - llama_batch batch = {}; + llama_batch batch{}; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1740,8 +1855,11 @@ struct server_context { float slot_prompt_similarity = 0.0f; common_chat_templates_ptr chat_templates; + oaicompat_parser_options oai_parser_opt; ~server_context() { + mtmd_free(mctx); + // Clear any sampling context for (server_slot &slot : slots) { common_sampler_free(slot.smpl); @@ -1760,7 +1878,7 @@ struct server_context { } bool load_model(const common_params ¶ms) { - SRV_INF("loading model '%s'\n", params.model.c_str()); + SRV_INF("loading model '%s'\n", params.model.path.c_str()); params_base = params; @@ -1770,7 +1888,7 @@ struct server_context { ctx = llama_init.context.get(); if (model == nullptr) { - SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); + SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } @@ -1781,33 +1899,34 @@ struct server_context { add_bos_token = llama_vocab_get_add_bos(vocab); has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; - if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) { - SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); + if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); auto params_dft = params_base; params_dft.devices = params_base.speculative.devices; - params_dft.hf_file = params_base.speculative.hf_file; - params_dft.hf_repo = params_base.speculative.hf_repo; params_dft.model = params_base.speculative.model; - params_dft.model_url = params_base.speculative.model_url; params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_parallel = 1; + // force F16 KV cache for the draft model for extra performance + params_dft.cache_type_k = GGML_TYPE_F16; + params_dft.cache_type_v = GGML_TYPE_F16; + llama_init_dft = common_init_from_params(params_dft); model_dft = llama_init_dft.model.get(); if (model_dft == nullptr) { - SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str()); + SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); return false; } if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", - params_base.speculative.model.c_str(), params_base.model.c_str()); + params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); return false; } @@ -1817,10 +1936,6 @@ struct server_context { cparams_dft = common_context_params_to_llama(params_dft); cparams_dft.n_batch = n_ctx_dft; - // force F16 KV cache for the draft model for extra performance - cparams_dft.type_k = GGML_TYPE_F16; - cparams_dft.type_v = GGML_TYPE_F16; - // the context is not needed - we will create one for each slot llama_init_dft.context.reset(); } @@ -1829,12 +1944,55 @@ struct server_context { try { common_chat_format_example(chat_templates.get(), params.use_jinja); } catch (const std::exception &e) { + SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. " "This may cause the model to output suboptimal responses\n", __func__); chat_templates = common_chat_templates_init(model, "chatml"); } + std::string &mmproj_path = params_base.mmproj.path; + if (!mmproj_path.empty()) { + mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params_base.mmproj_use_gpu; + mparams.print_timings = false; + mparams.n_threads = params_base.cpuparams.n_threads; + mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + if (mctx == nullptr) { + SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); + return false; + } + SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); + } + + if (!params_base.speculative.model.path.empty()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); + return false; + } + } + + if (!llama_memory_can_shift(llama_get_memory(ctx))) { + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); + } + } + return true; } @@ -1850,6 +2008,8 @@ struct server_context { slot.ctx = ctx; slot.n_ctx = n_ctx_slot; slot.n_predict = params_base.n_predict; + slot.mctx = mctx; + slot.cache_tokens.has_mtmd = mctx != nullptr; if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); @@ -1870,12 +2030,13 @@ struct server_context { SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); slot.params.sampling = params_base.sampling; + slot.params.n_keep = params_base.n_keep; slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); }; slot.reset(); - slots.push_back(slot); + slots.push_back(std::move(slot)); } default_generation_settings_for_props = slots[0].to_json(); @@ -1885,12 +2046,20 @@ struct server_context { // used) { const int32_t n_batch = llama_n_batch(ctx); - - // only a single seq_id per token is needed batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); } metrics.init(); + + oai_parser_opt = { + /* use_jinja */ params_base.use_jinja, + /* prefill_assistant */ params_base.prefill_assistant, + /* reasoning_format */ params_base.reasoning_format, + /* common_chat_templates */ chat_templates.get(), + /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, + /* allow_audio */ mctx ? mtmd_support_audio(mctx) : false, + /* enable_thinking */ params_base.reasoning_budget != 0, + }; } server_slot *get_slot_by_id(int id) { @@ -1923,7 +2092,7 @@ struct server_context { } // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens); + int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); // fraction of the common subsequence length compared to the current slot's prompt length float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); @@ -1943,7 +2112,8 @@ struct server_context { // find the slot that has been least recently used if (ret == nullptr) { - int64_t t_last = ggml_time_us(); + int64_t t_last = -1; + for (server_slot &slot : slots) { // skip the slot if it is not available if (slot.is_processing()) { @@ -1951,7 +2121,7 @@ struct server_context { } // select the current slot if the criteria match - if (slot.t_last_used < t_last) { + if (!ret || slot.t_last_used <= t_last) { t_last = slot.t_last_used; ret = &slot; } @@ -1965,7 +2135,7 @@ struct server_context { return ret; } - bool launch_slot_with_task(server_slot &slot, const server_task &task) { + bool launch_slot_with_task(server_slot &slot, server_task &&task) { slot.reset(); slot.id_task = task.id; slot.index = task.index; @@ -1973,12 +2143,16 @@ struct server_context { slot.params = std::move(task.params); slot.prompt_tokens = std::move(task.prompt_tokens); - if (!are_lora_equal(task.params.lora, slot.lora)) { + if (!are_lora_equal(slot.params.lora, slot.lora)) { // if lora is changed, we cannot reuse cached tokens slot.cache_tokens.clear(); - slot.lora = task.params.lora; + slot.lora = slot.params.lora; } + if (!slot.prompt_tokens.validate(ctx)) { + send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); + return false; + } SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { @@ -2022,7 +2196,7 @@ struct server_context { SRV_DBG("%s", "clearing KV cache\n"); // clear the entire KV cache - llama_kv_cache_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); clean_kv_cache = false; } @@ -2076,6 +2250,14 @@ struct server_context { slot.has_next_token = true; } + // if context shifting is disabled, make sure that we don't run out of context + if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx); + } + // check the limits if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { slot.stop = STOP_TYPE_LIMIT; @@ -2085,16 +2267,6 @@ struct server_context { } if (slot.has_new_line) { - // if we have already seen a new line, we stop after a certain time limit - if (slot.params.t_max_predict_ms > 0 && - (ggml_time_us() - slot.t_start_generation > 1000.0f * slot.params.t_max_predict_ms)) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, - (int)slot.params.t_max_predict_ms); - } - // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent if (slot.params.n_indent > 0) { // check the current indentation @@ -2135,6 +2307,16 @@ struct server_context { // check if there is a new line in the generated text if (result.text_to_send.find('\n') != std::string::npos) { slot.has_new_line = true; + + // if we have seen a new line, we stop after a certain time limit, but only upon another new line + if (slot.params.t_max_predict_ms > 0 && + (ggml_time_us() - slot.t_start_generation > 1000.0f * slot.params.t_max_predict_ms)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, + (int)slot.params.t_max_predict_ms); + } } // if context shift is disabled, we stop when it reaches the context limit @@ -2237,6 +2419,15 @@ struct server_context { queue_results.send(std::move(res)); } + // if multimodal is enabled, send an error and return false + bool ensure_no_mtmd(const int id_task) { + if (mctx) { + send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + return false; + } + return true; + } + void send_partial_response(server_slot &slot, const completion_token_output &tkn) { auto res = std::make_unique(); @@ -2254,6 +2445,8 @@ struct server_context { res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + slot.update_chat_msg(res->oaicompat_msg_diffs); + // populate res.probs_output if (slot.params.sampling.n_probs > 0) { res->prob_output = tkn; // copy the token probs @@ -2273,10 +2466,10 @@ struct server_context { res->id_slot = slot.id; res->index = slot.index; - res->content = std::move(slot.generated_text); + res->content = slot.generated_text; res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); - res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + res->prompt = slot.prompt_tokens.detokenize(ctx, true); res->response_fields = std::move(slot.params.response_fields); res->truncated = slot.truncated; @@ -2293,7 +2486,8 @@ struct server_context { res->oaicompat = slot.params.oaicompat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res->oaicompat_chat_format = slot.params.oaicompat_chat_format; + res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); + // populate res.probs_output if (slot.params.sampling.n_probs > 0) { if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { @@ -2402,10 +2596,10 @@ struct server_context { server_task task(SERVER_TASK_TYPE_CANCEL); task.id_target = id_task; queue_results.remove_waiting_task_id(id_task); - cancel_tasks.push_back(task); + cancel_tasks.push_back(std::move(task)); } // push to beginning of the queue, so it has highest priority - queue_tasks.post(cancel_tasks, true); + queue_tasks.post(std::move(cancel_tasks), true); } // receive the results from task(s) @@ -2486,7 +2680,7 @@ struct server_context { // Functions to process the task // - void process_single_task(server_task task) { + void process_single_task(server_task &&task) { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: case SERVER_TASK_TYPE_INFILL: @@ -2499,17 +2693,18 @@ struct server_context { if (slot == nullptr) { // if no slot is available, we defer this task for processing later SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); - queue_tasks.defer(task); + queue_tasks.defer(std::move(task)); break; } + if (slot->is_processing()) { // if requested slot is unavailable, we defer this task for processing later SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(task); + queue_tasks.defer(std::move(task)); break; } - if (!launch_slot_with_task(*slot, task)) { + if (!launch_slot_with_task(*slot, std::move(task))) { SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); break; } @@ -2553,9 +2748,6 @@ struct server_context { res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); res->t_start = metrics.t_start; - res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx); - res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx); - res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; res->t_prompt_processing_total = metrics.t_prompt_processing_total; res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; @@ -2575,6 +2767,10 @@ struct server_context { queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_SAVE: { + if (!ensure_no_mtmd(task.id)) { + break; + } + int id_slot = task.slot_action.slot_id; server_slot *slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2584,7 +2780,7 @@ struct server_context { if (slot->is_processing()) { // if requested slot is unavailable, we defer this task for processing later SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(task); + queue_tasks.defer(std::move(task)); break; } @@ -2594,8 +2790,9 @@ struct server_context { std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; + const llama_tokens &tokens = slot->cache_tokens.get_text_tokens(); const size_t nwrite = - llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); + llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -2611,6 +2808,8 @@ struct server_context { queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { + if (!ensure_no_mtmd(task.id)) + break; int id_slot = task.slot_action.slot_id; server_slot *slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2620,7 +2819,7 @@ struct server_context { if (slot->is_processing()) { // if requested slot is unavailable, we defer this task for processing later SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(task); + queue_tasks.defer(std::move(task)); break; } @@ -2629,17 +2828,20 @@ struct server_context { std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - slot->cache_tokens.resize(slot->n_ctx); + llama_tokens tokens; + tokens.resize(slot->n_ctx); size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), - slot->cache_tokens.size(), &token_count); + size_t nread = + llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { - slot->cache_tokens.resize(0); + slot->cache_tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } - slot->cache_tokens.resize(token_count); + tokens.resize(token_count); + slot->cache_tokens.clear(); + slot->cache_tokens.insert(tokens); const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -2655,6 +2857,8 @@ struct server_context { queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_ERASE: { + if (!ensure_no_mtmd(task.id)) + break; int id_slot = task.slot_action.slot_id; server_slot *slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2664,13 +2868,13 @@ struct server_context { if (slot->is_processing()) { // if requested slot is unavailable, we defer this task for processing later SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(task); + queue_tasks.defer(std::move(task)); break; } // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_kv_cache_seq_rm(ctx, slot->id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1); slot->cache_tokens.clear(); auto res = std::make_unique(); @@ -2715,7 +2919,7 @@ struct server_context { server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); task.id = queue_tasks.get_new_id(); - queue_tasks.post(task); + queue_tasks.post(std::move(task)); } // apply context-shift if needed @@ -2730,6 +2934,12 @@ struct server_context { continue; } + if (mctx) { + // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is + // loaded we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); + } + // Shift context const int n_keep = slot.params.n_keep + add_bos_token; const int n_left = slot.n_past - n_keep; @@ -2738,15 +2948,19 @@ struct server_context { SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - llama_kv_cache_seq_rm(ctx, slot.id, n_keep, n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, n_keep, n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.n_past, -n_discard); - if (slot.params.cache_prompt) { - for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + // add generated tokens to cache + { + llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy + for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { + new_tokens[i - n_discard] = new_tokens[i]; } - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + new_tokens.resize(slot.cache_tokens.size() - n_discard); + slot.cache_tokens.clear(); + slot.cache_tokens.insert(new_tokens); } slot.n_past -= n_discard; @@ -2784,10 +2998,7 @@ struct server_context { common_batch_add(batch, slot.sampled, slot.n_past, {slot.id}, true); slot.n_past += 1; - - if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(slot.sampled); - } + slot.cache_tokens.push_back(slot.sampled); SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", slot.n_ctx, slot.n_past, (int)slot.cache_tokens.size(), slot.truncated); @@ -2826,19 +3037,19 @@ struct server_context { slot.params.n_keep, slot.n_prompt_tokens); // print prompt tokens (for debugging) - if (1) { + /*if (1) { // first 16 tokens (avoid flooding logs) for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], - common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + common_token_to_piece(ctx, prompt_tokens[i]).c_str()); } } else { // all - for (int i = 0; i < (int)prompt_tokens.size(); i++) { + for (int i = 0; i < (int) prompt_tokens.size(); i++) { SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], - common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + common_token_to_piece(ctx, prompt_tokens[i]).c_str()); } - } + }*/ // empty prompt passed -> release the slot and send empty response if (prompt_tokens.empty()) { @@ -2850,7 +3061,15 @@ struct server_context { continue; } - if (slot.is_non_causal()) { + // TODO: support memory-less logits computation + if (slot.need_logits() && !llama_get_memory(ctx)) { + slot.release(); + send_error(slot, "the current context does not logits computation. skipping", + ERROR_TYPE_SERVER); + continue; + } + + if (!slot.can_split()) { if (slot.n_prompt_tokens > n_ubatch) { slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", @@ -2885,21 +3104,26 @@ struct server_context { // if input prompt is too big, truncate it if (slot.n_prompt_tokens >= slot.n_ctx) { + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - llama_tokens new_tokens(prompt_tokens.begin(), - prompt_tokens.begin() + slot.params.n_keep); + const llama_tokens &curr_tokens = slot.prompt_tokens.get_text_tokens(); + llama_tokens new_tokens(curr_tokens.begin(), curr_tokens.begin() + slot.params.n_keep); new_tokens.insert(new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + + curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - prompt_tokens.end()); + curr_tokens.end()); - prompt_tokens = std::move(new_tokens); + prompt_tokens.clear(); + prompt_tokens.insert(new_tokens); slot.truncated = true; slot.n_prompt_tokens = prompt_tokens.size(); @@ -2913,13 +3137,18 @@ struct server_context { if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens); + slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); // reuse chunks from the cached prompt by shifting their KV cache in the new position if (params_base.n_cache_reuse > 0) { size_t head_c = slot.n_past; // cache size_t head_p = slot.n_past; // current prompt + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); @@ -2945,11 +3174,12 @@ struct server_context { const int64_t kv_shift = (int64_t)head_p - (int64_t)head_c; - llama_kv_cache_seq_rm(ctx, slot.id, head_p, head_c); - llama_kv_cache_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift); + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, + head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { - slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; + slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]); slot.n_past++; } @@ -2962,13 +3192,39 @@ struct server_context { SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); } + } else { + // if we don't cache the prompt, we have to remove the entire KV cache + slot.n_past = 0; + } + + if (slot.n_past > 0 && slot.n_past < (int)slot.cache_tokens.size()) { + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + if (pos_min == -1) { + SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", + slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min); + GGML_ABORT( + "pos_min == -1, but n_past > 0 - should not happen: " + "https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); + } + + const auto n_swa = llama_model_n_swa(model); + if (pos_min > std::max(0, slot.n_past - n_swa)) { + SLT_WRN(slot, + "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = " + "%d\n", + slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min, n_swa); + SLT_WRN(slot, + "forcing full prompt re-processing due to lack of cache data (likely due " + "to SWA, see %s)\n", + "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + slot.n_past = 0; + } } } if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { - // we have to evaluate at least 1 token to generate logits. SLT_WRN(slot, - "need to evaluate at least 1 token to generate logits, n_past = %d, " + "need to evaluate at least 1 token for each active slot, n_past = %d, " "n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); @@ -2978,8 +3234,7 @@ struct server_context { slot.n_prompt_tokens_processed = 0; } - // non-causal tasks require to fit the entire prompt in the physical batch - if (slot.is_non_causal()) { + if (!slot.can_split()) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; @@ -2987,9 +3242,9 @@ struct server_context { } // keep only the common part - if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) { + if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) { // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); // there is no common part left slot.n_past = 0; @@ -2998,24 +3253,52 @@ struct server_context { SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); // remove the non-common part from the cache - slot.cache_tokens.resize(slot.n_past); + slot.cache_tokens.keep_first(slot.n_past); + + // check if we should process the image + if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { + // process the image + int32_t new_n_past; + int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); + int32_t n_pos = new_n_past - slot.n_past; + + if (res != 0) { + SLT_ERR(slot, "failed to process image, res = %d\n", res); + slot.release(); + send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + continue; + } + + // add the image chunk to cache + { + const auto &chunk = slot.prompt_tokens.find_chunk(slot.n_past); + slot.cache_tokens.push_back(chunk.get()); // copy + } + + slot.n_past += n_pos; + slot.n_prompt_tokens_processed += n_pos; + } // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { - // without pooling, we want to output the embeddings for all the tokens in the batch - const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && - llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; + // get next token to process + llama_token cur_tok = slot.prompt_tokens[slot.n_past]; + if (cur_tok == LLAMA_TOKEN_NULL) { + break; // end of text chunk + } - common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, {slot.id}, need_embd); + // embedding requires all tokens in the batch to be output + const bool need_embd = server_task_type_need_embd(slot.task_type); - if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); - } + common_batch_add(batch, cur_tok, slot.n_past, {slot.id}, need_embd); + slot.cache_tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; slot.n_past++; } + // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float)slot.n_prompt_tokens_processed / slot.n_prompt_tokens); @@ -3024,12 +3307,16 @@ struct server_context { slot.state = SLOT_STATE_DONE_PROMPT; GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT((size_t)slot.n_prompt_tokens == slot.prompt_tokens.size()); common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system for (int i = 0; i < slot.n_prompt_tokens; ++i) { - common_sampler_accept(slot.smpl, prompt_tokens[i], false); + llama_token id = slot.prompt_tokens[i]; + if (id != LLAMA_TOKEN_NULL) { + common_sampler_accept(slot.smpl, id, false); + } } // extract the logits only for the last token @@ -3056,14 +3343,48 @@ struct server_context { SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); if (slot_batched) { - // make sure we're in the right embedding mode - llama_set_embeddings(ctx, slot_batched->is_non_causal()); // apply lora, only need to do it once per batch common_set_adapter_lora(ctx, slot_batched->lora); + + llama_set_embeddings(ctx, slot_batched->need_embd()); } + // pad the batch so that batch.n_tokens >= n_slots + // TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689 + if (slot_batched->need_embd()) { + const int n_slots = slots.size(); + + if (batch.n_tokens < n_slots) { + std::set seq_ids; + for (int j = 0; j < batch.n_tokens; ++j) { + seq_ids.insert(batch.seq_id[j][0]); + } + + // find unused sequence id + llama_seq_id seq_id = -1; + for (int i = 0; i < n_slots; ++i) { + if (seq_ids.find(i) == seq_ids.end()) { + seq_id = i; + } + } + + const int n_add = n_slots - batch.n_tokens; + + SRV_WRN("adding %d dummy tokens to the batch, seq_id = %d\n", n_add, seq_id); + + for (int j = 0; j < n_add; ++j) { + common_batch_add(batch, 0, j, {seq_id}, true); + } + + slots[seq_id].cache_tokens.clear(); + llama_memory_seq_rm(llama_get_memory(ctx), seq_id, -1, -1); + } + } + + int32_t i_next = 0; + // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { + for (int32_t i = 0; i < batch.n_tokens; i = i_next) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); llama_batch batch_view = { @@ -3072,32 +3393,51 @@ struct server_context { }; const int ret = llama_decode(ctx, batch_view); + metrics.on_decoded(slots); if (ret != 0) { - if (n_batch == 1 || ret < 0) { - // if you get here, it means the KV cache is full - try increasing it via the context size - SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i " - "= %d, n_batch = %d, ret = %d\n", - i, n_batch, ret); - for (auto &slot : slots) { - slot.release(); - send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); + { + std::string err; + + if (n_batch == 1 && ret == 1) { + err = "Context size has been exceeded."; + } + + if (ret == -1) { + err = "Invalid input batch."; + } + + if (ret < -1) { + err = "Compute error."; + } + + if (!err.empty()) { + SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); + for (auto &slot : slots) { + slot.release(); + send_error(slot, err); + } + break; } - break; // break loop of n_batch } // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; - i -= n_batch; - SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing " - "it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch " + "= %d, ret = %d\n", i, n_batch, ret); continue; // continue loop of n_batch } + // move the head of the batch forward with the number of tokens we just processed + i_next = i + n_tokens; + + // on successful decode, restore the original batch size + n_batch = llama_n_batch(ctx); + for (auto &slot : slots) { if (slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) { continue; // continue loop of slots @@ -3174,6 +3514,11 @@ struct server_context { continue; } + if (mctx) { + // we should never reach this, as speculative is automatically disabled if mmproj is loaded + GGML_ABORT("not supported by multimodal"); + } + // determine the max draft that fits the current slot state int n_draft_max = slot.params.speculative.n_max; @@ -3201,7 +3546,8 @@ struct server_context { params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; params_spec.p_min = slot.params.speculative.p_min; - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); + const llama_tokens &cached_text_tokens = slot.cache_tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); // ignore small drafts if (slot.params.speculative.n_min > (int)draft.size()) { @@ -3210,6 +3556,9 @@ struct server_context { continue; } + // keep track of total number of drafted tokens tested + slot.n_draft_total += draft.size(); + // construct the speculation batch common_batch_clear(slot.batch_spec); common_batch_add(slot.batch_spec, id, slot.n_past, {slot.id}, true); @@ -3228,10 +3577,13 @@ struct server_context { slot.n_past += ids.size(); slot.n_decoded += ids.size(); + // update how many tokens out of those tested were accepted + slot.n_draft_accepted += ids.size() - 1; + slot.cache_tokens.push_back(id); - slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); + slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); - llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1); for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp index 603424b..913a9ce 100644 --- a/src/main/cpp/utils.hpp +++ b/src/main/cpp/utils.hpp @@ -1,24 +1,18 @@ #pragma once +#include "arg.h" // common_remote_get_content #include "base64.hpp" +#include "chat.h" #include "common.h" #include "llama.h" #include "log.h" +#include "mtmd-helper.h" +#include "mtmd.h" -#ifndef NDEBUG -// crash the server in debug mode, otherwise send an http 500 error -#define CPPHTTPLIB_NO_EXCEPTIONS 1 -#endif -// increase max payload length to allow use of larger context size -#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 -// #include "httplib.h" - -// Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT -#include "nlohmann/json.hpp" - -#include "chat.h" +#include +#include #include #include #include @@ -48,6 +42,8 @@ using json = nlohmann::ordered_json; #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +using raw_buffer = std::vector; + template static T json_value(const json &body, const std::string &key, const T &default_value) { // Fallback null to default value if (body.contains(key) && !body.at(key).is_null()) { @@ -65,6 +61,32 @@ template static T json_value(const json &body, const std::string &k const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); +// thin wrapper around common_grammar_trigger with (de)serialization functions +struct server_grammar_trigger { + common_grammar_trigger value; + + server_grammar_trigger() = default; + server_grammar_trigger(const common_grammar_trigger &value) : value(value) {} + server_grammar_trigger(const json &in) { + value.type = (common_grammar_trigger_type)in.at("type").get(); + value.value = in.at("value").get(); + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + value.token = (llama_token)in.at("token").get(); + } + } + + json to_json() const { + json out{ + {"type", (int)value.type}, + {"value", value.value}, + }; + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + out["token"] = (int)value.token; + } + return out; + } +}; + // // tokenizer and input processing utils // @@ -248,13 +270,19 @@ static size_t validate_utf8(const std::string &text) { static llama_tokens format_rerank(const struct llama_vocab *vocab, const llama_tokens &query, const llama_tokens &doc) { llama_tokens result; + // Get EOS token - use SEP token as fallback if EOS is not available + llama_token eos_token = llama_vocab_eos(vocab); + if (eos_token == LLAMA_TOKEN_NULL) { + eos_token = llama_vocab_sep(vocab); + } + result.reserve(doc.size() + query.size() + 4); result.push_back(llama_vocab_bos(vocab)); result.insert(result.end(), query.begin(), query.end()); - result.push_back(llama_vocab_eos(vocab)); + result.push_back(eos_token); result.push_back(llama_vocab_sep(vocab)); result.insert(result.end(), doc.begin(), doc.end()); - result.push_back(llama_vocab_eos(vocab)); + result.push_back(eos_token); return result; } @@ -367,7 +395,7 @@ static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); } -static inline std::vector base64_decode(const std::string &encoded_string) { +static inline raw_buffer base64_decode(const std::string &encoded_string) { int i = 0; int j = 0; int in_ = 0; @@ -377,7 +405,7 @@ static inline std::vector base64_decode(const std::string &encoded_stri uint8_t char_array_4[4]; uint8_t char_array_3[3]; - std::vector ret; + raw_buffer ret; while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { char_array_4[i++] = encoded_string[in_]; @@ -441,30 +469,12 @@ static std::string random_string() { static std::string gen_chatcmplid() { return "chatcmpl-" + random_string(); } +static std::string gen_tool_call_id() { return random_string(); } + // // other common utils // -static bool ends_with(const std::string &str, const std::string &suffix) { - return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); -} - -static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { - if (!text.empty() && !stop.empty()) { - const char text_last_char = text.back(); - for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { - if (stop[char_index] == text_last_char) { - const std::string current_partial = stop.substr(0, char_index + 1); - if (ends_with(text, current_partial)) { - return text.size() - char_index - 1; - } - } - } - } - - return std::string::npos; -} - // TODO: reuse llama_detokenize template static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) { std::string ret; @@ -491,21 +501,11 @@ static std::string tokens_to_output_formatted_string(const llama_context *ctx, c return out; } -// static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) { -// const std::string str = -// std::string(event) + ": " + -// data.dump(-1, ' ', false, json::error_handler_t::replace) + -// "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). -// -// LOG_DBG("data stream, to_send: %s", str.c_str()); -// -// return sink.write(str.c_str(), str.size()); -// } - // // OAI utils // +// used by /completions endpoint static json oaicompat_completion_params_parse(const json &body) { json llama_params; @@ -550,25 +550,32 @@ static json oaicompat_completion_params_parse(const json &body) { return llama_params; } -static json oaicompat_completion_params_parse(const json &body, /* openai api json semantics */ - bool use_jinja, common_reasoning_format reasoning_format, - const struct common_chat_templates *tmpls) { +struct oaicompat_parser_options { + bool use_jinja = false; + bool prefill_assistant = false; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; + common_chat_templates *tmpls = nullptr; + bool allow_image = false; + bool allow_audio = false; + bool enable_thinking = false; +}; + +// used by /chat/completions endpoint +static json oaicompat_chat_params_parse(json &body, /* openai api json semantics */ + const oaicompat_parser_options &opt, std::vector &out_files) { json llama_params; auto tools = json_value(body, "tools", json()); + auto has_tools = tools.is_array() && !tools.empty(); auto stream = json_value(body, "stream", false); + auto tool_choice = json_value(body, "tool_choice", std::string("auto")); - if (tools.is_array() && !tools.empty()) { - if (stream) { - throw std::runtime_error("Cannot use tools with stream"); - } - if (!use_jinja) { + if (!opt.use_jinja) { + if (has_tools) { throw std::runtime_error("tools param requires --jinja flag"); } - } - if (!use_jinja) { - if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) { - throw std::runtime_error("Unsupported param: tool_choice"); + if (tool_choice != "auto") { + throw std::runtime_error("tool_choice param requires --jinja flag"); } } @@ -600,34 +607,171 @@ static json oaicompat_completion_params_parse(const json &body, /* openai api js } } + // get input files + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + json &messages = body.at("messages"); + if (!messages.is_array()) { + throw std::runtime_error("Expected 'messages' to be an array"); + } + for (auto &msg : messages) { + std::string role = json_value(msg, "role", std::string()); + if (role != "assistant" && !msg.contains("content")) { + throw std::runtime_error("All non-assistant messages must contain 'content'"); + } + if (role == "assistant") { + if (!msg.contains("content") && !msg.contains("tool_calls")) { + throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!"); + } + if (!msg.contains("content")) { + continue; // avoid errors with no content + } + } + json &content = msg.at("content"); + if (content.is_string() || content.is_null()) { + continue; + } + + if (!content.is_array()) { + throw std::runtime_error("Expected 'content' to be a string or an array"); + } + + for (auto &p : content) { + std::string type = json_value(p, "type", std::string()); + if (type == "image_url") { + if (!opt.allow_image) { + throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need " + "to provide the mmproj"); + } + + json image_url = json_value(p, "image_url", json::object()); + std::string url = json_value(image_url, "url", std::string()); + if (string_starts_with(url, "http")) { + // download remote image + // TODO @ngxson : maybe make these params configurable + common_remote_params params; + params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.max_size = 1024 * 1024 * 10; // 10MB + params.timeout = 10; // seconds + SRV_INF("downloading image from '%s'\n", url.c_str()); + auto res = common_remote_get_content(url, params); + if (200 <= res.first && res.first < 300) { + SRV_INF("downloaded %ld bytes\n", res.second.size()); + raw_buffer data; + data.insert(data.end(), res.second.begin(), res.second.end()); + out_files.push_back(data); + } else { + throw std::runtime_error("Failed to download image"); + } + + } else { + // try to decode base64 image + std::vector parts = string_split(url, /*separator*/ ','); + if (parts.size() != 2) { + throw std::runtime_error("Invalid image_url.url value"); + } else if (!string_starts_with(parts[0], "data:image/")) { + throw std::runtime_error("Invalid image_url.url format: " + parts[0]); + } else if (!string_ends_with(parts[0], "base64")) { + throw std::runtime_error("image_url.url must be base64 encoded"); + } else { + auto base64_data = parts[1]; + auto decoded_data = base64_decode(base64_data); + out_files.push_back(decoded_data); + } + } + + // replace this chunk with a marker + p["type"] = "text"; + p["text"] = mtmd_default_marker(); + p.erase("image_url"); + + } else if (type == "input_audio") { + if (!opt.allow_audio) { + throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need " + "to provide the mmproj"); + } + + json input_audio = json_value(p, "input_audio", json::object()); + std::string data = json_value(input_audio, "data", std::string()); + std::string format = json_value(input_audio, "format", std::string()); + // while we also support flac, we don't allow it here so we matches the OAI spec + if (format != "wav" && format != "mp3") { + throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'"); + } + auto decoded_data = base64_decode(data); // expected to be base64 encoded + out_files.push_back(decoded_data); + + // replace this chunk with a marker + p["type"] = "text"; + p["text"] = mtmd_default_marker(); + p.erase("input_audio"); + + } else if (type != "text") { + throw std::runtime_error("unsupported content[].type"); + } + } + } + common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages")); + inputs.messages = common_chat_msgs_parse_oaicompat(messages); inputs.tools = common_chat_tools_parse_oaicompat(tools); - inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice); inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); inputs.grammar = grammar; - inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); - inputs.use_jinja = use_jinja; + inputs.use_jinja = opt.use_jinja; inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); - inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); - if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) { - throw std::runtime_error("Cannot use custom grammar constraints with tools."); + inputs.reasoning_format = opt.reasoning_format; + inputs.enable_thinking = opt.enable_thinking; + if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { + if (body.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + llama_params["parse_tool_calls"] = true; + } + + // if the assistant message appears at the end of list, we do not add end-of-turn token + // for ex. this can be useful to modify the reasoning process in reasoning models + bool prefill_assistant_message = + !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant; + common_chat_msg last_message; + if (prefill_assistant_message) { + last_message = inputs.messages.back(); + inputs.messages.pop_back(); + + /* sanity check, max one assistant message at the end of the list */ + if (!inputs.messages.empty() && inputs.messages.back().role == "assistant") { + throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list."); + } + + /* TODO: test this properly */ + inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE; + inputs.add_generation_prompt = true; } // Apply chat template to the list of messages - auto chat_params = common_chat_templates_apply(tmpls, inputs); + auto chat_params = common_chat_templates_apply(opt.tmpls, inputs); + + /* Append assistant prefilled message */ + if (prefill_assistant_message) { + chat_params.prompt += last_message.content; + } llama_params["chat_format"] = static_cast(chat_params.format); llama_params["prompt"] = chat_params.prompt; - llama_params["grammar"] = chat_params.grammar; + if (!chat_params.grammar.empty()) { + llama_params["grammar"] = chat_params.grammar; + } llama_params["grammar_lazy"] = chat_params.grammar_lazy; auto grammar_triggers = json::array(); for (const auto &trigger : chat_params.grammar_triggers) { - grammar_triggers.push_back(trigger.to_json()); + server_grammar_trigger ct(trigger); + grammar_triggers.push_back(ct.to_json()); } llama_params["grammar_triggers"] = grammar_triggers; llama_params["preserved_tokens"] = chat_params.preserved_tokens; + llama_params["thinking_forced_open"] = chat_params.thinking_forced_open; for (const auto &stop : chat_params.additional_stops) { llama_params["stop"].push_back(stop); } @@ -642,6 +786,9 @@ static json oaicompat_completion_params_parse(const json &body, /* openai api js // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may // need to fix it in the future if (json_value(body, "logprobs", false)) { + if (has_tools && stream) { + throw std::runtime_error("logprobs is not supported with tools + stream"); + } llama_params["n_probs"] = json_value(body, "top_logprobs", 20); } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { throw std::runtime_error("top_logprobs requires logprobs to be set to true"); @@ -853,4 +1000,257 @@ static std::vector parse_lora_request(const std::vecto } return lora; -} \ No newline at end of file +} + +// +// utils for interacting with libmtmd +// (may need to refactor in near future) +// + +/** + * server_tokens is a helper to manage the input tokens and image for the server. + * it is made this way to simplify the logic of KV cache management. + */ +struct server_tokens { + bool has_mtmd = false; + + private: // disallow accessing these members directly, risking out-of-sync + // map a **start** position in tokens to the image chunk + std::unordered_map map_pos_to_media; + + // list of tokens + // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token + // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position** + // important: for models using mrope, an image can contain multiple tokens but will use only one **position** + llama_tokens tokens; + + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // pos 0 1 2 3 4 5 6 7 8 9 + // map_pos_to_media will contain: {5, img0}, {8, img1} + + public: + server_tokens() = default; + ~server_tokens() = default; + + // Prevent copying + server_tokens(const server_tokens &) = delete; + server_tokens &operator=(const server_tokens &) = delete; + + // Allow moving (usually implicitly generated if members are movable) + server_tokens(server_tokens &&) = default; + server_tokens &operator=(server_tokens &&) = default; + + // Allow accessing elements using [] operator + llama_token operator[](size_t index) { return tokens[index]; } + const llama_token &operator[](size_t index) const { return tokens[index]; } + + server_tokens(mtmd::input_chunks &mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { + for (size_t i = 0; i < mtmd_chunks.size(); ++i) { + push_back(mtmd_chunks[i]); + } + } + + server_tokens(llama_tokens &tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + + // for debugging + std::string str() const { + std::ostringstream oss; + oss << "tokens: "; + for (const auto &t : tokens) { + if (t == LLAMA_TOKEN_NULL) { + oss << " "; + } else { + oss << t << " "; + } + } + oss << "\n"; + oss << "image pos: "; + for (const auto &it : map_pos_to_media) { + oss << it.first << ", "; + } + return oss.str(); + } + + const mtmd::input_chunk_ptr &find_chunk(llama_pos pos) const { + auto it = map_pos_to_media.find(pos); + if (it != map_pos_to_media.end()) { + return it->second; + } else { + throw std::runtime_error("Chunk not found"); + } + } + + void push_back(llama_token tok) { + if (tok == LLAMA_TOKEN_NULL) { + throw std::runtime_error("Invalid token"); + } + tokens.emplace_back(tok); + } + + // will create a copy of the chunk if it contains non-text data + void push_back(const mtmd_input_chunk *chunk) { + auto type = mtmd_input_chunk_get_type(chunk); + if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { + GGML_ASSERT(has_mtmd); + const int n_pos = mtmd_input_chunk_get_n_pos(chunk); + llama_pos start_pos = tokens.size(); + for (int i = 0; i < n_pos; ++i) { + tokens.emplace_back(LLAMA_TOKEN_NULL); + } + mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); + map_pos_to_media[start_pos] = std::move(new_chunk); + } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + for (size_t i = 0; i < n_tokens; ++i) { + push_back(text_tokens[i]); + } + } else { + GGML_ABORT("Invalid chunk type"); + } + } + + // for compatibility with context shift and prompt truncation + void insert(const llama_tokens &inp_tokens) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); + } + + // for compatibility with speculative decoding, ctx shift, slot save/load + const llama_tokens &get_text_tokens() const { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + return tokens; + } + + // for compatibility with speculative decoding + void set_token(llama_pos pos, llama_token id) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens[pos] = id; + } + + size_t size() const { return tokens.size(); } + + bool empty() const { return tokens.empty(); } + + void clear() { tokens.clear(); } + + void keep_first(size_t n) { + GGML_ASSERT(n <= tokens.size()); + if (has_mtmd) { + if (n == tokens.size()) { + return; // nothing to do + } + // we throw an error if we try to remove a token in the middle of an image + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // n 1 2 3 4 5 6 7 8 9 10 + // allowed to resize ^ ^ + // disallowed to resize ^ ^ ^ + if (n > 0) { + llama_token last_token = tokens[n - 1]; + // make sure we never remove tokens in the middle of an image + if (last_token == LLAMA_TOKEN_NULL) { + find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk + } + } + // remove all image chunks that are not used anymore + for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end();) { + llama_pos pos = it->first; + if (pos >= (llama_pos)n) { + it = map_pos_to_media.erase(it); + } else { + ++it; + } + } + } + tokens.resize(n); + } + + std::string detokenize(const llama_context *ctx, bool special) const { + llama_tokens text_tokens; + text_tokens.reserve(tokens.size()); + for (const auto &t : tokens) { + if (t != LLAMA_TOKEN_NULL) { + text_tokens.push_back(t); + } + } + return common_detokenize(ctx, text_tokens, special); + } + + size_t get_common_prefix(const server_tokens &b) const { + size_t max_idx = std::min(tokens.size(), b.tokens.size()); + for (size_t i = 0; i < max_idx; ++i) { + auto &ai = tokens[i]; + auto &bi = b.tokens[i]; + + if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + GGML_ASSERT(has_mtmd); + const auto &a_chunk = find_chunk(i); + const auto &b_chunk = b.find_chunk(i); + GGML_ASSERT(a_chunk && b_chunk); + std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get()); + std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get()); + size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get()); + size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get()); + if (ai_id == bi_id && a_pos == b_pos) { + GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen + i += a_pos - 1; // will be +1 by the for loop + continue; + } else { + return i; + } + } else if (ai == bi) { + continue; + } else { + return i; + } + } + return max_idx; // all tokens are equal + } + + // make sure all text tokens are within the vocab range + bool validate(const struct llama_context *ctx) const { + const llama_model *model = llama_get_model(ctx); + const llama_vocab *vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); + + for (size_t i = 0; i < tokens.size(); ++i) { + auto &t = tokens[i]; + if (t == LLAMA_TOKEN_NULL) { + try { + const auto &chunk = find_chunk(i); + size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); + i += n_pos - 1; // will be +1 by the for loop + } catch (const std::exception &e) { + return false; + } + } else if (t < 0 || t >= n_vocab) { + return false; + } + } + return true; + } + + // encode and decode the image chunk + int32_t process_chunk(llama_context *ctx, mtmd_context *mctx, llama_pos n_past, int32_t seq_id, + llama_pos &n_pos_out) { + auto &chunk = find_chunk(n_past); + const char *name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; + SRV_INF("processing %s...\n", name); + int32_t n_batch = llama_n_batch(ctx); + int64_t t0 = ggml_time_ms(); + llama_pos new_n_past = n_past; + int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, chunk.get(), n_past, seq_id, n_batch, + true, // logits last + &new_n_past); + SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); + if (result != 0) { + LOG_ERR("mtmd_helper_eval failed with status %d", result); + n_pos_out = n_past; + return result; + } + n_pos_out = new_n_past; + return 0; + } +}; diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index e3e69d8..bb94a70 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -330,6 +330,6 @@ public void testTemplate() { .setStopStrings("\"\"\"") .setNPredict(nPredict) .setSeed(42); - Assert.assertEquals(model.applyTemplate(params), "<|im_start|>system\nBook<|im_end|>\n<|im_start|>user\nWhat is the best book?<|im_end|>\n<|im_start|>assistant\nIt depends on your interests. Do you like fiction or non-fiction?<|im_end|>\n<|im_start|>assistant\n"); + Assert.assertEquals(model.applyTemplate(params), "<|im_start|>system\nBook<|im_end|>\n<|im_start|>user\nWhat is the best book?<|im_end|>\n<|im_start|>assistant\nIt depends on your interests. Do you like fiction or non-fiction?"); } }