diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..a113c01c --- /dev/null +++ b/.clang-format @@ -0,0 +1,225 @@ +--- +Language: Cpp +# BasedOnStyle: LLVM +AccessModifierOffset: -2 +AlignAfterOpenBracket: Align +AlignArrayOfStructures: None +AlignConsecutiveAssignments: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + PadOperators: true +AlignConsecutiveBitFields: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + PadOperators: false +AlignConsecutiveDeclarations: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + PadOperators: false +AlignConsecutiveMacros: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + PadOperators: false +AlignEscapedNewlines: Right +AlignOperands: Align +AlignTrailingComments: + Kind: Always + OverEmptyLines: 0 +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: Never +AllowShortCaseLabelsOnASingleLine: false +AllowShortEnumsOnASingleLine: true +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: All +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: MultiLine +AttributeMacros: + - __capability +BinPackArguments: true +BinPackParameters: true +BitFieldColonSpacing: Both +BraceWrapping: + AfterCaseLabel: false + AfterClass: false + AfterControlStatement: Never + AfterEnum: false + AfterExternBlock: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + BeforeLambdaBody: false + BeforeWhile: false + IndentBraces: false + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +BreakAfterAttributes: Never +BreakAfterJavaFieldAnnotations: false +BreakArrays: true +BreakBeforeBinaryOperators: None +BreakBeforeConceptDeclarations: Always +BreakBeforeBraces: Attach +BreakBeforeInlineASMColon: OnlyMultiline +BreakBeforeTernaryOperators: true +BreakConstructorInitializers: BeforeColon +BreakInheritanceList: BeforeColon +BreakStringLiterals: true +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +EmptyLineAfterAccessModifier: Never +EmptyLineBeforeAccessModifier: LogicalBlock +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IfMacros: + - KJ_IF_MAYBE +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + SortPriority: 0 + CaseSensitive: false + - Regex: '^(<|"(gtest|gmock|isl|json)/)' + Priority: 3 + SortPriority: 0 + CaseSensitive: false + - Regex: '.*' + Priority: 1 + SortPriority: 0 + CaseSensitive: false +IncludeIsMainRegex: '(Test)?$' +IncludeIsMainSourceRegex: '' +IndentAccessModifiers: false +IndentCaseBlocks: false +IndentCaseLabels: false +IndentExternBlock: AfterExternBlock +IndentGotoLabels: true +IndentPPDirectives: None +IndentRequiresClause: true +IndentWidth: 4 +IndentWrappedFunctionNames: false +InsertBraces: false +InsertNewlineAtEOF: false +InsertTrailingCommas: None +IntegerLiteralSeparator: + Binary: 0 + BinaryMinDigits: 0 + Decimal: 0 + DecimalMinDigits: 0 + Hex: 0 + HexMinDigits: 0 +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: true +LambdaBodyIndentation: Signature +LineEnding: DeriveLF +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 4 +ObjCBreakBeforeNestedBlockParam: true +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PackConstructorInitializers: BinPack +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakOpenParenthesis: 0 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyIndentedWhitespace: 0 +PenaltyReturnTypeOnItsOwnLine: 60 +PointerAlignment: Right +PPIndentWidth: -1 +QualifierAlignment: Leave +ReferenceAlignment: Pointer +ReflowComments: true +RemoveBracesLLVM: false +RemoveSemicolon: false +RequiresClausePosition: OwnLine +RequiresExpressionIndentation: OuterScope +SeparateDefinitionBlocks: Leave +ShortNamespaceLines: 1 +SortIncludes: CaseSensitive +SortJavaStaticImport: Before +SortUsingDeclarations: LexicographicNumeric +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceAroundPointerQualifiers: Default +SpaceBeforeAssignmentOperators: true +SpaceBeforeCaseColon: false +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeParensOptions: + AfterControlStatements: true + AfterForeachMacros: true + AfterFunctionDefinitionName: false + AfterFunctionDeclarationName: false + AfterIfMacros: true + AfterOverloadedOperator: false + AfterRequiresInClause: false + AfterRequiresInExpression: false + BeforeNonEmptyParentheses: false +SpaceBeforeRangeBasedForLoopColon: true +SpaceBeforeSquareBrackets: false +SpaceInEmptyBlock: false +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: Never +SpacesInConditionalStatement: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInLineCommentPrefix: + Minimum: 1 + Maximum: -1 +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Latest +StatementAttributeLikeMacros: + - Q_EMIT +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +TabWidth: 8 +UseTab: Never +WhitespaceSensitiveMacros: + - BOOST_PP_STRINGIZE + - CF_SWIFT_NAME + - NS_SWIFT_NAME + - PP_STRINGIZE + - STRINGIZE +... + diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 00000000..952c0cca --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,24 @@ +--- +Checks: > + bugprone-*, + -bugprone-easily-swappable-parameters, + -bugprone-implicit-widening-of-multiplication-result, + -bugprone-misplaced-widening-cast, + -bugprone-narrowing-conversions, + readability-*, + -readability-avoid-unconditional-preprocessor-if, + -readability-function-cognitive-complexity, + -readability-identifier-length, + -readability-implicit-bool-conversion, + -readability-magic-numbers, + -readability-uppercase-literal-suffix, + -readability-simplify-boolean-expr, + clang-analyzer-*, + -clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling, + performance-*, + portability-*, + misc-*, + -misc-const-correctness, + -misc-non-private-member-variables-in-classes, + -misc-no-recursion, +FormatStyle: none diff --git a/.github/build.bat b/.github/build.bat new file mode 100755 index 00000000..a904405e --- /dev/null +++ b/.github/build.bat @@ -0,0 +1,7 @@ +@echo off + +mkdir build +cmake -Bbuild %* +cmake --build build --config Release + +if errorlevel 1 exit /b %ERRORLEVEL% \ No newline at end of file diff --git a/.github/build.sh b/.github/build.sh new file mode 100755 index 00000000..2842d7e6 --- /dev/null +++ b/.github/build.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +mkdir -p build +cmake -Bbuild $@ || exit 1 +cmake --build build --config Release -j4 || exit 1 diff --git a/.github/build_cuda_linux.sh b/.github/build_cuda_linux.sh new file mode 100755 index 00000000..147c2174 --- /dev/null +++ b/.github/build_cuda_linux.sh @@ -0,0 +1,12 @@ +#!/bin/sh + +# A Cuda 12.1 install script for RHEL8/Rocky8/Manylinux_2.28 + +sudo dnf install -y kernel-devel kernel-headers +sudo dnf install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-8.noarch.rpm +sudo dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo + +# We prefer CUDA 12.1 as it's compatible with 12.2+ +sudo dnf install -y cuda-toolkit-12-1 + +exec .github/build.sh $@ -DGGML_CUDA=1 -DCMAKE_CUDA_COMPILER=/usr/local/cuda-12.1/bin/nvcc \ No newline at end of file diff --git a/.github/dockcross/dockcross-android-arm b/.github/dockcross/dockcross-android-arm new file mode 100755 index 00000000..9cb27365 --- /dev/null +++ b/.github/dockcross/dockcross-android-arm @@ -0,0 +1,278 @@ +#!/usr/bin/env bash + +DEFAULT_DOCKCROSS_IMAGE=dockcross/android-arm:20240418-88c04a4 + +#------------------------------------------------------------------------------ +# Helpers +# +err() { + echo -e >&2 "ERROR: $*\n" +} + +die() { + err "$*" + exit 1 +} + +has() { + # eg. has command update + local kind=$1 + local name=$2 + + type -t $kind:$name | grep -q function +} + +# If OCI_EXE is not already set, search for a container executor (OCI stands for "Open Container Initiative") +if [ -z "$OCI_EXE" ]; then + if which podman >/dev/null 2>/dev/null; then + OCI_EXE=podman + elif which docker >/dev/null 2>/dev/null; then + OCI_EXE=docker + else + die "Cannot find a container executor. Search for docker and podman." + fi +fi + +#------------------------------------------------------------------------------ +# Command handlers +# +command:update-image() { + $OCI_EXE pull $FINAL_IMAGE +} + +help:update-image() { + echo "Pull the latest $FINAL_IMAGE ." +} + +command:update-script() { + if cmp -s <( $OCI_EXE run --rm $FINAL_IMAGE ) $0; then + echo "$0 is up to date" + else + echo -n "Updating $0 ... " + $OCI_EXE run --rm $FINAL_IMAGE > $0 && echo ok + fi +} + +help:update-script() { + echo "Update $0 from $FINAL_IMAGE ." +} + +command:update() { + command:update-image + command:update-script +} + +help:update() { + echo "Pull the latest $FINAL_IMAGE, and then update $0 from that." +} + +command:help() { + if [[ $# != 0 ]]; then + if ! has command $1; then + err \"$1\" is not an dockcross command + command:help + elif ! has help $1; then + err No help found for \"$1\" + else + help:$1 + fi + else + cat >&2 < +ENDHELP + exit 1 + fi +} + +#------------------------------------------------------------------------------ +# Option processing +# +special_update_command='' +while [[ $# != 0 ]]; do + case $1 in + + --) + shift + break + ;; + + --args|-a) + ARG_ARGS="$2" + shift 2 + ;; + + --config|-c) + ARG_CONFIG="$2" + shift 2 + ;; + + --image|-i) + ARG_IMAGE="$2" + shift 2 + ;; + update|update-image|update-script) + special_update_command=$1 + break + ;; + -*) + err Unknown option \"$1\" + command:help + exit + ;; + + *) + break + ;; + + esac +done + +# The precedence for options is: +# 1. command-line arguments +# 2. environment variables +# 3. defaults + +# Source the config file if it exists +DEFAULT_DOCKCROSS_CONFIG=~/.dockcross +FINAL_CONFIG=${ARG_CONFIG-${DOCKCROSS_CONFIG-$DEFAULT_DOCKCROSS_CONFIG}} + +[[ -f "$FINAL_CONFIG" ]] && source "$FINAL_CONFIG" + +# Set the docker image +FINAL_IMAGE=${ARG_IMAGE-${DOCKCROSS_IMAGE-$DEFAULT_DOCKCROSS_IMAGE}} + +# Handle special update command +if [ "$special_update_command" != "" ]; then + case $special_update_command in + + update) + command:update + exit $? + ;; + + update-image) + command:update-image + exit $? + ;; + + update-script) + command:update-script + exit $? + ;; + + esac +fi + +# Set the docker run extra args (if any) +FINAL_ARGS=${ARG_ARGS-${DOCKCROSS_ARGS}} + +# Bash on Ubuntu on Windows +UBUNTU_ON_WINDOWS=$([ -e /proc/version ] && grep -l Microsoft /proc/version || echo "") +# MSYS, Git Bash, etc. +MSYS=$([ -e /proc/version ] && grep -l MINGW /proc/version || echo "") +# CYGWIN +CYGWIN=$([ -e /proc/version ] && grep -l CYGWIN /proc/version || echo "") + +if [ -z "$UBUNTU_ON_WINDOWS" -a -z "$MSYS" -a "$OCI_EXE" != "podman" ]; then + USER_IDS=(-e BUILDER_UID="$( id -u )" -e BUILDER_GID="$( id -g )" -e BUILDER_USER="$( id -un )" -e BUILDER_GROUP="$( id -gn )") +fi + +# Change the PWD when working in Docker on Windows +if [ -n "$UBUNTU_ON_WINDOWS" ]; then + WSL_ROOT="/mnt/" + CFG_FILE=/etc/wsl.conf + if [ -f "$CFG_FILE" ]; then + CFG_CONTENT=$(cat $CFG_FILE | sed -r '/[^=]+=[^=]+/!d' | sed -r 's/\s+=\s/=/g') + eval "$CFG_CONTENT" + if [ -n "$root" ]; then + WSL_ROOT=$root + fi + fi + HOST_PWD=`pwd -P` + HOST_PWD=${HOST_PWD/$WSL_ROOT//} +elif [ -n "$MSYS" ]; then + HOST_PWD=$PWD + HOST_PWD=${HOST_PWD/\//} + HOST_PWD=${HOST_PWD/\//:\/} +elif [ -n "$CYGWIN" ]; then + for f in pwd readlink cygpath ; do + test -n "$(type "${f}" )" || { echo >&2 "Missing functionality (${f}) (in cygwin)." ; exit 1 ; } ; + done ; + HOST_PWD="$( cygpath -w "$( readlink -f "$( pwd ;)" ; )" ; )" ; +else + HOST_PWD=$PWD + [ -L $HOST_PWD ] && HOST_PWD=$(readlink $HOST_PWD) +fi + +# Mount Additional Volumes +if [ -z "$SSH_DIR" ]; then + SSH_DIR="$HOME/.ssh" +fi + +HOST_VOLUMES= +if [ -e "$SSH_DIR" -a -z "$MSYS" ]; then + if test -n "${CYGWIN}" ; then + HOST_VOLUMES+="-v $(cygpath -w ${SSH_DIR} ; ):/home/$(id -un)/.ssh" ; + else + HOST_VOLUMES+="-v $SSH_DIR:/home/$(id -un)/.ssh" ; + fi ; +fi + +#------------------------------------------------------------------------------ +# Now, finally, run the command in a container +# +TTY_ARGS= +tty -s && [ -z "$MSYS" ] && TTY_ARGS=-ti +CONTAINER_NAME=dockcross_$RANDOM +$OCI_EXE run $TTY_ARGS --name $CONTAINER_NAME \ + -v "$HOST_PWD":/work \ + $HOST_VOLUMES \ + "${USER_IDS[@]}" \ + $FINAL_ARGS \ + $FINAL_IMAGE "$@" +run_exit_code=$? + +# Attempt to delete container +rm_output=$($OCI_EXE rm -f $CONTAINER_NAME 2>&1) +rm_exit_code=$? +if [[ $rm_exit_code != 0 ]]; then + if [[ "$CIRCLECI" == "true" ]] && [[ $rm_output == *"Driver btrfs failed to remove"* ]]; then + : # Ignore error because of https://circleci.com/docs/docker-btrfs-error/ + else + echo "$rm_output" + exit $rm_exit_code + fi +fi + +exit $run_exit_code + +################################################################################ +# +# This image is not intended to be run manually. +# +# To create a dockcross helper script for the +# dockcross/android-arm:20240418-88c04a4 image, run: +# +# docker run --rm dockcross/android-arm:20240418-88c04a4 > dockcross-android-arm-20240418-88c04a4 +# chmod +x dockcross-android-arm-20240418-88c04a4 +# +# You may then wish to move the dockcross script to your PATH. +# +################################################################################ diff --git a/.github/dockcross/dockcross-android-arm64 b/.github/dockcross/dockcross-android-arm64 new file mode 100755 index 00000000..50452754 --- /dev/null +++ b/.github/dockcross/dockcross-android-arm64 @@ -0,0 +1,278 @@ +#!/usr/bin/env bash + +DEFAULT_DOCKCROSS_IMAGE=dockcross/android-arm64:20240418-88c04a4 + +#------------------------------------------------------------------------------ +# Helpers +# +err() { + echo -e >&2 "ERROR: $*\n" +} + +die() { + err "$*" + exit 1 +} + +has() { + # eg. has command update + local kind=$1 + local name=$2 + + type -t $kind:$name | grep -q function +} + +# If OCI_EXE is not already set, search for a container executor (OCI stands for "Open Container Initiative") +if [ -z "$OCI_EXE" ]; then + if which podman >/dev/null 2>/dev/null; then + OCI_EXE=podman + elif which docker >/dev/null 2>/dev/null; then + OCI_EXE=docker + else + die "Cannot find a container executor. Search for docker and podman." + fi +fi + +#------------------------------------------------------------------------------ +# Command handlers +# +command:update-image() { + $OCI_EXE pull $FINAL_IMAGE +} + +help:update-image() { + echo "Pull the latest $FINAL_IMAGE ." +} + +command:update-script() { + if cmp -s <( $OCI_EXE run --rm $FINAL_IMAGE ) $0; then + echo "$0 is up to date" + else + echo -n "Updating $0 ... " + $OCI_EXE run --rm $FINAL_IMAGE > $0 && echo ok + fi +} + +help:update-script() { + echo "Update $0 from $FINAL_IMAGE ." +} + +command:update() { + command:update-image + command:update-script +} + +help:update() { + echo "Pull the latest $FINAL_IMAGE, and then update $0 from that." +} + +command:help() { + if [[ $# != 0 ]]; then + if ! has command $1; then + err \"$1\" is not an dockcross command + command:help + elif ! has help $1; then + err No help found for \"$1\" + else + help:$1 + fi + else + cat >&2 < +ENDHELP + exit 1 + fi +} + +#------------------------------------------------------------------------------ +# Option processing +# +special_update_command='' +while [[ $# != 0 ]]; do + case $1 in + + --) + shift + break + ;; + + --args|-a) + ARG_ARGS="$2" + shift 2 + ;; + + --config|-c) + ARG_CONFIG="$2" + shift 2 + ;; + + --image|-i) + ARG_IMAGE="$2" + shift 2 + ;; + update|update-image|update-script) + special_update_command=$1 + break + ;; + -*) + err Unknown option \"$1\" + command:help + exit + ;; + + *) + break + ;; + + esac +done + +# The precedence for options is: +# 1. command-line arguments +# 2. environment variables +# 3. defaults + +# Source the config file if it exists +DEFAULT_DOCKCROSS_CONFIG=~/.dockcross +FINAL_CONFIG=${ARG_CONFIG-${DOCKCROSS_CONFIG-$DEFAULT_DOCKCROSS_CONFIG}} + +[[ -f "$FINAL_CONFIG" ]] && source "$FINAL_CONFIG" + +# Set the docker image +FINAL_IMAGE=${ARG_IMAGE-${DOCKCROSS_IMAGE-$DEFAULT_DOCKCROSS_IMAGE}} + +# Handle special update command +if [ "$special_update_command" != "" ]; then + case $special_update_command in + + update) + command:update + exit $? + ;; + + update-image) + command:update-image + exit $? + ;; + + update-script) + command:update-script + exit $? + ;; + + esac +fi + +# Set the docker run extra args (if any) +FINAL_ARGS=${ARG_ARGS-${DOCKCROSS_ARGS}} + +# Bash on Ubuntu on Windows +UBUNTU_ON_WINDOWS=$([ -e /proc/version ] && grep -l Microsoft /proc/version || echo "") +# MSYS, Git Bash, etc. +MSYS=$([ -e /proc/version ] && grep -l MINGW /proc/version || echo "") +# CYGWIN +CYGWIN=$([ -e /proc/version ] && grep -l CYGWIN /proc/version || echo "") + +if [ -z "$UBUNTU_ON_WINDOWS" -a -z "$MSYS" -a "$OCI_EXE" != "podman" ]; then + USER_IDS=(-e BUILDER_UID="$( id -u )" -e BUILDER_GID="$( id -g )" -e BUILDER_USER="$( id -un )" -e BUILDER_GROUP="$( id -gn )") +fi + +# Change the PWD when working in Docker on Windows +if [ -n "$UBUNTU_ON_WINDOWS" ]; then + WSL_ROOT="/mnt/" + CFG_FILE=/etc/wsl.conf + if [ -f "$CFG_FILE" ]; then + CFG_CONTENT=$(cat $CFG_FILE | sed -r '/[^=]+=[^=]+/!d' | sed -r 's/\s+=\s/=/g') + eval "$CFG_CONTENT" + if [ -n "$root" ]; then + WSL_ROOT=$root + fi + fi + HOST_PWD=`pwd -P` + HOST_PWD=${HOST_PWD/$WSL_ROOT//} +elif [ -n "$MSYS" ]; then + HOST_PWD=$PWD + HOST_PWD=${HOST_PWD/\//} + HOST_PWD=${HOST_PWD/\//:\/} +elif [ -n "$CYGWIN" ]; then + for f in pwd readlink cygpath ; do + test -n "$(type "${f}" )" || { echo >&2 "Missing functionality (${f}) (in cygwin)." ; exit 1 ; } ; + done ; + HOST_PWD="$( cygpath -w "$( readlink -f "$( pwd ;)" ; )" ; )" ; +else + HOST_PWD=$PWD + [ -L $HOST_PWD ] && HOST_PWD=$(readlink $HOST_PWD) +fi + +# Mount Additional Volumes +if [ -z "$SSH_DIR" ]; then + SSH_DIR="$HOME/.ssh" +fi + +HOST_VOLUMES= +if [ -e "$SSH_DIR" -a -z "$MSYS" ]; then + if test -n "${CYGWIN}" ; then + HOST_VOLUMES+="-v $(cygpath -w ${SSH_DIR} ; ):/home/$(id -un)/.ssh" ; + else + HOST_VOLUMES+="-v $SSH_DIR:/home/$(id -un)/.ssh" ; + fi ; +fi + +#------------------------------------------------------------------------------ +# Now, finally, run the command in a container +# +TTY_ARGS= +tty -s && [ -z "$MSYS" ] && TTY_ARGS=-ti +CONTAINER_NAME=dockcross_$RANDOM +$OCI_EXE run $TTY_ARGS --name $CONTAINER_NAME \ + -v "$HOST_PWD":/work \ + $HOST_VOLUMES \ + "${USER_IDS[@]}" \ + $FINAL_ARGS \ + $FINAL_IMAGE "$@" +run_exit_code=$? + +# Attempt to delete container +rm_output=$($OCI_EXE rm -f $CONTAINER_NAME 2>&1) +rm_exit_code=$? +if [[ $rm_exit_code != 0 ]]; then + if [[ "$CIRCLECI" == "true" ]] && [[ $rm_output == *"Driver btrfs failed to remove"* ]]; then + : # Ignore error because of https://circleci.com/docs/docker-btrfs-error/ + else + echo "$rm_output" + exit $rm_exit_code + fi +fi + +exit $run_exit_code + +################################################################################ +# +# This image is not intended to be run manually. +# +# To create a dockcross helper script for the +# dockcross/android-arm64:20240418-88c04a4 image, run: +# +# docker run --rm dockcross/android-arm64:20240418-88c04a4 > dockcross-android-arm64-20240418-88c04a4 +# chmod +x dockcross-android-arm64-20240418-88c04a4 +# +# You may then wish to move the dockcross script to your PATH. +# +################################################################################ diff --git a/.github/dockcross/dockcross-linux-arm64-lts b/.github/dockcross/dockcross-linux-arm64-lts new file mode 100755 index 00000000..6afd72f6 --- /dev/null +++ b/.github/dockcross/dockcross-linux-arm64-lts @@ -0,0 +1,278 @@ +#!/usr/bin/env bash + +DEFAULT_DOCKCROSS_IMAGE=dockcross/linux-arm64-lts:20230601-c2f5366 + +#------------------------------------------------------------------------------ +# Helpers +# +err() { + echo -e >&2 "ERROR: $*\n" +} + +die() { + err "$*" + exit 1 +} + +has() { + # eg. has command update + local kind=$1 + local name=$2 + + type -t $kind:$name | grep -q function +} + +# If OCI_EXE is not already set, search for a container executor (OCI stands for "Open Container Initiative") +if [ -z "$OCI_EXE" ]; then + if which podman >/dev/null 2>/dev/null; then + OCI_EXE=podman + elif which docker >/dev/null 2>/dev/null; then + OCI_EXE=docker + else + die "Cannot find a container executor. Search for docker and podman." + fi +fi + +#------------------------------------------------------------------------------ +# Command handlers +# +command:update-image() { + $OCI_EXE pull $FINAL_IMAGE +} + +help:update-image() { + echo "Pull the latest $FINAL_IMAGE ." +} + +command:update-script() { + if cmp -s <( $OCI_EXE run --rm $FINAL_IMAGE ) $0; then + echo "$0 is up to date" + else + echo -n "Updating $0 ... " + $OCI_EXE run --rm $FINAL_IMAGE > $0 && echo ok + fi +} + +help:update-script() { + echo "Update $0 from $FINAL_IMAGE ." +} + +command:update() { + command:update-image + command:update-script +} + +help:update() { + echo "Pull the latest $FINAL_IMAGE, and then update $0 from that." +} + +command:help() { + if [[ $# != 0 ]]; then + if ! has command $1; then + err \"$1\" is not an dockcross command + command:help + elif ! has help $1; then + err No help found for \"$1\" + else + help:$1 + fi + else + cat >&2 < +ENDHELP + exit 1 + fi +} + +#------------------------------------------------------------------------------ +# Option processing +# +special_update_command='' +while [[ $# != 0 ]]; do + case $1 in + + --) + shift + break + ;; + + --args|-a) + ARG_ARGS="$2" + shift 2 + ;; + + --config|-c) + ARG_CONFIG="$2" + shift 2 + ;; + + --image|-i) + ARG_IMAGE="$2" + shift 2 + ;; + update|update-image|update-script) + special_update_command=$1 + break + ;; + -*) + err Unknown option \"$1\" + command:help + exit + ;; + + *) + break + ;; + + esac +done + +# The precedence for options is: +# 1. command-line arguments +# 2. environment variables +# 3. defaults + +# Source the config file if it exists +DEFAULT_DOCKCROSS_CONFIG=~/.dockcross +FINAL_CONFIG=${ARG_CONFIG-${DOCKCROSS_CONFIG-$DEFAULT_DOCKCROSS_CONFIG}} + +[[ -f "$FINAL_CONFIG" ]] && source "$FINAL_CONFIG" + +# Set the docker image +FINAL_IMAGE=${ARG_IMAGE-${DOCKCROSS_IMAGE-$DEFAULT_DOCKCROSS_IMAGE}} + +# Handle special update command +if [ "$special_update_command" != "" ]; then + case $special_update_command in + + update) + command:update + exit $? + ;; + + update-image) + command:update-image + exit $? + ;; + + update-script) + command:update-script + exit $? + ;; + + esac +fi + +# Set the docker run extra args (if any) +FINAL_ARGS=${ARG_ARGS-${DOCKCROSS_ARGS}} + +# Bash on Ubuntu on Windows +UBUNTU_ON_WINDOWS=$([ -e /proc/version ] && grep -l Microsoft /proc/version || echo "") +# MSYS, Git Bash, etc. +MSYS=$([ -e /proc/version ] && grep -l MINGW /proc/version || echo "") +# CYGWIN +CYGWIN=$([ -e /proc/version ] && grep -l CYGWIN /proc/version || echo "") + +if [ -z "$UBUNTU_ON_WINDOWS" -a -z "$MSYS" -a "$OCI_EXE" != "podman" ]; then + USER_IDS=(-e BUILDER_UID="$( id -u )" -e BUILDER_GID="$( id -g )" -e BUILDER_USER="$( id -un )" -e BUILDER_GROUP="$( id -gn )") +fi + +# Change the PWD when working in Docker on Windows +if [ -n "$UBUNTU_ON_WINDOWS" ]; then + WSL_ROOT="/mnt/" + CFG_FILE=/etc/wsl.conf + if [ -f "$CFG_FILE" ]; then + CFG_CONTENT=$(cat $CFG_FILE | sed -r '/[^=]+=[^=]+/!d' | sed -r 's/\s+=\s/=/g') + eval "$CFG_CONTENT" + if [ -n "$root" ]; then + WSL_ROOT=$root + fi + fi + HOST_PWD=`pwd -P` + HOST_PWD=${HOST_PWD/$WSL_ROOT//} +elif [ -n "$MSYS" ]; then + HOST_PWD=$PWD + HOST_PWD=${HOST_PWD/\//} + HOST_PWD=${HOST_PWD/\//:\/} +elif [ -n "$CYGWIN" ]; then + for f in pwd readlink cygpath ; do + test -n "$(type "${f}" )" || { echo >&2 "Missing functionality (${f}) (in cygwin)." ; exit 1 ; } ; + done ; + HOST_PWD="$( cygpath -w "$( readlink -f "$( pwd ;)" ; )" ; )" ; +else + HOST_PWD=$PWD + [ -L $HOST_PWD ] && HOST_PWD=$(readlink $HOST_PWD) +fi + +# Mount Additional Volumes +if [ -z "$SSH_DIR" ]; then + SSH_DIR="$HOME/.ssh" +fi + +HOST_VOLUMES= +if [ -e "$SSH_DIR" -a -z "$MSYS" ]; then + if test -n "${CYGWIN}" ; then + HOST_VOLUMES+="-v $(cygpath -w ${SSH_DIR} ; ):/home/$(id -un)/.ssh" ; + else + HOST_VOLUMES+="-v $SSH_DIR:/home/$(id -un)/.ssh" ; + fi ; +fi + +#------------------------------------------------------------------------------ +# Now, finally, run the command in a container +# +TTY_ARGS= +tty -s && [ -z "$MSYS" ] && TTY_ARGS=-ti +CONTAINER_NAME=dockcross_$RANDOM +$OCI_EXE run $TTY_ARGS --name $CONTAINER_NAME \ + -v "$HOST_PWD":/work \ + $HOST_VOLUMES \ + "${USER_IDS[@]}" \ + $FINAL_ARGS \ + $FINAL_IMAGE "$@" +run_exit_code=$? + +# Attempt to delete container +rm_output=$($OCI_EXE rm -f $CONTAINER_NAME 2>&1) +rm_exit_code=$? +if [[ $rm_exit_code != 0 ]]; then + if [[ "$CIRCLECI" == "true" ]] && [[ $rm_output == *"Driver btrfs failed to remove"* ]]; then + : # Ignore error because of https://circleci.com/docs/docker-btrfs-error/ + else + echo "$rm_output" + exit $rm_exit_code + fi +fi + +exit $run_exit_code + +################################################################################ +# +# This image is not intended to be run manually. +# +# To create a dockcross helper script for the +# dockcross/linux-arm64-lts:20230601-c2f5366 image, run: +# +# docker run --rm dockcross/linux-arm64-lts:20230601-c2f5366 > dockcross-linux-arm64-lts-20230601-c2f5366 +# chmod +x dockcross-linux-arm64-lts-20230601-c2f5366 +# +# You may then wish to move the dockcross script to your PATH. +# +################################################################################ diff --git a/.github/dockcross/dockcross-manylinux2014-x64 b/.github/dockcross/dockcross-manylinux2014-x64 new file mode 100755 index 00000000..5fc98484 --- /dev/null +++ b/.github/dockcross/dockcross-manylinux2014-x64 @@ -0,0 +1,278 @@ +#!/usr/bin/env bash + +DEFAULT_DOCKCROSS_IMAGE=dockcross/manylinux2014-x64:20230601-c2f5366 + +#------------------------------------------------------------------------------ +# Helpers +# +err() { + echo -e >&2 "ERROR: $*\n" +} + +die() { + err "$*" + exit 1 +} + +has() { + # eg. has command update + local kind=$1 + local name=$2 + + type -t $kind:$name | grep -q function +} + +# If OCI_EXE is not already set, search for a container executor (OCI stands for "Open Container Initiative") +if [ -z "$OCI_EXE" ]; then + if which podman >/dev/null 2>/dev/null; then + OCI_EXE=podman + elif which docker >/dev/null 2>/dev/null; then + OCI_EXE=docker + else + die "Cannot find a container executor. Search for docker and podman." + fi +fi + +#------------------------------------------------------------------------------ +# Command handlers +# +command:update-image() { + $OCI_EXE pull $FINAL_IMAGE +} + +help:update-image() { + echo "Pull the latest $FINAL_IMAGE ." +} + +command:update-script() { + if cmp -s <( $OCI_EXE run --rm $FINAL_IMAGE ) $0; then + echo "$0 is up to date" + else + echo -n "Updating $0 ... " + $OCI_EXE run --rm $FINAL_IMAGE > $0 && echo ok + fi +} + +help:update-script() { + echo "Update $0 from $FINAL_IMAGE ." +} + +command:update() { + command:update-image + command:update-script +} + +help:update() { + echo "Pull the latest $FINAL_IMAGE, and then update $0 from that." +} + +command:help() { + if [[ $# != 0 ]]; then + if ! has command $1; then + err \"$1\" is not an dockcross command + command:help + elif ! has help $1; then + err No help found for \"$1\" + else + help:$1 + fi + else + cat >&2 < +ENDHELP + exit 1 + fi +} + +#------------------------------------------------------------------------------ +# Option processing +# +special_update_command='' +while [[ $# != 0 ]]; do + case $1 in + + --) + shift + break + ;; + + --args|-a) + ARG_ARGS="$2" + shift 2 + ;; + + --config|-c) + ARG_CONFIG="$2" + shift 2 + ;; + + --image|-i) + ARG_IMAGE="$2" + shift 2 + ;; + update|update-image|update-script) + special_update_command=$1 + break + ;; + -*) + err Unknown option \"$1\" + command:help + exit + ;; + + *) + break + ;; + + esac +done + +# The precedence for options is: +# 1. command-line arguments +# 2. environment variables +# 3. defaults + +# Source the config file if it exists +DEFAULT_DOCKCROSS_CONFIG=~/.dockcross +FINAL_CONFIG=${ARG_CONFIG-${DOCKCROSS_CONFIG-$DEFAULT_DOCKCROSS_CONFIG}} + +[[ -f "$FINAL_CONFIG" ]] && source "$FINAL_CONFIG" + +# Set the docker image +FINAL_IMAGE=${ARG_IMAGE-${DOCKCROSS_IMAGE-$DEFAULT_DOCKCROSS_IMAGE}} + +# Handle special update command +if [ "$special_update_command" != "" ]; then + case $special_update_command in + + update) + command:update + exit $? + ;; + + update-image) + command:update-image + exit $? + ;; + + update-script) + command:update-script + exit $? + ;; + + esac +fi + +# Set the docker run extra args (if any) +FINAL_ARGS=${ARG_ARGS-${DOCKCROSS_ARGS}} + +# Bash on Ubuntu on Windows +UBUNTU_ON_WINDOWS=$([ -e /proc/version ] && grep -l Microsoft /proc/version || echo "") +# MSYS, Git Bash, etc. +MSYS=$([ -e /proc/version ] && grep -l MINGW /proc/version || echo "") +# CYGWIN +CYGWIN=$([ -e /proc/version ] && grep -l CYGWIN /proc/version || echo "") + +if [ -z "$UBUNTU_ON_WINDOWS" -a -z "$MSYS" -a "$OCI_EXE" != "podman" ]; then + USER_IDS=(-e BUILDER_UID="$( id -u )" -e BUILDER_GID="$( id -g )" -e BUILDER_USER="$( id -un )" -e BUILDER_GROUP="$( id -gn )") +fi + +# Change the PWD when working in Docker on Windows +if [ -n "$UBUNTU_ON_WINDOWS" ]; then + WSL_ROOT="/mnt/" + CFG_FILE=/etc/wsl.conf + if [ -f "$CFG_FILE" ]; then + CFG_CONTENT=$(cat $CFG_FILE | sed -r '/[^=]+=[^=]+/!d' | sed -r 's/\s+=\s/=/g') + eval "$CFG_CONTENT" + if [ -n "$root" ]; then + WSL_ROOT=$root + fi + fi + HOST_PWD=`pwd -P` + HOST_PWD=${HOST_PWD/$WSL_ROOT//} +elif [ -n "$MSYS" ]; then + HOST_PWD=$PWD + HOST_PWD=${HOST_PWD/\//} + HOST_PWD=${HOST_PWD/\//:\/} +elif [ -n "$CYGWIN" ]; then + for f in pwd readlink cygpath ; do + test -n "$(type "${f}" )" || { echo >&2 "Missing functionality (${f}) (in cygwin)." ; exit 1 ; } ; + done ; + HOST_PWD="$( cygpath -w "$( readlink -f "$( pwd ;)" ; )" ; )" ; +else + HOST_PWD=$PWD + [ -L $HOST_PWD ] && HOST_PWD=$(readlink $HOST_PWD) +fi + +# Mount Additional Volumes +if [ -z "$SSH_DIR" ]; then + SSH_DIR="$HOME/.ssh" +fi + +HOST_VOLUMES= +if [ -e "$SSH_DIR" -a -z "$MSYS" ]; then + if test -n "${CYGWIN}" ; then + HOST_VOLUMES+="-v $(cygpath -w ${SSH_DIR} ; ):/home/$(id -un)/.ssh" ; + else + HOST_VOLUMES+="-v $SSH_DIR:/home/$(id -un)/.ssh" ; + fi ; +fi + +#------------------------------------------------------------------------------ +# Now, finally, run the command in a container +# +TTY_ARGS= +tty -s && [ -z "$MSYS" ] && TTY_ARGS=-ti +CONTAINER_NAME=dockcross_$RANDOM +$OCI_EXE run $TTY_ARGS --name $CONTAINER_NAME \ + -v "$HOST_PWD":/work \ + $HOST_VOLUMES \ + "${USER_IDS[@]}" \ + $FINAL_ARGS \ + $FINAL_IMAGE "$@" +run_exit_code=$? + +# Attempt to delete container +rm_output=$($OCI_EXE rm -f $CONTAINER_NAME 2>&1) +rm_exit_code=$? +if [[ $rm_exit_code != 0 ]]; then + if [[ "$CIRCLECI" == "true" ]] && [[ $rm_output == *"Driver btrfs failed to remove"* ]]; then + : # Ignore error because of https://circleci.com/docs/docker-btrfs-error/ + else + echo "$rm_output" + exit $rm_exit_code + fi +fi + +exit $run_exit_code + +################################################################################ +# +# This image is not intended to be run manually. +# +# To create a dockcross helper script for the +# dockcross/manylinux2014-x64:20230601-c2f5366 image, run: +# +# docker run --rm dockcross/manylinux2014-x64:20230601-c2f5366 > dockcross-manylinux2014-x64-20230601-c2f5366 +# chmod +x dockcross-manylinux2014-x64-20230601-c2f5366 +# +# You may then wish to move the dockcross script to your PATH. +# +################################################################################ diff --git a/.github/dockcross/dockcross-manylinux_2_28-x64 b/.github/dockcross/dockcross-manylinux_2_28-x64 new file mode 100755 index 00000000..c363e9fa --- /dev/null +++ b/.github/dockcross/dockcross-manylinux_2_28-x64 @@ -0,0 +1,278 @@ +#!/usr/bin/env bash + +DEFAULT_DOCKCROSS_IMAGE=dockcross/manylinux_2_28-x64:20240812-60fa1b0 + +#------------------------------------------------------------------------------ +# Helpers +# +err() { + echo -e >&2 "ERROR: $*\n" +} + +die() { + err "$*" + exit 1 +} + +has() { + # eg. has command update + local kind=$1 + local name=$2 + + type -t $kind:$name | grep -q function +} + +# If OCI_EXE is not already set, search for a container executor (OCI stands for "Open Container Initiative") +if [ -z "$OCI_EXE" ]; then + if which podman >/dev/null 2>/dev/null; then + OCI_EXE=podman + elif which docker >/dev/null 2>/dev/null; then + OCI_EXE=docker + else + die "Cannot find a container executor. Search for docker and podman." + fi +fi + +#------------------------------------------------------------------------------ +# Command handlers +# +command:update-image() { + $OCI_EXE pull $FINAL_IMAGE +} + +help:update-image() { + echo "Pull the latest $FINAL_IMAGE ." +} + +command:update-script() { + if cmp -s <( $OCI_EXE run --rm $FINAL_IMAGE ) $0; then + echo "$0 is up to date" + else + echo -n "Updating $0 ... " + $OCI_EXE run --rm $FINAL_IMAGE > $0 && echo ok + fi +} + +help:update-script() { + echo "Update $0 from $FINAL_IMAGE ." +} + +command:update() { + command:update-image + command:update-script +} + +help:update() { + echo "Pull the latest $FINAL_IMAGE, and then update $0 from that." +} + +command:help() { + if [[ $# != 0 ]]; then + if ! has command $1; then + err \"$1\" is not an dockcross command + command:help + elif ! has help $1; then + err No help found for \"$1\" + else + help:$1 + fi + else + cat >&2 < +ENDHELP + exit 1 + fi +} + +#------------------------------------------------------------------------------ +# Option processing +# +special_update_command='' +while [[ $# != 0 ]]; do + case $1 in + + --) + shift + break + ;; + + --args|-a) + ARG_ARGS="$2" + shift 2 + ;; + + --config|-c) + ARG_CONFIG="$2" + shift 2 + ;; + + --image|-i) + ARG_IMAGE="$2" + shift 2 + ;; + update|update-image|update-script) + special_update_command=$1 + break + ;; + -*) + err Unknown option \"$1\" + command:help + exit + ;; + + *) + break + ;; + + esac +done + +# The precedence for options is: +# 1. command-line arguments +# 2. environment variables +# 3. defaults + +# Source the config file if it exists +DEFAULT_DOCKCROSS_CONFIG=~/.dockcross +FINAL_CONFIG=${ARG_CONFIG-${DOCKCROSS_CONFIG-$DEFAULT_DOCKCROSS_CONFIG}} + +[[ -f "$FINAL_CONFIG" ]] && source "$FINAL_CONFIG" + +# Set the docker image +FINAL_IMAGE=${ARG_IMAGE-${DOCKCROSS_IMAGE-$DEFAULT_DOCKCROSS_IMAGE}} + +# Handle special update command +if [ "$special_update_command" != "" ]; then + case $special_update_command in + + update) + command:update + exit $? + ;; + + update-image) + command:update-image + exit $? + ;; + + update-script) + command:update-script + exit $? + ;; + + esac +fi + +# Set the docker run extra args (if any) +FINAL_ARGS=${ARG_ARGS-${DOCKCROSS_ARGS}} + +# Bash on Ubuntu on Windows +UBUNTU_ON_WINDOWS=$([ -e /proc/version ] && grep -l Microsoft /proc/version || echo "") +# MSYS, Git Bash, etc. +MSYS=$([ -e /proc/version ] && grep -l MINGW /proc/version || echo "") +# CYGWIN +CYGWIN=$([ -e /proc/version ] && grep -l CYGWIN /proc/version || echo "") + +if [ -z "$UBUNTU_ON_WINDOWS" -a -z "$MSYS" -a "$OCI_EXE" != "podman" ]; then + USER_IDS=(-e BUILDER_UID="$( id -u )" -e BUILDER_GID="$( id -g )" -e BUILDER_USER="$( id -un )" -e BUILDER_GROUP="$( id -gn )") +fi + +# Change the PWD when working in Docker on Windows +if [ -n "$UBUNTU_ON_WINDOWS" ]; then + WSL_ROOT="/mnt/" + CFG_FILE=/etc/wsl.conf + if [ -f "$CFG_FILE" ]; then + CFG_CONTENT=$(cat $CFG_FILE | sed -r '/[^=]+=[^=]+/!d' | sed -r 's/\s+=\s/=/g') + eval "$CFG_CONTENT" + if [ -n "$root" ]; then + WSL_ROOT=$root + fi + fi + HOST_PWD=`pwd -P` + HOST_PWD=${HOST_PWD/$WSL_ROOT//} +elif [ -n "$MSYS" ]; then + HOST_PWD=$PWD + HOST_PWD=${HOST_PWD/\//} + HOST_PWD=${HOST_PWD/\//:\/} +elif [ -n "$CYGWIN" ]; then + for f in pwd readlink cygpath ; do + test -n "$(type "${f}" )" || { echo >&2 "Missing functionality (${f}) (in cygwin)." ; exit 1 ; } ; + done ; + HOST_PWD="$( cygpath -w "$( readlink -f "$( pwd ;)" ; )" ; )" ; +else + HOST_PWD=$PWD + [ -L $HOST_PWD ] && HOST_PWD=$(readlink $HOST_PWD) +fi + +# Mount Additional Volumes +if [ -z "$SSH_DIR" ]; then + SSH_DIR="$HOME/.ssh" +fi + +HOST_VOLUMES= +if [ -e "$SSH_DIR" -a -z "$MSYS" ]; then + if test -n "${CYGWIN}" ; then + HOST_VOLUMES+="-v $(cygpath -w ${SSH_DIR} ; ):/home/$(id -un)/.ssh" ; + else + HOST_VOLUMES+="-v $SSH_DIR:/home/$(id -un)/.ssh" ; + fi ; +fi + +#------------------------------------------------------------------------------ +# Now, finally, run the command in a container +# +TTY_ARGS= +tty -s && [ -z "$MSYS" ] && TTY_ARGS=-ti +CONTAINER_NAME=dockcross_$RANDOM +$OCI_EXE run $TTY_ARGS --name $CONTAINER_NAME \ + -v "$HOST_PWD":/work \ + $HOST_VOLUMES \ + "${USER_IDS[@]}" \ + $FINAL_ARGS \ + $FINAL_IMAGE "$@" +run_exit_code=$? + +# Attempt to delete container +rm_output=$($OCI_EXE rm -f $CONTAINER_NAME 2>&1) +rm_exit_code=$? +if [[ $rm_exit_code != 0 ]]; then + if [[ "$CIRCLECI" == "true" ]] && [[ $rm_output == *"Driver btrfs failed to remove"* ]]; then + : # Ignore error because of https://circleci.com/docs/docker-btrfs-error/ + else + echo "$rm_output" + exit $rm_exit_code + fi +fi + +exit $run_exit_code + +################################################################################ +# +# This image is not intended to be run manually. +# +# To create a dockcross helper script for the +# dockcross/manylinux_2_28-x64:20240812-60fa1b0 image, run: +# +# docker run --rm dockcross/manylinux_2_28-x64:20240812-60fa1b0 > dockcross-manylinux_2_28-x64-20240812-60fa1b0 +# chmod +x dockcross-manylinux_2_28-x64-20240812-60fa1b0 +# +# You may then wish to move the dockcross script to your PATH. +# +################################################################################ diff --git a/.github/dockcross/update.sh b/.github/dockcross/update.sh new file mode 100755 index 00000000..5898ac80 --- /dev/null +++ b/.github/dockcross/update.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +# This script prints the commands to upgrade the docker cross compilation scripts +docker run --rm dockcross/manylinux2014-x64 > ./dockcross-manylinux2014-x64 +docker run --rm dockcross/manylinux_2_28-x64 > ./dockcross-manylinux_2_28-x64 +docker run --rm dockcross/manylinux2014-x86 > ./dockcross-manylinux2014-x86 +docker run --rm dockcross/linux-arm64-lts > ./dockcross-linux-arm64-lts +docker run --rm dockcross/android-arm > ./dockcross-android-arm +docker run --rm dockcross/android-arm64 > ./dockcross-android-arm64 +docker run --rm dockcross/android-x86 > ./dockcross-android-x86 +docker run --rm dockcross/android-x86_64 > ./dockcross-android-x86_64 +chmod +x ./dockcross-* diff --git a/.github/include/unix/jni.h b/.github/include/unix/jni.h new file mode 100644 index 00000000..c85da1bc --- /dev/null +++ b/.github/include/unix/jni.h @@ -0,0 +1,2001 @@ +/* + * Copyright (c) 1996, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * We used part of Netscape's Java Runtime Interface (JRI) as the starting + * point of our design and implementation. + */ + +/****************************************************************************** + * Java Runtime Interface + * Copyright (c) 1996 Netscape Communications Corporation. All rights reserved. + *****************************************************************************/ + +#ifndef _JAVASOFT_JNI_H_ +#define _JAVASOFT_JNI_H_ + +#include +#include + +/* jni_md.h contains the machine-dependent typedefs for jbyte, jint + and jlong */ + +#include "jni_md.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * JNI Types + */ + +#ifndef JNI_TYPES_ALREADY_DEFINED_IN_JNI_MD_H + +typedef unsigned char jboolean; +typedef unsigned short jchar; +typedef short jshort; +typedef float jfloat; +typedef double jdouble; + +typedef jint jsize; + +#ifdef __cplusplus + +class _jobject {}; +class _jclass : public _jobject {}; +class _jthrowable : public _jobject {}; +class _jstring : public _jobject {}; +class _jarray : public _jobject {}; +class _jbooleanArray : public _jarray {}; +class _jbyteArray : public _jarray {}; +class _jcharArray : public _jarray {}; +class _jshortArray : public _jarray {}; +class _jintArray : public _jarray {}; +class _jlongArray : public _jarray {}; +class _jfloatArray : public _jarray {}; +class _jdoubleArray : public _jarray {}; +class _jobjectArray : public _jarray {}; + +typedef _jobject *jobject; +typedef _jclass *jclass; +typedef _jthrowable *jthrowable; +typedef _jstring *jstring; +typedef _jarray *jarray; +typedef _jbooleanArray *jbooleanArray; +typedef _jbyteArray *jbyteArray; +typedef _jcharArray *jcharArray; +typedef _jshortArray *jshortArray; +typedef _jintArray *jintArray; +typedef _jlongArray *jlongArray; +typedef _jfloatArray *jfloatArray; +typedef _jdoubleArray *jdoubleArray; +typedef _jobjectArray *jobjectArray; + +#else + +struct _jobject; + +typedef struct _jobject *jobject; +typedef jobject jclass; +typedef jobject jthrowable; +typedef jobject jstring; +typedef jobject jarray; +typedef jarray jbooleanArray; +typedef jarray jbyteArray; +typedef jarray jcharArray; +typedef jarray jshortArray; +typedef jarray jintArray; +typedef jarray jlongArray; +typedef jarray jfloatArray; +typedef jarray jdoubleArray; +typedef jarray jobjectArray; + +#endif + +typedef jobject jweak; + +typedef union jvalue { + jboolean z; + jbyte b; + jchar c; + jshort s; + jint i; + jlong j; + jfloat f; + jdouble d; + jobject l; +} jvalue; + +struct _jfieldID; +typedef struct _jfieldID *jfieldID; + +struct _jmethodID; +typedef struct _jmethodID *jmethodID; + +/* Return values from jobjectRefType */ +typedef enum _jobjectType { + JNIInvalidRefType = 0, + JNILocalRefType = 1, + JNIGlobalRefType = 2, + JNIWeakGlobalRefType = 3 +} jobjectRefType; + + +#endif /* JNI_TYPES_ALREADY_DEFINED_IN_JNI_MD_H */ + +/* + * jboolean constants + */ + +#define JNI_FALSE 0 +#define JNI_TRUE 1 + +/* + * possible return values for JNI functions. + */ + +#define JNI_OK 0 /* success */ +#define JNI_ERR (-1) /* unknown error */ +#define JNI_EDETACHED (-2) /* thread detached from the VM */ +#define JNI_EVERSION (-3) /* JNI version error */ +#define JNI_ENOMEM (-4) /* not enough memory */ +#define JNI_EEXIST (-5) /* VM already created */ +#define JNI_EINVAL (-6) /* invalid arguments */ + +/* + * used in ReleaseScalarArrayElements + */ + +#define JNI_COMMIT 1 +#define JNI_ABORT 2 + +/* + * used in RegisterNatives to describe native method name, signature, + * and function pointer. + */ + +typedef struct { + char *name; + char *signature; + void *fnPtr; +} JNINativeMethod; + +/* + * JNI Native Method Interface. + */ + +struct JNINativeInterface_; + +struct JNIEnv_; + +#ifdef __cplusplus +typedef JNIEnv_ JNIEnv; +#else +typedef const struct JNINativeInterface_ *JNIEnv; +#endif + +/* + * JNI Invocation Interface. + */ + +struct JNIInvokeInterface_; + +struct JavaVM_; + +#ifdef __cplusplus +typedef JavaVM_ JavaVM; +#else +typedef const struct JNIInvokeInterface_ *JavaVM; +#endif + +struct JNINativeInterface_ { + void *reserved0; + void *reserved1; + void *reserved2; + + void *reserved3; + jint (JNICALL *GetVersion)(JNIEnv *env); + + jclass (JNICALL *DefineClass) + (JNIEnv *env, const char *name, jobject loader, const jbyte *buf, + jsize len); + jclass (JNICALL *FindClass) + (JNIEnv *env, const char *name); + + jmethodID (JNICALL *FromReflectedMethod) + (JNIEnv *env, jobject method); + jfieldID (JNICALL *FromReflectedField) + (JNIEnv *env, jobject field); + + jobject (JNICALL *ToReflectedMethod) + (JNIEnv *env, jclass cls, jmethodID methodID, jboolean isStatic); + + jclass (JNICALL *GetSuperclass) + (JNIEnv *env, jclass sub); + jboolean (JNICALL *IsAssignableFrom) + (JNIEnv *env, jclass sub, jclass sup); + + jobject (JNICALL *ToReflectedField) + (JNIEnv *env, jclass cls, jfieldID fieldID, jboolean isStatic); + + jint (JNICALL *Throw) + (JNIEnv *env, jthrowable obj); + jint (JNICALL *ThrowNew) + (JNIEnv *env, jclass clazz, const char *msg); + jthrowable (JNICALL *ExceptionOccurred) + (JNIEnv *env); + void (JNICALL *ExceptionDescribe) + (JNIEnv *env); + void (JNICALL *ExceptionClear) + (JNIEnv *env); + void (JNICALL *FatalError) + (JNIEnv *env, const char *msg); + + jint (JNICALL *PushLocalFrame) + (JNIEnv *env, jint capacity); + jobject (JNICALL *PopLocalFrame) + (JNIEnv *env, jobject result); + + jobject (JNICALL *NewGlobalRef) + (JNIEnv *env, jobject lobj); + void (JNICALL *DeleteGlobalRef) + (JNIEnv *env, jobject gref); + void (JNICALL *DeleteLocalRef) + (JNIEnv *env, jobject obj); + jboolean (JNICALL *IsSameObject) + (JNIEnv *env, jobject obj1, jobject obj2); + jobject (JNICALL *NewLocalRef) + (JNIEnv *env, jobject ref); + jint (JNICALL *EnsureLocalCapacity) + (JNIEnv *env, jint capacity); + + jobject (JNICALL *AllocObject) + (JNIEnv *env, jclass clazz); + jobject (JNICALL *NewObject) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jobject (JNICALL *NewObjectV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jobject (JNICALL *NewObjectA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jclass (JNICALL *GetObjectClass) + (JNIEnv *env, jobject obj); + jboolean (JNICALL *IsInstanceOf) + (JNIEnv *env, jobject obj, jclass clazz); + + jmethodID (JNICALL *GetMethodID) + (JNIEnv *env, jclass clazz, const char *name, const char *sig); + + jobject (JNICALL *CallObjectMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jobject (JNICALL *CallObjectMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jobject (JNICALL *CallObjectMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue * args); + + jboolean (JNICALL *CallBooleanMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jboolean (JNICALL *CallBooleanMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jboolean (JNICALL *CallBooleanMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue * args); + + jbyte (JNICALL *CallByteMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jbyte (JNICALL *CallByteMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jbyte (JNICALL *CallByteMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + jchar (JNICALL *CallCharMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jchar (JNICALL *CallCharMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jchar (JNICALL *CallCharMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + jshort (JNICALL *CallShortMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jshort (JNICALL *CallShortMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jshort (JNICALL *CallShortMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + jint (JNICALL *CallIntMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jint (JNICALL *CallIntMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jint (JNICALL *CallIntMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + jlong (JNICALL *CallLongMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jlong (JNICALL *CallLongMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jlong (JNICALL *CallLongMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + jfloat (JNICALL *CallFloatMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jfloat (JNICALL *CallFloatMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jfloat (JNICALL *CallFloatMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + jdouble (JNICALL *CallDoubleMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jdouble (JNICALL *CallDoubleMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jdouble (JNICALL *CallDoubleMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + void (JNICALL *CallVoidMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + void (JNICALL *CallVoidMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + void (JNICALL *CallVoidMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue * args); + + jobject (JNICALL *CallNonvirtualObjectMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jobject (JNICALL *CallNonvirtualObjectMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jobject (JNICALL *CallNonvirtualObjectMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue * args); + + jboolean (JNICALL *CallNonvirtualBooleanMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jboolean (JNICALL *CallNonvirtualBooleanMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jboolean (JNICALL *CallNonvirtualBooleanMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue * args); + + jbyte (JNICALL *CallNonvirtualByteMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jbyte (JNICALL *CallNonvirtualByteMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jbyte (JNICALL *CallNonvirtualByteMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + jchar (JNICALL *CallNonvirtualCharMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jchar (JNICALL *CallNonvirtualCharMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jchar (JNICALL *CallNonvirtualCharMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + jshort (JNICALL *CallNonvirtualShortMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jshort (JNICALL *CallNonvirtualShortMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jshort (JNICALL *CallNonvirtualShortMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + jint (JNICALL *CallNonvirtualIntMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jint (JNICALL *CallNonvirtualIntMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jint (JNICALL *CallNonvirtualIntMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + jlong (JNICALL *CallNonvirtualLongMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jlong (JNICALL *CallNonvirtualLongMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jlong (JNICALL *CallNonvirtualLongMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + jfloat (JNICALL *CallNonvirtualFloatMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jfloat (JNICALL *CallNonvirtualFloatMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jfloat (JNICALL *CallNonvirtualFloatMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + jdouble (JNICALL *CallNonvirtualDoubleMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jdouble (JNICALL *CallNonvirtualDoubleMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jdouble (JNICALL *CallNonvirtualDoubleMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + void (JNICALL *CallNonvirtualVoidMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + void (JNICALL *CallNonvirtualVoidMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + void (JNICALL *CallNonvirtualVoidMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue * args); + + jfieldID (JNICALL *GetFieldID) + (JNIEnv *env, jclass clazz, const char *name, const char *sig); + + jobject (JNICALL *GetObjectField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jboolean (JNICALL *GetBooleanField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jbyte (JNICALL *GetByteField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jchar (JNICALL *GetCharField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jshort (JNICALL *GetShortField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jint (JNICALL *GetIntField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jlong (JNICALL *GetLongField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jfloat (JNICALL *GetFloatField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jdouble (JNICALL *GetDoubleField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + + void (JNICALL *SetObjectField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jobject val); + void (JNICALL *SetBooleanField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jboolean val); + void (JNICALL *SetByteField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jbyte val); + void (JNICALL *SetCharField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jchar val); + void (JNICALL *SetShortField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jshort val); + void (JNICALL *SetIntField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jint val); + void (JNICALL *SetLongField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jlong val); + void (JNICALL *SetFloatField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jfloat val); + void (JNICALL *SetDoubleField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jdouble val); + + jmethodID (JNICALL *GetStaticMethodID) + (JNIEnv *env, jclass clazz, const char *name, const char *sig); + + jobject (JNICALL *CallStaticObjectMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jobject (JNICALL *CallStaticObjectMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jobject (JNICALL *CallStaticObjectMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jboolean (JNICALL *CallStaticBooleanMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jboolean (JNICALL *CallStaticBooleanMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jboolean (JNICALL *CallStaticBooleanMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jbyte (JNICALL *CallStaticByteMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jbyte (JNICALL *CallStaticByteMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jbyte (JNICALL *CallStaticByteMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jchar (JNICALL *CallStaticCharMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jchar (JNICALL *CallStaticCharMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jchar (JNICALL *CallStaticCharMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jshort (JNICALL *CallStaticShortMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jshort (JNICALL *CallStaticShortMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jshort (JNICALL *CallStaticShortMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jint (JNICALL *CallStaticIntMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jint (JNICALL *CallStaticIntMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jint (JNICALL *CallStaticIntMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jlong (JNICALL *CallStaticLongMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jlong (JNICALL *CallStaticLongMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jlong (JNICALL *CallStaticLongMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jfloat (JNICALL *CallStaticFloatMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jfloat (JNICALL *CallStaticFloatMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jfloat (JNICALL *CallStaticFloatMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jdouble (JNICALL *CallStaticDoubleMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jdouble (JNICALL *CallStaticDoubleMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jdouble (JNICALL *CallStaticDoubleMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + void (JNICALL *CallStaticVoidMethod) + (JNIEnv *env, jclass cls, jmethodID methodID, ...); + void (JNICALL *CallStaticVoidMethodV) + (JNIEnv *env, jclass cls, jmethodID methodID, va_list args); + void (JNICALL *CallStaticVoidMethodA) + (JNIEnv *env, jclass cls, jmethodID methodID, const jvalue * args); + + jfieldID (JNICALL *GetStaticFieldID) + (JNIEnv *env, jclass clazz, const char *name, const char *sig); + jobject (JNICALL *GetStaticObjectField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jboolean (JNICALL *GetStaticBooleanField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jbyte (JNICALL *GetStaticByteField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jchar (JNICALL *GetStaticCharField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jshort (JNICALL *GetStaticShortField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jint (JNICALL *GetStaticIntField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jlong (JNICALL *GetStaticLongField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jfloat (JNICALL *GetStaticFloatField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jdouble (JNICALL *GetStaticDoubleField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + + void (JNICALL *SetStaticObjectField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jobject value); + void (JNICALL *SetStaticBooleanField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jboolean value); + void (JNICALL *SetStaticByteField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jbyte value); + void (JNICALL *SetStaticCharField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jchar value); + void (JNICALL *SetStaticShortField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jshort value); + void (JNICALL *SetStaticIntField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jint value); + void (JNICALL *SetStaticLongField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jlong value); + void (JNICALL *SetStaticFloatField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jfloat value); + void (JNICALL *SetStaticDoubleField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jdouble value); + + jstring (JNICALL *NewString) + (JNIEnv *env, const jchar *unicode, jsize len); + jsize (JNICALL *GetStringLength) + (JNIEnv *env, jstring str); + const jchar *(JNICALL *GetStringChars) + (JNIEnv *env, jstring str, jboolean *isCopy); + void (JNICALL *ReleaseStringChars) + (JNIEnv *env, jstring str, const jchar *chars); + + jstring (JNICALL *NewStringUTF) + (JNIEnv *env, const char *utf); + jsize (JNICALL *GetStringUTFLength) + (JNIEnv *env, jstring str); + const char* (JNICALL *GetStringUTFChars) + (JNIEnv *env, jstring str, jboolean *isCopy); + void (JNICALL *ReleaseStringUTFChars) + (JNIEnv *env, jstring str, const char* chars); + + + jsize (JNICALL *GetArrayLength) + (JNIEnv *env, jarray array); + + jobjectArray (JNICALL *NewObjectArray) + (JNIEnv *env, jsize len, jclass clazz, jobject init); + jobject (JNICALL *GetObjectArrayElement) + (JNIEnv *env, jobjectArray array, jsize index); + void (JNICALL *SetObjectArrayElement) + (JNIEnv *env, jobjectArray array, jsize index, jobject val); + + jbooleanArray (JNICALL *NewBooleanArray) + (JNIEnv *env, jsize len); + jbyteArray (JNICALL *NewByteArray) + (JNIEnv *env, jsize len); + jcharArray (JNICALL *NewCharArray) + (JNIEnv *env, jsize len); + jshortArray (JNICALL *NewShortArray) + (JNIEnv *env, jsize len); + jintArray (JNICALL *NewIntArray) + (JNIEnv *env, jsize len); + jlongArray (JNICALL *NewLongArray) + (JNIEnv *env, jsize len); + jfloatArray (JNICALL *NewFloatArray) + (JNIEnv *env, jsize len); + jdoubleArray (JNICALL *NewDoubleArray) + (JNIEnv *env, jsize len); + + jboolean * (JNICALL *GetBooleanArrayElements) + (JNIEnv *env, jbooleanArray array, jboolean *isCopy); + jbyte * (JNICALL *GetByteArrayElements) + (JNIEnv *env, jbyteArray array, jboolean *isCopy); + jchar * (JNICALL *GetCharArrayElements) + (JNIEnv *env, jcharArray array, jboolean *isCopy); + jshort * (JNICALL *GetShortArrayElements) + (JNIEnv *env, jshortArray array, jboolean *isCopy); + jint * (JNICALL *GetIntArrayElements) + (JNIEnv *env, jintArray array, jboolean *isCopy); + jlong * (JNICALL *GetLongArrayElements) + (JNIEnv *env, jlongArray array, jboolean *isCopy); + jfloat * (JNICALL *GetFloatArrayElements) + (JNIEnv *env, jfloatArray array, jboolean *isCopy); + jdouble * (JNICALL *GetDoubleArrayElements) + (JNIEnv *env, jdoubleArray array, jboolean *isCopy); + + void (JNICALL *ReleaseBooleanArrayElements) + (JNIEnv *env, jbooleanArray array, jboolean *elems, jint mode); + void (JNICALL *ReleaseByteArrayElements) + (JNIEnv *env, jbyteArray array, jbyte *elems, jint mode); + void (JNICALL *ReleaseCharArrayElements) + (JNIEnv *env, jcharArray array, jchar *elems, jint mode); + void (JNICALL *ReleaseShortArrayElements) + (JNIEnv *env, jshortArray array, jshort *elems, jint mode); + void (JNICALL *ReleaseIntArrayElements) + (JNIEnv *env, jintArray array, jint *elems, jint mode); + void (JNICALL *ReleaseLongArrayElements) + (JNIEnv *env, jlongArray array, jlong *elems, jint mode); + void (JNICALL *ReleaseFloatArrayElements) + (JNIEnv *env, jfloatArray array, jfloat *elems, jint mode); + void (JNICALL *ReleaseDoubleArrayElements) + (JNIEnv *env, jdoubleArray array, jdouble *elems, jint mode); + + void (JNICALL *GetBooleanArrayRegion) + (JNIEnv *env, jbooleanArray array, jsize start, jsize l, jboolean *buf); + void (JNICALL *GetByteArrayRegion) + (JNIEnv *env, jbyteArray array, jsize start, jsize len, jbyte *buf); + void (JNICALL *GetCharArrayRegion) + (JNIEnv *env, jcharArray array, jsize start, jsize len, jchar *buf); + void (JNICALL *GetShortArrayRegion) + (JNIEnv *env, jshortArray array, jsize start, jsize len, jshort *buf); + void (JNICALL *GetIntArrayRegion) + (JNIEnv *env, jintArray array, jsize start, jsize len, jint *buf); + void (JNICALL *GetLongArrayRegion) + (JNIEnv *env, jlongArray array, jsize start, jsize len, jlong *buf); + void (JNICALL *GetFloatArrayRegion) + (JNIEnv *env, jfloatArray array, jsize start, jsize len, jfloat *buf); + void (JNICALL *GetDoubleArrayRegion) + (JNIEnv *env, jdoubleArray array, jsize start, jsize len, jdouble *buf); + + void (JNICALL *SetBooleanArrayRegion) + (JNIEnv *env, jbooleanArray array, jsize start, jsize l, const jboolean *buf); + void (JNICALL *SetByteArrayRegion) + (JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte *buf); + void (JNICALL *SetCharArrayRegion) + (JNIEnv *env, jcharArray array, jsize start, jsize len, const jchar *buf); + void (JNICALL *SetShortArrayRegion) + (JNIEnv *env, jshortArray array, jsize start, jsize len, const jshort *buf); + void (JNICALL *SetIntArrayRegion) + (JNIEnv *env, jintArray array, jsize start, jsize len, const jint *buf); + void (JNICALL *SetLongArrayRegion) + (JNIEnv *env, jlongArray array, jsize start, jsize len, const jlong *buf); + void (JNICALL *SetFloatArrayRegion) + (JNIEnv *env, jfloatArray array, jsize start, jsize len, const jfloat *buf); + void (JNICALL *SetDoubleArrayRegion) + (JNIEnv *env, jdoubleArray array, jsize start, jsize len, const jdouble *buf); + + jint (JNICALL *RegisterNatives) + (JNIEnv *env, jclass clazz, const JNINativeMethod *methods, + jint nMethods); + jint (JNICALL *UnregisterNatives) + (JNIEnv *env, jclass clazz); + + jint (JNICALL *MonitorEnter) + (JNIEnv *env, jobject obj); + jint (JNICALL *MonitorExit) + (JNIEnv *env, jobject obj); + + jint (JNICALL *GetJavaVM) + (JNIEnv *env, JavaVM **vm); + + void (JNICALL *GetStringRegion) + (JNIEnv *env, jstring str, jsize start, jsize len, jchar *buf); + void (JNICALL *GetStringUTFRegion) + (JNIEnv *env, jstring str, jsize start, jsize len, char *buf); + + void * (JNICALL *GetPrimitiveArrayCritical) + (JNIEnv *env, jarray array, jboolean *isCopy); + void (JNICALL *ReleasePrimitiveArrayCritical) + (JNIEnv *env, jarray array, void *carray, jint mode); + + const jchar * (JNICALL *GetStringCritical) + (JNIEnv *env, jstring string, jboolean *isCopy); + void (JNICALL *ReleaseStringCritical) + (JNIEnv *env, jstring string, const jchar *cstring); + + jweak (JNICALL *NewWeakGlobalRef) + (JNIEnv *env, jobject obj); + void (JNICALL *DeleteWeakGlobalRef) + (JNIEnv *env, jweak ref); + + jboolean (JNICALL *ExceptionCheck) + (JNIEnv *env); + + jobject (JNICALL *NewDirectByteBuffer) + (JNIEnv* env, void* address, jlong capacity); + void* (JNICALL *GetDirectBufferAddress) + (JNIEnv* env, jobject buf); + jlong (JNICALL *GetDirectBufferCapacity) + (JNIEnv* env, jobject buf); + + /* New JNI 1.6 Features */ + + jobjectRefType (JNICALL *GetObjectRefType) + (JNIEnv* env, jobject obj); + + /* Module Features */ + + jobject (JNICALL *GetModule) + (JNIEnv* env, jclass clazz); + + /* Virtual threads */ + + jboolean (JNICALL *IsVirtualThread) + (JNIEnv* env, jobject obj); +}; + +/* + * We use inlined functions for C++ so that programmers can write: + * + * env->FindClass("java/lang/String") + * + * in C++ rather than: + * + * (*env)->FindClass(env, "java/lang/String") + * + * in C. + */ + +struct JNIEnv_ { + const struct JNINativeInterface_ *functions; +#ifdef __cplusplus + + jint GetVersion() { + return functions->GetVersion(this); + } + jclass DefineClass(const char *name, jobject loader, const jbyte *buf, + jsize len) { + return functions->DefineClass(this, name, loader, buf, len); + } + jclass FindClass(const char *name) { + return functions->FindClass(this, name); + } + jmethodID FromReflectedMethod(jobject method) { + return functions->FromReflectedMethod(this,method); + } + jfieldID FromReflectedField(jobject field) { + return functions->FromReflectedField(this,field); + } + + jobject ToReflectedMethod(jclass cls, jmethodID methodID, jboolean isStatic) { + return functions->ToReflectedMethod(this, cls, methodID, isStatic); + } + + jclass GetSuperclass(jclass sub) { + return functions->GetSuperclass(this, sub); + } + jboolean IsAssignableFrom(jclass sub, jclass sup) { + return functions->IsAssignableFrom(this, sub, sup); + } + + jobject ToReflectedField(jclass cls, jfieldID fieldID, jboolean isStatic) { + return functions->ToReflectedField(this,cls,fieldID,isStatic); + } + + jint Throw(jthrowable obj) { + return functions->Throw(this, obj); + } + jint ThrowNew(jclass clazz, const char *msg) { + return functions->ThrowNew(this, clazz, msg); + } + jthrowable ExceptionOccurred() { + return functions->ExceptionOccurred(this); + } + void ExceptionDescribe() { + functions->ExceptionDescribe(this); + } + void ExceptionClear() { + functions->ExceptionClear(this); + } + void FatalError(const char *msg) { + functions->FatalError(this, msg); + } + + jint PushLocalFrame(jint capacity) { + return functions->PushLocalFrame(this,capacity); + } + jobject PopLocalFrame(jobject result) { + return functions->PopLocalFrame(this,result); + } + + jobject NewGlobalRef(jobject lobj) { + return functions->NewGlobalRef(this,lobj); + } + void DeleteGlobalRef(jobject gref) { + functions->DeleteGlobalRef(this,gref); + } + void DeleteLocalRef(jobject obj) { + functions->DeleteLocalRef(this, obj); + } + + jboolean IsSameObject(jobject obj1, jobject obj2) { + return functions->IsSameObject(this,obj1,obj2); + } + + jobject NewLocalRef(jobject ref) { + return functions->NewLocalRef(this,ref); + } + jint EnsureLocalCapacity(jint capacity) { + return functions->EnsureLocalCapacity(this,capacity); + } + + jobject AllocObject(jclass clazz) { + return functions->AllocObject(this,clazz); + } + jobject NewObject(jclass clazz, jmethodID methodID, ...) { + va_list args; + jobject result; + va_start(args, methodID); + result = functions->NewObjectV(this,clazz,methodID,args); + va_end(args); + return result; + } + jobject NewObjectV(jclass clazz, jmethodID methodID, + va_list args) { + return functions->NewObjectV(this,clazz,methodID,args); + } + jobject NewObjectA(jclass clazz, jmethodID methodID, + const jvalue *args) { + return functions->NewObjectA(this,clazz,methodID,args); + } + + jclass GetObjectClass(jobject obj) { + return functions->GetObjectClass(this,obj); + } + jboolean IsInstanceOf(jobject obj, jclass clazz) { + return functions->IsInstanceOf(this,obj,clazz); + } + + jmethodID GetMethodID(jclass clazz, const char *name, + const char *sig) { + return functions->GetMethodID(this,clazz,name,sig); + } + + jobject CallObjectMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jobject result; + va_start(args,methodID); + result = functions->CallObjectMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jobject CallObjectMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallObjectMethodV(this,obj,methodID,args); + } + jobject CallObjectMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallObjectMethodA(this,obj,methodID,args); + } + + jboolean CallBooleanMethod(jobject obj, + jmethodID methodID, ...) { + va_list args; + jboolean result; + va_start(args,methodID); + result = functions->CallBooleanMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jboolean CallBooleanMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallBooleanMethodV(this,obj,methodID,args); + } + jboolean CallBooleanMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallBooleanMethodA(this,obj,methodID, args); + } + + jbyte CallByteMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jbyte result; + va_start(args,methodID); + result = functions->CallByteMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jbyte CallByteMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallByteMethodV(this,obj,methodID,args); + } + jbyte CallByteMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallByteMethodA(this,obj,methodID,args); + } + + jchar CallCharMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jchar result; + va_start(args,methodID); + result = functions->CallCharMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jchar CallCharMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallCharMethodV(this,obj,methodID,args); + } + jchar CallCharMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallCharMethodA(this,obj,methodID,args); + } + + jshort CallShortMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jshort result; + va_start(args,methodID); + result = functions->CallShortMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jshort CallShortMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallShortMethodV(this,obj,methodID,args); + } + jshort CallShortMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallShortMethodA(this,obj,methodID,args); + } + + jint CallIntMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jint result; + va_start(args,methodID); + result = functions->CallIntMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jint CallIntMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallIntMethodV(this,obj,methodID,args); + } + jint CallIntMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallIntMethodA(this,obj,methodID,args); + } + + jlong CallLongMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jlong result; + va_start(args,methodID); + result = functions->CallLongMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jlong CallLongMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallLongMethodV(this,obj,methodID,args); + } + jlong CallLongMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallLongMethodA(this,obj,methodID,args); + } + + jfloat CallFloatMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jfloat result; + va_start(args,methodID); + result = functions->CallFloatMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jfloat CallFloatMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallFloatMethodV(this,obj,methodID,args); + } + jfloat CallFloatMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallFloatMethodA(this,obj,methodID,args); + } + + jdouble CallDoubleMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jdouble result; + va_start(args,methodID); + result = functions->CallDoubleMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jdouble CallDoubleMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallDoubleMethodV(this,obj,methodID,args); + } + jdouble CallDoubleMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallDoubleMethodA(this,obj,methodID,args); + } + + void CallVoidMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + va_start(args,methodID); + functions->CallVoidMethodV(this,obj,methodID,args); + va_end(args); + } + void CallVoidMethodV(jobject obj, jmethodID methodID, + va_list args) { + functions->CallVoidMethodV(this,obj,methodID,args); + } + void CallVoidMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + functions->CallVoidMethodA(this,obj,methodID,args); + } + + jobject CallNonvirtualObjectMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jobject result; + va_start(args,methodID); + result = functions->CallNonvirtualObjectMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jobject CallNonvirtualObjectMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualObjectMethodV(this,obj,clazz, + methodID,args); + } + jobject CallNonvirtualObjectMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualObjectMethodA(this,obj,clazz, + methodID,args); + } + + jboolean CallNonvirtualBooleanMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jboolean result; + va_start(args,methodID); + result = functions->CallNonvirtualBooleanMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jboolean CallNonvirtualBooleanMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualBooleanMethodV(this,obj,clazz, + methodID,args); + } + jboolean CallNonvirtualBooleanMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualBooleanMethodA(this,obj,clazz, + methodID, args); + } + + jbyte CallNonvirtualByteMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jbyte result; + va_start(args,methodID); + result = functions->CallNonvirtualByteMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jbyte CallNonvirtualByteMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualByteMethodV(this,obj,clazz, + methodID,args); + } + jbyte CallNonvirtualByteMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualByteMethodA(this,obj,clazz, + methodID,args); + } + + jchar CallNonvirtualCharMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jchar result; + va_start(args,methodID); + result = functions->CallNonvirtualCharMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jchar CallNonvirtualCharMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualCharMethodV(this,obj,clazz, + methodID,args); + } + jchar CallNonvirtualCharMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualCharMethodA(this,obj,clazz, + methodID,args); + } + + jshort CallNonvirtualShortMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jshort result; + va_start(args,methodID); + result = functions->CallNonvirtualShortMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jshort CallNonvirtualShortMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualShortMethodV(this,obj,clazz, + methodID,args); + } + jshort CallNonvirtualShortMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualShortMethodA(this,obj,clazz, + methodID,args); + } + + jint CallNonvirtualIntMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jint result; + va_start(args,methodID); + result = functions->CallNonvirtualIntMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jint CallNonvirtualIntMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualIntMethodV(this,obj,clazz, + methodID,args); + } + jint CallNonvirtualIntMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualIntMethodA(this,obj,clazz, + methodID,args); + } + + jlong CallNonvirtualLongMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jlong result; + va_start(args,methodID); + result = functions->CallNonvirtualLongMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jlong CallNonvirtualLongMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualLongMethodV(this,obj,clazz, + methodID,args); + } + jlong CallNonvirtualLongMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualLongMethodA(this,obj,clazz, + methodID,args); + } + + jfloat CallNonvirtualFloatMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jfloat result; + va_start(args,methodID); + result = functions->CallNonvirtualFloatMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jfloat CallNonvirtualFloatMethodV(jobject obj, jclass clazz, + jmethodID methodID, + va_list args) { + return functions->CallNonvirtualFloatMethodV(this,obj,clazz, + methodID,args); + } + jfloat CallNonvirtualFloatMethodA(jobject obj, jclass clazz, + jmethodID methodID, + const jvalue * args) { + return functions->CallNonvirtualFloatMethodA(this,obj,clazz, + methodID,args); + } + + jdouble CallNonvirtualDoubleMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jdouble result; + va_start(args,methodID); + result = functions->CallNonvirtualDoubleMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jdouble CallNonvirtualDoubleMethodV(jobject obj, jclass clazz, + jmethodID methodID, + va_list args) { + return functions->CallNonvirtualDoubleMethodV(this,obj,clazz, + methodID,args); + } + jdouble CallNonvirtualDoubleMethodA(jobject obj, jclass clazz, + jmethodID methodID, + const jvalue * args) { + return functions->CallNonvirtualDoubleMethodA(this,obj,clazz, + methodID,args); + } + + void CallNonvirtualVoidMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + va_start(args,methodID); + functions->CallNonvirtualVoidMethodV(this,obj,clazz,methodID,args); + va_end(args); + } + void CallNonvirtualVoidMethodV(jobject obj, jclass clazz, + jmethodID methodID, + va_list args) { + functions->CallNonvirtualVoidMethodV(this,obj,clazz,methodID,args); + } + void CallNonvirtualVoidMethodA(jobject obj, jclass clazz, + jmethodID methodID, + const jvalue * args) { + functions->CallNonvirtualVoidMethodA(this,obj,clazz,methodID,args); + } + + jfieldID GetFieldID(jclass clazz, const char *name, + const char *sig) { + return functions->GetFieldID(this,clazz,name,sig); + } + + jobject GetObjectField(jobject obj, jfieldID fieldID) { + return functions->GetObjectField(this,obj,fieldID); + } + jboolean GetBooleanField(jobject obj, jfieldID fieldID) { + return functions->GetBooleanField(this,obj,fieldID); + } + jbyte GetByteField(jobject obj, jfieldID fieldID) { + return functions->GetByteField(this,obj,fieldID); + } + jchar GetCharField(jobject obj, jfieldID fieldID) { + return functions->GetCharField(this,obj,fieldID); + } + jshort GetShortField(jobject obj, jfieldID fieldID) { + return functions->GetShortField(this,obj,fieldID); + } + jint GetIntField(jobject obj, jfieldID fieldID) { + return functions->GetIntField(this,obj,fieldID); + } + jlong GetLongField(jobject obj, jfieldID fieldID) { + return functions->GetLongField(this,obj,fieldID); + } + jfloat GetFloatField(jobject obj, jfieldID fieldID) { + return functions->GetFloatField(this,obj,fieldID); + } + jdouble GetDoubleField(jobject obj, jfieldID fieldID) { + return functions->GetDoubleField(this,obj,fieldID); + } + + void SetObjectField(jobject obj, jfieldID fieldID, jobject val) { + functions->SetObjectField(this,obj,fieldID,val); + } + void SetBooleanField(jobject obj, jfieldID fieldID, + jboolean val) { + functions->SetBooleanField(this,obj,fieldID,val); + } + void SetByteField(jobject obj, jfieldID fieldID, + jbyte val) { + functions->SetByteField(this,obj,fieldID,val); + } + void SetCharField(jobject obj, jfieldID fieldID, + jchar val) { + functions->SetCharField(this,obj,fieldID,val); + } + void SetShortField(jobject obj, jfieldID fieldID, + jshort val) { + functions->SetShortField(this,obj,fieldID,val); + } + void SetIntField(jobject obj, jfieldID fieldID, + jint val) { + functions->SetIntField(this,obj,fieldID,val); + } + void SetLongField(jobject obj, jfieldID fieldID, + jlong val) { + functions->SetLongField(this,obj,fieldID,val); + } + void SetFloatField(jobject obj, jfieldID fieldID, + jfloat val) { + functions->SetFloatField(this,obj,fieldID,val); + } + void SetDoubleField(jobject obj, jfieldID fieldID, + jdouble val) { + functions->SetDoubleField(this,obj,fieldID,val); + } + + jmethodID GetStaticMethodID(jclass clazz, const char *name, + const char *sig) { + return functions->GetStaticMethodID(this,clazz,name,sig); + } + + jobject CallStaticObjectMethod(jclass clazz, jmethodID methodID, + ...) { + va_list args; + jobject result; + va_start(args,methodID); + result = functions->CallStaticObjectMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jobject CallStaticObjectMethodV(jclass clazz, jmethodID methodID, + va_list args) { + return functions->CallStaticObjectMethodV(this,clazz,methodID,args); + } + jobject CallStaticObjectMethodA(jclass clazz, jmethodID methodID, + const jvalue *args) { + return functions->CallStaticObjectMethodA(this,clazz,methodID,args); + } + + jboolean CallStaticBooleanMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jboolean result; + va_start(args,methodID); + result = functions->CallStaticBooleanMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jboolean CallStaticBooleanMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticBooleanMethodV(this,clazz,methodID,args); + } + jboolean CallStaticBooleanMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticBooleanMethodA(this,clazz,methodID,args); + } + + jbyte CallStaticByteMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jbyte result; + va_start(args,methodID); + result = functions->CallStaticByteMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jbyte CallStaticByteMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticByteMethodV(this,clazz,methodID,args); + } + jbyte CallStaticByteMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticByteMethodA(this,clazz,methodID,args); + } + + jchar CallStaticCharMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jchar result; + va_start(args,methodID); + result = functions->CallStaticCharMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jchar CallStaticCharMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticCharMethodV(this,clazz,methodID,args); + } + jchar CallStaticCharMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticCharMethodA(this,clazz,methodID,args); + } + + jshort CallStaticShortMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jshort result; + va_start(args,methodID); + result = functions->CallStaticShortMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jshort CallStaticShortMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticShortMethodV(this,clazz,methodID,args); + } + jshort CallStaticShortMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticShortMethodA(this,clazz,methodID,args); + } + + jint CallStaticIntMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jint result; + va_start(args,methodID); + result = functions->CallStaticIntMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jint CallStaticIntMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticIntMethodV(this,clazz,methodID,args); + } + jint CallStaticIntMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticIntMethodA(this,clazz,methodID,args); + } + + jlong CallStaticLongMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jlong result; + va_start(args,methodID); + result = functions->CallStaticLongMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jlong CallStaticLongMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticLongMethodV(this,clazz,methodID,args); + } + jlong CallStaticLongMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticLongMethodA(this,clazz,methodID,args); + } + + jfloat CallStaticFloatMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jfloat result; + va_start(args,methodID); + result = functions->CallStaticFloatMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jfloat CallStaticFloatMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticFloatMethodV(this,clazz,methodID,args); + } + jfloat CallStaticFloatMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticFloatMethodA(this,clazz,methodID,args); + } + + jdouble CallStaticDoubleMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jdouble result; + va_start(args,methodID); + result = functions->CallStaticDoubleMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jdouble CallStaticDoubleMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticDoubleMethodV(this,clazz,methodID,args); + } + jdouble CallStaticDoubleMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticDoubleMethodA(this,clazz,methodID,args); + } + + void CallStaticVoidMethod(jclass cls, jmethodID methodID, ...) { + va_list args; + va_start(args,methodID); + functions->CallStaticVoidMethodV(this,cls,methodID,args); + va_end(args); + } + void CallStaticVoidMethodV(jclass cls, jmethodID methodID, + va_list args) { + functions->CallStaticVoidMethodV(this,cls,methodID,args); + } + void CallStaticVoidMethodA(jclass cls, jmethodID methodID, + const jvalue * args) { + functions->CallStaticVoidMethodA(this,cls,methodID,args); + } + + jfieldID GetStaticFieldID(jclass clazz, const char *name, + const char *sig) { + return functions->GetStaticFieldID(this,clazz,name,sig); + } + jobject GetStaticObjectField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticObjectField(this,clazz,fieldID); + } + jboolean GetStaticBooleanField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticBooleanField(this,clazz,fieldID); + } + jbyte GetStaticByteField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticByteField(this,clazz,fieldID); + } + jchar GetStaticCharField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticCharField(this,clazz,fieldID); + } + jshort GetStaticShortField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticShortField(this,clazz,fieldID); + } + jint GetStaticIntField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticIntField(this,clazz,fieldID); + } + jlong GetStaticLongField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticLongField(this,clazz,fieldID); + } + jfloat GetStaticFloatField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticFloatField(this,clazz,fieldID); + } + jdouble GetStaticDoubleField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticDoubleField(this,clazz,fieldID); + } + + void SetStaticObjectField(jclass clazz, jfieldID fieldID, + jobject value) { + functions->SetStaticObjectField(this,clazz,fieldID,value); + } + void SetStaticBooleanField(jclass clazz, jfieldID fieldID, + jboolean value) { + functions->SetStaticBooleanField(this,clazz,fieldID,value); + } + void SetStaticByteField(jclass clazz, jfieldID fieldID, + jbyte value) { + functions->SetStaticByteField(this,clazz,fieldID,value); + } + void SetStaticCharField(jclass clazz, jfieldID fieldID, + jchar value) { + functions->SetStaticCharField(this,clazz,fieldID,value); + } + void SetStaticShortField(jclass clazz, jfieldID fieldID, + jshort value) { + functions->SetStaticShortField(this,clazz,fieldID,value); + } + void SetStaticIntField(jclass clazz, jfieldID fieldID, + jint value) { + functions->SetStaticIntField(this,clazz,fieldID,value); + } + void SetStaticLongField(jclass clazz, jfieldID fieldID, + jlong value) { + functions->SetStaticLongField(this,clazz,fieldID,value); + } + void SetStaticFloatField(jclass clazz, jfieldID fieldID, + jfloat value) { + functions->SetStaticFloatField(this,clazz,fieldID,value); + } + void SetStaticDoubleField(jclass clazz, jfieldID fieldID, + jdouble value) { + functions->SetStaticDoubleField(this,clazz,fieldID,value); + } + + jstring NewString(const jchar *unicode, jsize len) { + return functions->NewString(this,unicode,len); + } + jsize GetStringLength(jstring str) { + return functions->GetStringLength(this,str); + } + const jchar *GetStringChars(jstring str, jboolean *isCopy) { + return functions->GetStringChars(this,str,isCopy); + } + void ReleaseStringChars(jstring str, const jchar *chars) { + functions->ReleaseStringChars(this,str,chars); + } + + jstring NewStringUTF(const char *utf) { + return functions->NewStringUTF(this,utf); + } + jsize GetStringUTFLength(jstring str) { + return functions->GetStringUTFLength(this,str); + } + const char* GetStringUTFChars(jstring str, jboolean *isCopy) { + return functions->GetStringUTFChars(this,str,isCopy); + } + void ReleaseStringUTFChars(jstring str, const char* chars) { + functions->ReleaseStringUTFChars(this,str,chars); + } + + jsize GetArrayLength(jarray array) { + return functions->GetArrayLength(this,array); + } + + jobjectArray NewObjectArray(jsize len, jclass clazz, + jobject init) { + return functions->NewObjectArray(this,len,clazz,init); + } + jobject GetObjectArrayElement(jobjectArray array, jsize index) { + return functions->GetObjectArrayElement(this,array,index); + } + void SetObjectArrayElement(jobjectArray array, jsize index, + jobject val) { + functions->SetObjectArrayElement(this,array,index,val); + } + + jbooleanArray NewBooleanArray(jsize len) { + return functions->NewBooleanArray(this,len); + } + jbyteArray NewByteArray(jsize len) { + return functions->NewByteArray(this,len); + } + jcharArray NewCharArray(jsize len) { + return functions->NewCharArray(this,len); + } + jshortArray NewShortArray(jsize len) { + return functions->NewShortArray(this,len); + } + jintArray NewIntArray(jsize len) { + return functions->NewIntArray(this,len); + } + jlongArray NewLongArray(jsize len) { + return functions->NewLongArray(this,len); + } + jfloatArray NewFloatArray(jsize len) { + return functions->NewFloatArray(this,len); + } + jdoubleArray NewDoubleArray(jsize len) { + return functions->NewDoubleArray(this,len); + } + + jboolean * GetBooleanArrayElements(jbooleanArray array, jboolean *isCopy) { + return functions->GetBooleanArrayElements(this,array,isCopy); + } + jbyte * GetByteArrayElements(jbyteArray array, jboolean *isCopy) { + return functions->GetByteArrayElements(this,array,isCopy); + } + jchar * GetCharArrayElements(jcharArray array, jboolean *isCopy) { + return functions->GetCharArrayElements(this,array,isCopy); + } + jshort * GetShortArrayElements(jshortArray array, jboolean *isCopy) { + return functions->GetShortArrayElements(this,array,isCopy); + } + jint * GetIntArrayElements(jintArray array, jboolean *isCopy) { + return functions->GetIntArrayElements(this,array,isCopy); + } + jlong * GetLongArrayElements(jlongArray array, jboolean *isCopy) { + return functions->GetLongArrayElements(this,array,isCopy); + } + jfloat * GetFloatArrayElements(jfloatArray array, jboolean *isCopy) { + return functions->GetFloatArrayElements(this,array,isCopy); + } + jdouble * GetDoubleArrayElements(jdoubleArray array, jboolean *isCopy) { + return functions->GetDoubleArrayElements(this,array,isCopy); + } + + void ReleaseBooleanArrayElements(jbooleanArray array, + jboolean *elems, + jint mode) { + functions->ReleaseBooleanArrayElements(this,array,elems,mode); + } + void ReleaseByteArrayElements(jbyteArray array, + jbyte *elems, + jint mode) { + functions->ReleaseByteArrayElements(this,array,elems,mode); + } + void ReleaseCharArrayElements(jcharArray array, + jchar *elems, + jint mode) { + functions->ReleaseCharArrayElements(this,array,elems,mode); + } + void ReleaseShortArrayElements(jshortArray array, + jshort *elems, + jint mode) { + functions->ReleaseShortArrayElements(this,array,elems,mode); + } + void ReleaseIntArrayElements(jintArray array, + jint *elems, + jint mode) { + functions->ReleaseIntArrayElements(this,array,elems,mode); + } + void ReleaseLongArrayElements(jlongArray array, + jlong *elems, + jint mode) { + functions->ReleaseLongArrayElements(this,array,elems,mode); + } + void ReleaseFloatArrayElements(jfloatArray array, + jfloat *elems, + jint mode) { + functions->ReleaseFloatArrayElements(this,array,elems,mode); + } + void ReleaseDoubleArrayElements(jdoubleArray array, + jdouble *elems, + jint mode) { + functions->ReleaseDoubleArrayElements(this,array,elems,mode); + } + + void GetBooleanArrayRegion(jbooleanArray array, + jsize start, jsize len, jboolean *buf) { + functions->GetBooleanArrayRegion(this,array,start,len,buf); + } + void GetByteArrayRegion(jbyteArray array, + jsize start, jsize len, jbyte *buf) { + functions->GetByteArrayRegion(this,array,start,len,buf); + } + void GetCharArrayRegion(jcharArray array, + jsize start, jsize len, jchar *buf) { + functions->GetCharArrayRegion(this,array,start,len,buf); + } + void GetShortArrayRegion(jshortArray array, + jsize start, jsize len, jshort *buf) { + functions->GetShortArrayRegion(this,array,start,len,buf); + } + void GetIntArrayRegion(jintArray array, + jsize start, jsize len, jint *buf) { + functions->GetIntArrayRegion(this,array,start,len,buf); + } + void GetLongArrayRegion(jlongArray array, + jsize start, jsize len, jlong *buf) { + functions->GetLongArrayRegion(this,array,start,len,buf); + } + void GetFloatArrayRegion(jfloatArray array, + jsize start, jsize len, jfloat *buf) { + functions->GetFloatArrayRegion(this,array,start,len,buf); + } + void GetDoubleArrayRegion(jdoubleArray array, + jsize start, jsize len, jdouble *buf) { + functions->GetDoubleArrayRegion(this,array,start,len,buf); + } + + void SetBooleanArrayRegion(jbooleanArray array, jsize start, jsize len, + const jboolean *buf) { + functions->SetBooleanArrayRegion(this,array,start,len,buf); + } + void SetByteArrayRegion(jbyteArray array, jsize start, jsize len, + const jbyte *buf) { + functions->SetByteArrayRegion(this,array,start,len,buf); + } + void SetCharArrayRegion(jcharArray array, jsize start, jsize len, + const jchar *buf) { + functions->SetCharArrayRegion(this,array,start,len,buf); + } + void SetShortArrayRegion(jshortArray array, jsize start, jsize len, + const jshort *buf) { + functions->SetShortArrayRegion(this,array,start,len,buf); + } + void SetIntArrayRegion(jintArray array, jsize start, jsize len, + const jint *buf) { + functions->SetIntArrayRegion(this,array,start,len,buf); + } + void SetLongArrayRegion(jlongArray array, jsize start, jsize len, + const jlong *buf) { + functions->SetLongArrayRegion(this,array,start,len,buf); + } + void SetFloatArrayRegion(jfloatArray array, jsize start, jsize len, + const jfloat *buf) { + functions->SetFloatArrayRegion(this,array,start,len,buf); + } + void SetDoubleArrayRegion(jdoubleArray array, jsize start, jsize len, + const jdouble *buf) { + functions->SetDoubleArrayRegion(this,array,start,len,buf); + } + + jint RegisterNatives(jclass clazz, const JNINativeMethod *methods, + jint nMethods) { + return functions->RegisterNatives(this,clazz,methods,nMethods); + } + jint UnregisterNatives(jclass clazz) { + return functions->UnregisterNatives(this,clazz); + } + + jint MonitorEnter(jobject obj) { + return functions->MonitorEnter(this,obj); + } + jint MonitorExit(jobject obj) { + return functions->MonitorExit(this,obj); + } + + jint GetJavaVM(JavaVM **vm) { + return functions->GetJavaVM(this,vm); + } + + void GetStringRegion(jstring str, jsize start, jsize len, jchar *buf) { + functions->GetStringRegion(this,str,start,len,buf); + } + void GetStringUTFRegion(jstring str, jsize start, jsize len, char *buf) { + functions->GetStringUTFRegion(this,str,start,len,buf); + } + + void * GetPrimitiveArrayCritical(jarray array, jboolean *isCopy) { + return functions->GetPrimitiveArrayCritical(this,array,isCopy); + } + void ReleasePrimitiveArrayCritical(jarray array, void *carray, jint mode) { + functions->ReleasePrimitiveArrayCritical(this,array,carray,mode); + } + + const jchar * GetStringCritical(jstring string, jboolean *isCopy) { + return functions->GetStringCritical(this,string,isCopy); + } + void ReleaseStringCritical(jstring string, const jchar *cstring) { + functions->ReleaseStringCritical(this,string,cstring); + } + + jweak NewWeakGlobalRef(jobject obj) { + return functions->NewWeakGlobalRef(this,obj); + } + void DeleteWeakGlobalRef(jweak ref) { + functions->DeleteWeakGlobalRef(this,ref); + } + + jboolean ExceptionCheck() { + return functions->ExceptionCheck(this); + } + + jobject NewDirectByteBuffer(void* address, jlong capacity) { + return functions->NewDirectByteBuffer(this, address, capacity); + } + void* GetDirectBufferAddress(jobject buf) { + return functions->GetDirectBufferAddress(this, buf); + } + jlong GetDirectBufferCapacity(jobject buf) { + return functions->GetDirectBufferCapacity(this, buf); + } + jobjectRefType GetObjectRefType(jobject obj) { + return functions->GetObjectRefType(this, obj); + } + + /* Module Features */ + + jobject GetModule(jclass clazz) { + return functions->GetModule(this, clazz); + } + + /* Virtual threads */ + + jboolean IsVirtualThread(jobject obj) { + return functions->IsVirtualThread(this, obj); + } + +#endif /* __cplusplus */ +}; + +/* + * optionString may be any option accepted by the JVM, or one of the + * following: + * + * -D= Set a system property. + * -verbose[:class|gc|jni] Enable verbose output, comma-separated. E.g. + * "-verbose:class" or "-verbose:gc,class" + * Standard names include: gc, class, and jni. + * All nonstandard (VM-specific) names must begin + * with "X". + * vfprintf extraInfo is a pointer to the vfprintf hook. + * exit extraInfo is a pointer to the exit hook. + * abort extraInfo is a pointer to the abort hook. + */ +typedef struct JavaVMOption { + char *optionString; + void *extraInfo; +} JavaVMOption; + +typedef struct JavaVMInitArgs { + jint version; + + jint nOptions; + JavaVMOption *options; + jboolean ignoreUnrecognized; +} JavaVMInitArgs; + +typedef struct JavaVMAttachArgs { + jint version; + + char *name; + jobject group; +} JavaVMAttachArgs; + +/* These will be VM-specific. */ + +#define JDK1_2 +#define JDK1_4 + +/* End VM-specific. */ + +struct JNIInvokeInterface_ { + void *reserved0; + void *reserved1; + void *reserved2; + + jint (JNICALL *DestroyJavaVM)(JavaVM *vm); + + jint (JNICALL *AttachCurrentThread)(JavaVM *vm, void **penv, void *args); + + jint (JNICALL *DetachCurrentThread)(JavaVM *vm); + + jint (JNICALL *GetEnv)(JavaVM *vm, void **penv, jint version); + + jint (JNICALL *AttachCurrentThreadAsDaemon)(JavaVM *vm, void **penv, void *args); +}; + +struct JavaVM_ { + const struct JNIInvokeInterface_ *functions; +#ifdef __cplusplus + + jint DestroyJavaVM() { + return functions->DestroyJavaVM(this); + } + jint AttachCurrentThread(void **penv, void *args) { + return functions->AttachCurrentThread(this, penv, args); + } + jint DetachCurrentThread() { + return functions->DetachCurrentThread(this); + } + + jint GetEnv(void **penv, jint version) { + return functions->GetEnv(this, penv, version); + } + jint AttachCurrentThreadAsDaemon(void **penv, void *args) { + return functions->AttachCurrentThreadAsDaemon(this, penv, args); + } +#endif +}; + +#ifdef _JNI_IMPLEMENTATION_ +#define _JNI_IMPORT_OR_EXPORT_ JNIEXPORT +#else +#define _JNI_IMPORT_OR_EXPORT_ JNIIMPORT +#endif +_JNI_IMPORT_OR_EXPORT_ jint JNICALL +JNI_GetDefaultJavaVMInitArgs(void *args); + +_JNI_IMPORT_OR_EXPORT_ jint JNICALL +JNI_CreateJavaVM(JavaVM **pvm, void **penv, void *args); + +_JNI_IMPORT_OR_EXPORT_ jint JNICALL +JNI_GetCreatedJavaVMs(JavaVM **, jsize, jsize *); + +/* Defined by native libraries. */ +JNIEXPORT jint JNICALL +JNI_OnLoad(JavaVM *vm, void *reserved); + +JNIEXPORT void JNICALL +JNI_OnUnload(JavaVM *vm, void *reserved); + +#define JNI_VERSION_1_1 0x00010001 +#define JNI_VERSION_1_2 0x00010002 +#define JNI_VERSION_1_4 0x00010004 +#define JNI_VERSION_1_6 0x00010006 +#define JNI_VERSION_1_8 0x00010008 +#define JNI_VERSION_9 0x00090000 +#define JNI_VERSION_10 0x000a0000 +#define JNI_VERSION_19 0x00130000 +#define JNI_VERSION_20 0x00140000 +#define JNI_VERSION_21 0x00150000 + +#ifdef __cplusplus +} /* extern "C" */ +#endif /* __cplusplus */ + +#endif /* !_JAVASOFT_JNI_H_ */ diff --git a/.github/include/unix/jni_md.h b/.github/include/unix/jni_md.h new file mode 100644 index 00000000..6e352038 --- /dev/null +++ b/.github/include/unix/jni_md.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 1996, 2013, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +#ifndef _JAVASOFT_JNI_MD_H_ +#define _JAVASOFT_JNI_MD_H_ + +#ifndef __has_attribute + #define __has_attribute(x) 0 +#endif +#if (defined(__GNUC__) && ((__GNUC__ > 4) || (__GNUC__ == 4) && (__GNUC_MINOR__ > 2))) || __has_attribute(visibility) + #ifdef ARM + #define JNIEXPORT __attribute__((externally_visible,visibility("default"))) + #define JNIIMPORT __attribute__((externally_visible,visibility("default"))) + #else + #define JNIEXPORT __attribute__((visibility("default"))) + #define JNIIMPORT __attribute__((visibility("default"))) + #endif +#else + #define JNIEXPORT + #define JNIIMPORT +#endif + +#define JNICALL + +typedef int jint; +#ifdef _LP64 +typedef long jlong; +#else +typedef long long jlong; +#endif + +typedef signed char jbyte; + +#endif /* !_JAVASOFT_JNI_MD_H_ */ diff --git a/.github/include/windows/jni.h b/.github/include/windows/jni.h new file mode 100644 index 00000000..c85da1bc --- /dev/null +++ b/.github/include/windows/jni.h @@ -0,0 +1,2001 @@ +/* + * Copyright (c) 1996, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * We used part of Netscape's Java Runtime Interface (JRI) as the starting + * point of our design and implementation. + */ + +/****************************************************************************** + * Java Runtime Interface + * Copyright (c) 1996 Netscape Communications Corporation. All rights reserved. + *****************************************************************************/ + +#ifndef _JAVASOFT_JNI_H_ +#define _JAVASOFT_JNI_H_ + +#include +#include + +/* jni_md.h contains the machine-dependent typedefs for jbyte, jint + and jlong */ + +#include "jni_md.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * JNI Types + */ + +#ifndef JNI_TYPES_ALREADY_DEFINED_IN_JNI_MD_H + +typedef unsigned char jboolean; +typedef unsigned short jchar; +typedef short jshort; +typedef float jfloat; +typedef double jdouble; + +typedef jint jsize; + +#ifdef __cplusplus + +class _jobject {}; +class _jclass : public _jobject {}; +class _jthrowable : public _jobject {}; +class _jstring : public _jobject {}; +class _jarray : public _jobject {}; +class _jbooleanArray : public _jarray {}; +class _jbyteArray : public _jarray {}; +class _jcharArray : public _jarray {}; +class _jshortArray : public _jarray {}; +class _jintArray : public _jarray {}; +class _jlongArray : public _jarray {}; +class _jfloatArray : public _jarray {}; +class _jdoubleArray : public _jarray {}; +class _jobjectArray : public _jarray {}; + +typedef _jobject *jobject; +typedef _jclass *jclass; +typedef _jthrowable *jthrowable; +typedef _jstring *jstring; +typedef _jarray *jarray; +typedef _jbooleanArray *jbooleanArray; +typedef _jbyteArray *jbyteArray; +typedef _jcharArray *jcharArray; +typedef _jshortArray *jshortArray; +typedef _jintArray *jintArray; +typedef _jlongArray *jlongArray; +typedef _jfloatArray *jfloatArray; +typedef _jdoubleArray *jdoubleArray; +typedef _jobjectArray *jobjectArray; + +#else + +struct _jobject; + +typedef struct _jobject *jobject; +typedef jobject jclass; +typedef jobject jthrowable; +typedef jobject jstring; +typedef jobject jarray; +typedef jarray jbooleanArray; +typedef jarray jbyteArray; +typedef jarray jcharArray; +typedef jarray jshortArray; +typedef jarray jintArray; +typedef jarray jlongArray; +typedef jarray jfloatArray; +typedef jarray jdoubleArray; +typedef jarray jobjectArray; + +#endif + +typedef jobject jweak; + +typedef union jvalue { + jboolean z; + jbyte b; + jchar c; + jshort s; + jint i; + jlong j; + jfloat f; + jdouble d; + jobject l; +} jvalue; + +struct _jfieldID; +typedef struct _jfieldID *jfieldID; + +struct _jmethodID; +typedef struct _jmethodID *jmethodID; + +/* Return values from jobjectRefType */ +typedef enum _jobjectType { + JNIInvalidRefType = 0, + JNILocalRefType = 1, + JNIGlobalRefType = 2, + JNIWeakGlobalRefType = 3 +} jobjectRefType; + + +#endif /* JNI_TYPES_ALREADY_DEFINED_IN_JNI_MD_H */ + +/* + * jboolean constants + */ + +#define JNI_FALSE 0 +#define JNI_TRUE 1 + +/* + * possible return values for JNI functions. + */ + +#define JNI_OK 0 /* success */ +#define JNI_ERR (-1) /* unknown error */ +#define JNI_EDETACHED (-2) /* thread detached from the VM */ +#define JNI_EVERSION (-3) /* JNI version error */ +#define JNI_ENOMEM (-4) /* not enough memory */ +#define JNI_EEXIST (-5) /* VM already created */ +#define JNI_EINVAL (-6) /* invalid arguments */ + +/* + * used in ReleaseScalarArrayElements + */ + +#define JNI_COMMIT 1 +#define JNI_ABORT 2 + +/* + * used in RegisterNatives to describe native method name, signature, + * and function pointer. + */ + +typedef struct { + char *name; + char *signature; + void *fnPtr; +} JNINativeMethod; + +/* + * JNI Native Method Interface. + */ + +struct JNINativeInterface_; + +struct JNIEnv_; + +#ifdef __cplusplus +typedef JNIEnv_ JNIEnv; +#else +typedef const struct JNINativeInterface_ *JNIEnv; +#endif + +/* + * JNI Invocation Interface. + */ + +struct JNIInvokeInterface_; + +struct JavaVM_; + +#ifdef __cplusplus +typedef JavaVM_ JavaVM; +#else +typedef const struct JNIInvokeInterface_ *JavaVM; +#endif + +struct JNINativeInterface_ { + void *reserved0; + void *reserved1; + void *reserved2; + + void *reserved3; + jint (JNICALL *GetVersion)(JNIEnv *env); + + jclass (JNICALL *DefineClass) + (JNIEnv *env, const char *name, jobject loader, const jbyte *buf, + jsize len); + jclass (JNICALL *FindClass) + (JNIEnv *env, const char *name); + + jmethodID (JNICALL *FromReflectedMethod) + (JNIEnv *env, jobject method); + jfieldID (JNICALL *FromReflectedField) + (JNIEnv *env, jobject field); + + jobject (JNICALL *ToReflectedMethod) + (JNIEnv *env, jclass cls, jmethodID methodID, jboolean isStatic); + + jclass (JNICALL *GetSuperclass) + (JNIEnv *env, jclass sub); + jboolean (JNICALL *IsAssignableFrom) + (JNIEnv *env, jclass sub, jclass sup); + + jobject (JNICALL *ToReflectedField) + (JNIEnv *env, jclass cls, jfieldID fieldID, jboolean isStatic); + + jint (JNICALL *Throw) + (JNIEnv *env, jthrowable obj); + jint (JNICALL *ThrowNew) + (JNIEnv *env, jclass clazz, const char *msg); + jthrowable (JNICALL *ExceptionOccurred) + (JNIEnv *env); + void (JNICALL *ExceptionDescribe) + (JNIEnv *env); + void (JNICALL *ExceptionClear) + (JNIEnv *env); + void (JNICALL *FatalError) + (JNIEnv *env, const char *msg); + + jint (JNICALL *PushLocalFrame) + (JNIEnv *env, jint capacity); + jobject (JNICALL *PopLocalFrame) + (JNIEnv *env, jobject result); + + jobject (JNICALL *NewGlobalRef) + (JNIEnv *env, jobject lobj); + void (JNICALL *DeleteGlobalRef) + (JNIEnv *env, jobject gref); + void (JNICALL *DeleteLocalRef) + (JNIEnv *env, jobject obj); + jboolean (JNICALL *IsSameObject) + (JNIEnv *env, jobject obj1, jobject obj2); + jobject (JNICALL *NewLocalRef) + (JNIEnv *env, jobject ref); + jint (JNICALL *EnsureLocalCapacity) + (JNIEnv *env, jint capacity); + + jobject (JNICALL *AllocObject) + (JNIEnv *env, jclass clazz); + jobject (JNICALL *NewObject) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jobject (JNICALL *NewObjectV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jobject (JNICALL *NewObjectA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jclass (JNICALL *GetObjectClass) + (JNIEnv *env, jobject obj); + jboolean (JNICALL *IsInstanceOf) + (JNIEnv *env, jobject obj, jclass clazz); + + jmethodID (JNICALL *GetMethodID) + (JNIEnv *env, jclass clazz, const char *name, const char *sig); + + jobject (JNICALL *CallObjectMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jobject (JNICALL *CallObjectMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jobject (JNICALL *CallObjectMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue * args); + + jboolean (JNICALL *CallBooleanMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jboolean (JNICALL *CallBooleanMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jboolean (JNICALL *CallBooleanMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue * args); + + jbyte (JNICALL *CallByteMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jbyte (JNICALL *CallByteMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jbyte (JNICALL *CallByteMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + jchar (JNICALL *CallCharMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jchar (JNICALL *CallCharMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jchar (JNICALL *CallCharMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + jshort (JNICALL *CallShortMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jshort (JNICALL *CallShortMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jshort (JNICALL *CallShortMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + jint (JNICALL *CallIntMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jint (JNICALL *CallIntMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jint (JNICALL *CallIntMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + jlong (JNICALL *CallLongMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jlong (JNICALL *CallLongMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jlong (JNICALL *CallLongMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + jfloat (JNICALL *CallFloatMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jfloat (JNICALL *CallFloatMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jfloat (JNICALL *CallFloatMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + jdouble (JNICALL *CallDoubleMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + jdouble (JNICALL *CallDoubleMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + jdouble (JNICALL *CallDoubleMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue *args); + + void (JNICALL *CallVoidMethod) + (JNIEnv *env, jobject obj, jmethodID methodID, ...); + void (JNICALL *CallVoidMethodV) + (JNIEnv *env, jobject obj, jmethodID methodID, va_list args); + void (JNICALL *CallVoidMethodA) + (JNIEnv *env, jobject obj, jmethodID methodID, const jvalue * args); + + jobject (JNICALL *CallNonvirtualObjectMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jobject (JNICALL *CallNonvirtualObjectMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jobject (JNICALL *CallNonvirtualObjectMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue * args); + + jboolean (JNICALL *CallNonvirtualBooleanMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jboolean (JNICALL *CallNonvirtualBooleanMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jboolean (JNICALL *CallNonvirtualBooleanMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue * args); + + jbyte (JNICALL *CallNonvirtualByteMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jbyte (JNICALL *CallNonvirtualByteMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jbyte (JNICALL *CallNonvirtualByteMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + jchar (JNICALL *CallNonvirtualCharMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jchar (JNICALL *CallNonvirtualCharMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jchar (JNICALL *CallNonvirtualCharMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + jshort (JNICALL *CallNonvirtualShortMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jshort (JNICALL *CallNonvirtualShortMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jshort (JNICALL *CallNonvirtualShortMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + jint (JNICALL *CallNonvirtualIntMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jint (JNICALL *CallNonvirtualIntMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jint (JNICALL *CallNonvirtualIntMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + jlong (JNICALL *CallNonvirtualLongMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jlong (JNICALL *CallNonvirtualLongMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jlong (JNICALL *CallNonvirtualLongMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + jfloat (JNICALL *CallNonvirtualFloatMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jfloat (JNICALL *CallNonvirtualFloatMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jfloat (JNICALL *CallNonvirtualFloatMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + jdouble (JNICALL *CallNonvirtualDoubleMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + jdouble (JNICALL *CallNonvirtualDoubleMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + jdouble (JNICALL *CallNonvirtualDoubleMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue *args); + + void (JNICALL *CallNonvirtualVoidMethod) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, ...); + void (JNICALL *CallNonvirtualVoidMethodV) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + va_list args); + void (JNICALL *CallNonvirtualVoidMethodA) + (JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, + const jvalue * args); + + jfieldID (JNICALL *GetFieldID) + (JNIEnv *env, jclass clazz, const char *name, const char *sig); + + jobject (JNICALL *GetObjectField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jboolean (JNICALL *GetBooleanField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jbyte (JNICALL *GetByteField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jchar (JNICALL *GetCharField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jshort (JNICALL *GetShortField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jint (JNICALL *GetIntField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jlong (JNICALL *GetLongField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jfloat (JNICALL *GetFloatField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + jdouble (JNICALL *GetDoubleField) + (JNIEnv *env, jobject obj, jfieldID fieldID); + + void (JNICALL *SetObjectField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jobject val); + void (JNICALL *SetBooleanField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jboolean val); + void (JNICALL *SetByteField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jbyte val); + void (JNICALL *SetCharField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jchar val); + void (JNICALL *SetShortField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jshort val); + void (JNICALL *SetIntField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jint val); + void (JNICALL *SetLongField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jlong val); + void (JNICALL *SetFloatField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jfloat val); + void (JNICALL *SetDoubleField) + (JNIEnv *env, jobject obj, jfieldID fieldID, jdouble val); + + jmethodID (JNICALL *GetStaticMethodID) + (JNIEnv *env, jclass clazz, const char *name, const char *sig); + + jobject (JNICALL *CallStaticObjectMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jobject (JNICALL *CallStaticObjectMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jobject (JNICALL *CallStaticObjectMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jboolean (JNICALL *CallStaticBooleanMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jboolean (JNICALL *CallStaticBooleanMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jboolean (JNICALL *CallStaticBooleanMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jbyte (JNICALL *CallStaticByteMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jbyte (JNICALL *CallStaticByteMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jbyte (JNICALL *CallStaticByteMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jchar (JNICALL *CallStaticCharMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jchar (JNICALL *CallStaticCharMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jchar (JNICALL *CallStaticCharMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jshort (JNICALL *CallStaticShortMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jshort (JNICALL *CallStaticShortMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jshort (JNICALL *CallStaticShortMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jint (JNICALL *CallStaticIntMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jint (JNICALL *CallStaticIntMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jint (JNICALL *CallStaticIntMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jlong (JNICALL *CallStaticLongMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jlong (JNICALL *CallStaticLongMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jlong (JNICALL *CallStaticLongMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jfloat (JNICALL *CallStaticFloatMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jfloat (JNICALL *CallStaticFloatMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jfloat (JNICALL *CallStaticFloatMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + jdouble (JNICALL *CallStaticDoubleMethod) + (JNIEnv *env, jclass clazz, jmethodID methodID, ...); + jdouble (JNICALL *CallStaticDoubleMethodV) + (JNIEnv *env, jclass clazz, jmethodID methodID, va_list args); + jdouble (JNICALL *CallStaticDoubleMethodA) + (JNIEnv *env, jclass clazz, jmethodID methodID, const jvalue *args); + + void (JNICALL *CallStaticVoidMethod) + (JNIEnv *env, jclass cls, jmethodID methodID, ...); + void (JNICALL *CallStaticVoidMethodV) + (JNIEnv *env, jclass cls, jmethodID methodID, va_list args); + void (JNICALL *CallStaticVoidMethodA) + (JNIEnv *env, jclass cls, jmethodID methodID, const jvalue * args); + + jfieldID (JNICALL *GetStaticFieldID) + (JNIEnv *env, jclass clazz, const char *name, const char *sig); + jobject (JNICALL *GetStaticObjectField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jboolean (JNICALL *GetStaticBooleanField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jbyte (JNICALL *GetStaticByteField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jchar (JNICALL *GetStaticCharField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jshort (JNICALL *GetStaticShortField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jint (JNICALL *GetStaticIntField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jlong (JNICALL *GetStaticLongField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jfloat (JNICALL *GetStaticFloatField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + jdouble (JNICALL *GetStaticDoubleField) + (JNIEnv *env, jclass clazz, jfieldID fieldID); + + void (JNICALL *SetStaticObjectField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jobject value); + void (JNICALL *SetStaticBooleanField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jboolean value); + void (JNICALL *SetStaticByteField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jbyte value); + void (JNICALL *SetStaticCharField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jchar value); + void (JNICALL *SetStaticShortField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jshort value); + void (JNICALL *SetStaticIntField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jint value); + void (JNICALL *SetStaticLongField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jlong value); + void (JNICALL *SetStaticFloatField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jfloat value); + void (JNICALL *SetStaticDoubleField) + (JNIEnv *env, jclass clazz, jfieldID fieldID, jdouble value); + + jstring (JNICALL *NewString) + (JNIEnv *env, const jchar *unicode, jsize len); + jsize (JNICALL *GetStringLength) + (JNIEnv *env, jstring str); + const jchar *(JNICALL *GetStringChars) + (JNIEnv *env, jstring str, jboolean *isCopy); + void (JNICALL *ReleaseStringChars) + (JNIEnv *env, jstring str, const jchar *chars); + + jstring (JNICALL *NewStringUTF) + (JNIEnv *env, const char *utf); + jsize (JNICALL *GetStringUTFLength) + (JNIEnv *env, jstring str); + const char* (JNICALL *GetStringUTFChars) + (JNIEnv *env, jstring str, jboolean *isCopy); + void (JNICALL *ReleaseStringUTFChars) + (JNIEnv *env, jstring str, const char* chars); + + + jsize (JNICALL *GetArrayLength) + (JNIEnv *env, jarray array); + + jobjectArray (JNICALL *NewObjectArray) + (JNIEnv *env, jsize len, jclass clazz, jobject init); + jobject (JNICALL *GetObjectArrayElement) + (JNIEnv *env, jobjectArray array, jsize index); + void (JNICALL *SetObjectArrayElement) + (JNIEnv *env, jobjectArray array, jsize index, jobject val); + + jbooleanArray (JNICALL *NewBooleanArray) + (JNIEnv *env, jsize len); + jbyteArray (JNICALL *NewByteArray) + (JNIEnv *env, jsize len); + jcharArray (JNICALL *NewCharArray) + (JNIEnv *env, jsize len); + jshortArray (JNICALL *NewShortArray) + (JNIEnv *env, jsize len); + jintArray (JNICALL *NewIntArray) + (JNIEnv *env, jsize len); + jlongArray (JNICALL *NewLongArray) + (JNIEnv *env, jsize len); + jfloatArray (JNICALL *NewFloatArray) + (JNIEnv *env, jsize len); + jdoubleArray (JNICALL *NewDoubleArray) + (JNIEnv *env, jsize len); + + jboolean * (JNICALL *GetBooleanArrayElements) + (JNIEnv *env, jbooleanArray array, jboolean *isCopy); + jbyte * (JNICALL *GetByteArrayElements) + (JNIEnv *env, jbyteArray array, jboolean *isCopy); + jchar * (JNICALL *GetCharArrayElements) + (JNIEnv *env, jcharArray array, jboolean *isCopy); + jshort * (JNICALL *GetShortArrayElements) + (JNIEnv *env, jshortArray array, jboolean *isCopy); + jint * (JNICALL *GetIntArrayElements) + (JNIEnv *env, jintArray array, jboolean *isCopy); + jlong * (JNICALL *GetLongArrayElements) + (JNIEnv *env, jlongArray array, jboolean *isCopy); + jfloat * (JNICALL *GetFloatArrayElements) + (JNIEnv *env, jfloatArray array, jboolean *isCopy); + jdouble * (JNICALL *GetDoubleArrayElements) + (JNIEnv *env, jdoubleArray array, jboolean *isCopy); + + void (JNICALL *ReleaseBooleanArrayElements) + (JNIEnv *env, jbooleanArray array, jboolean *elems, jint mode); + void (JNICALL *ReleaseByteArrayElements) + (JNIEnv *env, jbyteArray array, jbyte *elems, jint mode); + void (JNICALL *ReleaseCharArrayElements) + (JNIEnv *env, jcharArray array, jchar *elems, jint mode); + void (JNICALL *ReleaseShortArrayElements) + (JNIEnv *env, jshortArray array, jshort *elems, jint mode); + void (JNICALL *ReleaseIntArrayElements) + (JNIEnv *env, jintArray array, jint *elems, jint mode); + void (JNICALL *ReleaseLongArrayElements) + (JNIEnv *env, jlongArray array, jlong *elems, jint mode); + void (JNICALL *ReleaseFloatArrayElements) + (JNIEnv *env, jfloatArray array, jfloat *elems, jint mode); + void (JNICALL *ReleaseDoubleArrayElements) + (JNIEnv *env, jdoubleArray array, jdouble *elems, jint mode); + + void (JNICALL *GetBooleanArrayRegion) + (JNIEnv *env, jbooleanArray array, jsize start, jsize l, jboolean *buf); + void (JNICALL *GetByteArrayRegion) + (JNIEnv *env, jbyteArray array, jsize start, jsize len, jbyte *buf); + void (JNICALL *GetCharArrayRegion) + (JNIEnv *env, jcharArray array, jsize start, jsize len, jchar *buf); + void (JNICALL *GetShortArrayRegion) + (JNIEnv *env, jshortArray array, jsize start, jsize len, jshort *buf); + void (JNICALL *GetIntArrayRegion) + (JNIEnv *env, jintArray array, jsize start, jsize len, jint *buf); + void (JNICALL *GetLongArrayRegion) + (JNIEnv *env, jlongArray array, jsize start, jsize len, jlong *buf); + void (JNICALL *GetFloatArrayRegion) + (JNIEnv *env, jfloatArray array, jsize start, jsize len, jfloat *buf); + void (JNICALL *GetDoubleArrayRegion) + (JNIEnv *env, jdoubleArray array, jsize start, jsize len, jdouble *buf); + + void (JNICALL *SetBooleanArrayRegion) + (JNIEnv *env, jbooleanArray array, jsize start, jsize l, const jboolean *buf); + void (JNICALL *SetByteArrayRegion) + (JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte *buf); + void (JNICALL *SetCharArrayRegion) + (JNIEnv *env, jcharArray array, jsize start, jsize len, const jchar *buf); + void (JNICALL *SetShortArrayRegion) + (JNIEnv *env, jshortArray array, jsize start, jsize len, const jshort *buf); + void (JNICALL *SetIntArrayRegion) + (JNIEnv *env, jintArray array, jsize start, jsize len, const jint *buf); + void (JNICALL *SetLongArrayRegion) + (JNIEnv *env, jlongArray array, jsize start, jsize len, const jlong *buf); + void (JNICALL *SetFloatArrayRegion) + (JNIEnv *env, jfloatArray array, jsize start, jsize len, const jfloat *buf); + void (JNICALL *SetDoubleArrayRegion) + (JNIEnv *env, jdoubleArray array, jsize start, jsize len, const jdouble *buf); + + jint (JNICALL *RegisterNatives) + (JNIEnv *env, jclass clazz, const JNINativeMethod *methods, + jint nMethods); + jint (JNICALL *UnregisterNatives) + (JNIEnv *env, jclass clazz); + + jint (JNICALL *MonitorEnter) + (JNIEnv *env, jobject obj); + jint (JNICALL *MonitorExit) + (JNIEnv *env, jobject obj); + + jint (JNICALL *GetJavaVM) + (JNIEnv *env, JavaVM **vm); + + void (JNICALL *GetStringRegion) + (JNIEnv *env, jstring str, jsize start, jsize len, jchar *buf); + void (JNICALL *GetStringUTFRegion) + (JNIEnv *env, jstring str, jsize start, jsize len, char *buf); + + void * (JNICALL *GetPrimitiveArrayCritical) + (JNIEnv *env, jarray array, jboolean *isCopy); + void (JNICALL *ReleasePrimitiveArrayCritical) + (JNIEnv *env, jarray array, void *carray, jint mode); + + const jchar * (JNICALL *GetStringCritical) + (JNIEnv *env, jstring string, jboolean *isCopy); + void (JNICALL *ReleaseStringCritical) + (JNIEnv *env, jstring string, const jchar *cstring); + + jweak (JNICALL *NewWeakGlobalRef) + (JNIEnv *env, jobject obj); + void (JNICALL *DeleteWeakGlobalRef) + (JNIEnv *env, jweak ref); + + jboolean (JNICALL *ExceptionCheck) + (JNIEnv *env); + + jobject (JNICALL *NewDirectByteBuffer) + (JNIEnv* env, void* address, jlong capacity); + void* (JNICALL *GetDirectBufferAddress) + (JNIEnv* env, jobject buf); + jlong (JNICALL *GetDirectBufferCapacity) + (JNIEnv* env, jobject buf); + + /* New JNI 1.6 Features */ + + jobjectRefType (JNICALL *GetObjectRefType) + (JNIEnv* env, jobject obj); + + /* Module Features */ + + jobject (JNICALL *GetModule) + (JNIEnv* env, jclass clazz); + + /* Virtual threads */ + + jboolean (JNICALL *IsVirtualThread) + (JNIEnv* env, jobject obj); +}; + +/* + * We use inlined functions for C++ so that programmers can write: + * + * env->FindClass("java/lang/String") + * + * in C++ rather than: + * + * (*env)->FindClass(env, "java/lang/String") + * + * in C. + */ + +struct JNIEnv_ { + const struct JNINativeInterface_ *functions; +#ifdef __cplusplus + + jint GetVersion() { + return functions->GetVersion(this); + } + jclass DefineClass(const char *name, jobject loader, const jbyte *buf, + jsize len) { + return functions->DefineClass(this, name, loader, buf, len); + } + jclass FindClass(const char *name) { + return functions->FindClass(this, name); + } + jmethodID FromReflectedMethod(jobject method) { + return functions->FromReflectedMethod(this,method); + } + jfieldID FromReflectedField(jobject field) { + return functions->FromReflectedField(this,field); + } + + jobject ToReflectedMethod(jclass cls, jmethodID methodID, jboolean isStatic) { + return functions->ToReflectedMethod(this, cls, methodID, isStatic); + } + + jclass GetSuperclass(jclass sub) { + return functions->GetSuperclass(this, sub); + } + jboolean IsAssignableFrom(jclass sub, jclass sup) { + return functions->IsAssignableFrom(this, sub, sup); + } + + jobject ToReflectedField(jclass cls, jfieldID fieldID, jboolean isStatic) { + return functions->ToReflectedField(this,cls,fieldID,isStatic); + } + + jint Throw(jthrowable obj) { + return functions->Throw(this, obj); + } + jint ThrowNew(jclass clazz, const char *msg) { + return functions->ThrowNew(this, clazz, msg); + } + jthrowable ExceptionOccurred() { + return functions->ExceptionOccurred(this); + } + void ExceptionDescribe() { + functions->ExceptionDescribe(this); + } + void ExceptionClear() { + functions->ExceptionClear(this); + } + void FatalError(const char *msg) { + functions->FatalError(this, msg); + } + + jint PushLocalFrame(jint capacity) { + return functions->PushLocalFrame(this,capacity); + } + jobject PopLocalFrame(jobject result) { + return functions->PopLocalFrame(this,result); + } + + jobject NewGlobalRef(jobject lobj) { + return functions->NewGlobalRef(this,lobj); + } + void DeleteGlobalRef(jobject gref) { + functions->DeleteGlobalRef(this,gref); + } + void DeleteLocalRef(jobject obj) { + functions->DeleteLocalRef(this, obj); + } + + jboolean IsSameObject(jobject obj1, jobject obj2) { + return functions->IsSameObject(this,obj1,obj2); + } + + jobject NewLocalRef(jobject ref) { + return functions->NewLocalRef(this,ref); + } + jint EnsureLocalCapacity(jint capacity) { + return functions->EnsureLocalCapacity(this,capacity); + } + + jobject AllocObject(jclass clazz) { + return functions->AllocObject(this,clazz); + } + jobject NewObject(jclass clazz, jmethodID methodID, ...) { + va_list args; + jobject result; + va_start(args, methodID); + result = functions->NewObjectV(this,clazz,methodID,args); + va_end(args); + return result; + } + jobject NewObjectV(jclass clazz, jmethodID methodID, + va_list args) { + return functions->NewObjectV(this,clazz,methodID,args); + } + jobject NewObjectA(jclass clazz, jmethodID methodID, + const jvalue *args) { + return functions->NewObjectA(this,clazz,methodID,args); + } + + jclass GetObjectClass(jobject obj) { + return functions->GetObjectClass(this,obj); + } + jboolean IsInstanceOf(jobject obj, jclass clazz) { + return functions->IsInstanceOf(this,obj,clazz); + } + + jmethodID GetMethodID(jclass clazz, const char *name, + const char *sig) { + return functions->GetMethodID(this,clazz,name,sig); + } + + jobject CallObjectMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jobject result; + va_start(args,methodID); + result = functions->CallObjectMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jobject CallObjectMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallObjectMethodV(this,obj,methodID,args); + } + jobject CallObjectMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallObjectMethodA(this,obj,methodID,args); + } + + jboolean CallBooleanMethod(jobject obj, + jmethodID methodID, ...) { + va_list args; + jboolean result; + va_start(args,methodID); + result = functions->CallBooleanMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jboolean CallBooleanMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallBooleanMethodV(this,obj,methodID,args); + } + jboolean CallBooleanMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallBooleanMethodA(this,obj,methodID, args); + } + + jbyte CallByteMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jbyte result; + va_start(args,methodID); + result = functions->CallByteMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jbyte CallByteMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallByteMethodV(this,obj,methodID,args); + } + jbyte CallByteMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallByteMethodA(this,obj,methodID,args); + } + + jchar CallCharMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jchar result; + va_start(args,methodID); + result = functions->CallCharMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jchar CallCharMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallCharMethodV(this,obj,methodID,args); + } + jchar CallCharMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallCharMethodA(this,obj,methodID,args); + } + + jshort CallShortMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jshort result; + va_start(args,methodID); + result = functions->CallShortMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jshort CallShortMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallShortMethodV(this,obj,methodID,args); + } + jshort CallShortMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallShortMethodA(this,obj,methodID,args); + } + + jint CallIntMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jint result; + va_start(args,methodID); + result = functions->CallIntMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jint CallIntMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallIntMethodV(this,obj,methodID,args); + } + jint CallIntMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallIntMethodA(this,obj,methodID,args); + } + + jlong CallLongMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jlong result; + va_start(args,methodID); + result = functions->CallLongMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jlong CallLongMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallLongMethodV(this,obj,methodID,args); + } + jlong CallLongMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallLongMethodA(this,obj,methodID,args); + } + + jfloat CallFloatMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jfloat result; + va_start(args,methodID); + result = functions->CallFloatMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jfloat CallFloatMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallFloatMethodV(this,obj,methodID,args); + } + jfloat CallFloatMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallFloatMethodA(this,obj,methodID,args); + } + + jdouble CallDoubleMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + jdouble result; + va_start(args,methodID); + result = functions->CallDoubleMethodV(this,obj,methodID,args); + va_end(args); + return result; + } + jdouble CallDoubleMethodV(jobject obj, jmethodID methodID, + va_list args) { + return functions->CallDoubleMethodV(this,obj,methodID,args); + } + jdouble CallDoubleMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + return functions->CallDoubleMethodA(this,obj,methodID,args); + } + + void CallVoidMethod(jobject obj, jmethodID methodID, ...) { + va_list args; + va_start(args,methodID); + functions->CallVoidMethodV(this,obj,methodID,args); + va_end(args); + } + void CallVoidMethodV(jobject obj, jmethodID methodID, + va_list args) { + functions->CallVoidMethodV(this,obj,methodID,args); + } + void CallVoidMethodA(jobject obj, jmethodID methodID, + const jvalue * args) { + functions->CallVoidMethodA(this,obj,methodID,args); + } + + jobject CallNonvirtualObjectMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jobject result; + va_start(args,methodID); + result = functions->CallNonvirtualObjectMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jobject CallNonvirtualObjectMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualObjectMethodV(this,obj,clazz, + methodID,args); + } + jobject CallNonvirtualObjectMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualObjectMethodA(this,obj,clazz, + methodID,args); + } + + jboolean CallNonvirtualBooleanMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jboolean result; + va_start(args,methodID); + result = functions->CallNonvirtualBooleanMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jboolean CallNonvirtualBooleanMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualBooleanMethodV(this,obj,clazz, + methodID,args); + } + jboolean CallNonvirtualBooleanMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualBooleanMethodA(this,obj,clazz, + methodID, args); + } + + jbyte CallNonvirtualByteMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jbyte result; + va_start(args,methodID); + result = functions->CallNonvirtualByteMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jbyte CallNonvirtualByteMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualByteMethodV(this,obj,clazz, + methodID,args); + } + jbyte CallNonvirtualByteMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualByteMethodA(this,obj,clazz, + methodID,args); + } + + jchar CallNonvirtualCharMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jchar result; + va_start(args,methodID); + result = functions->CallNonvirtualCharMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jchar CallNonvirtualCharMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualCharMethodV(this,obj,clazz, + methodID,args); + } + jchar CallNonvirtualCharMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualCharMethodA(this,obj,clazz, + methodID,args); + } + + jshort CallNonvirtualShortMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jshort result; + va_start(args,methodID); + result = functions->CallNonvirtualShortMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jshort CallNonvirtualShortMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualShortMethodV(this,obj,clazz, + methodID,args); + } + jshort CallNonvirtualShortMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualShortMethodA(this,obj,clazz, + methodID,args); + } + + jint CallNonvirtualIntMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jint result; + va_start(args,methodID); + result = functions->CallNonvirtualIntMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jint CallNonvirtualIntMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualIntMethodV(this,obj,clazz, + methodID,args); + } + jint CallNonvirtualIntMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualIntMethodA(this,obj,clazz, + methodID,args); + } + + jlong CallNonvirtualLongMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jlong result; + va_start(args,methodID); + result = functions->CallNonvirtualLongMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jlong CallNonvirtualLongMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallNonvirtualLongMethodV(this,obj,clazz, + methodID,args); + } + jlong CallNonvirtualLongMethodA(jobject obj, jclass clazz, + jmethodID methodID, const jvalue * args) { + return functions->CallNonvirtualLongMethodA(this,obj,clazz, + methodID,args); + } + + jfloat CallNonvirtualFloatMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jfloat result; + va_start(args,methodID); + result = functions->CallNonvirtualFloatMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jfloat CallNonvirtualFloatMethodV(jobject obj, jclass clazz, + jmethodID methodID, + va_list args) { + return functions->CallNonvirtualFloatMethodV(this,obj,clazz, + methodID,args); + } + jfloat CallNonvirtualFloatMethodA(jobject obj, jclass clazz, + jmethodID methodID, + const jvalue * args) { + return functions->CallNonvirtualFloatMethodA(this,obj,clazz, + methodID,args); + } + + jdouble CallNonvirtualDoubleMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + jdouble result; + va_start(args,methodID); + result = functions->CallNonvirtualDoubleMethodV(this,obj,clazz, + methodID,args); + va_end(args); + return result; + } + jdouble CallNonvirtualDoubleMethodV(jobject obj, jclass clazz, + jmethodID methodID, + va_list args) { + return functions->CallNonvirtualDoubleMethodV(this,obj,clazz, + methodID,args); + } + jdouble CallNonvirtualDoubleMethodA(jobject obj, jclass clazz, + jmethodID methodID, + const jvalue * args) { + return functions->CallNonvirtualDoubleMethodA(this,obj,clazz, + methodID,args); + } + + void CallNonvirtualVoidMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) { + va_list args; + va_start(args,methodID); + functions->CallNonvirtualVoidMethodV(this,obj,clazz,methodID,args); + va_end(args); + } + void CallNonvirtualVoidMethodV(jobject obj, jclass clazz, + jmethodID methodID, + va_list args) { + functions->CallNonvirtualVoidMethodV(this,obj,clazz,methodID,args); + } + void CallNonvirtualVoidMethodA(jobject obj, jclass clazz, + jmethodID methodID, + const jvalue * args) { + functions->CallNonvirtualVoidMethodA(this,obj,clazz,methodID,args); + } + + jfieldID GetFieldID(jclass clazz, const char *name, + const char *sig) { + return functions->GetFieldID(this,clazz,name,sig); + } + + jobject GetObjectField(jobject obj, jfieldID fieldID) { + return functions->GetObjectField(this,obj,fieldID); + } + jboolean GetBooleanField(jobject obj, jfieldID fieldID) { + return functions->GetBooleanField(this,obj,fieldID); + } + jbyte GetByteField(jobject obj, jfieldID fieldID) { + return functions->GetByteField(this,obj,fieldID); + } + jchar GetCharField(jobject obj, jfieldID fieldID) { + return functions->GetCharField(this,obj,fieldID); + } + jshort GetShortField(jobject obj, jfieldID fieldID) { + return functions->GetShortField(this,obj,fieldID); + } + jint GetIntField(jobject obj, jfieldID fieldID) { + return functions->GetIntField(this,obj,fieldID); + } + jlong GetLongField(jobject obj, jfieldID fieldID) { + return functions->GetLongField(this,obj,fieldID); + } + jfloat GetFloatField(jobject obj, jfieldID fieldID) { + return functions->GetFloatField(this,obj,fieldID); + } + jdouble GetDoubleField(jobject obj, jfieldID fieldID) { + return functions->GetDoubleField(this,obj,fieldID); + } + + void SetObjectField(jobject obj, jfieldID fieldID, jobject val) { + functions->SetObjectField(this,obj,fieldID,val); + } + void SetBooleanField(jobject obj, jfieldID fieldID, + jboolean val) { + functions->SetBooleanField(this,obj,fieldID,val); + } + void SetByteField(jobject obj, jfieldID fieldID, + jbyte val) { + functions->SetByteField(this,obj,fieldID,val); + } + void SetCharField(jobject obj, jfieldID fieldID, + jchar val) { + functions->SetCharField(this,obj,fieldID,val); + } + void SetShortField(jobject obj, jfieldID fieldID, + jshort val) { + functions->SetShortField(this,obj,fieldID,val); + } + void SetIntField(jobject obj, jfieldID fieldID, + jint val) { + functions->SetIntField(this,obj,fieldID,val); + } + void SetLongField(jobject obj, jfieldID fieldID, + jlong val) { + functions->SetLongField(this,obj,fieldID,val); + } + void SetFloatField(jobject obj, jfieldID fieldID, + jfloat val) { + functions->SetFloatField(this,obj,fieldID,val); + } + void SetDoubleField(jobject obj, jfieldID fieldID, + jdouble val) { + functions->SetDoubleField(this,obj,fieldID,val); + } + + jmethodID GetStaticMethodID(jclass clazz, const char *name, + const char *sig) { + return functions->GetStaticMethodID(this,clazz,name,sig); + } + + jobject CallStaticObjectMethod(jclass clazz, jmethodID methodID, + ...) { + va_list args; + jobject result; + va_start(args,methodID); + result = functions->CallStaticObjectMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jobject CallStaticObjectMethodV(jclass clazz, jmethodID methodID, + va_list args) { + return functions->CallStaticObjectMethodV(this,clazz,methodID,args); + } + jobject CallStaticObjectMethodA(jclass clazz, jmethodID methodID, + const jvalue *args) { + return functions->CallStaticObjectMethodA(this,clazz,methodID,args); + } + + jboolean CallStaticBooleanMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jboolean result; + va_start(args,methodID); + result = functions->CallStaticBooleanMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jboolean CallStaticBooleanMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticBooleanMethodV(this,clazz,methodID,args); + } + jboolean CallStaticBooleanMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticBooleanMethodA(this,clazz,methodID,args); + } + + jbyte CallStaticByteMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jbyte result; + va_start(args,methodID); + result = functions->CallStaticByteMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jbyte CallStaticByteMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticByteMethodV(this,clazz,methodID,args); + } + jbyte CallStaticByteMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticByteMethodA(this,clazz,methodID,args); + } + + jchar CallStaticCharMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jchar result; + va_start(args,methodID); + result = functions->CallStaticCharMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jchar CallStaticCharMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticCharMethodV(this,clazz,methodID,args); + } + jchar CallStaticCharMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticCharMethodA(this,clazz,methodID,args); + } + + jshort CallStaticShortMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jshort result; + va_start(args,methodID); + result = functions->CallStaticShortMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jshort CallStaticShortMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticShortMethodV(this,clazz,methodID,args); + } + jshort CallStaticShortMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticShortMethodA(this,clazz,methodID,args); + } + + jint CallStaticIntMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jint result; + va_start(args,methodID); + result = functions->CallStaticIntMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jint CallStaticIntMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticIntMethodV(this,clazz,methodID,args); + } + jint CallStaticIntMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticIntMethodA(this,clazz,methodID,args); + } + + jlong CallStaticLongMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jlong result; + va_start(args,methodID); + result = functions->CallStaticLongMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jlong CallStaticLongMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticLongMethodV(this,clazz,methodID,args); + } + jlong CallStaticLongMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticLongMethodA(this,clazz,methodID,args); + } + + jfloat CallStaticFloatMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jfloat result; + va_start(args,methodID); + result = functions->CallStaticFloatMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jfloat CallStaticFloatMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticFloatMethodV(this,clazz,methodID,args); + } + jfloat CallStaticFloatMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticFloatMethodA(this,clazz,methodID,args); + } + + jdouble CallStaticDoubleMethod(jclass clazz, + jmethodID methodID, ...) { + va_list args; + jdouble result; + va_start(args,methodID); + result = functions->CallStaticDoubleMethodV(this,clazz,methodID,args); + va_end(args); + return result; + } + jdouble CallStaticDoubleMethodV(jclass clazz, + jmethodID methodID, va_list args) { + return functions->CallStaticDoubleMethodV(this,clazz,methodID,args); + } + jdouble CallStaticDoubleMethodA(jclass clazz, + jmethodID methodID, const jvalue *args) { + return functions->CallStaticDoubleMethodA(this,clazz,methodID,args); + } + + void CallStaticVoidMethod(jclass cls, jmethodID methodID, ...) { + va_list args; + va_start(args,methodID); + functions->CallStaticVoidMethodV(this,cls,methodID,args); + va_end(args); + } + void CallStaticVoidMethodV(jclass cls, jmethodID methodID, + va_list args) { + functions->CallStaticVoidMethodV(this,cls,methodID,args); + } + void CallStaticVoidMethodA(jclass cls, jmethodID methodID, + const jvalue * args) { + functions->CallStaticVoidMethodA(this,cls,methodID,args); + } + + jfieldID GetStaticFieldID(jclass clazz, const char *name, + const char *sig) { + return functions->GetStaticFieldID(this,clazz,name,sig); + } + jobject GetStaticObjectField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticObjectField(this,clazz,fieldID); + } + jboolean GetStaticBooleanField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticBooleanField(this,clazz,fieldID); + } + jbyte GetStaticByteField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticByteField(this,clazz,fieldID); + } + jchar GetStaticCharField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticCharField(this,clazz,fieldID); + } + jshort GetStaticShortField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticShortField(this,clazz,fieldID); + } + jint GetStaticIntField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticIntField(this,clazz,fieldID); + } + jlong GetStaticLongField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticLongField(this,clazz,fieldID); + } + jfloat GetStaticFloatField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticFloatField(this,clazz,fieldID); + } + jdouble GetStaticDoubleField(jclass clazz, jfieldID fieldID) { + return functions->GetStaticDoubleField(this,clazz,fieldID); + } + + void SetStaticObjectField(jclass clazz, jfieldID fieldID, + jobject value) { + functions->SetStaticObjectField(this,clazz,fieldID,value); + } + void SetStaticBooleanField(jclass clazz, jfieldID fieldID, + jboolean value) { + functions->SetStaticBooleanField(this,clazz,fieldID,value); + } + void SetStaticByteField(jclass clazz, jfieldID fieldID, + jbyte value) { + functions->SetStaticByteField(this,clazz,fieldID,value); + } + void SetStaticCharField(jclass clazz, jfieldID fieldID, + jchar value) { + functions->SetStaticCharField(this,clazz,fieldID,value); + } + void SetStaticShortField(jclass clazz, jfieldID fieldID, + jshort value) { + functions->SetStaticShortField(this,clazz,fieldID,value); + } + void SetStaticIntField(jclass clazz, jfieldID fieldID, + jint value) { + functions->SetStaticIntField(this,clazz,fieldID,value); + } + void SetStaticLongField(jclass clazz, jfieldID fieldID, + jlong value) { + functions->SetStaticLongField(this,clazz,fieldID,value); + } + void SetStaticFloatField(jclass clazz, jfieldID fieldID, + jfloat value) { + functions->SetStaticFloatField(this,clazz,fieldID,value); + } + void SetStaticDoubleField(jclass clazz, jfieldID fieldID, + jdouble value) { + functions->SetStaticDoubleField(this,clazz,fieldID,value); + } + + jstring NewString(const jchar *unicode, jsize len) { + return functions->NewString(this,unicode,len); + } + jsize GetStringLength(jstring str) { + return functions->GetStringLength(this,str); + } + const jchar *GetStringChars(jstring str, jboolean *isCopy) { + return functions->GetStringChars(this,str,isCopy); + } + void ReleaseStringChars(jstring str, const jchar *chars) { + functions->ReleaseStringChars(this,str,chars); + } + + jstring NewStringUTF(const char *utf) { + return functions->NewStringUTF(this,utf); + } + jsize GetStringUTFLength(jstring str) { + return functions->GetStringUTFLength(this,str); + } + const char* GetStringUTFChars(jstring str, jboolean *isCopy) { + return functions->GetStringUTFChars(this,str,isCopy); + } + void ReleaseStringUTFChars(jstring str, const char* chars) { + functions->ReleaseStringUTFChars(this,str,chars); + } + + jsize GetArrayLength(jarray array) { + return functions->GetArrayLength(this,array); + } + + jobjectArray NewObjectArray(jsize len, jclass clazz, + jobject init) { + return functions->NewObjectArray(this,len,clazz,init); + } + jobject GetObjectArrayElement(jobjectArray array, jsize index) { + return functions->GetObjectArrayElement(this,array,index); + } + void SetObjectArrayElement(jobjectArray array, jsize index, + jobject val) { + functions->SetObjectArrayElement(this,array,index,val); + } + + jbooleanArray NewBooleanArray(jsize len) { + return functions->NewBooleanArray(this,len); + } + jbyteArray NewByteArray(jsize len) { + return functions->NewByteArray(this,len); + } + jcharArray NewCharArray(jsize len) { + return functions->NewCharArray(this,len); + } + jshortArray NewShortArray(jsize len) { + return functions->NewShortArray(this,len); + } + jintArray NewIntArray(jsize len) { + return functions->NewIntArray(this,len); + } + jlongArray NewLongArray(jsize len) { + return functions->NewLongArray(this,len); + } + jfloatArray NewFloatArray(jsize len) { + return functions->NewFloatArray(this,len); + } + jdoubleArray NewDoubleArray(jsize len) { + return functions->NewDoubleArray(this,len); + } + + jboolean * GetBooleanArrayElements(jbooleanArray array, jboolean *isCopy) { + return functions->GetBooleanArrayElements(this,array,isCopy); + } + jbyte * GetByteArrayElements(jbyteArray array, jboolean *isCopy) { + return functions->GetByteArrayElements(this,array,isCopy); + } + jchar * GetCharArrayElements(jcharArray array, jboolean *isCopy) { + return functions->GetCharArrayElements(this,array,isCopy); + } + jshort * GetShortArrayElements(jshortArray array, jboolean *isCopy) { + return functions->GetShortArrayElements(this,array,isCopy); + } + jint * GetIntArrayElements(jintArray array, jboolean *isCopy) { + return functions->GetIntArrayElements(this,array,isCopy); + } + jlong * GetLongArrayElements(jlongArray array, jboolean *isCopy) { + return functions->GetLongArrayElements(this,array,isCopy); + } + jfloat * GetFloatArrayElements(jfloatArray array, jboolean *isCopy) { + return functions->GetFloatArrayElements(this,array,isCopy); + } + jdouble * GetDoubleArrayElements(jdoubleArray array, jboolean *isCopy) { + return functions->GetDoubleArrayElements(this,array,isCopy); + } + + void ReleaseBooleanArrayElements(jbooleanArray array, + jboolean *elems, + jint mode) { + functions->ReleaseBooleanArrayElements(this,array,elems,mode); + } + void ReleaseByteArrayElements(jbyteArray array, + jbyte *elems, + jint mode) { + functions->ReleaseByteArrayElements(this,array,elems,mode); + } + void ReleaseCharArrayElements(jcharArray array, + jchar *elems, + jint mode) { + functions->ReleaseCharArrayElements(this,array,elems,mode); + } + void ReleaseShortArrayElements(jshortArray array, + jshort *elems, + jint mode) { + functions->ReleaseShortArrayElements(this,array,elems,mode); + } + void ReleaseIntArrayElements(jintArray array, + jint *elems, + jint mode) { + functions->ReleaseIntArrayElements(this,array,elems,mode); + } + void ReleaseLongArrayElements(jlongArray array, + jlong *elems, + jint mode) { + functions->ReleaseLongArrayElements(this,array,elems,mode); + } + void ReleaseFloatArrayElements(jfloatArray array, + jfloat *elems, + jint mode) { + functions->ReleaseFloatArrayElements(this,array,elems,mode); + } + void ReleaseDoubleArrayElements(jdoubleArray array, + jdouble *elems, + jint mode) { + functions->ReleaseDoubleArrayElements(this,array,elems,mode); + } + + void GetBooleanArrayRegion(jbooleanArray array, + jsize start, jsize len, jboolean *buf) { + functions->GetBooleanArrayRegion(this,array,start,len,buf); + } + void GetByteArrayRegion(jbyteArray array, + jsize start, jsize len, jbyte *buf) { + functions->GetByteArrayRegion(this,array,start,len,buf); + } + void GetCharArrayRegion(jcharArray array, + jsize start, jsize len, jchar *buf) { + functions->GetCharArrayRegion(this,array,start,len,buf); + } + void GetShortArrayRegion(jshortArray array, + jsize start, jsize len, jshort *buf) { + functions->GetShortArrayRegion(this,array,start,len,buf); + } + void GetIntArrayRegion(jintArray array, + jsize start, jsize len, jint *buf) { + functions->GetIntArrayRegion(this,array,start,len,buf); + } + void GetLongArrayRegion(jlongArray array, + jsize start, jsize len, jlong *buf) { + functions->GetLongArrayRegion(this,array,start,len,buf); + } + void GetFloatArrayRegion(jfloatArray array, + jsize start, jsize len, jfloat *buf) { + functions->GetFloatArrayRegion(this,array,start,len,buf); + } + void GetDoubleArrayRegion(jdoubleArray array, + jsize start, jsize len, jdouble *buf) { + functions->GetDoubleArrayRegion(this,array,start,len,buf); + } + + void SetBooleanArrayRegion(jbooleanArray array, jsize start, jsize len, + const jboolean *buf) { + functions->SetBooleanArrayRegion(this,array,start,len,buf); + } + void SetByteArrayRegion(jbyteArray array, jsize start, jsize len, + const jbyte *buf) { + functions->SetByteArrayRegion(this,array,start,len,buf); + } + void SetCharArrayRegion(jcharArray array, jsize start, jsize len, + const jchar *buf) { + functions->SetCharArrayRegion(this,array,start,len,buf); + } + void SetShortArrayRegion(jshortArray array, jsize start, jsize len, + const jshort *buf) { + functions->SetShortArrayRegion(this,array,start,len,buf); + } + void SetIntArrayRegion(jintArray array, jsize start, jsize len, + const jint *buf) { + functions->SetIntArrayRegion(this,array,start,len,buf); + } + void SetLongArrayRegion(jlongArray array, jsize start, jsize len, + const jlong *buf) { + functions->SetLongArrayRegion(this,array,start,len,buf); + } + void SetFloatArrayRegion(jfloatArray array, jsize start, jsize len, + const jfloat *buf) { + functions->SetFloatArrayRegion(this,array,start,len,buf); + } + void SetDoubleArrayRegion(jdoubleArray array, jsize start, jsize len, + const jdouble *buf) { + functions->SetDoubleArrayRegion(this,array,start,len,buf); + } + + jint RegisterNatives(jclass clazz, const JNINativeMethod *methods, + jint nMethods) { + return functions->RegisterNatives(this,clazz,methods,nMethods); + } + jint UnregisterNatives(jclass clazz) { + return functions->UnregisterNatives(this,clazz); + } + + jint MonitorEnter(jobject obj) { + return functions->MonitorEnter(this,obj); + } + jint MonitorExit(jobject obj) { + return functions->MonitorExit(this,obj); + } + + jint GetJavaVM(JavaVM **vm) { + return functions->GetJavaVM(this,vm); + } + + void GetStringRegion(jstring str, jsize start, jsize len, jchar *buf) { + functions->GetStringRegion(this,str,start,len,buf); + } + void GetStringUTFRegion(jstring str, jsize start, jsize len, char *buf) { + functions->GetStringUTFRegion(this,str,start,len,buf); + } + + void * GetPrimitiveArrayCritical(jarray array, jboolean *isCopy) { + return functions->GetPrimitiveArrayCritical(this,array,isCopy); + } + void ReleasePrimitiveArrayCritical(jarray array, void *carray, jint mode) { + functions->ReleasePrimitiveArrayCritical(this,array,carray,mode); + } + + const jchar * GetStringCritical(jstring string, jboolean *isCopy) { + return functions->GetStringCritical(this,string,isCopy); + } + void ReleaseStringCritical(jstring string, const jchar *cstring) { + functions->ReleaseStringCritical(this,string,cstring); + } + + jweak NewWeakGlobalRef(jobject obj) { + return functions->NewWeakGlobalRef(this,obj); + } + void DeleteWeakGlobalRef(jweak ref) { + functions->DeleteWeakGlobalRef(this,ref); + } + + jboolean ExceptionCheck() { + return functions->ExceptionCheck(this); + } + + jobject NewDirectByteBuffer(void* address, jlong capacity) { + return functions->NewDirectByteBuffer(this, address, capacity); + } + void* GetDirectBufferAddress(jobject buf) { + return functions->GetDirectBufferAddress(this, buf); + } + jlong GetDirectBufferCapacity(jobject buf) { + return functions->GetDirectBufferCapacity(this, buf); + } + jobjectRefType GetObjectRefType(jobject obj) { + return functions->GetObjectRefType(this, obj); + } + + /* Module Features */ + + jobject GetModule(jclass clazz) { + return functions->GetModule(this, clazz); + } + + /* Virtual threads */ + + jboolean IsVirtualThread(jobject obj) { + return functions->IsVirtualThread(this, obj); + } + +#endif /* __cplusplus */ +}; + +/* + * optionString may be any option accepted by the JVM, or one of the + * following: + * + * -D= Set a system property. + * -verbose[:class|gc|jni] Enable verbose output, comma-separated. E.g. + * "-verbose:class" or "-verbose:gc,class" + * Standard names include: gc, class, and jni. + * All nonstandard (VM-specific) names must begin + * with "X". + * vfprintf extraInfo is a pointer to the vfprintf hook. + * exit extraInfo is a pointer to the exit hook. + * abort extraInfo is a pointer to the abort hook. + */ +typedef struct JavaVMOption { + char *optionString; + void *extraInfo; +} JavaVMOption; + +typedef struct JavaVMInitArgs { + jint version; + + jint nOptions; + JavaVMOption *options; + jboolean ignoreUnrecognized; +} JavaVMInitArgs; + +typedef struct JavaVMAttachArgs { + jint version; + + char *name; + jobject group; +} JavaVMAttachArgs; + +/* These will be VM-specific. */ + +#define JDK1_2 +#define JDK1_4 + +/* End VM-specific. */ + +struct JNIInvokeInterface_ { + void *reserved0; + void *reserved1; + void *reserved2; + + jint (JNICALL *DestroyJavaVM)(JavaVM *vm); + + jint (JNICALL *AttachCurrentThread)(JavaVM *vm, void **penv, void *args); + + jint (JNICALL *DetachCurrentThread)(JavaVM *vm); + + jint (JNICALL *GetEnv)(JavaVM *vm, void **penv, jint version); + + jint (JNICALL *AttachCurrentThreadAsDaemon)(JavaVM *vm, void **penv, void *args); +}; + +struct JavaVM_ { + const struct JNIInvokeInterface_ *functions; +#ifdef __cplusplus + + jint DestroyJavaVM() { + return functions->DestroyJavaVM(this); + } + jint AttachCurrentThread(void **penv, void *args) { + return functions->AttachCurrentThread(this, penv, args); + } + jint DetachCurrentThread() { + return functions->DetachCurrentThread(this); + } + + jint GetEnv(void **penv, jint version) { + return functions->GetEnv(this, penv, version); + } + jint AttachCurrentThreadAsDaemon(void **penv, void *args) { + return functions->AttachCurrentThreadAsDaemon(this, penv, args); + } +#endif +}; + +#ifdef _JNI_IMPLEMENTATION_ +#define _JNI_IMPORT_OR_EXPORT_ JNIEXPORT +#else +#define _JNI_IMPORT_OR_EXPORT_ JNIIMPORT +#endif +_JNI_IMPORT_OR_EXPORT_ jint JNICALL +JNI_GetDefaultJavaVMInitArgs(void *args); + +_JNI_IMPORT_OR_EXPORT_ jint JNICALL +JNI_CreateJavaVM(JavaVM **pvm, void **penv, void *args); + +_JNI_IMPORT_OR_EXPORT_ jint JNICALL +JNI_GetCreatedJavaVMs(JavaVM **, jsize, jsize *); + +/* Defined by native libraries. */ +JNIEXPORT jint JNICALL +JNI_OnLoad(JavaVM *vm, void *reserved); + +JNIEXPORT void JNICALL +JNI_OnUnload(JavaVM *vm, void *reserved); + +#define JNI_VERSION_1_1 0x00010001 +#define JNI_VERSION_1_2 0x00010002 +#define JNI_VERSION_1_4 0x00010004 +#define JNI_VERSION_1_6 0x00010006 +#define JNI_VERSION_1_8 0x00010008 +#define JNI_VERSION_9 0x00090000 +#define JNI_VERSION_10 0x000a0000 +#define JNI_VERSION_19 0x00130000 +#define JNI_VERSION_20 0x00140000 +#define JNI_VERSION_21 0x00150000 + +#ifdef __cplusplus +} /* extern "C" */ +#endif /* __cplusplus */ + +#endif /* !_JAVASOFT_JNI_H_ */ diff --git a/.github/include/windows/jni_md.h b/.github/include/windows/jni_md.h new file mode 100644 index 00000000..6c8d6b9e --- /dev/null +++ b/.github/include/windows/jni_md.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 1996, 1998, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +#ifndef _JAVASOFT_JNI_MD_H_ +#define _JAVASOFT_JNI_MD_H_ + +#define JNIEXPORT __declspec(dllexport) +#define JNIIMPORT __declspec(dllimport) +#define JNICALL __stdcall + +// 'long' is always 32 bit on windows so this matches what jdk expects +typedef long jint; +typedef __int64 jlong; +typedef signed char jbyte; + +#endif /* !_JAVASOFT_JNI_MD_H_ */ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..a15f809d --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,105 @@ +--- +name: Continuous Integration +on: + - pull_request + - workflow_dispatch +env: + MODEL_URL: https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf + MODEL_NAME: codellama-7b.Q2_K.gguf + RERANKING_MODEL_URL: https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf + RERANKING_MODEL_NAME: jina-reranker-v1-tiny-en-Q4_0.gguf +jobs: + + build-and-test-linux: + name: ubuntu-latest + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-java@v4 + with: + distribution: zulu + java-version: "11" + - name: Build libraries + run: | + mvn compile + .github/build.sh -DLLAMA_VERBOSE=ON + - name: Download text generation model + run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + - name: Download reranking model + run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + - name: List files in models directory + run: ls -l models/ + - name: Run tests + run: mvn test + - if: failure() + uses: actions/upload-artifact@v4 + with: + name: error-log-linux + path: ${{ github.workspace }}/hs_err_pid*.log + if-no-files-found: warn + + build-and-test-macos: + name: ${{ matrix.target.runner }} + runs-on: ${{ matrix.target.runner }} + strategy: + fail-fast: false + matrix: + target: + - runner: macos-13 + cmake: -DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON + - runner: macos-14 + cmake: -DLLAMA_METAL_EMBED_LIBRARY=ON -DLLAMA_VERBOSE=ON + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-java@v4 + with: + distribution: zulu + java-version: "11" + - name: Build libraries + run: | + mvn compile + .github/build.sh ${{ matrix.target.cmake }} + - name: Download text generaton model model + run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + - name: Download reranking model + run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + - name: List files in models directory + run: ls -l models/ + - name: Run tests + run: mvn test + - if: failure() + uses: actions/upload-artifact@v4 + with: + name: error-log-macos + path: ${{ github.workspace }}/hs_err_pid*.log + if-no-files-found: warn + + build-and-test-windows: + name: windows-2019 + runs-on: windows-2019 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-java@v4 + with: + distribution: 'zulu' + java-version: '11' + - name: Build libraries + run: | + mvn compile + .github\build.bat -DLLAMA_VERBOSE=ON + - name: Download model + run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME + - name: Download reranking model + run: curl -L $env:RERANKING_MODEL_URL --create-dirs -o models/$env:RERANKING_MODEL_NAME + - name: List files in models directory + run: ls -l models/ + - name: Run tests + run: mvn test + - if: failure() + uses: actions/upload-artifact@v4 + with: + name: windows-output + path: | + ${{ github.workspace }}\hs_err_pid*.log + ${{ github.workspace }}/src/main/resources/de/kherud/llama/**/* + if-no-files-found: warn diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index afe7f173..64032028 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -1,93 +1,217 @@ -name: Build JNI bindings +name: Release to Maven Central on: workflow_dispatch: + inputs: + build_only: + description: 'Whether to only build the project and skip releasing it (yes/NO)' + required: false + default: 'no' release: - types: [created] + types: [ created ] +env: + MODEL_URL: "https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf" + MODEL_NAME: "codellama-7b.Q2_K.gguf" + RERANKING_MODEL_URL: "https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf" + RERANKING_MODEL_NAME: "jina-reranker-v1-tiny-en-Q4_0.gguf" jobs: - build: - name: Build ${{ matrix.target.name }}-${{ matrix.target.arch }} - runs-on: ${{ matrix.target.image }} + +# todo: doesn't work with the newest llama.cpp version +# build-linux-cuda: +# name: Build Linux x86-64 CUDA12 +# runs-on: ubuntu-latest +# steps: +# - uses: actions/checkout@v4 +# - name: Build libraries +# shell: bash +# run: | +# .github/dockcross/dockcross-manylinux_2_28-x64 .github/build_cuda_linux.sh "-DOS_NAME=Linux -DOS_ARCH=x86_64" +# - name: Upload artifacts +# uses: actions/upload-artifact@v4 +# with: +# name: linux-libraries-cuda +# path: ${{ github.workspace }}/src/main/resources_linux_cuda/de/kherud/llama/ + + build-linux-docker: + name: Build ${{ matrix.target.os }}-${{ matrix.target.arch }} + runs-on: ubuntu-latest strategy: fail-fast: false matrix: target: - { - name: Linux, + os: Linux, arch: x86_64, - image: ubuntu-latest, - cmake: "-DCMAKE_CXX_FLAGS='-march=x86-64'" + image: dockcross-manylinux2014-x64, } - # - { - # name: Linux, - # arch: aarch64, - # image: ubuntu-latest, - # cmake: "-DCMAKE_CXX_FLAGS='-march=armv8-a'" - # } - { - name: Mac, - arch: x86_64, - image: macos-latest, - cmake: "-DCMAKE_OSX_ARCHITECTURES=x86_64" + os: Linux, + arch: aarch64, + image: dockcross-linux-arm64-lts, } - { - name: Mac, + os: Linux-Android, arch: aarch64, - image: macos-latest, - cmake: "-DCMAKE_OSX_ARCHITECTURES=arm64" + image: dockcross-android-arm64, + } + steps: + - uses: actions/checkout@v4 + - name: Build libraries + shell: bash + run: | + .github/dockcross/${{ matrix.target.image }} .github/build.sh "-DOS_NAME=${{ matrix.target.os }} -DOS_ARCH=${{ matrix.target.arch }}" + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: ${{ matrix.target.os }}-${{ matrix.target.arch }}-libraries + path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/ + + + build-macos-native: + name: Build ${{ matrix.target.runner }} + runs-on: ${{ matrix.target.runner }} + strategy: + fail-fast: false + matrix: + target: + - { + runner: macos-13, + cmake: '-DLLAMA_METAL=OFF' } - { - name: Windows, - arch: x86_64, - image: windows-latest, - cmake: "-DCMAKE_GENERATOR_PLATFORM=x64" + runner: macos-14, + cmake: '-DLLAMA_METAL_EMBED_LIBRARY=ON' } - # - { - # name: Windows, - # arch: aarch64, - # image: windows-latest, - # cmake: "-DCMAKE_GENERATOR_PLATFORM=ARM64" - # } steps: - - name: Checkout Repository - uses: actions/checkout@v4 + - uses: actions/checkout@v4 + - name: Build libraries + shell: bash + run: | + mvn compile + .github/build.sh ${{ matrix.target.cmake }} + - name: Upload artifacts + uses: actions/upload-artifact@v4 with: - submodules: recursive - ref: "jni" - - name: CMake (Windows) - if: ${{ matrix.target.name == 'Windows' }} + name: ${{ matrix.target.runner }}-libraries + path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/ + + + build-win-native: + name: Build ${{ matrix.target.os }}-${{ matrix.target.arch }} + runs-on: windows-2019 + strategy: + fail-fast: false + matrix: + target: + - { + os: Windows, + arch: x86_64, + cmake: '-G "Visual Studio 16 2019" -A "x64"' + } + - { + os: Windows, + arch: x86, + cmake: '-G "Visual Studio 16 2019" -A "Win32"' + } +# MSVC aarch64 builds no longer work with llama.cpp (requires clang instead) +# - { +# os: Windows, +# arch: aarch64, +# cmake: '-G "Visual Studio 16 2019" -A "ARM64"' +# } +# - { +# os: Windows, +# arch: arm, +# cmake: '-G "Visual Studio 16 2019" -A "ARM"' +# } + steps: + - uses: actions/checkout@v4 + - name: Build libraries shell: cmd run: | - cd scripts && .\build.bat ${{ matrix.target.cmake }} -DOS_NAME=${{ matrix.target.name }} -DOS_ARCH=${{ matrix.target.arch }} - - name: CMake (Unix) - if: ${{ matrix.target.name != 'Windows' }} - shell: bash - run: | - cd scripts && ./build.sh ${{ matrix.target.cmake }} -DOS_NAME=${{ matrix.target.name }} -DOS_ARCH=${{ matrix.target.arch }} - - name: Upload Unix Artifact - if: ${{ matrix.target.name == 'Windows' }} - uses: actions/upload-artifact@v3 + .github\build.bat ${{ matrix.target.cmake }} -DOS_NAME=${{ matrix.target.os }} -DOS_ARCH=${{ matrix.target.arch }} + - name: Upload artifacts + uses: actions/upload-artifact@v4 with: - name: artifacts - path: ${{ github.workspace }}\src\main\resources\de\kherud\llama\ - - name: Upload Windows Artifact - if: ${{ matrix.target.name != 'Windows' }} - uses: actions/upload-artifact@v3 + name: ${{ matrix.target.os }}-${{ matrix.target.arch }}-libraries + path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/ + + + test-linux: + name: Test Linux + needs: build-linux-docker + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 with: - name: artifacts + name: Linux-x86_64-libraries path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/ + - name: Download text generation model + run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + - name: Download reranking model + run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + - uses: actions/setup-java@v4 + with: + distribution: 'zulu' + java-version: '11' + - name: Run tests + run: mvn test + +# test-macos: +# name: Test Mac +# needs: build-macos-native +# runs-on: macos-14 +# steps: +# - uses: actions/checkout@v4 +# - uses: actions/download-artifact@v4 +# with: +# name: macos14-libraries +# path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/ +# - name: Download model +# run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} +# - uses: actions/setup-java@v4 +# with: +# distribution: 'zulu' +# java-version: '11' +# - name: Run tests +# run: mvn test + + +# test-windows: +# name: Test Windows +# needs: build-win-native +# runs-on: windows-latest +# steps: +# - uses: actions/checkout@v4 +# - uses: actions/download-artifact@v4 +# with: +# name: Windows-x86_64-libraries +# path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/ +# - name: Download model +# run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME +# - uses: actions/setup-java@v4 +# with: +# distribution: 'zulu' +# java-version: '11' +# - name: Run tests +# run: mvn test + + publish: - needs: [build] + if: ${{ github.event_name != 'workflow_dispatch' || github.event.inputs.build_only == 'no' }} + needs: [ test-linux,build-macos-native,build-win-native ] #,build-linux-cuda runs-on: ubuntu-latest steps: - - name: Checkout Repository - uses: actions/checkout@v4 - with: - ref: "jni" - - name: Download Artifacts - uses: actions/download-artifact@v3 + - uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 with: - name: artifacts + pattern: "*-libraries" + merge-multiple: true path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/ +# - uses: actions/download-artifact@v4 +# with: +# name: linux-libraries-cuda +# path: ${{ github.workspace }}/src/main/resources_linux_cuda/de/kherud/llama/ - name: Set up Maven Central Repository uses: actions/setup-java@v3 with: @@ -99,7 +223,7 @@ jobs: gpg-private-key: ${{ secrets.GPG_SIGNING_KEY }} gpg-passphrase: MAVEN_GPG_PASSPHRASE - name: Publish package - run: mvn --batch-mode -P release deploy + run: mvn --batch-mode -P release -Dmaven.test.skip=true deploy env: MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }} MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }} diff --git a/.gitignore b/.gitignore index a6622338..274f8687 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .idea target build +cmake-build-* .DS_Store .directory .vscode @@ -30,9 +31,15 @@ build hs_err_pid* replay_pid* +models/*.gguf src/main/cpp/de_kherud_llama_*.h +src/main/resources_cuda_linux/ src/main/resources/**/*.so src/main/resources/**/*.dylib src/main/resources/**/*.dll src/main/resources/**/*.metal src/test/resources/**/*.gbnf + +**/*.etag +**/*.lastModified +src/main/cpp/llama.cpp/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index fd82610c..00000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "src/main/cpp/llama.cpp"] - path = src/main/cpp/llama.cpp - url = https://github.com/ggerganov/llama.cpp.git diff --git a/CMakeLists.txt b/CMakeLists.txt index b34c98ea..96c62950 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,73 +1,121 @@ -cmake_minimum_required(VERSION 3.12) +cmake_minimum_required(VERSION 3.14) project(jllama CXX) -set(CMAKE_CXX_STANDARD 11) -set(CMAKE_CXX_STANDARD_REQUIRED true) -set(CMAKE_POSITION_INDEPENDENT_CODE ON) +include(FetchContent) + set(BUILD_SHARED_LIBS ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(BUILD_SHARED_LIBS OFF) -find_package(Java REQUIRED) -find_program(JAVA_EXECUTABLE NAMES java) +option(LLAMA_VERBOSE "llama: verbose output" OFF) -# find "jni.h" include directory -find_path(JNI_INCLUDE_DIR NAMES jni.h HINTS ENV JAVA_HOME PATH_SUFFIXES include) -if(NOT JNI_INCLUDE_DIR) - message(FATAL_ERROR "Could not find jni.h") -endif() +#################### json #################### -# find "jni_md.h" include directory if not set -file(GLOB_RECURSE JNI_MD_PATHS RELATIVE "${JNI_INCLUDE_DIR}" "${JNI_INCLUDE_DIR}/**/jni_md.h") -if(NOT JNI_MD_PATHS) - message(FATAL_ERROR "Could not find jni_md.h") -endif() -foreach(PATH IN LISTS JNI_MD_PATHS) - get_filename_component(DIR ${PATH} DIRECTORY) - list(APPEND JNI_MD_INCLUDE_DIRS "${JNI_INCLUDE_DIR}/${DIR}") -endforeach() +FetchContent_Declare( + json + GIT_REPOSITORY https://github.com/nlohmann/json + GIT_TAG v3.11.3 +) +FetchContent_MakeAvailable(json) + +#################### llama.cpp #################### + +set(LLAMA_BUILD_COMMON ON) +FetchContent_Declare( + llama.cpp + GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git + GIT_TAG b4916 +) +FetchContent_MakeAvailable(llama.cpp) + +#################### jllama #################### # find which OS we build for if not set (make sure to run mvn compile first) if(NOT DEFINED OS_NAME) - execute_process( + find_package(Java REQUIRED) + find_program(JAVA_EXECUTABLE NAMES java) + execute_process( COMMAND ${JAVA_EXECUTABLE} -cp ${CMAKE_SOURCE_DIR}/target/classes de.kherud.llama.OSInfo --os OUTPUT_VARIABLE OS_NAME OUTPUT_STRIP_TRAILING_WHITESPACE ) endif() +if(NOT OS_NAME) + message(FATAL_ERROR "Could not determine OS name") +endif() # find which architecture we build for if not set (make sure to run mvn compile first) if(NOT DEFINED OS_ARCH) + find_package(Java REQUIRED) + find_program(JAVA_EXECUTABLE NAMES java) execute_process( COMMAND ${JAVA_EXECUTABLE} -cp ${CMAKE_SOURCE_DIR}/target/classes de.kherud.llama.OSInfo --arch OUTPUT_VARIABLE OS_ARCH OUTPUT_STRIP_TRAILING_WHITESPACE ) endif() +if(NOT OS_ARCH) + message(FATAL_ERROR "Could not determine CPU architecture") +endif() -include_directories( - ${JNI_INCLUDE_DIR} - ${JNI_MD_INCLUDE_DIRS} - src/main/cpp - src/main/cpp/llama.cpp - src/main/cpp/llama.cpp/common -) +if(GGML_CUDA) + set(JLLAMA_DIR ${CMAKE_SOURCE_DIR}/src/main/resources_linux_cuda/de/kherud/llama/${OS_NAME}/${OS_ARCH}) + message(STATUS "GPU (CUDA Linux) build - Installing files to ${JLLAMA_DIR}") +else() + set(JLLAMA_DIR ${CMAKE_SOURCE_DIR}/src/main/resources/de/kherud/llama/${OS_NAME}/${OS_ARCH}) + message(STATUS "CPU build - Installing files to ${JLLAMA_DIR}") +endif() -add_subdirectory( - src/main/cpp/llama.cpp -) +# include jni.h and jni_md.h +if(NOT DEFINED JNI_INCLUDE_DIRS) + if(OS_NAME MATCHES "^Linux" OR OS_NAME STREQUAL "Mac" OR OS_NAME STREQUAL "Darwin") + set(JNI_INCLUDE_DIRS .github/include/unix) + elseif(OS_NAME STREQUAL "Windows") + set(JNI_INCLUDE_DIRS .github/include/windows) + # if we don't have provided headers, try to find them via Java + else() + find_package(Java REQUIRED) + find_program(JAVA_EXECUTABLE NAMES java) + + find_path(JNI_INCLUDE_DIRS NAMES jni.h HINTS ENV JAVA_HOME PATH_SUFFIXES include) -add_library(jllama SHARED src/main/cpp/jllama.cpp) + # find "jni_md.h" include directory if not set + file(GLOB_RECURSE JNI_MD_PATHS RELATIVE "${JNI_INCLUDE_DIRS}" "${JNI_INCLUDE_DIRS}/**/jni_md.h") + foreach(PATH IN LISTS JNI_MD_PATHS) + get_filename_component(DIR ${PATH} DIRECTORY) + list(APPEND JNI_INCLUDE_DIRS "${JNI_INCLUDE_DIRS}/${DIR}") + endforeach() + endif() +endif() +if(NOT JNI_INCLUDE_DIRS) + message(FATAL_ERROR "Could not determine JNI include directories") +endif() + +add_library(jllama SHARED src/main/cpp/jllama.cpp src/main/cpp/server.hpp src/main/cpp/utils.hpp) -target_link_libraries(jllama PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +set_target_properties(jllama PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(jllama PRIVATE src/main/cpp ${JNI_INCLUDE_DIRS}) +target_link_libraries(jllama PRIVATE common llama nlohmann_json) +target_compile_features(jllama PRIVATE cxx_std_11) + +target_compile_definitions(jllama PRIVATE + SERVER_VERBOSE=$ +) if(OS_NAME STREQUAL "Windows") - set_target_properties(jllama llama PROPERTIES - RUNTIME_OUTPUT_DIRECTORY_RELEASE "${CMAKE_SOURCE_DIR}/src/main/resources/de/kherud/llama/${OS_NAME}/${OS_ARCH}" + set_target_properties(jllama llama ggml PROPERTIES + RUNTIME_OUTPUT_DIRECTORY_DEBUG ${JLLAMA_DIR} + RUNTIME_OUTPUT_DIRECTORY_RELEASE ${JLLAMA_DIR} + RUNTIME_OUTPUT_DIRECTORY_RELWITHDEBINFO ${JLLAMA_DIR} ) else() - set_target_properties(jllama llama PROPERTIES - LIBRARY_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/src/main/resources/de/kherud/llama/${OS_NAME}/${OS_ARCH}" + set_target_properties(jllama llama ggml PROPERTIES + LIBRARY_OUTPUT_DIRECTORY ${JLLAMA_DIR} ) endif() -message(STATUS "Installing files to ${CMAKE_SOURCE_DIR}/src/main/resources/de/kherud/llama/${OS_NAME}/${OS_ARCH}") +if (LLAMA_METAL AND NOT LLAMA_METAL_EMBED_LIBRARY) + # copy ggml-common.h and ggml-metal.metal to bin directory + configure_file(${llama.cpp_SOURCE_DIR}/ggml-metal.metal ${JLLAMA_DIR}/ggml-metal.metal COPYONLY) +endif() diff --git a/README.md b/README.md index 31141b74..1bc278b1 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,24 @@ ![Java 11+](https://img.shields.io/badge/Java-11%2B-informational) -![llama.cpp b1170](https://img.shields.io/badge/llama.cpp-%23b1204-informational) +![llama.cpp b4916](https://img.shields.io/badge/llama.cpp-%23b4916-informational) # Java Bindings for [llama.cpp](https://github.com/ggerganov/llama.cpp) -The main goal of llama.cpp is to run the LLaMA model using 4-bit integer quantization on a MacBook. -This repository provides Java bindings for the C++ library. +Inference of Meta's LLaMA model (and others) in pure C/C++. **You are welcome to contribute** +1. [Quick Start](#quick-start) + 1.1 [No Setup required](#no-setup-required) + 1.2 [Setup required](#setup-required) +2. [Documentation](#documentation) + 2.1 [Example](#example) + 2.2 [Inference](#inference) + 2.3 [Infilling](#infilling) +3. [Android](#importing-in-android) + +> [!NOTE] +> Now with support for Gemma 3 + ## Quick Start Access this library via Maven: @@ -16,123 +27,109 @@ Access this library via Maven: de.kherud llama - - 2.0.0 + 4.1.0 ``` -Here is a short example: - -```java -public class Example { - - public static void main(String... args) throws IOException { - LlamaModel.setLogger((level, message) -> System.out.print(message)); - ModelParameters modelParams = new ModelParameters.Builder() - .setNGpuLayers(43) - .build(); - InferenceParameters inferParams = new InferenceParameters.Builder() - .setTemperature(0.7f) - .setPenalizeNl(true) - .setMirostat(InferenceParameters.MiroStat.V2) - .setAntiPrompt(new String[]{"\n"}) - .build(); - - String modelPath = "/run/media/konstantin/Seagate/models/llama2/llama-2-13b-chat/ggml-model-q4_0.gguf"; - String system = "This is a conversation between User and Llama, a friendly chatbot.\n" + - "Llama is helpful, kind, honest, good at writing, and never fails to answer any " + - "requests immediately and with precision.\n"; - BufferedReader reader = new BufferedReader(new InputStreamReader(System.in, StandardCharsets.UTF_8)); - try (LlamaModel model = new LlamaModel(modelPath, modelParams)) { - System.out.print(system); - String prompt = system; - while (true) { - prompt += "\nUser: "; - System.out.print("\nUser: "); - String input = reader.readLine(); - prompt += input; - System.out.print("Llama: "); - prompt += "\nLlama: "; - for (String output : model.generate(prompt, inferParams)) { - System.out.print(output); - prompt += output; - } - } - } - } -} -``` - -Also have a look at the [examples](src/test/java/examples). +There are multiple [examples](src/test/java/examples). ### No Setup required We support CPU inference for the following platforms out of the box: -- Linux x86-64 -- MacOS x86-64, arm64 (M1) -- Windows x86-64 +- Linux x86-64, aarch64 +- MacOS x86-64, aarch64 (M-series) +- Windows x86-64, x64 If any of these match your platform, you can include the Maven dependency and get started. ### Setup required If none of the above listed platforms matches yours, currently you have to compile the library yourself (also if you -want GPU acceleration, see below). More support is planned soon. +want GPU acceleration). + +This consists of two steps: 1) Compiling the libraries and 2) putting them in the right location. + +##### Library Compilation -Run in the directory of this repository (java-llama.cpp): +First, have a look at [llama.cpp](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) to know which build arguments to use (e.g. for CUDA support). +Any build option of llama.cpp works equivalently for this project. +You then have to run the following commands in the directory of this repository (java-llama.cpp): ```shell -mkdir build -cd build -cmake .. -DBUILD_SHARED_LIBS=ON # add any other arguments for your backend -cmake --build . --config Release +mvn compile # don't forget this line +cmake -B build # add any other arguments for your backend, e.g. -DGGML_CUDA=ON +cmake --build build --config Release ``` -All required files will be put in a resources directory matching your platform, which will appear in the cmake output. For example something like: +> [!TIP] +> Use `-DLLAMA_CURL=ON` to download models via Java code using `ModelParameters#setModelUrl(String)`. + +All compiled libraries will be put in a resources directory matching your platform, which will appear in the cmake output. For example something like: ```shell -- Installing files to /java-llama.cpp/src/main/resources/de/kherud/llama/Linux/x86_64 ``` -This includes: +#### Library Location -- Linux: `libllama.so`, `libjllama.so` -- MacOS: `libllama.dylib`, `libjllama.dylib`, `ggml-metal.metal` -- Windows: `llama.dll`, `jllama.dll` +This project has to load a single shared library `jllama`. -If you then compile your own JAR from this directory, you are ready to go. Otherwise, if you still want to use the library -as a Maven dependency, see below how to set the necessary paths in order for Java to find your compiled libraries. +Note, that the file name varies between operating systems, e.g., `jllama.dll` on Windows, `jllama.so` on Linux, and `jllama.dylib` on macOS. -### Custom llama.cpp Setup (GPU acceleration) +The application will search in the following order in the following locations: -This repository provides default support for CPU based inference. You can compile `llama.cpp` any way you want, however. -In order to use your self-compiled library, set either of the [JVM options](https://www.jetbrains.com/help/idea/tuning-the-ide.html#configure-jvm-options): +- In **de.kherud.llama.lib.path**: Use this option if you want a custom location for your shared libraries, i.e., set VM option `-Dde.kherud.llama.lib.path=/path/to/directory`. +- In **java.library.path**: These are predefined locations for each OS, e.g., `/usr/java/packages/lib:/usr/lib64:/lib64:/lib:/usr/lib` on Linux. + You can find out the locations using `System.out.println(System.getProperty("java.library.path"))`. + Use this option if you want to install the shared libraries as system libraries. +- From the **JAR**: If any of the libraries weren't found yet, the application will try to use a prebuilt shared library. + This of course only works for the [supported platforms](#no-setup-required) . -- `de.kherud.llama.lib.path`, for example `-Dde.kherud.llama.lib.path=/directory/containing/lib` -- `java.library.path`, for example `-Djava.library.path=/directory/containing/lib` +## Documentation -This repository uses [`System#mapLibraryName`](https://docs.oracle.com/javase%2F7%2Fdocs%2Fapi%2F%2F/java/lang/System.html) to determine the name of the shared library for you platform. -If for any reason your library has a different name, you can set it with +### Example -- `de.kherud.llama.lib.name`, for example `-Dde.kherud.llama.lib.name=myname.so` +This is a short example on how to use this library: -For compiling `llama.cpp`, refer to the official [readme](https://github.com/ggerganov/llama.cpp#build) for details. -The library can be built with the `llama.cpp` project: +```java +public class Example { -```shell -mkdir build -cd build -cmake .. -DBUILD_SHARED_LIBS=ON # add any other arguments for your backend -cmake --build . --config Release + public static void main(String... args) throws IOException { + ModelParameters modelParams = new ModelParameters() + .setModel("models/mistral-7b-instruct-v0.2.Q2_K.gguf") + .setGpuLayers(43); + + String system = "This is a conversation between User and Llama, a friendly chatbot.\n" + + "Llama is helpful, kind, honest, good at writing, and never fails to answer any " + + "requests immediately and with precision.\n"; + BufferedReader reader = new BufferedReader(new InputStreamReader(System.in, StandardCharsets.UTF_8)); + try (LlamaModel model = new LlamaModel(modelParams)) { + System.out.print(system); + String prompt = system; + while (true) { + prompt += "\nUser: "; + System.out.print("\nUser: "); + String input = reader.readLine(); + prompt += input; + System.out.print("Llama: "); + prompt += "\nLlama: "; + InferenceParameters inferParams = new InferenceParameters(prompt) + .setTemperature(0.7f) + .setPenalizeNl(true) + .setMiroStat(MiroStat.V2) + .setStopStrings("User:"); + for (LlamaOutput output : model.generate(inferParams)) { + System.out.print(output); + prompt += output; + } + } + } + } +} ``` -Look for the shared library in `build`. - -> [!IMPORTANT] -> If you are running MacOS with Metal, you have to put the file `ggml-metal.metal` from `build/bin` in the same directory as the shared library. - -## Documentation +Also have a look at the other [examples](src/test/java/examples). ### Inference @@ -141,13 +138,15 @@ model to your prompt in order to extend the context. If there is repeated conten cache this, to improve performance. ```java -try (LlamaModel model = new LlamaModel("/path/to/gguf-model")) { +ModelParameters modelParams = new ModelParameters().setModel("/path/to/model.gguf"); +InferenceParameters inferParams = new InferenceParameters("Tell me a joke."); +try (LlamaModel model = new LlamaModel(modelParams)) { // Stream a response and access more information about each output. - for (String output : model.generate("Tell me a joke.")) { + for (LlamaOutput output : model.generate(inferParams)) { System.out.print(output); } // Calculate a whole response before returning it. - String response = model.complete("Tell me another one"); + String response = model.complete(inferParams); // Returns the hidden representation of the context + prompt. float[] embedding = model.embed("Embed this"); } @@ -159,38 +158,101 @@ try (LlamaModel model = new LlamaModel("/path/to/gguf-model")) { > freed when the model is no longer needed. This isn't strictly required, but avoids memory leaks if you use different > models throughout the lifecycle of your application. +### Infilling + +You can simply set `InferenceParameters#setInputPrefix(String)` and `InferenceParameters#setInputSuffix(String)`. + ### Model/Inference Configuration There are two sets of parameters you can configure, `ModelParameters` and `InferenceParameters`. Both provide builder -classes to ease configuration. All non-specified options have sensible defaults. +classes to ease configuration. `ModelParameters` are once needed for loading a model, `InferenceParameters` are needed +for every inference task. All non-specified options have sensible defaults. ```java -ModelParameters modelParams = new ModelParameters.Builder() - .setLoraAdapter("/path/to/lora/adapter") - .setLoraBase("/path/to/lora/base") - .build(); -InferenceParameters inferParams = new InferenceParameters.Builder() - .setGrammar(new File("/path/to/grammar.gbnf")) - .setTemperature(0.8) - .build(); -LlamaModel model = new LlamaModel("/path/to/model.bin", modelParams); -model.generate(prompt, inferParams) +ModelParameters modelParams = new ModelParameters() + .setModel("/path/to/model.gguf") + .addLoraAdapter("/path/to/lora/adapter"); +String grammar = """ + root ::= (expr "=" term "\\n")+ + expr ::= term ([-+*/] term)* + term ::= [0-9]"""; +InferenceParameters inferParams = new InferenceParameters("") + .setGrammar(grammar) + .setTemperature(0.8); +try (LlamaModel model = new LlamaModel(modelParams)) { + model.generate(inferParams); +} ``` ### Logging -Both Java and C++ logging can be configured via the static method `LlamaModel.setLogger`: +Per default, logs are written to stdout. +This can be intercepted via the static method `LlamaModel.setLogger(LogFormat, BiConsumer)`. +There is text- and JSON-based logging. The default is JSON. +Note, that text-based logging will include additional output of the GGML backend, while JSON-based logging +only provides request logs (while still writing GGML messages to stdout). +To only change the log format while still writing to stdout, `null` can be passed for the callback. +Logging can be disabled by passing an empty callback. ```java -// The method accepts a BiConsumer. -LlamaModel.setLogger((level, message) -> System.out.println(level.name() + ": " + message)); -// To completely silence any output, pass a no-op. -LlamaModel.setLogger((level, message) -> {}); - -// Similarly, a progress callback can be set (only the C++ side will call this). -// I think this is only used to report progress loading the model with a value of 0-1. -// It is thus state specific and can be done via the parameters. -new ModelParameters.Builder() - .setProgressCallback(progress -> System.out.println("progress: " + progress)) - .build(); +// Re-direct log messages however you like (e.g. to a logging library) +LlamaModel.setLogger(LogFormat.TEXT, (level, message) -> System.out.println(level.name() + ": " + message)); +// Log to stdout, but change the format +LlamaModel.setLogger(LogFormat.TEXT, null); +// Disable logging by passing a no-op +LlamaModel.setLogger(null, (level, message) -> {}); +``` + +## Importing in Android + +You can use this library in Android project. +1. Add java-llama.cpp as a submodule in your an droid `app` project directory +```shell +git submodule add https://github.com/kherud/java-llama.cpp +``` +2. Declare the library as a source in your build.gradle +```gradle +android { + val jllamaLib = file("java-llama.cpp") + + // Execute "mvn compile" if folder target/ doesn't exist at ./java-llama.cpp/ + if (!file("$jllamaLib/target").exists()) { + exec { + commandLine = listOf("mvn", "compile") + workingDir = file("java-llama.cpp/") + } + } + + ... + defaultConfig { + ... + externalNativeBuild { + cmake { + // Add an flags if needed + cppFlags += "" + arguments += "" + } + } + } + + // Declare c++ sources + externalNativeBuild { + cmake { + path = file("$jllamaLib/CMakeLists.txt") + version = "3.22.1" + } + } + + // Declare java sources + sourceSets { + named("main") { + // Add source directory for java-llama.cpp + java.srcDir("$jllamaLib/src/main/java") + } + } +} +``` +3. Exclude `de.kherud.llama` in proguard-rules.pro +```proguard +keep class de.kherud.llama.** { *; } ``` diff --git a/models/README.md b/models/README.md new file mode 100644 index 00000000..2481356b --- /dev/null +++ b/models/README.md @@ -0,0 +1,3 @@ +# Local Model Directory +This directory contains models which will be automatically downloaded +for use in java-llama.cpp's unit tests. diff --git a/pom.xml b/pom.xml index ab2a6562..67b366ee 100644 --- a/pom.xml +++ b/pom.xml @@ -1,14 +1,16 @@ - 4.0.0 de.kherud llama - 2.0.0 + 4.2.0 jar ${project.groupId}:${project.artifactId} - Java Bindings for llama.cpp - A Port of Facebook's LLaMA model in C/C++. + Java Bindings for llama.cpp - A Port of Facebook's LLaMA model + in C/C++. https://github.com/kherud/java-llama.cpp @@ -39,13 +41,14 @@ ossrh - https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ + + https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ 5.13.0 - 4.13.1 + 4.13.2 UTF-8 @@ -59,7 +62,7 @@ org.jetbrains annotations - 24.0.1 + 24.1.0 compile @@ -69,14 +72,55 @@ org.apache.maven.plugins maven-compiler-plugin - 3.11.0 - - - -h - src/main/cpp - - + 3.13.0 + + + + gpu + compile + + compile + + + + -h + src/main/cpp + + + ${project.build.outputDirectory}_cuda + + + + + maven-resources-plugin + 3.3.1 + + + + copy-resources + process-classes + + copy-resources + + + + ${project.build.outputDirectory}_cuda + + + + ${basedir}/src/main/resources_linux_cuda/ + + **/*.* + + + + + + + + @@ -136,6 +180,27 @@ + + org.apache.maven.plugins + maven-jar-plugin + 3.4.2 + + + + cuda + package + + jar + + + cuda12-linux-x86-64 + + ${project.build.outputDirectory}_cuda + + + + diff --git a/scripts/build.bat b/scripts/build.bat deleted file mode 100755 index 277fd1de..00000000 --- a/scripts/build.bat +++ /dev/null @@ -1,12 +0,0 @@ -@echo off - -pushd .. -if not exist "build" ( - mkdir build -) -cd build -cmake .. %* -cmake --build . --config Release -popd - -if errorlevel 1 exit /b %ERRORLEVEL% \ No newline at end of file diff --git a/scripts/build.sh b/scripts/build.sh deleted file mode 100755 index 2dbc3f88..00000000 --- a/scripts/build.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -pushd .. -mkdir -p build -cd build -cmake .. $@ || (popd && exit 1) -cmake --build . --config Release || (popd && exit 1) -popd diff --git a/scripts/jni-signature b/scripts/jni-signature deleted file mode 100755 index 8fc1886b..00000000 --- a/scripts/jni-signature +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -if [ -z "$1" ]; then - echo "Usage: $0 target/path/to/YourClass.class" - exit 1 -fi - -CLASS_NAME=$(basename "$1" .class) -CLASS_PATH=$(dirname "$1") - -echo $CLASS_NAME -echo $CLASS_PATH - -javap -s -p -cp "$CLASS_PATH" "$CLASS_NAME" diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index deb49619..11c80ae0 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -1,1340 +1,853 @@ -#include "llama.h" #include "jllama.h" -#include "common.h" -#include "grammar-parser.h" -#include +#include "arg.h" +#include "json-schema-to-grammar.h" +#include "llama.h" +#include "log.h" +#include "nlohmann/json.hpp" +#include "server.hpp" + +#include #include -#include -#include +#include + +// We store some references to Java classes and their fields/methods here to speed up things for later and to fail +// early on if anything can't be found. This happens when the JVM loads the shared library (see `JNI_OnLoad`). +// The references remain valid throughout the whole life of the shared library, on `JNI_OnUnload` they are released. + +namespace { +JavaVM *g_vm = nullptr; // classes -static jclass c_llama_model = 0; -static jclass c_llama_iterator = 0; -static jclass c_model_params = 0; -static jclass c_infer_params = 0; -static jclass c_standard_charsets = 0; -static jclass c_string = 0; -static jclass c_map = 0; -static jclass c_set = 0; -static jclass c_entry = 0; -static jclass c_iterator = 0; -static jclass c_integer = 0; -static jclass c_float = 0; -static jclass c_log_level = 0; -static jclass c_biconsumer = 0; -static jclass c_llama_error = 0; -static jclass c_error_oom = 0; +jclass c_llama_model = nullptr; +jclass c_llama_iterator = nullptr; +jclass c_standard_charsets = nullptr; +jclass c_output = nullptr; +jclass c_string = nullptr; +jclass c_hash_map = nullptr; +jclass c_map = nullptr; +jclass c_set = nullptr; +jclass c_entry = nullptr; +jclass c_iterator = nullptr; +jclass c_integer = nullptr; +jclass c_float = nullptr; +jclass c_biconsumer = nullptr; +jclass c_llama_error = nullptr; +jclass c_log_level = nullptr; +jclass c_log_format = nullptr; +jclass c_error_oom = nullptr; + +// constructors +jmethodID cc_output = nullptr; +jmethodID cc_hash_map = nullptr; +jmethodID cc_integer = nullptr; +jmethodID cc_float = nullptr; // methods -static jmethodID m_get_bytes = 0; -static jmethodID m_entry_set = 0; -static jmethodID m_set_iterator = 0; -static jmethodID m_iterator_has_next = 0; -static jmethodID m_iterator_next = 0; -static jmethodID m_entry_key = 0; -static jmethodID m_entry_value = 0; -static jmethodID m_int_value = 0; -static jmethodID m_float_value = 0; -static jmethodID m_biconsumer_accept = 0; +jmethodID m_get_bytes = nullptr; +jmethodID m_entry_set = nullptr; +jmethodID m_set_iterator = nullptr; +jmethodID m_iterator_has_next = nullptr; +jmethodID m_iterator_next = nullptr; +jmethodID m_entry_key = nullptr; +jmethodID m_entry_value = nullptr; +jmethodID m_map_put = nullptr; +jmethodID m_int_value = nullptr; +jmethodID m_float_value = nullptr; +jmethodID m_biconsumer_accept = nullptr; // fields -static jfieldID f_model_pointer = 0; -// iterator -static jfieldID f_iter_has_next = 0; -static jfieldID f_iter_n_generated = 0; -static jfieldID f_iter_token_index = 0; -// inference parameters -static jfieldID f_n_predict = 0; -static jfieldID f_n_keep = 0; -static jfieldID f_n_probs = 0; -static jfieldID f_logit_bias = 0; -static jfieldID f_top_k = 0; -static jfieldID f_top_p = 0; -static jfieldID f_tfs_z = 0; -static jfieldID f_typical_p = 0; -static jfieldID f_temperature = 0; -static jfieldID f_repeat_penalty = 0; -static jfieldID f_repeat_last_n = 0; -static jfieldID f_frequency_penalty = 0; -static jfieldID f_presence_penalty = 0; -static jfieldID f_penalize_nl = 0; -static jfieldID f_ignore_eos = 0; -static jfieldID f_mirostat = 0; -static jfieldID f_mirostat_tau = 0; -static jfieldID f_mirostat_eta = 0; -static jfieldID f_beam_search = 0; -static jfieldID f_n_beams = 0; -static jfieldID f_grammar = 0; -static jfieldID f_antiprompt = 0; -static jfieldID f_infer_seed = 0; -// model parameters -static jfieldID f_n_threads = 0; -static jfieldID f_model_seed = 0; -static jfieldID f_n_ctx = 0; -static jfieldID f_n_batch = 0; -static jfieldID f_n_gpu_layers = 0; -static jfieldID f_main_gpu = 0; -static jfieldID f_tensor_split = 0; -static jfieldID f_rope_freq_base = 0; -static jfieldID f_rope_freq_scale = 0; -static jfieldID f_low_vram = 0; -static jfieldID f_mul_mat_q = 0; -static jfieldID f_f16_kv = 0; -static jfieldID f_logits_all = 0; -static jfieldID f_vocab_only = 0; -static jfieldID f_use_mmap = 0; -static jfieldID f_use_mlock = 0; -static jfieldID f_embedding = 0; -static jfieldID f_lora_adapter = 0; -static jfieldID f_lora_base = 0; -static jfieldID f_hellaswag = 0; -static jfieldID f_hellaswag_tasks = 0; -static jfieldID f_memory_f16 = 0; -static jfieldID f_mem_test = 0; -static jfieldID f_numa = 0; -static jfieldID f_verbose_prompt = 0; -// log level -static jfieldID f_utf_8 = 0; -static jfieldID f_log_level_debug = 0; -static jfieldID f_log_level_info = 0; -static jfieldID f_log_level_warn = 0; -static jfieldID f_log_level_error = 0; +jfieldID f_model_pointer = nullptr; +jfieldID f_task_id = nullptr; +jfieldID f_utf_8 = nullptr; +jfieldID f_iter_has_next = nullptr; +jfieldID f_log_level_debug = nullptr; +jfieldID f_log_level_info = nullptr; +jfieldID f_log_level_warn = nullptr; +jfieldID f_log_level_error = nullptr; +jfieldID f_log_format_json = nullptr; +jfieldID f_log_format_text = nullptr; + // objects -static jobject o_utf_8 = 0; -static jobject o_log_level_debug = 0; -static jobject o_log_level_info = 0; -static jobject o_log_level_warn = 0; -static jobject o_log_level_error = 0; +jobject o_utf_8 = nullptr; +jobject o_log_level_debug = nullptr; +jobject o_log_level_info = nullptr; +jobject o_log_level_warn = nullptr; +jobject o_log_level_error = nullptr; +jobject o_log_format_json = nullptr; +jobject o_log_format_text = nullptr; +jobject o_log_callback = nullptr; + +/** + * Convert a Java string to a std::string + */ +std::string parse_jstring(JNIEnv *env, jstring java_string) { + auto *const string_bytes = (jbyteArray)env->CallObjectMethod(java_string, m_get_bytes, o_utf_8); + + auto length = (size_t)env->GetArrayLength(string_bytes); + jbyte *byte_elements = env->GetByteArrayElements(string_bytes, nullptr); + + std::string string = std::string((char *)byte_elements, length); + + env->ReleaseByteArrayElements(string_bytes, byte_elements, JNI_ABORT); + env->DeleteLocalRef(string_bytes); + + return string; +} + +char **parse_string_array(JNIEnv *env, const jobjectArray string_array, const jsize length) { + auto *const result = static_cast(malloc(length * sizeof(char *))); + + if (result == nullptr) { + return nullptr; + } + + for (jsize i = 0; i < length; i++) { + auto *const javaString = static_cast(env->GetObjectArrayElement(string_array, i)); + const char *cString = env->GetStringUTFChars(javaString, nullptr); + result[i] = strdup(cString); + env->ReleaseStringUTFChars(javaString, cString); + } + + return result; +} + +void free_string_array(char **array, jsize length) { + if (array != nullptr) { + for (jsize i = 0; i < length; i++) { + free(array[i]); + } + free(array); + } +} -static JavaVM* g_vm = nullptr; -static jobject g_log_callback = nullptr; +/** + * Since Java expects utf16 but std::strings are utf8, we can't directly use `env->NewString` or `env-NewString`, + * but we directly send the bytes and do the conversion in Java. Unfortunately, there isn't a nice/standardized way to + * do this conversion in C++ + */ +jbyteArray parse_jbytes(JNIEnv *env, const std::string &string) { + jsize length = string.size(); // NOLINT(*-narrowing-conversions) + jbyteArray bytes = env->NewByteArray(length); + env->SetByteArrayRegion(bytes, 0, length, reinterpret_cast(string.c_str())); + return bytes; +} -JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) -{ - JNIEnv *env = 0; +/** + * Map a llama.cpp log level to its Java enumeration option. + */ +jobject log_level_to_jobject(ggml_log_level level) { + switch (level) { + case GGML_LOG_LEVEL_ERROR: + return o_log_level_error; + case GGML_LOG_LEVEL_WARN: + return o_log_level_warn; + default: + case GGML_LOG_LEVEL_INFO: + return o_log_level_info; + case GGML_LOG_LEVEL_DEBUG: + return o_log_level_debug; + } +} - if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_1)) - { +/** + * Returns the JNIEnv of the current thread. + */ +JNIEnv *get_jni_env() { + JNIEnv *env = nullptr; + if (g_vm == nullptr || g_vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { + throw std::runtime_error("Thread is not attached to the JVM"); + } + return env; +} + +bool log_json; +std::function log_callback; + +/** + * Invoke the log callback if there is any. + */ +void log_callback_trampoline(ggml_log_level level, const char *text, void *user_data) { + if (log_callback != nullptr) { + log_callback(level, text, user_data); + } +} +} // namespace + +/** + * The VM calls JNI_OnLoad when the native library is loaded (for example, through `System.loadLibrary`). + * `JNI_OnLoad` must return the JNI version needed by the native library. + * In order to use any of the new JNI functions, a native library must export a `JNI_OnLoad` function that returns + * `JNI_VERSION_1_2`. If the native library does not export a JNI_OnLoad function, the VM assumes that the library + * only requires JNI version `JNI_VERSION_1_1`. If the VM does not recognize the version number returned by + `JNI_OnLoad`, the VM will unload the library and act as if the library was never loaded. + */ +JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { + g_vm = vm; + JNIEnv *env = nullptr; + + if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_1)) { goto error; } // find classes c_llama_model = env->FindClass("de/kherud/llama/LlamaModel"); - c_llama_iterator = env->FindClass("de/kherud/llama/LlamaModel$LlamaIterator"); - c_infer_params = env->FindClass("de/kherud/llama/InferenceParameters"); - c_model_params = env->FindClass("de/kherud/llama/ModelParameters"); + c_llama_iterator = env->FindClass("de/kherud/llama/LlamaIterator"); c_standard_charsets = env->FindClass("java/nio/charset/StandardCharsets"); + c_output = env->FindClass("de/kherud/llama/LlamaOutput"); c_string = env->FindClass("java/lang/String"); + c_hash_map = env->FindClass("java/util/HashMap"); c_map = env->FindClass("java/util/Map"); c_set = env->FindClass("java/util/Set"); c_entry = env->FindClass("java/util/Map$Entry"); c_iterator = env->FindClass("java/util/Iterator"); c_integer = env->FindClass("java/lang/Integer"); c_float = env->FindClass("java/lang/Float"); - c_log_level = env->FindClass("de/kherud/llama/LogLevel"); c_biconsumer = env->FindClass("java/util/function/BiConsumer"); c_llama_error = env->FindClass("de/kherud/llama/LlamaException"); + c_log_level = env->FindClass("de/kherud/llama/LogLevel"); + c_log_format = env->FindClass("de/kherud/llama/args/LogFormat"); c_error_oom = env->FindClass("java/lang/OutOfMemoryError"); - if (!(c_llama_model && c_llama_iterator && c_infer_params && c_model_params && c_standard_charsets && c_string && c_map && c_set && c_entry && c_iterator && c_integer && c_float && c_log_level && c_biconsumer && c_llama_error && c_error_oom)) - { + if (!(c_llama_model && c_llama_iterator && c_standard_charsets && c_output && c_string && c_hash_map && c_map && + c_set && c_entry && c_iterator && c_integer && c_float && c_biconsumer && c_llama_error && c_log_level && + c_log_format && c_error_oom)) { goto error; } // create references - c_llama_model = (jclass)env->NewWeakGlobalRef(c_llama_model); - c_llama_iterator = (jclass)env->NewWeakGlobalRef(c_llama_iterator); - c_infer_params = (jclass)env->NewWeakGlobalRef(c_infer_params); - c_model_params = (jclass)env->NewWeakGlobalRef(c_model_params); - c_string = (jclass)env->NewWeakGlobalRef(c_string); - c_map = (jclass)env->NewWeakGlobalRef(c_map); - c_set = (jclass)env->NewWeakGlobalRef(c_set); - c_entry = (jclass)env->NewWeakGlobalRef(c_entry); - c_iterator = (jclass)env->NewWeakGlobalRef(c_iterator); - c_integer = (jclass)env->NewWeakGlobalRef(c_integer); - c_float = (jclass)env->NewWeakGlobalRef(c_float); - c_log_level = (jclass)env->NewWeakGlobalRef(c_log_level); - c_biconsumer = (jclass)env->NewWeakGlobalRef(c_biconsumer); - c_llama_error = (jclass)env->NewWeakGlobalRef(c_llama_error); - c_error_oom = (jclass)env->NewWeakGlobalRef(c_error_oom); + c_llama_model = (jclass)env->NewGlobalRef(c_llama_model); + c_llama_iterator = (jclass)env->NewGlobalRef(c_llama_iterator); + c_output = (jclass)env->NewGlobalRef(c_output); + c_string = (jclass)env->NewGlobalRef(c_string); + c_hash_map = (jclass)env->NewGlobalRef(c_hash_map); + c_map = (jclass)env->NewGlobalRef(c_map); + c_set = (jclass)env->NewGlobalRef(c_set); + c_entry = (jclass)env->NewGlobalRef(c_entry); + c_iterator = (jclass)env->NewGlobalRef(c_iterator); + c_integer = (jclass)env->NewGlobalRef(c_integer); + c_float = (jclass)env->NewGlobalRef(c_float); + c_biconsumer = (jclass)env->NewGlobalRef(c_biconsumer); + c_llama_error = (jclass)env->NewGlobalRef(c_llama_error); + c_log_level = (jclass)env->NewGlobalRef(c_log_level); + c_log_format = (jclass)env->NewGlobalRef(c_log_format); + c_error_oom = (jclass)env->NewGlobalRef(c_error_oom); + + // find constructors + cc_output = env->GetMethodID(c_output, "", "([BLjava/util/Map;Z)V"); + cc_hash_map = env->GetMethodID(c_hash_map, "", "()V"); + cc_integer = env->GetMethodID(c_integer, "", "(I)V"); + cc_float = env->GetMethodID(c_float, "", "(F)V"); + + if (!(cc_output && cc_hash_map && cc_integer && cc_float)) { + goto error; + } // find methods - m_get_bytes = env->GetMethodID(c_string, "getBytes", "(Ljava/nio/charset/Charset;)[B"); + m_get_bytes = env->GetMethodID(c_string, "getBytes", "(Ljava/lang/String;)[B"); m_entry_set = env->GetMethodID(c_map, "entrySet", "()Ljava/util/Set;"); m_set_iterator = env->GetMethodID(c_set, "iterator", "()Ljava/util/Iterator;"); m_iterator_has_next = env->GetMethodID(c_iterator, "hasNext", "()Z"); m_iterator_next = env->GetMethodID(c_iterator, "next", "()Ljava/lang/Object;"); m_entry_key = env->GetMethodID(c_entry, "getKey", "()Ljava/lang/Object;"); m_entry_value = env->GetMethodID(c_entry, "getValue", "()Ljava/lang/Object;"); + m_map_put = env->GetMethodID(c_map, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"); m_int_value = env->GetMethodID(c_integer, "intValue", "()I"); m_float_value = env->GetMethodID(c_float, "floatValue", "()F"); m_biconsumer_accept = env->GetMethodID(c_biconsumer, "accept", "(Ljava/lang/Object;Ljava/lang/Object;)V"); - if (!(m_get_bytes && m_entry_set && m_set_iterator && m_iterator_has_next && m_iterator_next && m_entry_key && m_entry_value && m_int_value && m_float_value && m_biconsumer_accept)) - { + if (!(m_get_bytes && m_entry_set && m_set_iterator && m_iterator_has_next && m_iterator_next && m_entry_key && + m_entry_value && m_map_put && m_int_value && m_float_value && m_biconsumer_accept)) { goto error; } // find fields f_model_pointer = env->GetFieldID(c_llama_model, "ctx", "J"); + f_task_id = env->GetFieldID(c_llama_iterator, "taskId", "I"); + f_utf_8 = env->GetStaticFieldID(c_standard_charsets, "UTF_8", "Ljava/nio/charset/Charset;"); f_iter_has_next = env->GetFieldID(c_llama_iterator, "hasNext", "Z"); - f_iter_n_generated = env->GetFieldID(c_llama_iterator, "generatedCount", "J"); - f_iter_token_index = env->GetFieldID(c_llama_iterator, "tokenIndex", "J"); - - f_n_predict = env->GetFieldID(c_infer_params, "nPredict", "I"); - f_n_keep = env->GetFieldID(c_infer_params, "nKeep", "I"); - f_n_probs = env->GetFieldID(c_infer_params, "nProbs", "I"); - f_logit_bias = env->GetFieldID(c_infer_params, "logitBias", "Ljava/util/Map;"); - f_top_k = env->GetFieldID(c_infer_params, "topK", "I"); - f_top_p = env->GetFieldID(c_infer_params, "topP", "F"); - f_tfs_z = env->GetFieldID(c_infer_params, "tfsZ", "F"); - f_typical_p = env->GetFieldID(c_infer_params, "typicalP", "F"); - f_temperature = env->GetFieldID(c_infer_params, "temperature", "F"); - f_repeat_penalty = env->GetFieldID(c_infer_params, "repeatPenalty", "F"); - f_repeat_last_n = env->GetFieldID(c_infer_params, "repeatLastN", "I"); - f_frequency_penalty = env->GetFieldID(c_infer_params, "frequencyPenalty", "F"); - f_presence_penalty = env->GetFieldID(c_infer_params, "presencePenalty", "F"); - f_penalize_nl = env->GetFieldID(c_infer_params, "penalizeNL", "Z"); - f_ignore_eos = env->GetFieldID(c_infer_params, "ignoreEos", "Z"); - f_mirostat = env->GetFieldID(c_infer_params, "mirostat", "I"); - f_mirostat_tau = env->GetFieldID(c_infer_params, "mirostatTau", "F"); - f_mirostat_eta = env->GetFieldID(c_infer_params, "mirostatEta", "F"); - f_beam_search = env->GetFieldID(c_infer_params, "beamSearch", "Z"); - f_n_beams = env->GetFieldID(c_infer_params, "nBeams", "I"); - f_grammar = env->GetFieldID(c_infer_params, "grammar", "Ljava/lang/String;"); - f_antiprompt = env->GetFieldID(c_infer_params, "antiprompt", "[Ljava/lang/String;"); - f_infer_seed = env->GetFieldID(c_infer_params, "seed", "I"); - - f_n_threads = env->GetFieldID(c_model_params, "nThreads", "I"); - f_model_seed = env->GetFieldID(c_model_params, "seed", "I"); - f_n_ctx = env->GetFieldID(c_model_params, "nCtx", "I"); - f_n_batch = env->GetFieldID(c_model_params, "nBatch", "I"); - f_n_gpu_layers = env->GetFieldID(c_model_params, "nGpuLayers", "I"); - f_main_gpu = env->GetFieldID(c_model_params, "mainGpu", "I"); - f_tensor_split = env->GetFieldID(c_model_params, "tensorSplit", "[F"); - f_rope_freq_base = env->GetFieldID(c_model_params, "ropeFreqBase", "F"); - f_rope_freq_scale = env->GetFieldID(c_model_params, "ropeFreqScale", "F"); - f_low_vram = env->GetFieldID(c_model_params, "lowVram", "Z"); - f_mul_mat_q = env->GetFieldID(c_model_params, "mulMatQ", "Z"); - f_f16_kv = env->GetFieldID(c_model_params, "f16Kv", "Z"); - f_logits_all = env->GetFieldID(c_model_params, "logitsAll", "Z"); - f_vocab_only = env->GetFieldID(c_model_params, "vocabOnly", "Z"); - f_use_mmap = env->GetFieldID(c_model_params, "useMmap", "Z"); - f_use_mlock = env->GetFieldID(c_model_params, "useMlock", "Z"); - f_embedding = env->GetFieldID(c_model_params, "embedding", "Z"); - f_lora_adapter = env->GetFieldID(c_model_params, "loraAdapter", "Ljava/lang/String;"); - f_lora_base = env->GetFieldID(c_model_params, "loraBase", "Ljava/lang/String;"); - f_hellaswag = env->GetFieldID(c_model_params, "hellaswag", "Z"); - f_hellaswag_tasks = env->GetFieldID(c_model_params, "hellaswagTasks", "S"); - f_memory_f16 = env->GetFieldID(c_model_params, "memoryF16", "Z"); - f_mem_test = env->GetFieldID(c_model_params, "memTest", "Z"); - f_numa = env->GetFieldID(c_model_params, "numa", "Z"); - f_verbose_prompt = env->GetFieldID(c_model_params, "verbosePrompt", "Z"); - - if (!(f_model_pointer && f_iter_has_next && f_iter_n_generated && f_iter_token_index)) - { - goto error; - } - if (!(f_n_predict && f_n_keep && f_n_probs && f_logit_bias && f_top_k && f_top_p && f_tfs_z && f_typical_p && f_temperature && f_repeat_penalty && f_repeat_last_n && f_frequency_penalty && f_presence_penalty && f_penalize_nl && f_ignore_eos && f_mirostat && f_mirostat_tau && f_mirostat_eta && f_beam_search && f_n_beams && f_grammar && f_antiprompt && f_infer_seed)) - { - goto error; - } - if (!(f_n_threads && f_model_seed && f_n_ctx && f_n_batch && f_n_gpu_layers && f_main_gpu && f_tensor_split && f_rope_freq_base && f_rope_freq_scale && f_low_vram && f_mul_mat_q && f_f16_kv && f_logits_all && f_vocab_only && f_use_mmap && f_use_mlock && f_embedding && f_lora_adapter && f_lora_base && f_hellaswag && f_hellaswag_tasks && f_memory_f16 && f_mem_test && f_numa && f_verbose_prompt)) - { + f_log_level_debug = env->GetStaticFieldID(c_log_level, "DEBUG", "Lde/kherud/llama/LogLevel;"); + f_log_level_info = env->GetStaticFieldID(c_log_level, "INFO", "Lde/kherud/llama/LogLevel;"); + f_log_level_warn = env->GetStaticFieldID(c_log_level, "WARN", "Lde/kherud/llama/LogLevel;"); + f_log_level_error = env->GetStaticFieldID(c_log_level, "ERROR", "Lde/kherud/llama/LogLevel;"); + f_log_format_json = env->GetStaticFieldID(c_log_format, "JSON", "Lde/kherud/llama/args/LogFormat;"); + f_log_format_text = env->GetStaticFieldID(c_log_format, "TEXT", "Lde/kherud/llama/args/LogFormat;"); + + if (!(f_model_pointer && f_task_id && f_utf_8 && f_iter_has_next && f_log_level_debug && f_log_level_info && + f_log_level_warn && f_log_level_error && f_log_format_json && f_log_format_text)) { goto error; } - f_utf_8 = env->GetStaticFieldID(c_standard_charsets, "UTF_8", "Ljava/nio/charset/Charset;"); - f_log_level_debug = env->GetStaticFieldID(c_log_level, "DEBUG", "Lde/kherud/llama/LogLevel;"); - f_log_level_info = env->GetStaticFieldID(c_log_level, "INFO", "Lde/kherud/llama/LogLevel;"); - f_log_level_warn = env->GetStaticFieldID(c_log_level, "WARN", "Lde/kherud/llama/LogLevel;"); - f_log_level_error = env->GetStaticFieldID(c_log_level, "ERROR", "Lde/kherud/llama/LogLevel;"); - - if (!(f_utf_8 && f_log_level_debug && f_log_level_info && f_log_level_warn && f_log_level_error)) - { - goto error; - } - - o_utf_8 = env->GetStaticObjectField(c_standard_charsets, f_utf_8); - o_log_level_debug = env->GetStaticObjectField(c_log_level, f_log_level_debug); + o_utf_8 = env->NewStringUTF("UTF-8"); + o_log_level_debug = env->GetStaticObjectField(c_log_level, f_log_level_debug); o_log_level_info = env->GetStaticObjectField(c_log_level, f_log_level_info); o_log_level_warn = env->GetStaticObjectField(c_log_level, f_log_level_warn); o_log_level_error = env->GetStaticObjectField(c_log_level, f_log_level_error); + o_log_format_json = env->GetStaticObjectField(c_log_format, f_log_format_json); + o_log_format_text = env->GetStaticObjectField(c_log_format, f_log_format_text); + + if (!(o_utf_8 && o_log_level_debug && o_log_level_info && o_log_level_warn && o_log_level_error && + o_log_format_json && o_log_format_text)) { + goto error; + } - if (!(o_utf_8 && o_log_level_debug && o_log_level_info && o_log_level_warn && o_log_level_error)) - { - goto error; - } + o_utf_8 = env->NewGlobalRef(o_utf_8); + o_log_level_debug = env->NewGlobalRef(o_log_level_debug); + o_log_level_info = env->NewGlobalRef(o_log_level_info); + o_log_level_warn = env->NewGlobalRef(o_log_level_warn); + o_log_level_error = env->NewGlobalRef(o_log_level_error); + o_log_format_json = env->NewGlobalRef(o_log_format_json); + o_log_format_text = env->NewGlobalRef(o_log_format_text); - if (env->ExceptionCheck()) - { + if (env->ExceptionCheck()) { env->ExceptionDescribe(); goto error; } + llama_backend_init(); + goto success; error: return JNI_ERR; success: - return JNI_VERSION_1_1; + return JNI_VERSION_1_6; } -JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) -{ - JNIEnv *env = 0; - - if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_1)) +/** + * The VM calls `JNI_OnUnload` when the class loader containing the native library is garbage collected. + * This function can be used to perform cleanup operations. Because this function is called in an unknown context + * (such as from a finalizer), the programmer should be conservative on using Java VM services, and refrain from + * arbitrary Java call-backs. + * Note that `JNI_OnLoad` and `JNI_OnUnload` are two functions optionally supplied by JNI libraries, not exported from + * the VM. + */ +JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) { + JNIEnv *env = nullptr; + + if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_6)) { return; + } - env->DeleteWeakGlobalRef(c_llama_model); - env->DeleteWeakGlobalRef(c_llama_iterator); - env->DeleteWeakGlobalRef(c_infer_params); - env->DeleteWeakGlobalRef(c_model_params); - env->DeleteWeakGlobalRef(c_string); - env->DeleteWeakGlobalRef(c_map); - env->DeleteWeakGlobalRef(c_set); - env->DeleteWeakGlobalRef(c_entry); - env->DeleteWeakGlobalRef(c_iterator); - env->DeleteWeakGlobalRef(c_integer); - env->DeleteWeakGlobalRef(c_float); - env->DeleteWeakGlobalRef(c_log_level); - env->DeleteWeakGlobalRef(c_biconsumer); - env->DeleteWeakGlobalRef(c_llama_error); - env->DeleteWeakGlobalRef(c_error_oom); + env->DeleteGlobalRef(c_llama_model); + env->DeleteGlobalRef(c_llama_iterator); + env->DeleteGlobalRef(c_output); + env->DeleteGlobalRef(c_string); + env->DeleteGlobalRef(c_hash_map); + env->DeleteGlobalRef(c_map); + env->DeleteGlobalRef(c_set); + env->DeleteGlobalRef(c_entry); + env->DeleteGlobalRef(c_iterator); + env->DeleteGlobalRef(c_integer); + env->DeleteGlobalRef(c_float); + env->DeleteGlobalRef(c_biconsumer); + env->DeleteGlobalRef(c_llama_error); + env->DeleteGlobalRef(c_log_level); + env->DeleteGlobalRef(c_log_level); + env->DeleteGlobalRef(c_error_oom); + + env->DeleteGlobalRef(o_utf_8); + env->DeleteGlobalRef(o_log_level_debug); + env->DeleteGlobalRef(o_log_level_info); + env->DeleteGlobalRef(o_log_level_warn); + env->DeleteGlobalRef(o_log_level_error); + env->DeleteGlobalRef(o_log_format_json); + env->DeleteGlobalRef(o_log_format_text); + + if (o_log_callback != nullptr) { + env->DeleteGlobalRef(o_log_callback); + } + + llama_backend_free(); } -static void jllama_log_callback(enum llama_log_level level, const char * text, void * user_data) { - JNIEnv* env; - g_vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_2); +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jobjectArray jparams) { + common_params params; - jobject java_log_level; - switch (level) { - case LLAMA_LOG_LEVEL_ERROR: java_log_level = o_log_level_error; break; - case LLAMA_LOG_LEVEL_WARN: java_log_level = o_log_level_warn; break; - case LLAMA_LOG_LEVEL_INFO: java_log_level = o_log_level_info; break; - default: java_log_level = o_log_level_debug; break; + const jsize argc = env->GetArrayLength(jparams); + char **argv = parse_string_array(env, jparams, argc); + if (argv == nullptr) { + return; } - jstring java_text = env->NewStringUTF(text); - env->CallVoidMethod(g_log_callback, m_biconsumer_accept, java_log_level, java_text); + const auto parsed_params = common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER); + free_string_array(argv, argc); + if (!parsed_params) { + return; + } - env->DeleteLocalRef(java_log_level); - env->DeleteLocalRef(java_text); -} + SRV_INF("loading model '%s'\n", params.model.c_str()); -static std::string parse_jstring(JNIEnv *env, jstring java_string) -{ - const jbyteArray string_bytes = (jbyteArray)env->CallObjectMethod(java_string, m_get_bytes, o_utf_8); + common_init(); - size_t length = (size_t)env->GetArrayLength(string_bytes); - jbyte *byte_elements = env->GetByteArrayElements(string_bytes, nullptr); + // struct that contains llama context and inference + auto *ctx_server = new server_context(); - std::string string = std::string((char *)byte_elements, length); + llama_numa_init(params.numa); - env->ReleaseByteArrayElements(string_bytes, byte_elements, JNI_ABORT); - env->DeleteLocalRef(string_bytes); + LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, + params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); - return string; -} - -static int parse_jinteger(JNIEnv *env, jobject java_integer) -{ - if (!java_integer) - return 0; - return env->CallIntMethod(java_integer, m_int_value); -} + std::atomic state{SERVER_STATE_LOADING_MODEL}; -static float parse_jfloat(JNIEnv *env, jobject java_float) -{ - if (!java_float) - return 0; - return env->CallFloatMethod(java_float, m_float_value); -} + // Necessary similarity of prompt for slot selection + ctx_server->slot_prompt_similarity = params.slot_prompt_similarity; -// Since Java expects utf16 but std::strings are utf8, we cant directly use `env->NewString` or `env-NewString`, but -// we simply send the bytes directly and do the conversion in Java. Unfortunately, there isn't a nice/standardized way -// to do this conversion in C++ -static jbyteArray parse_jbytes(JNIEnv *env, std::string string) -{ - jsize len = string.size(); - jbyteArray bytes = env->NewByteArray(len); - env->SetByteArrayRegion(bytes, 0, len, (jbyte*)string.c_str()); - return bytes; -} + LOG_INF("%s: loading model\n", __func__); -// completion token output with probabilities -struct completion_token_output -{ - struct token_prob - { - llama_token tok; - float prob; - }; - - std::vector probs; - llama_token tok; -}; - -static size_t common_part(const std::vector &a, const std::vector &b) -{ - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) - { + // load the model + if (!ctx_server->load_model(params)) { + llama_backend_free(); + env->ThrowNew(c_llama_error, "could not load model from given file path"); + return; } - return i; -} -enum stop_type -{ - STOP_FULL, - STOP_PARTIAL, -}; + ctx_server->init(); + state.store(SERVER_STATE_READY); -static bool ends_with(const std::string &str, const std::string &suffix) -{ - return str.size() >= suffix.size() && - 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); -} + LOG_INF("%s: model loaded\n", __func__); -static size_t find_partial_stop_string(const std::string &stop, - const std::string &text) -{ - if (!text.empty() && !stop.empty()) - { - const char text_last_char = text.back(); - for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) - { - if (stop[char_index] == text_last_char) - { - const std::string current_partial = stop.substr(0, char_index + 1); - if (ends_with(text, current_partial)) - { - return text.size() - char_index - 1; - } - } - } - } - return std::string::npos; -} - -template -static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) -{ - std::string ret; - for (; begin != end; ++begin) - { - ret += llama_token_to_piece(ctx, *begin); - } - return ret; -} + const auto model_meta = ctx_server->model_meta(); -struct jllama_context -{ - bool has_next_token = false; - std::string generated_text; - std::vector generated_token_probs; + if (!params.speculative.model.empty() || !params.speculative.hf_repo.empty()) { + SRV_INF("loading draft model '%s'\n", params.speculative.model.c_str()); + auto params_dft = params; - size_t num_prompt_tokens = 0; - size_t num_tokens_predicted = 0; - size_t n_past = 0; - size_t n_remain = 0; + params_dft.devices = params.speculative.devices; + params_dft.hf_file = params.speculative.hf_file; + params_dft.hf_repo = params.speculative.hf_repo; + params_dft.model = params.speculative.model; + params_dft.model_url = params.speculative.model_url; + params_dft.n_ctx = params.speculative.n_ctx == 0 ? params.n_ctx / params.n_parallel : params.speculative.n_ctx; + params_dft.n_gpu_layers = params.speculative.n_gpu_layers; + params_dft.n_parallel = 1; - std::string prompt; - std::vector embd; - std::vector last_n_tokens; + common_init_result llama_init_dft = common_init_from_params(params_dft); - llama_model *model = nullptr; - llama_context *ctx = nullptr; - gpt_params params; + llama_model *model_dft = llama_init_dft.model.get(); - grammar_parser::parse_state parsed_grammar; - llama_grammar *grammar = nullptr; + if (model_dft == nullptr) { + SRV_ERR("failed to load draft model, '%s'\n", params.speculative.model.c_str()); + } - bool truncated = false; - bool stopped_eos = false; - bool stopped_word = false; - bool stopped_limit = false; - std::string stopping_word; - int32_t multibyte_pending = 0; + if (!common_speculative_are_compatible(ctx_server->ctx, llama_init_dft.context.get())) { + SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", + params.speculative.model.c_str(), params.model.c_str()); + } - std::mutex mutex; + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); - std::unique_lock lock() - { - return std::unique_lock(mutex); - } + ctx_server->cparams_dft = common_context_params_to_llama(params_dft); + ctx_server->cparams_dft.n_batch = n_ctx_dft; - ~jllama_context() - { - if (ctx) - { - llama_free(ctx); - ctx = nullptr; - } - if (model) - { - llama_free_model(model); - model = nullptr; - } - } + // force F16 KV cache for the draft model for extra performance + ctx_server->cparams_dft.type_k = GGML_TYPE_F16; + ctx_server->cparams_dft.type_v = GGML_TYPE_F16; - void rewind() - { - params.antiprompt.clear(); - params.grammar.clear(); - num_prompt_tokens = 0; - num_tokens_predicted = 0; - generated_text = ""; - generated_text.reserve(params.n_ctx); - generated_token_probs.clear(); - truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; - stopping_word = ""; - multibyte_pending = 0; - n_remain = 0; - n_past = 0; - - if (grammar != nullptr) - { - llama_grammar_free(grammar); - grammar = nullptr; - } + // the context is not needed - we will create one for each slot + llama_init_dft.context.reset(); } - bool loadModel(const gpt_params ¶ms_) - { - params = params_; - std::tie(model, ctx) = llama_init_from_gpt_params(params); - if (model == nullptr) - { - // LOG_ERROR("unable to load model", {{"model", params_.model}}); - return false; - } + // print sample chat example to make it clear which template is used + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + common_chat_templates_source(ctx_server->chat_templates.get()), + common_chat_format_example(ctx_server->chat_templates.get(), ctx_server->params_base.use_jinja).c_str()); - last_n_tokens.resize(params.n_ctx); - std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); - return true; - } + // print sample chat example to make it clear which template is used + // LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + // common_chat_templates_source(ctx_server->chat_templates.get()), + // common_chat_format_example(*ctx_server->chat_templates.template_default, + // ctx_server->params_base.use_jinja) .c_str()); - std::vector tokenize(std::string prompt, bool add_bos) const - { - // If `add_bos` is true, we only add BOS, when json_prompt is a string, - // or the first element of the json_prompt array is a string. - return ::llama_tokenize(ctx, prompt, add_bos); - } + ctx_server->queue_tasks.on_new_task( + std::bind(&server_context::process_single_task, ctx_server, std::placeholders::_1)); + ctx_server->queue_tasks.on_update_slots(std::bind(&server_context::update_slots, ctx_server)); - bool loadGrammar() - { - if (!params.grammar.empty()) - { - parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) - { - // LOG_ERROR("grammar parse error", {{"grammar", params.grammar}}); - return false; - } - grammar_parser::print_grammar(stderr, parsed_grammar); - - { - auto it = params.logit_bias.find(llama_token_eos(ctx)); - if (it != params.logit_bias.end() && it->second == -INFINITY) - { - // LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); - } + std::thread t([ctx_server]() { + JNIEnv *env; + jint res = g_vm->GetEnv((void **)&env, JNI_VERSION_1_6); + if (res == JNI_EDETACHED) { + res = g_vm->AttachCurrentThread((void **)&env, nullptr); + if (res != JNI_OK) { + throw std::runtime_error("Failed to attach thread to JVM"); } - - std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } - return true; - } + ctx_server->queue_tasks.start_loop(); + }); + t.detach(); - void loadPrompt() - { - auto prompt_tokens = tokenize(prompt, true); // always add BOS + env->SetLongField(obj, f_model_pointer, reinterpret_cast(ctx_server)); +} - num_prompt_tokens = prompt_tokens.size(); +JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv *env, jobject obj, jstring jparams) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) - if (params.n_keep < 0) - { - params.n_keep = (int)num_prompt_tokens; - } - params.n_keep = std::min(params.n_ctx - 4, params.n_keep); - - // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t)params.n_ctx) - { - const int n_left = (params.n_ctx - params.n_keep) / 2; - std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); - - // LOG_VERBOSE("input truncated", { - // {"n_ctx", params.n_ctx}, - // {"n_keep", params.n_keep}, - // {"n_left", n_left}, - // {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - // }); - - truncated = true; - prompt_tokens = new_tokens; - } - else - { - const size_t ps = num_prompt_tokens; - std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); - } + std::string c_params = parse_jstring(env, jparams); + json data = json::parse(c_params); - // compare the evaluated prompt with the new prompt - n_past = common_part(embd, prompt_tokens); - embd = prompt_tokens; - if (n_past == num_prompt_tokens) - { - // we have to evaluate at least 1 token to generate logits. - n_past--; - } + server_task_type type = SERVER_TASK_TYPE_COMPLETION; - // LOG_VERBOSE("prompt ingested", { - // {"n_past", n_past}, - // {"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)}, - // {"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, - // }); - - has_next_token = true; + if (data.contains("input_prefix") || data.contains("input_suffix")) { + type = SERVER_TASK_TYPE_INFILL; } - void beginCompletion() - { - // number of tokens to keep when resetting context - n_remain = params.n_predict; - llama_set_rng_seed(ctx, params.seed); - } + auto completion_id = gen_chatcmplid(); + std::vector tasks; - completion_token_output nextToken() - { - completion_token_output result; - result.tok = -1; - - if (embd.size() >= (size_t)params.n_ctx) - { - // Reset context - const int n_left = (params.n_ctx - params.n_keep) / 2; - - std::vector new_tokens(embd.begin(), embd.begin() + params.n_keep); - new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end()); - embd = new_tokens; - n_past = params.n_keep; - truncated = true; - // LOG_VERBOSE("input truncated", { - // {"n_ctx", params.n_ctx}, - // {"n_keep", params.n_keep}, - // {"n_left", n_left}, - // {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - // }); - } + try { + const auto &prompt = data.at("prompt"); - while (n_past < embd.size()) - { - int n_eval = (int)embd.size() - n_past; - if (n_eval > params.n_batch) - { - n_eval = params.n_batch; - } - if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) - { - // LOG_ERROR("failed to eval", { - // {"n_eval", n_eval}, - // {"n_past", n_past}, - // {"n_threads", params.n_threads}, - // {"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, - // }); - has_next_token = false; - return result; - } - n_past += n_eval; - } + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true); - if (params.n_predict == 0) - { - has_next_token = false; - result.tok = llama_token_eos(ctx); - return result; - } + tasks.reserve(tokenized_prompts.size()); + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(type); - // out of user input, sample next token - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; - const float top_p = params.top_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n; - const float repeat_penalty = params.repeat_penalty; - const float alpha_presence = params.presence_penalty; - const float alpha_frequency = params.frequency_penalty; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; - const bool penalize_nl = params.penalize_nl; - const int32_t n_probs = params.n_probs; - - { - auto *logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(ctx); - - // Apply params.logit_bias map - for (const auto &it : params.logit_bias) - { - logits[it.first] += it.second; - } + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = i; - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) - { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } + task.prompt_tokens = std::move(tokenized_prompts[i]); + task.params = server_task::params_from_json_cmpl(ctx_server->ctx, ctx_server->params_base, data); + task.id_selected_slot = json_value(data, "id_slot", -1); - llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; - - // Apply penalties - float nl_logit = logits[llama_token_nl(ctx)]; - auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); - llama_sample_repetition_penalty(ctx, &candidates_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); - if (!penalize_nl) - { - logits[llama_token_nl(ctx)] = nl_logit; - } + // OAI-compat + task.params.oaicompat = OAICOMPAT_TYPE_NONE; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl - if (grammar != nullptr) - { - llama_sample_grammar(ctx, &candidates_p, grammar); - } + tasks.push_back(task); + } + } catch (const std::exception &e) { + const auto &err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST); + env->ThrowNew(c_llama_error, err.dump().c_str()); + return 0; + } - if (temp <= 0) - { - // Greedy sampling - result.tok = llama_sample_token_greedy(ctx, &candidates_p); - if (n_probs > 0) - { - llama_sample_softmax(ctx, &candidates_p); - } - } - else - { - if (mirostat == 1) - { - static float mirostat_mu = 2.0f * mirostat_tau; - const int mirostat_m = 100; - llama_sample_temperature(ctx, &candidates_p, temp); - result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); - } - else if (mirostat == 2) - { - static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temperature(ctx, &candidates_p, temp); - result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); - } - else - { - // Temperature sampling - size_t min_keep = std::max(1, n_probs); - llama_sample_top_k(ctx, &candidates_p, top_k, min_keep); - llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep); - llama_sample_typical(ctx, &candidates_p, typical_p, min_keep); - llama_sample_top_p(ctx, &candidates_p, top_p, min_keep); - llama_sample_temperature(ctx, &candidates_p, temp); - result.tok = llama_sample_token(ctx, &candidates_p); - } - } + ctx_server->queue_results.add_waiting_tasks(tasks); + ctx_server->queue_tasks.post(tasks); - if (grammar != nullptr) - { - llama_grammar_accept_token(ctx, grammar, result.tok); - } + const auto task_ids = server_task::get_list_id(tasks); - for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) - { - result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); - } + if (task_ids.size() != 1) { + env->ThrowNew(c_llama_error, "multitasking currently not supported"); + return 0; + } - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(result.tok); - num_tokens_predicted++; - } + return *task_ids.begin(); +} - // add it to the context - embd.push_back(result.tok); - // decrement remaining sampling budget - --n_remain; - - if (!embd.empty() && embd.back() == llama_token_eos(ctx)) - { - // stopping_word = llama_token_to_piece(ctx, embd.back()); - has_next_token = false; - stopped_eos = true; - // LOG_VERBOSE("eos token found", {}); - return result; - } +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *env, jobject obj, jint id_task) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) + ctx_server->queue_results.remove_waiting_task_id(id_task); +} - has_next_token = params.n_predict == -1 || n_remain != 0; - return result; - } +JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIEnv *env, jobject obj, jint id_task) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) - size_t findStoppingStrings(const std::string &text, const size_t last_token_size, - const stop_type type) - { - size_t stop_pos = std::string::npos; - for (const std::string &word : params.antiprompt) - { - size_t pos; - if (type == STOP_FULL) - { - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - pos = text.find(word, from_pos); - } - else - { - pos = find_partial_stop_string(word, text); - } - if (pos != std::string::npos && - (stop_pos == std::string::npos || pos < stop_pos)) - { - if (type == STOP_FULL) - { - stopping_word = word; - stopped_word = true; - has_next_token = false; - } - stop_pos = pos; + server_task_result_ptr result = ctx_server->queue_results.recv(id_task); + + if (result->is_error()) { + std::string response = result->to_json()["message"].get(); + ctx_server->queue_results.remove_waiting_task_id(id_task); + env->ThrowNew(c_llama_error, response.c_str()); + return nullptr; + } + const auto out_res = result->to_json(); + + std::string response = out_res["content"].get(); + if (result->is_stop()) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + + jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map); + if (out_res.contains("completion_probabilities")) { + auto completion_probabilities = out_res["completion_probabilities"]; + for (const auto &entry : completion_probabilities) { + auto probs = entry["probs"]; + for (const auto &tp : probs) { + std::string tok_str = tp["tok_str"]; + jstring jtok_str = env->NewStringUTF(tok_str.c_str()); + float prob = tp["prob"]; + jobject jprob = env->NewObject(c_float, cc_float, prob); + env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob); + env->DeleteLocalRef(jtok_str); + env->DeleteLocalRef(jprob); } } - return stop_pos; } + jbyteArray jbytes = parse_jbytes(env, response); + return env->NewObject(c_output, cc_output, jbytes, o_probabilities, result->is_stop()); +} - completion_token_output doCompletion() - { - auto token_with_probs = nextToken(); - - const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok); - generated_text += token_text; - - if (params.n_probs > 0) - { - generated_token_probs.push_back(token_with_probs); - } +JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, jobject obj, jstring jprompt) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) - if (multibyte_pending > 0) - { - multibyte_pending -= token_text.size(); - } - else if (token_text.size() == 1) - { - const char c = token_text[0]; - // 2-byte characters: 110xxxxx 10xxxxxx - if ((c & 0xE0) == 0xC0) - { - multibyte_pending = 1; - // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx - } - else if ((c & 0xF0) == 0xE0) - { - multibyte_pending = 2; - // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - } - else if ((c & 0xF8) == 0xF0) - { - multibyte_pending = 3; - } - else - { - multibyte_pending = 0; - } - } + if (!ctx_server->params_base.embedding) { + env->ThrowNew(c_llama_error, + "model was not loaded with embedding support (see ModelParameters#setEmbedding(boolean))"); + return nullptr; + } - if (multibyte_pending > 0 && !has_next_token) - { - has_next_token = true; - n_remain++; - } + const std::string prompt = parse_jstring(env, jprompt); - if (!has_next_token && n_remain == 0) - { - stopped_limit = true; - } + SRV_INF("Calling embedding '%s'\n", prompt.c_str()); - // LOG_VERBOSE("next token", { - // {"token", token_with_probs.tok}, - // {"token_text", tokens_to_output_formatted_string(ctx, token_with_probs.tok)}, - // {"has_next_token", has_next_token}, - // {"n_remain", n_remain}, - // {"num_tokens_predicted", num_tokens_predicted}, - // {"stopped_eos", stopped_eos}, - // {"stopped_word", stopped_word}, - // {"stopped_limit", stopped_limit}, - // {"stopping_word", stopping_word}, - // }); - - return token_with_probs; - } + const auto tokens = tokenize_mixed(ctx_server->vocab, prompt, true, true); + std::vector tasks; - std::vector getEmbedding() - { - static const int n_embd = llama_n_embd(ctx); - if (!params.embedding) - { - // LOG_WARNING("embedding disabled", { - // {"params.embedding", params.embedding}, - // }); - return std::vector(n_embd, 0.0f); - } - const float *data = llama_get_embeddings(ctx); - std::vector embedding(data, data + n_embd); - return embedding; - } -}; - -static gpt_params parse_model_params(JNIEnv *env, jobject jparams, jstring java_file_path) -{ - gpt_params params; - - params.model = parse_jstring(env, java_file_path); - params.seed = env->GetIntField(jparams, f_model_seed); - params.n_threads = env->GetIntField(jparams, f_n_threads); - params.n_ctx = env->GetIntField(jparams, f_n_ctx); - params.n_batch = env->GetIntField(jparams, f_n_batch); - params.n_gpu_layers = env->GetIntField(jparams, f_n_gpu_layers); - params.main_gpu = env->GetIntField(jparams, f_main_gpu); - params.rope_freq_base = env->GetFloatField(jparams, f_rope_freq_base); - params.rope_freq_scale = env->GetFloatField(jparams, f_rope_freq_scale); - params.hellaswag = env->GetBooleanField(jparams, f_hellaswag); - params.hellaswag_tasks = env->GetShortField(jparams, f_hellaswag_tasks); - params.low_vram = env->GetBooleanField(jparams, f_low_vram); - params.mul_mat_q = env->GetBooleanField(jparams, f_mul_mat_q); - params.memory_f16 = env->GetBooleanField(jparams, f_memory_f16); - params.embedding = env->GetBooleanField(jparams, f_embedding); - params.escape = env->GetIntField(jparams, f_n_predict); - params.use_mmap = env->GetBooleanField(jparams, f_use_mmap); - params.use_mlock = env->GetBooleanField(jparams, f_use_mlock); - params.numa = env->GetBooleanField(jparams, f_numa); - params.verbose_prompt = env->GetBooleanField(jparams, f_verbose_prompt); - - jstring j_lora_adapter = (jstring)env->GetObjectField(jparams, f_lora_adapter); - if (j_lora_adapter != nullptr) - { - params.lora_adapter = parse_jstring(env, j_lora_adapter); - env->DeleteLocalRef(j_lora_adapter); - } - jstring j_lora_base = (jstring)env->GetObjectField(jparams, f_lora_base); - if (j_lora_base != nullptr) - { - params.lora_base = parse_jstring(env, j_lora_base); - env->DeleteLocalRef(j_lora_base); - } - -// jfloatArray j_tensor_split = (jfloatArray)env->GetObjectField(jparams, f_tensor_split); -// if (j_tensor_split != nullptr) -// { -// #ifndef GGML_USE_CUBLAS -// // LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {}); -// #endif -// jsize array_length = env->GetArrayLength(j_tensor_split); -// GGML_ASSERT(array_length <= LLAMA_MAX_DEVICES); -// float *tensor_split = new float[array_length]; -// env->GetFloatArrayRegion(j_tensor_split, 0, array_length, tensor_split); -// for (size_t i_device = 0; i_device < LLAMA_MAX_DEVICES; ++i_device) -// { -// if (i_device < array_length) -// { -// params.tensor_split[i_device] = tensor_split[i_device]; -// } -// else -// { -// params.tensor_split[i_device] = 0.0f; -// } -// } -// delete[] tensor_split; -// } - // - // #ifndef LLAMA_SUPPORTS_GPU_OFFLOAD - // if (params.n_gpu_layers > 0) { - // // LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. " - // // "See main README.md for information on enabling GPU BLAS support", - // // {{"n_gpu_layers", params.n_gpu_layers}}); - // } - // #endif - // - // #ifndef GGML_USE_CUBLAS - // if (params.low_vram) { - // // LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n", {}); - // } - // if (!params.mul_mat_q) { - // // LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n", {}); - // } - // if (params.main_gpu != 0) { - // // LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.", {}); - // } - // #endif - // - // // todo: these have to be set in llama_context_params - // // f_logits_all - // // f_vocab_only - // // f_memory_f16 - // // f_f16_kv - - return params; -} + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); -static void parse_inference_params(JNIEnv *env, jllama_context *llama, jstring prompt, jobject params) -{ - llama->prompt = parse_jstring(env, prompt); - llama->params.n_predict = env->GetIntField(params, f_n_predict); - llama->params.n_keep = env->GetIntField(params, f_n_keep); - llama->params.n_probs = env->GetIntField(params, f_n_probs); - llama->params.top_k = env->GetIntField(params, f_top_k); - llama->params.top_p = env->GetFloatField(params, f_top_p); - llama->params.tfs_z = env->GetFloatField(params, f_tfs_z); - llama->params.typical_p = env->GetFloatField(params, f_typical_p); - llama->params.temp = env->GetFloatField(params, f_temperature); - llama->params.repeat_penalty = env->GetFloatField(params, f_repeat_penalty); - llama->params.repeat_last_n = env->GetIntField(params, f_repeat_last_n); - llama->params.frequency_penalty = env->GetFloatField(params, f_frequency_penalty); - llama->params.presence_penalty = env->GetFloatField(params, f_presence_penalty); - llama->params.penalize_nl = env->GetBooleanField(params, f_penalize_nl); - llama->params.mirostat = env->GetIntField(params, f_mirostat); - llama->params.mirostat_tau = env->GetFloatField(params, f_mirostat_tau); - llama->params.mirostat_eta = env->GetFloatField(params, f_mirostat_eta); - llama->params.seed = env->GetIntField(params, f_infer_seed); - - jstring j_grammar = (jstring)env->GetObjectField(params, f_grammar); - if (j_grammar != nullptr) - { - llama->params.grammar = parse_jstring(env, j_grammar); - env->DeleteLocalRef(j_grammar); - } + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = 0; + task.prompt_tokens = std::move(tokens); - llama->params.logit_bias.clear(); - jboolean ignore_eos = env->GetBooleanField(params, f_ignore_eos); - if (ignore_eos) - { - llama->params.logit_bias[llama_token_eos(llama->ctx)] = -INFINITY; - } + // OAI-compat + task.params.oaicompat = OAICOMPAT_TYPE_NONE; - jobject logit_bias = env->GetObjectField(params, f_logit_bias); - if (logit_bias != nullptr) - { - const int n_vocab = llama_n_vocab(llama->ctx); - jobject entry_set = env->CallObjectMethod(logit_bias, m_entry_set); - jobject iterator = env->CallObjectMethod(entry_set, m_set_iterator); - while (env->CallBooleanMethod(iterator, m_iterator_has_next)) - { - jobject entry = env->CallObjectMethod(iterator, m_iterator_next); - jobject key = env->CallObjectMethod(entry, m_entry_key); - jobject value = env->CallObjectMethod(entry, m_entry_value); - - int tok = parse_jinteger(env, key); - float bias = parse_jfloat(env, value); - llama->params.logit_bias[tok] = bias; - - env->DeleteLocalRef(entry); - env->DeleteLocalRef(key); - env->DeleteLocalRef(value); - } - } + tasks.push_back(task); - llama->params.antiprompt.clear(); - jobjectArray antiprompt = (jobjectArray)env->GetObjectField(params, f_antiprompt); - if (antiprompt != nullptr) - { - jsize array_length = env->GetArrayLength(antiprompt); - for (jsize i = 0; i < array_length; i++) - { - jstring java_string = (jstring)env->GetObjectArrayElement(antiprompt, i); - if (java_string != nullptr) - { - std::string string = parse_jstring(env, java_string); - llama->params.antiprompt.push_back(string); - env->DeleteLocalRef(java_string); - } - } - } + ctx_server->queue_results.add_waiting_tasks(tasks); + ctx_server->queue_tasks.post(tasks); - // LOG_VERBOSE("completion parameters parsed", format_generation_settings(*llama)); -} + std::unordered_set task_ids = server_task::get_list_id(tasks); + const auto id_task = *task_ids.begin(); + json responses = json::array(); -JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_getSystemInfo(JNIEnv *env, jobject obj) -{ - const char *sys_info = llama_print_system_info(); - return env->NewStringUTF(sys_info); -} + json error = nullptr; -JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jstring file_path, jobject jparams) -{ - gpt_params params = parse_model_params(env, jparams, file_path); + server_task_result_ptr result = ctx_server->queue_results.recv(id_task); - jllama_context *llama = new jllama_context; - llama_backend_init(false); + json response_str = result->to_json(); + if (result->is_error()) { + std::string response = result->to_json()["message"].get(); + ctx_server->queue_results.remove_waiting_task_id(id_task); + env->ThrowNew(c_llama_error, response.c_str()); + return nullptr; + } - if (!llama->loadModel(params)) - { - env->ThrowNew(c_llama_error, "could not load model from given file path"); - return; + if (result->is_stop()) { + ctx_server->queue_results.remove_waiting_task_id(id_task); } - // LOG_INFO("build info", {{"build", BUILD_NUMBER}, - // {"commit", BUILD_COMMIT}}); - // LOG_INFO("system info", { - // {"n_threads", params.n_threads}, - // {"total_threads", std::thread::hardware_concurrency()}, - // {"system_info", llama_print_system_info()}, - // }); + const auto out_res = result->to_json(); - env->SetLongField(obj, f_model_pointer, reinterpret_cast(llama)); -} + // Extract "embedding" as a vector of vectors (2D array) + std::vector> embedding = out_res["embedding"].get>>(); -JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setupInference(JNIEnv *env, jobject obj, jstring prompt, jobject params) -{ - jlong llama_handle = env->GetLongField(obj, f_model_pointer); - jllama_context *llama = reinterpret_cast(llama_handle); + // Get total number of rows in the embedding + jsize embedding_rows = embedding.size(); - auto lock = llama->lock(); + // Get total number of columns in the first row (assuming all rows are of equal length) + jsize embedding_cols = embedding_rows > 0 ? embedding[0].size() : 0; - llama->rewind(); + SRV_INF("Embedding has %d rows and %d columns\n", embedding_rows, embedding_cols); - llama_reset_timings(llama->ctx); + // Ensure embedding is not empty + if (embedding.empty() || embedding[0].empty()) { + env->ThrowNew(c_error_oom, "embedding array is empty"); + return nullptr; + } - parse_inference_params(env, llama, prompt, params); + // Extract only the first row + const std::vector &first_row = embedding[0]; // Reference to avoid copying - if (!llama->loadGrammar()) - { - env->ThrowNew(c_llama_error, "could not load grammar"); + // Create a new float array in JNI + jfloatArray j_embedding = env->NewFloatArray(embedding_cols); + if (j_embedding == nullptr) { + env->ThrowNew(c_error_oom, "could not allocate embedding"); + return nullptr; } - llama->loadPrompt(); - llama->beginCompletion(); + // Copy the first row into the JNI float array + env->SetFloatArrayRegion(j_embedding, 0, embedding_cols, reinterpret_cast(first_row.data())); + + return j_embedding; } -JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_getNext(JNIEnv *env, jobject obj, jobject iter) -{ - jlong llama_handle = env->GetLongField(obj, f_model_pointer); - jllama_context *llama = reinterpret_cast(llama_handle); - - size_t sent_count = env->GetLongField(iter, f_iter_n_generated); - size_t sent_token_probs_index = env->GetLongField(iter, f_iter_token_index); - - completion_token_output token_with_probs; - while (llama->has_next_token) - { - token_with_probs = llama->doCompletion(); - if (token_with_probs.tok >= 0 && llama->multibyte_pending <= 0) - { - break; - } - } - const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok); - - size_t pos = std::min(sent_count, llama->generated_text.size()); - - const std::string str_test = llama->generated_text.substr(pos); - bool is_stop_full = false; - size_t stop_pos = llama->findStoppingStrings(str_test, token_text.size(), STOP_FULL); - if (stop_pos != std::string::npos) - { - is_stop_full = true; - llama->generated_text.erase( - llama->generated_text.begin() + pos + stop_pos, - llama->generated_text.end()); - pos = std::min(sent_count, llama->generated_text.size()); - } - else - { - is_stop_full = false; - stop_pos = llama->findStoppingStrings(str_test, token_text.size(), STOP_PARTIAL); - } +JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jobject obj, jstring jprompt, + jobjectArray documents) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) - std::string to_send; - if (stop_pos == std::string::npos || (!llama->has_next_token && !is_stop_full && stop_pos > 0)) - { - to_send = llama->generated_text.substr(pos, std::string::npos); - sent_count += to_send.size(); - env->SetLongField(iter, f_iter_n_generated, sent_count); - - std::vector probs_output = {}; - - if (llama->params.n_probs > 0) - { - const std::vector to_send_toks = llama_tokenize(llama->ctx, to_send, false); - size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size()); - size_t probs_stop_pos = std::min( - sent_token_probs_index + to_send_toks.size(), - llama->generated_token_probs.size()); - if (probs_pos < probs_stop_pos) - { - probs_output = std::vector( - llama->generated_token_probs.begin() + probs_pos, - llama->generated_token_probs.begin() + probs_stop_pos); - } - sent_token_probs_index = probs_stop_pos; - env->SetLongField(iter, f_iter_token_index, sent_token_probs_index); - } - } - else - { - to_send = ""; - } - - if (!llama->has_next_token) - { - env->SetLongField(iter, f_iter_has_next, false); + if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) { + env->ThrowNew(c_llama_error, + "This server does not support reranking. Start it with `--reranking` and without `--embedding`"); + return nullptr; } - return parse_jbytes(env, to_send); -} + const std::string prompt = parse_jstring(env, jprompt); -JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_getFull(JNIEnv *env, jobject obj, jstring prompt, jobject params) -{ - Java_de_kherud_llama_LlamaModel_setupInference(env, obj, prompt, params); + const auto tokenized_query = tokenize_mixed(ctx_server->vocab, prompt, true, true); - jlong llama_handle = env->GetLongField(obj, f_model_pointer); - jllama_context *llama = reinterpret_cast(llama_handle); + json responses = json::array(); - size_t stop_pos = std::string::npos; + std::vector tasks; + const jsize amount_documents = env->GetArrayLength(documents); + auto *document_array = parse_string_array(env, documents, amount_documents); + auto document_vector = std::vector(document_array, document_array + amount_documents); + free_string_array(document_array, amount_documents); - while (llama->has_next_token) - { - const completion_token_output token_with_probs = llama->doCompletion(); - const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(llama->ctx, token_with_probs.tok); + std::vector tokenized_docs = tokenize_input_prompts(ctx_server->vocab, document_vector, true, true); - stop_pos = llama->findStoppingStrings(llama->generated_text, - token_text.size(), STOP_FULL); + tasks.reserve(tokenized_docs.size()); + for (int i = 0; i < tokenized_docs.size(); i++) { + auto task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]); + tasks.push_back(task); } + ctx_server->queue_results.add_waiting_tasks(tasks); + ctx_server->queue_tasks.post(tasks); - if (stop_pos == std::string::npos) - { - stop_pos = llama->findStoppingStrings(llama->generated_text, 0, STOP_PARTIAL); - } - if (stop_pos != std::string::npos) - { - llama->generated_text.erase(llama->generated_text.begin() + stop_pos, llama->generated_text.end()); - } + // get the result + std::unordered_set task_ids = server_task::get_list_id(tasks); + std::vector results(task_ids.size()); - auto probs = llama->generated_token_probs; - if (llama->params.n_probs > 0 && llama->stopped_word) - { - const std::vector stop_word_toks = llama_tokenize(llama->ctx, llama->stopping_word, false); - probs = std::vector(llama->generated_token_probs.begin(), llama->generated_token_probs.end() - stop_word_toks.size()); + // Create a new HashMap instance + jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map); + if (o_probabilities == nullptr) { + env->ThrowNew(c_llama_error, "Failed to create HashMap object."); + return nullptr; } - llama_print_timings(llama->ctx); + for (int i = 0; i < (int)task_ids.size(); i++) { + server_task_result_ptr result = ctx_server->queue_results.recv(task_ids); + if (result->is_error()) { + auto response = result->to_json()["message"].get(); + for (const int id_task : task_ids) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + env->ThrowNew(c_llama_error, response.c_str()); + return nullptr; + } - // llama->lock().release(); - // llama->mutex.unlock(); + const auto out_res = result->to_json(); - return parse_jbytes(env, llama->generated_text); -} + if (result->is_stop()) { + for (const int id_task : task_ids) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + } -JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, jobject obj, jstring java_prompt) -{ - // auto lock = llama.lock(); - jlong llama_handle = env->GetLongField(obj, f_model_pointer); - jllama_context *llama = reinterpret_cast(llama_handle); - - llama->rewind(); - llama_reset_timings(llama->ctx); - llama->prompt = parse_jstring(env, java_prompt); - llama->params.n_predict = 0; - llama->loadPrompt(); - llama->beginCompletion(); - llama->doCompletion(); - - static const int n_embd = llama_n_embd(llama->ctx); - // if (!llama->params.embedding) - // { - // // LOG_WARNING("embedding disabled", { - // // {"params.embedding", params.embedding}, - // // }); - // return std::vector(n_embd, 0.0f); - // } - const float *data = llama_get_embeddings(llama->ctx); - std::vector embedding(data, data + n_embd); - - jfloatArray java_embedding = env->NewFloatArray(embedding.size()); - if (java_embedding == nullptr) - { - env->ThrowNew(c_error_oom, "could not allocate embedding"); - return nullptr; + int index = out_res["index"].get(); + float score = out_res["score"].get(); + std::string tok_str = document_vector[index]; + jstring jtok_str = env->NewStringUTF(tok_str.c_str()); + + jobject jprob = env->NewObject(c_float, cc_float, score); + env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob); + env->DeleteLocalRef(jtok_str); + env->DeleteLocalRef(jprob); } + jbyteArray jbytes = parse_jbytes(env, prompt); + return env->NewObject(c_output, cc_output, jbytes, o_probabilities, true); +} + +JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) + + std::string c_params = parse_jstring(env, jparams); + json data = json::parse(c_params); - env->SetFloatArrayRegion(java_embedding, 0, embedding.size(), reinterpret_cast(embedding.data())); + json templateData = + oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja, + ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get()); + std::string tok_str = templateData.at("prompt"); + jstring jtok_str = env->NewStringUTF(tok_str.c_str()); - return java_embedding; + return jtok_str; } -JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring java_prompt) -{ - jlong llama_handle = env->GetLongField(obj, f_model_pointer); - jllama_context *llama = reinterpret_cast(llama_handle); +JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) - // auto lock = llama->lock(); + const std::string c_prompt = parse_jstring(env, jprompt); - std::string prompt = parse_jstring(env, java_prompt); - std::vector tokens = llama->tokenize(prompt, false); + llama_tokens tokens = tokenize_mixed(ctx_server->vocab, c_prompt, false, true); + jsize token_size = tokens.size(); // NOLINT(*-narrowing-conversions) - jintArray java_tokens = env->NewIntArray(tokens.size()); - if (java_tokens == nullptr) - { - env->ThrowNew(c_error_oom, "could not allocate tokens"); + jintArray java_tokens = env->NewIntArray(token_size); + if (java_tokens == nullptr) { + env->ThrowNew(c_error_oom, "could not allocate token memory"); return nullptr; } - env->SetIntArrayRegion(java_tokens, 0, tokens.size(), reinterpret_cast(tokens.data())); + env->SetIntArrayRegion(java_tokens, 0, token_size, reinterpret_cast(tokens.data())); - // lock.release(); return java_tokens; } -JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes(JNIEnv *env, jobject obj, jintArray java_tokens) -{ - jlong llama_handle = env->GetLongField(obj, f_model_pointer); - jllama_context *llama = reinterpret_cast(llama_handle); - // auto lock = llama.lock(); +JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes(JNIEnv *env, jobject obj, + jintArray java_tokens) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) jsize length = env->GetArrayLength(java_tokens); jint *elements = env->GetIntArrayElements(java_tokens, nullptr); std::vector tokens(elements, elements + length); - std::string text = tokens_to_str(llama->ctx, tokens.cbegin(), tokens.cend()); + std::string text = tokens_to_str(ctx_server->ctx, tokens.cbegin(), tokens.cend()); env->ReleaseIntArrayElements(java_tokens, elements, 0); - return env->NewString((jchar *)text.data(), text.size()); + return parse_jbytes(env, text); } -JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv * env, jclass clazz, jobject callback) { - env->GetJavaVM(&g_vm); +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv *env, jobject obj) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) + ctx_server->queue_tasks.terminate(); + // delete ctx_server; +} + +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv *env, jobject obj, jint id_task) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) + std::unordered_set id_tasks = {id_task}; + ctx_server->cancel_tasks(id_tasks); + ctx_server->queue_results.remove_waiting_task_id(id_task); +} - if (g_log_callback != nullptr) - { - env->DeleteGlobalRef(g_log_callback); - } +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv *env, jclass clazz, jobject log_format, + jobject jcallback) { + if (o_log_callback != nullptr) { + env->DeleteGlobalRef(o_log_callback); + } + + log_json = env->IsSameObject(log_format, o_log_format_json); + + if (jcallback == nullptr) { + log_callback = nullptr; + llama_log_set(nullptr, nullptr); + } else { + o_log_callback = env->NewGlobalRef(jcallback); + log_callback = [](enum ggml_log_level level, const char *text, void *user_data) { + JNIEnv *env = get_jni_env(); + jstring message = env->NewStringUTF(text); + jobject log_level = log_level_to_jobject(level); + env->CallVoidMethod(o_log_callback, m_biconsumer_accept, log_level, message); + env->DeleteLocalRef(message); + }; + if (!log_json) { + llama_log_set(log_callback_trampoline, nullptr); + } + } +} - if (callback == nullptr) { - llama_log_set(nullptr, nullptr); - } else { - g_log_callback = env->NewGlobalRef(callback); - llama_log_set(jllama_log_callback, nullptr); - } +JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv *env, jclass clazz, + jstring j_schema) { + const std::string c_schema = parse_jstring(env, j_schema); + nlohmann::ordered_json c_schema_json = nlohmann::ordered_json::parse(c_schema); + const std::string c_grammar = json_schema_to_grammar(c_schema_json); + return parse_jbytes(env, c_grammar); } diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h index ddd432b2..dc17fa83 100644 --- a/src/main/cpp/jllama.h +++ b/src/main/cpp/jllama.h @@ -12,72 +12,91 @@ extern "C" { * Method: embed * Signature: (Ljava/lang/String;)[F */ -JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed - (JNIEnv *, jobject, jstring); +JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *, jobject, jstring); /* * Class: de_kherud_llama_LlamaModel * Method: encode * Signature: (Ljava/lang/String;)[I */ -JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode - (JNIEnv *, jobject, jstring); +JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *, jobject, jstring); /* * Class: de_kherud_llama_LlamaModel - * Method: decode - * Signature: ([I)Ljava/lang/String; + * Method: setLogger + * Signature: (Lde/kherud/llama/args/LogFormat;Ljava/util/function/BiConsumer;)V */ -JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes - (JNIEnv *, jobject, jintArray); +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv *, jclass, jobject, jobject); /* * Class: de_kherud_llama_LlamaModel - * Method: setLogger - * Signature: (Ljava/util/function/BiConsumer;)V + * Method: requestCompletion + * Signature: (Ljava/lang/String;)I */ -JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger - (JNIEnv *, jclass, jobject); +JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv *, jobject, jstring); /* * Class: de_kherud_llama_LlamaModel - * Method: loadModel - * Signature: (Ljava/lang/String;Lde/kherud/llama/ModelParameters;)V + * Method: receiveCompletion + * Signature: (I)Lde/kherud/llama/LlamaOutput; */ -JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel - (JNIEnv *, jobject, jstring, jobject); +JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIEnv *, jobject, jint); /* * Class: de_kherud_llama_LlamaModel - * Method: setupInference - * Signature: (Ljava/lang/String;Lde/kherud/llama/InferenceParameters;)V + * Method: cancelCompletion + * Signature: (I)V */ -JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setupInference - (JNIEnv *, jobject, jstring, jobject); +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv *, jobject, jint); /* * Class: de_kherud_llama_LlamaModel - * Method: getFull - * Signature: (Ljava/lang/String;Lde/kherud/llama/InferenceParameters;)[B + * Method: decodeBytes + * Signature: ([I)[B */ -JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_getFull - (JNIEnv *, jobject, jstring, jobject); +JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes(JNIEnv *, jobject, jintArray); /* * Class: de_kherud_llama_LlamaModel - * Method: getNext - * Signature: (Lde/kherud/llama/LlamaModel/LlamaIterator;)[B + * Method: loadModel + * Signature: ([Ljava/lang/String;)V */ -JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_getNext - (JNIEnv *, jobject, jobject); +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *, jobject, jobjectArray); /* * Class: de_kherud_llama_LlamaModel * Method: delete * Signature: ()V */ -JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete - (JNIEnv *, jobject); +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv *, jobject); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: releaseTask + * Signature: (I)V + */ +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *, jobject, jint); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: jsonSchemaToGrammarBytes + * Signature: (Ljava/lang/String;)[B + */ +JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv *, jclass, jstring); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: rerank + * Signature: (Ljava/lang/String;[Ljava/lang/String;)Lde/kherud/llama/LlamaOutput; + */ +JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *, jobject, jstring, jobjectArray); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: applyTemplate + * Signature: (Ljava/lang/String;)Ljava/lang/String;; + */ +JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *, jobject, jstring); #ifdef __cplusplus } diff --git a/src/main/cpp/llama.cpp b/src/main/cpp/llama.cpp deleted file mode 160000 index 8781013e..00000000 --- a/src/main/cpp/llama.cpp +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 8781013ef654270cbead3e0011e33a6d690fb168 diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp new file mode 100644 index 00000000..9686f2af --- /dev/null +++ b/src/main/cpp/server.hpp @@ -0,0 +1,3271 @@ +#include "utils.hpp" + +#include "json-schema-to-grammar.h" +#include "sampling.h" +#include "speculative.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +constexpr int HTTP_POLLING_SECONDS = 1; + +enum stop_type { + STOP_TYPE_NONE, + STOP_TYPE_EOS, + STOP_TYPE_WORD, + STOP_TYPE_LIMIT, +}; + +// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 +enum slot_state { + SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it + // with launch_slot_with_task in the future + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_DONE_PROMPT, + SLOT_STATE_GENERATING, +}; + +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded +}; + +enum server_task_type { + SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, + SERVER_TASK_TYPE_CANCEL, + SERVER_TASK_TYPE_NEXT_RESPONSE, + SERVER_TASK_TYPE_METRICS, + SERVER_TASK_TYPE_SLOT_SAVE, + SERVER_TASK_TYPE_SLOT_RESTORE, + SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_SET_LORA, +}; + +enum oaicompat_type { + OAICOMPAT_TYPE_NONE, + OAICOMPAT_TYPE_CHAT, + OAICOMPAT_TYPE_COMPLETION, + OAICOMPAT_TYPE_EMBEDDING, +}; + +// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 +enum error_type { + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error +}; + +struct slot_params { + bool stream = true; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool return_tokens = false; + + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = + 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_predict = -1; // new tokens to predict + int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters + + int64_t t_max_prompt_ms = -1; // TODO: implement + int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + + std::vector lora; + + std::vector antiprompt; + std::vector response_fields; + bool timings_per_token = false; + bool post_sampling_probs = false; + bool ignore_eos = false; + + struct common_params_sampling sampling; + struct common_params_speculative speculative; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + json to_json() const { + std::vector samplers; + samplers.reserve(sampling.samplers.size()); + for (const auto &sampler : sampling.samplers) { + samplers.emplace_back(common_sampler_type_to_str(sampler)); + } + + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + + auto grammar_triggers = json::array(); + for (const auto &trigger : sampling.grammar_triggers) { + grammar_triggers.push_back(trigger.to_json()); + } + + return json{ + {"n_predict", n_predict}, // Server configured n_predict + {"seed", sampling.seed}, + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"dry_sequence_breakers", sampling.dry_sequence_breakers}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"stop", antiprompt}, + {"max_tokens", n_predict}, // User configured n_predict + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + {"logit_bias", format_logit_bias(sampling.logit_bias)}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"grammar", sampling.grammar}, + {"grammar_lazy", sampling.grammar_lazy}, + {"grammar_triggers", grammar_triggers}, + {"preserved_tokens", sampling.preserved_tokens}, + {"chat_format", common_chat_format_name(oaicompat_chat_format)}, + {"samplers", samplers}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, + }; + } +}; + +struct server_task { + int id = -1; // to be filled by server_queue + int index = -1; // used when there are multiple prompts (batch request) + + server_task_type type; + + // used by SERVER_TASK_TYPE_CANCEL + int id_target = -1; + + // used by SERVER_TASK_TYPE_INFERENCE + slot_params params; + llama_tokens prompt_tokens; + int id_selected_slot = -1; + + // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE + struct slot_action { + int slot_id; + std::string filename; + std::string filepath; + }; + slot_action slot_action; + + // used by SERVER_TASK_TYPE_METRICS + bool metrics_reset_bucket = false; + + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + + server_task(server_task_type type) : type(type) {} + + static slot_params params_from_json_cmpl(const llama_context *ctx, const common_params ¶ms_base, + const json &data) { + const llama_model *model = llama_get_model(ctx); + const llama_vocab *vocab = llama_model_get_vocab(model); + + slot_params params; + + // Sampling parameter defaults are loaded from the global server context (but individual requests can still + // override them) + slot_params defaults; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + + // enabling this will output extra debug information in the HTTP responses from the server + params.verbose = params_base.verbosity > 9; + params.timings_per_token = json_value(data, "timings_per_token", false); + + params.stream = json_value(data, "stream", false); + params.cache_prompt = json_value(data, "cache_prompt", true); + params.return_tokens = json_value(data, "return_tokens", false); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + params.n_indent = json_value(data, "n_indent", defaults.n_indent); + params.n_keep = json_value(data, "n_keep", defaults.n_keep); + params.n_discard = json_value(data, "n_discard", defaults.n_discard); + // params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: + // implement + params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.response_fields = json_value(data, "response_fields", std::vector()); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = + json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = + json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); + params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); + params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); + + params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); + params.speculative.n_min = std::max(params.speculative.n_min, 0); + params.speculative.n_max = std::max(params.speculative.n_max, 0); + + // Use OpenAI API logprobs only if n_probs wasn't provided + if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs) { + params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); + } + + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = params_base.lora_adapters; + } + + // TODO: add more sanity checks for the input parameters + + if (params.sampling.penalty_last_n < -1) { + throw std::runtime_error("Error: repeat_last_n must be >= -1"); + } + + if (params.sampling.dry_penalty_last_n < -1) { + throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); + } + + if (params.sampling.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + params.sampling.penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_base < 1.0f) { + params.sampling.dry_base = defaults.sampling.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: + // https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + params.sampling.dry_sequence_breakers = + json_value(data, "dry_sequence_breakers", std::vector()); + if (params.sampling.dry_sequence_breakers.empty()) { + throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); + } + } + } + + // process "json_schema" and "grammar" + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); + params.sampling.grammar = json_schema_to_grammar(schema); + SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); + } catch (const std::exception &e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } + } else { + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); + params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); + SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); + } + + { + auto it = data.find("chat_format"); + if (it != data.end()) { + params.oaicompat_chat_format = static_cast(it->get()); + SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str()); + } else { + params.oaicompat_chat_format = defaults.oaicompat_chat_format; + } + } + + { + const auto preserved_tokens = data.find("preserved_tokens"); + if (preserved_tokens != data.end()) { + for (const auto &t : *preserved_tokens) { + auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, + /* parse_special= */ true); + if (ids.size() == 1) { + SRV_DBG("Preserved token: %d\n", ids[0]); + params.sampling.preserved_tokens.insert(ids[0]); + } else { + // This may happen when using a tool call style meant for a model with special tokens to + // preserve on a model without said tokens. + SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); + } + } + } + const auto grammar_triggers = data.find("grammar_triggers"); + if (grammar_triggers != data.end()) { + for (const auto &t : *grammar_triggers) { + auto ct = common_grammar_trigger::from_json(t); + if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto &word = ct.value; + auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + auto token = ids[0]; + if (std::find(params.sampling.preserved_tokens.begin(), + params.sampling.preserved_tokens.end(), + (llama_token)token) == params.sampling.preserved_tokens.end()) { + throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + + word); + } + SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); + common_grammar_trigger trigger; + trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; + trigger.value = (llama_token)token; + params.sampling.grammar_triggers.push_back(trigger); + } else { + SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); + params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); + } + } else { + params.sampling.grammar_triggers.push_back(ct); + } + } + } + if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { + throw std::runtime_error("Error: no triggers set for lazy grammar!"); + } + } + + { + params.sampling.logit_bias.clear(); + params.ignore_eos = json_value(data, "ignore_eos", false); + + const auto &logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto &el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else if (el[0].is_string()) { + auto toks = common_tokenize(vocab, el[0].get(), false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + } + } + + { + params.antiprompt.clear(); + + const auto &stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto &word : *stop) { + if (!word.empty()) { + params.antiprompt.push_back(word); + } + } + } + } + + { + const auto samplers = data.find("samplers"); + if (samplers != data.end()) { + if (samplers->is_array()) { + params.sampling.samplers = common_sampler_types_from_names(*samplers, false); + } else if (samplers->is_string()) { + params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); + } + } else { + params.sampling.samplers = defaults.sampling.samplers; + } + } + + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; + params.oaicompat_model = json_value(data, "model", model_name); + + return params; + } + + // utility function + static std::unordered_set get_list_id(const std::vector &tasks) { + std::unordered_set ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids.insert(tasks[i].id); + } + return ids; + } +}; + +struct result_timings { + int32_t prompt_n = -1; + double prompt_ms; + double prompt_per_token_ms; + double prompt_per_second; + + int32_t predicted_n = -1; + double predicted_ms; + double predicted_per_token_ms; + double predicted_per_second; + + json to_json() const { + return { + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, + {"prompt_per_token_ms", prompt_per_token_ms}, + {"prompt_per_second", prompt_per_second}, + + {"predicted_n", predicted_n}, + {"predicted_ms", predicted_ms}, + {"predicted_per_token_ms", predicted_per_token_ms}, + {"predicted_per_second", predicted_per_second}, + }; + } +}; + +struct server_task_result { + int id = -1; + int id_slot = -1; + virtual bool is_error() { + // only used by server_task_result_error + return false; + } + virtual bool is_stop() { + // only used by server_task_result_cmpl_* + return false; + } + virtual int get_index() { return -1; } + virtual json to_json() = 0; + virtual ~server_task_result() = default; +}; + +// using shared_ptr for polymorphism of server_task_result +using server_task_result_ptr = std::unique_ptr; + +inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: + return "eos"; + case STOP_TYPE_WORD: + return "word"; + case STOP_TYPE_LIMIT: + return "limit"; + default: + return "none"; + } +} + +struct completion_token_output { + llama_token tok; + float prob; + std::string text_to_send; + struct prob_info { + llama_token tok; + std::string txt; + float prob; + }; + std::vector probs; + + json to_json(bool post_sampling_probs) const { + json probs_for_token = json::array(); + for (const auto &p : probs) { + std::string txt(p.txt); + txt.resize(validate_utf8(txt)); + probs_for_token.push_back(json{ + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.txt)}, + {post_sampling_probs ? "prob" : "logprob", post_sampling_probs ? p.prob : logarithm(p.prob)}, + }); + } + return probs_for_token; + } + + static json probs_vector_to_json(const std::vector &probs, bool post_sampling_probs) { + json out = json::array(); + for (const auto &p : probs) { + std::string txt(p.text_to_send); + txt.resize(validate_utf8(txt)); + out.push_back(json{ + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.text_to_send)}, + {post_sampling_probs ? "prob" : "logprob", post_sampling_probs ? p.prob : logarithm(p.prob)}, + {post_sampling_probs ? "top_probs" : "top_logprobs", p.to_json(post_sampling_probs)}, + }); + } + return out; + } + + static float logarithm(float x) { + // nlohmann::json converts -inf to null, so we need to prevent that + return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); + } + + static std::vector str_to_bytes(const std::string &str) { + std::vector bytes; + for (unsigned char c : str) { + bytes.push_back(c); + } + return bytes; + } +}; + +struct server_task_result_cmpl_final : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + bool stream; + result_timings timings; + std::string prompt; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + int32_t n_tokens_cached; + bool has_new_line; + std::string stopping_word; + stop_type stop = STOP_TYPE_NONE; + + bool post_sampling_probs; + std::vector probs_output; + std::vector response_fields; + + slot_params generation_params; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + virtual int get_index() override { return index; } + + virtual bool is_stop() override { + return true; // in stream mode, final responses are considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + json res = json{ + {"index", index}, + {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"tokens", stream ? llama_tokens{} : tokens}, + {"id_slot", id_slot}, + {"stop", true}, + {"model", oaicompat_model}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + {"generation_settings", generation_params.to_json()}, + {"prompt", prompt}, + {"has_new_line", has_new_line}, + {"truncated", truncated}, + {"stop_type", stop_type_to_str(stop)}, + {"stopping_word", stopping_word}, + {"tokens_cached", n_tokens_cached}, + {"timings", timings.to_json()}, + }; + if (!stream && !probs_output.empty()) { + res["completion_probabilities"] = + completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); + } + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (!stream && probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + json res = json{ + {"choices", json::array({json{ + {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", finish_reason}, + }})}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"usage", json{{"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}}}, + {"id", oaicompat_cmpl_id}}; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + std::string finish_reason = "length"; + common_chat_msg msg; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + SRV_DBG("Parsing chat message: %s\n", content.c_str()); + msg = common_chat_parse(content, oaicompat_chat_format); + finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; + } else { + msg.content = content; + } + + json message{ + {"role", "assistant"}, + }; + if (!msg.reasoning_content.empty()) { + message["reasoning_content"] = msg.reasoning_content; + } + if (msg.content.empty() && !msg.tool_calls.empty()) { + message["content"] = json(); + } else { + message["content"] = msg.content; + } + if (!msg.tool_calls.empty()) { + auto tool_calls = json::array(); + for (const auto &tc : msg.tool_calls) { + tool_calls.push_back({ + {"type", "function"}, + {"function", + { + {"name", tc.name}, + {"arguments", tc.arguments}, + }}, + {"id", tc.id}, + }); + } + message["tool_calls"] = tool_calls; + } + + json choice{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", message}, + }; + + if (!stream && probs_output.size() > 0) { + choice["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + + std::time_t t = std::time(0); + + json res = json{{"choices", json::array({choice})}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion"}, + {"usage", json{{"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}}}, + {"id", oaicompat_cmpl_id}}; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat_stream() { + std::time_t t = std::time(0); + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + + json choice = json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}}; + + json ret = json{ + {"choices", json::array({choice})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + {"usage", + json{ + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }; + + if (timings.prompt_n >= 0) { + ret.push_back({"timings", timings.to_json()}); + } + + return ret; + } +}; + +struct server_task_result_cmpl_partial : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + int32_t n_decoded; + int32_t n_prompt_tokens; + + bool post_sampling_probs; + completion_token_output prob_output; + result_timings timings; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + + virtual int get_index() override { return index; } + + virtual bool is_stop() override { + return false; // in stream mode, partial responses are not considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + // non-OAI-compat JSON + json res = json{ + {"index", index}, + {"content", content}, + {"tokens", tokens}, + {"stop", false}, + {"id_slot", id_slot}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + }; + // populate the timings object when needed (usually for the last response or with timings_per_token enabled) + if (timings.prompt_n > 0) { + res.push_back({"timings", timings.to_json()}); + } + if (!prob_output.probs.empty()) { + res["completion_probabilities"] = + completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); + } + return res; + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (prob_output.probs.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + json res = json{{"choices", json::array({json{ + {"text", content}, + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", nullptr}, + }})}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"id", oaicompat_cmpl_id}}; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + bool first = n_decoded == 0; + std::time_t t = std::time(0); + json choices; + + if (first) { + if (content.empty()) { + choices = json::array( + {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"role", "assistant"}}}}}); + } else { + // We have to send this as two updates to conform to openai behavior + json initial_ret = json{{"choices", json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"role", "assistant"}}}}})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + json second_ret = + json{{"choices", + json::array( + {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"content", content}}}}})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + return std::vector({initial_ret, second_ret}); + } + } else { + choices = json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", + json{ + {"content", content}, + }}, + }}); + } + + GGML_ASSERT(choices.size() >= 1); + + if (prob_output.probs.size() > 0) { + choices[0]["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + + json ret = json{{"choices", choices}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}}; + + if (timings.prompt_n >= 0) { + ret.push_back({"timings", timings.to_json()}); + } + + return std::vector({ret}); + } +}; + +struct server_task_result_embd : server_task_result { + int index = 0; + std::vector> embedding; + + int32_t n_tokens; + + // OAI-compat fields + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + + virtual int get_index() override { return index; } + + virtual json to_json() override { + return oaicompat == OAICOMPAT_TYPE_EMBEDDING ? to_json_oaicompat() : to_json_non_oaicompat(); + } + + json to_json_non_oaicompat() { + return json{ + {"index", index}, + {"embedding", embedding}, + }; + } + + json to_json_oaicompat() { + return json{ + {"index", index}, + {"embedding", embedding[0]}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +struct server_task_result_rerank : server_task_result { + int index = 0; + float score = -1e6; + + int32_t n_tokens; + + virtual int get_index() override { return index; } + + virtual json to_json() override { + return json{ + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +// this function maybe used outside of server_task_result_error +static json format_error_response(const std::string &message, const enum error_type type) { + std::string type_str; + int code = 500; + switch (type) { + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; + } + return json{ + {"code", code}, + {"message", message}, + {"type", type_str}, + }; +} + +struct server_task_result_error : server_task_result { + int index = 0; + error_type err_type = ERROR_TYPE_SERVER; + std::string err_msg; + + virtual bool is_error() override { return true; } + + virtual json to_json() override { return format_error_response(err_msg, err_type); } +}; + +struct server_task_result_metrics : server_task_result { + int n_idle_slots; + int n_processing_slots; + int n_tasks_deferred; + int64_t t_start; + + int32_t kv_cache_tokens_count; + int32_t kv_cache_used_cells; + + // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + // while we can also use std::vector this requires copying the slot object which can be quite messy + // therefore, we use json to temporarily store the slot.to_json() result + json slots_data = json::array(); + + virtual json to_json() override { + return json{ + {"idle", n_idle_slots}, + {"processing", n_processing_slots}, + {"deferred", n_tasks_deferred}, + {"t_start", t_start}, + + {"n_prompt_tokens_processed_total", n_prompt_tokens_processed_total}, + {"t_tokens_generation_total", t_tokens_generation_total}, + {"n_tokens_predicted_total", n_tokens_predicted_total}, + {"t_prompt_processing_total", t_prompt_processing_total}, + + {"n_prompt_tokens_processed", n_prompt_tokens_processed}, + {"t_prompt_processing", t_prompt_processing}, + {"n_tokens_predicted", n_tokens_predicted}, + {"t_tokens_generation", t_tokens_generation}, + + {"n_decode_total", n_decode_total}, + {"n_busy_slots_total", n_busy_slots_total}, + + {"kv_cache_tokens_count", kv_cache_tokens_count}, + {"kv_cache_used_cells", kv_cache_used_cells}, + + {"slots", slots_data}, + }; + } +}; + +struct server_task_result_slot_save_load : server_task_result { + std::string filename; + bool is_save; // true = save, false = load + + size_t n_tokens; + size_t n_bytes; + double t_ms; + + virtual json to_json() override { + if (is_save) { + return json{ + {"id_slot", id_slot}, {"filename", filename}, {"n_saved", n_tokens}, + {"n_written", n_bytes}, {"timings", {{"save_ms", t_ms}}}, + }; + } else { + return json{ + {"id_slot", id_slot}, + {"filename", filename}, + {"n_restored", n_tokens}, + {"n_read", n_bytes}, + {"timings", {{"restore_ms", t_ms}}}, + }; + } + } +}; + +struct server_task_result_slot_erase : server_task_result { + size_t n_erased; + + virtual json to_json() override { + return json{ + {"id_slot", id_slot}, + {"n_erased", n_erased}, + }; + } +}; + +struct server_task_result_apply_lora : server_task_result { + virtual json to_json() override { return json{{"success", true}}; } +}; + +struct server_slot { + int id; + int id_task = -1; + + // only used for completion/embedding/infill/rerank + server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; + + llama_batch batch_spec = {}; + + llama_context *ctx = nullptr; + llama_context *ctx_dft = nullptr; + + common_speculative *spec = nullptr; + + std::vector lora; + + // the index relative to completion multi-task request + size_t index = 0; + + struct slot_params params; + + slot_state state = SLOT_STATE_IDLE; + + // used to determine the slot that has been used the longest + int64_t t_last_used = -1; + + // generation props + int32_t n_ctx = 0; // context size per slot + int32_t n_past = 0; + int32_t n_decoded = 0; + int32_t n_remaining = -1; + int32_t i_batch = -1; + int32_t n_predict = -1; // TODO: disambiguate from params.n_predict + + // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated + int32_t n_prompt_tokens = 0; + int32_t n_prompt_tokens_processed = 0; + + // input prompt tokens + llama_tokens prompt_tokens; + + size_t last_nl_pos = 0; + + std::string generated_text; + llama_tokens generated_tokens; + + llama_tokens cache_tokens; + + std::vector generated_token_probs; + + bool has_next_token = true; + bool has_new_line = false; + bool truncated = false; + stop_type stop; + + std::string stopping_word; + + // sampling + json json_schema; + + struct common_sampler *smpl = nullptr; + + llama_token sampled; + + common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + // stats + size_t n_sent_text = 0; // number of sent text character + + int64_t t_start_process_prompt; + int64_t t_start_generation; + + double t_prompt_processing; // ms + double t_token_generation; // ms + + std::function callback_on_release; + + void reset() { + SLT_DBG(*this, "%s", "\n"); + + n_prompt_tokens = 0; + last_nl_pos = 0; + generated_text = ""; + has_new_line = false; + truncated = false; + stop = STOP_TYPE_NONE; + stopping_word = ""; + n_past = 0; + n_sent_text = 0; + task_type = SERVER_TASK_TYPE_COMPLETION; + + generated_tokens.clear(); + generated_token_probs.clear(); + } + + bool is_non_causal() const { + return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; + } + + bool can_batch_with(server_slot &other_slot) { + return is_non_causal() == other_slot.is_non_causal() && are_lora_equal(lora, other_slot.lora); + } + + bool has_budget(const common_params &global_params) { + if (params.n_predict == -1 && global_params.n_predict == -1) { + return true; // limitless + } + + n_remaining = -1; + + if (params.n_predict != -1) { + n_remaining = params.n_predict - n_decoded; + } else if (global_params.n_predict != -1) { + n_remaining = global_params.n_predict - n_decoded; + } + + return n_remaining > 0; // no budget + } + + bool is_processing() const { return state != SLOT_STATE_IDLE; } + + bool can_speculate() const { return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; } + + void add_token(const completion_token_output &token) { + if (!is_processing()) { + SLT_WRN(*this, "%s", "slot is not processing\n"); + return; + } + generated_token_probs.push_back(token); + } + + void release() { + if (is_processing()) { + SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); + + t_last_used = ggml_time_us(); + t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; + state = SLOT_STATE_IDLE; + callback_on_release(id); + } + } + + result_timings get_timings() const { + result_timings timings; + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; + timings.predicted_per_token_ms = t_token_generation / n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + + return timings; + } + + size_t find_stopping_strings(const std::string &text, const size_t last_token_size, bool is_full_stop) { + size_t stop_pos = std::string::npos; + + for (const std::string &word : params.antiprompt) { + size_t pos; + + if (is_full_stop) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + + pos = text.find(word, from_pos); + } else { + // otherwise, partial stop + pos = find_partial_stop_string(word, text); + } + + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (is_full_stop) { + stop = STOP_TYPE_WORD; + stopping_word = word; + has_next_token = false; + } + stop_pos = pos; + } + } + + return stop_pos; + } + + void print_timings() const { + const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; + const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + const double t_gen = t_token_generation / n_decoded; + const double n_gen_second = 1e3 / t_token_generation * n_decoded; + + SLT_INF(*this, + "\n" + "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " total time = %10.2f ms / %5d tokens\n", + t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, t_token_generation, + n_decoded, t_gen, n_gen_second, t_prompt_processing + t_token_generation, + n_prompt_tokens_processed + n_decoded); + } + + json to_json() const { + return json{ + {"id", id}, + {"id_task", id_task}, + {"n_ctx", n_ctx}, + {"speculative", can_speculate()}, + {"is_processing", is_processing()}, + {"non_causal", is_non_causal()}, + {"params", params.to_json()}, + {"prompt", common_detokenize(ctx, prompt_tokens)}, + {"next_token", + { + {"has_next_token", has_next_token}, + {"has_new_line", has_new_line}, + {"n_remain", n_remaining}, + {"n_decoded", n_decoded}, + {"stopping_word", stopping_word}, + }}, + }; + } +}; + +struct server_metrics { + int64_t t_start = 0; + + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + void init() { t_start = ggml_time_us(); } + + void on_prompt_eval(const server_slot &slot) { + n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; + n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + t_prompt_processing_total += slot.t_prompt_processing; + } + + void on_prediction(const server_slot &slot) { + n_tokens_predicted_total += slot.n_decoded; + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + t_tokens_generation_total += slot.t_token_generation; + } + + void on_decoded(const std::vector &slots) { + n_decode_total++; + for (const auto &slot : slots) { + if (slot.is_processing()) { + n_busy_slots_total++; + } + } + } + + void reset_bucket() { + n_prompt_tokens_processed = 0; + t_prompt_processing = 0; + n_tokens_predicted = 0; + t_tokens_generation = 0; + } +}; + +struct server_queue { + int id = 0; + bool running; + + // queues + std::deque queue_tasks; + std::deque queue_tasks_deferred; + + std::mutex mutex_tasks; + std::condition_variable condition_tasks; + + // callback functions + std::function callback_new_task; + std::function callback_update_slots; + + // Add a new task to the end of the queue + int post(server_task task, bool front = false) { + std::unique_lock lock(mutex_tasks); + GGML_ASSERT(task.id != -1); + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + QUE_DBG("new task, id = %d, front = %d\n", task.id, front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + condition_tasks.notify_one(); + return task.id; + } + + // multi-task version of post() + int post(std::vector &tasks, bool front = false) { + std::unique_lock lock(mutex_tasks); + for (auto &task : tasks) { + if (task.id == -1) { + task.id = id++; + } + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int)tasks.size(), front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + } + condition_tasks.notify_one(); + return 0; + } + + // Add a new task, but defer until one slot is available + void defer(server_task task) { + std::unique_lock lock(mutex_tasks); + QUE_DBG("defer task, id = %d\n", task.id); + queue_tasks_deferred.push_back(std::move(task)); + condition_tasks.notify_one(); + } + + // Get the next id for creating a new task + int get_new_id() { + std::unique_lock lock(mutex_tasks); + int new_id = id++; + return new_id; + } + + // Register function to process a new task + void on_new_task(std::function callback) { callback_new_task = std::move(callback); } + + // Register the function to be called when all slots data is ready to be processed + void on_update_slots(std::function callback) { callback_update_slots = std::move(callback); } + + // Call when the state of one slot is changed, it will move one task from deferred to main queue + void pop_deferred_task() { + std::unique_lock lock(mutex_tasks); + if (!queue_tasks_deferred.empty()) { + queue_tasks.emplace_back(std::move(queue_tasks_deferred.front())); + queue_tasks_deferred.pop_front(); + } + condition_tasks.notify_one(); + } + + // end the start_loop routine + void terminate() { + std::unique_lock lock(mutex_tasks); + running = false; + condition_tasks.notify_all(); + } + + /** + * Main loop consists of these steps: + * - Wait until a new task arrives + * - Process the task (i.e. maybe copy data into slot) + * - Check if multitask is finished + * - Update all slots + */ + void start_loop() { + running = true; + + while (true) { + QUE_DBG("%s", "processing new tasks\n"); + + while (true) { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + lock.unlock(); + break; + } + server_task task = queue_tasks.front(); + queue_tasks.pop_front(); + lock.unlock(); + + QUE_DBG("processing task, id = %d\n", task.id); + callback_new_task(std::move(task)); + } + + // all tasks in the current loop is processed, slots data is now ready + QUE_DBG("%s", "update slots\n"); + + callback_update_slots(); + + QUE_DBG("%s", "waiting for new tasks\n"); + { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + condition_tasks.wait(lock, [&] { return (!queue_tasks.empty() || !running); }); + } + } + } + } + + private: + void cleanup_pending_task(int id_target) { + // no need lock because this is called exclusively by post() + auto rm_func = [id_target](const server_task &task) { return task.id_target == id_target; }; + queue_tasks.erase(std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), queue_tasks.end()); + queue_tasks_deferred.erase(std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), + queue_tasks_deferred.end()); + } +}; + +struct server_response { + // for keeping track of all tasks waiting for the result + std::unordered_set waiting_task_ids; + + // the main result queue (using ptr for polymorphism) + std::vector queue_results; + + std::mutex mutex_results; + std::condition_variable condition_results; + + // add the id_task to the list of tasks waiting for response + void add_waiting_task_id(int id_task) { + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, + (int)waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.insert(id_task); + } + + void add_waiting_tasks(const std::vector &tasks) { + std::unique_lock lock(mutex_results); + + for (const auto &task : tasks) { + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, + (int)waiting_task_ids.size()); + waiting_task_ids.insert(task.id); + } + } + + // when the request is finished, we can remove task associated with it + void remove_waiting_task_id(int id_task) { + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, + (int)waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.erase(id_task); + // make sure to clean up all pending results + queue_results.erase(std::remove_if(queue_results.begin(), queue_results.end(), + [id_task](const server_task_result_ptr &res) { return res->id == id_task; }), + queue_results.end()); + } + + void remove_waiting_task_ids(const std::unordered_set &id_tasks) { + std::unique_lock lock(mutex_results); + + for (const auto &id_task : id_tasks) { + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, + (int)waiting_task_ids.size()); + waiting_task_ids.erase(id_task); + } + } + + // This function blocks the thread until there is a response for one of the id_tasks + server_task_result_ptr recv(const std::unordered_set &id_tasks) { + while (true) { + std::unique_lock lock(mutex_results); + condition_results.wait(lock, [&] { return !queue_results.empty(); }); + + for (size_t i = 0; i < queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + + // should never reach here + } + + // same as recv(), but have timeout in seconds + // if timeout is reached, nullptr is returned + server_task_result_ptr recv_with_timeout(const std::unordered_set &id_tasks, int timeout) { + while (true) { + std::unique_lock lock(mutex_results); + + for (int i = 0; i < (int)queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + + std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); + if (cr_res == std::cv_status::timeout) { + return nullptr; + } + } + + // should never reach here + } + + // single-task version of recv() + server_task_result_ptr recv(int id_task) { + std::unordered_set id_tasks = {id_task}; + return recv(id_tasks); + } + + // Send a new result to a waiting id_task + void send(server_task_result_ptr &&result) { + SRV_DBG("sending result for task id = %d\n", result->id); + + std::unique_lock lock(mutex_results); + for (const auto &id_task : waiting_task_ids) { + if (result->id == id_task) { + SRV_DBG("task id = %d pushed to result queue\n", result->id); + + queue_results.emplace_back(std::move(result)); + condition_results.notify_all(); + return; + } + } + } +}; + +struct server_context { + common_params params_base; + + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; + + llama_model *model = nullptr; + llama_context *ctx = nullptr; + + const llama_vocab *vocab = nullptr; + + llama_model *model_dft = nullptr; + + llama_context_params cparams_dft; + + llama_batch batch = {}; + + bool clean_kv_cache = true; + bool add_bos_token = true; + bool has_eos_token = false; + + int32_t n_ctx; // total context for all clients / slots + + // slots / clients + std::vector slots; + json default_generation_settings_for_props; + + server_queue queue_tasks; + server_response queue_results; + + server_metrics metrics; + + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; + + common_chat_templates_ptr chat_templates; + + ~server_context() { + // Clear any sampling context + for (server_slot &slot : slots) { + common_sampler_free(slot.smpl); + slot.smpl = nullptr; + + llama_free(slot.ctx_dft); + slot.ctx_dft = nullptr; + + common_speculative_free(slot.spec); + slot.spec = nullptr; + + llama_batch_free(slot.batch_spec); + } + + llama_batch_free(batch); + } + + bool load_model(const common_params ¶ms) { + SRV_INF("loading model '%s'\n", params.model.c_str()); + + params_base = params; + + llama_init = common_init_from_params(params_base); + + model = llama_init.model.get(); + ctx = llama_init.context.get(); + + if (model == nullptr) { + SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); + return false; + } + + vocab = llama_model_get_vocab(model); + + n_ctx = llama_n_ctx(ctx); + + add_bos_token = llama_vocab_get_add_bos(vocab); + has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; + + if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); + + auto params_dft = params_base; + + params_dft.devices = params_base.speculative.devices; + params_dft.hf_file = params_base.speculative.hf_file; + params_dft.hf_repo = params_base.speculative.hf_repo; + params_dft.model = params_base.speculative.model; + params_dft.model_url = params_base.speculative.model_url; + params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel + : params_base.speculative.n_ctx; + params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; + params_dft.n_parallel = 1; + + llama_init_dft = common_init_from_params(params_dft); + + model_dft = llama_init_dft.model.get(); + + if (model_dft == nullptr) { + SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str()); + return false; + } + + if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { + SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", + params_base.speculative.model.c_str(), params_base.model.c_str()); + + return false; + } + + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); + + cparams_dft = common_context_params_to_llama(params_dft); + cparams_dft.n_batch = n_ctx_dft; + + // force F16 KV cache for the draft model for extra performance + cparams_dft.type_k = GGML_TYPE_F16; + cparams_dft.type_v = GGML_TYPE_F16; + + // the context is not needed - we will create one for each slot + llama_init_dft.context.reset(); + } + + chat_templates = common_chat_templates_init(model, params_base.chat_template); + try { + common_chat_format_example(chat_templates.get(), params.use_jinja); + } catch (const std::exception &e) { + SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. " + "This may cause the model to output suboptimal responses\n", + __func__); + chat_templates = common_chat_templates_init(model, "chatml"); + } + + return true; + } + + void init() { + const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; + + SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); + + for (int i = 0; i < params_base.n_parallel; i++) { + server_slot slot; + + slot.id = i; + slot.ctx = ctx; + slot.n_ctx = n_ctx_slot; + slot.n_predict = params_base.n_predict; + + if (model_dft) { + slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); + + slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); + if (slot.ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create draft context\n"); + return; + } + + slot.spec = common_speculative_init(slot.ctx_dft); + if (slot.spec == nullptr) { + SRV_ERR("%s", "failed to create speculator\n"); + return; + } + } + + SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); + + slot.params.sampling = params_base.sampling; + + slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); }; + + slot.reset(); + + slots.push_back(slot); + } + + default_generation_settings_for_props = slots[0].to_json(); + + // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens + // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not + // used) + { + const int32_t n_batch = llama_n_batch(ctx); + + // only a single seq_id per token is needed + batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + } + + metrics.init(); + } + + server_slot *get_slot_by_id(int id) { + for (server_slot &slot : slots) { + if (slot.id == id) { + return &slot; + } + } + + return nullptr; + } + + server_slot *get_available_slot(const server_task &task) { + server_slot *ret = nullptr; + + // find the slot that has at least n% prompt similarity + if (ret == nullptr && slot_prompt_similarity != 0.0f) { + int lcs_len = 0; + float similarity = 0; + + for (server_slot &slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + // skip the slot if it does not contains cached tokens + if (slot.cache_tokens.empty()) { + continue; + } + + // length of the Longest Common Subsequence between the current slot's prompt and the input prompt + int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens); + + // fraction of the common subsequence length compared to the current slot's prompt length + float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); + + // select the current slot if the criteria match + if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { + lcs_len = cur_lcs_len; + similarity = cur_similarity; + ret = &slot; + } + } + + if (ret != nullptr) { + SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity); + } + } + + // find the slot that has been least recently used + if (ret == nullptr) { + int64_t t_last = ggml_time_us(); + for (server_slot &slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + // select the current slot if the criteria match + if (slot.t_last_used < t_last) { + t_last = slot.t_last_used; + ret = &slot; + } + } + + if (ret != nullptr) { + SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last); + } + } + + return ret; + } + + bool launch_slot_with_task(server_slot &slot, const server_task &task) { + slot.reset(); + slot.id_task = task.id; + slot.index = task.index; + slot.task_type = task.type; + slot.params = std::move(task.params); + slot.prompt_tokens = std::move(task.prompt_tokens); + + if (!are_lora_equal(task.params.lora, slot.lora)) { + // if lora is changed, we cannot reuse cached tokens + slot.cache_tokens.clear(); + slot.lora = task.params.lora; + } + + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); + + if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { + // Might be better to reject the request with a 400 ? + SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, + slot.n_predict); + slot.params.n_predict = slot.n_predict; + } + + if (slot.params.ignore_eos && has_eos_token) { + slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY}); + } + + { + if (slot.smpl != nullptr) { + common_sampler_free(slot.smpl); + } + + slot.smpl = common_sampler_init(model, slot.params.sampling); + if (slot.smpl == nullptr) { + // for now, the only error that may happen here is invalid grammar + send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + + if (slot.ctx_dft) { + llama_batch_free(slot.batch_spec); + + slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + } + + slot.state = SLOT_STATE_STARTED; + + SLT_INF(slot, "%s", "processing task\n"); + + return true; + } + + void kv_cache_clear() { + SRV_DBG("%s", "clearing KV cache\n"); + + // clear the entire KV cache + llama_kv_cache_clear(ctx); + clean_kv_cache = false; + } + + bool process_token(completion_token_output &result, server_slot &slot) { + // remember which tokens were sampled - used for repetition penalties during sampling + const std::string token_str = result.text_to_send; + slot.sampled = result.tok; + + slot.generated_text += token_str; + if (slot.params.return_tokens) { + slot.generated_tokens.push_back(result.tok); + } + slot.has_next_token = true; + + // check if there is incomplete UTF-8 character at the end + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); + + // search stop word and delete it + if (!incomplete) { + size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); + + const std::string str_test = slot.generated_text.substr(pos); + bool send_text = true; + + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); + if (stop_pos != std::string::npos) { + slot.generated_text.erase(slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); + pos = std::min(slot.n_sent_text, slot.generated_text.size()); + } else if (slot.has_next_token) { + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + send_text = stop_pos == std::string::npos; + } + + // check if there is any token to predict + if (send_text) { + // no send the stop word in the response + result.text_to_send = slot.generated_text.substr(pos, std::string::npos); + slot.n_sent_text += result.text_to_send.size(); + // add the token to slot queue and cache + } else { + result.text_to_send = ""; + } + + slot.add_token(result); + if (slot.params.stream) { + send_partial_response(slot, result); + } + } + + if (incomplete) { + slot.has_next_token = true; + } + + // check the limits + if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); + } + + if (slot.has_new_line) { + // if we have already seen a new line, we stop after a certain time limit + if (slot.params.t_max_predict_ms > 0 && + (ggml_time_us() - slot.t_start_generation > 1000.0f * slot.params.t_max_predict_ms)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, + (int)slot.params.t_max_predict_ms); + } + + // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent + if (slot.params.n_indent > 0) { + // check the current indentation + // TODO: improve by not doing it more than once for each new line + if (slot.last_nl_pos > 0) { + size_t pos = slot.last_nl_pos; + + int n_indent = 0; + while (pos < slot.generated_text.size() && + (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { + n_indent++; + pos++; + } + + if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + // cut the last line + slot.generated_text.erase(pos, std::string::npos); + + SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, + n_indent); + } + } + + // find the next new line + { + const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); + + if (pos != std::string::npos) { + slot.last_nl_pos = pos + 1; + } + } + } + } + + // check if there is a new line in the generated text + if (result.text_to_send.find('\n') != std::string::npos) { + slot.has_new_line = true; + } + + // if context shift is disabled, we stop when it reaches the context limit + if (slot.n_past >= slot.n_ctx) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, + "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = " + "%d, n_ctx = %d\n", + slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); + } + + if (llama_vocab_is_eog(vocab, result.tok)) { + slot.stop = STOP_TYPE_EOS; + slot.has_next_token = false; + + SLT_DBG(slot, "%s", "stopped by EOS\n"); + } + + const auto n_ctx_train = llama_model_n_ctx_train(model); + + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; // stop prediction + + SLT_WRN(slot, + "n_predict (%d) is set for infinite generation. " + "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", + slot.params.n_predict, n_ctx_train); + } + + SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, + result.tok, token_str.c_str()); + + return slot.has_next_token; // continue + } + + void populate_token_probs(const server_slot &slot, completion_token_output &result, bool post_sampling, + bool special, int idx) { + size_t n_probs = slot.params.sampling.n_probs; + size_t n_vocab = llama_vocab_n_tokens(vocab); + if (post_sampling) { + const auto *cur_p = common_sampler_get_candidates(slot.smpl); + const size_t max_probs = cur_p->size; + + // set probability for sampled token + for (size_t i = 0; i < max_probs; i++) { + if (cur_p->data[i].id == result.tok) { + result.prob = cur_p->data[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { + result.probs.push_back( + {cur_p->data[i].id, common_token_to_piece(ctx, cur_p->data[i].id, special), cur_p->data[i].p}); + } + } else { + // TODO: optimize this with min-p optimization + std::vector cur = get_token_probabilities(ctx, idx); + + // set probability for sampled token + for (size_t i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + result.prob = cur[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + result.probs.push_back({cur[i].id, common_token_to_piece(ctx, cur[i].id, special), cur[i].p}); + } + } + } + + void send_error(const server_task &task, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(task.id, error, type); + } + + void send_error(const server_slot &slot, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(slot.id_task, error, type); + } + + void send_error(const int id_task, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) { + SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); + + auto res = std::make_unique(); + res->id = id_task; + res->err_type = type; + res->err_msg = error; + + queue_results.send(std::move(res)); + } + + void send_partial_response(server_slot &slot, const completion_token_output &tkn) { + auto res = std::make_unique(); + + res->id = slot.id_task; + res->index = slot.index; + res->content = tkn.text_to_send; + res->tokens = {tkn.tok}; + + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->post_sampling_probs = slot.params.post_sampling_probs; + + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + res->prob_output = tkn; // copy the token probs + } + + // populate timings if this is final response or timings_per_token is enabled + if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { + res->timings = slot.get_timings(); + } + + queue_results.send(std::move(res)); + } + + void send_final_response(server_slot &slot) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->id_slot = slot.id; + + res->index = slot.index; + res->content = std::move(slot.generated_text); + res->tokens = std::move(slot.generated_tokens); + res->timings = slot.get_timings(); + res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + res->response_fields = std::move(slot.params.response_fields); + + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_tokens_cached = slot.n_past; + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.params.post_sampling_probs; + + res->verbose = slot.params.verbose; + res->stream = slot.params.stream; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->oaicompat_chat_format = slot.params.oaicompat_chat_format; + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { + const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); + + size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); + res->probs_output = std::vector( + slot.generated_token_probs.begin(), slot.generated_token_probs.end() - safe_offset); + } else { + res->probs_output = std::vector(slot.generated_token_probs.begin(), + slot.generated_token_probs.end()); + } + } + + res->generation_params = slot.params; // copy the parameters + + queue_results.send(std::move(res)); + } + + void send_embedding(const server_slot &slot, const llama_batch &batch) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; + res->oaicompat = slot.params.oaicompat; + + const int n_embd = llama_model_n_embd(model); + + std::vector embd_res(n_embd, 0.0f); + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float *embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], + batch.seq_id[i][0]); + + res->embedding.push_back(std::vector(n_embd, 0.0f)); + continue; + } + + // normalize only when there is pooling + // TODO: configurable + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + } else { + res->embedding.push_back({embd, embd + n_embd}); + } + } + + SLT_DBG(slot, "%s", "sending embeddings\n"); + + queue_results.send(std::move(res)); + } + + void send_rerank(const server_slot &slot, const llama_batch &batch) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float *embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], + batch.seq_id[i][0]); + + res->score = -1e6; + continue; + } + + res->score = embd[0]; + } + + SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); + + queue_results.send(std::move(res)); + } + + // + // Functions to create new task(s) and receive result(s) + // + + void cancel_tasks(const std::unordered_set &id_tasks) { + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto &id_task : id_tasks) { + SRV_WRN("cancel task, id_task = %d\n", id_task); + + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_task; + queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(task); + } + // push to beginning of the queue, so it has highest priority + queue_tasks.post(cancel_tasks, true); + } + + // receive the results from task(s) + void receive_multi_results(const std::unordered_set &id_tasks, + const std::function &)> &result_handler, + const std::function &error_handler, + const std::function &is_connection_closed) { + std::vector results(id_tasks.size()); + for (int i = 0; i < (int)id_tasks.size(); i++) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; + } + + if (result == nullptr) { + i--; // retry + continue; + } + + if (result->is_error()) { + error_handler(result->to_json()); + cancel_tasks(id_tasks); + return; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr || + dynamic_cast(result.get()) != nullptr || + dynamic_cast(result.get()) != nullptr); + const size_t idx = result->get_index(); + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = std::move(result); + } + result_handler(results); + } + + // receive the results from task(s), in stream mode + void receive_cmpl_results_stream(const std::unordered_set &id_tasks, + const std::function &result_handler, + const std::function &error_handler, + const std::function &is_connection_closed) { + size_t n_finished = 0; + while (true) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; + } + + if (result == nullptr) { + continue; // retry + } + + if (result->is_error()) { + error_handler(result->to_json()); + cancel_tasks(id_tasks); + return; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr || + dynamic_cast(result.get()) != nullptr); + if (!result_handler(result)) { + cancel_tasks(id_tasks); + break; + } + + if (result->is_stop()) { + if (++n_finished == id_tasks.size()) { + break; + } + } + } + } + + // + // Functions to process the task + // + + void process_single_task(server_task task) { + switch (task.type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: { + const int id_slot = task.id_selected_slot; + + server_slot *slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); + + if (slot == nullptr) { + // if no slot is available, we defer this task for processing later + SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } + + if (!launch_slot_with_task(*slot, task)) { + SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); + break; + } + } break; + case SERVER_TASK_TYPE_CANCEL: { + // release slot linked with the task id + for (auto &slot : slots) { + if (slot.id_task == task.id_target) { + slot.release(); + break; + } + } + } break; + case SERVER_TASK_TYPE_NEXT_RESPONSE: { + // do nothing + } break; + case SERVER_TASK_TYPE_METRICS: { + json slots_data = json::array(); + + int n_idle_slots = 0; + int n_processing_slots = 0; + + for (server_slot &slot : slots) { + json slot_data = slot.to_json(); + + if (slot.is_processing()) { + n_processing_slots++; + } else { + n_idle_slots++; + } + + slots_data.push_back(slot_data); + } + SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); + + auto res = std::make_unique(); + res->id = task.id; + res->slots_data = std::move(slots_data); + res->n_idle_slots = n_idle_slots; + res->n_processing_slots = n_processing_slots; + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); + res->t_start = metrics.t_start; + + res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx); + res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx); + + res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; + res->t_prompt_processing_total = metrics.t_prompt_processing_total; + res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; + res->t_tokens_generation_total = metrics.t_tokens_generation_total; + + res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; + res->t_prompt_processing = metrics.t_prompt_processing; + res->n_tokens_predicted = metrics.n_tokens_predicted; + res->t_tokens_generation = metrics.t_tokens_generation; + + res->n_decode_total = metrics.n_decode_total; + res->n_busy_slots_total = metrics.n_busy_slots_total; + + if (task.metrics_reset_bucket) { + metrics.reset_bucket(); + } + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_SAVE: { + int id_slot = task.slot_action.slot_id; + server_slot *slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } + + const size_t token_count = slot->cache_tokens.size(); + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + const size_t nwrite = + llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); + + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = true; + res->n_tokens = token_count; + res->n_bytes = nwrite; + res->t_ms = t_save_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: { + int id_slot = task.slot_action.slot_id; + server_slot *slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } + + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + slot->cache_tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), + slot->cache_tokens.size(), &token_count); + if (nread == 0) { + slot->cache_tokens.resize(0); + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", + ERROR_TYPE_INVALID_REQUEST); + break; + } + slot->cache_tokens.resize(token_count); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = false; + res->n_tokens = token_count; + res->n_bytes = nread; + res->t_ms = t_restore_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: { + int id_slot = task.slot_action.slot_id; + server_slot *slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } + + // Erase token cache + const size_t n_erased = slot->cache_tokens.size(); + llama_kv_cache_seq_rm(ctx, slot->id, -1, -1); + slot->cache_tokens.clear(); + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->n_erased = n_erased; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SET_LORA: { + params_base.lora_adapters = std::move(task.set_lora); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); + } break; + } + } + + void update_slots() { + // check if all slots are idle + { + bool all_idle = true; + + for (auto &slot : slots) { + if (slot.is_processing()) { + all_idle = false; + break; + } + } + + if (all_idle) { + SRV_INF("%s", "all slots are idle\n"); + if (clean_kv_cache) { + kv_cache_clear(); + } + + return; + } + } + + { + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); + queue_tasks.post(task); + } + + // apply context-shift if needed + // TODO: simplify and improve + for (server_slot &slot : slots) { + if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { + if (!params_base.ctx_shift) { + // this check is redundant (for good) + // we should never get here, because generation should already stopped in process_token() + slot.release(); + send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + continue; + } + + // Shift context + const int n_keep = slot.params.n_keep + add_bos_token; + const int n_left = slot.n_past - n_keep; + const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + + SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, + n_discard); + + llama_kv_cache_seq_rm(ctx, slot.id, n_keep, n_keep + n_discard); + llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); + + if (slot.params.cache_prompt) { + for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { + slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + } + + slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + } + + slot.n_past -= n_discard; + + slot.truncated = true; + } + } + + // start populating the batch for this iteration + common_batch_clear(batch); + + // track if given slot can be batched with slots already in the batch + server_slot *slot_batched = nullptr; + + auto accept_special_token = [&](server_slot &slot, llama_token token) { + return params_base.special || + slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end(); + }; + + // frist, add sampled tokens from any ongoing sequences + for (auto &slot : slots) { + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + + slot.i_batch = batch.n_tokens; + + common_batch_add(batch, slot.sampled, slot.n_past, {slot.id}, true); + + slot.n_past += 1; + + if (slot.params.cache_prompt) { + slot.cache_tokens.push_back(slot.sampled); + } + + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.n_past, (int)slot.cache_tokens.size(), slot.truncated); + } + + // process in chunks of params.n_batch + int32_t n_batch = llama_n_batch(ctx); + int32_t n_ubatch = llama_n_ubatch(ctx); + + // next, batch any pending prompts without exceeding n_batch + if (params_base.cont_batching || batch.n_tokens == 0) { + for (auto &slot : slots) { + // check if we can batch this slot with the previous one + if (slot.is_processing()) { + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + } + + // this slot still has a prompt to be processed + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { + auto &prompt_tokens = slot.prompt_tokens; + + // TODO: maybe move branch to outside of this loop in the future + if (slot.state == SLOT_STATE_STARTED) { + slot.t_start_process_prompt = ggml_time_us(); + slot.t_start_generation = 0; + + slot.n_past = 0; + slot.n_prompt_tokens = prompt_tokens.size(); + slot.state = SLOT_STATE_PROCESSING_PROMPT; + + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, + slot.params.n_keep, slot.n_prompt_tokens); + + // print prompt tokens (for debugging) + if (1) { + // first 16 tokens (avoid flooding logs) + for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], + common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } + } else { + // all + for (int i = 0; i < (int)prompt_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], + common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } + } + + // empty prompt passed -> release the slot and send empty response + if (prompt_tokens.empty()) { + SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); + + slot.release(); + slot.print_timings(); + send_final_response(slot); + continue; + } + + if (slot.is_non_causal()) { + if (slot.n_prompt_tokens > n_ubatch) { + slot.release(); + send_error(slot, "input is too large to process. increase the physical batch size", + ERROR_TYPE_SERVER); + continue; + } + + if (slot.n_prompt_tokens > slot.n_ctx) { + slot.release(); + send_error(slot, "input is larger than the max context size. skipping", + ERROR_TYPE_SERVER); + continue; + } + } else { + if (!params_base.ctx_shift) { + // if context shift is disabled, we make sure prompt size is smaller than KV size + // TODO: there should be a separate parameter that control prompt truncation + // context shift should be applied only during the generation phase + if (slot.n_prompt_tokens >= slot.n_ctx) { + slot.release(); + send_error(slot, + "the request exceeds the available context size. try increasing the " + "context size or enable context shift", + ERROR_TYPE_INVALID_REQUEST); + continue; + } + } + if (slot.params.n_keep < 0) { + slot.params.n_keep = slot.n_prompt_tokens; + } + slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); + + // if input prompt is too big, truncate it + if (slot.n_prompt_tokens >= slot.n_ctx) { + const int n_left = slot.n_ctx - slot.params.n_keep; + + const int n_block_size = n_left / 2; + const int erased_blocks = + (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + + llama_tokens new_tokens(prompt_tokens.begin(), + prompt_tokens.begin() + slot.params.n_keep); + + new_tokens.insert(new_tokens.end(), + prompt_tokens.begin() + slot.params.n_keep + + erased_blocks * n_block_size, + prompt_tokens.end()); + + prompt_tokens = std::move(new_tokens); + + slot.truncated = true; + slot.n_prompt_tokens = prompt_tokens.size(); + + SLT_WRN(slot, + "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", + slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); + + GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); + } + + if (slot.params.cache_prompt) { + // reuse any previously computed tokens that are common with the new prompt + slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens); + + // reuse chunks from the cached prompt by shifting their KV cache in the new position + if (params_base.n_cache_reuse > 0) { + size_t head_c = slot.n_past; // cache + size_t head_p = slot.n_past; // current prompt + + SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", + params_base.n_cache_reuse, slot.n_past); + + while (head_c < slot.cache_tokens.size() && head_p < prompt_tokens.size()) { + + size_t n_match = 0; + while (head_c + n_match < slot.cache_tokens.size() && + head_p + n_match < prompt_tokens.size() && + slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + + n_match++; + } + + if (n_match >= (size_t)params_base.n_cache_reuse) { + SLT_INF(slot, + "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> " + "[%zu, %zu)\n", + n_match, head_c, head_c + n_match, head_p, head_p + n_match); + // for (size_t i = head_p; i < head_p + n_match; i++) { + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], + // common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + // } + + const int64_t kv_shift = (int64_t)head_p - (int64_t)head_c; + + llama_kv_cache_seq_rm(ctx, slot.id, head_p, head_c); + llama_kv_cache_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift); + + for (size_t i = 0; i < n_match; i++) { + slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; + slot.n_past++; + } + + head_c += n_match; + head_p += n_match; + } else { + head_c += 1; + } + } + + SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); + } + } + } + + if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { + // we have to evaluate at least 1 token to generate logits. + SLT_WRN(slot, + "need to evaluate at least 1 token to generate logits, n_past = %d, " + "n_prompt_tokens = %d\n", + slot.n_past, slot.n_prompt_tokens); + + slot.n_past--; + } + + slot.n_prompt_tokens_processed = 0; + } + + // non-causal tasks require to fit the entire prompt in the physical batch + if (slot.is_non_causal()) { + // cannot fit the prompt in the current batch - will try next iter + if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + continue; + } + } + + // keep only the common part + if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) { + // could not partially delete (likely using a non-Transformer model) + llama_kv_cache_seq_rm(ctx, slot.id, -1, -1); + + // there is no common part left + slot.n_past = 0; + } + + SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); + + // remove the non-common part from the cache + slot.cache_tokens.resize(slot.n_past); + + // add prompt tokens for processing in the current batch + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + // without pooling, we want to output the embeddings for all the tokens in the batch + const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && + llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; + + common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, {slot.id}, need_embd); + + if (slot.params.cache_prompt) { + slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); + } + + slot.n_prompt_tokens_processed++; + slot.n_past++; + } + + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", + slot.n_past, batch.n_tokens, (float)slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + + // entire prompt has been processed + if (slot.n_past == slot.n_prompt_tokens) { + slot.state = SLOT_STATE_DONE_PROMPT; + + GGML_ASSERT(batch.n_tokens > 0); + + common_sampler_reset(slot.smpl); + + // Process all prompt tokens through sampler system + for (int i = 0; i < slot.n_prompt_tokens; ++i) { + common_sampler_accept(slot.smpl, prompt_tokens[i], false); + } + + // extract the logits only for the last token + batch.logits[batch.n_tokens - 1] = true; + + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + } + } + + if (batch.n_tokens >= n_batch) { + break; + } + } + } + + if (batch.n_tokens == 0) { + SRV_WRN("%s", "no tokens to decode\n"); + return; + } + + SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + + if (slot_batched) { + // make sure we're in the right embedding mode + llama_set_embeddings(ctx, slot_batched->is_non_causal()); + // apply lora, only need to do it once per batch + common_set_adapter_lora(ctx, slot_batched->lora); + } + + // process the created batch of tokens + for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + + llama_batch batch_view = { + n_tokens, batch.token + i, nullptr, batch.pos + i, + batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, + }; + + const int ret = llama_decode(ctx, batch_view); + metrics.on_decoded(slots); + + if (ret != 0) { + if (n_batch == 1 || ret < 0) { + // if you get here, it means the KV cache is full - try increasing it via the context size + SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i " + "= %d, n_batch = %d, ret = %d\n", + i, n_batch, ret); + for (auto &slot : slots) { + slot.release(); + send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); + } + break; // break loop of n_batch + } + + // retry with half the batch size to try to find a free slot in the KV cache + n_batch /= 2; + i -= n_batch; + + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing " + "it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", + i, n_batch, ret); + + continue; // continue loop of n_batch + } + + for (auto &slot : slots) { + if (slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) { + continue; // continue loop of slots + } + + if (slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + // prompt evaluated for embedding + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + if (slot.task_type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; + } else if (slot.state != SLOT_STATE_GENERATING) { + continue; // continue loop of slots + } + + const int tok_idx = slot.i_batch - i; + + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + + slot.i_batch = -1; + + common_sampler_accept(slot.smpl, id, true); + + slot.n_decoded += 1; + + const int64_t t_current = ggml_time_us(); + + if (slot.n_decoded == 1) { + slot.t_start_generation = t_current; + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + + slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; + + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + + if (slot.params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); + } + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + continue; + } + } + + // do speculative decoding + for (auto &slot : slots) { + if (!slot.is_processing() || !slot.can_speculate()) { + continue; + } + + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + // determine the max draft that fits the current slot state + int n_draft_max = slot.params.speculative.n_max; + + // note: n_past is not yet increased for the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + + if (slot.n_remaining > 0) { + n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); + } + + SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < slot.params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", + n_draft_max, slot.params.speculative.n_min); + + continue; + } + + llama_token id = slot.sampled; + + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); + + // ignore small drafts + if (slot.params.speculative.n_min > (int)draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min); + + continue; + } + + // construct the speculation batch + common_batch_clear(slot.batch_spec); + common_batch_add(slot.batch_spec, id, slot.n_past, {slot.id}, true); + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, {slot.id}, true); + } + + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + + slot.n_past += ids.size(); + slot.n_decoded += ids.size(); + + slot.cache_tokens.push_back(id); + slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); + + llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = + common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + break; + } + } + + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int)ids.size() - 1, (int)draft.size(), + slot.n_past); + } + } + + SRV_DBG("%s", "run slots completed\n"); + } + + json model_meta() const { + return json{ + {"vocab_type", llama_vocab_type(vocab)}, {"n_vocab", llama_vocab_n_tokens(vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, {"n_embd", llama_model_n_embd(model)}, + {"n_params", llama_model_n_params(model)}, {"size", llama_model_size(model)}, + }; + } +}; diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp new file mode 100644 index 00000000..603424b4 --- /dev/null +++ b/src/main/cpp/utils.hpp @@ -0,0 +1,856 @@ +#pragma once + +#include "base64.hpp" +#include "common.h" +#include "llama.h" +#include "log.h" + +#ifndef NDEBUG +// crash the server in debug mode, otherwise send an http 500 error +#define CPPHTTPLIB_NO_EXCEPTIONS 1 +#endif +// increase max payload length to allow use of larger context size +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 +// #include "httplib.h" + +// Change JSON_ASSERT from assert() to GGML_ASSERT: +#define JSON_ASSERT GGML_ASSERT +#include "nlohmann/json.hpp" + +#include "chat.h" + +#include +#include +#include +#include +#include + +#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" + +using json = nlohmann::ordered_json; + +#define SLT_INF(slot, fmt, ...) \ + LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) \ + LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) \ + LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) \ + LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) + +#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +template static T json_value(const json &body, const std::string &key, const T &default_value) { + // Fallback null to default value + if (body.contains(key) && !body.at(key).is_null()) { + try { + return body.at(key); + } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) { + LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), + json(default_value).type_name()); + return default_value; + } + } else { + return default_value; + } +} + +const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); + +// +// tokenizer and input processing utils +// + +static bool json_is_array_of_numbers(const json &data) { + if (data.is_array()) { + for (const auto &e : data) { + if (!e.is_number_integer()) { + return false; + } + } + return true; + } + return false; +} + +// is array having BOTH numbers & strings? +static bool json_is_array_of_mixed_numbers_strings(const json &data) { + bool seen_string = false; + bool seen_number = false; + if (data.is_array()) { + for (const auto &e : data) { + seen_string |= e.is_string(); + seen_number |= e.is_number_integer(); + if (seen_number && seen_string) { + return true; + } + } + } + return false; +} + +// get value by path(key1 / key2) +static json json_get_nested_values(const std::vector &paths, const json &js) { + json result = json::object(); + + for (const std::string &path : paths) { + json current = js; + const auto keys = string_split(path, /*separator*/ '/'); + bool valid_path = true; + for (const std::string &k : keys) { + if (valid_path && current.is_object() && current.contains(k)) { + current = current[k]; + } else { + valid_path = false; + } + } + if (valid_path) { + result[path] = current; + } + } + return result; +} + +/** + * this handles 2 cases: + * - only string, example: "string" + * - mixed string and tokens, example: [12, 34, "string", 56, 78] + */ +static llama_tokens tokenize_mixed(const llama_vocab *vocab, const json &json_prompt, bool add_special, + bool parse_special) { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + llama_tokens prompt_tokens; + + if (json_prompt.is_array()) { + bool first = true; + for (const auto &p : json_prompt) { + if (p.is_string()) { + auto s = p.template get(); + + llama_tokens p; + if (first) { + p = common_tokenize(vocab, s, add_special, parse_special); + first = false; + } else { + p = common_tokenize(vocab, s, false, parse_special); + } + + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } else { + if (first) { + first = false; + } + + prompt_tokens.push_back(p.template get()); + } + } + } else { + auto s = json_prompt.template get(); + prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); + } + + return prompt_tokens; +} + +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * and multiple prompts (multi-tasks): + * - "prompt": ["string1", "string2"] + * - "prompt": ["string1", [12, 34, 56]] + * - "prompt": [[12, 34, 56], [78, 90, 12]] + * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] + */ +static std::vector tokenize_input_prompts(const llama_vocab *vocab, const json &json_prompt, + bool add_special, bool parse_special) { + std::vector result; + if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { + // string or mixed + result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special)); + } else if (json_is_array_of_numbers(json_prompt)) { + // array of tokens + result.push_back(json_prompt.get()); + } else if (json_prompt.is_array()) { + // array of prompts + result.reserve(json_prompt.size()); + for (const auto &p : json_prompt) { + if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) { + result.push_back(tokenize_mixed(vocab, p, add_special, parse_special)); + } else if (json_is_array_of_numbers(p)) { + // array of tokens + result.push_back(p.get()); + } else { + throw std::runtime_error( + "element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens"); + } + } + } else { + throw std::runtime_error( + "\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); + } + if (result.empty()) { + throw std::runtime_error("\"prompt\" must not be empty"); + } + return result; +} + +// return the last index of character that can form a valid string +// if the last character is potentially cut in half, return the index before the cut +// if validate_utf8(text) == text.size(), then the whole text is valid utf8 +static size_t validate_utf8(const std::string &text) { + size_t len = text.size(); + if (len == 0) + return 0; + + // Check the last few bytes to see if a multi-byte character is cut off + for (size_t i = 1; i <= 4 && i <= len; ++i) { + unsigned char c = text[len - i]; + // Check for start of a multi-byte sequence from the end + if ((c & 0xE0) == 0xC0) { + // 2-byte character start: 110xxxxx + // Needs at least 2 bytes + if (i < 2) + return len - i; + } else if ((c & 0xF0) == 0xE0) { + // 3-byte character start: 1110xxxx + // Needs at least 3 bytes + if (i < 3) + return len - i; + } else if ((c & 0xF8) == 0xF0) { + // 4-byte character start: 11110xxx + // Needs at least 4 bytes + if (i < 4) + return len - i; + } + } + + // If no cut-off multi-byte character is found, return full length + return len; +} + +// +// template utils +// + +// format rerank task: [BOS]query[EOS][SEP]doc[EOS] +static llama_tokens format_rerank(const struct llama_vocab *vocab, const llama_tokens &query, const llama_tokens &doc) { + llama_tokens result; + + result.reserve(doc.size() + query.size() + 4); + result.push_back(llama_vocab_bos(vocab)); + result.insert(result.end(), query.begin(), query.end()); + result.push_back(llama_vocab_eos(vocab)); + result.push_back(llama_vocab_sep(vocab)); + result.insert(result.end(), doc.begin(), doc.end()); + result.push_back(llama_vocab_eos(vocab)); + + return result; +} + +// format infill task +static llama_tokens format_infill(const llama_vocab *vocab, const json &input_prefix, const json &input_suffix, + const json &input_extra, const int n_batch, const int n_predict, const int n_ctx, + const bool spm_infill, const llama_tokens &tokens_prompt) { + // TODO: optimize this block by reducing memory allocations and movement + + // use FIM repo-level pattern: + // ref: https://arxiv.org/pdf/2409.12186 + // + // [FIM_REP]myproject + // [FIM_SEP]filename0 + // extra chunk 0 + // [FIM_SEP]filename1 + // extra chunk 1 + // ... + // [FIM_SEP]filename + // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt + // + llama_tokens extra_tokens; + extra_tokens.reserve(n_ctx); + + auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); + auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); + + if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: make project name an input + static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); + + extra_tokens.push_back(llama_vocab_fim_rep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); + } + for (const auto &chunk : input_extra) { + // { "text": string, "filename": string } + const std::string text = json_value(chunk, "text", std::string()); + const std::string filename = json_value(chunk, "filename", std::string("tmp")); + + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } else { + // chunk separator in binary form to avoid confusing the AI + static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, + 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; + static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); + + extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); + } + + const auto chunk_tokens = common_tokenize(vocab, text, false, false); + extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); + } + + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: current filename + static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } + + // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) + const int n_prefix_take = std::min(tokens_prefix.size(), 3 * (n_batch / 4)); + const int n_suffix_take = + std::min(tokens_suffix.size(), std::max(0, (n_batch / 4) - (2 + tokens_prompt.size()))); + + SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, + (n_prefix_take + n_suffix_take)); + + // fill the rest of the context with extra chunks + const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch)-2 * n_predict), extra_tokens.size()); + + tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); + tokens_suffix.resize(n_suffix_take); + + tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); + tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); + tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); + + auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; + auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; + + if (llama_vocab_get_add_bos(vocab)) { + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); + } + + SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int)extra_tokens.size()); + + // put the extra context before the FIM prefix + embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); + + embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); + embd_inp.push_back(llama_vocab_fim_mid(vocab)); + + return embd_inp; +} + +// +// base64 utils (TODO: move to common in the future) +// + +static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + +static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); } + +static inline std::vector base64_decode(const std::string &encoded_string) { + int i = 0; + int j = 0; + int in_ = 0; + + int in_len = encoded_string.size(); + + uint8_t char_array_4[4]; + uint8_t char_array_3[3]; + + std::vector ret; + + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; + in_++; + if (i == 4) { + for (i = 0; i < 4; i++) { + char_array_4[i] = base64_chars.find(char_array_4[i]); + } + + char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) { + ret.push_back(char_array_3[i]); + } + + i = 0; + } + } + + if (i) { + for (j = i; j < 4; j++) { + char_array_4[j] = 0; + } + + for (j = 0; j < 4; j++) { + char_array_4[j] = base64_chars.find(char_array_4[j]); + } + + char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; j < i - 1; j++) { + ret.push_back(char_array_3[j]); + } + } + + return ret; +} + +// +// random string / id +// + +static std::string random_string() { + static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); + + std::random_device rd; + std::mt19937 generator(rd()); + + std::string result(32, ' '); + + for (int i = 0; i < 32; ++i) { + result[i] = str[generator() % str.size()]; + } + + return result; +} + +static std::string gen_chatcmplid() { return "chatcmpl-" + random_string(); } + +// +// other common utils +// + +static bool ends_with(const std::string &str, const std::string &suffix) { + return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); +} + +static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { + if (!text.empty() && !stop.empty()) { + const char text_last_char = text.back(); + for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { + if (stop[char_index] == text_last_char) { + const std::string current_partial = stop.substr(0, char_index + 1); + if (ends_with(text, current_partial)) { + return text.size() - char_index - 1; + } + } + } + } + + return std::string::npos; +} + +// TODO: reuse llama_detokenize +template static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) { + std::string ret; + for (; begin != end; ++begin) { + ret += common_token_to_piece(ctx, *begin); + } + + return ret; +} + +// format incomplete utf-8 multibyte character for output +static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) { + std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); + + // if the size is 1 and first bit is 1, meaning it's a partial character + // (size > 1 meaning it's already a known token) + if (out.size() == 1 && (out[0] & 0x80) == 0x80) { + std::stringstream ss; + ss << std::hex << (out[0] & 0xff); + std::string res(ss.str()); + out = "byte: \\x" + res; + } + + return out; +} + +// static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) { +// const std::string str = +// std::string(event) + ": " + +// data.dump(-1, ' ', false, json::error_handler_t::replace) + +// "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). +// +// LOG_DBG("data stream, to_send: %s", str.c_str()); +// +// return sink.write(str.c_str(), str.size()); +// } + +// +// OAI utils +// + +static json oaicompat_completion_params_parse(const json &body) { + json llama_params; + + if (!body.contains("prompt")) { + throw std::runtime_error("\"prompt\" is required"); + } + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Handle "echo" field + if (json_value(body, "echo", false)) { + throw std::runtime_error("Only no echo is supported"); + } + + // Params supported by OAI but unsupported by llama.cpp + static const std::vector unsupported_params{"best_of", "suffix"}; + for (const auto ¶m : unsupported_params) { + if (body.contains(param)) { + throw std::runtime_error("Unsupported param: " + param); + } + } + + // Copy remaining properties to llama_params + for (const auto &item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } + } + + return llama_params; +} + +static json oaicompat_completion_params_parse(const json &body, /* openai api json semantics */ + bool use_jinja, common_reasoning_format reasoning_format, + const struct common_chat_templates *tmpls) { + json llama_params; + + auto tools = json_value(body, "tools", json()); + auto stream = json_value(body, "stream", false); + + if (tools.is_array() && !tools.empty()) { + if (stream) { + throw std::runtime_error("Cannot use tools with stream"); + } + if (!use_jinja) { + throw std::runtime_error("tools param requires --jinja flag"); + } + } + if (!use_jinja) { + if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) { + throw std::runtime_error("Unsupported param: tool_choice"); + } + } + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + auto json_schema = json_value(body, "json_schema", json()); + auto grammar = json_value(body, "grammar", std::string()); + if (!json_schema.is_null() && !grammar.empty()) { + throw std::runtime_error("Cannot use both json_schema and grammar"); + } + + // Handle "response_format" field + if (body.contains("response_format")) { + json response_format = json_value(body, "response_format", json::object()); + std::string response_type = json_value(response_format, "type", std::string()); + if (response_type == "json_object") { + json_schema = json_value(response_format, "schema", json::object()); + } else if (response_type == "json_schema") { + auto schema_wrapper = json_value(response_format, "json_schema", json::object()); + json_schema = json_value(schema_wrapper, "schema", json::object()); + } else if (!response_type.empty() && response_type != "text") { + throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + + response_type); + } + } + + common_chat_templates_inputs inputs; + inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages")); + inputs.tools = common_chat_tools_parse_oaicompat(tools); + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); + inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); + inputs.grammar = grammar; + inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); + inputs.use_jinja = use_jinja; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; + inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); + if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + + // Apply chat template to the list of messages + auto chat_params = common_chat_templates_apply(tmpls, inputs); + + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; + llama_params["grammar"] = chat_params.grammar; + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + auto grammar_triggers = json::array(); + for (const auto &trigger : chat_params.grammar_triggers) { + grammar_triggers.push_back(trigger.to_json()); + } + llama_params["grammar_triggers"] = grammar_triggers; + llama_params["preserved_tokens"] = chat_params.preserved_tokens; + for (const auto &stop : chat_params.additional_stops) { + llama_params["stop"].push_back(stop); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Handle "logprobs" field + // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may + // need to fix it in the future + if (json_value(body, "logprobs", false)) { + llama_params["n_probs"] = json_value(body, "top_logprobs", 20); + } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { + throw std::runtime_error("top_logprobs requires logprobs to be set to true"); + } + + // Copy remaining properties to llama_params + // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint. + // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp + for (const auto &item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } + } + + return llama_params; +} + +static json format_embeddings_response_oaicompat(const json &request, const json &embeddings, bool use_base64 = false) { + json data = json::array(); + int32_t n_tokens = 0; + int i = 0; + for (const auto &elem : embeddings) { + json embedding_obj; + + if (use_base64) { + const auto &vec = json_value(elem, "embedding", json::array()).get>(); + const char *data_ptr = reinterpret_cast(vec.data()); + size_t data_size = vec.size() * sizeof(float); + embedding_obj = {{"embedding", base64::encode(data_ptr, data_size)}, + {"index", i++}, + {"object", "embedding"}, + {"encoding_format", "base64"}}; + } else { + embedding_obj = { + {"embedding", json_value(elem, "embedding", json::array())}, {"index", i++}, {"object", "embedding"}}; + } + data.push_back(embedding_obj); + + n_tokens += json_value(elem, "tokens_evaluated", 0); + } + + json res = json{{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json{{"prompt_tokens", n_tokens}, {"total_tokens", n_tokens}}}, + {"data", data}}; + + return res; +} + +static json format_response_rerank(const json &request, const json &ranks, bool is_tei_format, + std::vector &texts) { + json res; + if (is_tei_format) { + // TEI response format + res = json::array(); + bool return_text = json_value(request, "return_text", false); + for (const auto &rank : ranks) { + int index = json_value(rank, "index", 0); + json elem = json{ + {"index", index}, + {"score", json_value(rank, "score", 0.0)}, + }; + if (return_text) { + elem["text"] = std::move(texts[index]); + } + res.push_back(elem); + } + } else { + // Jina response format + json results = json::array(); + int32_t n_tokens = 0; + for (const auto &rank : ranks) { + results.push_back(json{ + {"index", json_value(rank, "index", 0)}, + {"relevance_score", json_value(rank, "score", 0.0)}, + }); + + n_tokens += json_value(rank, "tokens_evaluated", 0); + } + + res = json{{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json{{"prompt_tokens", n_tokens}, {"total_tokens", n_tokens}}}, + {"results", results}}; + } + + return res; +} + +static bool is_valid_utf8(const std::string &str) { + const unsigned char *bytes = reinterpret_cast(str.data()); + const unsigned char *end = bytes + str.length(); + + while (bytes < end) { + if (*bytes <= 0x7F) { + // 1-byte sequence (0xxxxxxx) + bytes++; + } else if ((*bytes & 0xE0) == 0xC0) { + // 2-byte sequence (110xxxxx 10xxxxxx) + if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) + return false; + bytes += 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) + return false; + bytes += 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) + return false; + bytes += 4; + } else { + // Invalid UTF-8 lead byte + return false; + } + } + + return true; +} + +static json format_tokenizer_response(const json &tokens) { return json{{"tokens", tokens}}; } + +static json format_detokenized_response(const std::string &content) { return json{{"content", content}}; } + +static json format_logit_bias(const std::vector &logit_bias) { + json data = json::array(); + for (const auto &lb : logit_bias) { + data.push_back(json{ + {"bias", lb.bias}, + {"token", lb.token}, + }); + } + return data; +} + +static std::string safe_json_to_str(const json &data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); +} + +static std::vector get_token_probabilities(llama_context *ctx, int idx) { + std::vector cur; + const auto *logits = llama_get_logits_ith(ctx, idx); + + const llama_model *model = llama_get_model(ctx); + const llama_vocab *vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + // sort tokens by logits + std::sort(cur.begin(), cur.end(), + [](const llama_token_data &a, const llama_token_data &b) { return a.logit > b.logit; }); + + // apply softmax + float max_l = cur[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < cur.size(); ++i) { + float p = expf(cur[i].logit - max_l); + cur[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < cur.size(); ++i) { + cur[i].p /= cum_sum; + } + + return cur; +} + +static bool are_lora_equal(const std::vector &l1, + const std::vector &l2) { + if (l1.size() != l2.size()) { + return false; + } + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { + return false; + } + } + return true; +} + +// parse lora config from JSON request, returned a copy of lora_base with updated scale +static std::vector parse_lora_request(const std::vector &lora_base, + const json &data) { + std::vector lora(lora_base); + int max_idx = lora.size(); + + // clear existing value + for (auto &entry : lora) { + entry.scale = 0.0f; + } + + // set value + for (const auto &entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); + } + } + + return lora; +} \ No newline at end of file diff --git a/src/main/java/de/kherud/llama/CliParameters.java b/src/main/java/de/kherud/llama/CliParameters.java new file mode 100644 index 00000000..4142628e --- /dev/null +++ b/src/main/java/de/kherud/llama/CliParameters.java @@ -0,0 +1,40 @@ +package de.kherud.llama; + +import org.jetbrains.annotations.Nullable; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +abstract class CliParameters { + + final Map parameters = new HashMap<>(); + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + for (String key : parameters.keySet()) { + String value = parameters.get(key); + builder.append(key).append(" "); + if (value != null) { + builder.append(value).append(" "); + } + } + return builder.toString(); + } + + public String[] toArray() { + List result = new ArrayList<>(); + result.add(""); // c args contain the program name as the first argument, so we add an empty entry + for (String key : parameters.keySet()) { + result.add(key); + String value = parameters.get(key); + if (value != null) { + result.add(value); + } + } + return result.toArray(new String[0]); + } + +} diff --git a/src/main/java/de/kherud/llama/InferenceParameters.java b/src/main/java/de/kherud/llama/InferenceParameters.java index 6235c82e..41f74cc9 100644 --- a/src/main/java/de/kherud/llama/InferenceParameters.java +++ b/src/main/java/de/kherud/llama/InferenceParameters.java @@ -1,307 +1,546 @@ package de.kherud.llama; -import java.io.BufferedReader; -import java.io.File; -import java.io.FileReader; -import java.io.IOException; -import java.util.Collections; +import java.util.Collection; +import java.util.List; import java.util.Map; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; +import de.kherud.llama.args.MiroStat; +import de.kherud.llama.args.Sampler; /** - * Parameters used throughout inference of a {@link LlamaModel}, e.g., {@link LlamaModel#generate(String)} and - * {@link LlamaModel#complete(String)}. + * Parameters used throughout inference of a {@link LlamaModel}, e.g., {@link LlamaModel#generate(InferenceParameters)} + * and + * {@link LlamaModel#complete(InferenceParameters)}. */ -public final class InferenceParameters { - - public final int nPredict; // new tokens to predict - public final int nKeep; // number of tokens to keep from initial prompt - public final int nProbs; // if greater than 0, output the probabilities of top nProbs tokens. - @Nullable - public final Map logitBias; // logit bias for specific tokens - public final int topK; // <= 0 to use vocab size - public final float topP; // 1.0 = disabled - public final float tfsZ; // 1.0 = disabled - public final float typicalP; // 1.0 = disabled - public final float temperature; // 1.0 = disabled - public final float repeatPenalty; // 1.0 = disabled - public final int repeatLastN; // last n tokens to penalize (0 = disable penalty, -1 = context size) - public final float frequencyPenalty; // 0.0 = disabled - public final float presencePenalty; // 0.0 = disabled - public final boolean penalizeNL; // 0.0 = disabled - public final boolean ignoreEos; - public final int mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - public final float mirostatTau; // target entropy - public final float mirostatEta; // learning rate - public final boolean beamSearch; - public final int nBeams; - @Nullable - public final String grammar; // optional BNF-like grammar to constrain sampling - @Nullable - public final String[] antiprompt; // string upon seeing which more user input is prompted - public final int seed; - - /** - * Private constructor to build immutable parameters object. Called via {@link Builder}. - */ - private InferenceParameters( - int nPredict, - int nKeep, - int nProbs, - @Nullable Map logitBias, - int topK, - float topP, - float tfsZ, - float typicalP, - float temperature, - float repeatPenalty, - int repeatLastN, - float frequencyPenalty, - float presencePenalty, - boolean penalizeNL, - boolean ignoreEos, - MiroStat mirostat, - float mirostatTau, - float mirostatEta, - boolean beamSearch, - int nBeams, - @Nullable String grammar, - @Nullable String[] antiprompt, - int seed - ) { - this.nPredict = nPredict; - this.nKeep = nKeep; - this.nProbs = nProbs; - this.logitBias = logitBias; - this.topK = topK; - this.topP = topP; - this.tfsZ = tfsZ; - this.typicalP = typicalP; - this.temperature = temperature; - this.repeatPenalty = repeatPenalty; - this.repeatLastN = repeatLastN; - this.frequencyPenalty = frequencyPenalty; - this.presencePenalty = presencePenalty; - this.penalizeNL = penalizeNL; - this.ignoreEos = ignoreEos; - this.mirostat = mirostat.level; - this.mirostatTau = mirostatTau; - this.mirostatEta = mirostatEta; - this.beamSearch = beamSearch; - this.nBeams = nBeams; - this.grammar = grammar; - this.antiprompt = antiprompt; - this.seed = seed; - } - - /** - * The builder class used for creating new {@link InferenceParameters} of a {@link LlamaModel}. - */ - public static class Builder { - - private int nPredict = -1; // new tokens to predict - private int nKeep = 0; // number of tokens to keep from initial prompt - private int nProbs = 0; // if greater than 0, output the probabilities of top nProbs tokens. - - // sampling parameters - private Map logitBias = null; // logit bias for specific tokens - private int topK = 40; // <= 0 to use vocab size - private float topP = 0.95f; // 1.0 = disabled - private float tfsZ = 1.00f; // 1.0 = disabled - private float typicalP = 1.00f; // 1.0 = disabled - private float temperature = 0.80f; // 1.0 = disabled - private float repeatPenalty = 1.10f; // 1.0 = disabled - private int repeatLastN = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - private float frequencyPenalty = 0.00f; // 0.0 = disabled - private float presencePenalty = 0.00f; // 0.0 = disabled - private boolean penalizeNl = false; // consider newlines as a repeatable token - private boolean ignoreEos = false; - private MiroStat mirostat = MiroStat.Disabled; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - private float mirostatTau = 5.00f; // target entropy - private float mirostatEta = 0.10f; // learning rate - private boolean beamSearch = false; - private int nBeams = 2; - - private String grammar = null; // optional BNF-like grammar to constrain sampling - private String[] antiPrompt = null; // string upon seeing which more user input is prompted - - private int seed = 42; - - /** - * Constructs the immutable {@link InferenceParameters} objects with the configured options. - * Note, that all options not configured have sensible defaults. - * - * @return an immutable parameters object - */ - public InferenceParameters build() { - return new InferenceParameters( - nPredict, - nKeep, - nProbs, - logitBias, - topK, - topP, - tfsZ, - typicalP, - temperature, - repeatPenalty, - repeatLastN, - frequencyPenalty, - presencePenalty, - penalizeNl, - ignoreEos, - mirostat, - mirostatTau, - mirostatEta, - beamSearch, - nBeams, - grammar, - antiPrompt, - seed - ); - } +@SuppressWarnings("unused") +public final class InferenceParameters extends JsonParameters { + + private static final String PARAM_PROMPT = "prompt"; + private static final String PARAM_INPUT_PREFIX = "input_prefix"; + private static final String PARAM_INPUT_SUFFIX = "input_suffix"; + private static final String PARAM_CACHE_PROMPT = "cache_prompt"; + private static final String PARAM_N_PREDICT = "n_predict"; + private static final String PARAM_TOP_K = "top_k"; + private static final String PARAM_TOP_P = "top_p"; + private static final String PARAM_MIN_P = "min_p"; + private static final String PARAM_TFS_Z = "tfs_z"; + private static final String PARAM_TYPICAL_P = "typical_p"; + private static final String PARAM_TEMPERATURE = "temperature"; + private static final String PARAM_DYNATEMP_RANGE = "dynatemp_range"; + private static final String PARAM_DYNATEMP_EXPONENT = "dynatemp_exponent"; + private static final String PARAM_REPEAT_LAST_N = "repeat_last_n"; + private static final String PARAM_REPEAT_PENALTY = "repeat_penalty"; + private static final String PARAM_FREQUENCY_PENALTY = "frequency_penalty"; + private static final String PARAM_PRESENCE_PENALTY = "presence_penalty"; + private static final String PARAM_MIROSTAT = "mirostat"; + private static final String PARAM_MIROSTAT_TAU = "mirostat_tau"; + private static final String PARAM_MIROSTAT_ETA = "mirostat_eta"; + private static final String PARAM_PENALIZE_NL = "penalize_nl"; + private static final String PARAM_N_KEEP = "n_keep"; + private static final String PARAM_SEED = "seed"; + private static final String PARAM_N_PROBS = "n_probs"; + private static final String PARAM_MIN_KEEP = "min_keep"; + private static final String PARAM_GRAMMAR = "grammar"; + private static final String PARAM_PENALTY_PROMPT = "penalty_prompt"; + private static final String PARAM_IGNORE_EOS = "ignore_eos"; + private static final String PARAM_LOGIT_BIAS = "logit_bias"; + private static final String PARAM_STOP = "stop"; + private static final String PARAM_SAMPLERS = "samplers"; + private static final String PARAM_STREAM = "stream"; + private static final String PARAM_USE_CHAT_TEMPLATE = "use_chat_template"; + private static final String PARAM_USE_JINJA = "use_jinja"; + private static final String PARAM_MESSAGES = "messages"; + + public InferenceParameters(String prompt) { + // we always need a prompt + setPrompt(prompt); + } - public Builder setNPredict(int nPredict) { - this.nPredict = nPredict; - return this; - } + /** + * Set the prompt to start generation with (default: empty) + */ + public InferenceParameters setPrompt(String prompt) { + parameters.put(PARAM_PROMPT, toJsonString(prompt)); + return this; + } - public Builder setNKeep(int nKeep) { - this.nKeep = nKeep; - return this; - } + /** + * Set a prefix for infilling (default: empty) + */ + public InferenceParameters setInputPrefix(String inputPrefix) { + parameters.put(PARAM_INPUT_PREFIX, toJsonString(inputPrefix)); + return this; + } - public Builder setNProbs(int nProbs) { - this.nProbs = nProbs; - return this; - } + /** + * Set a suffix for infilling (default: empty) + */ + public InferenceParameters setInputSuffix(String inputSuffix) { + parameters.put(PARAM_INPUT_SUFFIX, toJsonString(inputSuffix)); + return this; + } - public Builder setLogitBias(@NotNull Map logitBias) { - this.logitBias = Collections.unmodifiableMap(logitBias); - return this; - } + /** + * Whether to remember the prompt to avoid reprocessing it + */ + public InferenceParameters setCachePrompt(boolean cachePrompt) { + parameters.put(PARAM_CACHE_PROMPT, String.valueOf(cachePrompt)); + return this; + } - public Builder setTopK(int topK) { - this.topK = topK; - return this; - } + /** + * Set the number of tokens to predict (default: -1, -1 = infinity, -2 = until context filled) + */ + public InferenceParameters setNPredict(int nPredict) { + parameters.put(PARAM_N_PREDICT, String.valueOf(nPredict)); + return this; + } - public Builder setTopP(float topP) { - this.topP = topP; - return this; - } + /** + * Set top-k sampling (default: 40, 0 = disabled) + */ + public InferenceParameters setTopK(int topK) { + parameters.put(PARAM_TOP_K, String.valueOf(topK)); + return this; + } - public Builder setTfsZ(float tfsZ) { - this.tfsZ = tfsZ; - return this; - } + /** + * Set top-p sampling (default: 0.9, 1.0 = disabled) + */ + public InferenceParameters setTopP(float topP) { + parameters.put(PARAM_TOP_P, String.valueOf(topP)); + return this; + } - public Builder setTypicalP(float typicalP) { - this.typicalP = typicalP; - return this; - } + /** + * Set min-p sampling (default: 0.1, 0.0 = disabled) + */ + public InferenceParameters setMinP(float minP) { + parameters.put(PARAM_MIN_P, String.valueOf(minP)); + return this; + } - public Builder setTemperature(float temperature) { - this.temperature = temperature; - return this; - } + /** + * Set tail free sampling, parameter z (default: 1.0, 1.0 = disabled) + */ + public InferenceParameters setTfsZ(float tfsZ) { + parameters.put(PARAM_TFS_Z, String.valueOf(tfsZ)); + return this; + } - public Builder setRepeatPenalty(float repeatPenalty) { - this.repeatPenalty = repeatPenalty; - return this; - } + /** + * Set locally typical sampling, parameter p (default: 1.0, 1.0 = disabled) + */ + public InferenceParameters setTypicalP(float typicalP) { + parameters.put(PARAM_TYPICAL_P, String.valueOf(typicalP)); + return this; + } - public Builder setRepeatLastN(int repeatLastN) { - this.repeatLastN = repeatLastN; - return this; - } + /** + * Set the temperature (default: 0.8) + */ + public InferenceParameters setTemperature(float temperature) { + parameters.put(PARAM_TEMPERATURE, String.valueOf(temperature)); + return this; + } - public Builder setFrequencyPenalty(float frequencyPenalty) { - this.frequencyPenalty = frequencyPenalty; - return this; - } + /** + * Set the dynamic temperature range (default: 0.0, 0.0 = disabled) + */ + public InferenceParameters setDynamicTemperatureRange(float dynatempRange) { + parameters.put(PARAM_DYNATEMP_RANGE, String.valueOf(dynatempRange)); + return this; + } - public Builder setPresencePenalty(float presencePenalty) { - this.presencePenalty = presencePenalty; - return this; - } + /** + * Set the dynamic temperature exponent (default: 1.0) + */ + public InferenceParameters setDynamicTemperatureExponent(float dynatempExponent) { + parameters.put(PARAM_DYNATEMP_EXPONENT, String.valueOf(dynatempExponent)); + return this; + } - public Builder setPenalizeNl(boolean penalizeNl) { - this.penalizeNl = penalizeNl; - return this; - } + /** + * Set the last n tokens to consider for penalties (default: 64, 0 = disabled, -1 = ctx_size) + */ + public InferenceParameters setRepeatLastN(int repeatLastN) { + parameters.put(PARAM_REPEAT_LAST_N, String.valueOf(repeatLastN)); + return this; + } - public Builder setIgnoreEos(boolean ignoreEos) { - this.ignoreEos = ignoreEos; - return this; - } + /** + * Set the penalty of repeated sequences of tokens (default: 1.0, 1.0 = disabled) + */ + public InferenceParameters setRepeatPenalty(float repeatPenalty) { + parameters.put(PARAM_REPEAT_PENALTY, String.valueOf(repeatPenalty)); + return this; + } - public Builder setMirostat(MiroStat mode) { - this.mirostat = mode; - return this; - } + /** + * Set the repetition alpha frequency penalty (default: 0.0, 0.0 = disabled) + */ + public InferenceParameters setFrequencyPenalty(float frequencyPenalty) { + parameters.put(PARAM_FREQUENCY_PENALTY, String.valueOf(frequencyPenalty)); + return this; + } - public Builder setMirostatTau(float mirostatTau) { - this.mirostatTau = mirostatTau; - return this; - } + /** + * Set the repetition alpha presence penalty (default: 0.0, 0.0 = disabled) + */ + public InferenceParameters setPresencePenalty(float presencePenalty) { + parameters.put(PARAM_PRESENCE_PENALTY, String.valueOf(presencePenalty)); + return this; + } + + /** + * Set MiroStat sampling strategies. + */ + public InferenceParameters setMiroStat(MiroStat mirostat) { + parameters.put(PARAM_MIROSTAT, String.valueOf(mirostat.ordinal())); + return this; + } + + /** + * Set the MiroStat target entropy, parameter tau (default: 5.0) + */ + public InferenceParameters setMiroStatTau(float mirostatTau) { + parameters.put(PARAM_MIROSTAT_TAU, String.valueOf(mirostatTau)); + return this; + } - public Builder setMirostatEta(float mirostatEta) { - this.mirostatEta = mirostatEta; - return this; + /** + * Set the MiroStat learning rate, parameter eta (default: 0.1) + */ + public InferenceParameters setMiroStatEta(float mirostatEta) { + parameters.put(PARAM_MIROSTAT_ETA, String.valueOf(mirostatEta)); + return this; + } + + /** + * Whether to penalize newline tokens + */ + public InferenceParameters setPenalizeNl(boolean penalizeNl) { + parameters.put(PARAM_PENALIZE_NL, String.valueOf(penalizeNl)); + return this; + } + + /** + * Set the number of tokens to keep from the initial prompt (default: 0, -1 = all) + */ + public InferenceParameters setNKeep(int nKeep) { + parameters.put(PARAM_N_KEEP, String.valueOf(nKeep)); + return this; + } + + /** + * Set the RNG seed (default: -1, use random seed for < 0) + */ + public InferenceParameters setSeed(int seed) { + parameters.put(PARAM_SEED, String.valueOf(seed)); + return this; + } + + /** + * Set the amount top tokens probabilities to output if greater than 0. + */ + public InferenceParameters setNProbs(int nProbs) { + parameters.put(PARAM_N_PROBS, String.valueOf(nProbs)); + return this; + } + + /** + * Set the amount of tokens the samplers should return at least (0 = disabled) + */ + public InferenceParameters setMinKeep(int minKeep) { + parameters.put(PARAM_MIN_KEEP, String.valueOf(minKeep)); + return this; + } + + /** + * Set BNF-like grammar to constrain generations (see samples in grammars/ dir) + */ + public InferenceParameters setGrammar(String grammar) { + parameters.put(PARAM_GRAMMAR, toJsonString(grammar)); + return this; + } + + /** + * Override which part of the prompt is penalized for repetition. + * E.g. if original prompt is "Alice: Hello!" and penaltyPrompt is "Hello!", only the latter will be penalized if + * repeated. See pull request 3727 for more details. + */ + public InferenceParameters setPenaltyPrompt(String penaltyPrompt) { + parameters.put(PARAM_PENALTY_PROMPT, toJsonString(penaltyPrompt)); + return this; + } + + /** + * Override which tokens to penalize for repetition. + * E.g. if original prompt is "Alice: Hello!" and penaltyPrompt corresponds to the token ids of "Hello!", only the + * latter will be penalized if repeated. + * See pull request 3727 for more details. + */ + public InferenceParameters setPenaltyPrompt(int[] tokens) { + if (tokens.length > 0) { + StringBuilder builder = new StringBuilder(); + builder.append("["); + for (int i = 0; i < tokens.length; i++) { + builder.append(tokens[i]); + if (i < tokens.length - 1) { + builder.append(", "); + } + } + builder.append("]"); + parameters.put(PARAM_PENALTY_PROMPT, builder.toString()); } + return this; + } - public Builder setBeamSearch(boolean beamSearch) { - this.beamSearch = beamSearch; - return this; + /** + * Set whether to ignore end of stream token and continue generating (implies --logit-bias 2-inf) + */ + public InferenceParameters setIgnoreEos(boolean ignoreEos) { + parameters.put(PARAM_IGNORE_EOS, String.valueOf(ignoreEos)); + return this; + } + + /** + * Modify the likelihood of tokens appearing in the completion by their id. E.g., Map.of(15043, 1f) + * to increase the likelihood of token ' Hello', or a negative value to decrease it. + * Note, this method overrides any previous calls to + *
    + *
  • {@link #setTokenBias(Map)}
  • + *
  • {@link #disableTokens(Collection)}
  • + *
  • {@link #disableTokenIds(Collection)}}
  • + *
+ */ + public InferenceParameters setTokenIdBias(Map logitBias) { + if (!logitBias.isEmpty()) { + StringBuilder builder = new StringBuilder(); + builder.append("["); + int i = 0; + for (Map.Entry entry : logitBias.entrySet()) { + Integer key = entry.getKey(); + Float value = entry.getValue(); + builder.append("[") + .append(key) + .append(", ") + .append(value) + .append("]"); + if (i++ < logitBias.size() - 1) { + builder.append(", "); + } + } + builder.append("]"); + parameters.put(PARAM_LOGIT_BIAS, builder.toString()); } + return this; + } - public Builder setNBeams(int nBeams) { - this.nBeams = nBeams; - return this; + /** + * Set tokens to disable, this corresponds to {@link #setTokenIdBias(Map)} with a value of + * {@link Float#NEGATIVE_INFINITY}. + * Note, this method overrides any previous calls to + *
    + *
  • {@link #setTokenIdBias(Map)}
  • + *
  • {@link #setTokenBias(Map)}
  • + *
  • {@link #disableTokens(Collection)}
  • + *
+ */ + public InferenceParameters disableTokenIds(Collection tokenIds) { + if (!tokenIds.isEmpty()) { + StringBuilder builder = new StringBuilder(); + builder.append("["); + int i = 0; + for (Integer token : tokenIds) { + builder.append("[") + .append(token) + .append(", ") + .append(false) + .append("]"); + if (i++ < tokenIds.size() - 1) { + builder.append(", "); + } + } + builder.append("]"); + parameters.put(PARAM_LOGIT_BIAS, builder.toString()); } + return this; + } - // default charset usage for Java backwards compatibility - @SuppressWarnings("ImplicitDefaultCharsetUsage") - public Builder setGrammar(@NotNull File file) throws IOException { - StringBuilder grammarBuilder = new StringBuilder(); - try (BufferedReader br = new BufferedReader(new FileReader(file))) { - String currentLine; - while ((currentLine = br.readLine()) != null) { - grammarBuilder.append(currentLine).append("\n"); + /** + * Modify the likelihood of tokens appearing in the completion by their id. E.g., Map.of(" Hello", 1f) + * to increase the likelihood of token id 15043, or a negative value to decrease it. + * Note, this method overrides any previous calls to + *
    + *
  • {@link #setTokenIdBias(Map)}
  • + *
  • {@link #disableTokens(Collection)}
  • + *
  • {@link #disableTokenIds(Collection)}}
  • + *
+ */ + public InferenceParameters setTokenBias(Map logitBias) { + if (!logitBias.isEmpty()) { + StringBuilder builder = new StringBuilder(); + builder.append("["); + int i = 0; + for (Map.Entry entry : logitBias.entrySet()) { + String key = entry.getKey(); + Float value = entry.getValue(); + builder.append("[") + .append(toJsonString(key)) + .append(", ") + .append(value) + .append("]"); + if (i++ < logitBias.size() - 1) { + builder.append(", "); } } - return setGrammar(grammarBuilder.toString()); + builder.append("]"); + parameters.put(PARAM_LOGIT_BIAS, builder.toString()); } + return this; + } - public Builder setGrammar(@Nullable String grammar) { - this.grammar = grammar; - return this; + /** + * Set tokens to disable, this corresponds to {@link #setTokenBias(Map)} with a value of + * {@link Float#NEGATIVE_INFINITY}. + * Note, this method overrides any previous calls to + *
    + *
  • {@link #setTokenBias(Map)}
  • + *
  • {@link #setTokenIdBias(Map)}
  • + *
  • {@link #disableTokenIds(Collection)}
  • + *
+ */ + public InferenceParameters disableTokens(Collection tokens) { + if (!tokens.isEmpty()) { + StringBuilder builder = new StringBuilder(); + builder.append("["); + int i = 0; + for (String token : tokens) { + builder.append("[") + .append(toJsonString(token)) + .append(", ") + .append(false) + .append("]"); + if (i++ < tokens.size() - 1) { + builder.append(", "); + } + } + builder.append("]"); + parameters.put(PARAM_LOGIT_BIAS, builder.toString()); } + return this; + } - public Builder setAntiPrompt(@NotNull String[] antiPrompt) { - this.antiPrompt = antiPrompt; - return this; + /** + * Set strings upon seeing which token generation is stopped + */ + public InferenceParameters setStopStrings(String... stopStrings) { + if (stopStrings.length > 0) { + StringBuilder builder = new StringBuilder(); + builder.append("["); + for (int i = 0; i < stopStrings.length; i++) { + builder.append(toJsonString(stopStrings[i])); + if (i < stopStrings.length - 1) { + builder.append(", "); + } + } + builder.append("]"); + parameters.put(PARAM_STOP, builder.toString()); } + return this; + } - public Builder setSeed(int seed) { - this.seed = seed; - return this; + /** + * Set which samplers to use for token generation in the given order + */ + public InferenceParameters setSamplers(Sampler... samplers) { + if (samplers.length > 0) { + StringBuilder builder = new StringBuilder(); + builder.append("["); + for (int i = 0; i < samplers.length; i++) { + switch (samplers[i]) { + case TOP_K: + builder.append("\"top_k\""); + break; + case TOP_P: + builder.append("\"top_p\""); + break; + case MIN_P: + builder.append("\"min_p\""); + break; + case TEMPERATURE: + builder.append("\"temperature\""); + break; + } + if (i < samplers.length - 1) { + builder.append(", "); + } + } + builder.append("]"); + parameters.put(PARAM_SAMPLERS, builder.toString()); } + return this; } - public enum MiroStat { + /** + * Set whether generate should apply a chat template (default: false) + */ + public InferenceParameters setUseChatTemplate(boolean useChatTemplate) { + parameters.put(PARAM_USE_JINJA, String.valueOf(useChatTemplate)); + return this; + } + + /** + * Set the messages for chat-based inference. + * - Allows **only one** system message. + * - Allows **one or more** user/assistant messages. + */ + public InferenceParameters setMessages(String systemMessage, List> messages) { + StringBuilder messagesBuilder = new StringBuilder(); + messagesBuilder.append("["); + + // Add system message (if provided) + if (systemMessage != null && !systemMessage.isEmpty()) { + messagesBuilder.append("{\"role\": \"system\", \"content\": ") + .append(toJsonString(systemMessage)) + .append("}"); + if (!messages.isEmpty()) { + messagesBuilder.append(", "); + } + } + + // Add user/assistant messages + for (int i = 0; i < messages.size(); i++) { + Pair message = messages.get(i); + String role = message.getKey(); + String content = message.getValue(); + + if (!role.equals("user") && !role.equals("assistant")) { + throw new IllegalArgumentException("Invalid role: " + role + ". Role must be 'user' or 'assistant'."); + } + + messagesBuilder.append("{\"role\":") + .append(toJsonString(role)) + .append(", \"content\": ") + .append(toJsonString(content)) + .append("}"); + + if (i < messages.size() - 1) { + messagesBuilder.append(", "); + } + } - Disabled(0), - V1(1), - V2(2); + messagesBuilder.append("]"); - private final int level; + // Convert ArrayNode to a JSON string and store it in parameters + parameters.put(PARAM_MESSAGES, messagesBuilder.toString()); + return this; + } - MiroStat(int level) { - this.level = level; - } + InferenceParameters setStream(boolean stream) { + parameters.put(PARAM_STREAM, String.valueOf(stream)); + return this; } + } diff --git a/src/main/java/de/kherud/llama/JsonParameters.java b/src/main/java/de/kherud/llama/JsonParameters.java new file mode 100644 index 00000000..e9916976 --- /dev/null +++ b/src/main/java/de/kherud/llama/JsonParameters.java @@ -0,0 +1,95 @@ +package de.kherud.llama; + +import java.util.HashMap; +import java.util.Map; + +/** + * The Java library re-uses most of the llama.cpp server code, which mostly works with JSONs. Thus, the complexity and + * maintainability is much lower if we work with JSONs. This class provides a simple abstraction to easily create + * JSON object strings by filling a Map<String, String> with key value pairs. + */ +abstract class JsonParameters { + + // We save parameters directly as a String map here, to re-use as much as possible of the (json-based) C++ code. + // The JNI code for a proper Java-typed data object is comparatively too complex and hard to maintain. + final Map parameters = new HashMap<>(); + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + builder.append("{\n"); + int i = 0; + for (Map.Entry entry : parameters.entrySet()) { + String key = entry.getKey(); + String value = entry.getValue(); + builder.append("\t\"") + .append(key) + .append("\": ") + .append(value); + if (i++ < parameters.size() - 1) { + builder.append(","); + } + builder.append("\n"); + } + builder.append("}"); + return builder.toString(); + } + + // taken from org.json.JSONObject#quote(String, Writer) + String toJsonString(String text) { + if (text == null) return null; + StringBuilder builder = new StringBuilder((text.length()) + 2); + + char b; + char c = 0; + String hhhh; + int i; + int len = text.length(); + + builder.append('"'); + for (i = 0; i < len; i += 1) { + b = c; + c = text.charAt(i); + switch (c) { + case '\\': + case '"': + builder.append('\\'); + builder.append(c); + break; + case '/': + if (b == '<') { + builder.append('\\'); + } + builder.append(c); + break; + case '\b': + builder.append("\\b"); + break; + case '\t': + builder.append("\\t"); + break; + case '\n': + builder.append("\\n"); + break; + case '\f': + builder.append("\\f"); + break; + case '\r': + builder.append("\\r"); + break; + default: + if (c < ' ' || (c >= '\u0080' && c < '\u00a0') || (c >= '\u2000' && c < '\u2100')) { + builder.append("\\u"); + hhhh = Integer.toHexString(c); + builder.append("0000", 0, 4 - hhhh.length()); + builder.append(hhhh); + } + else { + builder.append(c); + } + } + } + builder.append('"'); + return builder.toString(); + } +} diff --git a/src/main/java/de/kherud/llama/LlamaException.java b/src/main/java/de/kherud/llama/LlamaException.java index c2b5762c..84d4ee7c 100644 --- a/src/main/java/de/kherud/llama/LlamaException.java +++ b/src/main/java/de/kherud/llama/LlamaException.java @@ -1,6 +1,6 @@ package de.kherud.llama; -public class LlamaException extends RuntimeException { +class LlamaException extends RuntimeException { public LlamaException(String message) { super(message); diff --git a/src/main/java/de/kherud/llama/LlamaIterable.java b/src/main/java/de/kherud/llama/LlamaIterable.java new file mode 100644 index 00000000..7e6dff89 --- /dev/null +++ b/src/main/java/de/kherud/llama/LlamaIterable.java @@ -0,0 +1,15 @@ +package de.kherud.llama; + +import org.jetbrains.annotations.NotNull; + +/** + * An iterable used by {@link LlamaModel#generate(InferenceParameters)} that specifically returns a {@link LlamaIterator}. + */ +@FunctionalInterface +public interface LlamaIterable extends Iterable { + + @NotNull + @Override + LlamaIterator iterator(); + +} diff --git a/src/main/java/de/kherud/llama/LlamaIterator.java b/src/main/java/de/kherud/llama/LlamaIterator.java new file mode 100644 index 00000000..cb1c5c2c --- /dev/null +++ b/src/main/java/de/kherud/llama/LlamaIterator.java @@ -0,0 +1,51 @@ +package de.kherud.llama; + +import java.lang.annotation.Native; +import java.util.Iterator; +import java.util.NoSuchElementException; + +/** + * This iterator is used by {@link LlamaModel#generate(InferenceParameters)}. In addition to implementing {@link Iterator}, + * it allows to cancel ongoing inference (see {@link #cancel()}). + */ +public final class LlamaIterator implements Iterator { + + private final LlamaModel model; + private final int taskId; + + @Native + @SuppressWarnings("FieldMayBeFinal") + private boolean hasNext = true; + + LlamaIterator(LlamaModel model, InferenceParameters parameters) { + this.model = model; + parameters.setStream(true); + taskId = model.requestCompletion(parameters.toString()); + } + + @Override + public boolean hasNext() { + return hasNext; + } + + @Override + public LlamaOutput next() { + if (!hasNext) { + throw new NoSuchElementException(); + } + LlamaOutput output = model.receiveCompletion(taskId); + hasNext = !output.stop; + if (output.stop) { + model.releaseTask(taskId); + } + return output; + } + + /** + * Cancel the ongoing generation process. + */ + public void cancel() { + model.cancelCompletion(taskId); + hasNext = false; + } +} diff --git a/src/main/java/de/kherud/llama/LlamaLoader.java b/src/main/java/de/kherud/llama/LlamaLoader.java index 4bdcf98c..58692522 100644 --- a/src/main/java/de/kherud/llama/LlamaLoader.java +++ b/src/main/java/de/kherud/llama/LlamaLoader.java @@ -26,7 +26,6 @@ import java.nio.file.StandardCopyOption; import java.util.LinkedList; import java.util.List; -import java.util.UUID; import java.util.stream.Stream; import org.jetbrains.annotations.Nullable; @@ -63,7 +62,6 @@ static synchronized void initialize() throws UnsatisfiedLinkError { System.err.println("'ggml-metal.metal' not found"); } } - loadNativeLibrary("llama"); loadNativeLibrary("jllama"); extracted = true; } @@ -74,7 +72,8 @@ static synchronized void initialize() throws UnsatisfiedLinkError { private static void cleanup() { try (Stream dirList = Files.list(getTempDir().toPath())) { dirList.filter(LlamaLoader::shouldCleanPath).forEach(LlamaLoader::cleanPath); - } catch (IOException e) { + } + catch (IOException e) { System.err.println("Failed to open directory: " + e.getMessage()); } } @@ -87,7 +86,8 @@ private static boolean shouldCleanPath(Path path) { private static void cleanPath(Path path) { try { Files.delete(path); - } catch (Exception e) { + } + catch (Exception e) { System.err.println("Failed to delete old native lib: " + e.getMessage()); } } @@ -95,36 +95,31 @@ private static void cleanPath(Path path) { private static void loadNativeLibrary(String name) { List triedPaths = new LinkedList<>(); - // Try loading library from de.kherud.llama.lib.path library path - String nativeLibName = System.getProperty("de.kherud.llama.lib.name"); - if (nativeLibName == null) { - nativeLibName = System.mapLibraryName(name); - } - + String nativeLibName = System.mapLibraryName(name); String nativeLibPath = System.getProperty("de.kherud.llama.lib.path"); if (nativeLibPath != null) { Path path = Paths.get(nativeLibPath, nativeLibName); if (loadNativeLibrary(path)) { return; - } else { + } + else { triedPaths.add(nativeLibPath); } } - // Load the os-dependent library from the jar file - nativeLibPath = getNativeResourcePath(); - if (hasNativeLib(nativeLibPath, nativeLibName)) { - // temporary library folder - String tempFolder = getTempDir().getAbsolutePath(); - // Try extracting the library from jar - if (extractAndLoadLibraryFile(nativeLibPath, nativeLibName, tempFolder)) { + if (OSInfo.isAndroid()) { + try { + // loadLibrary can load directly from packed apk file automatically + // if java-llama.cpp is added as code source + System.loadLibrary(name); return; - } else { - triedPaths.add(nativeLibPath); + } + catch (UnsatisfiedLinkError e) { + triedPaths.add("Directly from .apk/lib"); } } - // As a last resort try from java.library.path + // Try to load the library from java.library.path String javaLibraryPath = System.getProperty("java.library.path", ""); for (String ldPath : javaLibraryPath.split(File.pathSeparator)) { if (ldPath.isEmpty()) { @@ -133,11 +128,26 @@ private static void loadNativeLibrary(String name) { Path path = Paths.get(ldPath, nativeLibName); if (loadNativeLibrary(path)) { return; - } else { + } + else { triedPaths.add(ldPath); } } + // As a last resort try load the os-dependent library from the jar file + nativeLibPath = getNativeResourcePath(); + if (hasNativeLib(nativeLibPath, nativeLibName)) { + // temporary library folder + String tempFolder = getTempDir().getAbsolutePath(); + // Try extracting the library from jar + if (extractAndLoadLibraryFile(nativeLibPath, nativeLibName, tempFolder)) { + return; + } + else { + triedPaths.add(nativeLibPath); + } + } + throw new UnsatisfiedLinkError( String.format( "No native library found for os.name=%s, os.arch=%s, paths=[%s]", @@ -154,7 +164,7 @@ private static void loadNativeLibrary(String name) { * @param path path of the native library * @return true for successfully loading, otherwise false */ - private static boolean loadNativeLibrary(Path path) { + public static boolean loadNativeLibrary(Path path) { if (!Files.exists(path)) { return false; } @@ -162,7 +172,8 @@ private static boolean loadNativeLibrary(Path path) { try { System.load(absolutePath); return true; - } catch (UnsatisfiedLinkError e) { + } + catch (UnsatisfiedLinkError e) { System.err.println(e.getMessage()); System.err.println("Failed to load native library: " + absolutePath + ". osinfo: " + OSInfo.getNativeLibFolderPathForCurrentOS()); return false; @@ -172,17 +183,8 @@ private static boolean loadNativeLibrary(Path path) { @Nullable private static Path extractFile(String sourceDirectory, String fileName, String targetDirectory, boolean addUuid) { String nativeLibraryFilePath = sourceDirectory + "/" + fileName; - // Include architecture name in temporary filename in order to avoid conflicts - // when multiple JVMs with different architectures running at the same time - String extractedLibFileName; - if (addUuid) { - String uuid = UUID.randomUUID().toString(); - extractedLibFileName = uuid + "-" + fileName; - } else { - extractedLibFileName = fileName; - } - Path extractedFilePath = Paths.get(targetDirectory, extractedLibFileName); + Path extractedFilePath = Paths.get(targetDirectory, fileName); try { // Extract a native library file into the target directory @@ -191,7 +193,8 @@ private static Path extractFile(String sourceDirectory, String fileName, String return null; } Files.copy(reader, extractedFilePath, StandardCopyOption.REPLACE_EXISTING); - } finally { + } + finally { // Delete the extracted lib file on JVM exit. extractedFilePath.toFile().deleteOnExit(); } @@ -211,7 +214,8 @@ private static Path extractFile(String sourceDirectory, String fileName, String System.out.println("Extracted '" + fileName + "' to '" + extractedFilePath + "'"); return extractedFilePath; - } catch (IOException e) { + } + catch (IOException e) { System.err.println(e.getMessage()); return null; } diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index 9bdddb87..eab36202 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -1,13 +1,15 @@ package de.kherud.llama; +import de.kherud.llama.args.LogFormat; +import org.jetbrains.annotations.Nullable; + import java.lang.annotation.Native; import java.nio.charset.StandardCharsets; -import java.util.Iterator; -import java.util.NoSuchElementException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import java.util.function.BiConsumer; -import org.jetbrains.annotations.Nullable; - /** * This class is a wrapper around the llama.cpp functionality. * Upon being created, it natively allocates memory for the model context. @@ -15,9 +17,9 @@ *

* The main functionality of this class is: *

    - *
  • Streaming answers (and probabilities) via {@link #generate(String)}
  • - *
  • Creating whole responses to prompts via {@link #complete(String)}
  • - *
  • Creating embeddings via {@link #embed(String)} (make sure to configure {@link ModelParameters.Builder#setEmbedding(boolean)}
  • + *
  • Streaming answers (and probabilities) via {@link #generate(InferenceParameters)}
  • + *
  • Creating whole responses to prompts via {@link #complete(InferenceParameters)}
  • + *
  • Creating embeddings via {@link #embed(String)} (make sure to configure {@link ModelParameters#enableEmbedding()}
  • *
  • Accessing the tokenizer via {@link #encode(String)} and {@link #decode(int[])}
  • *
*/ @@ -27,87 +29,59 @@ public class LlamaModel implements AutoCloseable { LlamaLoader.initialize(); } - private static final ModelParameters defaultModelParams = new ModelParameters.Builder().build(); - private static final InferenceParameters defaultInferenceParams = new InferenceParameters.Builder().build(); - @Native private long ctx; /** - * Load a gguf llama.cpp model from a given file path with default {@link ModelParameters}. - * - * @param filePath a file path pointing to the model - * @throws LlamaException if no model could be loaded from the given file path - */ - public LlamaModel(String filePath) { - this(filePath, defaultModelParams); - } - - /** - * Load a gguf llama.cpp model from a given file path with custom {@link ModelParameters}. + * Load with the given {@link ModelParameters}. Make sure to either set + *
    + *
  • {@link ModelParameters#setModel(String)}
  • + *
  • {@link ModelParameters#setModelUrl(String)}
  • + *
  • {@link ModelParameters#setHfRepo(String)}, {@link ModelParameters#setHfFile(String)}
  • + *
* - * @param filePath a file path pointing to the model - * @param parameters the set of previously configured options + * @param parameters the set of options * @throws LlamaException if no model could be loaded from the given file path */ - public LlamaModel(String filePath, ModelParameters parameters) { - loadModel(filePath, parameters); - } - - /** - * Generate and return a whole answer with default parameters. Note, that the prompt isn't preprocessed in any - * way, nothing like "User: ", "###Instruction", etc. is added. - * - * @param prompt the LLM prompt - * @return an LLM response - */ - public String complete(String prompt) { - return complete(prompt, defaultInferenceParams); + public LlamaModel(ModelParameters parameters) { + loadModel(parameters.toArray()); } /** * Generate and return a whole answer with custom parameters. Note, that the prompt isn't preprocessed in any * way, nothing like "User: ", "###Instruction", etc. is added. * - * @param prompt the LLM prompt * @return an LLM response */ - public String complete(String prompt, InferenceParameters parameters) { - byte[] bytes = getFull(prompt, parameters); - return new String(bytes, StandardCharsets.UTF_8); - } - - /** - * Generate and stream outputs with default inference parameters. Note, that the prompt isn't preprocessed in any - * way, nothing like "User: ", "###Instruction", etc. is added. - * - * @param prompt the LLM prompt - * @return iterable LLM outputs - */ - public Iterable generate(String prompt) { - return generate(prompt, defaultInferenceParams); + public String complete(InferenceParameters parameters) { + parameters.setStream(false); + int taskId = requestCompletion(parameters.toString()); + LlamaOutput output = receiveCompletion(taskId); + return output.text; } /** * Generate and stream outputs with custom inference parameters. Note, that the prompt isn't preprocessed in any * way, nothing like "User: ", "###Instruction", etc. is added. * - * @param prompt the LLM prompt * @return iterable LLM outputs */ - public Iterable generate(String prompt, InferenceParameters parameters) { - return () -> new LlamaIterator(prompt, parameters); + public LlamaIterable generate(InferenceParameters parameters) { + return () -> new LlamaIterator(this, parameters); } - + + + /** * Get the embedding of a string. Note, that the prompt isn't preprocessed in any way, nothing like * "User: ", "###Instruction", etc. is added. * * @param prompt the string to embed * @return an embedding float array - * @throws IllegalStateException if embedding mode was not activated (see {@link ModelParameters.Builder#setEmbedding(boolean)}) + * @throws IllegalStateException if embedding mode was not activated (see {@link ModelParameters#enableEmbedding()}) */ - public native float[] embed(String prompt); + public native float[] embed(String prompt); + /** * Tokenize a prompt given the native tokenizer @@ -123,58 +97,75 @@ public Iterable generate(String prompt, InferenceParameters parameters) * @param tokens an array of tokens * @return the token ids decoded to a string */ - public String decode(int[] tokens) { + public String decode(int[] tokens) { byte[] bytes = decodeBytes(tokens); return new String(bytes, StandardCharsets.UTF_8); } /** - * Sets a callback for both Java and C++ log messages. Can be set to {@code null} to disable logging. + * Sets a callback for native llama.cpp log messages. + * Per default, log messages are written in JSON to stdout. Note, that in text mode the callback will be also + * invoked with log messages of the GGML backend, while JSON mode can only access request log messages. + * In JSON mode, GGML messages will still be written to stdout. + * To only change the log format but keep logging to stdout, the given callback can be null. + * To disable logging, pass an empty callback, i.e., (level, msg) -> {}. * + * @param format the log format to use * @param callback a method to call for log messages */ - public static native void setLogger(@Nullable BiConsumer callback); + public static native void setLogger(LogFormat format, @Nullable BiConsumer callback); @Override public void close() { delete(); } - private native void loadModel(String filePath, ModelParameters parameters) throws LlamaException; - private native void setupInference(String prompt, InferenceParameters parameters); - private native byte[] getFull(String prompt, InferenceParameters parameters); - private native byte[] getNext(LlamaIterator iterator); - private native byte[] decodeBytes(int[] tokens); + // don't overload native methods since the C++ function names get nasty + native int requestCompletion(String params) throws LlamaException; + + native LlamaOutput receiveCompletion(int taskId) throws LlamaException; + + native void cancelCompletion(int taskId); + + native byte[] decodeBytes(int[] tokens); + + private native void loadModel(String... parameters) throws LlamaException; + private native void delete(); + + native void releaseTask(int taskId); - // fields are modified by native code and thus should not be final - @SuppressWarnings("FieldMayBeFinal") - private final class LlamaIterator implements Iterator { - - @Native - private boolean hasNext = true; - @Native - private long generatedCount = 0; - @Native - private long tokenIndex = 0; - - private LlamaIterator(String prompt, InferenceParameters parameters) { - setupInference(prompt, parameters); - } - - @Override - public boolean hasNext() { - return hasNext; - } - - @Override - public String next() { - if (!hasNext) { - throw new NoSuchElementException(); - } - byte[] bytes = getNext(this); - return new String(bytes, StandardCharsets.UTF_8); - } + private static native byte[] jsonSchemaToGrammarBytes(String schema); + + public static String jsonSchemaToGrammar(String schema) { + return new String(jsonSchemaToGrammarBytes(schema), StandardCharsets.UTF_8); } - + + public List> rerank(boolean reRank, String query, String ... documents) { + LlamaOutput output = rerank(query, documents); + + Map scoredDocumentMap = output.probabilities; + + List> rankedDocuments = new ArrayList<>(); + + if (reRank) { + // Sort in descending order based on Float values + scoredDocumentMap.entrySet() + .stream() + .sorted((a, b) -> Float.compare(b.getValue(), a.getValue())) // Descending order + .forEach(entry -> rankedDocuments.add(new Pair<>(entry.getKey(), entry.getValue()))); + } else { + // Copy without sorting + scoredDocumentMap.forEach((key, value) -> rankedDocuments.add(new Pair<>(key, value))); + } + + return rankedDocuments; + } + + public native LlamaOutput rerank(String query, String... documents); + + public String applyTemplate(InferenceParameters parameters) { + return applyTemplate(parameters.toString()); + } + public native String applyTemplate(String parametersJson); } diff --git a/src/main/java/de/kherud/llama/LlamaOutput.java b/src/main/java/de/kherud/llama/LlamaOutput.java new file mode 100644 index 00000000..365b335e --- /dev/null +++ b/src/main/java/de/kherud/llama/LlamaOutput.java @@ -0,0 +1,39 @@ +package de.kherud.llama; + +import org.jetbrains.annotations.NotNull; + +import java.nio.charset.StandardCharsets; +import java.util.Map; + +/** + * An output of the LLM providing access to the generated text and the associated probabilities. You have to configure + * {@link InferenceParameters#setNProbs(int)} in order for probabilities to be returned. + */ +public final class LlamaOutput { + + /** + * The last bit of generated text that is representable as text (i.e., cannot be individual utf-8 multibyte code + * points). + */ + @NotNull + public final String text; + + /** + * Note, that you have to configure {@link InferenceParameters#setNProbs(int)} in order for probabilities to be returned. + */ + @NotNull + public final Map probabilities; + + final boolean stop; + + LlamaOutput(byte[] generated, @NotNull Map probabilities, boolean stop) { + this.text = new String(generated, StandardCharsets.UTF_8); + this.probabilities = probabilities; + this.stop = stop; + } + + @Override + public String toString() { + return text; + } +} diff --git a/src/main/java/de/kherud/llama/LogLevel.java b/src/main/java/de/kherud/llama/LogLevel.java index 25520f0e..b55c0898 100644 --- a/src/main/java/de/kherud/llama/LogLevel.java +++ b/src/main/java/de/kherud/llama/LogLevel.java @@ -5,24 +5,9 @@ */ public enum LogLevel { - DEBUG(-1), - INFO(4), - WARN(3), - ERROR(2); - - private final int code; - - LogLevel(int code) { - this.code = code; - } - - /** - * Returns the native log level code of this option - * - * @return the native code - */ - int getCode() { - return code; - } + DEBUG, + INFO, + WARN, + ERROR } diff --git a/src/main/java/de/kherud/llama/ModelParameters.java b/src/main/java/de/kherud/llama/ModelParameters.java index 9b91a134..7999295d 100644 --- a/src/main/java/de/kherud/llama/ModelParameters.java +++ b/src/main/java/de/kherud/llama/ModelParameters.java @@ -1,339 +1,964 @@ package de.kherud.llama; -import org.jetbrains.annotations.Nullable; +import de.kherud.llama.args.*; -/** +/*** * Parameters used for initializing a {@link LlamaModel}. */ -public final class ModelParameters { - - public final int nThreads; - - public final int seed; - public final int nCtx; // text context - public final int nBatch; // prompt processing batch size - public final int nGpuLayers; // number of layers to store in VRAM - public final int mainGpu; // the GPU that is used for scratch and small tensors - public final float[] tensorSplit; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) - public final float ropeFreqBase; // RoPE base frequency - public final float ropeFreqScale; // RoPE frequency scaling factor - // public final llama_progress_callback progress_callback; -// public final Pointer progress_callback_user_data; - public final boolean lowVram; // if true, reduce VRAM usage at the cost of performance - public final boolean mulMatQ; // if true, use experimental mul_mat_q kernels - public final boolean f16Kv; // use fp16 for KV cache - public final boolean logitsAll; // the llama_eval() call computes all logits, not just the last one - public final boolean vocabOnly; // only load the vocabulary, no weights - public final boolean useMmap; // use mmap if possible - public final boolean useMlock; // force system to keep model in RAM - public final boolean embedding; // embedding mode only - @Nullable - public final String loraAdapter; // lora adapter path - @Nullable - public final String loraBase; // base model path for the lora adapter - public final boolean hellaswag; // compute HellaSwag score over random tasks from datafile supplied in prompt - public final short hellaswagTasks; // number of tasks to use when computing the HellaSwag score - public final boolean memoryF16; // use f16 instead of f32 for memory kv - public final boolean memTest; // compute maximum memory usage - public final boolean numa; // attempt optimizations that help on some NUMA systems - public final boolean verbosePrompt; // log prompt tokens before generation - - /** - * Private constructor to build immutable parameters object. Called via {@link Builder}. - */ - private ModelParameters( - int nThreads, - int seed, - int nCtx, - int nBatch, - int nGpuLayers, - int mainGpu, - float[] tensorSplit, - float ropeFreqBase, - float ropeFreqScale, - boolean lowVram, - boolean mulMatQ, - boolean f16Kv, - boolean logitsAll, - boolean vocabOnly, - boolean useMmap, - boolean useMlock, - boolean embedding, - @Nullable String loraAdapter, - @Nullable String loraBase, - boolean hellaswag, - short hellaswagTasks, - boolean memoryF16, - boolean memTest, - boolean numa, - boolean verbosePrompt - ) { - this.seed = seed; - this.nCtx = nCtx; - this.nBatch = nBatch; - this.nGpuLayers = nGpuLayers; - this.mainGpu = mainGpu; - this.tensorSplit = tensorSplit; - this.ropeFreqBase = ropeFreqBase; - this.ropeFreqScale = ropeFreqScale; - this.lowVram = lowVram; - this.mulMatQ = mulMatQ; - this.f16Kv = f16Kv; - this.logitsAll = logitsAll; - this.vocabOnly = vocabOnly; - this.useMmap = useMmap; - this.useMlock = useMlock; - this.embedding = embedding; - this.nThreads = nThreads; - this.loraAdapter = loraAdapter; - this.loraBase = loraBase; - this.hellaswag = hellaswag; - this.hellaswagTasks = hellaswagTasks; - this.memoryF16 = memoryF16; - this.memTest = memTest; - this.numa = numa; - this.verbosePrompt = verbosePrompt; - } - - /** - * The builder class used for creating new {@link ModelParameters} of a {@link LlamaModel}. - */ - public static class Builder { - - private int nThreads = Runtime.getRuntime().availableProcessors(); - public int seed = -1; - public int nCtx = 512; // text context - public int nBatch = 512; // prompt processing batch size - public int nGpuLayers = -1; // number of layers to store in VRAM - public int mainGpu = 0; // the GPU that is used for scratch and small tensors - public float[] tensorSplit = null; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) - public float ropeFreqBase = 10000.0f; // RoPE base frequency - public float ropeFreqScale = 1.0f; // RoPE frequency scaling factor - // public llama_progress_callback progress_callback; - // public Pointer progress_callback_user_data; - public boolean lowVram = false; // if true, reduce VRAM usage at the cost of performance - public boolean mulMatQ = true; // if true, use experimental mul_mat_q kernels - public boolean f16Kv; // use fp16 for KV cache - public boolean logitsAll; // the llama_eval() call computes all logits, not just the last one - public boolean vocabOnly = false; // only load the vocabulary, no weights - public boolean useMmap = true; // use mmap if possible - public boolean useMlock = false; // force system to keep model in RAM - public boolean embedding = false; // embedding mode only - private String loraAdapter = null; // lora adapter path - private String loraBase = null; // base model path for the lora adapter - - private boolean hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt - private short hellaswagTasks = 400; // number of tasks to use when computing the HellaSwag score - - private boolean memoryF16 = true; // use f16 instead of f32 for memory kv - private boolean memTest = false; // compute maximum memory usage - private boolean numa = false; // attempt optimizations that help on some NUMA systems - private boolean verbosePrompt = false; // print prompt tokens before generation - - /** - * Constructs the immutable {@link ModelParameters} objects with the configured options. - * Note, that all options not configured have sensible defaults. - * - * @return an immutable parameters object - */ - public ModelParameters build() { - return new ModelParameters( - nThreads, - seed, - nCtx, - nBatch, - nGpuLayers, - mainGpu, - tensorSplit, - ropeFreqBase, - ropeFreqScale, - lowVram, - mulMatQ, - f16Kv, - logitsAll, - vocabOnly, - useMmap, - useMlock, - embedding, - loraAdapter, - loraBase, - hellaswag, - hellaswagTasks, - memoryF16, - memTest, - numa, - verbosePrompt - ); - } - - public Builder setNThreads(int nThreads) { - this.nThreads = nThreads; - return this; - } - - public Builder setLoraAdapter(@Nullable String loraAdapter) { - this.loraAdapter = loraAdapter; - return this; - } - - public Builder setLoraBase(@Nullable String loraBase) { - this.loraBase = loraBase; - return this; - } - - public Builder setHellaswag(boolean hellaswag) { - this.hellaswag = hellaswag; - return this; - } - - public Builder setHellaswagTasks(short hellaswagTasks) { - this.hellaswagTasks = hellaswagTasks; - return this; - } - - public Builder setMemoryF16(boolean memoryF16) { - this.memoryF16 = memoryF16; - return this; - } - - public Builder setMemTest(boolean memTest) { - this.memTest = memTest; - return this; - } - - public Builder setNuma(boolean numa) { - this.numa = numa; - return this; - } - - public Builder setVerbosePrompt(boolean verbosePrompt) { - this.verbosePrompt = verbosePrompt; - return this; - } - - /** - * Set a callback that will be used to report progress loading the model with a float value of 0-1. - * - * @param progressCallback the function to call ony any progress - * @return this builder object - */ -// public Builder setProgressCallback(@Nullable Consumer progressCallback) { -// // Similarly to setting the logger, we don't allow passing any user data to the progress callback, since -// // the JVM might move the object around in the memory, thus invalidating any pointers. -// if (progressCallback == null) { -// ctxParams.setProgress_callback(null); -// } else { -// ctxParams.setProgress_callback((progress, ctx) -> progressCallback.accept(progress)); -// } -// return this; -// } - - public Builder setSeed(int seed) { - this.seed = seed; - return this; - } - - public Builder setNCtx(int nCtx) { - this.nCtx = nCtx; - return this; - } - - public Builder setNBbatch(int nBatch) { - this.nBatch = nBatch; - return this; - } - - public Builder setNGpuLayers(int nGpuLayers) { - this.nGpuLayers = nGpuLayers; - return this; - } - - public Builder setMainGpu(int mainGpu) { - this.mainGpu = mainGpu; - return this; - } - - public Builder setTensorSplit(float[] tensorSplit) { - this.tensorSplit = tensorSplit; - return this; - } - - public Builder setRopeFreqBase(float ropeFreqBase) { - this.ropeFreqBase = ropeFreqBase; - return this; - } - - public Builder setRopeFreqScale(float ropeFreqScale) { - this.ropeFreqScale = ropeFreqScale; - return this; - } - -// public Builder setProgressCallback(LlamaLibrary.llama_progress_callback progress_callback) { -// ctxParams.setProgress_callback(progress_callback); -// return this; -// } - -// public Builder setProgressCallbackUserData(Pointer progress_callback_user_data) { -// ctxParams.setProgress_callback_user_data(progress_callback_user_data); -// return this; -// } - - public Builder setLowVram(boolean lowVram) { - this.lowVram = lowVram; - return this; - } - - public Builder setMulMatQ(boolean mulMatQ) { - this.mulMatQ = mulMatQ; - return this; - } - - /** - * use fp16 for KV cache - */ - public Builder setF16Kv(boolean f16Kv) { - this.f16Kv = f16Kv; - return this; - } - - /** - * the llama_eval() call computes all logits, not just the last one - */ - public Builder setLogitsAll(boolean logitsAll) { - this.logitsAll = logitsAll; - return this; - } - - /** - * only load the vocabulary, no weights - */ - public Builder setVocabOnly(boolean vocabOnly) { - this.vocabOnly = vocabOnly; - return this; - } - - /** - * use mmap if possible - */ - public Builder setUseMmap(boolean useMmap) { - this.useMmap = useMmap; - return this; - } - - /** - * force system to keep model in RAM - */ - public Builder setUseMLock(boolean useMlock) { - this.useMlock = useMlock; - return this; - } - - /** - * embedding mode only - */ - public Builder setEmbedding(boolean embedding) { - this.embedding = embedding; - return this; - } - } +@SuppressWarnings("unused") +public final class ModelParameters extends CliParameters { + + /** + * Set the number of threads to use during generation (default: -1). + */ + public ModelParameters setThreads(int nThreads) { + parameters.put("--threads", String.valueOf(nThreads)); + return this; + } + + /** + * Set the number of threads to use during batch and prompt processing (default: same as --threads). + */ + public ModelParameters setThreadsBatch(int nThreads) { + parameters.put("--threads-batch", String.valueOf(nThreads)); + return this; + } + + /** + * Set the CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: ""). + */ + public ModelParameters setCpuMask(String mask) { + parameters.put("--cpu-mask", mask); + return this; + } + + /** + * Set the range of CPUs for affinity. Complements --cpu-mask. + */ + public ModelParameters setCpuRange(String range) { + parameters.put("--cpu-range", range); + return this; + } + + /** + * Use strict CPU placement (default: 0). + */ + public ModelParameters setCpuStrict(int strictCpu) { + parameters.put("--cpu-strict", String.valueOf(strictCpu)); + return this; + } + + /** + * Set process/thread priority: 0-normal, 1-medium, 2-high, 3-realtime (default: 0). + */ + public ModelParameters setPriority(int priority) { + if (priority < 0 || priority > 3) { + throw new IllegalArgumentException("Invalid value for priority"); + } + parameters.put("--prio", String.valueOf(priority)); + return this; + } + + /** + * Set the polling level to wait for work (0 - no polling, default: 0). + */ + public ModelParameters setPoll(int poll) { + parameters.put("--poll", String.valueOf(poll)); + return this; + } + + /** + * Set the CPU affinity mask for batch processing: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask). + */ + public ModelParameters setCpuMaskBatch(String mask) { + parameters.put("--cpu-mask-batch", mask); + return this; + } + + /** + * Set the ranges of CPUs for batch affinity. Complements --cpu-mask-batch. + */ + public ModelParameters setCpuRangeBatch(String range) { + parameters.put("--cpu-range-batch", range); + return this; + } + + /** + * Use strict CPU placement for batch processing (default: same as --cpu-strict). + */ + public ModelParameters setCpuStrictBatch(int strictCpuBatch) { + parameters.put("--cpu-strict-batch", String.valueOf(strictCpuBatch)); + return this; + } + + /** + * Set process/thread priority for batch processing: 0-normal, 1-medium, 2-high, 3-realtime (default: 0). + */ + public ModelParameters setPriorityBatch(int priorityBatch) { + if (priorityBatch < 0 || priorityBatch > 3) { + throw new IllegalArgumentException("Invalid value for priority batch"); + } + parameters.put("--prio-batch", String.valueOf(priorityBatch)); + return this; + } + + /** + * Set the polling level for batch processing (default: same as --poll). + */ + public ModelParameters setPollBatch(int pollBatch) { + parameters.put("--poll-batch", String.valueOf(pollBatch)); + return this; + } + + /** + * Set the size of the prompt context (default: 0, 0 = loaded from model). + */ + public ModelParameters setCtxSize(int ctxSize) { + parameters.put("--ctx-size", String.valueOf(ctxSize)); + return this; + } + + /** + * Set the number of tokens to predict (default: -1 = infinity, -2 = until context filled). + */ + public ModelParameters setPredict(int nPredict) { + parameters.put("--predict", String.valueOf(nPredict)); + return this; + } + + /** + * Set the logical maximum batch size (default: 0). + */ + public ModelParameters setBatchSize(int batchSize) { + parameters.put("--batch-size", String.valueOf(batchSize)); + return this; + } + + /** + * Set the physical maximum batch size (default: 0). + */ + public ModelParameters setUbatchSize(int ubatchSize) { + parameters.put("--ubatch-size", String.valueOf(ubatchSize)); + return this; + } + + /** + * Set the number of tokens to keep from the initial prompt (default: -1 = all). + */ + public ModelParameters setKeep(int keep) { + parameters.put("--keep", String.valueOf(keep)); + return this; + } + + /** + * Disable context shift on infinite text generation (default: enabled). + */ + public ModelParameters disableContextShift() { + parameters.put("--no-context-shift", null); + return this; + } + + /** + * Enable Flash Attention (default: disabled). + */ + public ModelParameters enableFlashAttn() { + parameters.put("--flash-attn", null); + return this; + } + + /** + * Disable internal libllama performance timings (default: false). + */ + public ModelParameters disablePerf() { + parameters.put("--no-perf", null); + return this; + } + + /** + * Process escape sequences (default: true). + */ + public ModelParameters enableEscape() { + parameters.put("--escape", null); + return this; + } + + /** + * Do not process escape sequences (default: false). + */ + public ModelParameters disableEscape() { + parameters.put("--no-escape", null); + return this; + } + + /** + * Enable special tokens output (default: true). + */ + public ModelParameters enableSpecial() { + parameters.put("--special", null); + return this; + } + + /** + * Skip warming up the model with an empty run (default: false). + */ + public ModelParameters skipWarmup() { + parameters.put("--no-warmup", null); + return this; + } + + /** + * Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. + * (default: disabled) + */ + public ModelParameters setSpmInfill() { + parameters.put("--spm-infill", null); + return this; + } + + /** + * Set samplers that will be used for generation in the order, separated by ';' (default: all). + */ + public ModelParameters setSamplers(Sampler... samplers) { + if (samplers.length > 0) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < samplers.length; i++) { + Sampler sampler = samplers[i]; + builder.append(sampler.name().toLowerCase()); + if (i < samplers.length - 1) { + builder.append(";"); + } + } + parameters.put("--samplers", builder.toString()); + } + return this; + } + + /** + * Set RNG seed (default: -1, use random seed). + */ + public ModelParameters setSeed(long seed) { + parameters.put("--seed", String.valueOf(seed)); + return this; + } + + /** + * Ignore end of stream token and continue generating (implies --logit-bias EOS-inf). + */ + public ModelParameters ignoreEos() { + parameters.put("--ignore-eos", null); + return this; + } + + /** + * Set temperature for sampling (default: 0.8). + */ + public ModelParameters setTemp(float temp) { + parameters.put("--temp", String.valueOf(temp)); + return this; + } + + /** + * Set top-k sampling (default: 40, 0 = disabled). + */ + public ModelParameters setTopK(int topK) { + parameters.put("--top-k", String.valueOf(topK)); + return this; + } + + /** + * Set top-p sampling (default: 0.95, 1.0 = disabled). + */ + public ModelParameters setTopP(float topP) { + parameters.put("--top-p", String.valueOf(topP)); + return this; + } + + /** + * Set min-p sampling (default: 0.05, 0.0 = disabled). + */ + public ModelParameters setMinP(float minP) { + parameters.put("--min-p", String.valueOf(minP)); + return this; + } + + /** + * Set xtc probability (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setXtcProbability(float xtcProbability) { + parameters.put("--xtc-probability", String.valueOf(xtcProbability)); + return this; + } + + /** + * Set xtc threshold (default: 0.1, 1.0 = disabled). + */ + public ModelParameters setXtcThreshold(float xtcThreshold) { + parameters.put("--xtc-threshold", String.valueOf(xtcThreshold)); + return this; + } + + /** + * Set locally typical sampling parameter p (default: 1.0, 1.0 = disabled). + */ + public ModelParameters setTypical(float typP) { + parameters.put("--typical", String.valueOf(typP)); + return this; + } + + /** + * Set last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size). + */ + public ModelParameters setRepeatLastN(int repeatLastN) { + if (repeatLastN < -1) { + throw new RuntimeException("Invalid repeat-last-n value"); + } + parameters.put("--repeat-last-n", String.valueOf(repeatLastN)); + return this; + } + + /** + * Set penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled). + */ + public ModelParameters setRepeatPenalty(float repeatPenalty) { + parameters.put("--repeat-penalty", String.valueOf(repeatPenalty)); + return this; + } + + /** + * Set repeat alpha presence penalty (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setPresencePenalty(float presencePenalty) { + parameters.put("--presence-penalty", String.valueOf(presencePenalty)); + return this; + } + + /** + * Set repeat alpha frequency penalty (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setFrequencyPenalty(float frequencyPenalty) { + parameters.put("--frequency-penalty", String.valueOf(frequencyPenalty)); + return this; + } + + /** + * Set DRY sampling multiplier (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setDryMultiplier(float dryMultiplier) { + parameters.put("--dry-multiplier", String.valueOf(dryMultiplier)); + return this; + } + + /** + * Set DRY sampling base value (default: 1.75). + */ + public ModelParameters setDryBase(float dryBase) { + parameters.put("--dry-base", String.valueOf(dryBase)); + return this; + } + + /** + * Set allowed length for DRY sampling (default: 2). + */ + public ModelParameters setDryAllowedLength(int dryAllowedLength) { + parameters.put("--dry-allowed-length", String.valueOf(dryAllowedLength)); + return this; + } + + /** + * Set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size). + */ + public ModelParameters setDryPenaltyLastN(int dryPenaltyLastN) { + if (dryPenaltyLastN < -1) { + throw new RuntimeException("Invalid dry-penalty-last-n value"); + } + parameters.put("--dry-penalty-last-n", String.valueOf(dryPenaltyLastN)); + return this; + } + + /** + * Add sequence breaker for DRY sampling, clearing out default breakers (default: none). + */ + public ModelParameters setDrySequenceBreaker(String drySequenceBreaker) { + parameters.put("--dry-sequence-breaker", drySequenceBreaker); + return this; + } + + /** + * Set dynamic temperature range (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setDynatempRange(float dynatempRange) { + parameters.put("--dynatemp-range", String.valueOf(dynatempRange)); + return this; + } + + /** + * Set dynamic temperature exponent (default: 1.0). + */ + public ModelParameters setDynatempExponent(float dynatempExponent) { + parameters.put("--dynatemp-exp", String.valueOf(dynatempExponent)); + return this; + } + + /** + * Use Mirostat sampling (default: PLACEHOLDER, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0). + */ + public ModelParameters setMirostat(MiroStat mirostat) { + parameters.put("--mirostat", String.valueOf(mirostat.ordinal())); + return this; + } + + /** + * Set Mirostat learning rate, parameter eta (default: 0.1). + */ + public ModelParameters setMirostatLR(float mirostatLR) { + parameters.put("--mirostat-lr", String.valueOf(mirostatLR)); + return this; + } + + /** + * Set Mirostat target entropy, parameter tau (default: 5.0). + */ + public ModelParameters setMirostatEnt(float mirostatEnt) { + parameters.put("--mirostat-ent", String.valueOf(mirostatEnt)); + return this; + } + + /** + * Modify the likelihood of token appearing in the completion. + */ + public ModelParameters setLogitBias(String tokenIdAndBias) { + parameters.put("--logit-bias", tokenIdAndBias); + return this; + } + + /** + * Set BNF-like grammar to constrain generations (default: empty). + */ + public ModelParameters setGrammar(String grammar) { + parameters.put("--grammar", grammar); + return this; + } + + /** + * Specify the file to read grammar from. + */ + public ModelParameters setGrammarFile(String fileName) { + parameters.put("--grammar-file", fileName); + return this; + } + + /** + * Specify the JSON schema to constrain generations (default: empty). + */ + public ModelParameters setJsonSchema(String schema) { + parameters.put("--json-schema", schema); + return this; + } + + /** + * Set pooling type for embeddings (default: model default if unspecified). + */ + public ModelParameters setPoolingType(PoolingType type) { + parameters.put("--pooling", type.getArgValue()); + return this; + } + + /** + * Set RoPE frequency scaling method (default: linear unless specified by the model). + */ + public ModelParameters setRopeScaling(RopeScalingType type) { + parameters.put("--rope-scaling", type.getArgValue()); + return this; + } + + /** + * Set RoPE context scaling factor, expands context by a factor of N. + */ + public ModelParameters setRopeScale(float ropeScale) { + parameters.put("--rope-scale", String.valueOf(ropeScale)); + return this; + } + + /** + * Set RoPE base frequency, used by NTK-aware scaling (default: loaded from model). + */ + public ModelParameters setRopeFreqBase(float ropeFreqBase) { + parameters.put("--rope-freq-base", String.valueOf(ropeFreqBase)); + return this; + } + + /** + * Set RoPE frequency scaling factor, expands context by a factor of 1/N. + */ + public ModelParameters setRopeFreqScale(float ropeFreqScale) { + parameters.put("--rope-freq-scale", String.valueOf(ropeFreqScale)); + return this; + } + + /** + * Set YaRN: original context size of model (default: model training context size). + */ + public ModelParameters setYarnOrigCtx(int yarnOrigCtx) { + parameters.put("--yarn-orig-ctx", String.valueOf(yarnOrigCtx)); + return this; + } + + /** + * Set YaRN: extrapolation mix factor (default: 0.0 = full interpolation). + */ + public ModelParameters setYarnExtFactor(float yarnExtFactor) { + parameters.put("--yarn-ext-factor", String.valueOf(yarnExtFactor)); + return this; + } + + /** + * Set YaRN: scale sqrt(t) or attention magnitude (default: 1.0). + */ + public ModelParameters setYarnAttnFactor(float yarnAttnFactor) { + parameters.put("--yarn-attn-factor", String.valueOf(yarnAttnFactor)); + return this; + } + + /** + * Set YaRN: high correction dim or alpha (default: 1.0). + */ + public ModelParameters setYarnBetaSlow(float yarnBetaSlow) { + parameters.put("--yarn-beta-slow", String.valueOf(yarnBetaSlow)); + return this; + } + + /** + * Set YaRN: low correction dim or beta (default: 32.0). + */ + public ModelParameters setYarnBetaFast(float yarnBetaFast) { + parameters.put("--yarn-beta-fast", String.valueOf(yarnBetaFast)); + return this; + } + + /** + * Set group-attention factor (default: 1). + */ + public ModelParameters setGrpAttnN(int grpAttnN) { + parameters.put("--grp-attn-n", String.valueOf(grpAttnN)); + return this; + } + + /** + * Set group-attention width (default: 512). + */ + public ModelParameters setGrpAttnW(int grpAttnW) { + parameters.put("--grp-attn-w", String.valueOf(grpAttnW)); + return this; + } + + /** + * Enable verbose printing of the KV cache. + */ + public ModelParameters enableDumpKvCache() { + parameters.put("--dump-kv-cache", null); + return this; + } + + /** + * Disable KV offload. + */ + public ModelParameters disableKvOffload() { + parameters.put("--no-kv-offload", null); + return this; + } + + /** + * Set KV cache data type for K (allowed values: F16). + */ + public ModelParameters setCacheTypeK(CacheType type) { + parameters.put("--cache-type-k", type.name().toLowerCase()); + return this; + } + + /** + * Set KV cache data type for V (allowed values: F16). + */ + public ModelParameters setCacheTypeV(CacheType type) { + parameters.put("--cache-type-v", type.name().toLowerCase()); + return this; + } + + /** + * Set KV cache defragmentation threshold (default: 0.1, < 0 - disabled). + */ + public ModelParameters setDefragThold(float defragThold) { + parameters.put("--defrag-thold", String.valueOf(defragThold)); + return this; + } + + /** + * Set the number of parallel sequences to decode (default: 1). + */ + public ModelParameters setParallel(int nParallel) { + parameters.put("--parallel", String.valueOf(nParallel)); + return this; + } + + /** + * Enable continuous batching (a.k.a dynamic batching) (default: disabled). + */ + public ModelParameters enableContBatching() { + parameters.put("--cont-batching", null); + return this; + } + + /** + * Disable continuous batching. + */ + public ModelParameters disableContBatching() { + parameters.put("--no-cont-batching", null); + return this; + } + + /** + * Force system to keep model in RAM rather than swapping or compressing. + */ + public ModelParameters enableMlock() { + parameters.put("--mlock", null); + return this; + } + + /** + * Do not memory-map model (slower load but may reduce pageouts if not using mlock). + */ + public ModelParameters disableMmap() { + parameters.put("--no-mmap", null); + return this; + } + + /** + * Set NUMA optimization type for system. + */ + public ModelParameters setNuma(NumaStrategy numaStrategy) { + parameters.put("--numa", numaStrategy.name().toLowerCase()); + return this; + } + + /** + * Set comma-separated list of devices to use for offloading <dev1,dev2,..> (none = don't offload). + */ + public ModelParameters setDevices(String devices) { + parameters.put("--device", devices); + return this; + } + + /** + * Set the number of layers to store in VRAM. + */ + public ModelParameters setGpuLayers(int gpuLayers) { + parameters.put("--gpu-layers", String.valueOf(gpuLayers)); + return this; + } + + /** + * Set how to split the model across multiple GPUs (none, layer, row). + */ + public ModelParameters setSplitMode(GpuSplitMode splitMode) { + parameters.put("--split-mode", splitMode.name().toLowerCase()); + return this; + } + + /** + * Set fraction of the model to offload to each GPU, comma-separated list of proportions N0,N1,N2,.... + */ + public ModelParameters setTensorSplit(String tensorSplit) { + parameters.put("--tensor-split", tensorSplit); + return this; + } + + /** + * Set the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row). + */ + public ModelParameters setMainGpu(int mainGpu) { + parameters.put("--main-gpu", String.valueOf(mainGpu)); + return this; + } + + /** + * Enable checking model tensor data for invalid values. + */ + public ModelParameters enableCheckTensors() { + parameters.put("--check-tensors", null); + return this; + } + + /** + * Override model metadata by key. This option can be specified multiple times. + */ + public ModelParameters setOverrideKv(String keyValue) { + parameters.put("--override-kv", keyValue); + return this; + } + + /** + * Add a LoRA adapter (can be repeated to use multiple adapters). + */ + public ModelParameters addLoraAdapter(String fname) { + parameters.put("--lora", fname); + return this; + } + + /** + * Add a LoRA adapter with user-defined scaling (can be repeated to use multiple adapters). + */ + public ModelParameters addLoraScaledAdapter(String fname, float scale) { + parameters.put("--lora-scaled", fname + "," + scale); + return this; + } + + /** + * Add a control vector (this argument can be repeated to add multiple control vectors). + */ + public ModelParameters addControlVector(String fname) { + parameters.put("--control-vector", fname); + return this; + } + + /** + * Add a control vector with user-defined scaling (can be repeated to add multiple scaled control vectors). + */ + public ModelParameters addControlVectorScaled(String fname, float scale) { + parameters.put("--control-vector-scaled", fname + "," + scale); + return this; + } + + /** + * Set the layer range to apply the control vector(s) to (start and end inclusive). + */ + public ModelParameters setControlVectorLayerRange(int start, int end) { + parameters.put("--control-vector-layer-range", start + "," + end); + return this; + } + + /** + * Set the model path from which to load the base model. + */ + public ModelParameters setModel(String model) { + parameters.put("--model", model); + return this; + } + + /** + * Set the model download URL (default: unused). + */ + public ModelParameters setModelUrl(String modelUrl) { + parameters.put("--model-url", modelUrl); + return this; + } + + /** + * Set the Hugging Face model repository (default: unused). + */ + public ModelParameters setHfRepo(String hfRepo) { + parameters.put("--hf-repo", hfRepo); + return this; + } + + /** + * Set the Hugging Face model file (default: unused). + */ + public ModelParameters setHfFile(String hfFile) { + parameters.put("--hf-file", hfFile); + return this; + } + + /** + * Set the Hugging Face model repository for the vocoder model (default: unused). + */ + public ModelParameters setHfRepoV(String hfRepoV) { + parameters.put("--hf-repo-v", hfRepoV); + return this; + } + + /** + * Set the Hugging Face model file for the vocoder model (default: unused). + */ + public ModelParameters setHfFileV(String hfFileV) { + parameters.put("--hf-file-v", hfFileV); + return this; + } + + /** + * Set the Hugging Face access token (default: value from HF_TOKEN environment variable). + */ + public ModelParameters setHfToken(String hfToken) { + parameters.put("--hf-token", hfToken); + return this; + } + + /** + * Enable embedding use case; use only with dedicated embedding models. + */ + public ModelParameters enableEmbedding() { + parameters.put("--embedding", null); + return this; + } + + /** + * Enable reranking endpoint on server. + */ + public ModelParameters enableReranking() { + parameters.put("--reranking", null); + return this; + } + + /** + * Set minimum chunk size to attempt reusing from the cache via KV shifting. + */ + public ModelParameters setCacheReuse(int cacheReuse) { + parameters.put("--cache-reuse", String.valueOf(cacheReuse)); + return this; + } + + /** + * Set the path to save the slot kv cache. + */ + public ModelParameters setSlotSavePath(String slotSavePath) { + parameters.put("--slot-save-path", slotSavePath); + return this; + } + + /** + * Set custom jinja chat template. + */ + public ModelParameters setChatTemplate(String chatTemplate) { + parameters.put("--chat-template", chatTemplate); + return this; + } + + /** + * Set how much the prompt of a request must match the prompt of a slot in order to use that slot. + */ + public ModelParameters setSlotPromptSimilarity(float similarity) { + parameters.put("--slot-prompt-similarity", String.valueOf(similarity)); + return this; + } + + /** + * Load LoRA adapters without applying them (apply later via POST /lora-adapters). + */ + public ModelParameters setLoraInitWithoutApply() { + parameters.put("--lora-init-without-apply", null); + return this; + } + + /** + * Disable logging. + */ + public ModelParameters disableLog() { + parameters.put("--log-disable", null); + return this; + } + + /** + * Set the log file path. + */ + public ModelParameters setLogFile(String logFile) { + parameters.put("--log-file", logFile); + return this; + } + + /** + * Set verbosity level to infinity (log all messages, useful for debugging). + */ + public ModelParameters setVerbose() { + parameters.put("--verbose", null); + return this; + } + + /** + * Set the verbosity threshold (messages with a higher verbosity will be ignored). + */ + public ModelParameters setLogVerbosity(int verbosity) { + parameters.put("--log-verbosity", String.valueOf(verbosity)); + return this; + } + + /** + * Enable prefix in log messages. + */ + public ModelParameters enableLogPrefix() { + parameters.put("--log-prefix", null); + return this; + } + + /** + * Enable timestamps in log messages. + */ + public ModelParameters enableLogTimestamps() { + parameters.put("--log-timestamps", null); + return this; + } + + /** + * Set the number of tokens to draft for speculative decoding. + */ + public ModelParameters setDraftMax(int draftMax) { + parameters.put("--draft-max", String.valueOf(draftMax)); + return this; + } + + /** + * Set the minimum number of draft tokens to use for speculative decoding. + */ + public ModelParameters setDraftMin(int draftMin) { + parameters.put("--draft-min", String.valueOf(draftMin)); + return this; + } + + /** + * Set the minimum speculative decoding probability for greedy decoding. + */ + public ModelParameters setDraftPMin(float draftPMin) { + parameters.put("--draft-p-min", String.valueOf(draftPMin)); + return this; + } + + /** + * Set the size of the prompt context for the draft model. + */ + public ModelParameters setCtxSizeDraft(int ctxSizeDraft) { + parameters.put("--ctx-size-draft", String.valueOf(ctxSizeDraft)); + return this; + } + + /** + * Set the comma-separated list of devices to use for offloading the draft model. + */ + public ModelParameters setDeviceDraft(String deviceDraft) { + parameters.put("--device-draft", deviceDraft); + return this; + } + + /** + * Set the number of layers to store in VRAM for the draft model. + */ + public ModelParameters setGpuLayersDraft(int gpuLayersDraft) { + parameters.put("--gpu-layers-draft", String.valueOf(gpuLayersDraft)); + return this; + } + + /** + * Set the draft model for speculative decoding. + */ + public ModelParameters setModelDraft(String modelDraft) { + parameters.put("--model-draft", modelDraft); + return this; + } + + /** + * Enable jinja for templating + */ + public ModelParameters enableJinja() { + parameters.put("--jinja", null); + return this; + } + } + + diff --git a/src/main/java/de/kherud/llama/OSInfo.java b/src/main/java/de/kherud/llama/OSInfo.java index 740bdca5..9354ec2f 100644 --- a/src/main/java/de/kherud/llama/OSInfo.java +++ b/src/main/java/de/kherud/llama/OSInfo.java @@ -31,234 +31,256 @@ */ @SuppressWarnings("UseOfSystemOutOrSystemErr") class OSInfo { - private static final ProcessRunner processRunner = new ProcessRunner(); - private static final HashMap archMapping = new HashMap<>(); + public static final String X86 = "x86"; + public static final String X64 = "x64"; + public static final String X86_64 = "x86_64"; + public static final String IA64_32 = "ia64_32"; + public static final String IA64 = "ia64"; + public static final String PPC = "ppc"; + public static final String PPC64 = "ppc64"; + private static final ProcessRunner processRunner = new ProcessRunner(); + private static final HashMap archMapping = new HashMap<>(); - public static final String X86 = "x86"; - public static final String X86_64 = "x86_64"; - public static final String IA64_32 = "ia64_32"; - public static final String IA64 = "ia64"; - public static final String PPC = "ppc"; - public static final String PPC64 = "ppc64"; + static { + // x86 mappings + archMapping.put(X86, X86); + archMapping.put("i386", X86); + archMapping.put("i486", X86); + archMapping.put("i586", X86); + archMapping.put("i686", X86); + archMapping.put("pentium", X86); - static { - // x86 mappings - archMapping.put(X86, X86); - archMapping.put("i386", X86); - archMapping.put("i486", X86); - archMapping.put("i586", X86); - archMapping.put("i686", X86); - archMapping.put("pentium", X86); + // x86_64 mappings + archMapping.put(X86_64, X86_64); + archMapping.put("amd64", X86_64); + archMapping.put("em64t", X86_64); + archMapping.put("universal", X86_64); // Needed for openjdk7 in Mac - // x86_64 mappings - archMapping.put(X86_64, X86_64); - archMapping.put("amd64", X86_64); - archMapping.put("em64t", X86_64); - archMapping.put("universal", X86_64); // Needed for openjdk7 in Mac + // Itanium 64-bit mappings + archMapping.put(IA64, IA64); + archMapping.put("ia64w", IA64); - // Itanium 64-bit mappings - archMapping.put(IA64, IA64); - archMapping.put("ia64w", IA64); + // Itanium 32-bit mappings, usually an HP-UX construct + archMapping.put(IA64_32, IA64_32); + archMapping.put("ia64n", IA64_32); - // Itanium 32-bit mappings, usually an HP-UX construct - archMapping.put(IA64_32, IA64_32); - archMapping.put("ia64n", IA64_32); + // PowerPC mappings + archMapping.put(PPC, PPC); + archMapping.put("power", PPC); + archMapping.put("powerpc", PPC); + archMapping.put("power_pc", PPC); + archMapping.put("power_rs", PPC); - // PowerPC mappings - archMapping.put(PPC, PPC); - archMapping.put("power", PPC); - archMapping.put("powerpc", PPC); - archMapping.put("power_pc", PPC); - archMapping.put("power_rs", PPC); + // TODO: PowerPC 64bit mappings + archMapping.put(PPC64, PPC64); + archMapping.put("power64", PPC64); + archMapping.put("powerpc64", PPC64); + archMapping.put("power_pc64", PPC64); + archMapping.put("power_rs64", PPC64); + archMapping.put("ppc64el", PPC64); + archMapping.put("ppc64le", PPC64); + + // TODO: Adding X64 support + archMapping.put(X64, X64); + } - // TODO: PowerPC 64bit mappings - archMapping.put(PPC64, PPC64); - archMapping.put("power64", PPC64); - archMapping.put("powerpc64", PPC64); - archMapping.put("power_pc64", PPC64); - archMapping.put("power_rs64", PPC64); - archMapping.put("ppc64el", PPC64); - archMapping.put("ppc64le", PPC64); - } + public static void main(String[] args) { + if (args.length >= 1) { + if ("--os".equals(args[0])) { + System.out.print(getOSName()); + return; + } + else if ("--arch".equals(args[0])) { + System.out.print(getArchName()); + return; + } + } - public static void main(String[] args) { - if (args.length >= 1) { - if ("--os".equals(args[0])) { - System.out.print(getOSName()); - return; - } else if ("--arch".equals(args[0])) { - System.out.print(getArchName()); - return; - } - } + System.out.print(getNativeLibFolderPathForCurrentOS()); + } - System.out.print(getNativeLibFolderPathForCurrentOS()); - } + static String getNativeLibFolderPathForCurrentOS() { + return getOSName() + "/" + getArchName(); + } - static String getNativeLibFolderPathForCurrentOS() { - return getOSName() + "/" + getArchName(); - } + static String getOSName() { + return translateOSNameToFolderName(System.getProperty("os.name")); + } - static String getOSName() { - return translateOSNameToFolderName(System.getProperty("os.name")); - } + static boolean isAndroid() { + return isAndroidRuntime() || isAndroidTermux(); + } - static boolean isAndroid() { - return isAndroidRuntime() || isAndroidTermux(); - } + static boolean isAndroidRuntime() { + return System.getProperty("java.runtime.name", "").toLowerCase().contains("android"); + } - static boolean isAndroidRuntime() { - return System.getProperty("java.runtime.name", "").toLowerCase().contains("android"); - } + static boolean isAndroidTermux() { + try { + return processRunner.runAndWaitFor("uname -o").toLowerCase().contains("android"); + } + catch (Exception ignored) { + return false; + } + } - static boolean isAndroidTermux() { - try { - return processRunner.runAndWaitFor("uname -o").toLowerCase().contains("android"); - } catch (Exception ignored) { - return false; - } - } + static boolean isMusl() { + Path mapFilesDir = Paths.get("/proc/self/map_files"); + try (Stream dirStream = Files.list(mapFilesDir)) { + return dirStream + .map( + path -> { + try { + return path.toRealPath().toString(); + } + catch (IOException e) { + return ""; + } + }) + .anyMatch(s -> s.toLowerCase().contains("musl")); + } + catch (Exception ignored) { + // fall back to checking for alpine linux in the event we're using an older kernel which + // may not fail the above check + return isAlpineLinux(); + } + } - static boolean isMusl() { - Path mapFilesDir = Paths.get("/proc/self/map_files"); - try (Stream dirStream = Files.list(mapFilesDir)) { - return dirStream - .map( - path -> { - try { - return path.toRealPath().toString(); - } catch (IOException e) { - return ""; - } - }) - .anyMatch(s -> s.toLowerCase().contains("musl")); - } catch (Exception ignored) { - // fall back to checking for alpine linux in the event we're using an older kernel which - // may not fail the above check - return isAlpineLinux(); - } - } + static boolean isAlpineLinux() { + try (Stream osLines = Files.lines(Paths.get("/etc/os-release"))) { + return osLines.anyMatch(l -> l.startsWith("ID") && l.contains("alpine")); + } + catch (Exception ignored2) { + } + return false; + } - static boolean isAlpineLinux() { - try (Stream osLines = Files.lines(Paths.get("/etc/os-release"))) { - return osLines.anyMatch(l -> l.startsWith("ID") && l.contains("alpine")); - } catch (Exception ignored2) { - } - return false; - } + static String getHardwareName() { + try { + return processRunner.runAndWaitFor("uname -m"); + } + catch (Throwable e) { + System.err.println("Error while running uname -m: " + e.getMessage()); + return "unknown"; + } + } - static String getHardwareName() { - try { - return processRunner.runAndWaitFor("uname -m"); - } catch (Throwable e) { - System.err.println("Error while running uname -m: " + e.getMessage()); - return "unknown"; - } - } + static String resolveArmArchType() { + if (System.getProperty("os.name").contains("Linux")) { + String armType = getHardwareName(); + // armType (uname -m) can be armv5t, armv5te, armv5tej, armv5tejl, armv6, armv7, armv7l, + // aarch64, i686 - static String resolveArmArchType() { - if (System.getProperty("os.name").contains("Linux")) { - String armType = getHardwareName(); - // armType (uname -m) can be armv5t, armv5te, armv5tej, armv5tejl, armv6, armv7, armv7l, - // aarch64, i686 + // for Android, we fold everything that is not aarch64 into arm + if (isAndroid()) { + if (armType.startsWith("aarch64")) { + // Use arm64 + return "aarch64"; + } + else { + return "arm"; + } + } - // for Android, we fold everything that is not aarch64 into arm - if (isAndroid()) { - if (armType.startsWith("aarch64")) { - // Use arm64 - return "aarch64"; - } else { - return "arm"; - } - } + if (armType.startsWith("armv6")) { + // Raspberry PI + return "armv6"; + } + else if (armType.startsWith("armv7")) { + // Generic + return "armv7"; + } + else if (armType.startsWith("armv5")) { + // Use armv5, soft-float ABI + return "arm"; + } + else if (armType.startsWith("aarch64")) { + // Use arm64 + return "aarch64"; + } - if (armType.startsWith("armv6")) { - // Raspberry PI - return "armv6"; - } else if (armType.startsWith("armv7")) { - // Generic - return "armv7"; - } else if (armType.startsWith("armv5")) { - // Use armv5, soft-float ABI - return "arm"; - } else if (armType.startsWith("aarch64")) { - // Use arm64 - return "aarch64"; - } + // Java 1.8 introduces a system property to determine armel or armhf + // https://bugs.openjdk.org/browse/JDK-8005545 + String abi = System.getProperty("sun.arch.abi"); + if (abi != null && abi.startsWith("gnueabihf")) { + return "armv7"; + } - // Java 1.8 introduces a system property to determine armel or armhf - // http://bugs.java.com/bugdatabase/view_bug.do?bug_id=8005545 - String abi = System.getProperty("sun.arch.abi"); - if (abi != null && abi.startsWith("gnueabihf")) { - return "armv7"; - } + // For java7, we still need to run some shell commands to determine ABI of JVM + String javaHome = System.getProperty("java.home"); + try { + // determine if first JVM found uses ARM hard-float ABI + int exitCode = Runtime.getRuntime().exec("which readelf").waitFor(); + if (exitCode == 0) { + String[] cmdarray = { + "/bin/sh", + "-c", + "find '" + + javaHome + + "' -name 'libjvm.so' | head -1 | xargs readelf -A | " + + "grep 'Tag_ABI_VFP_args: VFP registers'" + }; + exitCode = Runtime.getRuntime().exec(cmdarray).waitFor(); + if (exitCode == 0) { + return "armv7"; + } + } + else { + System.err.println( + "WARNING! readelf not found. Cannot check if running on an armhf system, armel architecture will be presumed."); + } + } + catch (IOException | InterruptedException e) { + // ignored: fall back to "arm" arch (soft-float ABI) + } + } + // Use armv5, soft-float ABI + return "arm"; + } - // For java7, we still need to run some shell commands to determine ABI of JVM - String javaHome = System.getProperty("java.home"); - try { - // determine if first JVM found uses ARM hard-float ABI - int exitCode = Runtime.getRuntime().exec("which readelf").waitFor(); - if (exitCode == 0) { - String[] cmdarray = { - "/bin/sh", - "-c", - "find '" - + javaHome - + "' -name 'libjvm.so' | head -1 | xargs readelf -A | " - + "grep 'Tag_ABI_VFP_args: VFP registers'" - }; - exitCode = Runtime.getRuntime().exec(cmdarray).waitFor(); - if (exitCode == 0) { - return "armv7"; - } - } else { - System.err.println( - "WARNING! readelf not found. Cannot check if running on an armhf system, armel architecture will be presumed."); - } - } catch (IOException | InterruptedException e) { - // ignored: fall back to "arm" arch (soft-float ABI) - } - } - // Use armv5, soft-float ABI - return "arm"; - } + static String getArchName() { + String override = System.getProperty("de.kherud.llama.osinfo.architecture"); + if (override != null) { + return override; + } - static String getArchName() { - String override = System.getProperty("de.kherud.llama.osinfo.architecture"); - if (override != null) { - return override; - } + String osArch = System.getProperty("os.arch"); - String osArch = System.getProperty("os.arch"); + if (osArch.startsWith("arm")) { + osArch = resolveArmArchType(); + } + else { + String lc = osArch.toLowerCase(Locale.US); + if (archMapping.containsKey(lc)) return archMapping.get(lc); + } + return translateArchNameToFolderName(osArch); + } - if (osArch.startsWith("arm")) { - osArch = resolveArmArchType(); - } else { - String lc = osArch.toLowerCase(Locale.US); - if (archMapping.containsKey(lc)) return archMapping.get(lc); - } - return translateArchNameToFolderName(osArch); - } + static String translateOSNameToFolderName(String osName) { + if (osName.contains("Windows")) { + return "Windows"; + } + else if (osName.contains("Mac") || osName.contains("Darwin")) { + return "Mac"; + } + else if (osName.contains("AIX")) { + return "AIX"; + } + else if (isMusl()) { + return "Linux-Musl"; + } + else if (isAndroid()) { + return "Linux-Android"; + } + else if (osName.contains("Linux")) { + return "Linux"; + } + else { + return osName.replaceAll("\\W", ""); + } + } - static String translateOSNameToFolderName(String osName) { - if (osName.contains("Windows")) { - return "Windows"; - } else if (osName.contains("Mac") || osName.contains("Darwin")) { - return "Mac"; - } else if (osName.contains("AIX")) { - return "AIX"; - } else if (isMusl()) { - return "Linux-Musl"; - } else if (isAndroid()) { - return "Linux-Android"; - } else if (osName.contains("Linux")) { - return "Linux"; - } else { - return osName.replaceAll("\\W", ""); - } - } - - static String translateArchNameToFolderName(String archName) { - return archName.replaceAll("\\W", ""); - } + static String translateArchNameToFolderName(String archName) { + return archName.replaceAll("\\W", ""); + } } diff --git a/src/main/java/de/kherud/llama/Pair.java b/src/main/java/de/kherud/llama/Pair.java new file mode 100644 index 00000000..48ac648b --- /dev/null +++ b/src/main/java/de/kherud/llama/Pair.java @@ -0,0 +1,48 @@ +package de.kherud.llama; + +import java.util.Objects; + +public class Pair { + + private final K key; + private final V value; + + public Pair(K key, V value) { + this.key = key; + this.value = value; + } + + public K getKey() { + return key; + } + + public V getValue() { + return value; + } + + @Override + public int hashCode() { + return Objects.hash(key, value); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Pair other = (Pair) obj; + return Objects.equals(key, other.key) && Objects.equals(value, other.value); + } + + @Override + public String toString() { + return "Pair [key=" + key + ", value=" + value + "]"; + } + + + + +} diff --git a/src/main/java/de/kherud/llama/ProcessRunner.java b/src/main/java/de/kherud/llama/ProcessRunner.java index 6a1fd8dd..24e63498 100644 --- a/src/main/java/de/kherud/llama/ProcessRunner.java +++ b/src/main/java/de/kherud/llama/ProcessRunner.java @@ -21,7 +21,7 @@ String runAndWaitFor(String command, long timeout, TimeUnit unit) return getProcessOutput(p); } - static String getProcessOutput(Process process) throws IOException { + private static String getProcessOutput(Process process) throws IOException { try (InputStream in = process.getInputStream()) { int readLen; ByteArrayOutputStream b = new ByteArrayOutputStream(); diff --git a/src/main/java/de/kherud/llama/args/CacheType.java b/src/main/java/de/kherud/llama/args/CacheType.java new file mode 100644 index 00000000..8404ed75 --- /dev/null +++ b/src/main/java/de/kherud/llama/args/CacheType.java @@ -0,0 +1,15 @@ +package de.kherud.llama.args; + +public enum CacheType { + + F32, + F16, + BF16, + Q8_0, + Q4_0, + Q4_1, + IQ4_NL, + Q5_0, + Q5_1 + +} diff --git a/src/main/java/de/kherud/llama/args/GpuSplitMode.java b/src/main/java/de/kherud/llama/args/GpuSplitMode.java new file mode 100644 index 00000000..0c0cd934 --- /dev/null +++ b/src/main/java/de/kherud/llama/args/GpuSplitMode.java @@ -0,0 +1,8 @@ +package de.kherud.llama.args; + +public enum GpuSplitMode { + + NONE, + LAYER, + ROW +} diff --git a/src/main/java/de/kherud/llama/args/LogFormat.java b/src/main/java/de/kherud/llama/args/LogFormat.java new file mode 100644 index 00000000..8a5b46e8 --- /dev/null +++ b/src/main/java/de/kherud/llama/args/LogFormat.java @@ -0,0 +1,11 @@ +package de.kherud.llama.args; + +/** + * The log output format (defaults to JSON for all server-based outputs). + */ +public enum LogFormat { + + JSON, + TEXT + +} diff --git a/src/main/java/de/kherud/llama/args/MiroStat.java b/src/main/java/de/kherud/llama/args/MiroStat.java new file mode 100644 index 00000000..5268d9bc --- /dev/null +++ b/src/main/java/de/kherud/llama/args/MiroStat.java @@ -0,0 +1,8 @@ +package de.kherud.llama.args; + +public enum MiroStat { + + DISABLED, + V1, + V2 +} diff --git a/src/main/java/de/kherud/llama/args/NumaStrategy.java b/src/main/java/de/kherud/llama/args/NumaStrategy.java new file mode 100644 index 00000000..fa7a61b0 --- /dev/null +++ b/src/main/java/de/kherud/llama/args/NumaStrategy.java @@ -0,0 +1,8 @@ +package de.kherud.llama.args; + +public enum NumaStrategy { + + DISTRIBUTE, + ISOLATE, + NUMACTL +} diff --git a/src/main/java/de/kherud/llama/args/PoolingType.java b/src/main/java/de/kherud/llama/args/PoolingType.java new file mode 100644 index 00000000..c0379c85 --- /dev/null +++ b/src/main/java/de/kherud/llama/args/PoolingType.java @@ -0,0 +1,21 @@ +package de.kherud.llama.args; + +public enum PoolingType { + + UNSPECIFIED("unspecified"), + NONE("none"), + MEAN("mean"), + CLS("cls"), + LAST("last"), + RANK("rank"); + + private final String argValue; + + PoolingType(String value) { + this.argValue = value; + } + + public String getArgValue() { + return argValue; + } +} \ No newline at end of file diff --git a/src/main/java/de/kherud/llama/args/RopeScalingType.java b/src/main/java/de/kherud/llama/args/RopeScalingType.java new file mode 100644 index 00000000..138d05be --- /dev/null +++ b/src/main/java/de/kherud/llama/args/RopeScalingType.java @@ -0,0 +1,21 @@ +package de.kherud.llama.args; + +public enum RopeScalingType { + + UNSPECIFIED("unspecified"), + NONE("none"), + LINEAR("linear"), + YARN2("yarn"), + LONGROPE("longrope"), + MAX_VALUE("maxvalue"); + + private final String argValue; + + RopeScalingType(String value) { + this.argValue = value; + } + + public String getArgValue() { + return argValue; + } +} \ No newline at end of file diff --git a/src/main/java/de/kherud/llama/args/Sampler.java b/src/main/java/de/kherud/llama/args/Sampler.java new file mode 100644 index 00000000..564a2e6f --- /dev/null +++ b/src/main/java/de/kherud/llama/args/Sampler.java @@ -0,0 +1,15 @@ +package de.kherud.llama.args; + +public enum Sampler { + + DRY, + TOP_K, + TOP_P, + TYP_P, + MIN_P, + TEMPERATURE, + XTC, + INFILL, + PENALTIES + +} diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java new file mode 100644 index 00000000..e3e69d8c --- /dev/null +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -0,0 +1,335 @@ +package de.kherud.llama; + +import java.io.*; +import java.util.*; +import java.util.regex.Pattern; + +import de.kherud.llama.args.LogFormat; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class LlamaModelTest { + + private static final String prefix = "def remove_non_ascii(s: str) -> str:\n \"\"\" "; + private static final String suffix = "\n return result\n"; + private static final int nPredict = 10; + + private static LlamaModel model; + + @BeforeClass + public static void setup() { +// LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> System.out.println(level + ": " + msg)); + model = new LlamaModel( + new ModelParameters() + .setCtxSize(128) + .setModel("models/codellama-7b.Q2_K.gguf") + //.setModelUrl("https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf") + .setGpuLayers(43) + .enableEmbedding().enableLogTimestamps().enableLogPrefix() + ); + } + + @AfterClass + public static void tearDown() { + if (model != null) { + model.close(); + } + } + + @Test + public void testGenerateAnswer() { + Map logitBias = new HashMap<>(); + logitBias.put(2, 2.0f); + InferenceParameters params = new InferenceParameters(prefix) + .setTemperature(0.95f) + .setStopStrings("\"\"\"") + .setNPredict(nPredict) + .setTokenIdBias(logitBias); + + int generated = 0; + for (LlamaOutput ignored : model.generate(params)) { + generated++; + } + // todo: currently, after generating nPredict tokens, there is an additional empty output + Assert.assertTrue(generated > 0 && generated <= nPredict + 1); + } + + @Test + public void testGenerateInfill() { + Map logitBias = new HashMap<>(); + logitBias.put(2, 2.0f); + InferenceParameters params = new InferenceParameters("") + .setInputPrefix(prefix) + .setInputSuffix(suffix ) + .setTemperature(0.95f) + .setStopStrings("\"\"\"") + .setNPredict(nPredict) + .setTokenIdBias(logitBias) + .setSeed(42); + + int generated = 0; + for (LlamaOutput ignored : model.generate(params)) { + generated++; + } + Assert.assertTrue(generated > 0 && generated <= nPredict + 1); + } + + @Test + public void testGenerateGrammar() { + InferenceParameters params = new InferenceParameters("") + .setGrammar("root ::= (\"a\" | \"b\")+") + .setNPredict(nPredict); + StringBuilder sb = new StringBuilder(); + for (LlamaOutput output : model.generate(params)) { + sb.append(output); + } + String output = sb.toString(); + + Assert.assertTrue(output.matches("[ab]+")); + int generated = model.encode(output).length; + Assert.assertTrue(generated > 0 && generated <= nPredict + 1); + } + + @Test + public void testCompleteAnswer() { + Map logitBias = new HashMap<>(); + logitBias.put(2, 2.0f); + InferenceParameters params = new InferenceParameters(prefix) + .setTemperature(0.95f) + .setStopStrings("\"\"\"") + .setNPredict(nPredict) + .setTokenIdBias(logitBias) + .setSeed(42); + + String output = model.complete(params); + Assert.assertFalse(output.isEmpty()); + } + + @Test + public void testCompleteInfillCustom() { + Map logitBias = new HashMap<>(); + logitBias.put(2, 2.0f); + InferenceParameters params = new InferenceParameters("") + .setInputPrefix(prefix) + .setInputSuffix(suffix) + .setTemperature(0.95f) + .setStopStrings("\"\"\"") + .setNPredict(nPredict) + .setTokenIdBias(logitBias) + .setSeed(42); + + String output = model.complete(params); + Assert.assertFalse(output.isEmpty()); + } + + @Test + public void testCompleteGrammar() { + InferenceParameters params = new InferenceParameters("") + .setGrammar("root ::= (\"a\" | \"b\")+") + .setNPredict(nPredict); + String output = model.complete(params); + Assert.assertTrue(output + " doesn't match [ab]+", output.matches("[ab]+")); + int generated = model.encode(output).length; + Assert.assertTrue("generated count is: " + generated, generated > 0 && generated <= nPredict + 1); + + } + + @Test + public void testCancelGenerating() { + InferenceParameters params = new InferenceParameters(prefix).setNPredict(nPredict); + + int generated = 0; + LlamaIterator iterator = model.generate(params).iterator(); + while (iterator.hasNext()) { + iterator.next(); + generated++; + if (generated == 5) { + iterator.cancel(); + } + } + Assert.assertEquals(5, generated); + } + + @Test + public void testEmbedding() { + float[] embedding = model.embed(prefix); + Assert.assertEquals(4096, embedding.length); + } + + + @Ignore + /** + * To run this test download the model from here https://huggingface.co/mradermacher/jina-reranker-v1-tiny-en-GGUF/tree/main + * remove .enableEmbedding() from model setup and add .enableReRanking() and then enable the test. + */ + public void testReRanking() { + + String query = "Machine learning is"; + String [] TEST_DOCUMENTS = new String[] { + "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", + "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", + "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." + }; + LlamaOutput llamaOutput = model.rerank(query, TEST_DOCUMENTS[0], TEST_DOCUMENTS[1], TEST_DOCUMENTS[2], TEST_DOCUMENTS[3] ); + + System.out.println(llamaOutput); + } + + @Test + public void testTokenization() { + String prompt = "Hello, world!"; + int[] encoded = model.encode(prompt); + String decoded = model.decode(encoded); + // the llama tokenizer adds a space before the prompt + Assert.assertEquals(" " +prompt, decoded); + } + + @Ignore + public void testLogText() { + List messages = new ArrayList<>(); + LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> messages.add(new LogMessage(level, msg))); + + InferenceParameters params = new InferenceParameters(prefix) + .setNPredict(nPredict) + .setSeed(42); + model.complete(params); + + Assert.assertFalse(messages.isEmpty()); + + Pattern jsonPattern = Pattern.compile("^\\s*[\\[{].*[}\\]]\\s*$"); + for (LogMessage message : messages) { + Assert.assertNotNull(message.level); + Assert.assertFalse(jsonPattern.matcher(message.text).matches()); + } + } + + @Ignore + public void testLogJSON() { + List messages = new ArrayList<>(); + LlamaModel.setLogger(LogFormat.JSON, (level, msg) -> messages.add(new LogMessage(level, msg))); + + InferenceParameters params = new InferenceParameters(prefix) + .setNPredict(nPredict) + .setSeed(42); + model.complete(params); + + Assert.assertFalse(messages.isEmpty()); + + Pattern jsonPattern = Pattern.compile("^\\s*[\\[{].*[}\\]]\\s*$"); + for (LogMessage message : messages) { + Assert.assertNotNull(message.level); + Assert.assertTrue(jsonPattern.matcher(message.text).matches()); + } + } + + @Ignore + @Test + public void testLogStdout() { + // Unfortunately, `printf` can't be easily re-directed to Java. This test only works manually, thus. + InferenceParameters params = new InferenceParameters(prefix) + .setNPredict(nPredict) + .setSeed(42); + + System.out.println("########## Log Text ##########"); + LlamaModel.setLogger(LogFormat.TEXT, null); + model.complete(params); + + System.out.println("########## Log JSON ##########"); + LlamaModel.setLogger(LogFormat.JSON, null); + model.complete(params); + + System.out.println("########## Log None ##########"); + LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> {}); + model.complete(params); + + System.out.println("##############################"); + } + + private String completeAndReadStdOut() { + PrintStream stdOut = System.out; + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + @SuppressWarnings("ImplicitDefaultCharsetUsage") PrintStream printStream = new PrintStream(outputStream); + System.setOut(printStream); + + try { + InferenceParameters params = new InferenceParameters(prefix) + .setNPredict(nPredict) + .setSeed(42); + model.complete(params); + } finally { + System.out.flush(); + System.setOut(stdOut); + printStream.close(); + } + + return outputStream.toString(); + } + + private List splitLines(String text) { + List lines = new ArrayList<>(); + + Scanner scanner = new Scanner(text); + while (scanner.hasNextLine()) { + String line = scanner.nextLine(); + lines.add(line); + } + scanner.close(); + + return lines; + } + + private static final class LogMessage { + private final LogLevel level; + private final String text; + + private LogMessage(LogLevel level, String text) { + this.level = level; + this.text = text; + } + } + + @Test + public void testJsonSchemaToGrammar() { + String schema = "{\n" + + " \"properties\": {\n" + + " \"a\": {\"type\": \"string\"},\n" + + " \"b\": {\"type\": \"string\"},\n" + + " \"c\": {\"type\": \"string\"}\n" + + " },\n" + + " \"additionalProperties\": false\n" + + "}"; + + String expectedGrammar = "a-kv ::= \"\\\"a\\\"\" space \":\" space string\n" + + "a-rest ::= ( \",\" space b-kv )? b-rest\n" + + "b-kv ::= \"\\\"b\\\"\" space \":\" space string\n" + + "b-rest ::= ( \",\" space c-kv )?\n" + + "c-kv ::= \"\\\"c\\\"\" space \":\" space string\n" + + "char ::= [^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})\n" + + "root ::= \"{\" space (a-kv a-rest | b-kv b-rest | c-kv )? \"}\" space\n" + + "space ::= | \" \" | \"\\n\"{1,2} [ \\t]{0,20}\n" + + "string ::= \"\\\"\" char* \"\\\"\" space\n"; + + String actualGrammar = LlamaModel.jsonSchemaToGrammar(schema); + Assert.assertEquals(expectedGrammar, actualGrammar); + } + + @Test + public void testTemplate() { + + List> userMessages = new ArrayList<>(); + userMessages.add(new Pair<>("user", "What is the best book?")); + userMessages.add(new Pair<>("assistant", "It depends on your interests. Do you like fiction or non-fiction?")); + + InferenceParameters params = new InferenceParameters("A book recommendation system.") + .setMessages("Book", userMessages) + .setTemperature(0.95f) + .setStopStrings("\"\"\"") + .setNPredict(nPredict) + .setSeed(42); + Assert.assertEquals(model.applyTemplate(params), "<|im_start|>system\nBook<|im_end|>\n<|im_start|>user\nWhat is the best book?<|im_end|>\n<|im_start|>assistant\nIt depends on your interests. Do you like fiction or non-fiction?<|im_end|>\n<|im_start|>assistant\n"); + } +} diff --git a/src/test/java/de/kherud/llama/RerankingModelTest.java b/src/test/java/de/kherud/llama/RerankingModelTest.java new file mode 100644 index 00000000..60d32bde --- /dev/null +++ b/src/test/java/de/kherud/llama/RerankingModelTest.java @@ -0,0 +1,83 @@ +package de.kherud.llama; + +import java.util.List; +import java.util.Map; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +public class RerankingModelTest { + + private static LlamaModel model; + + String query = "Machine learning is"; + String[] TEST_DOCUMENTS = new String[] { + "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", + "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", + "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." }; + + @BeforeClass + public static void setup() { + model = new LlamaModel( + new ModelParameters().setCtxSize(128).setModel("models/jina-reranker-v1-tiny-en-Q4_0.gguf") + .setGpuLayers(43).enableReranking().enableLogTimestamps().enableLogPrefix()); + } + + @AfterClass + public static void tearDown() { + if (model != null) { + model.close(); + } + } + + @Test + public void testReRanking() { + + + LlamaOutput llamaOutput = model.rerank(query, TEST_DOCUMENTS[0], TEST_DOCUMENTS[1], TEST_DOCUMENTS[2], + TEST_DOCUMENTS[3]); + + Map rankedDocumentsMap = llamaOutput.probabilities; + Assert.assertTrue(rankedDocumentsMap.size()==TEST_DOCUMENTS.length); + + // Finding the most and least relevant documents + String mostRelevantDoc = null; + String leastRelevantDoc = null; + float maxScore = Float.MIN_VALUE; + float minScore = Float.MAX_VALUE; + + for (Map.Entry entry : rankedDocumentsMap.entrySet()) { + if (entry.getValue() > maxScore) { + maxScore = entry.getValue(); + mostRelevantDoc = entry.getKey(); + } + if (entry.getValue() < minScore) { + minScore = entry.getValue(); + leastRelevantDoc = entry.getKey(); + } + } + + // Assertions + Assert.assertTrue(maxScore > minScore); + Assert.assertEquals("Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", mostRelevantDoc); + Assert.assertEquals("Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine.", leastRelevantDoc); + + + } + + @Test + public void testSortedReRanking() { + List> rankedDocuments = model.rerank(true, query, TEST_DOCUMENTS); + Assert.assertEquals(rankedDocuments.size(), TEST_DOCUMENTS.length); + + // Check the ranking order: each score should be >= the next one + for (int i = 0; i < rankedDocuments.size() - 1; i++) { + float currentScore = rankedDocuments.get(i).getValue(); + float nextScore = rankedDocuments.get(i + 1).getValue(); + Assert.assertTrue("Ranking order incorrect at index " + i, currentScore >= nextScore); + } + } +} diff --git a/src/test/java/examples/GrammarExample.java b/src/test/java/examples/GrammarExample.java index 212a052b..d90de206 100644 --- a/src/test/java/examples/GrammarExample.java +++ b/src/test/java/examples/GrammarExample.java @@ -1,5 +1,8 @@ package examples; +import de.kherud.llama.LlamaOutput; +import de.kherud.llama.ModelParameters; + import de.kherud.llama.InferenceParameters; import de.kherud.llama.LlamaModel; @@ -9,14 +12,14 @@ public static void main(String... args) { String grammar = "root ::= (expr \"=\" term \"\\n\")+\n" + "expr ::= term ([-+*/] term)*\n" + "term ::= [0-9]"; - InferenceParameters params = new InferenceParameters.Builder() - .setGrammar(grammar) - .build(); - - String filePath = "/run/media/konstantin/Seagate/models/llama2/llama-2-13b-chat/gguf-model-q4_0.bin"; - LlamaModel model = new LlamaModel(filePath); - for (String output : model.generate("", params)) { - System.out.print(output); + ModelParameters modelParams = new ModelParameters() + .setModel("models/mistral-7b-instruct-v0.2.Q2_K.gguf"); + InferenceParameters inferParams = new InferenceParameters("") + .setGrammar(grammar); + try (LlamaModel model = new LlamaModel(modelParams)) { + for (LlamaOutput output : model.generate(inferParams)) { + System.out.print(output); + } } } diff --git a/src/test/java/examples/InfillExample.java b/src/test/java/examples/InfillExample.java new file mode 100644 index 00000000..e13ecb7c --- /dev/null +++ b/src/test/java/examples/InfillExample.java @@ -0,0 +1,28 @@ +package examples; + +import de.kherud.llama.InferenceParameters; +import de.kherud.llama.LlamaModel; +import de.kherud.llama.LlamaOutput; +import de.kherud.llama.ModelParameters; + +public class InfillExample { + + public static void main(String... args) { + ModelParameters modelParams = new ModelParameters() + .setModel("models/codellama-7b.Q2_K.gguf") + .setGpuLayers(43); + + String prefix = "def remove_non_ascii(s: str) -> str:\n \"\"\" "; + String suffix = "\n return result\n"; + try (LlamaModel model = new LlamaModel(modelParams)) { + System.out.print(prefix); + InferenceParameters inferParams = new InferenceParameters("") + .setInputPrefix(prefix) + .setInputSuffix(suffix); + for (LlamaOutput output : model.generate(inferParams)) { + System.out.print(output); + } + System.out.print(suffix); + } + } +} diff --git a/src/test/java/examples/MainExample.java b/src/test/java/examples/MainExample.java index 772f9d7c..2b5150a5 100644 --- a/src/test/java/examples/MainExample.java +++ b/src/test/java/examples/MainExample.java @@ -7,28 +7,24 @@ import de.kherud.llama.InferenceParameters; import de.kherud.llama.LlamaModel; +import de.kherud.llama.LlamaOutput; import de.kherud.llama.ModelParameters; +import de.kherud.llama.args.MiroStat; +@SuppressWarnings("InfiniteLoopStatement") public class MainExample { public static void main(String... args) throws IOException { - LlamaModel.setLogger((level, message) -> System.out.print(message)); - ModelParameters modelParams = new ModelParameters.Builder() - .setNGpuLayers(43) - .build(); - InferenceParameters inferParams = new InferenceParameters.Builder() - .setTemperature(0.7f) - .setPenalizeNl(true) - .setMirostat(InferenceParameters.MiroStat.V2) - .setAntiPrompt(new String[]{"\n"}) - .build(); - - String modelPath = "/run/media/konstantin/Seagate/models/llama2/llama-2-13b-chat/ggml-model-q4_0.gguf"; + ModelParameters modelParams = new ModelParameters() + .setModel("models/mistral-7b-instruct-v0.2.Q2_K.gguf") + .setGpuLayers(43); String system = "This is a conversation between User and Llama, a friendly chatbot.\n" + "Llama is helpful, kind, honest, good at writing, and never fails to answer any " + - "requests immediately and with precision.\n"; + "requests immediately and with precision.\n\n" + + "User: Hello Llama\n" + + "Llama: Hello. How may I help you today?"; BufferedReader reader = new BufferedReader(new InputStreamReader(System.in, StandardCharsets.UTF_8)); - try (LlamaModel model = new LlamaModel(modelPath, modelParams)) { + try (LlamaModel model = new LlamaModel(modelParams)) { System.out.print(system); String prompt = system; while (true) { @@ -38,7 +34,12 @@ public static void main(String... args) throws IOException { prompt += input; System.out.print("Llama: "); prompt += "\nLlama: "; - for (String output : model.generate(prompt, inferParams)) { + InferenceParameters inferParams = new InferenceParameters(prompt) + .setTemperature(0.7f) + .setPenalizeNl(true) + .setMiroStat(MiroStat.V2) + .setStopStrings("User:"); + for (LlamaOutput output : model.generate(inferParams)) { System.out.print(output); prompt += output; }