diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 00000000..737725bb --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,48 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +FROM nvcr.io/nvidia/tritonserver:24.03-py3 + +ARG USERNAME=triton-server + +RUN apt-get update \ + && apt-get install -y sudo + +RUN pip3 install transformers torch + +# Create the user +RUN apt-get update \ + && apt-get install -y sudo \ + && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \ + && chmod 0440 /etc/sudoers.d/$USERNAME + +RUN pip3 install pre-commit ipdb + +RUN mkhomedir_helper triton-server + +RUN apt-get install -y cmake rapidjson-dev + +USER ${USERNAME} diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..e1b8bd10 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,26 @@ +{ + "name": "Python Backend", + + "build": { + "dockerfile": "Dockerfile" + }, + "customizations": { + "vscode": { + "extensions": [ + "ms-python.vscode-pylance", + "ms-python.python", + "ms-vscode.cpptools-extension-pack", + "ms-vscode.cmake-tools", + "github.vscode-pull-request-github" + ] + } + }, + "postCreateCommand": "sudo chown -R triton-server:triton-server ~/.cache", + + "runArgs": [ "--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined", "--gpus=all", "--shm-size=2g", "--ulimit", "stack=67108864" ], + "mounts": [ + "source=${localEnv:HOME}/.ssh,target=/home/triton-server/.ssh,type=bind,consistency=cached", + "source=${localEnv:HOME}/.cache/huggingface,target=/home/triton-server/.cache/huggingface,type=bind,consistency=cached" + ], + "remoteUser": "triton-server" +} diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index ab4bd951..4fa18732 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,4 +1,4 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -31,8 +31,8 @@ on: jobs: pre-commit: - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v3 - - uses: pre-commit/action@v3.0.0 + - uses: actions/checkout@v5.0.0 + - uses: actions/setup-python@v6.0.0 + - uses: pre-commit/action@v3.0.1 diff --git a/.gitignore b/.gitignore index bf7e1686..419005f0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ /build -/.vscode *.so builddir @@ -139,3 +138,6 @@ dmypy.json # pytype static type analyzer .pytype/ +# vscode +.vscode/settings.json +.vscode/c_cpp_properties.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 298baab6..3c76a6ed 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -25,7 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. repos: -- repo: https://github.com/timothycrosley/isort +- repo: https://github.com/PyCQA/isort rev: 5.12.0 hooks: - id: isort @@ -36,7 +36,7 @@ repos: - id: black types_or: [python, cython] - repo: https://github.com/PyCQA/flake8 - rev: 5.0.4 + rev: 7.3.0 hooks: - id: flake8 args: [--max-line-length=88, --select=C,E,F,W,B,B950, --extend-ignore = E203,E501] @@ -57,7 +57,7 @@ repos: # More details about these pre-commit hooks here: # https://pre-commit.com/hooks.html - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v6.0.0 hooks: - id: check-case-conflict - id: check-executables-have-shebangs diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 00000000..597a746d --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,85 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "label": "Configure", + "type": "shell", + "command": "cmake", + "args": [ + "-DCMAKE_INSTALL_PREFIX:STRING=/opt/tritonserver/", + "-DTRITON_COMMON_REPO_TAG:STRING=main", + "-DTRITON_BACKEND_REPO_TAG:STRING=main", + "-DTRITON_CORE_REPO_TAG:STRING=main", + "-DTRITON_ENABLE_GPU:STRING=ON", + "-DTRITON_ENABLE_NVTX:STRING=ON", + "-DCMAKE_INSTALL_PREFIX:STRING=${workspaceFolder}/build/install", + "-DCMAKE_EXPORT_COMPILE_COMMANDS:BOOL=TRUE", + "-DCMAKE_BUILD_TYPE:STRING=Debug", + "-DCMAKE_C_COMPILER:FILEPATH=/usr/bin/gcc", + "-DCMAKE_CXX_COMPILER:FILEPATH=/usr/bin/g++", + "-S${workspaceFolder}", + "-B${workspaceFolder}/build", + "-G", + "Unix Makefiles" + ], + "problemMatcher": [] + }, + { + "label": "Build", + "type": "shell", + "command": "cmake", + "args": [ + "--build", + "/${workspaceFolder}/build", + "--config", + "Debug", + "--target", + "all", + "-j", + "18", + "--" + ] + }, + { + "label": "Install", + "type": "shell", + "command": "cmake", + "args": [ + "--build", + "${workspaceFolder}/build", + "--config", + "Debug", + "--target", + "install", + "-j", + "18", + "--" + ] + }, + { + "label": "Move", + "type": "shell", + "command": "sudo", + "args": [ + "cp", + "-r", + "${workspaceFolder}/build/install/backends/python/*", + "/opt/tritonserver/backends/python" + ] + }, + { + "label": "Build Python Backend", + "dependsOrder": "sequence", + "dependsOn": [ + "Configure", + "Build", + "Install", + "Move" + ], + "group": { + "kind": "build", + "isDefault": true + } + } + ] +} diff --git a/CMakeLists.txt b/CMakeLists.txt index 93a7ae60..f5c5b293 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -24,10 +24,13 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -cmake_minimum_required(VERSION 3.17) +cmake_minimum_required(VERSION 3.31.8) project(tritonpythonbackend LANGUAGES C CXX) +# Use C++17 standard as Triton's minimum required. +set(TRITON_MIN_CXX_STANDARD 17 CACHE STRING "The minimum C++ standard which features are requested to build this target.") + # # Options # @@ -38,6 +41,13 @@ option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON) option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON) option(TRITON_ENABLE_NVTX "Include nvtx markers collection in backend." OFF) +# FIXME: CI needs to enable the GPU flag. Python for window currently does not +# support GPU tensors. For simplicity, we will override this option here. +if(WIN32) + set(TRITON_ENABLE_GPU OFF CACHE BOOL "GPU disabled" FORCE) +endif() + +set(TRITON_REPO_ORGANIZATION "https://github.com/triton-inference-server" CACHE STRING "Git repository to pull from") set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo") set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo") set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo") @@ -60,17 +70,17 @@ include(ExternalProject) FetchContent_Declare( repo-common - GIT_REPOSITORY https://github.com/triton-inference-server/common.git + GIT_REPOSITORY ${TRITON_REPO_ORGANIZATION}/common.git GIT_TAG ${TRITON_COMMON_REPO_TAG} ) FetchContent_Declare( repo-core - GIT_REPOSITORY https://github.com/triton-inference-server/core.git + GIT_REPOSITORY ${TRITON_REPO_ORGANIZATION}/core.git GIT_TAG ${TRITON_CORE_REPO_TAG} ) FetchContent_Declare( repo-backend - GIT_REPOSITORY https://github.com/triton-inference-server/backend.git + GIT_REPOSITORY ${TRITON_REPO_ORGANIZATION}/backend.git GIT_TAG ${TRITON_BACKEND_REPO_TAG} ) FetchContent_MakeAvailable(repo-common repo-core repo-backend) @@ -78,10 +88,21 @@ FetchContent_MakeAvailable(repo-common repo-core repo-backend) FetchContent_Declare( pybind11 GIT_REPOSITORY "https://github.com/pybind/pybind11" - # COMMIT ID for v2.10.0 - GIT_TAG "aa304c9c7d725ffb9d10af08a3b34cb372307020" + # COMMIT ID for v2.12.0 + GIT_TAG "3e9dfa2866941655c56877882565e7577de6fc7b" GIT_SHALLOW ON ) + +# RHEL base container has multiple version of Python installed. By default +# it seems like pybind will pickup v3.6, so we specifically assign it to +# search for 3.12 here. +set(RHEL_BUILD OFF) +if(LINUX) + file(STRINGS "/etc/os-release" DISTRO_ID_LIKE REGEX "ID_LIKE") + if(${DISTRO_ID_LIKE} MATCHES "rhel|centos") + set(RHEL_BUILD ON) + endif(${DISTRO_ID_LIKE} MATCHES "rhel|centos") +endif(LINUX) FetchContent_MakeAvailable(pybind11) # @@ -93,15 +114,20 @@ FetchContent_Declare( GIT_TAG "v0.8" GIT_SHALLOW ON ) +# Option must be set off so WIN32 build does not break +set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) +set(BUILD_MOCK OFF) FetchContent_MakeAvailable(dlpack) # # Boost # +set(TRITON_BOOST_URL "https://archives.boost.io/release/1.80.0/source/boost_1_80_0.tar.gz" CACHE STRING "Boost source code URL") + ExternalProject_Add( boostorg - URL https://boostorg.jfrog.io/artifactory/main/release/1.79.0/source/boost_1_79_0.tar.gz - URL_HASH SHA256=273f1be93238a068aba4f9735a4a2b003019af067b9c183ed227780b8f36062c + URL ${TRITON_BOOST_URL} + URL_HASH SHA256=4b2136f98bdd1f5857f1c3dea9ac2018effe65286cf251534b6ae20cc45e1847 PREFIX "boost-src" CONFIGURE_COMMAND ${CMAKE_COMMAND} -E copy_directory /boost/ ${CMAKE_BINARY_DIR}/boost @@ -116,7 +142,7 @@ set(boostorg_INCLUDE_DIRS "${CMAKE_BINARY_DIR}/boost/") if(${TRITON_ENABLE_GPU}) find_package(CUDAToolkit REQUIRED) message(STATUS "Using CUDA ${CUDA_VERSION}") - set(CUDA_NVCC_FLAGS -std=c++11) + set(CUDA_NVCC_FLAGS -std=c++${TRITON_MIN_CXX_STANDARD}) elseif() message(WARNING "TRITON_ENABLE_GPU is OFF, GPU Tensor support will be disabled") endif() # TRITON_ENABLE_GPU @@ -126,17 +152,24 @@ if(${TRITON_ENABLE_NVTX}) endif() # TRITON_ENABLE_NVTX find_package(ZLIB REQUIRED) -find_package(Threads REQUIRED) + +if(NOT WIN32) + find_package(Threads REQUIRED) +endif() include_directories(${CMAKE_BINARY_DIR}) configure_file(src/libtriton_python.ldscript libtriton_python.ldscript COPYONLY) set( COMMON_SRCS + src/correlation_id.cc + src/correlation_id.h src/infer_response.cc src/infer_response.h src/infer_request.cc src/infer_request.h + src/infer_trace.cc + src/infer_trace.h src/message_queue.h src/ipc_message.cc src/ipc_message.h @@ -171,21 +204,21 @@ set( ) set( - PYTHON_BACKEND_SRCS - src/python_be.cc - src/python_be.h - src/pb_env.cc - src/pb_env.h - src/pb_metric_reporter.cc - src/pb_metric_reporter.h - src/memory_manager.cc - src/memory_manager.h - src/request_executor.cc - src/request_executor.h - src/stub_launcher.h - src/stub_launcher.cc - src/infer_payload.h - src/infer_payload.cc + PYTHON_BACKEND_SRCS + src/python_be.cc + src/python_be.h + src/pb_env.cc + src/pb_env.h + src/pb_metric_reporter.cc + src/pb_metric_reporter.h + src/memory_manager.cc + src/memory_manager.h + src/request_executor.cc + src/request_executor.h + src/stub_launcher.h + src/stub_launcher.cc + src/infer_payload.h + src/infer_payload.cc ) list(APPEND @@ -206,8 +239,14 @@ set( src/response_sender.h src/pb_stub.h src/pb_stub.cc + src/pb_stub_log.h + src/pb_stub_log.cc src/pb_response_iterator.h src/pb_response_iterator.cc + src/pb_cancel.cc + src/pb_cancel.h + src/pb_bls_cancel.cc + src/pb_bls_cancel.h ) list(APPEND @@ -229,53 +268,104 @@ add_library( TritonPythonBackend::triton-python-backend ALIAS triton-python-backend ) -target_compile_features(triton-python-backend PRIVATE cxx_std_11) +target_compile_features(triton-python-backend PRIVATE cxx_std_${TRITON_MIN_CXX_STANDARD}) target_compile_options( triton-python-backend PRIVATE $<$,$,$>: - -Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror> + -Wall -Wextra -Wno-unused-parameter -Wno-type-limits> + $<$:/Wall /D_WIN32_WINNT=0x0A00 /EHsc /Zc:preprocessor> ) -target_compile_features(triton-python-backend-stub PRIVATE cxx_std_11) +target_compile_features(triton-python-backend-stub PRIVATE cxx_std_${TRITON_MIN_CXX_STANDARD}) target_compile_options( triton-python-backend-stub PRIVATE $<$,$,$>: - -fvisibility=hidden -Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror> + -fvisibility=hidden -Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror> + $<$:/Wall /D_WIN32_WINNT=0x0A00 /EHsc /Zc:preprocessor> ) target_compile_definitions(triton-python-backend-stub PRIVATE TRITON_PB_STUB) -target_link_libraries( - triton-python-backend - PRIVATE +# RHEL assets are not released in a container environment nor do the current +# Python lib versions in the manylinux base container match those currently +# available for RHEL8 package managers. Therefore, we package the correct +# python libs in the backend folder and adjust the stub executable to look +# in its own folder at runtime. +if(RHEL_BUILD) + set_target_properties( + triton-python-backend-stub + PROPERTIES + SKIP_BUILD_RPATH TRUE + BUILD_WITH_INSTALL_RPATH TRUE + INSTALL_RPATH_USE_LINK_PATH FALSE + INSTALL_RPATH "$\{ORIGIN\}" + ) +endif(RHEL_BUILD) + + +# For WIN32 do not link Threads and DL_LIBS +if(WIN32) + target_link_libraries( + triton-python-backend + PRIVATE + dlpack + triton-backend-utils # from repo-backend + -lrt # shared memory + triton-core-serverstub # from repo-core + ZLIB::ZLIB + -larchive + ) + + target_link_libraries( + triton-python-backend-stub + PRIVATE + dlpack + triton-backend-utils # from repo-backend + pybind11::embed + -lrt # shared memory + -larchive # libarchive + ) +else() + target_link_libraries( + triton-python-backend + PRIVATE + dlpack + Threads::Threads + triton-backend-utils # from repo-backend + ${CMAKE_DL_LIBS} # dlopen and dlclose + -lrt # shared memory + triton-core-serverstub # from repo-core + ZLIB::ZLIB + -larchive + ) + + target_link_libraries( + triton-python-backend-stub + PRIVATE dlpack Threads::Threads - triton-backend-utils # from repo-backend - ${CMAKE_DL_LIBS} # dlopen and dlclose - -lrt # shared memory - triton-core-serverstub # from repo-core - ZLIB::ZLIB - -larchive -) - -target_link_libraries( - triton-python-backend-stub - PRIVATE - dlpack - Threads::Threads - triton-backend-utils # from repo-backend - ${CMAKE_DL_LIBS} # dlopen and dlclose - pybind11::embed - -lrt # shared memory - -larchive # libarchive -) + triton-backend-utils # from repo-backend + ${CMAKE_DL_LIBS} # dlopen and dlclose + pybind11::embed + -lrt # shared memory + -larchive # libarchive + ) +endif() -set_target_properties( - triton-python-backend PROPERTIES - POSITION_INDEPENDENT_CODE ON - OUTPUT_NAME triton_python - LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_python.ldscript - LINK_FLAGS "-Wl,--version-script libtriton_python.ldscript" -) +if(WIN32) + set_target_properties( + triton-python-backend PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_python + ) +else() + set_target_properties( + triton-python-backend PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_python + LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_python.ldscript + LINK_FLAGS "-Wl,--version-script libtriton_python.ldscript" + ) +endif() add_subdirectory(./src/shm_monitor) @@ -307,13 +397,6 @@ install( ${INSTALL_CONFIGDIR} ) -install( - DIRECTORY - src/resources/platform_handlers - DESTINATION - ${CMAKE_INSTALL_PREFIX}/backends/python -) - install( FILES src/resources/triton_python_backend_utils.py diff --git a/README.md b/README.md index 6c445d86..dd5e877a 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ + +# Preprocessing Using Python Backend Example This example shows how to preprocess your inputs using Python backend before it is passed to the TensorRT model for inference. This ensemble model includes an image preprocessing model (preprocess) and a TensorRT model (resnet50_trt) to do inference. **1. Converting PyTorch Model to ONNX format:** Run onnx_exporter.py to convert ResNet50 PyTorch model to ONNX format. Width and height dims are fixed at 224 but dynamic axes arguments for dynamic batching are used. Commands from the 2. and 3. subsections shall be executed within this Docker container. - $ docker run -it --gpus=all -v $(pwd):/workspace nvcr.io/nvidia/pytorch:xx.yy-py3 bash - $ pip install numpy pillow torchvision - $ python onnx_exporter.py --save model.onnx + docker run -it --gpus=all -v $(pwd):/workspace nvcr.io/nvidia/pytorch:xx.yy-py3 bash + pip install numpy pillow torchvision + python onnx_exporter.py --save model.onnx **2. Create the model repository:** - $ mkdir -p model_repository/ensemble_python_resnet50/1 - $ mkdir -p model_repository/preprocess/1 - $ mkdir -p model_repository/resnet50_trt/1 + mkdir -p model_repository/ensemble_python_resnet50/1 + mkdir -p model_repository/preprocess/1 + mkdir -p model_repository/resnet50_trt/1 # Copy the Python model - $ cp model.py model_repository/preprocess/1 + cp model.py model_repository/preprocess/1 **3. Build a TensorRT engine for the ONNX model** Set the arguments for enabling fp16 precision --fp16. To enable dynamic shapes use --minShapes, --optShapes, and maxShapes with --explicitBatch: - $ trtexec --onnx=model.onnx --saveEngine=./model_repository/resnet50_trt/1/model.plan --explicitBatch --minShapes=input:1x3x224x224 --optShapes=input:1x3x224x224 --maxShapes=input:256x3x224x224 --fp16 + trtexec --onnx=model.onnx --saveEngine=./model_repository/resnet50_trt/1/model.plan --explicitBatch --minShapes=input:1x3x224x224 --optShapes=input:1x3x224x224 --maxShapes=input:256x3x224x224 --fp16 **4. Run the command below to start the server container:** Under python_backend/examples/preprocessing, run this command to start the server docker container: - $ docker run --gpus=all -it --rm -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd):/workspace/ -v/$(pwd)/model_repository:/models nvcr.io/nvidia/tritonserver:xx.yy-py3 bash - $ pip install numpy pillow torchvision - $ tritonserver --model-repository=/models + docker run --gpus=all -it --rm -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd):/workspace/ -v/$(pwd)/model_repository:/models nvcr.io/nvidia/tritonserver:xx.yy-py3 bash + pip install numpy pillow torchvision + tritonserver --model-repository=/models **5. Start the client to test:** Under python_backend/examples/preprocessing, run the commands below to start the client Docker container: - $ wget https://raw.githubusercontent.com/triton-inference-server/server/main/qa/images/mug.jpg -O "mug.jpg" - $ docker run --rm --net=host -v $(pwd):/workspace/ nvcr.io/nvidia/tritonserver:xx.yy-py3-sdk python client.py --image mug.jpg - $ The result of classification is:COFFEE MUG + wget https://raw.githubusercontent.com/triton-inference-server/server/main/qa/images/mug.jpg -O "mug.jpg" + docker run --rm --net=host -v $(pwd):/workspace/ nvcr.io/nvidia/tritonserver:xx.yy-py3-sdk python client.py --image mug.jpg + The result of classification is:COFFEE MUG Here, since we input an image of "mug" and the inference result is "COFFEE MUG" which is correct. diff --git a/examples/preprocessing/client.py b/examples/preprocessing/client.py index 202d411a..1ac107af 100644 --- a/examples/preprocessing/client.py +++ b/examples/preprocessing/client.py @@ -29,7 +29,7 @@ import sys import numpy as np -import tritongrpcclient +import tritonclient.grpc as tritongrpcclient def load_image(img_path: str): diff --git a/inferentia/README.md b/inferentia/README.md index 6a90740d..fb0de4f7 100644 --- a/inferentia/README.md +++ b/inferentia/README.md @@ -34,7 +34,7 @@ and the [Neuron Runtime](https://awsdocs-neuron.readthedocs-hosted.com/en/latest ## Table of Contents -- [Using Triton with Inferentia](#using-triton-with-inferentia) +- [Using Triton with Inferentia 1](#using-triton-with-inferentia-1) - [Table of Contents](#table-of-contents) - [Inferentia setup](#inferentia-setup) - [Setting up the Inferentia model](#setting-up-the-inferentia-model) @@ -60,18 +60,18 @@ or simply clone with https. Clone this repo with Github to home repo `/home/ubuntu`. ``` - $chmod 777 /home/ubuntu/python_backend/inferentia/scripts/setup-pre-container.sh - $sudo /home/ubuntu/python_backend/inferentia/scripts/setup-pre-container.sh + chmod 777 /home/ubuntu/python_backend/inferentia/scripts/setup-pre-container.sh + sudo /home/ubuntu/python_backend/inferentia/scripts/setup-pre-container.sh ``` Then, start the Triton instance with: ``` - $docker run --device /dev/neuron0 -v /home/ubuntu/python_backend:/home/ubuntu/python_backend -v /lib/udev:/mylib/udev --shm-size=1g --ulimit memlock=-1 -p 8000:8000 -p 8001:8001 -p 8002:8002 --ulimit stack=67108864 -ti nvcr.io/nvidia/tritonserver:-py3 + docker run --device /dev/neuron0 -v /home/ubuntu/python_backend:/home/ubuntu/python_backend -v /lib/udev:/mylib/udev --shm-size=1g --ulimit memlock=-1 -p 8000:8000 -p 8001:8001 -p 8002:8002 --ulimit stack=67108864 -ti nvcr.io/nvidia/tritonserver:-py3 ``` Note 1: The user would need to list any neuron device to run during container initialization. For example, to use 4 neuron devices on an instance, the user would need to run with: ``` - $docker run --device /dev/neuron0 --device /dev/neuron1 --device /dev/neuron2 --device /dev/neuron3 ...` + docker run --device /dev/neuron0 --device /dev/neuron1 --device /dev/neuron2 --device /dev/neuron3 ...` ``` Note 2: `/mylib/udev` is used for Neuron parameter passing. @@ -81,7 +81,7 @@ Note 3: For Triton container version xx.yy, please refer to After starting the Triton container, go into the `python_backend` folder and run the setup script. ``` - $source /home/ubuntu/python_backend/inferentia/scripts/setup.sh + source /home/ubuntu/python_backend/inferentia/scripts/setup.sh ``` This script will: 1. Install necessary dependencies @@ -118,7 +118,7 @@ triton python model directory. An example invocation for the `gen_triton_model.py` for PyTorch model can look like: ``` - $python3 inferentia/scripts/gen_triton_model.py --model_type pytorch --triton_input INPUT__0,INT64,4x384 INPUT__1,INT64,4x384 INPUT__2,INT64,4x384 --triton_output OUTPUT__0,INT64,4x384 OUTPUT__1,INT64,4x384 --compiled_model /home/ubuntu/bert_large_mlperf_neuron_hack_bs1_dynamic.pt --neuron_core_range 0:3 --triton_model_dir bert-large-mlperf-bs1x4 + python3 inferentia/scripts/gen_triton_model.py --model_type pytorch --triton_input INPUT__0,INT64,4x384 INPUT__1,INT64,4x384 INPUT__2,INT64,4x384 --triton_output OUTPUT__0,INT64,4x384 OUTPUT__1,INT64,4x384 --compiled_model /home/ubuntu/bert_large_mlperf_neuron_hack_bs1_dynamic.pt --neuron_core_range 0:3 --triton_model_dir bert-large-mlperf-bs1x4 ``` In order for the script to treat the compiled model as TorchScript @@ -161,7 +161,7 @@ script to generate triton python model directory. An example invocation for the `gen_triton_model.py` for TensorFlow model can look like: ``` - $python3 gen_triton_model.py --model_type tensorflow --compiled_model /home/ubuntu/inferentia-poc-2.0/scripts-rn50-tf-native/resnet50_mlperf_opt_fp16_compiled_b5_nc1/1 --neuron_core_range 0:3 --triton_model_dir rn50-1neuroncores-bs1x1 + python3 gen_triton_model.py --model_type tensorflow --compiled_model /home/ubuntu/inferentia-poc-2.0/scripts-rn50-tf-native/resnet50_mlperf_opt_fp16_compiled_b5_nc1/1 --neuron_core_range 0:3 --triton_model_dir rn50-1neuroncores-bs1x1 ``` NOTE: Unlike TorchScript model, TensorFlow SavedModel stores sufficient @@ -215,7 +215,7 @@ a valid torchscript file or tensorflow savedmodel. Now, the server can be launched with the model as below: ``` - $tritonserver --model-repository + tritonserver --model-repository ``` Note: @@ -255,7 +255,7 @@ contains the necessary files to set up testing with a simple add_sub model. The requires an instance with more than 8 inferentia cores to run, eg:`inf1.6xlarge`. start the test, run ``` - $source /python_backend/inferentia/qa/setup_test_enviroment_and_test.sh + source /python_backend/inferentia/qa/setup_test_enviroment_and_test.sh ``` where `` is usually `/home/ubuntu`/. This script will pull the [server repo](https://github.com/triton-inference-server/server) @@ -265,7 +265,7 @@ Triton Server and Triton SDK. Note: If you would need to change some of the tests in the server repo, you would need to run ``` - $export TRITON_SERVER_REPO_TAG= + export TRITON_SERVER_REPO_TAG= ``` before running the script. @@ -273,8 +273,8 @@ before running the script. ## pytorch-neuronx and tensorflow-neuronx 1. Similar to the steps for inf1, change the argument to the pre-container and on-container setup scripts to include the `-inf2` or `-trn1`flags e.g., ``` - $chmod 777 /home/ubuntu/python_backend/inferentia/scripts/setup-pre-container.sh - $sudo /home/ubuntu/python_backend/inferentia/scripts/setup-pre-container.sh -inf2 + chmod 777 /home/ubuntu/python_backend/inferentia/scripts/setup-pre-container.sh + sudo /home/ubuntu/python_backend/inferentia/scripts/setup-pre-container.sh -inf2 ``` 2. On the container, followed by the `docker run` command, you can pass similar argument to the setup.sh script For Pytorch: diff --git a/src/correlation_id.cc b/src/correlation_id.cc new file mode 100644 index 00000000..d7b19eea --- /dev/null +++ b/src/correlation_id.cc @@ -0,0 +1,120 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "correlation_id.h" + +namespace triton { namespace backend { namespace python { + +CorrelationId::CorrelationId() + : id_string_(""), id_uint_(0), id_type_(CorrelationIdDataType::UINT64) +{ +} + +CorrelationId::CorrelationId(const std::string& id_string) + : id_string_(id_string), id_uint_(0), + id_type_(CorrelationIdDataType::STRING) +{ +} + +CorrelationId::CorrelationId(uint64_t id_uint) + : id_string_(""), id_uint_(id_uint), id_type_(CorrelationIdDataType::UINT64) +{ +} + +CorrelationId::CorrelationId(const CorrelationId& rhs) +{ + id_uint_ = rhs.id_uint_; + id_type_ = rhs.id_type_; + id_string_ = rhs.id_string_; +} + +CorrelationId::CorrelationId(std::unique_ptr& correlation_id_shm) +{ + id_uint_ = correlation_id_shm->id_uint_; + id_type_ = correlation_id_shm->id_type_; + id_string_ = correlation_id_shm->id_string_; +} + +CorrelationId& +CorrelationId::operator=(const CorrelationId& rhs) +{ + id_uint_ = rhs.id_uint_; + id_type_ = rhs.id_type_; + id_string_ = rhs.id_string_; + return *this; +} + +void +CorrelationId::SaveToSharedMemory( + std::unique_ptr& shm_pool) +{ + AllocatedSharedMemory correlation_id_shm = + shm_pool->Construct(); + correlation_id_shm_ptr_ = correlation_id_shm.data_.get(); + + std::unique_ptr id_string_shm = + PbString::Create(shm_pool, id_string_); + + correlation_id_shm_ptr_->id_uint = id_uint_; + correlation_id_shm_ptr_->id_string_shm_handle = id_string_shm->ShmHandle(); + correlation_id_shm_ptr_->id_type = id_type_; + + // Save the references to shared memory. + correlation_id_shm_ = std::move(correlation_id_shm); + id_string_shm_ = std::move(id_string_shm); + shm_handle_ = correlation_id_shm_.handle_; +} + +std::unique_ptr +CorrelationId::LoadFromSharedMemory( + std::unique_ptr& shm_pool, + bi::managed_external_buffer::handle_t handle) +{ + AllocatedSharedMemory correlation_id_shm = + shm_pool->Load(handle); + CorrelationIdShm* correlation_id_shm_ptr = correlation_id_shm.data_.get(); + + std::unique_ptr id_string_shm = PbString::LoadFromSharedMemory( + shm_pool, correlation_id_shm_ptr->id_string_shm_handle); + + return std::unique_ptr( + new CorrelationId(correlation_id_shm, id_string_shm)); +} + +CorrelationId::CorrelationId( + AllocatedSharedMemory& correlation_id_shm, + std::unique_ptr& id_string_shm) + : correlation_id_shm_(std::move(correlation_id_shm)), + id_string_shm_(std::move(id_string_shm)) +{ + correlation_id_shm_ptr_ = correlation_id_shm_.data_.get(); + shm_handle_ = correlation_id_shm_.handle_; + id_string_ = id_string_shm_->String(); + id_uint_ = correlation_id_shm_ptr_->id_uint; + id_type_ = correlation_id_shm_ptr_->id_type; +} + +}}}; // namespace triton::backend::python diff --git a/src/correlation_id.h b/src/correlation_id.h new file mode 100644 index 00000000..63185d9f --- /dev/null +++ b/src/correlation_id.h @@ -0,0 +1,93 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include + +#include "pb_string.h" +#include "pb_utils.h" + +namespace triton { namespace backend { namespace python { + +enum class CorrelationIdDataType { UINT64, STRING }; + +struct CorrelationIdShm { + bi::managed_external_buffer::handle_t id_string_shm_handle; + uint64_t id_uint; + CorrelationIdDataType id_type; +}; + +class CorrelationId { + public: + CorrelationId(); + CorrelationId(const std::string& id_string); + CorrelationId(uint64_t id_uint); + CorrelationId(const CorrelationId& rhs); + CorrelationId(std::unique_ptr& correlation_id_shm); + CorrelationId& operator=(const CorrelationId& rhs); + + /// Save CorrelationId object to shared memory. + /// \param shm_pool Shared memory pool to save the CorrelationId object. + void SaveToSharedMemory(std::unique_ptr& shm_pool); + + /// Create a CorrelationId object from shared memory. + /// \param shm_pool Shared memory pool + /// \param handle Shared memory handle of the CorrelationId. + /// \return Returns the CorrelationId in the specified handle + /// location. + static std::unique_ptr LoadFromSharedMemory( + std::unique_ptr& shm_pool, + bi::managed_external_buffer::handle_t handle); + + // Function that help determine exact type of Correlation Id + CorrelationIdDataType Type() const { return id_type_; } + + // Get the value of the CorrelationId based on the type + const std::string& StringValue() const { return id_string_; } + uint64_t UnsignedIntValue() const { return id_uint_; } + + bi::managed_external_buffer::handle_t ShmHandle() { return shm_handle_; } + + private: + // The private constructor for creating a CorrelationId object from shared + // memory. + CorrelationId( + AllocatedSharedMemory& correlation_id_shm, + std::unique_ptr& id_string_shm); + + std::string id_string_; + uint64_t id_uint_; + CorrelationIdDataType id_type_; + + // Shared Memory Data Structures + AllocatedSharedMemory correlation_id_shm_; + CorrelationIdShm* correlation_id_shm_ptr_; + bi::managed_external_buffer::handle_t shm_handle_; + std::unique_ptr id_string_shm_; +}; + +}}}; // namespace triton::backend::python diff --git a/src/infer_payload.cc b/src/infer_payload.cc index 762201e8..6baad307 100644 --- a/src/infer_payload.cc +++ b/src/infer_payload.cc @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -31,7 +31,8 @@ namespace triton { namespace backend { namespace python { InferPayload::InferPayload( const bool is_decoupled, std::function)> callback) - : is_decoupled_(is_decoupled), is_promise_set_(false), callback_(callback) + : is_decoupled_(is_decoupled), is_promise_set_(false), callback_(callback), + request_address_(reinterpret_cast(nullptr)) { promise_.reset(new std::promise>()); } @@ -91,4 +92,31 @@ InferPayload::ResponseAllocUserp() return response_alloc_userp_; } +void +InferPayload::SetRequestAddress(intptr_t request_address) +{ + std::unique_lock lock(request_address_mutex_); + request_address_ = request_address; +} + +void +InferPayload::SetRequestCancellationFunc( + const std::function& request_cancel_func) +{ + request_cancel_func_ = request_cancel_func; +} + +void +InferPayload::SafeCancelRequest() +{ + std::unique_lock lock(request_address_mutex_); + if (request_address_ == 0L) { + return; + } + + if (request_cancel_func_) { + request_cancel_func_(request_address_); + } +} + }}} // namespace triton::backend::python diff --git a/src/infer_payload.h b/src/infer_payload.h index 662e8922..8e4aa7d3 100644 --- a/src/infer_payload.h +++ b/src/infer_payload.h @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -62,6 +62,10 @@ class InferPayload : public std::enable_shared_from_this { void SetResponseAllocUserp( const ResponseAllocatorUserp& response_alloc_userp); std::shared_ptr ResponseAllocUserp(); + void SetRequestAddress(intptr_t request_address); + void SetRequestCancellationFunc( + const std::function& request_cancel_func); + void SafeCancelRequest(); private: std::unique_ptr>> promise_; @@ -70,6 +74,9 @@ class InferPayload : public std::enable_shared_from_this { bool is_promise_set_; std::function)> callback_; std::shared_ptr response_alloc_userp_; + std::mutex request_address_mutex_; + intptr_t request_address_; + std::function request_cancel_func_; }; }}} // namespace triton::backend::python diff --git a/src/infer_request.cc b/src/infer_request.cc index 3ecde9e8..e5733662 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -38,18 +38,19 @@ namespace triton { namespace backend { namespace python { InferRequest::InferRequest( - const std::string& request_id, uint64_t correlation_id, + const std::string& request_id, const CorrelationId& correlation_id, const std::vector>& inputs, const std::set& requested_output_names, const std::string& model_name, const int64_t model_version, - const std::string& parameters, const uint32_t flags, const int32_t timeout, + const std::string& parameters, const uint32_t flags, const uint64_t timeout, const intptr_t response_factory_address, const intptr_t request_address, - const PreferredMemory& preferred_memory) + const PreferredMemory& preferred_memory, const InferenceTrace& trace) : request_id_(request_id), correlation_id_(correlation_id), inputs_(inputs), requested_output_names_(requested_output_names), model_name_(model_name), model_version_(model_version), parameters_(parameters), flags_(flags), timeout_(timeout), response_factory_address_(response_factory_address), - request_address_(request_address), preferred_memory_(preferred_memory) + request_address_(request_address), preferred_memory_(preferred_memory), + trace_(trace), request_release_flags_(TRITONSERVER_REQUEST_RELEASE_ALL) { for (auto& input : inputs) { if (!input) { @@ -67,12 +68,13 @@ InferRequest::InferRequest( } } - inputs_ = inputs; - requested_output_names_ = requested_output_names; #ifdef TRITON_PB_STUB + pb_cancel_ = + std::make_shared(response_factory_address_, request_address_); response_sender_ = std::make_shared( - request_address_, response_factory_address_, - Stub::GetOrCreateInstance()->SharedMemory()); + request_address_, response_factory_address_, nullptr /* is_decoupled */, + RequestedOutputNames(), Stub::GetOrCreateInstance()->SharedMemory(), + pb_cancel_); #endif } @@ -94,8 +96,8 @@ InferRequest::RequestId() return request_id_; } -uint64_t -InferRequest::CorrelationId() +CorrelationId& +InferRequest::GetCorrelationId() { return correlation_id_; } @@ -142,7 +144,7 @@ InferRequest::ShmHandle() return shm_handle_; } -int32_t +uint64_t InferRequest::Timeout() { return timeout_; @@ -166,6 +168,26 @@ InferRequest::GetPreferredMemory() return preferred_memory_; } +InferenceTrace& +InferRequest::GetTrace() +{ + return trace_; +} + +uint32_t +InferRequest::ReleaseFlags() +{ + request_release_flags_ = infer_request_shm_ptr_->request_release_flags; + return request_release_flags_; +} + +void +InferRequest::SetReleaseFlags(const uint32_t& flags) +{ + request_release_flags_ = flags; + infer_request_shm_ptr_->request_release_flags = request_release_flags_; +} + void InferRequest::SaveToSharedMemory(std::unique_ptr& shm_pool) { @@ -173,14 +195,10 @@ InferRequest::SaveToSharedMemory(std::unique_ptr& shm_pool) sizeof(InferRequestShm) + (RequestedOutputNames().size() * sizeof(bi::managed_external_buffer::handle_t)) + - (Inputs().size() * sizeof(bi::managed_external_buffer::handle_t)) + - PbString::ShmStructSize(ModelName()) + - PbString::ShmStructSize(RequestId()) + - PbString::ShmStructSize(Parameters())); + (Inputs().size() * sizeof(bi::managed_external_buffer::handle_t))); infer_request_shm_ptr_ = reinterpret_cast(infer_request_shm.data_.get()); - infer_request_shm_ptr_->correlation_id = CorrelationId(); infer_request_shm_ptr_->input_count = Inputs().size(); infer_request_shm_ptr_->model_version = model_version_; infer_request_shm_ptr_->requested_output_count = @@ -191,6 +209,7 @@ InferRequest::SaveToSharedMemory(std::unique_ptr& shm_pool) infer_request_shm_ptr_->is_decoupled = is_decoupled_; infer_request_shm_ptr_->timeout = timeout_; infer_request_shm_ptr_->preferred_memory = preferred_memory_; + infer_request_shm_ptr_->request_release_flags = request_release_flags_; output_names_handle_shm_ptr_ = reinterpret_cast( @@ -221,30 +240,24 @@ InferRequest::SaveToSharedMemory(std::unique_ptr& shm_pool) i++; } - size_t model_name_offset = - sizeof(InferRequestShm) + - (RequestedOutputNames().size() * - sizeof(bi::managed_external_buffer::handle_t)) + - (Inputs().size() * sizeof(bi::managed_external_buffer::handle_t)); - - std::unique_ptr model_name_shm = PbString::Create( - ModelName(), - reinterpret_cast(infer_request_shm_ptr_) + model_name_offset, - infer_request_shm.handle_ + model_name_offset); - - size_t request_id_offset = - model_name_offset + PbString::ShmStructSize(ModelName()); - std::unique_ptr request_id_shm = PbString::Create( - RequestId(), - reinterpret_cast(infer_request_shm_ptr_) + request_id_offset, - infer_request_shm.handle_ + request_id_offset); - - size_t parameters_offset = - request_id_offset + PbString::ShmStructSize(RequestId()); - std::unique_ptr parameters_shm = PbString::Create( - Parameters(), - reinterpret_cast(infer_request_shm_ptr_) + parameters_offset, - infer_request_shm.handle_ + parameters_offset); + correlation_id_.SaveToSharedMemory(shm_pool); + infer_request_shm_ptr_->correlation_id_shm_handle = + correlation_id_.ShmHandle(); + + std::unique_ptr model_name_shm = + PbString::Create(shm_pool, ModelName()); + infer_request_shm_ptr_->model_name_shm_handle = model_name_shm->ShmHandle(); + + std::unique_ptr request_id_shm = + PbString::Create(shm_pool, RequestId()); + infer_request_shm_ptr_->request_id_shm_handle = request_id_shm->ShmHandle(); + + std::unique_ptr parameters_shm = + PbString::Create(shm_pool, Parameters()); + infer_request_shm_ptr_->parameters_shm_handle = parameters_shm->ShmHandle(); + + trace_.SaveToSharedMemory(shm_pool); + infer_request_shm_ptr_->trace_shm_handle = trace_.ShmHandle(); // Save the references to shared memory. infer_request_shm_ = std::move(infer_request_shm); @@ -258,7 +271,8 @@ InferRequest::SaveToSharedMemory(std::unique_ptr& shm_pool) std::unique_ptr InferRequest::LoadFromSharedMemory( std::unique_ptr& shm_pool, - bi::managed_external_buffer::handle_t request_handle, bool open_cuda_handle) + bi::managed_external_buffer::handle_t request_handle, bool open_cuda_handle, + bool const* is_model_decoupled) { AllocatedSharedMemory infer_request_shm = shm_pool->Load(request_handle); @@ -296,38 +310,37 @@ InferRequest::LoadFromSharedMemory( input_tensors.emplace_back(std::move(input_tensor)); } - size_t model_name_offset = - sizeof(InferRequestShm) + - (requested_output_count * sizeof(bi::managed_external_buffer::handle_t)) + - (infer_request_shm_ptr->input_count * - sizeof(bi::managed_external_buffer::handle_t)); + std::unique_ptr correlation_id_shm = + CorrelationId::LoadFromSharedMemory( + shm_pool, infer_request_shm_ptr->correlation_id_shm_handle); - std::unique_ptr model_name_shm = PbString::LoadFromSharedMemory( - request_handle + model_name_offset, - reinterpret_cast(infer_request_shm_ptr) + model_name_offset); + std::unique_ptr infer_trace_shm = + InferenceTrace::LoadFromSharedMemory( + shm_pool, infer_request_shm_ptr->trace_shm_handle); - size_t request_id_offset = model_name_offset + model_name_shm->Size(); + std::unique_ptr model_name_shm = PbString::LoadFromSharedMemory( + shm_pool, infer_request_shm_ptr->model_name_shm_handle); std::unique_ptr request_id_shm = PbString::LoadFromSharedMemory( - request_handle + request_id_offset, - reinterpret_cast(infer_request_shm_ptr) + request_id_offset); - - size_t parameters_offset = request_id_offset + request_id_shm->Size(); + shm_pool, infer_request_shm_ptr->request_id_shm_handle); std::unique_ptr parameters_shm = PbString::LoadFromSharedMemory( - request_handle + request_id_offset, - reinterpret_cast(infer_request_shm_ptr) + parameters_offset); + shm_pool, infer_request_shm_ptr->parameters_shm_handle); return std::unique_ptr(new InferRequest( - infer_request_shm, request_id_shm, requested_output_names_shm, - model_name_shm, input_tensors, parameters_shm)); + infer_request_shm, request_id_shm, correlation_id_shm, + requested_output_names_shm, model_name_shm, input_tensors, parameters_shm, + infer_trace_shm, is_model_decoupled)); } InferRequest::InferRequest( AllocatedSharedMemory& infer_request_shm, std::unique_ptr& request_id_shm, + std::unique_ptr& correlation_id_shm, std::vector>& requested_output_names_shm, std::unique_ptr& model_name_shm, std::vector>& input_tensors, - std::unique_ptr& parameters_shm) + std::unique_ptr& parameters_shm, + std::unique_ptr& infer_trace_shm, + bool const* is_model_decoupled) : infer_request_shm_(std::move(infer_request_shm)), request_id_shm_(std::move(request_id_shm)), requested_output_names_shm_(std::move(requested_output_names_shm)), @@ -356,57 +369,54 @@ InferRequest::InferRequest( requested_output_names.emplace(pb_string->String()); } + correlation_id_ = CorrelationId(correlation_id_shm); request_id_ = request_id_shm_->String(); parameters_ = parameters_shm_->String(); requested_output_names_ = std::move(requested_output_names); model_name_ = model_name_shm_->String(); flags_ = infer_request_shm_ptr_->flags; model_version_ = infer_request_shm_ptr_->model_version; - correlation_id_ = infer_request_shm_ptr_->correlation_id; request_address_ = infer_request_shm_ptr_->address; response_factory_address_ = infer_request_shm_ptr_->response_factory_address; is_decoupled_ = infer_request_shm_ptr_->is_decoupled; timeout_ = infer_request_shm_ptr_->timeout; preferred_memory_ = infer_request_shm_ptr_->preferred_memory; + trace_ = InferenceTrace(infer_trace_shm); + request_release_flags_ = infer_request_shm_ptr_->request_release_flags; #ifdef TRITON_PB_STUB + pb_cancel_ = + std::make_shared(response_factory_address_, request_address_); response_sender_ = std::make_shared( - request_address_, response_factory_address_, - Stub::GetOrCreateInstance()->SharedMemory()); + request_address_, response_factory_address_, is_model_decoupled, + RequestedOutputNames(), Stub::GetOrCreateInstance()->SharedMemory(), + pb_cancel_); #endif } -#ifndef TRITON_PB_STUB -TRITONSERVER_Error* -InferRequest::DeleteResponseFactory() +#ifdef TRITON_PB_STUB +bool +InferRequest::IsCancelled() { - TRITONBACKEND_ResponseFactory* response_factory = - reinterpret_cast( - response_factory_address_); - TRITONSERVER_Error* error = - TRITONBACKEND_ResponseFactoryDelete(response_factory); - - return error; + return pb_cancel_->IsCancelled(); } -#endif -#ifdef TRITON_PB_STUB std::shared_ptr InferRequest::GetResponseSender() { - std::unique_ptr& stub = Stub::GetOrCreateInstance(); - if (!stub->IsDecoupled()) { - throw PythonBackendException( - "'get_response_sender' function must be called only when the model is " - "using the decoupled transaction policy."); - } - return response_sender_; } std::shared_ptr InferRequest::Exec(const bool is_decoupled) { + // Release the GIL. This avoids a potential deadlock situation in the parent + // process, where every thread in the thread pool is indirectly waiting for a + // function in the stub process that acquires the GIL. Meanwhile, the current + // thread, which holds the GIL, is also waiting for the parent side to have + // the next available thread to pick up the job during resource contention. + py::gil_scoped_release release; + // BLS should not be used in "initialize" or "finalize" function. std::unique_ptr& stub = Stub::GetOrCreateInstance(); if (!stub->IsInitialized() || stub->IsFinalizing()) { @@ -430,7 +440,6 @@ InferRequest::Exec(const bool is_decoupled) }); try { - py::gil_scoped_release release; ipc_message = IPCMessage::Create(shm_pool, true /* inline_response */); bool has_exception = false; PythonBackendException pb_exception(std::string{}); @@ -475,7 +484,7 @@ InferRequest::Exec(const bool is_decoupled) { bi::scoped_lock lock{ *(ipc_message->ResponseMutex())}; - stub->SendIPCMessage(ipc_message); + stub->SendIPCUtilsMessage(ipc_message); ipc_message->ResponseCondition()->wait(lock); } @@ -579,7 +588,7 @@ InferRequest::Exec(const bool is_decoupled) if (!output_tensor->IsCPU()) { uint64_t memory_release_id = output_tensor->Memory()->MemoryReleaseId(); output_tensor->Memory()->SetMemoryReleaseCallback( - [&memory_manager_message_queue, memory_release_id]() { + [&memory_manager_message_queue, memory_release_id, &shm_pool]() { memory_manager_message_queue->Push(memory_release_id); }); } diff --git a/src/infer_request.h b/src/infer_request.h index 7eb2fd88..f368d692 100644 --- a/src/infer_request.h +++ b/src/infer_request.h @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -29,11 +29,14 @@ #include #include +#include "correlation_id.h" #include "infer_response.h" +#include "infer_trace.h" #include "pb_preferred_memory.h" #include "pb_tensor.h" #ifdef TRITON_PB_STUB +#include "pb_cancel.h" #include "response_sender.h" #endif @@ -45,7 +48,6 @@ class Stub; // Inference Request // struct InferRequestShm { - uint64_t correlation_id; uint32_t input_count; uint32_t requested_output_count; int64_t model_version; @@ -53,41 +55,53 @@ struct InferRequestShm { intptr_t address; intptr_t response_factory_address; bool is_decoupled; - int32_t timeout; + uint64_t timeout; PreferredMemory preferred_memory; + bi::managed_external_buffer::handle_t trace_shm_handle; + uint32_t request_release_flags; + bi::managed_external_buffer::handle_t correlation_id_shm_handle; + bi::managed_external_buffer::handle_t model_name_shm_handle; + bi::managed_external_buffer::handle_t request_id_shm_handle; + bi::managed_external_buffer::handle_t parameters_shm_handle; }; class InferRequest { public: InferRequest( - const std::string& request_id, uint64_t correlation_id, + const std::string& request_id, const CorrelationId& correlation_id, const std::vector>& inputs, const std::set& requested_output_names, const std::string& model_name, const int64_t model_version, const std::string& parameters, const uint32_t flags = 0, - const int32_t timeout = 0, const intptr_t response_factory_address = 0, + const uint64_t timeout = 0, const intptr_t response_factory_address = 0, const intptr_t request_address = 0, const PreferredMemory& preferred_memory = - PreferredMemory(PreferredMemory::DEFAULT, 0)); + PreferredMemory(PreferredMemory::kDefault, 0), + const InferenceTrace& trace = InferenceTrace()); const std::vector>& Inputs(); const std::string& RequestId(); const std::string& Parameters(); - uint64_t CorrelationId(); + CorrelationId& GetCorrelationId(); const std::string& ModelName(); int64_t ModelVersion(); uint32_t Flags(); void SetFlags(uint32_t flags); const std::set& RequestedOutputNames(); bi::managed_external_buffer::handle_t ShmHandle(); - int32_t Timeout(); + uint64_t Timeout(); bool IsDecoupled(); void SetIsDecoupled(const bool is_decoupled); PreferredMemory& GetPreferredMemory(); + InferenceTrace& GetTrace(); + uint32_t ReleaseFlags(); + void SetReleaseFlags(const uint32_t& flags); + intptr_t GetResponseFactoryAddress() { return response_factory_address_; } #ifdef TRITON_PB_STUB std::shared_ptr Exec(const bool is_decoupled); std::shared_ptr GetResponseSender(); + bool IsCancelled(); #endif /// Save an Inference Request to shared memory. @@ -105,7 +119,7 @@ class InferRequest { static std::unique_ptr LoadFromSharedMemory( std::unique_ptr& shm_pool, bi::managed_external_buffer::handle_t request_handle, - bool open_cuda_handle); + bool open_cuda_handle, bool const* is_model_decoupled); /// Disallow copying the inference request object. DISALLOW_COPY_AND_ASSIGN(InferRequest); @@ -113,32 +127,33 @@ class InferRequest { intptr_t RequestAddress(); ~InferRequest() {} -#ifndef TRITON_PB_STUB - TRITONSERVER_Error* DeleteResponseFactory(); -#endif - private: InferRequest( AllocatedSharedMemory& infer_request_shm, std::unique_ptr& request_id_shm, + std::unique_ptr& correlation_id, std::vector>& requested_output_names_shm, std::unique_ptr& model_name_shm, std::vector>& input_tensors, - std::unique_ptr& parameters_shm); + std::unique_ptr& parameters_shm, + std::unique_ptr& infer_trace_shm, + bool const* is_model_decoupled); std::string request_id_; - uint64_t correlation_id_; + CorrelationId correlation_id_; std::vector> inputs_; std::set requested_output_names_; std::string model_name_; int64_t model_version_; std::string parameters_; uint32_t flags_; - int32_t timeout_; + uint64_t timeout_; intptr_t response_factory_address_; intptr_t request_address_; bool is_decoupled_; PreferredMemory preferred_memory_; + InferenceTrace trace_; + uint32_t request_release_flags_; // Shared Memory Data Structures AllocatedSharedMemory infer_request_shm_; @@ -153,6 +168,7 @@ class InferRequest { std::unique_ptr parameters_shm_; #ifdef TRITON_PB_STUB + std::shared_ptr pb_cancel_; std::shared_ptr response_sender_; #endif }; diff --git a/src/infer_response.cc b/src/infer_response.cc index afadc324..382756d4 100644 --- a/src/infer_response.cc +++ b/src/infer_response.cc @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -39,8 +39,10 @@ namespace triton { namespace backend { namespace python { InferResponse::InferResponse( const std::vector>& output_tensors, - std::shared_ptr error, const bool is_last_response, void* id) - : error_(error), is_last_response_(is_last_response), id_(id) + std::shared_ptr error, std::string parameters, + const bool is_last_response, void* id) + : error_(error), is_last_response_(is_last_response), id_(id), + parameters_(std::move(parameters)) { for (auto& output : output_tensors) { if (!output) { @@ -58,6 +60,12 @@ InferResponse::OutputTensors() return output_tensors_; } +const std::string& +InferResponse::Parameters() const +{ + return parameters_; +} + bool InferResponse::HasError() { @@ -83,6 +91,7 @@ InferResponse::SaveToSharedMemory( response_shm_ptr->is_error_set = false; shm_handle_ = response_shm_.handle_; response_shm_ptr->is_last_response = is_last_response_; + response_shm_ptr->id = id_; // Only save the output tensors to shared memory when the inference response // doesn't have error. @@ -105,7 +114,9 @@ InferResponse::SaveToSharedMemory( tensor_handle_shm_ptr[j] = output_tensor->ShmHandle(); j++; } - response_shm_ptr->id = id_; + + parameters_shm_ = PbString::Create(shm_pool, parameters_); + response_shm_ptr->parameters = parameters_shm_->ShmHandle(); } } @@ -143,6 +154,8 @@ InferResponse::LoadFromSharedMemory( std::shared_ptr pb_error; std::vector> output_tensors; + std::shared_ptr parameters_shm; + std::string parameters; // If the error field is set, do not load output tensors from shared memory. if (response_shm_ptr->has_error && response_shm_ptr->is_error_set) { @@ -154,26 +167,35 @@ InferResponse::LoadFromSharedMemory( bi::managed_external_buffer::handle_t* tensor_handle_shm = reinterpret_cast( response_shm.data_.get() + sizeof(ResponseShm)); + { #ifdef TRITON_PB_STUB - // Need to acquire the GIL to avoid hangs. - py::gil_scoped_acquire acquire; + // PbTensor::LoadFromSharedMemory() will construct Python objects if + // called from pb_stub, which requires holding the GIL. + py::gil_scoped_acquire acquire; #endif - for (size_t idx = 0; idx < requested_output_count; ++idx) { - std::shared_ptr pb_tensor = PbTensor::LoadFromSharedMemory( - shm_pool, tensor_handle_shm[idx], open_cuda_handle); - output_tensors.emplace_back(std::move(pb_tensor)); + for (size_t idx = 0; idx < requested_output_count; ++idx) { + std::shared_ptr pb_tensor = PbTensor::LoadFromSharedMemory( + shm_pool, tensor_handle_shm[idx], open_cuda_handle); + output_tensors.emplace_back(std::move(pb_tensor)); + } } + + parameters_shm = std::move( + PbString::LoadFromSharedMemory(shm_pool, response_shm_ptr->parameters)); + parameters = parameters_shm->String(); } return std::unique_ptr(new InferResponse( response_shm, output_tensors, pb_error, - response_shm_ptr->is_last_response, response_shm_ptr->id)); + response_shm_ptr->is_last_response, response_shm_ptr->id, parameters_shm, + parameters)); } InferResponse::InferResponse( AllocatedSharedMemory& response_shm, std::vector>& output_tensors, - std::shared_ptr& pb_error, const bool is_last_response, void* id) + std::shared_ptr& pb_error, const bool is_last_response, void* id, + std::shared_ptr& parameters_shm, std::string& parameters) { response_shm_ = std::move(response_shm); output_tensors_ = std::move(output_tensors); @@ -181,6 +203,8 @@ InferResponse::InferResponse( shm_handle_ = response_shm_.handle_; id_ = id; is_last_response_ = is_last_response; + parameters_shm_ = std::move(parameters_shm); + parameters_ = std::move(parameters); } std::shared_ptr& @@ -211,6 +235,10 @@ InferResponse::Send( std::vector, void*>>& output_buffers, const std::set& requested_output_names) { +#ifdef TRITON_ENABLE_GPU + static bool log_warning = true; +#endif // TRITON_ENABLE_GPU + std::shared_ptr response_error = WrapTritonErrorInSharedPtr(nullptr); std::unique_ptr response_error_handling; @@ -243,8 +271,8 @@ InferResponse::Send( }); if (HasError()) { - *response_error = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, Error()->Message().c_str()); + *response_error = + TRITONSERVER_ErrorNew(Error()->Code(), Error()->Message().c_str()); return; } @@ -270,11 +298,12 @@ InferResponse::Send( static_cast(output_tensor->TritonDtype()), output_tensor->Dims().data(), output_tensor->Dims().size())); - void* buffer; + void* triton_output_buffer; SET_ERROR_AND_RETURN( - response_error, TRITONBACKEND_OutputBuffer( - response_output, &buffer, output_tensor->ByteSize(), - &actual_memory_type, &actual_memory_type_id)); + response_error, + TRITONBACKEND_OutputBuffer( + response_output, &triton_output_buffer, output_tensor->ByteSize(), + &actual_memory_type, &actual_memory_type_id)); bool cuda_used = false; TRITONSERVER_BufferAttributes* output_buffer_attributes; @@ -286,6 +315,40 @@ InferResponse::Send( if (src_memory_type == TRITONSERVER_MEMORY_GPU && actual_memory_type == TRITONSERVER_MEMORY_GPU) { #ifdef TRITON_ENABLE_GPU + // Check if the triton-provided output buffer is using CUDA shared memory + // pool. If not, try to allocate a new buffer from the pool. + void* buffer = triton_output_buffer; + BackendMemory* backend_memory; + std::unique_ptr lbackend_memory; + std::unique_ptr& cuda_pool = + shm_pool->GetCUDAMemoryPoolManager(); + if (cuda_pool->UseCudaSharedPool(src_memory_type_id)) { + try { + if (!IsUsingCUDAPool( + cuda_pool, actual_memory_type_id, triton_output_buffer)) { + THROW_IF_TRITON_ERROR(BackendMemory::Create( + reinterpret_cast( + shm_pool->GetCUDAMemoryPoolManager() + ->TritonMemoryManager()), + BackendMemory::AllocationType::GPU_POOL, actual_memory_type_id, + output_tensor->ByteSize(), &backend_memory)); + lbackend_memory.reset(backend_memory); + buffer = lbackend_memory->MemoryPtr(); + } + } + catch (const PythonBackendException& pb_exception) { + if (log_warning) { + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, + (std::string("Failed to allocate memory from CUDA memory pool " + "for output tensor: ") + + pb_exception.what() + + std::string(", will use CUDA IPC for GPU output transfer.")) + .c_str()); + } + log_warning = false; + } + } cudaIpcMemHandle_t* cuda_ipc_mem_handle_p; SET_ERROR_AND_RETURN( response_error, @@ -309,8 +372,13 @@ InferResponse::Send( output_tensor->ByteSize(), reinterpret_cast(buffer), true /* copy_gpu */)); } + + if (lbackend_memory != nullptr) { + output_buffer->SetBackendMemory(std::move(lbackend_memory)); + } gpu_buffer_helper.AddBuffer(output_buffer->ShmHandle()); - output_buffers.push_back({std::move(output_buffer), buffer}); + output_buffers.push_back( + {std::move(output_buffer), triton_output_buffer}); #endif } @@ -325,7 +393,8 @@ InferResponse::Send( output_tensor->ByteSize(), nullptr /* data ptr */)); gpu_buffer_helper.AddBuffer(output_buffer->ShmHandle()); - output_buffers.push_back({std::move(output_buffer), buffer}); + output_buffers.push_back( + {std::move(output_buffer), triton_output_buffer}); } if (src_memory_type != TRITONSERVER_MEMORY_GPU) { @@ -334,13 +403,46 @@ InferResponse::Send( CopyBuffer( "Failed to copy the output tensor to buffer.", src_memory_type, src_memory_type_id, actual_memory_type, actual_memory_type_id, - output_tensor->ByteSize(), output_tensor->DataPtr(), buffer, - reinterpret_cast(cuda_stream), &cuda_used)); + output_tensor->ByteSize(), output_tensor->DataPtr(), + triton_output_buffer, reinterpret_cast(cuda_stream), + &cuda_used)); } cuda_copy |= cuda_used; } + if (!parameters_.empty()) { + triton::common::TritonJson::Value param; + THROW_IF_TRITON_ERROR( + param.Parse(parameters_.c_str(), parameters_.length())); + std::vector param_keys; + THROW_IF_TRITON_ERROR(param.Members(¶m_keys)); + for (const auto& key : param_keys) { + triton::common::TritonJson::Value value; + if (!param.Find(key.c_str(), &value)) { + throw PythonBackendException("Unexpected missing key on parameters"); + } + if (value.IsString()) { + std::string string_value; + THROW_IF_TRITON_ERROR(value.AsString(&string_value)); + THROW_IF_TRITON_ERROR(TRITONBACKEND_ResponseSetStringParameter( + response, key.c_str(), string_value.c_str())); + } else if (value.IsInt()) { + int64_t int_value = 0; + THROW_IF_TRITON_ERROR(value.AsInt(&int_value)); + THROW_IF_TRITON_ERROR(TRITONBACKEND_ResponseSetIntParameter( + response, key.c_str(), int_value)); + } else if (value.IsBool()) { + bool bool_value = false; + THROW_IF_TRITON_ERROR(value.AsBool(&bool_value)); + THROW_IF_TRITON_ERROR(TRITONBACKEND_ResponseSetBoolParameter( + response, key.c_str(), bool_value)); + } else { + throw PythonBackendException("Unsupported value type on parameters"); + } + } + } + #ifdef TRITON_ENABLE_GPU if (cuda_copy) { cudaStreamSynchronize(reinterpret_cast(cuda_stream)); diff --git a/src/infer_response.h b/src/infer_response.h index bdf31bb4..ab8eb68a 100644 --- a/src/infer_response.h +++ b/src/infer_response.h @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -38,6 +38,7 @@ namespace triton { namespace backend { namespace python { struct ResponseShm { uint32_t outputs_size; + bi::managed_external_buffer::handle_t parameters; bi::managed_external_buffer::handle_t error; bool has_error; // Indicates whether this error has a message or not. @@ -72,9 +73,10 @@ class InferResponse { public: InferResponse( const std::vector>& output_tensors, - std::shared_ptr error = nullptr, + std::shared_ptr error = nullptr, std::string parameters = "", const bool is_last_response = true, void* id = nullptr); std::vector>& OutputTensors(); + const std::string& Parameters() const; // JSON serializable unless empty void SaveToSharedMemory( std::unique_ptr& shm_pool, bool copy_gpu = true); static std::unique_ptr LoadFromSharedMemory( @@ -116,8 +118,8 @@ class InferResponse { InferResponse( AllocatedSharedMemory& response_shm, std::vector>& output_tensors, - std::shared_ptr& pb_error, const bool is_last_response, - void* id); + std::shared_ptr& pb_error, const bool is_last_response, void* id, + std::shared_ptr& parameters_shm, std::string& parameters); std::vector> output_tensors_; std::shared_ptr error_; @@ -128,6 +130,9 @@ class InferResponse { bool is_last_response_; // Representing the request id that the response was created from. void* id_; + + std::shared_ptr parameters_shm_; + std::string parameters_; }; }}} // namespace triton::backend::python diff --git a/src/infer_trace.cc b/src/infer_trace.cc new file mode 100644 index 00000000..50645dcc --- /dev/null +++ b/src/infer_trace.cc @@ -0,0 +1,101 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "infer_trace.h" + +namespace triton { namespace backend { namespace python { + +InferenceTrace::InferenceTrace(const InferenceTrace& rhs) +{ + triton_trace_ = rhs.triton_trace_; + trace_context_ = rhs.trace_context_; +} + +InferenceTrace& +InferenceTrace::operator=(const InferenceTrace& rhs) +{ + triton_trace_ = rhs.triton_trace_; + trace_context_ = rhs.trace_context_; + return *this; +} + +InferenceTrace::InferenceTrace(std::unique_ptr& trace_shm) +{ + triton_trace_ = trace_shm->triton_trace_; + trace_context_ = trace_shm->trace_context_; +} + +void +InferenceTrace::SaveToSharedMemory( + std::unique_ptr& shm_pool) +{ + AllocatedSharedMemory infer_trace_shm = + shm_pool->Construct(); + infer_trace_shm_ptr_ = infer_trace_shm.data_.get(); + + infer_trace_shm_ptr_->triton_trace = triton_trace_; + + std::unique_ptr trace_context_shm = + PbString::Create(shm_pool, trace_context_); + + infer_trace_shm_ptr_->trace_context_shm_handle = + trace_context_shm->ShmHandle(); + + // Save the references to shared memory. + trace_context_shm_ = std::move(trace_context_shm); + infer_trace_shm_ = std::move(infer_trace_shm); + shm_handle_ = infer_trace_shm_.handle_; +} + +std::unique_ptr +InferenceTrace::LoadFromSharedMemory( + std::unique_ptr& shm_pool, + bi::managed_external_buffer::handle_t handle) +{ + AllocatedSharedMemory infer_trace_shm = + shm_pool->Load(handle); + InferenceTraceShm* infer_trace_shm_ptr = infer_trace_shm.data_.get(); + + std::unique_ptr trace_context_shm = PbString::LoadFromSharedMemory( + shm_pool, infer_trace_shm_ptr->trace_context_shm_handle); + + return std::unique_ptr( + new InferenceTrace(infer_trace_shm, trace_context_shm)); +} + +InferenceTrace::InferenceTrace( + AllocatedSharedMemory& infer_trace_shm, + std::unique_ptr& trace_context_shm) + : infer_trace_shm_(std::move(infer_trace_shm)), + trace_context_shm_(std::move(trace_context_shm)) +{ + infer_trace_shm_ptr_ = infer_trace_shm_.data_.get(); + shm_handle_ = infer_trace_shm_.handle_; + triton_trace_ = infer_trace_shm_ptr_->triton_trace; + trace_context_ = trace_context_shm_->String(); +} + +}}}; // namespace triton::backend::python diff --git a/src/infer_trace.h b/src/infer_trace.h new file mode 100644 index 00000000..aac9137f --- /dev/null +++ b/src/infer_trace.h @@ -0,0 +1,90 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include + +#include "pb_string.h" +#include "pb_utils.h" + +namespace triton { namespace backend { namespace python { + +struct InferenceTraceShm { + bi::managed_external_buffer::handle_t trace_context_shm_handle; + // The address of the 'TRITONSERVER_InferTrace' object. + void* triton_trace; +}; + +// +// Inference Trace +// +class InferenceTrace { + public: + InferenceTrace(void* triton_trace, const std::string& ctxt) + : triton_trace_(triton_trace), trace_context_(ctxt) + { + } + InferenceTrace() : triton_trace_(nullptr), trace_context_("") {} + InferenceTrace(const InferenceTrace& rhs); + InferenceTrace(std::unique_ptr& trace_shm); + InferenceTrace& operator=(const InferenceTrace& rhs); + /// Save InferenceTrace object to shared memory. + /// \param shm_pool Shared memory pool to save the InferenceTrace object. + void SaveToSharedMemory(std::unique_ptr& shm_pool); + + /// Create a InferenceTrace object from shared memory. + /// \param shm_pool Shared memory pool + /// \param handle Shared memory handle of the InferenceTrace. + /// \return Returns the InferenceTrace in the specified handle + /// location. + static std::unique_ptr LoadFromSharedMemory( + std::unique_ptr& shm_pool, + bi::managed_external_buffer::handle_t handle); + + void* TritonTrace() { return triton_trace_; } + const std::string& Context() const { return trace_context_; } + + bi::managed_external_buffer::handle_t ShmHandle() { return shm_handle_; } + + private: + // The private constructor for creating a InferenceTrace object from shared + // memory. + InferenceTrace( + AllocatedSharedMemory& infer_trace_shm, + std::unique_ptr& trace_context_shm); + + void* triton_trace_; + std::string trace_context_; + + // Shared Memory Data Structures + AllocatedSharedMemory infer_trace_shm_; + InferenceTraceShm* infer_trace_shm_ptr_; + bi::managed_external_buffer::handle_t shm_handle_; + std::unique_ptr trace_context_shm_; +}; + +}}}; // namespace triton::backend::python diff --git a/src/ipc_message.cc b/src/ipc_message.cc index ea1dc5b0..2fa13ba3 100644 --- a/src/ipc_message.cc +++ b/src/ipc_message.cc @@ -56,6 +56,21 @@ IPCMessage::Create( new IPCMessage(ipc_message_shm, response_mutex_shm, response_cond_shm)); } +std::unique_ptr +IPCMessage::Create( + IPCMessageShm* ipc_message_shm, + bi::managed_external_buffer::handle_t& message_handle) +{ + return std::unique_ptr( + new IPCMessage(ipc_message_shm, message_handle)); +} + +AllocatedSharedMemory& +IPCMessage::GetAllocatedSharedMemory() +{ + return ipc_message_shm_; +} + std::unique_ptr IPCMessage::LoadFromSharedMemory( std::unique_ptr& shm_pool, @@ -133,4 +148,12 @@ IPCMessage::IPCMessage( ipc_message_handle_ = ipc_message_shm_.handle_; } +IPCMessage::IPCMessage( + IPCMessageShm* ipc_message_shm, + bi::managed_external_buffer::handle_t& handle) +{ + ipc_message_handle_ = handle; + ipc_message_shm_ptr_ = ipc_message_shm; +} + }}}; // namespace triton::backend::python diff --git a/src/ipc_message.h b/src/ipc_message.h index 7040f2b4..c0fab3a3 100644 --- a/src/ipc_message.h +++ b/src/ipc_message.h @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -41,18 +41,21 @@ typedef enum PYTHONSTUB_commandtype_enum { PYTHONSTUB_ExecuteResponse, PYTHONSTUB_InitializeRequest, PYTHONSTUB_InitializeResponse, + PYTHONSTUB_CUDAPoolInitializeRequest, PYTHONSTUB_FinalizeRequest, PYTHONSTUB_FinalizeResponse, PYTHONSTUB_LoadGPUBuffers, PYTHONSTUB_InferExecRequest, PYTHONSTUB_InferStreamExecRequest, PYTHONSTUB_InferExecResponse, + PYTHONSTUB_InferStreamExecResponse, PYTHONSTUB_ResponseSend, PYTHONSTUB_ResponseClose, PYTHONSTUB_AutoCompleteRequest, PYTHONSTUB_AutoCompleteResponse, PYTHONSTUB_LogRequest, - PYTHONSTUB_CleanupRequest, + PYTHONSTUB_BLSDecoupledInferPayloadCleanup, + PYTHONSTUB_DecoupledResponseFactoryCleanup, PYTHONSTUB_MetricFamilyRequestNew, PYTHONSTUB_MetricFamilyRequestDelete, PYTHONSTUB_MetricRequestNew, @@ -60,9 +63,12 @@ typedef enum PYTHONSTUB_commandtype_enum { PYTHONSTUB_MetricRequestValue, PYTHONSTUB_MetricRequestIncrement, PYTHONSTUB_MetricRequestSet, + PYTHONSTUB_MetricRequestObserve, PYTHONSTUB_LoadModelRequest, PYTHONSTUB_UnloadModelRequest, - PYTHONSTUB_ModelReadinessRequest + PYTHONSTUB_ModelReadinessRequest, + PYTHONSTUB_IsRequestCancelled, + PYTHONSTUB_CancelBLSInferRequest } PYTHONSTUB_CommandType; /// @@ -92,6 +98,10 @@ class IPCMessage { static std::unique_ptr Create( const std::unique_ptr& shm_pool, bool inline_response); + + static std::unique_ptr Create( + IPCMessageShm* ipc_message_shm, + bi::managed_external_buffer::handle_t& message_handle); static std::unique_ptr LoadFromSharedMemory( std::unique_ptr& shm_pool, bi::managed_external_buffer::handle_t message_handle); @@ -103,6 +113,7 @@ class IPCMessage { bi::interprocess_mutex* ResponseMutex(); bi::managed_external_buffer::handle_t& Args(); bi::managed_external_buffer::handle_t ShmHandle(); + AllocatedSharedMemory& GetAllocatedSharedMemory(); private: AllocatedSharedMemory ipc_message_shm_; @@ -124,6 +135,10 @@ class IPCMessage { AllocatedSharedMemory& ipc_message_shm, AllocatedSharedMemory& response_mutex_shm, AllocatedSharedMemory& response_cond_shm); + + IPCMessage( + IPCMessageShm* ipc_message_shm, + bi::managed_external_buffer::handle_t& handle); }; }}}; // namespace triton::backend::python diff --git a/src/memory_manager.cc b/src/memory_manager.cc index 23ac99be..716dee9e 100644 --- a/src/memory_manager.cc +++ b/src/memory_manager.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -33,29 +33,23 @@ namespace triton { namespace backend { namespace python { #ifdef TRITON_ENABLE_GPU -GPUMemoryRecord::GPUMemoryRecord(void* ptr) +BackendMemoryRecord::BackendMemoryRecord( + std::unique_ptr backend_memory) + : backend_memory_(std::move(backend_memory)) { - ptr_ = ptr; release_callback_ = [](void* ptr) { - cudaError_t err = cudaFree(ptr); - if (err != cudaSuccess) { - LOG_MESSAGE( - TRITONSERVER_LOG_ERROR, - (std::string("Failed to free the allocated cuda memory. error: ") + - cudaGetErrorString(err)) - .c_str()); - } + // Do nothing. The backend_memory_ will be destroyed in the destructor. }; } void* -GPUMemoryRecord::MemoryId() +BackendMemoryRecord::MemoryId() { - return ptr_; + return reinterpret_cast(backend_memory_->MemoryPtr()); } const std::function& -GPUMemoryRecord::ReleaseCallback() +BackendMemoryRecord::ReleaseCallback() { return release_callback_; } @@ -101,6 +95,7 @@ MemoryManager::QueueMonitorThread() // Call the release callback. it->second->ReleaseCallback()(it->second->MemoryId()); + // it->second.reset(); records_.erase(it); } } diff --git a/src/memory_manager.h b/src/memory_manager.h index 3ea6cc12..5b7e35f5 100644 --- a/src/memory_manager.h +++ b/src/memory_manager.h @@ -1,4 +1,4 @@ -// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -33,6 +33,7 @@ #include "message_queue.h" #include "triton/backend/backend_common.h" +#include "triton/backend/backend_memory.h" #include "triton/core/tritonserver.h" #ifdef TRITON_ENABLE_GPU @@ -46,17 +47,19 @@ class MemoryRecord { public: virtual const std::function& ReleaseCallback() = 0; virtual void* MemoryId() = 0; + virtual ~MemoryRecord() = default; }; #ifdef TRITON_ENABLE_GPU -class GPUMemoryRecord : public MemoryRecord { +class BackendMemoryRecord : public MemoryRecord { public: - GPUMemoryRecord(void* ptr); + BackendMemoryRecord(std::unique_ptr backend_memory); const std::function& ReleaseCallback() override; void* MemoryId() override; + ~BackendMemoryRecord() { backend_memory_.reset(); } private: - void* ptr_; + std::unique_ptr backend_memory_; std::function release_callback_; }; #endif diff --git a/src/message_queue.h b/src/message_queue.h index e9c47afd..06661c66 100644 --- a/src/message_queue.h +++ b/src/message_queue.h @@ -1,4 +1,4 @@ -// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -32,14 +32,19 @@ #include #include +#include "pb_utils.h" #include "shm_manager.h" +#ifdef TRITON_PB_STUB +#include "pb_stub_log.h" +#endif namespace triton { namespace backend { namespace python { namespace bi = boost::interprocess; /// Struct holding the representation of a message queue inside the shared /// memory. -/// \param size Total size of the message queue. +/// \param size Total size of the message queue. Considered invalid after +/// MessageQueue::LoadFromSharedMemory. Check DLIS-8378 for additional details. /// \param mutex Handle of the mutex variable protecting index. /// \param index Used element index. /// \param sem_empty Semaphore object counting the number of empty buffer slots. @@ -110,7 +115,22 @@ class MessageQueue { { bi::scoped_lock lock{*MutexMutable()}; - Buffer()[Head()] = message; + int head_idx = Head(); + // Additional check to avoid out of bounds read/write. Check DLIS-8378 for + // additional details. + if (head_idx < 0 || static_cast(head_idx) >= Size()) { + std::string error_msg = + "internal error: message queue head index out of bounds. Expects " + "positive integer less than the size of message queue " + + std::to_string(Size()) + " but got " + std::to_string(head_idx); +#ifdef TRITON_PB_STUB + LOG_ERROR << error_msg; +#else + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, error_msg.c_str()); +#endif + return; + } + Buffer()[head_idx] = message; HeadIncrement(); } SemFullMutable()->post(); @@ -145,7 +165,22 @@ class MessageQueue { } success = true; - Buffer()[Head()] = message; + int head_idx = Head(); + // Additional check to avoid out of bounds read/write. Check DLIS-8378 for + // additional details. + if (head_idx < 0 || static_cast(head_idx) >= Size()) { + std::string error_msg = + "internal error: message queue head index out of bounds. Expects " + "positive integer less than the size of message queue " + + std::to_string(Size()) + " but got " + std::to_string(head_idx); +#ifdef TRITON_PB_STUB + LOG_ERROR << error_msg; +#else + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, error_msg.c_str()); +#endif + return; + } + Buffer()[head_idx] = message; HeadIncrement(); } SemFullMutable()->post(); @@ -244,7 +279,7 @@ class MessageQueue { } private: - std::size_t& Size() { return mq_shm_ptr_->size; } + uint32_t Size() { return size_; } const bi::interprocess_mutex& Mutex() { return mq_shm_ptr_->mutex; } bi::interprocess_mutex* MutexMutable() { return &(mq_shm_ptr_->mutex); } int& Head() { return mq_shm_ptr_->head; } @@ -273,6 +308,7 @@ class MessageQueue { MessageQueueShm* mq_shm_ptr_; T* mq_buffer_shm_ptr_; bi::managed_external_buffer::handle_t mq_handle_; + uint32_t size_; /// Create/load a Message queue. /// \param mq_shm Message queue representation in shared memory. @@ -284,6 +320,7 @@ class MessageQueue { mq_buffer_shm_ptr_ = mq_buffer_shm_.data_.get(); mq_shm_ptr_ = mq_shm_.data_.get(); mq_handle_ = mq_shm_.handle_; + size_ = mq_shm_ptr_->size; } }; }}} // namespace triton::backend::python diff --git a/src/metric.cc b/src/metric.cc index f67c55bf..4c055910 100644 --- a/src/metric.cc +++ b/src/metric.cc @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -32,9 +32,12 @@ namespace triton { namespace backend { namespace python { -Metric::Metric(const std::string& labels, void* metric_family_address) - : labels_(labels), operation_value_(0), metric_address_(nullptr), - metric_family_address_(metric_family_address), is_cleared_(false) +Metric::Metric( + const std::string& labels, std::optional> buckets, + void* metric_family_address) + : labels_(labels), buckets_(buckets), operation_value_(0), + metric_address_(nullptr), metric_family_address_(metric_family_address), + is_cleared_(false) { #ifdef TRITON_PB_STUB SendCreateMetricRequest(); @@ -62,6 +65,20 @@ Metric::SaveToSharedMemory(std::unique_ptr& shm_pool) custom_metric_shm_ptr_->metric_family_address = metric_family_address_; custom_metric_shm_ptr_->metric_address = metric_address_; + // Histogram specific case + if (buckets_.has_value()) { + auto buckets_size = buckets_.value().size() * sizeof(double); + std::unique_ptr buckets_shm = PbMemory::Create( + shm_pool, TRITONSERVER_MemoryType::TRITONSERVER_MEMORY_CPU, 0, + buckets_size, reinterpret_cast(buckets_.value().data()), + false /* copy_gpu */); + custom_metric_shm_ptr_->buckets_shm_handle = buckets_shm->ShmHandle(); + buckets_shm_ = std::move(buckets_shm); + } else { + custom_metric_shm_ptr_->buckets_shm_handle = 0; + buckets_shm_ = nullptr; + } + // Save the references to shared memory. custom_metric_shm_ = std::move(custom_metric_shm); labels_shm_ = std::move(labels_shm); @@ -80,17 +97,40 @@ Metric::LoadFromSharedMemory( std::unique_ptr labels_shm = PbString::LoadFromSharedMemory( shm_pool, custom_metric_shm_ptr->labels_shm_handle); - return std::unique_ptr(new Metric(custom_metric_shm, labels_shm)); + std::unique_ptr buckets_shm = nullptr; + if (custom_metric_shm_ptr->buckets_shm_handle != 0) { + buckets_shm = PbMemory::LoadFromSharedMemory( + shm_pool, custom_metric_shm_ptr->buckets_shm_handle, + false /* open_cuda_handle */); + } + + return std::unique_ptr( + new Metric(custom_metric_shm, labels_shm, buckets_shm)); } Metric::Metric( AllocatedSharedMemory& custom_metric_shm, - std::unique_ptr& labels_shm) + std::unique_ptr& labels_shm, + std::unique_ptr& buckets_shm) : custom_metric_shm_(std::move(custom_metric_shm)), - labels_shm_(std::move(labels_shm)) + labels_shm_(std::move(labels_shm)), buckets_shm_(std::move(buckets_shm)) { custom_metric_shm_ptr_ = custom_metric_shm_.data_.get(); + + // FIXME: This constructor is called during each + // set/increment/observe/get_value call. It only needs the pointers. labels_ = labels_shm_->String(); + if (buckets_shm_ != nullptr) { // Histogram + size_t bucket_size = buckets_shm_->ByteSize() / sizeof(double); + std::vector buckets; + buckets.reserve(bucket_size); + for (size_t i = 0; i < bucket_size; ++i) { + buckets.emplace_back( + reinterpret_cast(buckets_shm_->DataPtr())[i]); + } + buckets_ = std::move(buckets); + } + operation_value_ = custom_metric_shm_ptr_->operation_value; metric_family_address_ = custom_metric_shm_ptr_->metric_family_address; metric_address_ = custom_metric_shm_ptr_->metric_address; @@ -127,6 +167,7 @@ Metric::SendCreateMetricRequest() void Metric::SendIncrementRequest(const double& value) { + py::gil_scoped_release release; try { CheckIfCleared(); std::unique_ptr& stub = Stub::GetOrCreateInstance(); @@ -161,6 +202,25 @@ Metric::SendSetValueRequest(const double& value) } } +void +Metric::SendObserveRequest(const double& value) +{ + py::gil_scoped_release release; + try { + CheckIfCleared(); + std::unique_ptr& stub = Stub::GetOrCreateInstance(); + operation_value_ = value; + SaveToSharedMemory(stub->ShmPool()); + AllocatedSharedMemory custom_metrics_shm; + stub->SendMessage( + custom_metrics_shm, PYTHONSTUB_MetricRequestObserve, shm_handle_); + } + catch (const PythonBackendException& pb_exception) { + throw PythonBackendException( + "Failed to observe metric value: " + std::string(pb_exception.what())); + } +} + double Metric::SendGetValueRequest() { @@ -222,14 +282,35 @@ Metric::InitializeTritonMetric() { std::vector labels_params; ParseLabels(labels_params, labels_); + TRITONSERVER_MetricKind kind; + THROW_IF_TRITON_ERROR(TRITONSERVER_GetMetricFamilyKind( + reinterpret_cast(metric_family_address_), + &kind)); + TRITONSERVER_MetricArgs* args = nullptr; + switch (kind) { + case TRITONSERVER_METRIC_KIND_COUNTER: + case TRITONSERVER_METRIC_KIND_GAUGE: + break; + case TRITONSERVER_METRIC_KIND_HISTOGRAM: { + const std::vector& buckets = buckets_.value(); + THROW_IF_TRITON_ERROR(TRITONSERVER_MetricArgsNew(&args)); + THROW_IF_TRITON_ERROR(TRITONSERVER_MetricArgsSetHistogram( + args, buckets.data(), buckets.size())); + break; + } + default: + break; + } + TRITONSERVER_Metric* triton_metric = nullptr; - THROW_IF_TRITON_ERROR(TRITONSERVER_MetricNew( + THROW_IF_TRITON_ERROR(TRITONSERVER_MetricNewWithArgs( &triton_metric, reinterpret_cast(metric_family_address_), - labels_params.data(), labels_params.size())); + labels_params.data(), labels_params.size(), args)); for (const auto label : labels_params) { TRITONSERVER_ParameterDelete(const_cast(label)); } + THROW_IF_TRITON_ERROR(TRITONSERVER_MetricArgsDelete(args)); return reinterpret_cast(triton_metric); } @@ -262,6 +343,8 @@ Metric::HandleMetricOperation( Increment(operation_value_); } else if (command_type == PYTHONSTUB_MetricRequestSet) { SetValue(operation_value_); + } else if (command_type == PYTHONSTUB_MetricRequestObserve) { + Observe(operation_value_); } else { throw PythonBackendException("Unknown metric operation"); } @@ -281,6 +364,13 @@ Metric::SetValue(const double& value) THROW_IF_TRITON_ERROR(TRITONSERVER_MetricSet(triton_metric, value)); } +void +Metric::Observe(const double& value) +{ + auto triton_metric = reinterpret_cast(metric_address_); + THROW_IF_TRITON_ERROR(TRITONSERVER_MetricObserve(triton_metric, value)); +} + double Metric::GetValue() { diff --git a/src/metric.h b/src/metric.h index 197e8ce9..cd54ca54 100644 --- a/src/metric.h +++ b/src/metric.h @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -26,9 +26,11 @@ #pragma once +#include #include #include "ipc_message.h" +#include "pb_memory.h" #include "pb_string.h" #include "pb_utils.h" @@ -47,6 +49,8 @@ namespace triton { namespace backend { namespace python { struct MetricShm { // The shared memory handle of the labels in PbString format. bi::managed_external_buffer::handle_t labels_shm_handle; + // The shared memory handle of the buckets in PbMemory format. + bi::managed_external_buffer::handle_t buckets_shm_handle; // The value used for incrementing or setting the metric. double operation_value; // The address of the TRITONSERVER_Metric object. @@ -58,7 +62,10 @@ struct MetricShm { class Metric { public: - Metric(const std::string& labels, void* metric_family_address); + Metric( + const std::string& labels, + std::optional> buckets, + void* metric_family_address); ~Metric(); @@ -97,6 +104,10 @@ class Metric { /// \param value The value to set the metric to. void SendSetValueRequest(const double& value); + /// Send the request to the parent process to observe the value to the metric. + /// \param value The value to set the metric to. + void SendObserveRequest(const double& value); + /// Send the request to the parent process to get the value of the metric. /// \return Returns the value of the metric. double SendGetValueRequest(); @@ -132,6 +143,10 @@ class Metric { /// \param value The value to set the metric to. void SetValue(const double& value); + /// Use Triton C API to sample the observation to the metric. + /// \param value The value to sample observation to the metric. + void Observe(const double& value); + /// Use Triton C API to get the value of the metric. double GetValue(); @@ -146,10 +161,14 @@ class Metric { // The private constructor for creating a Metric object from shared memory. Metric( AllocatedSharedMemory& custom_metric_shm, - std::unique_ptr& labels_shm); + std::unique_ptr& labels_shm, + std::unique_ptr& buckets); // The labels of the metric, which is the identifier of the metric. std::string labels_; + // Monotonically increasing values representing bucket boundaries for creating + // histogram metric. + std::optional> buckets_; // The value used for incrementing or setting the metric. double operation_value_; // The address of the TRITONSERVER_Metric object. @@ -168,6 +187,7 @@ class Metric { MetricShm* custom_metric_shm_ptr_; bi::managed_external_buffer::handle_t shm_handle_; std::unique_ptr labels_shm_; + std::unique_ptr buckets_shm_; }; }}}; // namespace triton::backend::python diff --git a/src/metric_family.cc b/src/metric_family.cc index fb0fb93a..222a0e23 100644 --- a/src/metric_family.cc +++ b/src/metric_family.cc @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -166,19 +166,39 @@ MetricFamily::SendCreateMetricFamilyRequest() } std::shared_ptr -MetricFamily::CreateMetric(const py::object& labels) +MetricFamily::CreateMetric(const py::object& labels, const py::object& buckets) { if (!labels.is_none()) { if (!py::isinstance(labels)) { throw PythonBackendException( - "Failed to create metric. Labels must be a " - "dictionary."); + "Failed to create metric. Labels must be a dictionary."); } } py::module json = py::module_::import("json"); std::string labels_str = std::string(py::str(json.attr("dumps")(labels))); - auto metric = std::make_shared(labels_str, metric_family_address_); + + std::optional> buckets_vec; + if (!buckets.is_none()) { + if (!py::isinstance(buckets)) { + throw PythonBackendException( + "Failed to create metric. Buckets must be a list."); + } + if (kind_ == kCounter || kind_ == kGauge) { + throw PythonBackendException( + "Failed to create metric. Unexpected buckets found."); + } + buckets_vec = buckets.cast>(); + } else { + if (kind_ == kHistogram) { + throw PythonBackendException( + "Failed to create metric. Missing required buckets."); + } + buckets_vec = std::nullopt; + } + + auto metric = + std::make_shared(labels_str, buckets_vec, metric_family_address_); { std::lock_guard lock(metric_map_mu_); metric_map_.insert({metric->MetricAddress(), metric}); @@ -201,10 +221,12 @@ TRITONSERVER_MetricKind MetricFamily::ToTritonServerMetricKind(const MetricKind& kind) { switch (kind) { - case COUNTER: + case kCounter: return TRITONSERVER_METRIC_KIND_COUNTER; - case GAUGE: + case kGauge: return TRITONSERVER_METRIC_KIND_GAUGE; + case kHistogram: + return TRITONSERVER_METRIC_KIND_HISTOGRAM; default: throw PythonBackendException("Unknown metric kind"); } diff --git a/src/metric_family.h b/src/metric_family.h index 04374a68..2b5f86ab 100644 --- a/src/metric_family.h +++ b/src/metric_family.h @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -97,8 +97,11 @@ class MetricFamily { /// Create a metric from the metric family and store it in the metric map. /// \param labels The labels of the metric. + /// \param buckets Monotonically increasing values representing bucket + /// boundaries for creating histogram metric. /// \return Returns the shared pointer to the created metric. - std::shared_ptr CreateMetric(const py::object& labels); + std::shared_ptr CreateMetric( + const py::object& labels, const py::object& buckets); #else /// Initialize the TRITONSERVER_MetricFamily object. /// \return Returns the address of the TRITONSERVER_MetricFamily object. @@ -128,8 +131,8 @@ class MetricFamily { std::string name_; // The description of the metric family. std::string description_; - // The metric kind of the metric family. Currently only supports GAUGE and - // COUNTER. + // The metric kind of the metric family. Currently only supports GAUGE, + // COUNTER and HISTOGRAM. MetricKind kind_; // The address of the TRITONSERVER_MetricFamily object. void* metric_family_address_; diff --git a/src/pb_bls_cancel.cc b/src/pb_bls_cancel.cc new file mode 100644 index 00000000..4341c037 --- /dev/null +++ b/src/pb_bls_cancel.cc @@ -0,0 +1,93 @@ +// Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "pb_bls_cancel.h" + +#include "pb_stub.h" +#include "pb_stub_log.h" + +namespace triton { namespace backend { namespace python { + +void +PbBLSCancel::SaveToSharedMemory(std::unique_ptr& shm_pool) +{ + cancel_shm_ = shm_pool->Construct(); + new (&(cancel_shm_.data_->mu)) bi::interprocess_mutex; + new (&(cancel_shm_.data_->cv)) bi::interprocess_condition; + cancel_shm_.data_->waiting_on_stub = false; + cancel_shm_.data_->infer_payload_id = infer_playload_id_; + cancel_shm_.data_->is_cancelled = is_cancelled_; +} + +bi::managed_external_buffer::handle_t +PbBLSCancel::ShmHandle() +{ + return cancel_shm_.handle_; +} + +CancelBLSRequestMessage* +PbBLSCancel::ShmPayload() +{ + return cancel_shm_.data_.get(); +} + +void +PbBLSCancel::Cancel() +{ + // Release the GIL. Python objects are not accessed during the check. + py::gil_scoped_release gil_release; + + std::unique_lock lk(mu_); + // The cancelled flag can only move from false to true, not the other way, so + // it is checked on each query until cancelled and then implicitly cached. + if (is_cancelled_) { + return; + } + if (!updating_) { + std::unique_ptr& stub = Stub::GetOrCreateInstance(); + if (!stub->StubToParentServiceActive()) { + LOG_ERROR << "Cannot communicate with parent service"; + return; + } + + stub->EnqueueCancelBLSRequest(this); + updating_ = true; + } + cv_.wait(lk, [this] { return !updating_; }); +} + +void +PbBLSCancel::ReportIsCancelled(bool is_cancelled) +{ + { + std::lock_guard lk(mu_); + is_cancelled_ = is_cancelled; + updating_ = false; + } + cv_.notify_all(); +} + +}}} // namespace triton::backend::python diff --git a/src/pb_bls_cancel.h b/src/pb_bls_cancel.h new file mode 100644 index 00000000..7fdd3fbf --- /dev/null +++ b/src/pb_bls_cancel.h @@ -0,0 +1,63 @@ +// Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include + +#include "pb_utils.h" + +namespace triton { namespace backend { namespace python { + +class PbBLSCancel { + public: + PbBLSCancel(void* infer_playload_id) + : updating_(false), infer_playload_id_(infer_playload_id), + is_cancelled_(false) + { + } + DISALLOW_COPY_AND_ASSIGN(PbBLSCancel); + + void SaveToSharedMemory(std::unique_ptr& shm_pool); + bi::managed_external_buffer::handle_t ShmHandle(); + CancelBLSRequestMessage* ShmPayload(); + + void Cancel(); + void ReportIsCancelled(bool is_cancelled); + + private: + AllocatedSharedMemory cancel_shm_; + + std::mutex mu_; + std::condition_variable cv_; + bool updating_; + + void* infer_playload_id_; + bool is_cancelled_; +}; + +}}}; // namespace triton::backend::python diff --git a/src/pb_cancel.cc b/src/pb_cancel.cc new file mode 100644 index 00000000..da9daf98 --- /dev/null +++ b/src/pb_cancel.cc @@ -0,0 +1,94 @@ +// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "pb_cancel.h" + +#include "pb_stub.h" +#include "pb_stub_log.h" + +namespace triton { namespace backend { namespace python { + +void +PbCancel::SaveToSharedMemory(std::unique_ptr& shm_pool) +{ + cancel_shm_ = shm_pool->Construct(); + new (&(cancel_shm_.data_->mu)) bi::interprocess_mutex; + new (&(cancel_shm_.data_->cv)) bi::interprocess_condition; + cancel_shm_.data_->waiting_on_stub = false; + cancel_shm_.data_->response_factory_address = response_factory_address_; + cancel_shm_.data_->request_address = request_address_; + cancel_shm_.data_->is_cancelled = is_cancelled_; +} + +bi::managed_external_buffer::handle_t +PbCancel::ShmHandle() +{ + return cancel_shm_.handle_; +} + +IsCancelledMessage* +PbCancel::ShmPayload() +{ + return cancel_shm_.data_.get(); +} + +bool +PbCancel::IsCancelled() +{ + // Release the GIL. Python objects are not accessed during the check. + py::gil_scoped_release gil_release; + + std::unique_lock lk(mu_); + // The cancelled flag can only move from false to true, not the other way, so + // it is checked on each query until cancelled and then implicitly cached. + if (is_cancelled_) { + return is_cancelled_; + } + if (!updating_) { + std::unique_ptr& stub = Stub::GetOrCreateInstance(); + if (!stub->StubToParentServiceActive()) { + LOG_ERROR << "Cannot communicate with parent service"; + return false; + } + stub->EnqueueIsCancelled(this); + updating_ = true; + } + cv_.wait(lk, [this] { return !updating_; }); + return is_cancelled_; +} + +void +PbCancel::ReportIsCancelled(bool is_cancelled) +{ + { + std::lock_guard lk(mu_); + is_cancelled_ = is_cancelled; + updating_ = false; + } + cv_.notify_all(); +} + +}}} // namespace triton::backend::python diff --git a/src/pb_cancel.h b/src/pb_cancel.h new file mode 100644 index 00000000..3ebf07b5 --- /dev/null +++ b/src/pb_cancel.h @@ -0,0 +1,64 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include + +#include "pb_utils.h" + +namespace triton { namespace backend { namespace python { + +class PbCancel { + public: + PbCancel(intptr_t response_factory_address, intptr_t request_address) + : updating_(false), response_factory_address_(response_factory_address), + request_address_(request_address), is_cancelled_(false) + { + } + DISALLOW_COPY_AND_ASSIGN(PbCancel); + + void SaveToSharedMemory(std::unique_ptr& shm_pool); + bi::managed_external_buffer::handle_t ShmHandle(); + IsCancelledMessage* ShmPayload(); + + bool IsCancelled(); + void ReportIsCancelled(bool is_cancelled); + + private: + AllocatedSharedMemory cancel_shm_; + + std::mutex mu_; + std::condition_variable cv_; + bool updating_; + + intptr_t response_factory_address_; + intptr_t request_address_; + bool is_cancelled_; +}; + +}}}; // namespace triton::backend::python diff --git a/src/pb_env.cc b/src/pb_env.cc index 0b6eb9ec..d9643a62 100644 --- a/src/pb_env.cc +++ b/src/pb_env.cc @@ -26,9 +26,11 @@ #include "pb_env.h" +#ifndef _WIN32 #include #include #include +#endif #include #include @@ -40,6 +42,29 @@ namespace triton { namespace backend { namespace python { +bool +FileExists(std::string& path) +{ + struct stat buffer; + return stat(path.c_str(), &buffer) == 0; +} + +void +LastModifiedTime(const std::string& path, time_t* last_modified_time) +{ + struct stat result; + if (stat(path.c_str(), &result) == 0) { + *last_modified_time = result.st_mtime; + } else { + throw PythonBackendException(std::string( + "LastModifiedTime() failed as file \'" + path + + std::string("\' does not exists."))); + } +} + +// FIXME: [DLIS-5969]: Develop platforom-agnostic functions +// to support custom python environments. +#ifndef _WIN32 void CopySingleArchiveEntry(archive* input_archive, archive* output_archive) { @@ -73,7 +98,6 @@ CopySingleArchiveEntry(archive* input_archive, archive* output_archive) } } - void ExtractTarFile(std::string& archive_path, std::string& dst_path) { @@ -153,27 +177,6 @@ ExtractTarFile(std::string& archive_path, std::string& dst_path) } } -bool -FileExists(std::string& path) -{ - struct stat buffer; - return stat(path.c_str(), &buffer) == 0; -} - -void -LastModifiedTime(const std::string& path, time_t* last_modified_time) -{ - struct stat result; - if (stat(path.c_str(), &result) == 0) { - *last_modified_time = result.st_mtime; - } else { - throw PythonBackendException(std::string( - "LastModifiedTime() failed as file \'" + path + - std::string("\' does not exists."))); - } -} - - void RecursiveDirectoryDelete(const char* dir) { @@ -326,5 +329,6 @@ EnvironmentManager::~EnvironmentManager() { RecursiveDirectoryDelete(base_path_); } +#endif }}} // namespace triton::backend::python diff --git a/src/pb_env.h b/src/pb_env.h index 09890ee8..04e01fa3 100644 --- a/src/pb_env.h +++ b/src/pb_env.h @@ -30,6 +30,11 @@ #include #include +#ifdef WIN32 +#include +#undef PATH_MAX +#define PATH_MAX MAX_PATH +#endif namespace triton { namespace backend { namespace python { void ExtractTarFile(std::string& archive_path, std::string& dst_path); @@ -39,6 +44,7 @@ bool FileExists(std::string& path); // // A class that manages Python environments // +#ifndef _WIN32 class EnvironmentManager { std::map> env_map_; char base_path_[PATH_MAX + 1]; @@ -52,5 +58,6 @@ class EnvironmentManager { std::string ExtractIfNotExtracted(std::string env_path); ~EnvironmentManager(); }; +#endif }}} // namespace triton::backend::python diff --git a/src/pb_error.cc b/src/pb_error.cc index e190af42..0e5d0bd4 100644 --- a/src/pb_error.cc +++ b/src/pb_error.cc @@ -1,4 +1,4 @@ -// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -27,6 +27,13 @@ #include "pb_error.h" namespace triton { namespace backend { namespace python { + +TRITONSERVER_Error_Code +PbError::Code() +{ + return code_; +} + const std::string& PbError::Message() { @@ -43,7 +50,10 @@ void PbError::SaveToSharedMemory(std::unique_ptr& shm_pool) { message_shm_ = PbString::Create(shm_pool, message_); - shm_handle_ = message_shm_->ShmHandle(); + error_shm_ = shm_pool->Construct(); + error_shm_.data_->code = code_; + error_shm_.data_->message_shm_handle = message_shm_->ShmHandle(); + shm_handle_ = error_shm_.handle_; } std::shared_ptr @@ -51,14 +61,25 @@ PbError::LoadFromSharedMemory( std::unique_ptr& shm_pool, bi::managed_external_buffer::handle_t shm_handle) { - std::unique_ptr message_shm = - PbString::LoadFromSharedMemory(shm_pool, shm_handle); - return std::shared_ptr(new PbError(message_shm)); + AllocatedSharedMemory error_shm = + shm_pool->Load(shm_handle); + std::unique_ptr message_shm = PbString::LoadFromSharedMemory( + shm_pool, error_shm.data_->message_shm_handle); + + TRITONSERVER_Error_Code code = error_shm.data_->code; + std::string message = message_shm->String(); + + return std::shared_ptr(new PbError( + std::move(message_shm), std::move(error_shm), code, std::move(message))); } -PbError::PbError(std::unique_ptr& message_shm) +PbError::PbError( + std::shared_ptr&& message_shm, + AllocatedSharedMemory&& error_shm, TRITONSERVER_Error_Code code, + std::string&& message) + : message_shm_(std::move(message_shm)), error_shm_(std::move(error_shm)), + code_(code), message_(std::move(message)) { - message_shm_ = std::move(message_shm); - message_ = message_shm_->String(); } + }}} // namespace triton::backend::python diff --git a/src/pb_error.h b/src/pb_error.h index b80546b2..6001459a 100644 --- a/src/pb_error.h +++ b/src/pb_error.h @@ -1,4 +1,4 @@ -// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -32,21 +32,44 @@ #include "pb_utils.h" namespace triton { namespace backend { namespace python { + +struct PbErrorShm { + TRITONSERVER_Error_Code code; + bi::managed_external_buffer::handle_t message_shm_handle; +}; + class PbError { public: - PbError(const std::string& message) : message_(message) {} + PbError( + const std::string& message, + TRITONSERVER_Error_Code code = TRITONSERVER_ERROR_INTERNAL) + : code_(code), message_(message) + { + } + DISALLOW_COPY_AND_ASSIGN(PbError); + + TRITONSERVER_Error_Code Code(); const std::string& Message(); + void SaveToSharedMemory(std::unique_ptr& shm_pool); bi::managed_external_buffer::handle_t ShmHandle(); + static std::shared_ptr LoadFromSharedMemory( std::unique_ptr& shm_pool, bi::managed_external_buffer::handle_t handle); - DISALLOW_COPY_AND_ASSIGN(PbError); private: - PbError(std::unique_ptr& pb_error); - std::string message_; + PbError( + std::shared_ptr&& message_shm, + AllocatedSharedMemory&& error_shm, + TRITONSERVER_Error_Code code, std::string&& message); + std::shared_ptr message_shm_; + AllocatedSharedMemory error_shm_; bi::managed_external_buffer::handle_t shm_handle_; + + TRITONSERVER_Error_Code code_; + std::string message_; }; + }}}; // namespace triton::backend::python diff --git a/src/pb_memory.cc b/src/pb_memory.cc index c18bf912..5b678f1a 100644 --- a/src/pb_memory.cc +++ b/src/pb_memory.cc @@ -1,4 +1,4 @@ -// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -26,6 +26,8 @@ #include "pb_memory.h" +#include + namespace triton { namespace backend { namespace python { std::unique_ptr @@ -35,7 +37,6 @@ PbMemory::Create( uint64_t byte_size, char* data, bool copy_gpu) { size_t requested_byte_size = sizeof(MemoryShm); - if (memory_type == TRITONSERVER_MEMORY_GPU) { #ifdef TRITON_ENABLE_GPU requested_byte_size += sizeof(cudaIpcMemHandle_t); @@ -46,9 +47,10 @@ PbMemory::Create( AllocatedSharedMemory memory_shm = shm_pool->Construct(requested_byte_size); + PbMemory::FillShmData( - memory_type, memory_type_id, byte_size, data, memory_shm.data_.get(), - memory_shm.handle_, copy_gpu); + shm_pool->GetCUDAMemoryPoolManager(), memory_type, memory_type_id, + byte_size, data, memory_shm.data_.get(), memory_shm.handle_, copy_gpu); if (memory_type == TRITONSERVER_MEMORY_CPU) { data = memory_shm.data_.get() + sizeof(MemoryShm); @@ -83,12 +85,14 @@ PbMemory::Create( std::unique_ptr PbMemory::Create( + std::unique_ptr& shm_pool, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, uint64_t byte_size, char* data, char* data_shm, bi::managed_external_buffer::handle_t handle, bool copy_gpu) { PbMemory::FillShmData( - memory_type, memory_type_id, byte_size, data, data_shm, handle, copy_gpu); + shm_pool->GetCUDAMemoryPoolManager(), memory_type, memory_type_id, + byte_size, data, data_shm, handle, copy_gpu); if (memory_type == TRITONSERVER_MEMORY_CPU) { data = data_shm + sizeof(MemoryShm); @@ -176,14 +180,15 @@ PbMemory::CopyBuffer( void PbMemory::FillShmData( + std::unique_ptr& cuda_pool, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, uint64_t byte_size, char* data, char* data_shm, bi::managed_external_buffer::handle_t handle, bool copy_gpu) { char* memory_data_shm = data_shm + sizeof(MemoryShm); MemoryShm* memory_shm_ptr = reinterpret_cast(data_shm); - memory_shm_ptr->is_cuda_handle_set = copy_gpu; memory_shm_ptr->memory_release_id = 0; + bool use_cuda_shared_pool = false; if (memory_type == TRITONSERVER_MEMORY_GPU) { #ifdef TRITON_ENABLE_GPU @@ -193,8 +198,15 @@ PbMemory::FillShmData( THROW_IF_CUDA_ERROR(cudaIpcGetMemHandle( reinterpret_cast(memory_data_shm), data)); } + if (cuda_pool->UseCudaSharedPool(memory_type_id) && + IsUsingCUDAPool(cuda_pool, memory_type_id, data)) { + use_cuda_shared_pool = true; + memory_shm_ptr->cuda_pool_offset = + data - + reinterpret_cast(cuda_pool->CUDAPoolAddress(memory_type_id)); + } } -#endif +#endif // TRITON_ENABLE_GPU } else { if (data != nullptr) { std::copy(data, data + byte_size, memory_data_shm); @@ -204,45 +216,69 @@ PbMemory::FillShmData( memory_shm_ptr->byte_size = byte_size; memory_shm_ptr->memory_type_id = memory_type_id; memory_shm_ptr->memory_type = memory_type; + memory_shm_ptr->use_cuda_shared_pool = use_cuda_shared_pool; } std::unique_ptr PbMemory::LoadFromSharedMemory( + std::unique_ptr& shm_pool, bi::managed_external_buffer::handle_t handle, char* data_shm, bool open_cuda_handle) { MemoryShm* memory_shm_ptr = reinterpret_cast(data_shm); char* memory_data_shm = data_shm + sizeof(MemoryShm); - char* data_ptr = nullptr; bool opened_cuda_ipc_handle = false; if (memory_shm_ptr->memory_type == TRITONSERVER_MEMORY_GPU && open_cuda_handle) { #ifdef TRITON_ENABLE_GPU - cudaIpcMemHandle_t* cuda_handle = - reinterpret_cast(memory_data_shm); + if (memory_shm_ptr->use_cuda_shared_pool) { + // When CUDA shared memory pool is used, the stub will retrieve the + // data pointer using the offset. + data_ptr = + (reinterpret_cast( + shm_pool->GetCUDAMemoryPoolManager()->CUDAPoolAddress( + memory_shm_ptr->memory_type_id)) + + memory_shm_ptr->cuda_pool_offset); + } else { + cudaIpcMemHandle_t* cuda_handle = + reinterpret_cast(memory_data_shm); - // The pointer opened by the cudaIpcOpenMemHandle will refer to the base - // address. We need to manually correct the offset. - void* data_ptr_base; - CUDAHandler& cuda_handler = CUDAHandler::getInstance(); - cuda_handler.OpenCudaHandle( - memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base); + // The pointer opened by the cudaIpcOpenMemHandle will refer to the base + // address. We need to manually correct the offset. + void* data_ptr_base; + CUDAHandler& cuda_handler = CUDAHandler::getInstance(); + cuda_handler.OpenCudaHandle( + memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base); - data_ptr = - (reinterpret_cast(data_ptr_base) + - memory_shm_ptr->gpu_pointer_offset); - opened_cuda_ipc_handle = true; -#endif + data_ptr = + (reinterpret_cast(data_ptr_base) + + memory_shm_ptr->gpu_pointer_offset); + opened_cuda_ipc_handle = true; + } + +#endif // TRITON_ENABLE_GPU } else { data_ptr = memory_data_shm; } + + // This check only validates CPU shared memory access. + if (memory_shm_ptr->memory_type != TRITONSERVER_MEMORY_GPU && + (data_ptr + memory_shm_ptr->byte_size > + (char*)shm_pool->GetBaseAddress() + shm_pool->GetCurrentCapacity())) { + std::ostringstream oss; + oss << "0x" << std::hex + << (reinterpret_cast(data_ptr) + memory_shm_ptr->byte_size); + throw PythonBackendException( + std::string("Attempted to access out of bounds memory address ") + + oss.str()); + } + return std::unique_ptr(new PbMemory( data_shm, data_ptr, handle, opened_cuda_ipc_handle /* opened_cuda_ipc_handle */)); } - std::unique_ptr PbMemory::LoadFromSharedMemory( std::unique_ptr& shm_pool, @@ -258,26 +294,48 @@ PbMemory::LoadFromSharedMemory( if (memory_shm_ptr->memory_type == TRITONSERVER_MEMORY_GPU) { if (memory_shm_ptr->byte_size > 0 && open_cuda_handle) { #ifdef TRITON_ENABLE_GPU - cudaIpcMemHandle_t* cuda_handle = - reinterpret_cast(memory_data_shm); - - // The pointer opened by the cudaIpcOpenMemHandle will refer to the base - // address. We need to manually correct the offset. - - void* data_ptr_base; - CUDAHandler& cuda_handler = CUDAHandler::getInstance(); - cuda_handler.OpenCudaHandle( - memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base); - - data_ptr = - (reinterpret_cast(data_ptr_base) + - memory_shm_ptr->gpu_pointer_offset); - opened_cuda_ipc_handle = true; + if (memory_shm_ptr->use_cuda_shared_pool) { + // When CUDA shared memory pool is used, the stub will retrieve the + // data pointer using the offset. + data_ptr = + (reinterpret_cast( + shm_pool->GetCUDAMemoryPoolManager()->CUDAPoolAddress( + memory_shm_ptr->memory_type_id)) + + memory_shm_ptr->cuda_pool_offset); + } else { + cudaIpcMemHandle_t* cuda_handle = + reinterpret_cast(memory_data_shm); + + // The pointer opened by the cudaIpcOpenMemHandle will refer to the base + // address. We need to manually correct the offset. + void* data_ptr_base; + CUDAHandler& cuda_handler = CUDAHandler::getInstance(); + cuda_handler.OpenCudaHandle( + memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base); + + data_ptr = + (reinterpret_cast(data_ptr_base) + + memory_shm_ptr->gpu_pointer_offset); + opened_cuda_ipc_handle = true; + } #endif } } else { data_ptr = memory_data_shm; } + + // This check only validates CPU shared memory access. + if (memory_shm_ptr->memory_type != TRITONSERVER_MEMORY_GPU && + (data_ptr + memory_shm_ptr->byte_size > + (char*)shm_pool->GetBaseAddress() + shm_pool->GetCurrentCapacity())) { + std::ostringstream oss; + oss << "0x" << std::hex + << (reinterpret_cast(data_ptr) + memory_shm_ptr->byte_size); + throw PythonBackendException( + std::string("Attempted to access out of bounds memory address ") + + oss.str()); + } + return std::unique_ptr(new PbMemory( memory_shm, data_ptr, opened_cuda_ipc_handle /* opened_cuda_ipc_handle */)); @@ -403,6 +461,18 @@ PbMemory::SetCudaIpcHandle(cudaIpcMemHandle_t* cuda_ipc_handle) { *(reinterpret_cast(ShmData())) = *(cuda_ipc_handle); } + +void +PbMemory::UpdateCUDAOffset(std::unique_ptr& cuda_pool) +{ + if (cuda_pool->UseCudaSharedPool(MemoryTypeId()) && + IsUsingCUDAPool(cuda_pool, MemoryTypeId(), DataPtr())) { + memory_shm_ptr_->cuda_pool_offset = + DataPtr() - + reinterpret_cast(cuda_pool->CUDAPoolAddress(MemoryTypeId())); + memory_shm_ptr_->use_cuda_shared_pool = true; + } +} #endif PbMemory::~PbMemory() diff --git a/src/pb_memory.h b/src/pb_memory.h index e7986014..ad79daed 100644 --- a/src/pb_memory.h +++ b/src/pb_memory.h @@ -1,4 +1,4 @@ -// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -42,13 +42,18 @@ namespace triton { namespace backend { namespace python { // struct MemoryShm { // If the memory type is a GPU pointer, the offset of the GPU pointer from the - // base address. For CPU memory type this field contains garbage data. + // base address. For CPU memory type this field contains garbage data. This + // field will only be used when the memory is not allocated from the CUDA + // shared memory pool. uint64_t gpu_pointer_offset; + bool use_cuda_shared_pool; + // The offset of the memory from the base address of the CUDA shared memory + // pool. + uint64_t cuda_pool_offset; TRITONSERVER_MemoryType memory_type; int64_t memory_type_id; uint64_t byte_size; - bool is_cuda_handle_set; uint64_t memory_release_id; }; @@ -60,6 +65,7 @@ class PbMemory { uint64_t byte_size, char* data, bool copy_gpu = true); static std::unique_ptr Create( + std::unique_ptr& shm_pool, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, uint64_t byte_size, char* data, char* data_shm, bi::managed_external_buffer::handle_t handle, bool copy_gpu = true); @@ -72,6 +78,8 @@ class PbMemory { #ifdef TRITON_ENABLE_GPU void SetCudaIpcHandle(cudaIpcMemHandle_t* cuda_ipc_handle); + + void UpdateCUDAOffset(std::unique_ptr& cuda_pool); #endif // Copy the destination buffer to the source buffer. @@ -83,6 +91,7 @@ class PbMemory { bi::managed_external_buffer::handle_t memory_handle, bool open_cuda_handle); static std::unique_ptr LoadFromSharedMemory( + std::unique_ptr& shm_pool, bi::managed_external_buffer::handle_t handle, char* data_shm, bool open_cuda_handle); static uint64_t ShmStructSize( @@ -117,8 +126,25 @@ class PbMemory { void SetMemoryReleaseCallback(std::function release_callback); + bool UseCUDASharedPool() const + { + return memory_shm_ptr_->use_cuda_shared_pool; + } + ~PbMemory(); +#ifndef TRITON_PB_STUB + void SetBackendMemory(std::unique_ptr&& backend_memory) + { + backend_memory_ = std::move(backend_memory); + }; + + std::unique_ptr GetBackendMemory() + { + return std::move(backend_memory_); + }; +#endif + private: AllocatedSharedMemory memory_shm_; MemoryShm* memory_shm_ptr_; @@ -150,6 +176,7 @@ class PbMemory { #endif static void FillShmData( + std::unique_ptr& cuda_pool, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, uint64_t byte_size, char* data, char* data_shm, bi::managed_external_buffer::handle_t handle, bool copy_gpu = true); diff --git a/src/pb_preferred_memory.h b/src/pb_preferred_memory.h index 55f4db89..c28f1b87 100644 --- a/src/pb_preferred_memory.h +++ b/src/pb_preferred_memory.h @@ -30,10 +30,10 @@ namespace triton { namespace backend { namespace python { class PreferredMemory { public: - enum MemoryType { GPU, CPU, DEFAULT }; + enum MemoryType { kGPU, kCPU, kDefault }; PreferredMemory() - : preferred_memory_type_(MemoryType::DEFAULT), preferred_device_id_(0) + : preferred_memory_type_(MemoryType::kDefault), preferred_device_id_(0) { } diff --git a/src/pb_response_iterator.cc b/src/pb_response_iterator.cc index 9561df68..536d4232 100644 --- a/src/pb_response_iterator.cc +++ b/src/pb_response_iterator.cc @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -40,6 +40,7 @@ ResponseIterator::ResponseIterator( : id_(response->Id()), is_finished_(false), is_cleared_(false), idx_(0) { response_buffer_.push(response); + pb_bls_cancel_ = std::make_shared(response->Id()); } ResponseIterator::~ResponseIterator() @@ -100,7 +101,7 @@ ResponseIterator::Next() } } -py::iterator +void ResponseIterator::Iter() { if (is_finished_) { @@ -111,8 +112,6 @@ ResponseIterator::Iter() idx_ = 0; } } - - return py::cast(*this); } void @@ -135,7 +134,7 @@ void ResponseIterator::Clear() { std::unique_ptr& stub = Stub::GetOrCreateInstance(); - stub->EnqueueCleanupId(id_); + stub->EnqueueCleanupId(id_, PYTHONSTUB_BLSDecoupledInferPayloadCleanup); { std::lock_guard lock{mu_}; response_buffer_.push(DUMMY_MESSAGE); @@ -161,4 +160,12 @@ ResponseIterator::GetExistingResponses() return responses; } +void +ResponseIterator::Cancel() +{ + if (!is_finished_) { + pb_bls_cancel_->Cancel(); + } +} + }}} // namespace triton::backend::python diff --git a/src/pb_response_iterator.h b/src/pb_response_iterator.h index 1122a216..cb26d6a3 100644 --- a/src/pb_response_iterator.h +++ b/src/pb_response_iterator.h @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -29,6 +29,7 @@ #include #include "infer_response.h" +#include "pb_bls_cancel.h" namespace triton { namespace backend { namespace python { @@ -38,11 +39,12 @@ class ResponseIterator { ~ResponseIterator(); std::shared_ptr Next(); - py::iterator Iter(); + void Iter(); void EnqueueResponse(std::shared_ptr infer_response); void* Id(); void Clear(); std::vector> GetExistingResponses(); + void Cancel(); private: std::vector> responses_; @@ -53,6 +55,7 @@ class ResponseIterator { bool is_finished_; bool is_cleared_; size_t idx_; + std::shared_ptr pb_bls_cancel_; }; }}} // namespace triton::backend::python diff --git a/src/pb_stub.cc b/src/pb_stub.cc index eb561dec..56048d78 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -28,7 +28,6 @@ #include #include -#include #include #include @@ -43,18 +42,27 @@ #include #include +#include "correlation_id.h" #include "model_loader.h" #include "pb_error.h" #include "pb_map.h" #include "pb_preferred_memory.h" #include "pb_response_iterator.h" #include "pb_string.h" +#include "pb_stub_log.h" #include "pb_utils.h" #include "response_sender.h" #include "scoped_defer.h" #include "shm_manager.h" #include "triton/common/nvtx.h" +#ifdef _WIN32 +#include // SIGINT & SIGTERM +#include +#else +#include +#endif + #ifdef TRITON_ENABLE_GPU #include #endif // TRITON_ENABLE_GPU @@ -76,15 +84,66 @@ SignalHandler(int signum) // Skip the SIGINT and SIGTERM } +template +PYTYPE +PyDefaultArgumentToMutableType(const py::object& argument) +{ + // The default argument on Python functions always reference the same copy, + // meaning if the default argument is changed by the function, then it is + // changed for all subsequent calls to the function. Thus, default arguments + // should be limited to basic types (i.e. None). This helper function returns + // an empty expected type, if the argument is None (i.e. default initialized). + // If the argument is neither None nor expected type, an exception is thrown. + if (py::isinstance(argument)) { + return PYTYPE(); + } + if (py::isinstance(argument)) { + return argument; + } + throw PythonBackendException( + std::string("Expect ") + typeid(PYTYPE).name() + ", got " + + std::string(py::str(argument.get_type()))); +} + +std::string +PyParametersToJSON(const py::dict& parameters) +{ + for (const auto& pair : parameters) { + if (!py::isinstance(pair.first)) { + throw PythonBackendException( + "Expect parameters keys to have type str, found type " + + std::string(py::str(pair.first.get_type()))); + } + if (!py::isinstance(pair.second) && + !py::isinstance(pair.second) && + !py::isinstance(pair.second)) { + throw PythonBackendException( + "Expect parameters values to have type bool/int/str, found type " + + std::string(py::str(pair.second.get_type()))); + } + } + py::module_ py_json = py::module_::import("json"); + std::string parameters_str = py::str(py_json.attr("dumps")(parameters)); + return parameters_str; +} + +void +AsyncEventFutureDoneCallback(const py::object& py_future) +{ + std::unique_ptr& stub = Stub::GetOrCreateInstance(); + stub->BackgroundFutureDone(py_future); +} + void Stub::Instantiate( int64_t shm_growth_size, int64_t shm_default_size, const std::string& shm_region_name, const std::string& model_path, const std::string& model_version, const std::string& triton_install_path, bi::managed_external_buffer::handle_t ipc_control_handle, - const std::string& name, const std::string& platform) + const std::string& name, const std::string& python_runtime_model) { - model_context_.Init(model_path, platform, triton_install_path, model_version); + model_context_.Init( + model_path, python_runtime_model, triton_install_path, model_version); name_ = name; health_mutex_ = nullptr; initialized_ = false; @@ -126,6 +185,7 @@ Stub::Instantiate( // interfere with the shared library resolution of other executable and // binaries. if (ipc_control_->uses_env) { +#ifndef _WIN32 char* ld_library_path = std::getenv("LD_LIBRARY_PATH"); if (ld_library_path != nullptr) { @@ -151,6 +211,11 @@ Stub::Instantiate( "When using an execution environment, LD_LIBRARY_PATH variable " "cannot be empty."); } +#else + throw PythonBackendException( + "Custom execution environments are not currently supported on " + "Windows."); +#endif } } catch (const PythonBackendException& pb_exception) { @@ -341,15 +406,19 @@ Stub::RunCommand() shm_pool_->Load(ipc_message->Args()); RequestBatch* request_batch_shm_ptr = reinterpret_cast(request_batch.data_.get()); - if (!ipc_control_->decoupled) { - ProcessRequests(request_batch_shm_ptr); - } else { - ProcessRequestsDecoupled(request_batch_shm_ptr); - } + ProcessRequests(request_batch_shm_ptr); } break; case PYTHONSTUB_CommandType::PYTHONSTUB_FinalizeRequest: ipc_message->Command() = PYTHONSTUB_FinalizeResponse; + // Clean up response_iterator_map_ before sending sending message back to + // the parent process to make sure that the clean up message can be + // processed before the message queue is destroyed. + { + std::lock_guard lock(response_iterator_map_mu_); + std::unordered_map>().swap( + response_iterator_map_); + } SendIPCMessage(ipc_message); return true; // Terminate the stub process case PYTHONSTUB_CommandType::PYTHONSTUB_LoadGPUBuffers: @@ -490,6 +559,9 @@ Stub::Initialize(bi::managed_external_buffer::handle_t map_handle) c_python_backend_utils.attr("InferenceResponse")); c_python_backend_utils.attr("shared_memory") = py::cast(shm_pool_.get()); + async_event_loop_ = py::none(); + background_futures_ = py::set(); + py::object TritonPythonModel = sys.attr("TritonPythonModel"); deserialize_bytes_ = python_backend_utils.attr("deserialize_bytes_tensor"); serialize_bytes_ = python_backend_utils.attr("serialize_byte_tensor"); @@ -526,18 +598,6 @@ Stub::Initialize(bi::managed_external_buffer::handle_t map_handle) initialized_ = true; } -void -Stub::ProcessResponse(InferResponse* response) -{ - response->SaveToSharedMemory(shm_pool_, false /* copy_gpu */); - - for (auto& output_tensor : response->OutputTensors()) { - if (!output_tensor->IsCPU()) { - gpu_tensors_.push_back(output_tensor); - } - } -} - void Stub::LoadGPUBuffers(std::unique_ptr& ipc_message) { @@ -603,7 +663,8 @@ Stub::LoadRequestsFromSharedMemory(RequestBatch* request_batch_shm_ptr) for (size_t i = 0; i < batch_size; i++) { std::shared_ptr infer_request = InferRequest::LoadFromSharedMemory( - shm_pool_, request_shm_handle[i], true /* open_cuda_handle */); + shm_pool_, request_shm_handle[i], true /* open_cuda_handle */, + &ipc_control_->decoupled /* is_model_decoupled */); py_request_list.append(infer_request); } @@ -611,31 +672,24 @@ Stub::LoadRequestsFromSharedMemory(RequestBatch* request_batch_shm_ptr) } void -Stub::ProcessRequestsDecoupled(RequestBatch* request_batch_shm_ptr) +Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) { py::list py_request_list = LoadRequestsFromSharedMemory(request_batch_shm_ptr); - std::unique_ptr execute_response = - IPCMessage::Create(shm_pool_, false /* Inline response */); - execute_response->Command() = PYTHONSTUB_ExecuteResponse; + std::unique_ptr execute_response; - AllocatedSharedMemory response_batch = - shm_pool_->Construct(); - ResponseBatch* response_batch_shm_ptr = - reinterpret_cast(response_batch.data_.get()); - execute_response->Args() = response_batch.handle_; + std::optional> response_batch; bool has_exception = false; std::string error_string; std::unique_ptr error_string_shm; + std::string err_message; ScopedDefer execute_finalize([this] { stub_message_queue_->Pop(); }); ScopedDefer _( [this, &execute_response] { SendIPCMessage(execute_response); }); - + py::object execute_return; + py::object coroutine_return; try { - response_batch_shm_ptr->has_error = false; - response_batch_shm_ptr->is_error_set = false; - if (!py::hasattr(model_instance_, "execute")) { std::string message = "Python model " + model_context_.PythonModelPath() + " does not implement `execute` method."; @@ -645,13 +699,24 @@ Stub::ProcessRequestsDecoupled(RequestBatch* request_batch_shm_ptr) { NVTX_RANGE(nvtx_, "PyExecute " + name_); - py::object execute_return = - model_instance_.attr("execute")(py_request_list); - if (!py::isinstance(execute_return)) { - throw PythonBackendException( - "Python model '" + name_ + - "' is using the decoupled mode and the execute function must " - "return None."); + execute_return = model_instance_.attr("execute")(py_request_list); + + bool is_coroutine = py::module::import("asyncio") + .attr("iscoroutine")(execute_return) + .cast(); + if (is_coroutine) { + if (IsDecoupled()) { + // Do not wait for async decoupled execute to return. + RunCoroutine(execute_return, true /* in_background */); + } else { + coroutine_return = + RunCoroutine(execute_return, false /* in_background */); + ProcessReturnedResponses( + py_request_list, coroutine_return, response_batch); + } + } else { + ProcessReturnedResponses( + py_request_list, execute_return, response_batch); } } } @@ -665,152 +730,249 @@ Stub::ProcessRequestsDecoupled(RequestBatch* request_batch_shm_ptr) } if (has_exception) { - std::string err_message = - std::string( - "Failed to process the request(s) for model '" + name_ + - "', message: ") + - error_string; - LOG_INFO << err_message.c_str(); + err_message = std::string( + "Failed to process the request(s) for model '" + name_ + + "', message: ") + + error_string; + LOG_ERROR << err_message.c_str(); + if (!response_batch) { + response_batch = shm_pool_->Construct( + sizeof(ResponseBatch) + sizeof(IPCMessageShm)); + } + ResponseBatch* response_batch_shm_ptr = reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); + + // The backend will clean up the response factory if there is an error in + // the response batch. For decoupled mode, it is necessary to handle cases + // where the response sender should have already cleaned up, ensuring the + // backend does not delete the response factory again during error handling. + if (IsDecoupled()) { + for (py::handle py_request : py_request_list) { + InferRequest* request = py_request.cast(); + if (request->GetResponseSender()->IsClosed()) { + response_batch_shm_ptr->is_response_factory_deleted = true; + } + } + } + response_batch_shm_ptr->has_error = true; - error_string_shm = PbString::Create(shm_pool_, error_string); + error_string_shm = PbString::Create(shm_pool_, err_message); response_batch_shm_ptr->error = error_string_shm->ShmHandle(); response_batch_shm_ptr->is_error_set = true; + response_batch_shm_ptr->batch_size = 0; + // Once the error is sent to the backend, the backend is supposed to close + // all response factories if not already closed, so closing all response + // senders if not already closed to prevent the model from sending more + // responses after the factories are closed. + for (py::handle py_request : py_request_list) { + InferRequest* request = py_request.cast(); + request->GetResponseSender()->Close(); + } + } else { + if (!response_batch) { + response_batch = shm_pool_->Construct( + sizeof(ResponseBatch) + sizeof(IPCMessageShm)); + ResponseBatch* response_batch_shm_ptr = reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); + response_batch_shm_ptr->batch_size = 0; + } + ResponseBatch* response_batch_shm_ptr = reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); + response_batch_shm_ptr->has_error = false; + response_batch_shm_ptr->is_error_set = false; } + + execute_response = IPCMessage::Create( + reinterpret_cast(response_batch.value().data_.get()), + response_batch.value().handle_); + execute_response->Args() = + response_batch.value().handle_ + sizeof(IPCMessageShm); + execute_response->InlineResponse() = false; + execute_response->Command() = PYTHONSTUB_ExecuteResponse; + _.Complete(); + execute_finalize.Complete(); } void -Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) +Stub::ProcessResponse(InferResponse* response) { - std::unique_ptr execute_response = - IPCMessage::Create(shm_pool_, false /* Inline response */); - execute_response->Command() = PYTHONSTUB_ExecuteResponse; - - AllocatedSharedMemory response_batch = shm_pool_->Construct( - request_batch_shm_ptr->batch_size * - sizeof(bi::managed_external_buffer::handle_t) + - sizeof(ResponseBatch)); - ResponseBatch* response_batch_shm_ptr = - reinterpret_cast(response_batch.data_.get()); - - std::unique_ptr error_string_shm; - py::list inference_responses; - - bi::managed_external_buffer::handle_t* responses_shm_handle = - reinterpret_cast( - response_batch.data_.get() + sizeof(ResponseBatch)); - - py::list responses; - - // Notifying the stub should be after responses. - ScopedDefer execute_finalize([this] { stub_message_queue_->Pop(); }); - ScopedDefer _( - [this, &execute_response] { SendIPCMessage(execute_response); }); - - execute_response->Args() = response_batch.handle_; - - bool has_exception = false; - std::string error_string; - try { - response_batch_shm_ptr->has_error = false; - response_batch_shm_ptr->is_error_set = false; - - uint32_t batch_size = request_batch_shm_ptr->batch_size; - - if (batch_size == 0) { - return; - } - - py::list py_request_list = - LoadRequestsFromSharedMemory(request_batch_shm_ptr); + response->SaveToSharedMemory(shm_pool_, false /* copy_gpu */); - if (!py::hasattr(model_instance_, "execute")) { - std::string message = "Python model " + model_context_.PythonModelPath() + - " does not implement `execute` method."; - throw PythonBackendException(message); + for (auto& output_tensor : response->OutputTensors()) { + if (!output_tensor->IsCPU()) { + gpu_tensors_.push_back(output_tensor); } + } +} - py::object request_list = py_request_list; - py::module asyncio = py::module::import("asyncio"); +void +Stub::ProcessReturnedResponses( + py::list py_requests, py::object py_responses_obj, + std::optional>& response_batch) +{ + // Return if there is nothing to process. + if (py::isinstance(py_responses_obj)) { + return; + } + // Only non-decoupled may return responses. + if (IsDecoupled()) { + throw PythonBackendException( + "Python model '" + name_ + + "' is using the decoupled mode and the execute function must return " + "None."); + } + // Check responses is a list. + if (!py::isinstance(py_responses_obj)) { + throw PythonBackendException( + "Expected a list in the execute return, found type '" + + std::string(py::str(py_responses_obj.get_type())) + "'."); + } + py::list py_responses = py_responses_obj; + // Responses and requests length must match. + size_t requests_size = py::len(py_requests); + size_t responses_size = py::len(py_responses); + if (requests_size != responses_size) { + throw PythonBackendException( + "Number of InferenceResponse objects do not match the number of " + "InferenceRequest objects. InferenceRequest(s) size is:" + + std::to_string(requests_size) + ", and InferenceResponse(s) size is:" + + std::to_string(responses_size) + "\n"); + } - // Execute Response - py::object execute_return; - py::object responses_obj; - bool is_coroutine; + for (size_t i = 0; i < responses_size; i++) { + if (!py::isinstance(py_responses[i])) { + InferRequest* request = py_requests[i].cast(); + // Response must be None if rescheduled. + if (request->ReleaseFlags() == TRITONSERVER_REQUEST_RELEASE_RESCHEDULE) { + throw PythonBackendException( + "Expected a None object in the execute function return list for " + "reschduled request, found type '" + + std::string(py::str(py_responses[i].get_type())) + "'."); + } + // Send the response. + if (!py::isinstance(py_responses[i])) { + throw PythonBackendException( + "Expected an 'InferenceResponse' object in the execute function " + "return list, found type '" + + std::string(py::str(py_responses[i].get_type())) + "'."); + } - { - NVTX_RANGE(nvtx_, "PyExecute " + name_); - execute_return = model_instance_.attr("execute")(request_list); - is_coroutine = asyncio.attr("iscoroutine")(execute_return).cast(); + InferResponse* response = py_responses[i].cast(); + try { + request->GetResponseSender()->UpdateStateAndCounters( + response, TRITONSERVER_RESPONSE_COMPLETE_FINAL); + } + catch (const PythonBackendException& pb_exception) { + // Handle the exception here to catch the error when there's a response + // returned from `execute()`. + if (request->GetResponseSender()->IsClosed()) { + response_batch = std::move(shm_pool_->Construct( + sizeof(ResponseBatch) + sizeof(IPCMessageShm))); + ResponseBatch* response_batch_shm_ptr = + reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); + response_batch_shm_ptr->batch_size = 0; + response_batch_shm_ptr->is_response_factory_deleted = true; + } + throw pb_exception; + } } + } + // Return all the created responses using response_batch. The reason + // that both of the paths are available is that sending the responses + // using response_batch is faster than using `response_sender`. + response_batch = std::move(shm_pool_->Construct( + sizeof(IPCMessageShm) + + requests_size * sizeof(bi::managed_external_buffer::handle_t) + + sizeof(ResponseBatch))); + ResponseBatch* response_batch_shm_ptr = reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); - if (is_coroutine) { - responses_obj = asyncio.attr("run")(execute_return); + bi::managed_external_buffer::handle_t* responses_shm_handle = + reinterpret_cast( + response_batch.value().data_.get() + sizeof(ResponseBatch) + + sizeof(IPCMessageShm)); + for (size_t i = 0; i < responses_size; i++) { + // Check the return type of execute function. + InferRequest* infer_request = py_requests[i].cast(); + InferResponse* infer_response = py_responses[i].cast(); + if (!py::isinstance(py_responses[i])) { + infer_response->PruneOutputTensors(infer_request->RequestedOutputNames()); + ProcessResponse(infer_response); + responses_shm_handle[i] = infer_response->ShmHandle(); } else { - responses_obj = execute_return; + responses_shm_handle[i] = 0; } + } + response_batch_shm_ptr->batch_size = requests_size; +} - // Check the return type of execute function. - if (!py::isinstance(responses_obj)) { - std::string str = py::str(execute_return.get_type()); - throw PythonBackendException( - std::string("Expected a list in the execute return, found type '") + - str + "'."); - } +py::object +Stub::GetAsyncEventLoop() +{ + if (py::isinstance(async_event_loop_)) { + // Create the event loop if not already. + py::module asyncio = py::module_::import("asyncio"); + async_event_loop_ = asyncio.attr("new_event_loop")(); + asyncio.attr("set_event_loop")(async_event_loop_); + py::object py_thread = + py::module_::import("threading") + .attr("Thread")( + "target"_a = async_event_loop_.attr("run_forever"), + "daemon"_a = true); + py_thread.attr("start")(); + } + return async_event_loop_; +} - responses = responses_obj; - size_t response_size = py::len(responses); - - // If the number of request objects do not match the number of - // response objects throw an error. - if (response_size != batch_size) { - std::string err = - "Number of InferenceResponse objects do not match the number " - "of " - "InferenceRequest objects. InferenceRequest(s) size is:" + - std::to_string(batch_size) + ", and InferenceResponse(s) size is:" + - std::to_string(response_size) + "\n"; - throw PythonBackendException(err); +py::object +Stub::RunCoroutine(py::object coroutine, bool in_background) +{ + py::object loop = GetAsyncEventLoop(); + py::object py_future = py::module_::import("asyncio").attr( + "run_coroutine_threadsafe")(coroutine, loop); + if (in_background) { + py_future.attr("add_done_callback")( + py::module_::import("c_python_backend_utils") + .attr("async_event_future_done_callback")); + background_futures_.attr("add")(py_future); + return py::none(); + } + return py_future.attr("result")(); +} + +void +Stub::BackgroundFutureDone(const py::object& py_future) +{ + ScopedDefer _([this, &py_future] { + // Remove future from background + try { + background_futures_.attr("remove")(py_future); } - for (auto& response : responses) { - // Check the return type of execute function. - if (!py::isinstance(response)) { - std::string str = py::str(response.get_type()); - throw PythonBackendException( - std::string("Expected an 'InferenceResponse' object in the execute " - "function return list, found type '") + - str + "'."); - } + catch (const py::error_already_set& error) { + LOG_ERROR << "Cannot remove future from background; " << error.what(); } - response_batch_shm_ptr->batch_size = response_size; - - for (size_t i = 0; i < batch_size; i++) { - InferResponse* infer_response = responses[i].cast(); - InferRequest* infer_request = py_request_list[i].cast(); - infer_response->PruneOutputTensors(infer_request->RequestedOutputNames()); - - ProcessResponse(infer_response); - responses_shm_handle[i] = infer_response->ShmHandle(); + }); + // TODO: Why using `py_future.result()` with error hangs on exit? + try { + py::object exception = py_future.attr("exception")(); + if (!py::isinstance(exception)) { + std::string err_msg = ""; + py::object traceback = py::module_::import("traceback") + .attr("TracebackException") + .attr("from_exception")(exception) + .attr("format")(); + for (py::handle line : traceback) { + err_msg += py::str(line); + } + LOG_ERROR << err_msg; } } catch (const PythonBackendException& pb_exception) { - has_exception = true; - error_string = pb_exception.what(); + LOG_ERROR << pb_exception.what(); } catch (const py::error_already_set& error) { - has_exception = true; - error_string = error.what(); - } - - if (has_exception) { - std::string err_message = - std::string( - "Failed to process the request(s) for model '" + name_ + - "', message: ") + - error_string; - error_string_shm = PbString::Create(shm_pool_, error_string); - response_batch_shm_ptr->has_error = true; - response_batch_shm_ptr->is_error_set = true; - response_batch_shm_ptr->error = error_string_shm->ShmHandle(); + LOG_ERROR << error.what(); } } @@ -825,13 +987,19 @@ void Stub::Finalize() { finalizing_ = true; - // Call finalize if exists. - if (initialized_ && py::hasattr(model_instance_, "finalize")) { - try { - model_instance_.attr("finalize")(); + if (initialized_) { + // Stop async event loop if created. + if (!py::isinstance(async_event_loop_)) { + async_event_loop_.attr("stop")(); } - catch (const py::error_already_set& e) { - LOG_INFO << e.what(); + // Call finalize if exists. + if (py::hasattr(model_instance_, "finalize")) { + try { + model_instance_.attr("finalize")(); + } + catch (const py::error_already_set& e) { + LOG_INFO << e.what(); + } } } #ifdef TRITON_ENABLE_GPU @@ -870,11 +1038,31 @@ Stub::SendIPCUtilsMessage(std::unique_ptr& ipc_message) Stub::~Stub() { - { +#ifdef TRITON_ENABLE_GPU + try { + if (shm_pool_ != nullptr) { + CUDAHandler& cuda_api = CUDAHandler::getInstance(); + for (auto& m : + shm_pool_->GetCUDAMemoryPoolManager()->CUDAPoolAddressMap()) { + if (m.second != nullptr) { + cuda_api.CloseCudaHandle(m.first, m.second); + } + } + } + } + catch (const PythonBackendException& pb_exception) { + std::cerr << "Error when closing CUDA handle: " << pb_exception.what(); + } +#endif + + // Ensure the interpreter is active before trying to clean up. + if (Py_IsInitialized()) { py::gil_scoped_acquire acquire; - model_instance_ = py::none(); + py::object async_event_loop_local(std::move(async_event_loop_)); + py::object background_futures_local(std::move(background_futures_)); + py::object model_instance_local(std::move(model_instance_)); } - stub_instance_.reset(); + stub_message_queue_.reset(); parent_message_queue_.reset(); stub_to_parent_mq_.reset(); @@ -943,8 +1131,18 @@ Stub::ServiceStubToParentRequests() stub_to_parent_buffer_.pop(); if (utils_msg_payload->command_type == PYTHONSTUB_LogRequest) { SendLogMessage(utils_msg_payload); - } else if (utils_msg_payload->command_type == PYTHONSTUB_CleanupRequest) { - SendCleanupId(utils_msg_payload); + } else if ( + (utils_msg_payload->command_type == + PYTHONSTUB_BLSDecoupledInferPayloadCleanup) || + (utils_msg_payload->command_type == + PYTHONSTUB_DecoupledResponseFactoryCleanup)) { + SendCleanupId(utils_msg_payload, utils_msg_payload->command_type); + } else if ( + utils_msg_payload->command_type == PYTHONSTUB_IsRequestCancelled) { + SendIsCancelled(utils_msg_payload); + } else if ( + utils_msg_payload->command_type == PYTHONSTUB_CancelBLSInferRequest) { + SendCancelBLSRequest(utils_msg_payload); } else { std::cerr << "Error when sending message via stub_to_parent message " "buffer - unknown command\n"; @@ -987,17 +1185,19 @@ Stub::SendLogMessage(std::unique_ptr& utils_msg_payload) } void -Stub::SendCleanupId(std::unique_ptr& utils_msg_payload) +Stub::SendCleanupId( + std::unique_ptr& utils_msg_payload, + const PYTHONSTUB_CommandType& command_type) { void* id = utils_msg_payload->utils_message_ptr; - { + if (command_type == PYTHONSTUB_BLSDecoupledInferPayloadCleanup) { std::lock_guard lock(response_iterator_map_mu_); response_iterator_map_.erase(id); } std::unique_ptr ipc_message = IPCMessage::Create(shm_pool_, true /* inline_response */); - ipc_message->Command() = PYTHONSTUB_CleanupRequest; + ipc_message->Command() = command_type; AllocatedSharedMemory cleanup_request_message = shm_pool_->Construct( sizeof(CleanupMessage) + @@ -1019,15 +1219,93 @@ Stub::SendCleanupId(std::unique_ptr& utils_msg_payload) } void -Stub::EnqueueCleanupId(void* id) +Stub::EnqueueCleanupId(void* id, const PYTHONSTUB_CommandType& command_type) { if (id != nullptr) { std::unique_ptr utils_msg_payload = - std::make_unique(PYTHONSTUB_CleanupRequest, id); + std::make_unique(command_type, id); EnqueueUtilsMessage(std::move(utils_msg_payload)); } } +void +Stub::SendCancelBLSRequest( + std::unique_ptr& utils_msg_payload) +{ + PbBLSCancel* pb_bls_cancel = + reinterpret_cast(utils_msg_payload->utils_message_ptr); + pb_bls_cancel->SaveToSharedMemory(shm_pool_); + + CancelBLSRequestMessage* message_payload = pb_bls_cancel->ShmPayload(); + std::unique_ptr ipc_message = + IPCMessage::Create(shm_pool_, false /* inline_response */); + ipc_message->Command() = utils_msg_payload->command_type; + ipc_message->Args() = pb_bls_cancel->ShmHandle(); + + bool is_cancelled = false; + { + bi::scoped_lock lk(message_payload->mu); + + SendIPCUtilsMessage(ipc_message); + while (!message_payload->waiting_on_stub) { + message_payload->cv.wait(lk); + } + + is_cancelled = message_payload->is_cancelled; + message_payload->waiting_on_stub = false; + message_payload->cv.notify_all(); + } + pb_bls_cancel->ReportIsCancelled(is_cancelled); +} + +void +Stub::EnqueueCancelBLSRequest(PbBLSCancel* pb_bls_cancel) +{ + std::unique_ptr utils_msg_payload = + std::make_unique( + PYTHONSTUB_CancelBLSInferRequest, + reinterpret_cast(pb_bls_cancel)); + EnqueueUtilsMessage(std::move(utils_msg_payload)); +} + +void +Stub::EnqueueIsCancelled(PbCancel* pb_cancel) +{ + std::unique_ptr utils_msg_payload = + std::make_unique( + PYTHONSTUB_IsRequestCancelled, reinterpret_cast(pb_cancel)); + EnqueueUtilsMessage(std::move(utils_msg_payload)); +} + +void +Stub::SendIsCancelled(std::unique_ptr& utils_msg_payload) +{ + PbCancel* pb_cancel = + reinterpret_cast(utils_msg_payload->utils_message_ptr); + pb_cancel->SaveToSharedMemory(shm_pool_); + + IsCancelledMessage* message_payload = pb_cancel->ShmPayload(); + std::unique_ptr ipc_message = + IPCMessage::Create(shm_pool_, false /* inline_response */); + ipc_message->Command() = utils_msg_payload->command_type; + ipc_message->Args() = pb_cancel->ShmHandle(); + + bool is_cancelled = false; + { + bi::scoped_lock lk(message_payload->mu); + + SendIPCUtilsMessage(ipc_message); + while (!message_payload->waiting_on_stub) { + message_payload->cv.wait(lk); + } + + is_cancelled = message_payload->is_cancelled; + message_payload->waiting_on_stub = false; + message_payload->cv.notify_all(); + } + pb_cancel->ReportIsCancelled(is_cancelled); +} + bool Stub::StubToParentServiceActive() { @@ -1062,86 +1340,18 @@ Stub::ParentToStubMQMonitor() break; } - std::unique_ptr ipc_message; - ResponseBatch* response_batch = nullptr; - bi::managed_external_buffer::handle_t* response_handle = nullptr; - std::unique_ptr infer_response; - bool responses_is_set = false; - PythonBackendException pb_exception(std::string{}); - - try { - ipc_message = IPCMessage::LoadFromSharedMemory(shm_pool_, handle); - AllocatedSharedMemory response_batch_shm = - shm_pool_->Load(ipc_message->Args()); - response_batch = - reinterpret_cast(response_batch_shm.data_.get()); - response_handle = - reinterpret_cast( - response_batch_shm.data_.get() + sizeof(ResponseBatch)); - responses_is_set = true; - - if (response_batch->has_error) { - if (response_batch->is_error_set) { - std::unique_ptr pb_string = - PbString::LoadFromSharedMemory(shm_pool_, response_batch->error); - infer_response = std::make_unique( - std::vector>{}, - std::make_shared(pb_string->String())); - } else { - infer_response = std::make_unique( - std::vector>{}, - std::make_shared( - "An error occurred while performing BLS request.")); - } - } - - if (responses_is_set) { - infer_response = InferResponse::LoadFromSharedMemory( - shm_pool_, *response_handle, true /* open cuda handle */); - - for (auto& output_tensor : infer_response->OutputTensors()) { - if (!output_tensor->IsCPU()) { - uint64_t memory_release_id = - output_tensor->Memory()->MemoryReleaseId(); - output_tensor->Memory()->SetMemoryReleaseCallback( - [this, memory_release_id]() { - this->MemoryManagerQueue()->Push(memory_release_id); - }); - } - } - } else { - infer_response = std::make_unique( - std::vector>{}, - std::make_shared( - "An error occurred while performing BLS request.")); - } - } - catch (const PythonBackendException& pb_exception) { - infer_response = std::make_unique( - std::vector>{}, - std::make_shared(pb_exception.what())); - } - - { - std::lock_guard lock(response_iterator_map_mu_); - if (response_iterator_map_.find(infer_response->Id()) != - response_iterator_map_.end()) { - response_iterator_map_[infer_response->Id()]->EnqueueResponse( - std::move(infer_response)); - } else { - auto response_iterator = - std::make_shared(std::move(infer_response)); - response_iterator_map_.insert( - std::pair>( - response_iterator->Id(), response_iterator)); - } - } - - { - bi::scoped_lock lock{ - *(ipc_message->ResponseMutex())}; - response_batch->waiting_on_stub = true; - ipc_message->ResponseCondition()->notify_all(); + std::unique_ptr ipc_message = + IPCMessage::LoadFromSharedMemory(shm_pool_, handle); + + switch (ipc_message->Command()) { + case PYTHONSTUB_CommandType::PYTHONSTUB_CUDAPoolInitializeRequest: { + GetCUDAMemoryPoolAddress(ipc_message); + } break; + case PYTHONSTUB_CommandType::PYTHONSTUB_InferStreamExecResponse: { + ProcessBLSResponseDecoupled(ipc_message); + } break; + default: + break; } } } @@ -1225,130 +1435,196 @@ Stub::GetProxyStream(const int& device_id) #endif } -std::unique_ptr Logger::log_instance_; - -std::unique_ptr& -Logger::GetOrCreateInstance() -{ - if (Logger::log_instance_.get() == nullptr) { - Logger::log_instance_ = std::make_unique(); - } - - return Logger::log_instance_; -} - -// Bound function, called from the python client void -Logger::Log(const std::string& message, LogLevel level) +Stub::GetCUDAMemoryPoolAddress(std::unique_ptr& ipc_message) { - std::unique_ptr& stub = Stub::GetOrCreateInstance(); - py::object frame = py::module_::import("inspect").attr("currentframe"); - py::object caller_frame = frame(); - py::object info = py::module_::import("inspect").attr("getframeinfo"); - py::object caller_info = info(caller_frame); - py::object filename_python = caller_info.attr("filename"); - std::string filename = filename_python.cast(); - py::object lineno = caller_info.attr("lineno"); - uint32_t line = lineno.cast(); - - if (!stub->StubToParentServiceActive()) { - Logger::GetOrCreateInstance()->Log(filename, line, level, message); - } else { - std::unique_ptr log_msg(new PbLog(filename, line, message, level)); - stub->EnqueueLogRequest(log_msg); +#ifdef TRITON_ENABLE_GPU + bool has_exception = false; + std::string error_string; + std::unique_ptr error_string_shm; + + CUDAMemPoolMessage* cuda_pool_message_ptr = nullptr; + try { + AllocatedSharedMemory cuda_handle_shm = + shm_pool_->Load(ipc_message->Args()); + cuda_pool_message_ptr = cuda_handle_shm.data_.get(); + + CUDAHandler& cuda_api = CUDAHandler::getInstance(); + void* cuda_pool_address; + cuda_api.OpenCudaHandle( + cuda_pool_message_ptr->device_id, &cuda_pool_message_ptr->cuda_handle, + &cuda_pool_address); + shm_pool_->GetCUDAMemoryPoolManager()->SetCUDAPoolAddress( + cuda_pool_message_ptr->device_id, cuda_pool_address); + } + catch (const PythonBackendException& pb_exception) { + has_exception = true; + error_string = pb_exception.what(); + shm_pool_->GetCUDAMemoryPoolManager()->SetCUDAPoolAddress( + cuda_pool_message_ptr->device_id, nullptr); } -} -// Called internally (.e.g. LOG_ERROR << "Error"; ) -void -Logger::Log( - const std::string& filename, uint32_t lineno, LogLevel level, - const std::string& message) -{ - // If the log monitor service is not active yet, format - // and pass messages to cerr - if (!BackendLoggingActive()) { - std::string path(filename); - size_t pos = path.rfind('/'); - if (pos != std::string::npos) { - path = path.substr(pos + 1, std::string::npos); + if (has_exception) { + LOG_INFO << "Failed to initialize CUDA shared memory pool in Python stub: " + << error_string; + cuda_pool_message_ptr->has_error = true; + cuda_pool_message_ptr->is_error_set = false; + + LOG_IF_EXCEPTION( + error_string_shm = PbString::Create(shm_pool_, error_string)); + if (error_string_shm != nullptr) { + cuda_pool_message_ptr->is_error_set = true; + cuda_pool_message_ptr->error = error_string_shm->ShmHandle(); } - std::stringstream ss; - struct timeval tv; - gettimeofday(&tv, NULL); - struct tm tm_time; - gmtime_r(((time_t*)&(tv.tv_sec)), &tm_time); - ss << LeadingLogChar(level) << std::setfill('0') << std::setw(2) - << (tm_time.tm_mon + 1) << std::setw(2) << tm_time.tm_mday << " " - << std::setw(2) << tm_time.tm_hour << ':' << std::setw(2) - << tm_time.tm_min << ':' << std::setw(2) << tm_time.tm_sec << "." - << std::setw(6) << tv.tv_usec << ' ' << static_cast(getpid()) - << ' ' << path << ':' << lineno << "] "; - std::cerr << ss.str() << " " << message << std::endl; - } else { - // Ensure we do not create a stub instance before it has initialized - std::unique_ptr& stub = Stub::GetOrCreateInstance(); - std::unique_ptr log_msg(new PbLog(filename, lineno, message, level)); - stub->EnqueueLogRequest(log_msg); } -} - -void -Logger::LogInfo(const std::string& message) -{ - Logger::Log(message, LogLevel::INFO); -} -void -Logger::LogWarn(const std::string& message) -{ - Logger::Log(message, LogLevel::WARNING); + { + bi::scoped_lock lock{ + *(ipc_message->ResponseMutex())}; + cuda_pool_message_ptr->waiting_on_stub = true; + ipc_message->ResponseCondition()->notify_all(); + while (cuda_pool_message_ptr->waiting_on_stub) { + ipc_message->ResponseCondition()->wait(lock); + } + } +#endif } void -Logger::LogError(const std::string& message) +Stub::ProcessBLSResponseDecoupled(std::unique_ptr& ipc_message) { - Logger::Log(message, LogLevel::ERROR); -} + ResponseBatch* response_batch = nullptr; + bi::managed_external_buffer::handle_t* response_handle = nullptr; + std::unique_ptr infer_response; + bool responses_is_set = false; + PythonBackendException pb_exception(std::string{}); -void -Logger::LogVerbose(const std::string& message) -{ - Logger::Log(message, LogLevel::VERBOSE); -} + try { + AllocatedSharedMemory response_batch_shm = + shm_pool_->Load(ipc_message->Args()); + response_batch = + reinterpret_cast(response_batch_shm.data_.get()); + response_handle = reinterpret_cast( + response_batch_shm.data_.get() + sizeof(ResponseBatch)); + responses_is_set = true; + + if (response_batch->has_error) { + if (response_batch->is_error_set) { + std::unique_ptr pb_string = + PbString::LoadFromSharedMemory(shm_pool_, response_batch->error); + infer_response = std::make_unique( + std::vector>{}, + std::make_shared(pb_string->String())); + } else { + infer_response = std::make_unique( + std::vector>{}, + std::make_shared( + "An error occurred while performing BLS request.")); + } + } -const std::string -Logger::LeadingLogChar(const LogLevel& level) -{ - switch (level) { - case LogLevel::WARNING: - return "W"; - case LogLevel::ERROR: - return "E"; - case LogLevel::INFO: - case LogLevel::VERBOSE: - default: - return "I"; + if (responses_is_set) { + infer_response = InferResponse::LoadFromSharedMemory( + shm_pool_, *response_handle, true /* open cuda handle */); + + for (auto& output_tensor : infer_response->OutputTensors()) { + if (!output_tensor->IsCPU()) { + uint64_t memory_release_id = + output_tensor->Memory()->MemoryReleaseId(); + output_tensor->Memory()->SetMemoryReleaseCallback( + [this, memory_release_id]() { + this->MemoryManagerQueue()->Push(memory_release_id); + }); + } + } + } else { + infer_response = std::make_unique( + std::vector>{}, + std::make_shared( + "An error occurred while performing BLS request.")); + } + } + catch (const PythonBackendException& pb_exception) { + infer_response = std::make_unique( + std::vector>{}, + std::make_shared(pb_exception.what())); } -} -void -Logger::SetBackendLoggingActive(bool status) -{ - backend_logging_active_ = status; -} + { + std::lock_guard lock(response_iterator_map_mu_); + if (response_iterator_map_.find(infer_response->Id()) != + response_iterator_map_.end()) { + response_iterator_map_[infer_response->Id()]->EnqueueResponse( + std::move(infer_response)); + } else { + auto response_iterator = + std::make_shared(std::move(infer_response)); + response_iterator_map_.insert( + std::pair>( + response_iterator->Id(), response_iterator)); + } + } -bool -Logger::BackendLoggingActive() -{ - return backend_logging_active_; + { + bi::scoped_lock lock{ + *(ipc_message->ResponseMutex())}; + response_batch->waiting_on_stub = true; + ipc_message->ResponseCondition()->notify_all(); + } } PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) { - py::class_>(module, "TritonError") - .def(py::init()) - .def("message", &PbError::Message); + py::class_> triton_error( + module, "TritonError"); + py::enum_(triton_error, "__ErrorCode") + .value("UNKNOWN", TRITONSERVER_Error_Code::TRITONSERVER_ERROR_UNKNOWN) + .value("INTERNAL", TRITONSERVER_Error_Code::TRITONSERVER_ERROR_INTERNAL) + .value("NOT_FOUND", TRITONSERVER_Error_Code::TRITONSERVER_ERROR_NOT_FOUND) + .value( + "INVALID_ARG", + TRITONSERVER_Error_Code::TRITONSERVER_ERROR_INVALID_ARG) + .value( + "UNAVAILABLE", + TRITONSERVER_Error_Code::TRITONSERVER_ERROR_UNAVAILABLE) + .value( + "UNSUPPORTED", + TRITONSERVER_Error_Code::TRITONSERVER_ERROR_UNSUPPORTED) + .value( + "ALREADY_EXISTS", + TRITONSERVER_Error_Code::TRITONSERVER_ERROR_ALREADY_EXISTS) + .value("CANCELLED", TRITONSERVER_Error_Code::TRITONSERVER_ERROR_CANCELLED) + .export_values(); + triton_error.def_property_readonly_static( + "UNKNOWN", + [](py::object /* self */) { return TRITONSERVER_ERROR_UNKNOWN; }); + triton_error.def_property_readonly_static( + "INTERNAL", + [](py::object /* self */) { return TRITONSERVER_ERROR_INTERNAL; }); + triton_error.def_property_readonly_static( + "NOT_FOUND", + [](py::object /* self */) { return TRITONSERVER_ERROR_NOT_FOUND; }); + triton_error.def_property_readonly_static( + "INVALID_ARG", + [](py::object /* self */) { return TRITONSERVER_ERROR_INVALID_ARG; }); + triton_error.def_property_readonly_static( + "UNAVAILABLE", + [](py::object /* self */) { return TRITONSERVER_ERROR_UNAVAILABLE; }); + triton_error.def_property_readonly_static( + "UNSUPPORTED", + [](py::object /* self */) { return TRITONSERVER_ERROR_UNSUPPORTED; }); + triton_error.def_property_readonly_static( + "ALREADY_EXISTS", + [](py::object /* self */) { return TRITONSERVER_ERROR_ALREADY_EXISTS; }); + triton_error.def_property_readonly_static( + "CANCELLED", + [](py::object /* self */) { return TRITONSERVER_ERROR_CANCELLED; }); + triton_error.def( + py::init(), + py::arg("message").none(false), + py::arg("code").none(false) = TRITONSERVER_ERROR_INTERNAL); + triton_error.def("code", &PbError::Code); + triton_error.def("message", &PbError::Message); py::class_>( module, "PreferredMemory") @@ -1358,31 +1634,58 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) py::arg("preferred_device_id").none(false) = 0); py::enum_(module, "MemoryType") - .value("TRITONSERVER_MEMORY_GPU", PreferredMemory::MemoryType::GPU) - .value("TRITONSERVER_MEMORY_CPU", PreferredMemory::MemoryType::CPU) + .value("TRITONSERVER_MEMORY_GPU", PreferredMemory::MemoryType::kGPU) + .value("TRITONSERVER_MEMORY_CPU", PreferredMemory::MemoryType::kCPU) .export_values(); + py::class_>( + module, "InferenceTrace") + .def("get_context", [](InferenceTrace& self) -> py::object { + auto context = self.Context(); + if (context != "") { + return py::str(context); + } + return py::none(); + }); + py::class_>( module, "InferenceRequest") .def( - py::init([](const std::string& request_id, uint64_t correlation_id, - const std::vector>& inputs, - const std::vector& requested_output_names, - const std::string& model_name, - const int64_t model_version, const uint32_t flags, - const int32_t timeout, - const PreferredMemory& preferred_memory) { - std::set requested_outputs; - for (auto& requested_output_name : requested_output_names) { - requested_outputs.emplace(requested_output_name); - } - // FIXME: InferenceRequest parameters are not supported in BLS now. - return std::make_shared( - request_id, correlation_id, inputs, requested_outputs, - model_name, model_version, "" /*parameters*/, flags, timeout, - 0 /*response_factory_address*/, 0 /*request_address*/, - preferred_memory); - }), + py::init( + [](const std::string& request_id, + const py::object& correlation_id, + const std::vector>& inputs, + const std::vector& requested_output_names, + const std::string& model_name, const int64_t model_version, + const uint32_t flags, const uint64_t timeout, + const PreferredMemory& preferred_memory, + const InferenceTrace& trace, const py::object& parameters_) { + py::dict parameters = + PyDefaultArgumentToMutableType(parameters_); + std::set requested_outputs; + for (auto& requested_output_name : requested_output_names) { + requested_outputs.emplace(requested_output_name); + } + std::string parameters_str = PyParametersToJSON(parameters); + + CorrelationId correlation_id_obj; + if (py::isinstance(correlation_id)) { + correlation_id_obj = + CorrelationId(py::cast(correlation_id)); + } else if (py::isinstance(correlation_id)) { + correlation_id_obj = + CorrelationId(py::cast(correlation_id)); + } else { + throw PythonBackendException( + "Correlation ID must be integer or string"); + } + + return std::make_shared( + request_id, correlation_id_obj, inputs, requested_outputs, + model_name, model_version, parameters_str, flags, timeout, + 0 /*response_factory_address*/, 0 /*request_address*/, + preferred_memory, trace); + }), py::arg("request_id").none(false) = "", py::arg("correlation_id").none(false) = 0, py::arg("inputs").none(false), @@ -1391,16 +1694,28 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) py::arg("model_version").none(false) = -1, py::arg("flags").none(false) = 0, py::arg("timeout").none(false) = 0, py::arg("preferred_memory").none(false) = - PreferredMemory(PreferredMemory::DEFAULT, 0)) + PreferredMemory(PreferredMemory::kDefault, 0), + py::arg("trace").none(false) = InferenceTrace(), + py::arg("parameters").none(true) = py::none()) .def( "inputs", &InferRequest::Inputs, py::return_value_policy::reference_internal) .def("request_id", &InferRequest::RequestId) - .def("correlation_id", &InferRequest::CorrelationId) + .def( + "correlation_id", + [](InferRequest& self) -> py::object { + CorrelationId correlation_id = self.GetCorrelationId(); + if (correlation_id.Type() == CorrelationIdDataType::STRING) { + return py::cast(correlation_id.StringValue()); + } else { + return py::cast(correlation_id.UnsignedIntValue()); + } + }) .def("flags", &InferRequest::Flags) .def("set_flags", &InferRequest::SetFlags) .def("timeout", &InferRequest::Timeout) .def("parameters", &InferRequest::Parameters) + .def("trace", &InferRequest::GetTrace) .def( "exec", [](std::shared_ptr& infer_request, @@ -1424,11 +1739,6 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) [](std::shared_ptr& infer_request, const bool decoupled) { std::unique_ptr& stub = Stub::GetOrCreateInstance(); - if (stub->IsDecoupled()) { - throw PythonBackendException( - "Async BLS request execution is not support in the decoupled " - "API."); - } py::object loop = py::module_::import("asyncio").attr("get_running_loop")(); py::cpp_function callback = [&stub, infer_request, decoupled]() { @@ -1452,7 +1762,10 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) .def( "requested_output_names", &InferRequest::RequestedOutputNames, py::return_value_policy::reference_internal) - .def("get_response_sender", &InferRequest::GetResponseSender); + .def("get_response_sender", &InferRequest::GetResponseSender) + .def("is_cancelled", &InferRequest::IsCancelled) + .def("set_release_flags", &InferRequest::SetReleaseFlags), + py::arg("flags").none(false); py::class_>(module, "Tensor") .def(py::init(&PbTensor::FromNumpy)) @@ -1475,39 +1788,55 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) py::class_>( module, "InferenceResponse") .def( - py::init< - const std::vector>&, - std::shared_ptr>(), + py::init( + [](const std::vector>& output_tensors, + const std::shared_ptr& error, + const py::object& parameters_) { + py::dict parameters = + PyDefaultArgumentToMutableType(parameters_); + std::string parameters_str = PyParametersToJSON(parameters); + return std::make_shared( + output_tensors, error, parameters_str /* parameters */); + }), py::arg("output_tensors") = py::list(), - py::arg("error") = static_cast>(nullptr)) + py::arg("error") = static_cast>(nullptr), + py::arg("parameters") = py::none()) .def( "output_tensors", &InferResponse::OutputTensors, py::return_value_policy::reference) .def("has_error", &InferResponse::HasError) - .def("error", &InferResponse::Error); + .def("error", &InferResponse::Error) + .def("parameters", &InferResponse::Parameters); py::class_>( module, "InferenceResponseSender") .def( "send", &ResponseSender::Send, py::arg("response") = nullptr, - py::arg("flags") = 0); + py::arg("flags") = 0) + .def("is_cancelled", &ResponseSender::IsCancelled); py::class_>( module, "ResponseIterator") .def(py::init&>()) - .def("__iter__", &ResponseIterator::Iter, py::keep_alive<0, 1>()) - .def("__next__", &ResponseIterator::Next); + .def( + "__iter__", + [](ResponseIterator& it) -> ResponseIterator& { + it.Iter(); + return it; + }) + .def("__next__", &ResponseIterator::Next) + .def("cancel", &ResponseIterator::Cancel); py::class_ logger(module, "Logger"); py::enum_(logger, "LogLevel") - .value("INFO", LogLevel::INFO) - .value("WARNING", LogLevel::WARNING) - .value("ERROR", LogLevel::ERROR) - .value("VERBOSE", LogLevel::VERBOSE) + .value("INFO", LogLevel::kInfo) + .value("WARNING", LogLevel::kWarning) + .value("ERROR", LogLevel::kError) + .value("VERBOSE", LogLevel::kVerbose) .export_values(); logger.def_static( "log", py::overload_cast(&Logger::Log), - py::arg("message"), py::arg("level") = LogLevel::INFO); + py::arg("message"), py::arg("level") = LogLevel::kInfo); logger.def_static("log_info", &Logger::LogInfo, py::arg("message")); logger.def_static("log_warn", &Logger::LogWarn, py::arg("message")); logger.def_static("log_error", &Logger::LogError, py::arg("message")); @@ -1516,11 +1845,13 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) py::class_>(module, "Metric") .def("increment", &Metric::SendIncrementRequest) .def("set", &Metric::SendSetValueRequest) + .def("observe", &Metric::SendObserveRequest) .def("value", &Metric::SendGetValueRequest); py::enum_(module, "MetricKind") - .value("COUNTER", MetricKind::COUNTER) - .value("GAUGE", MetricKind::GAUGE) + .value("COUNTER", MetricKind::kCounter) + .value("GAUGE", MetricKind::kGauge) + .value("HISTOGRAM", MetricKind::kHistogram) .export_values(); py::class_>( @@ -1531,9 +1862,11 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) py::arg("kind").none(false)) .def( "Metric", &MetricFamily::CreateMetric, - py::arg("labels").none(true) = py::none()); - module.attr("MetricFamily").attr("COUNTER") = MetricKind::COUNTER; - module.attr("MetricFamily").attr("GAUGE") = MetricKind::GAUGE; + py::arg("labels").none(true) = py::none(), + py::arg("buckets").none(true) = py::none()); + module.attr("MetricFamily").attr("COUNTER") = MetricKind::kCounter; + module.attr("MetricFamily").attr("GAUGE") = MetricKind::kGauge; + module.attr("MetricFamily").attr("HISTOGRAM") = MetricKind::kHistogram; module.def( "load_model", &LoadModel, py::arg("model_name").none(false), @@ -1546,6 +1879,12 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) "is_model_ready", &IsModelReady, py::arg("model_name").none(false), py::arg("model_version").none(false) = ""); + // This function is not part of the public API for Python backend. This is + // only used for internal callbacks. + module.def( + "async_event_future_done_callback", &AsyncEventFutureDoneCallback, + py::arg("py_future").none(false)); + // This class is not part of the public API for Python backend. This is only // used for internal testing purposes. py::class_(module, "SharedMemory") @@ -1558,64 +1897,38 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) void ModelContext::Init( - const std::string& model_path, const std::string& platform, + const std::string& model_path, const std::string& runtime_modeldir, const std::string& triton_install_path, const std::string& model_version) { - bool python_model_found = false; - std::string platform_model_path; - - if (platform != "NONE") { - platform_model_path = - triton_install_path + "/platform_handlers/" + platform + "/model.py"; - // Check if model file exists in the path. - struct stat buffer; - if (stat(platform_model_path.c_str(), &buffer) == 0) { - // Use the Platform model for serving the model. - python_model_found = true; - type_ = ModelType::PLATFORM; - python_model_path_ = platform_model_path; - // Trimming the model name from the model path, the platform model - // will populate the expected default model file name into model_path_. - model_dir_ = model_path.substr(0, model_path.find_last_of("\\/")); - } else { - LOG_WARN << "Unable to find model(handler) \'" << platform_model_path - << "\' for platform field \'" << platform << "\'"; - } - } - - if (!python_model_found) { + const char os_slash = std::filesystem::path::preferred_separator; + type_ = ModelType::kDefault; + if (runtime_modeldir != "DEFAULT") { + // For python based backends, existence of `model.py` in the corresponding + // backend folder happens on the core side, so we can omit this check here. + python_model_path_ = runtime_modeldir + os_slash + "model.py"; + type_ = ModelType::kBackend; + } else { python_model_path_ = model_path; // Check if model file exists in this path. struct stat buffer; - if (stat(python_model_path_.c_str(), &buffer) == 0) { - python_model_found = true; - type_ = ModelType::DEFAULT; - } - // Initializing here for consistency with platform model case. - model_dir_ = model_path.substr(0, model_path.find_last_of("\\/")); - } - - if (!python_model_found) { - if (platform != "NONE") { - throw PythonBackendException( - ("Python model file not found in neither \'" + platform_model_path + - "\' nor \'" + model_path + "\'")); - } else { + if (stat(python_model_path_.c_str(), &buffer) != 0) { throw PythonBackendException( ("Python model file not found in \'" + model_path + "\'")); } } + model_dir_ = model_path.substr(0, model_path.find_last_of(os_slash)); python_backend_folder_ = triton_install_path; model_version_ = model_version; - platform_ = platform; + runtime_modeldir_ = runtime_modeldir; } void ModelContext::StubSetup(py::module& sys) { + const char os_slash = std::filesystem::path::preferred_separator; std::string model_name = - python_model_path_.substr(python_model_path_.find_last_of("/") + 1); + python_model_path_.substr(python_model_path_.find_last_of(os_slash) + 1); // Model name without the .py extension auto dotpy_pos = model_name.find_last_of(".py"); @@ -1628,25 +1941,42 @@ ModelContext::StubSetup(py::module& sys) // returned by 'find_last_of'. Need to manually adjust the position. std::string model_name_trimmed = model_name.substr(0, dotpy_pos - 2); - if (type_ == ModelType::DEFAULT) { + if (type_ == ModelType::kDefault) { std::string model_path_parent = - python_model_path_.substr(0, python_model_path_.find_last_of("/")); + python_model_path_.substr(0, python_model_path_.find_last_of(os_slash)); std::string model_path_parent_parent = - model_path_parent.substr(0, model_path_parent.find_last_of("/")); + model_path_parent.substr(0, model_path_parent.find_last_of(os_slash)); sys.attr("path").attr("append")(model_path_parent); sys.attr("path").attr("append")(model_path_parent_parent); sys.attr("path").attr("append")(python_backend_folder_); sys = py::module_::import( (std::string(model_version_) + "." + model_name_trimmed).c_str()); } else { - std::string platform_model_dir( - python_backend_folder_ + "/platform_handlers/" + platform_ + "/"); - sys.attr("path").attr("append")(platform_model_dir); + std::string model_path_parent = + python_model_path_.substr(0, python_model_path_.find_last_of(os_slash)); + std::string backend_model_dir(model_path_parent); + sys.attr("path").attr("append")(backend_model_dir); sys.attr("path").attr("append")(python_backend_folder_); sys = py::module_::import(model_name_trimmed.c_str()); } } +#ifdef _WIN32 +bool +ParentProcessActive(DWORD parent_id) +{ + HANDLE parent = OpenProcess(PROCESS_ALL_ACCESS, FALSE, parent_id); + DWORD exit_code; + GetExitCodeProcess(parent, &exit_code); + return (exit_code == STILL_ACTIVE); +} +#else +bool +ParentProcessActive(pid_t parent_id) +{ + return (kill(parent_id, 0) == 0); +} +#endif extern "C" { @@ -1671,8 +2001,9 @@ main(int argc, char** argv) // Find the package name from model path. size_t prev = 0, pos = 0; + const char os_slash = std::filesystem::path::preferred_separator; do { - pos = model_path.find("/", prev); + pos = model_path.find(os_slash, prev); if (pos == std::string::npos) pos = model_path.length(); std::string token = model_path.substr(prev, pos - prev); @@ -1690,25 +2021,29 @@ main(int argc, char** argv) int64_t shm_growth_size = std::stol(argv[4]); std::string triton_install_path = argv[6]; std::string name = argv[8]; - std::string platform = argv[9]; + std::string runtime_modeldir = argv[9]; std::unique_ptr& stub = Stub::GetOrCreateInstance(); try { stub->Instantiate( shm_growth_size, shm_default_size, shm_region_name, model_path, model_version, argv[6] /* triton install path */, - std::stoi(argv[7]) /* IPCControl handle */, name, platform); + std::stoi(argv[7]) /* IPCControl handle */, name, runtime_modeldir); } catch (const PythonBackendException& pb_exception) { LOG_INFO << "Failed to preinitialize Python stub: " << pb_exception.what(); logger.reset(); + stub.reset(); exit(1); } // Start the Python Interpreter py::scoped_interpreter guard{}; +#ifdef _WIN32 + DWORD parent_pid = (DWORD)std::stoul(argv[5]); +#else pid_t parent_pid = std::stoi(argv[5]); - +#endif std::atomic background_thread_running = {true}; std::thread background_thread = std::thread([&parent_pid, &background_thread_running, &stub, &logger] { @@ -1727,7 +2062,7 @@ main(int argc, char** argv) stub->UpdateHealth(); - if (kill(parent_pid, 0) != 0) { + if (!ParentProcessActive(parent_pid)) { // When unhealthy, we should stop attempting to send // messages to the backend ASAP. if (stub->StubToParentServiceActive()) { diff --git a/src/pb_stub.h b/src/pb_stub.h index 6d047d29..942ecd98 100644 --- a/src/pb_stub.h +++ b/src/pb_stub.h @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -30,28 +30,15 @@ #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - #include "infer_request.h" #include "infer_response.h" #include "ipc_message.h" #include "message_queue.h" #include "metric.h" #include "metric_family.h" +#include "pb_cancel.h" #include "pb_log.h" #include "pb_response_iterator.h" -#include "pb_utils.h" namespace bi = boost::interprocess; @@ -64,104 +51,6 @@ using cudaStream_t = void*; namespace triton { namespace backend { namespace python { -#define LOG_IF_EXCEPTION(X) \ - do { \ - try { \ - (X); \ - } \ - catch (const PythonBackendException& pb_exception) { \ - LOG_INFO << pb_exception.what(); \ - } \ - } while (false) - -#define LOG_EXCEPTION(E) \ - do { \ - LOG_INFO << E.what(); \ - } while (false) - -/// Macros that use current filename and line number. -#define LOG_INFO LOG_FL(__FILE__, __LINE__, LogLevel::INFO) -#define LOG_WARN LOG_FL(__FILE__, __LINE__, LogLevel::WARNING) -#define LOG_ERROR LOG_FL(__FILE__, __LINE__, LogLevel::ERROR) -#define LOG_VERBOSE LOG_FL(__FILE__, __LINE__, LogLevel::VERBOSE) - -class Logger { - public: - Logger() { backend_logging_active_ = false; }; - ~Logger() { log_instance_.reset(); }; - /// Python client log function - static void Log(const std::string& message, LogLevel level = LogLevel::INFO); - - /// Python client log info function - static void LogInfo(const std::string& message); - - /// Python client warning function - static void LogWarn(const std::string& message); - - /// Python client log error function - static void LogError(const std::string& message); - - /// Python client log verbose function - static void LogVerbose(const std::string& message); - - /// Internal log function - void Log( - const std::string& filename, uint32_t lineno, LogLevel level, - const std::string& message); - - /// Log format helper function - const std::string LeadingLogChar(const LogLevel& level); - - /// Set PYBE Logging Status - void SetBackendLoggingActive(bool status); - - /// Get PYBE Logging Status - bool BackendLoggingActive(); - - /// Singleton Getter Function - static std::unique_ptr& GetOrCreateInstance(); - - DISALLOW_COPY_AND_ASSIGN(Logger); - - /// Flush the log. - void Flush() { std::cerr << std::flush; } - - private: - static std::unique_ptr log_instance_; - bool backend_logging_active_; -}; - -class LogMessage { - public: - /// Create a log message, stripping the path down to the filename only - LogMessage(const char* file, int line, LogLevel level) : level_(level) - { - std::string path(file); - size_t pos = path.rfind('/'); - if (pos != std::string::npos) { - path = path.substr(pos + 1, std::string::npos); - } - file_ = path; - line_ = static_cast(line); - } - /// Log message to console or send to backend (see Logger::Log for details) - ~LogMessage() - { - Logger::GetOrCreateInstance()->Log(file_, line_, level_, stream_.str()); - } - - std::stringstream& stream() { return stream_; } - - private: - std::stringstream stream_; - std::string file_; - uint32_t line_; - LogLevel level_; -}; - -#define LOG_FL(FN, LN, LVL) LogMessage((char*)(FN), LN, LVL).stream() - - class ModelContext { public: // Scans and establishes path for serving the python model. @@ -179,9 +68,15 @@ class ModelContext { std::string model_dir_; std::string model_version_; std::string python_backend_folder_; - std::string platform_; - - enum ModelType { DEFAULT, PLATFORM }; + std::string runtime_modeldir_; + + // Triton supports python-based backends, + // i.e. backends that provide common `model.py`, that can be re-used + // between different models. `ModelType` helps to differentiate + // between models running with c++ python backend (ModelType::kDefault) + // and models running with python-based backend (ModelType::kBackend) + // at the time of ModelContext::StubSetup to properly set up paths. + enum ModelType { kDefault, kBackend }; ModelType type_; }; @@ -209,7 +104,8 @@ class Stub { const std::string& shm_region_name, const std::string& model_path, const std::string& model_version, const std::string& triton_install_path, bi::managed_external_buffer::handle_t ipc_control_handle, - const std::string& model_instance_name, const std::string& platform); + const std::string& model_instance_name, + const std::string& runtime_modeldir); /// Get the health of the stub process. bool& Health(); @@ -255,7 +151,17 @@ class Stub { /// Execute a batch of requests. void ProcessRequests(RequestBatch* request_batch_shm_ptr); - void ProcessRequestsDecoupled(RequestBatch* request_batch_shm_ptr); + void ProcessReturnedResponses( + py::list py_requests, py::object py_responses_obj, + std::optional>& response_batch); + + void ProcessResponse(InferResponse* response); + + py::object GetAsyncEventLoop(); + + py::object RunCoroutine(py::object coroutine, bool in_background); + + void BackgroundFutureDone(const py::object& py_future); /// Get the memory manager message queue std::unique_ptr>& MemoryManagerQueue(); @@ -263,8 +169,10 @@ class Stub { /// Get the shared memory pool std::unique_ptr& ShmPool() { return shm_pool_; } - void ProcessResponse(InferResponse* response); + void ProcessBLSResponseDecoupled(std::unique_ptr& ipc_message); + void LoadGPUBuffers(std::unique_ptr& ipc_message); + bool IsDecoupled(); ~Stub(); @@ -303,10 +211,28 @@ class Stub { std::shared_ptr infer_response); /// Send the id to the python backend for object cleanup - void SendCleanupId(std::unique_ptr& utils_msg_payload); + void SendCleanupId( + std::unique_ptr& utils_msg_payload, + const PYTHONSTUB_CommandType& command_type); - /// Add cleanup id to queue - void EnqueueCleanupId(void* id); + /// Add cleanup id to queue. This is used for cleaning up the infer_payload + /// and the response factory for BLS decoupled response. + void EnqueueCleanupId(void* id, const PYTHONSTUB_CommandType& command_type); + + /// Send the id to the python backend for request address retrieval and + /// cancellation + void SendCancelBLSRequest( + std::unique_ptr& utils_msg_payload); + + /// Add infer payload id to queue. This is used for retrieving the request + /// address from the infer_payload + void EnqueueCancelBLSRequest(PbBLSCancel* pb_bls_cancel); + + /// Add request cancellation query to queue + void EnqueueIsCancelled(PbCancel* pb_cancel); + + /// Send request cancellation query to python backend + void SendIsCancelled(std::unique_ptr& utils_msg_payload); /// Is the stub initialized bool IsInitialized(); @@ -336,6 +262,9 @@ class Stub { /// for provided device cudaStream_t GetProxyStream(const int& device_id); + /// Get the CUDA memory pool address from the parent process. + void GetCUDAMemoryPoolAddress(std::unique_ptr& ipc_message); + private: bi::interprocess_mutex* stub_mutex_; bi::interprocess_condition* stub_cond_; @@ -349,6 +278,8 @@ class Stub { py::object model_instance_; py::object deserialize_bytes_; py::object serialize_bytes_; + py::object async_event_loop_; + py::object background_futures_; std::unique_ptr> stub_message_queue_; std::unique_ptr> diff --git a/src/pb_stub_log.cc b/src/pb_stub_log.cc new file mode 100644 index 00000000..d0b1ff97 --- /dev/null +++ b/src/pb_stub_log.cc @@ -0,0 +1,170 @@ +// Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "pb_stub_log.h" + +#include + +#include "pb_stub.h" + + +namespace py = pybind11; + +namespace triton { namespace backend { namespace python { + +std::unique_ptr Logger::log_instance_; + +std::unique_ptr& +Logger::GetOrCreateInstance() +{ + if (Logger::log_instance_.get() == nullptr) { + Logger::log_instance_ = std::make_unique(); + } + + return Logger::log_instance_; +} + +// Bound function, called from the python client +void +Logger::Log(const std::string& message, LogLevel level) +{ + std::unique_ptr& stub = Stub::GetOrCreateInstance(); + py::object frame = py::module_::import("inspect").attr("currentframe"); + py::object caller_frame = frame(); + py::object info = py::module_::import("inspect").attr("getframeinfo"); + py::object caller_info = info(caller_frame); + py::object filename_python = caller_info.attr("filename"); + std::string filename = filename_python.cast(); + py::object lineno = caller_info.attr("lineno"); + uint32_t line = lineno.cast(); + + if (!stub->StubToParentServiceActive()) { + Logger::GetOrCreateInstance()->Log(filename, line, level, message); + } else { + std::unique_ptr log_msg(new PbLog(filename, line, message, level)); + stub->EnqueueLogRequest(log_msg); + } +} + +// Called internally (.e.g. LOG_ERROR << "Error"; ) +void +Logger::Log( + const std::string& filename, uint32_t lineno, LogLevel level, + const std::string& message) +{ + // If the log monitor service is not active yet, format + // and pass messages to cerr + if (!BackendLoggingActive()) { + std::string path(filename); + size_t pos = path.rfind(std::filesystem::path::preferred_separator); + if (pos != std::string::npos) { + path = path.substr(pos + 1, std::string::npos); + } +#ifdef _WIN32 + std::stringstream ss; + SYSTEMTIME system_time; + GetSystemTime(&system_time); + ss << LeadingLogChar(level) << std::setfill('0') << std::setw(2) + << system_time.wMonth << std::setw(2) << system_time.wDay << ' ' + << std::setw(2) << system_time.wHour << ':' << std::setw(2) + << system_time.wMinute << ':' << std::setw(2) << system_time.wSecond + << '.' << std::setw(6) << system_time.wMilliseconds * 1000 << ' ' + << static_cast(GetCurrentProcessId()) << ' ' << path << ':' + << lineno << "] "; +#else + std::stringstream ss; + struct timeval tv; + gettimeofday(&tv, NULL); + struct tm tm_time; + gmtime_r(((time_t*)&(tv.tv_sec)), &tm_time); + ss << LeadingLogChar(level) << std::setfill('0') << std::setw(2) + << (tm_time.tm_mon + 1) << std::setw(2) << tm_time.tm_mday << " " + << std::setw(2) << tm_time.tm_hour << ':' << std::setw(2) + << tm_time.tm_min << ':' << std::setw(2) << tm_time.tm_sec << "." + << std::setw(6) << tv.tv_usec << ' ' << static_cast(getpid()) + << ' ' << path << ':' << lineno << "] "; + std::cerr << ss.str() << " " << message << std::endl; +#endif + } else { + // Ensure we do not create a stub instance before it has initialized + std::unique_ptr& stub = Stub::GetOrCreateInstance(); + std::unique_ptr log_msg(new PbLog(filename, lineno, message, level)); + stub->EnqueueLogRequest(log_msg); + } +} + +void +Logger::LogInfo(const std::string& message) +{ + Logger::Log(message, LogLevel::kInfo); +} + +void +Logger::LogWarn(const std::string& message) +{ + Logger::Log(message, LogLevel::kWarning); +} + +void +Logger::LogError(const std::string& message) +{ + Logger::Log(message, LogLevel::kError); +} + +void +Logger::LogVerbose(const std::string& message) +{ + Logger::Log(message, LogLevel::kVerbose); +} + +const std::string +Logger::LeadingLogChar(const LogLevel& level) +{ + switch (level) { + case LogLevel::kWarning: + return "W"; + case LogLevel::kError: + return "E"; + case LogLevel::kInfo: + case LogLevel::kVerbose: + default: + return "I"; + } +} + +void +Logger::SetBackendLoggingActive(bool status) +{ + backend_logging_active_ = status; +} + +bool +Logger::BackendLoggingActive() +{ + return backend_logging_active_; +} + +}}} // namespace triton::backend::python diff --git a/src/pb_stub_log.h b/src/pb_stub_log.h new file mode 100644 index 00000000..df67eba8 --- /dev/null +++ b/src/pb_stub_log.h @@ -0,0 +1,134 @@ +// Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include + +#include "pb_utils.h" + +namespace triton { namespace backend { namespace python { + +#define LOG_IF_EXCEPTION(X) \ + do { \ + try { \ + (X); \ + } \ + catch (const PythonBackendException& pb_exception) { \ + LOG_INFO << pb_exception.what(); \ + } \ + } while (false) + +#define LOG_EXCEPTION(E) \ + do { \ + LOG_INFO << E.what(); \ + } while (false) + +/// Macros that use current filename and line number. +#define LOG_INFO LOG_FL(__FILE__, __LINE__, LogLevel::kInfo) +#define LOG_WARN LOG_FL(__FILE__, __LINE__, LogLevel::kWarning) +#define LOG_ERROR LOG_FL(__FILE__, __LINE__, LogLevel::kError) +#define LOG_VERBOSE LOG_FL(__FILE__, __LINE__, LogLevel::kVerbose) + +class Logger { + public: + Logger() { backend_logging_active_ = false; }; + ~Logger() { log_instance_.reset(); }; + /// Python client log function + static void Log(const std::string& message, LogLevel level = LogLevel::kInfo); + + /// Python client log info function + static void LogInfo(const std::string& message); + + /// Python client warning function + static void LogWarn(const std::string& message); + + /// Python client log error function + static void LogError(const std::string& message); + + /// Python client log verbose function + static void LogVerbose(const std::string& message); + + /// Internal log function + void Log( + const std::string& filename, uint32_t lineno, LogLevel level, + const std::string& message); + + /// Log format helper function + const std::string LeadingLogChar(const LogLevel& level); + + /// Set PYBE Logging Status + void SetBackendLoggingActive(bool status); + + /// Get PYBE Logging Status + bool BackendLoggingActive(); + + /// Singleton Getter Function + static std::unique_ptr& GetOrCreateInstance(); + + DISALLOW_COPY_AND_ASSIGN(Logger); + + /// Flush the log. + void Flush() { std::cerr << std::flush; } + + private: + static std::unique_ptr log_instance_; + bool backend_logging_active_; +}; + +class LogMessage { + public: + /// Create a log message, stripping the path down to the filename only + LogMessage(const char* file, int line, LogLevel level) : level_(level) + { + std::string path(file); + const char os_slash = std::filesystem::path::preferred_separator; + size_t pos = path.rfind(os_slash); + if (pos != std::string::npos) { + path = path.substr(pos + 1, std::string::npos); + } + file_ = path; + line_ = static_cast(line); + } + /// Log message to console or send to backend (see Logger::Log for details) + ~LogMessage() + { + Logger::GetOrCreateInstance()->Log(file_, line_, level_, stream_.str()); + } + + std::stringstream& stream() { return stream_; } + + private: + std::stringstream stream_; + std::string file_; + uint32_t line_; + LogLevel level_; +}; + +#define LOG_FL(FN, LN, LVL) LogMessage((char*)(FN), LN, LVL).stream() + +}}} // namespace triton::backend::python diff --git a/src/pb_stub_utils.cc b/src/pb_stub_utils.cc index c9ffd661..9e05feae 100644 --- a/src/pb_stub_utils.cc +++ b/src/pb_stub_utils.cc @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -168,6 +168,8 @@ triton_to_pybind_dtype(TRITONSERVER_DataType data_type) dtype_numpy = py::dtype(py::format_descriptor::format()); break; case TRITONSERVER_TYPE_BF16: + // NOTE: Currently skipping this call via `if (BF16)` check, but may + // want to better handle this or set some default/invalid dtype. throw PythonBackendException("TYPE_BF16 not currently supported."); case TRITONSERVER_TYPE_INVALID: throw PythonBackendException("Dtype is invalid."); @@ -240,6 +242,10 @@ triton_to_dlpack_type(TRITONSERVER_DataType triton_dtype) case TRITONSERVER_TYPE_BYTES: throw PythonBackendException( "TYPE_BYTES tensors cannot be converted to DLPack."); + case TRITONSERVER_TYPE_BF16: + dl_code = DLDataTypeCode::kDLBfloat; + dt_size = 16; + break; default: throw PythonBackendException( @@ -301,6 +307,15 @@ dlpack_to_triton_type(const DLDataType& data_type) } } + if (data_type.code == DLDataTypeCode::kDLBfloat) { + if (data_type.bits != 16) { + throw PythonBackendException( + "Expected BF16 tensor to have 16 bits, but had: " + + std::to_string(data_type.bits)); + } + return TRITONSERVER_TYPE_BF16; + } + return TRITONSERVER_TYPE_INVALID; } }}} // namespace triton::backend::python diff --git a/src/pb_tensor.cc b/src/pb_tensor.cc index 4011faad..26e77586 100644 --- a/src/pb_tensor.cc +++ b/src/pb_tensor.cc @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -35,10 +35,76 @@ namespace py = pybind11; #endif #include "pb_tensor.h" +// WAR for undefined ssize_t on Windows: https://stackoverflow.com/a/35368387 +#if defined(_MSC_VER) +#include +typedef SSIZE_T ssize_t; +#endif + +#include +#include +#include namespace triton { namespace backend { namespace python { #ifdef TRITON_PB_STUB +py::array +deserialize_bytes_tensor_cpp(const uint8_t* data, size_t data_size) +{ + if (data_size == 0) { + py::module numpy = py::module::import("numpy"); + return numpy.attr("empty")(0, py::dtype("object")); + } + + // First pass: count the number of strings and calculate total size + size_t offset = 0; + size_t num_strings = 0; + size_t total_string_size = 0; + + while (offset < data_size) { + if (offset + 4 > data_size) { + throw PythonBackendException( + "Invalid bytes tensor data: incomplete length field"); + } + + // Read 4-byte length (little-endian) + uint32_t length = *reinterpret_cast(data + offset); + offset += 4; + + if (offset + length > data_size) { + throw PythonBackendException( + "Invalid bytes tensor data: string extends beyond buffer"); + } + + num_strings++; + total_string_size += length; + offset += length; + } + + // Create numpy array of objects using pybind11's numpy module + py::module numpy = py::module::import("numpy"); + py::array result = numpy.attr("empty")(num_strings, py::dtype("object")); + auto result_ptr = static_cast(result.request().ptr); + + // Second pass: extract strings + offset = 0; + size_t string_index = 0; + + while (offset < data_size) { + uint32_t length = *reinterpret_cast(data + offset); + offset += 4; + + // Create Python bytes object using pybind11 + py::bytes bytes_obj(reinterpret_cast(data + offset), length); + Py_INCREF(bytes_obj.ptr()); // Increment reference count + result_ptr[string_index] = bytes_obj.ptr(); + string_index++; + offset += length; + } + + return result; +} + PbTensor::PbTensor(const std::string& name, py::array& numpy_array) : name_(name) { @@ -147,19 +213,17 @@ PbTensor::PbTensor( #ifdef TRITON_PB_STUB if (memory_type_ == TRITONSERVER_MEMORY_CPU || memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) { - if (dtype != TRITONSERVER_TYPE_BYTES) { + if (dtype == TRITONSERVER_TYPE_BF16) { + // No native numpy representation for BF16. DLPack should be used instead. + numpy_array_ = py::none(); + } else if (dtype != TRITONSERVER_TYPE_BYTES) { py::object numpy_array = py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_); numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_)); } else { - py::object numpy_array = py::array( - triton_to_pybind_dtype(TRITONSERVER_TYPE_UINT8), {byte_size}, - (void*)memory_ptr_); - py::module triton_pb_utils = - py::module::import("triton_python_backend_utils"); - numpy_array_ = - triton_pb_utils.attr("deserialize_bytes_tensor")(numpy_array) - .attr("reshape")(dims); + py::object numpy_array = deserialize_bytes_tensor_cpp( + static_cast(memory_ptr_), byte_size_); + numpy_array_ = numpy_array.attr("reshape")(dims_); } } else { numpy_array_ = py::none(); @@ -226,6 +290,7 @@ delete_unused_dltensor(PyObject* dlp) } } + std::shared_ptr PbTensor::FromNumpy(const std::string& name, py::array& numpy_array) { @@ -433,12 +498,14 @@ PbTensor::FromDLPackCapsule( int64_t calculated_stride{1}; bool is_contiguous_c_order = true; for (size_t i = 1; i < dims.size(); i++) { - if (strides[ndim - i] != calculated_stride) { - is_contiguous_c_order = false; - break; - } + if (dims[ndim - i] != 1) { + if (strides[ndim - i] != calculated_stride) { + is_contiguous_c_order = false; + break; + } - calculated_stride *= dims[ndim - i]; + calculated_stride *= dims[ndim - i]; + } } if (!is_contiguous_c_order) { @@ -493,6 +560,14 @@ PbTensor::~PbTensor() noexcept(false) { pb_memory_.reset(); DeleteDLPack(); + +#ifdef TRITON_PB_STUB + { + py::gil_scoped_acquire acquire; + py::array numpy_array_local(std::move(numpy_array_)); + py::array numpy_array_serialized_local(std::move(numpy_array_serialized_)); + } +#endif } const std::string& @@ -505,12 +580,18 @@ PbTensor::Name() const const py::array* PbTensor::AsNumpy() const { - if (IsCPU()) { - return &numpy_array_; - } else { + if (!IsCPU()) { throw PythonBackendException( "Tensor is stored in GPU and cannot be converted to NumPy."); } + + if (dtype_ == TRITONSERVER_TYPE_BF16) { + throw PythonBackendException( + "Tensor dtype is BF16 and cannot be converted to NumPy. Use " + "to_dlpack() and from_dlpack() instead."); + } + + return &numpy_array_; } #endif // TRITON_PB_STUB @@ -553,7 +634,7 @@ PbTensor::SaveToSharedMemory( if (!pb_memory_) { pb_memory_ = PbMemory::Create( - memory_type_, memory_type_id_, byte_size_, + shm_pool, memory_type_, memory_type_id_, byte_size_, reinterpret_cast(memory_ptr_), reinterpret_cast(tensor_shm_ptr_) + pb_memory_offset, shm_handle_ + pb_memory_offset, copy_gpu); @@ -583,7 +664,7 @@ PbTensor::LoadFromSharedMemory( if (tensor_shm_ptr->memory == 0) { std::size_t pb_memory_offset = name_offset + name_shm->Size(); pb_memory = PbMemory::LoadFromSharedMemory( - pb_memory_offset, tensor_shm.data_.get() + pb_memory_offset, + shm_pool, pb_memory_offset, tensor_shm.data_.get() + pb_memory_offset, open_cuda_handle); } else { pb_memory = PbMemory::LoadFromSharedMemory( @@ -636,19 +717,17 @@ PbTensor::PbTensor( #ifdef TRITON_PB_STUB if (memory_type_ == TRITONSERVER_MEMORY_CPU || memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) { - if (dtype_ != TRITONSERVER_TYPE_BYTES) { + if (dtype_ == TRITONSERVER_TYPE_BF16) { + // No native numpy representation for BF16. DLPack should be used instead. + numpy_array_ = py::none(); + } else if (dtype_ != TRITONSERVER_TYPE_BYTES) { py::object numpy_array = py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_); numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_)); } else { - py::object numpy_array = py::array( - triton_to_pybind_dtype(TRITONSERVER_TYPE_UINT8), {byte_size_}, - (void*)memory_ptr_); - py::module triton_pb_utils = - py::module::import("triton_python_backend_utils"); - numpy_array_ = - triton_pb_utils.attr("deserialize_bytes_tensor")(numpy_array) - .attr("reshape")(dims_); + py::object numpy_array = deserialize_bytes_tensor_cpp( + static_cast(memory_ptr_), byte_size_); + numpy_array_ = numpy_array.attr("reshape")(dims_); } } else { numpy_array_ = py::none(); diff --git a/src/pb_tensor.h b/src/pb_tensor.h index b9c0d593..4f97b643 100644 --- a/src/pb_tensor.h +++ b/src/pb_tensor.h @@ -99,8 +99,7 @@ class PbTensor { int64_t memory_type_id, void* memory_ptr, uint64_t byte_size, DLManagedTensor* dl_managed_tensor = nullptr); - /// This constructor is used when - /// loading the tensor from shared memory. + /// This constructor is used when loading the tensor from shared memory. /// \param tensor_shm The name of the tensor /// \param dims_shm Tensor dimensions /// \param pb_string Triton dtype diff --git a/src/pb_utils.cc b/src/pb_utils.cc index 089f4cf0..79b45ec2 100644 --- a/src/pb_utils.cc +++ b/src/pb_utils.cc @@ -26,27 +26,23 @@ #include "pb_utils.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include + +#include + +#ifdef _WIN32 +#include + +#include +#else +#include #include +#endif -#include -#include -#include -#include -#include -#include +#ifndef _WIN32 +extern char** environ; +#endif -#include "scoped_defer.h" #ifdef TRITON_ENABLE_GPU #include @@ -59,42 +55,43 @@ namespace triton { namespace backend { namespace python { CUDAHandler::CUDAHandler() { - dl_open_handle_ = dlopen("libcuda.so", RTLD_LAZY); + dl_open_handle_ = LoadSharedObject("libcuda.so"); // If libcuda.so is successfully opened, it must be able to find // "cuPointerGetAttribute", "cuGetErrorString", and // "cuDevicePrimaryCtxGetState" symbols. if (dl_open_handle_ != nullptr) { - void* cu_pointer_get_attribute_fn = - dlsym(dl_open_handle_, "cuPointerGetAttribute"); + void* cu_pointer_get_attribute_fn = LocateSymbol("cuPointerGetAttribute"); if (cu_pointer_get_attribute_fn == nullptr) { throw PythonBackendException( - std::string("Failed to dlsym 'cuPointerGetAttribute'. Error: ") + - dlerror()); + std::string("Failed to locate 'cuPointerGetAttribute'. Error: ") + + LocateSymbolError()); } *((void**)&cu_pointer_get_attribute_fn_) = cu_pointer_get_attribute_fn; - void* cu_get_error_string_fn = dlsym(dl_open_handle_, "cuGetErrorString"); + void* cu_get_error_string_fn = LocateSymbol("cuGetErrorString"); if (cu_get_error_string_fn == nullptr) { throw PythonBackendException( - std::string("Failed to dlsym 'cuGetErrorString'. Error: ") + - dlerror()); + std::string("Failed to locate 'cuGetErrorString'. Error: ") + + LocateSymbolError()); } *((void**)&cu_get_error_string_fn_) = cu_get_error_string_fn; - void* cu_init_fn = dlsym(dl_open_handle_, "cuInit"); + void* cu_init_fn = LocateSymbol("cuInit"); if (cu_init_fn == nullptr) { throw PythonBackendException( - std::string("Failed to dlsym 'cuInit'. Error: ") + dlerror()); + std::string("Failed to locate 'cuInit'. Error: ") + + LocateSymbolError()); } *((void**)&cu_init_fn_) = cu_init_fn; void* cu_device_primary_ctx_get_state_fn = - dlsym(dl_open_handle_, "cuDevicePrimaryCtxGetState"); + LocateSymbol("cuDevicePrimaryCtxGetState"); if (cu_device_primary_ctx_get_state_fn == nullptr) { throw PythonBackendException( - std::string("Failed to dlsym 'cuDevicePrimaryCtxGetState'. Error: ") + - dlerror()); + std::string( + "Failed to locate 'cuDevicePrimaryCtxGetState'. Error: ") + + LocateSymbolError()); } *((void**)&cu_device_primary_ctx_get_state_fn_) = cu_device_primary_ctx_get_state_fn; @@ -105,10 +102,7 @@ CUDAHandler::CUDAHandler() const char* error_string; (*cu_get_error_string_fn_)(cuda_err, &error_string); error_str_ = std::string("failed to call cuInit: ") + error_string; - int status = dlclose(dl_open_handle_); - if (status != 0) { - throw PythonBackendException("Failed to close the libcuda handle."); - } + CloseLibrary(); dl_open_handle_ = nullptr; } } @@ -215,13 +209,58 @@ CUDAHandler::MaybeSetDevice(int device) CUDAHandler::~CUDAHandler() noexcept(false) { if (dl_open_handle_ != nullptr) { - int status = dlclose(dl_open_handle_); - if (status != 0) { - throw PythonBackendException("Failed to close the libcuda handle."); - } + CloseLibrary(); } } +void* +CUDAHandler::LoadSharedObject(const char* filename) +{ +#ifdef _WIN32 + // NOTE: 'nvcuda.dll' is a placeholder library. Apparently, this should be the + // equivalent library for Windows, but need to verify. + return LoadLibraryA("nvcuda.dll"); +#else + return dlopen("libcuda.so", RTLD_LAZY); +#endif +} + +void* +CUDAHandler::LocateSymbol(const char* symbol) +{ +#ifdef _WIN32 + return GetProcAddress(static_cast(dl_open_handle_), symbol); +#else + return dlsym(dl_open_handle_, symbol); +#endif +} + + +std::string +CUDAHandler::LocateSymbolError() +{ +#ifdef _WIN32 + return std::to_string(GetLastError()); +#else + return dlerror(); +#endif +} + +void +CUDAHandler::CloseLibrary() +{ + bool successful = true; +#ifdef _WIN32 + successful = (FreeLibrary(static_cast(dl_open_handle_)) != 0); +#else + successful = (dlclose(dl_open_handle_) == 0); +#endif + if (!successful) { + throw PythonBackendException("Failed to close the cuda library handle."); + } +} + + ScopedSetDevice::ScopedSetDevice(int device) { device_ = device; @@ -239,7 +278,32 @@ ScopedSetDevice::~ScopedSetDevice() cuda_handler.MaybeSetDevice(current_device_); } } -#endif + +bool +IsUsingCUDAPool( + std::unique_ptr& cuda_pool, int64_t memory_type_id, + void* data) +{ + CUDAHandler& cuda_api = CUDAHandler::getInstance(); + CUdeviceptr cuda_pool_address = 0; + cuda_api.PointerGetAttribute( + &cuda_pool_address, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + reinterpret_cast(data)); + + return ( + cuda_pool->CUDAPoolAddress(memory_type_id) == + reinterpret_cast(cuda_pool_address)); +} + +#endif // TRITON_ENABLE_GPU + +// FIXME: [DLIS-6078]: We should not need this function. However, some paths are +// being retrieved from core that are not platform-agnostic. +void +SanitizePath(std::string& path) +{ + std::replace(path.begin(), path.end(), '/', '\\'); +} #ifndef TRITON_PB_STUB std::shared_ptr @@ -258,5 +322,119 @@ WrapTritonErrorInSharedPtr(TRITONSERVER_Error* error) *response_error = error; return response_error; } +#endif // NOT TRITON_PB_STUB + +bool +IsValidIdentifier(const std::string& input) +{ + // Check for invalid characters + if (input.empty() || + input.find_first_of(INVALID_CHARS) != std::string::npos) { + return false; + } + + return true; +} + +bool +IsExecutableFile(const std::string& filepath) +{ + struct stat file_stat; + if (stat(filepath.c_str(), &file_stat) != 0) { + return false; + } + + // Check if it's a regular file and executable by owner + return S_ISREG(file_stat.st_mode) && (file_stat.st_mode & S_IXUSR); +} + +std::string +GenerateUUID() +{ + static boost::uuids::random_generator generator; + boost::uuids::uuid uuid = generator(); + return boost::uuids::to_string(uuid); +} + +// Helper function to get environment variables for Python virtual environments +std::map +ParseActivationScript(const std::string& activate_path) +{ + std::map env_vars; + + // Read the current environment as baseline +#ifndef _WIN32 + if (environ != nullptr) { + for (char** env = environ; *env != nullptr; env++) { + std::string env_str(*env); + size_t eq_pos = env_str.find('='); + if (eq_pos != std::string::npos) { + std::string key = env_str.substr(0, eq_pos); + std::string value = env_str.substr(eq_pos + 1); + env_vars[key] = value; + } + } + } #endif + + // Extract virtual environment root from activation script path + std::string venv_path = activate_path; + size_t bin_activate_pos = venv_path.find("/bin/activate"); + if (bin_activate_pos != std::string::npos) { + venv_path = venv_path.substr(0, bin_activate_pos); + } + + // Set standard virtual environment variables + env_vars["VIRTUAL_ENV"] = venv_path; + env_vars["VIRTUAL_ENV_PROMPT"] = "(" + venv_path + ")"; + + // Update PATH to include the virtual environment's bin directory + std::string new_path = venv_path + "/bin"; + if (env_vars.find("PATH") != env_vars.end()) { + new_path += ":" + env_vars["PATH"]; + } + env_vars["PATH"] = new_path; + + // Update LD_LIBRARY_PATH to include the virtual environment's lib directory + std::string new_lib_path = venv_path + "/lib"; + if (env_vars.find("LD_LIBRARY_PATH") != env_vars.end()) { + new_lib_path += ":" + env_vars["LD_LIBRARY_PATH"]; + } + env_vars["LD_LIBRARY_PATH"] = new_lib_path; + + // Remove PYTHONHOME if it exists + env_vars.erase("PYTHONHOME"); + + return env_vars; +} + +// Helper function to prepare environment array for execve +std::pair, std::vector> +PrepareEnvironment( + const std::map& env_vars, + const std::string& additional_lib_path) +{ + std::vector env_strings; + std::vector env_array; + + for (const auto& [key, value] : env_vars) { + std::string env_string; + if (key == "LD_LIBRARY_PATH" && !additional_lib_path.empty()) { + // Prepend the additional library path + env_string = key + "=" + additional_lib_path + ":" + value; + } else { + env_string = key + "=" + value; + } + env_strings.push_back(env_string); + } + + // Convert to char* array + for (auto& env_str : env_strings) { + env_array.push_back(const_cast(env_str.c_str())); + } + env_array.push_back(nullptr); + + return std::make_pair(std::move(env_strings), std::move(env_array)); +} + }}} // namespace triton::backend::python diff --git a/src/pb_utils.h b/src/pb_utils.h index 1d651f3f..fa315210 100644 --- a/src/pb_utils.h +++ b/src/pb_utils.h @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -29,15 +29,19 @@ #ifdef TRITON_ENABLE_GPU #include #endif // TRITON_ENABLE_GPU -#include #include #include +#include +#include +#include #include +#include #include #include #include #include +#include #include #include "pb_exception.h" @@ -71,7 +75,7 @@ namespace bi = boost::interprocess; TRITONSERVER_ErrorMessage(pb2_exception.what())); \ } \ } \ - while (false) + } while (false) #define THROW_IF_TRITON_ERROR(X) \ do { \ @@ -165,11 +169,14 @@ struct ResponseBatch : SendMessageBase { bool is_error_set; uint32_t response_size; + + // Indicates whether the response factory has been deleted or not. + bool is_response_factory_deleted = false; }; -enum LogLevel { INFO = 0, WARNING, ERROR, VERBOSE }; +enum LogLevel { kInfo = 0, kWarning, kError, kVerbose }; -enum MetricKind { COUNTER, GAUGE }; +enum MetricKind { kCounter = 0, kGauge, kHistogram }; struct LogSendMessage : SendMessageBase { bi::managed_external_buffer::handle_t filename; @@ -182,6 +189,17 @@ struct CleanupMessage : SendMessageBase { void* id; }; +struct CancelBLSRequestMessage : SendMessageBase { + void* infer_payload_id; + bool is_cancelled; +}; + +struct IsCancelledMessage : SendMessageBase { + intptr_t response_factory_address; + intptr_t request_address; + bool is_cancelled; +}; + struct CustomMetricsMessage : SendMessageBase { bi::managed_external_buffer::handle_t message; bool has_error; @@ -235,7 +253,22 @@ struct RequestBatch { bi::managed_external_buffer::handle_t gpu_buffers_handle; }; +struct MemoryReleaseMessage { + std::mutex mu; + std::condition_variable cv; + uint64_t id; + bool waiting_on_stub; +}; + #ifdef TRITON_ENABLE_GPU +struct CUDAMemPoolMessage : SendMessageBase { + cudaIpcMemHandle_t cuda_handle; + int32_t device_id; + bi::managed_external_buffer::handle_t error; + bool has_error; + bool is_error_set; +}; + class CUDAHandler { public: static CUDAHandler& getInstance() @@ -273,6 +306,10 @@ class CUDAHandler { int64_t memory_type_id, cudaIpcMemHandle_t* cuda_mem_handle, void** data_ptr); void CloseCudaHandle(int64_t memory_type_id, void* data_ptr); + void* LoadSharedObject(const char* filename); + void* LocateSymbol(const char* symbol); + std::string LocateSymbolError(); + void CloseLibrary(); /// Set the device only if the primary context has already been created for /// this device. Inspired from PyTorch's MaybeSetDevice. @@ -295,11 +332,39 @@ class ScopedSetDevice { int current_device_; }; +// Check if the data is allocated from the pool by the base address. +bool IsUsingCUDAPool( + std::unique_ptr& cuda_pool, int64_t memory_type_id, + void* data); + #endif // TRITON_ENABLE_GPU +// FIXME: [DLIS-6078]: We should not need this function. However, some paths are +// being retrieved from core that are not platform-agnostic. +void SanitizePath(std::string& path); + +// Invalid characters that are not allowed in user input +constexpr const char* INVALID_CHARS = ";|&$`<>()[]{}\\\"'*?~#!"; + +// Validate that an identifier (model name, region name, etc.) +bool IsValidIdentifier(const std::string& input); + +// Check if a file exists and is executable +bool IsExecutableFile(const std::string& filepath); + #ifndef TRITON_PB_STUB std::shared_ptr WrapTritonErrorInSharedPtr( TRITONSERVER_Error* error); #endif +std::string GenerateUUID(); + +// Environment handling utilities for Python activation scripts +std::map ParseActivationScript( + const std::string& activate_path); + +std::pair, std::vector> PrepareEnvironment( + const std::map& env_vars, + const std::string& additional_lib_path = ""); + }}} // namespace triton::backend::python diff --git a/src/python_be.cc b/src/python_be.cc index 6f25e024..c152e035 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -1,4 +1,4 @@ -// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -25,6 +25,9 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "python_be.h" +#include + +#include "correlation_id.h" #include "gpu_buffers.h" #include "infer_payload.h" #include "model_loader.h" @@ -150,107 +153,6 @@ ModelInstanceState::SetErrorForResponseSendMessage( } } -void -ModelInstanceState::SendMessageAndReceiveResponse( - bi::managed_external_buffer::handle_t message, - bi::managed_external_buffer::handle_t& response, bool& restart, - std::shared_ptr>& responses, - TRITONBACKEND_Request** requests, const uint32_t request_count) -{ - auto error = SendMessageToStub(message); - if (error != nullptr) { - restart = true; - RespondErrorToAllRequests( - TRITONSERVER_ErrorMessage(error), responses, requests, request_count); - - return; - } - - bi::managed_external_buffer::handle_t response_message; - error = Stub()->ReceiveMessageFromStub(response_message); - if (error != nullptr) { - restart = true; - RespondErrorToAllRequests( - TRITONSERVER_ErrorMessage(error), responses, requests, request_count); - - return; - } - - response = response_message; -} - -TRITONSERVER_Error* -ModelInstanceState::SendMessageToStub( - bi::managed_external_buffer::handle_t message) -{ - bool success = false; - while (!success) { - uint64_t timeout_miliseconds = 1000; - { - boost::posix_time::ptime timeout = - boost::get_system_time() + - boost::posix_time::milliseconds(timeout_miliseconds); - - bi::scoped_lock lock( - *(Stub()->HealthMutex()), timeout); - - // Check if lock has been acquired. - if (lock) { - Stub()->IpcControl()->stub_health = false; - } else { - // If it failed to obtain the lock, it means that the stub has been - // stuck or exited while holding the health mutex lock. - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, "Failed to obtain the health mutex."); - } - } - - Stub()->StubMessageQueue()->Push( - message, timeout_miliseconds /* duration ms */, success); - - if (!success && !IsStubProcessAlive()) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, "Stub process is not healthy."); - } - } - - return nullptr; // success -} - -void -ModelInstanceState::RespondErrorToAllRequests( - const char* message, - std::shared_ptr>& responses, - TRITONBACKEND_Request** requests, const uint32_t request_count) -{ - for (uint32_t r = 0; r < request_count; ++r) { - if ((*responses)[r] == nullptr) - continue; - - std::string err_message = - std::string( - "Failed to process the request(s) for model instance '" + Name() + - "', message: ") + - message; - - TRITONSERVER_Error* err = - TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, err_message.c_str()); - LOG_IF_ERROR( - TRITONBACKEND_ResponseSend( - (*responses)[r], TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), - "failed sending response"); - - (*responses)[r] = nullptr; - TRITONSERVER_ErrorDelete(err); - } -} - -void -ModelInstanceState::WaitForBLSRequestsToFinish() -{ - futures_.clear(); -} - bool ModelInstanceState::IsStubProcessAlive() { @@ -271,12 +173,12 @@ ModelInstanceState::IsStubProcessAlive() TRITONSERVER_Error* ModelInstanceState::SaveRequestsToSharedMemory( TRITONBACKEND_Request** requests, const uint32_t request_count, - std::vector>& pb_inference_requests, + std::vector>& pb_infer_requests, AllocatedSharedMemory& request_batch, std::shared_ptr>& responses) { // Clear any existing items in the requests vector - pb_inference_requests.clear(); + pb_infer_requests.clear(); ModelState* model_state = reinterpret_cast(Model()); RETURN_IF_EXCEPTION( @@ -340,6 +242,9 @@ ModelInstanceState::SaveRequestsToSharedMemory( } else if (type == TRITONSERVER_PARAMETER_STRING) { std::string string = reinterpret_cast(vvalue); RETURN_IF_ERROR(parameters_json.AddString(name, string)); + } else if (type == TRITONSERVER_PARAMETER_DOUBLE) { + RETURN_IF_ERROR(parameters_json.AddDouble( + name, *(reinterpret_cast(vvalue)))); } else { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, @@ -357,33 +262,63 @@ ModelInstanceState::SaveRequestsToSharedMemory( const char* id; RETURN_IF_ERROR(TRITONBACKEND_RequestId(request, &id)); - uint64_t correlation_id; - RETURN_IF_ERROR( - TRITONBACKEND_RequestCorrelationId(request, &correlation_id)); + uint64_t correlation_id_uint = 0; + CorrelationId correlation_id; + + auto error = + TRITONBACKEND_RequestCorrelationId(request, &correlation_id_uint); + if (error != nullptr) { + TRITONSERVER_ErrorDelete(error); + const char* correlation_id_string = ""; + RETURN_IF_ERROR(TRITONBACKEND_RequestCorrelationIdString( + request, &correlation_id_string)); + correlation_id = CorrelationId(std::string(correlation_id_string)); + } else { + correlation_id = CorrelationId(correlation_id_uint); + } uint32_t flags; RETURN_IF_ERROR(TRITONBACKEND_RequestFlags(request, &flags)); - std::unique_ptr infer_request; - if (model_state->IsDecoupled()) { - TRITONBACKEND_ResponseFactory* factory_ptr; - RETURN_IF_ERROR(TRITONBACKEND_ResponseFactoryNew(&factory_ptr, request)); - infer_request = std::make_unique( - id, correlation_id, pb_input_tensors, requested_output_names, - model_state->Name(), model_state->Version(), parameters_string, flags, - 0 /* BLS request timeout*/, reinterpret_cast(factory_ptr), - reinterpret_cast(request)); - } else { - infer_request = std::make_unique( - id, correlation_id, pb_input_tensors, requested_output_names, - model_state->Name(), model_state->Version(), parameters_string, flags, - 0 /* BLS request timeout*/, 0 /* response_factory_address */, - reinterpret_cast(request)); + // Do not return if error in this case, because Triton core + // will return an error if tracing is disabled (see PYBE PR#295). + // For the same reason, we do not log the error message, otherwise + // when Triton is compiled without tracing, it'll constantly log + // this error. + TRITONSERVER_InferenceTrace* triton_trace; + auto err = TRITONBACKEND_RequestTrace(request, &triton_trace); + if (err != nullptr) { + triton_trace = nullptr; + TRITONSERVER_ErrorDelete(err); + } + const char* val = nullptr; + if (triton_trace != nullptr) { + LOG_IF_ERROR( + TRITONSERVER_InferenceTraceContext(triton_trace, &val), + "failed to retrieve trace context"); } + std::string context = (val != nullptr) ? std::string(val) : ""; + InferenceTrace trace = + InferenceTrace(reinterpret_cast(triton_trace), context); + + uint64_t request_timeout; + RETURN_IF_ERROR(TRITONBACKEND_InferenceRequestTimeoutMicroseconds( + request, &request_timeout)); + + std::unique_ptr infer_request; + TRITONBACKEND_ResponseFactory* factory_ptr = nullptr; + RETURN_IF_ERROR(TRITONBACKEND_ResponseFactoryNew(&factory_ptr, request)); + + infer_request = std::make_unique( + id, correlation_id, pb_input_tensors, requested_output_names, + model_state->Name(), model_state->Version(), parameters_string, flags, + request_timeout, reinterpret_cast(factory_ptr), + reinterpret_cast(request), + PreferredMemory(PreferredMemory::kDefault, 0), trace); RETURN_IF_EXCEPTION(infer_request->SaveToSharedMemory(Stub()->ShmPool())); requests_shm[r] = infer_request->ShmHandle(); - pb_inference_requests.emplace_back(std::move(infer_request)); + pb_infer_requests.emplace_back(std::move(infer_request)); } return nullptr; // success @@ -404,11 +339,6 @@ ModelInstanceState::LaunchStubProcess() thread_pool_ = std::make_unique( model_state->StateForBackend()->thread_pool_size); - if (model_state->IsDecoupled()) { - decoupled_thread_ = true; - decoupled_monitor_ = - std::thread(&ModelInstanceState::DecoupledMessageQueueMonitor, this); - } request_executor_ = std::make_unique( Stub()->ShmPool(), model_state->TritonServer()); @@ -509,8 +439,21 @@ ModelInstanceState::GetInputTensor( RETURN_IF_ERROR(backend::ReadInputTensor( request, input_name, input_buffer, &byte_size)); } + + if (input_dtype == TRITONSERVER_TYPE_BYTES) { + const char* content = reinterpret_cast(input_tensor->DataPtr()); + size_t content_byte_size = input_tensor->ByteSize(); + int64_t request_element_cnt = 0; + RETURN_IF_ERROR( + GetElementCount(input_tensor->Dims(), &request_element_cnt)); + RETURN_IF_ERROR(ValidateStringBuffer( + content, content_byte_size, request_element_cnt, input_name, + nullptr /* str_list */)); + } } else { #ifdef TRITON_ENABLE_GPU + // Attempt to use the cuda shared memory pool for GPU tensor. + ShareCUDAMemoryPool(src_memory_type_id); // Retrieving GPU input tensors const void* buffer = nullptr; @@ -519,6 +462,8 @@ ModelInstanceState::GetInputTensor( // collector is used in the non-decoupled mode. if (collector) { + // The ProcessTensor function will try to allocate the buffer in the CUDA + // pool first. RETURN_IF_ERROR(collector->ProcessTensor( input_name, nullptr, 0, alloc_perference, reinterpret_cast(&buffer), &input_byte_size, @@ -558,10 +503,22 @@ ModelInstanceState::GetInputTensor( Stub()->ShmPool(), true /* copy_gpu */)); } } else { + // Try to use the cuda shared memory pool first. void* dev_ptr; - RETURN_IF_CUDA_ERROR( - cudaMalloc(&dev_ptr, input_byte_size), TRITONSERVER_ERROR_INTERNAL, - std::string("Failed to allocated CUDA memory")); + BackendMemory* backend_memory; + std::unique_ptr lbackend_memory; + RETURN_IF_ERROR(BackendMemory::Create( + reinterpret_cast( + Stub() + ->ShmPool() + ->GetCUDAMemoryPoolManager() + ->TritonMemoryManager()), + {BackendMemory::AllocationType::GPU_POOL, + BackendMemory::AllocationType::GPU}, + src_memory_type_id, input_byte_size, &backend_memory)); + + dev_ptr = backend_memory->MemoryPtr(); + lbackend_memory.reset(backend_memory); size_t byte_size = input_byte_size; @@ -584,14 +541,11 @@ ModelInstanceState::GetInputTensor( const_cast(dev_ptr), input_byte_size, nullptr /* DLManagedTensor */); + input_tensor->SetMemory(std::move( + PbMemory::Create(Stub()->ShmPool(), std::move(lbackend_memory)))); + RETURN_IF_EXCEPTION(input_tensor->SaveToSharedMemory( Stub()->ShmPool(), true /* copy_gpu */)); - - std::unique_ptr gpu_memory_record = - std::make_unique(input_tensor->Memory()->DataPtr()); - uint64_t memory_release_id = - Stub()->GetMemoryManager()->AddRecord(std::move(gpu_memory_record)); - input_tensor->Memory()->SetMemoryReleaseId(memory_release_id); } #else return TRITONSERVER_ErrorNew( @@ -642,7 +596,8 @@ ModelInstanceState::ExecuteBLSRequest( reinterpret_cast( request_batch.data_.get() + sizeof(RequestBatch)); infer_request = InferRequest::LoadFromSharedMemory( - Stub()->ShmPool(), *request_handle, false /* open_cuda_handle */); + Stub()->ShmPool(), *request_handle, false /* open_cuda_handle */, + nullptr /* is_model_decoupled */); // If the BLS inputs are in GPU an additional round trip between the // stub process and the main process is required. The reason is that we @@ -652,6 +607,8 @@ ModelInstanceState::ExecuteBLSRequest( for (auto& input_tensor : infer_request->Inputs()) { if (!input_tensor->IsCPU()) { #ifdef TRITON_ENABLE_GPU + // Attempt to use the cuda shared memory pool for GPU tensor. + ShareCUDAMemoryPool(input_tensor->MemoryTypeId()); BackendMemory* backend_memory; std::unique_ptr lbackend_memory; has_gpu_tensor = true; @@ -708,7 +665,8 @@ ModelInstanceState::ExecuteBLSRequest( if (is_decoupled && (infer_response->Id() != nullptr)) { // Need to manage the lifetime of InferPayload object for bls // decoupled responses. - infer_payload_[reinterpret_cast(&infer_payload)] = + std::lock_guard lock(infer_payload_mu_); + infer_payload_[reinterpret_cast(infer_payload.get())] = infer_payload; } @@ -744,48 +702,6 @@ ModelInstanceState::ExecuteBLSRequest( } } -void -ModelInstanceState::DecoupledMessageQueueMonitor() -{ - while (decoupled_thread_) { - bi::managed_external_buffer::handle_t handle = - Stub()->ParentMessageQueue()->Pop(); - if (handle == DUMMY_MESSAGE) { - break; - } - std::unique_ptr message = - IPCMessage::LoadFromSharedMemory(Stub()->ShmPool(), handle); - - // Need to notify the model instance thread that the execute response has - // been received. - if (message->Command() == PYTHONSTUB_ExecuteResponse) { - std::lock_guard guard{mu_}; - received_message_ = std::move(message); - cv_.notify_one(); - } else if (message->Command() == PYTHONSTUB_ResponseSend) { - std::shared_ptr response_send_message = std::move(message); - std::packaged_task task([this, response_send_message] { - ResponseSendDecoupled(response_send_message); - }); - std::future future = - boost::asio::post(*thread_pool_, std::move(task)); - futures_.emplace_back(std::move(future)); - } else if ( - message->Command() == PYTHONSTUB_InferExecRequest || - message->Command() == PYTHONSTUB_InferStreamExecRequest) { - std::shared_ptr bls_execute = std::move(message); - std::packaged_task task([this, bls_execute] { - ExecuteBLSRequest( - bls_execute, - (bls_execute->Command() == PYTHONSTUB_InferStreamExecRequest)); - }); - std::future future = - boost::asio::post(*thread_pool_, std::move(task)); - futures_.emplace_back(std::move(future)); - } - } -} - void ModelInstanceState::StubToParentMQMonitor() { @@ -803,8 +719,13 @@ ModelInstanceState::StubToParentMQMonitor() ProcessLogRequest(message); break; } - case PYTHONSTUB_CleanupRequest: { - ProcessBLSCleanupRequest(message); + case PYTHONSTUB_BLSDecoupledInferPayloadCleanup: + case PYTHONSTUB_DecoupledResponseFactoryCleanup: { + ProcessCleanupRequest(message); + break; + } + case PYTHONSTUB_IsRequestCancelled: { + ProcessIsRequestCancelled(message); break; } case PYTHONSTUB_MetricFamilyRequestNew: @@ -816,7 +737,8 @@ ModelInstanceState::StubToParentMQMonitor() case PYTHONSTUB_MetricRequestDelete: case PYTHONSTUB_MetricRequestValue: case PYTHONSTUB_MetricRequestIncrement: - case PYTHONSTUB_MetricRequestSet: { + case PYTHONSTUB_MetricRequestSet: + case PYTHONSTUB_MetricRequestObserve: { ProcessMetricRequest(message); break; } @@ -826,6 +748,29 @@ ModelInstanceState::StubToParentMQMonitor() ProcessModelControlRequest(message); break; } + case PYTHONSTUB_ResponseSend: { + std::shared_ptr response_send_message = std::move(message); + std::packaged_task task([this, response_send_message] { + ResponseSendDecoupled(response_send_message); + }); + boost::asio::post(*thread_pool_, std::move(task)); + break; + } + case PYTHONSTUB_InferExecRequest: + case PYTHONSTUB_InferStreamExecRequest: { + std::shared_ptr bls_execute = std::move(message); + std::packaged_task task([this, bls_execute] { + ExecuteBLSRequest( + bls_execute, + (bls_execute->Command() == PYTHONSTUB_InferStreamExecRequest)); + }); + boost::asio::post(*thread_pool_, std::move(task)); + break; + } + case PYTHONSTUB_CancelBLSInferRequest: { + ProcessCancelBLSRequest(message); + break; + } default: { LOG_MESSAGE( TRITONSERVER_LOG_ERROR, "Unexpected message type received."); @@ -850,25 +795,25 @@ ModelInstanceState::ProcessLogRequest( LogLevel level = pb_log_message->Level(); switch (level) { - case LogLevel::INFO: { + case LogLevel::kInfo: { TRITONSERVER_LogMessage( TRITONSERVER_LOG_INFO, (filename.c_str()), line, (log_message.c_str())); break; } - case LogLevel::WARNING: { + case LogLevel::kWarning: { TRITONSERVER_LogMessage( TRITONSERVER_LOG_WARN, (filename.c_str()), line, (log_message.c_str())); break; } - case LogLevel::ERROR: { + case LogLevel::kError: { TRITONSERVER_LogMessage( TRITONSERVER_LOG_ERROR, (filename.c_str()), line, (log_message.c_str())); break; } - case LogLevel::VERBOSE: { + case LogLevel::kVerbose: { TRITONSERVER_LogMessage( TRITONSERVER_LOG_VERBOSE, (filename.c_str()), line, (log_message.c_str())); @@ -890,16 +835,24 @@ ModelInstanceState::ProcessLogRequest( } void -ModelInstanceState::ProcessBLSCleanupRequest( +ModelInstanceState::ProcessCleanupRequest( const std::unique_ptr& message) { AllocatedSharedMemory cleanup_request_message = Stub()->ShmPool()->Load(message->Args()); CleanupMessage* cleanup_message_ptr = reinterpret_cast(cleanup_request_message.data_.get()); - - void* id = cleanup_message_ptr->id; - infer_payload_.erase(id); + intptr_t id = reinterpret_cast(cleanup_message_ptr->id); + if (message->Command() == PYTHONSTUB_BLSDecoupledInferPayloadCleanup) { + // Remove the InferPayload object from the map. + std::lock_guard lock(infer_payload_mu_); + infer_payload_.erase(id); + } else if (message->Command() == PYTHONSTUB_DecoupledResponseFactoryCleanup) { + // Delete response factory + std::unique_ptr< + TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> + response_factory(reinterpret_cast(id)); + } { bi::scoped_lock lock{*(message->ResponseMutex())}; @@ -908,6 +861,74 @@ ModelInstanceState::ProcessBLSCleanupRequest( } } +void +ModelInstanceState::ProcessCancelBLSRequest( + const std::unique_ptr& message) +{ + AllocatedSharedMemory message_shm = + Stub()->ShmPool()->Load(message->Args()); + CancelBLSRequestMessage* message_payload = + reinterpret_cast(message_shm.data_.get()); + + { + bi::scoped_lock lk{message_payload->mu}; + + intptr_t id = reinterpret_cast(message_payload->infer_payload_id); + try { + { + std::lock_guard lock(infer_payload_mu_); + if (infer_payload_.find(id) != infer_payload_.end()) { + infer_payload_[id]->SafeCancelRequest(); + } + } + message_payload->is_cancelled = true; + } + catch (const PythonBackendException& pb_exception) { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, pb_exception.what()); + } + + message_payload->waiting_on_stub = true; + message_payload->cv.notify_all(); + while (message_payload->waiting_on_stub) { + message_payload->cv.wait(lk); + } + } +} + +void +ModelInstanceState::ProcessIsRequestCancelled( + const std::unique_ptr& message) +{ + AllocatedSharedMemory message_shm = + Stub()->ShmPool()->Load(message->Args()); + IsCancelledMessage* message_payload = + reinterpret_cast(message_shm.data_.get()); + + { + bi::scoped_lock lk{message_payload->mu}; + + if (message_payload->response_factory_address != 0) { + TRITONBACKEND_ResponseFactory* response_factory = + reinterpret_cast( + message_payload->response_factory_address); + TRITONBACKEND_ResponseFactoryIsCancelled( + response_factory, &message_payload->is_cancelled); + } else if (message_payload->request_address != 0) { + TRITONBACKEND_Request* request = reinterpret_cast( + message_payload->request_address); + TRITONBACKEND_RequestIsCancelled(request, &message_payload->is_cancelled); + } else { + throw PythonBackendException("Cannot determine request cancellation"); + } + + message_payload->waiting_on_stub = true; + message_payload->cv.notify_all(); + while (message_payload->waiting_on_stub) { + message_payload->cv.wait(lk); + } + } +} + template void ModelInstanceState::ProcessMessage( @@ -994,6 +1015,7 @@ ModelInstanceState::ProcessMetricRequest( } case PYTHONSTUB_MetricRequestIncrement: case PYTHONSTUB_MetricRequestSet: + case PYTHONSTUB_MetricRequestObserve: case PYTHONSTUB_MetricRequestValue: { metric->HandleMetricOperation(metrics_message_ptr, command); break; @@ -1044,36 +1066,141 @@ ModelInstanceState::ProcessModelControlRequest( }); } -void -ModelInstanceState::StartMonitor() +TRITONSERVER_Error* +ModelInstanceState::SendMessageToStub( + bi::managed_external_buffer::handle_t message) { - stub_to_parent_thread_ = true; - stub_to_parent_queue_monitor_ = - std::thread(&ModelInstanceState::StubToParentMQMonitor, this); + bool success = false; + while (!success) { + uint64_t timeout_miliseconds = 1000; + { + boost::posix_time::ptime timeout = + boost::get_system_time() + + boost::posix_time::milliseconds(timeout_miliseconds); + + bi::scoped_lock lock( + *(Stub()->HealthMutex()), timeout); + + // Check if lock has been acquired. + if (lock) { + Stub()->IpcControl()->stub_health = false; + } else { + // If it failed to obtain the lock, it means that the stub has been + // stuck or exited while holding the health mutex lock. + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "Failed to obtain the health mutex."); + } + } + + Stub()->StubMessageQueue()->Push( + message, timeout_miliseconds /* duration ms */, success); + + if (!success && !IsStubProcessAlive()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "Stub process is not healthy."); + } + } + + return nullptr; // success } void -ModelInstanceState::TerminateMonitor() +ModelInstanceState::SendMessageAndReceiveResponse( + bi::managed_external_buffer::handle_t message, + bi::managed_external_buffer::handle_t& response, + std::shared_ptr>& responses, + TRITONBACKEND_Request** requests, const uint32_t request_count) { - if (stub_to_parent_thread_) { - stub_to_parent_thread_ = false; - // Push a dummy message to signal the thread to terminate. - Stub()->StubToParentMessageQueue()->Push(DUMMY_MESSAGE); - stub_to_parent_queue_monitor_.join(); + auto error = SendMessageToStub(message); + if (error != nullptr) { + RespondErrorToAllRequests( + TRITONSERVER_ErrorMessage(error), responses, requests, request_count); + + return; + } + + bi::managed_external_buffer::handle_t response_message; + error = Stub()->ReceiveMessageFromStub(response_message); + if (error != nullptr) { + RespondErrorToAllRequests( + TRITONSERVER_ErrorMessage(error), responses, requests, request_count); + + return; } + + response = response_message; } void -ModelInstanceState::ResponseSendDecoupled( - std::shared_ptr response_send_message) +ModelInstanceState::RespondErrorToAllRequests( + const char* message, + std::shared_ptr>& responses, + TRITONBACKEND_Request** requests, const uint32_t request_count) { - AllocatedSharedMemory send_message = - Stub()->ShmPool()->Load( - response_send_message->Args()); - - ResponseSendMessage* send_message_payload = + for (uint32_t r = 0; r < request_count; ++r) { + if ((*responses)[r] == nullptr) + continue; + + std::string err_message = + std::string( + "Failed to process the request(s) for model instance '" + Name() + + "', message: ") + + message; + + TRITONSERVER_Error* err = + TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, err_message.c_str()); + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + (*responses)[r], TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), + "failed sending response"); + + (*responses)[r] = nullptr; + TRITONSERVER_ErrorDelete(err); + } +} + + +void +ModelInstanceState::StartMonitor() +{ + stub_to_parent_thread_ = true; + stub_to_parent_queue_monitor_ = + std::thread(&ModelInstanceState::StubToParentMQMonitor, this); +} + +void +ModelInstanceState::TerminateMonitor() +{ + if (stub_to_parent_thread_) { + stub_to_parent_thread_ = false; + // Push a dummy message to signal the thread to terminate. + Stub()->StubToParentMessageQueue()->Push(DUMMY_MESSAGE); + stub_to_parent_queue_monitor_.join(); + } +} + +void +ModelInstanceState::ResponseSendDecoupled( + std::shared_ptr response_send_message) +{ + AllocatedSharedMemory send_message = + Stub()->ShmPool()->Load( + response_send_message->Args()); + + ResponseSendMessage* send_message_payload = reinterpret_cast(send_message.data_.get()); std::unique_ptr error_message; + ScopedDefer response_factory_deleter([send_message_payload] { + if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { + TRITONBACKEND_ResponseFactory* response_factory = + reinterpret_cast( + send_message_payload->response_factory_address); + std::unique_ptr< + TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> + lresponse_factory(reinterpret_cast( + response_factory)); + } + }); ScopedDefer _([send_message_payload] { { bi::scoped_lock guard{send_message_payload->mu}; @@ -1090,8 +1217,10 @@ ModelInstanceState::ResponseSendDecoupled( reinterpret_cast( send_message_payload->response_factory_address); if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { - std::lock_guard guard{closed_requests_mutex_}; - closed_requests_.push_back(send_message_payload->request_address); + { + std::lock_guard guard{closed_requests_mutex_}; + closed_requests_.push_back(send_message_payload->request_address); + } } if (send_message_payload->response != 0) { @@ -1109,14 +1238,17 @@ ModelInstanceState::ResponseSendDecoupled( error_message); std::vector, void*>> gpu_output_buffers; - std::unique_ptr< - TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> - response_factory_ptr; GPUBuffersHelper gpu_buffer_helper; - if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { - response_factory_ptr.reset( - reinterpret_cast(response_factory)); + +#ifdef TRITON_ENABLE_GPU + for (auto& output_tensor : infer_response->OutputTensors()) { + if (!output_tensor->IsCPU()) { + // Attempt to use the cuda shared memory pool for GPU tensor. + ShareCUDAMemoryPool(output_tensor->MemoryTypeId()); + } } +#endif // TRITON_ENABLE_GPU + infer_response->Send( response, CudaStream(), requires_deferred_callback, send_message_payload->flags, Stub()->ShmPool(), gpu_buffer_helper, @@ -1140,23 +1272,52 @@ ModelInstanceState::ResponseSendDecoupled( bool cuda_copy = false; for (auto& output_buffer_pair : gpu_output_buffers) { auto& pb_memory = output_buffer_pair.first; + void* pointer = output_buffer_pair.second; + bool cuda_used; - if (pb_memory->MemoryType() == TRITONSERVER_MEMORY_CPU) { - bool cuda_used; - void* pointer = output_buffer_pair.second; - - CopyBuffer( - "Failed to copy the output tensor to buffer.", - TRITONSERVER_MEMORY_CPU, 0, TRITONSERVER_MEMORY_CPU, 0, - pb_memory->ByteSize(), pb_memory->DataPtr(), pointer, - CudaStream(), &cuda_used); - cuda_copy |= cuda_used; - } + try { + if (pb_memory->MemoryType() == TRITONSERVER_MEMORY_CPU) { + THROW_IF_TRITON_ERROR(CopyBuffer( + "Failed to copy the CPU output tensor to buffer.", + TRITONSERVER_MEMORY_CPU, 0, TRITONSERVER_MEMORY_CPU, 0, + pb_memory->ByteSize(), pb_memory->DataPtr(), pointer, + CudaStream(), &cuda_used)); + cuda_copy |= cuda_used; + } else if ( + (pb_memory->MemoryType() == TRITONSERVER_MEMORY_GPU) && + pb_memory->UseCUDASharedPool() && + (pb_memory->DataPtr() != pointer)) { + // If the data pointer from pb_memory is not the same as the + // pointer, it means that the Triton-provided buffer is not used + // during tensor transfer. Instead, an intermediate buffer that uses + // CUDA shared memory pool is used. In this case, we need to copy + // the data from the intermediate buffer back to the Triton-provided + // buffer. + THROW_IF_TRITON_ERROR(CopyBuffer( + "Failed to copy the GPU output tensor to buffer.", + TRITONSERVER_MEMORY_GPU, pb_memory->MemoryTypeId(), + TRITONSERVER_MEMORY_GPU, pb_memory->MemoryTypeId(), + pb_memory->ByteSize(), pb_memory->DataPtr(), pointer, + CudaStream(), &cuda_used)); + cuda_copy |= cuda_used; + } #ifdef TRITON_ENABLE_GPU - if (cuda_copy) { - cudaStreamSynchronize(stream_); - } + if (cuda_copy) { + cudaStreamSynchronize(stream_); + } #endif // TRITON_ENABLE_GPU + } + catch (const PythonBackendException& pb_exception) { + TRITONSERVER_Error* error = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string( + "Failed to copy output tensor to Triton-provided buffer: ") + + pb_exception.what()) + .c_str()); + SetErrorForResponseSendMessage( + send_message_payload, WrapTritonErrorInSharedPtr(error), + error_message); + } } } } else { @@ -1164,20 +1325,13 @@ ModelInstanceState::ResponseSendDecoupled( response_factory, send_message_payload->flags); SetErrorForResponseSendMessage( send_message_payload, WrapTritonErrorInSharedPtr(error), error_message); - - if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { - std::unique_ptr< - TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> - response_factory(reinterpret_cast( - send_message_payload->response_factory_address)); - } } } TRITONSERVER_Error* -ModelInstanceState::ProcessRequestsDecoupled( +ModelInstanceState::ProcessRequests( TRITONBACKEND_Request** requests, const uint32_t request_count, - std::vector>& pb_inference_requests, + std::vector>& pb_infer_requests, PbMetricReporter& reporter) { NVTX_RANGE(nvtx_, "ProcessRequests " + Name()); @@ -1203,8 +1357,7 @@ ModelInstanceState::ProcessRequestsDecoupled( std::shared_ptr> responses; RETURN_IF_ERROR(SaveRequestsToSharedMemory( - requests, request_count, pb_inference_requests, request_batch, - responses)); + requests, request_count, pb_infer_requests, request_batch, responses)); uint64_t compute_start_ns = 0; SET_TIMESTAMP(compute_start_ns); @@ -1216,30 +1369,48 @@ ModelInstanceState::ProcessRequestsDecoupled( IPCMessage::Create(Stub()->ShmPool(), false /*inline_response*/)); ipc_message->Command() = PYTHONSTUB_CommandType::PYTHONSTUB_ExecuteRequest; ipc_message->Args() = request_batch.handle_; - received_message_ = nullptr; - ScopedDefer _([this] { + + ScopedDefer execute_finalize([this] { // Push a dummy message to signal the thread to terminate. Stub()->StubMessageQueue()->Push(DUMMY_MESSAGE); }); + std::unique_ptr response; { - std::unique_lock guard{mu_}; Stub()->StubMessageQueue()->Push(ipc_message->ShmHandle()); - cv_.wait(guard, [this] { return received_message_ != nullptr; }); + bi::managed_external_buffer::handle_t response_message; + RETURN_IF_ERROR(Stub()->ReceiveMessageFromStub(response_message)); + response = + IPCMessage::LoadFromSharedMemory(Stub()->ShmPool(), response_message); } - - AllocatedSharedMemory response_batch = - Stub()->ShmPool()->Load(received_message_->Args()); + char* ipc_message_shm = + reinterpret_cast(response->GetAllocatedSharedMemory().data_.get()); + ResponseBatch* response_batch_shm_ptr = + reinterpret_cast(ipc_message_shm + sizeof(IPCMessageShm)); uint64_t compute_end_ns = 0; SET_TIMESTAMP(compute_end_ns); reporter.SetComputeEndNs(compute_end_ns); - reporter.SetBatchStatistics(request_count); + reporter.SetBatchStatistics(total_batch_size); - if (response_batch.data_->has_error) { - if (response_batch.data_->is_error_set) { + if (response_batch_shm_ptr->has_error) { + // Clean up the response factory if an error occurred. The + // `is_response_factory_deleted` flag indicates whether the response factory + // has been deleted for some corner cases. + if (!response_batch_shm_ptr->is_response_factory_deleted) { + for (uint32_t r = 0; r < request_count; r++) { + TRITONBACKEND_ResponseFactory* response_factory = + reinterpret_cast( + pb_infer_requests[r]->GetResponseFactoryAddress()); + std::unique_ptr< + TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> + lresponse_factory(reinterpret_cast( + response_factory)); + } + } + if (response_batch_shm_ptr->is_error_set) { auto error = PbString::LoadFromSharedMemory( - Stub()->ShmPool(), response_batch.data_->error); + Stub()->ShmPool(), response_batch_shm_ptr->error); return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INTERNAL, error->String().c_str()); } @@ -1248,210 +1419,100 @@ ModelInstanceState::ProcessRequestsDecoupled( TRITONSERVER_ERROR_INTERNAL, "Failed to process the requests."); } - return nullptr; // success -} - -void -ModelInstanceState::ProcessRequests( - TRITONBACKEND_Request** requests, const uint32_t request_count, - bool& restart) -{ - NVTX_RANGE(nvtx_, "ProcessRequests " + Name()); - ModelState* model_state = reinterpret_cast(Model()); - std::string name = model_state->Name(); - - LOG_MESSAGE( - TRITONSERVER_LOG_VERBOSE, - (std::string("model ") + model_state->Name() + ", instance " + Name() + - ", executing " + std::to_string(request_count) + " requests") - .c_str()); - - uint64_t exec_start_ns = 0; - SET_TIMESTAMP(exec_start_ns); - - // We take the responsibility of the responses. - std::shared_ptr> responses( - new std::vector()); - responses->reserve(request_count); - PbMetricReporter reporter( - TritonModelInstance(), requests, request_count, responses); - reporter.SetExecStartNs(exec_start_ns); - - for (size_t i = 0; i < request_count; i++) { - TRITONBACKEND_Response* response; - auto err = TRITONBACKEND_ResponseNew(&response, requests[i]); - if (err == nullptr) { - responses->emplace_back(response); - } else { - responses->emplace_back(nullptr); - LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response"); - TRITONSERVER_ErrorDelete(err); - } - } - - size_t total_batch_size = 0; - RESPOND_ALL_AND_RETURN_IF_ERROR( - responses, request_count, - CheckIncomingRequests(requests, request_count, total_batch_size)); - - // No request to process - if (total_batch_size == 0) { - return; - } - - // Wait for all the pending BLS requests to be completed. - ScopedDefer bls_defer([this] { WaitForBLSRequestsToFinish(); }); - std::vector> pb_inference_requests; - AllocatedSharedMemory request_batch; - RESPOND_ALL_AND_RETURN_IF_ERROR( - responses, request_count, - SaveRequestsToSharedMemory( - requests, request_count, pb_inference_requests, request_batch, - responses)); - - std::shared_ptr ipc_message = - IPCMessage::Create(Stub()->ShmPool(), false /*inline_response*/); - ipc_message->Command() = PYTHONSTUB_CommandType::PYTHONSTUB_ExecuteRequest; - ipc_message->Args() = request_batch.handle_; - - uint64_t compute_start_ns = 0; - SET_TIMESTAMP(compute_start_ns); - reporter.SetComputeStartNs(compute_start_ns); - - // This means that the stub process has exited and Python - // backend failed to restart the stub process. - if (Stub()->StubPid() == 0) { - const char* error_message = "The stub process has exited unexpectedly."; - RespondErrorToAllRequests( - error_message, responses, requests, request_count); - return; - } - - bi::managed_external_buffer::handle_t response_message; - { - NVTX_RANGE(nvtx_, "StubProcessing " + Name()); - SendMessageAndReceiveResponse( - ipc_message->ShmHandle(), response_message, restart, responses, - requests, request_count); - } - - ScopedDefer execute_finalize([this, &restart] { - // Push a dummy message to the message queue so that - // the stub process is notified that it can release - // the object stored in shared memory. - NVTX_RANGE(nvtx_, "RequestExecuteFinalize " + Name()); - if (!restart) - // Push a dummy message to signal the thread to terminate. - Stub()->StubMessageQueue()->Push(DUMMY_MESSAGE); - }); - if (restart) { - return; - } - - RESPOND_ALL_AND_RETURN_IF_EXCEPTION( - responses, request_count, - ipc_message = IPCMessage::LoadFromSharedMemory( - Stub()->ShmPool(), response_message)); - - // If the stub command is no longer PYTHONSTUB_InferExecRequest, it indicates - // that inference request execution has finished and there are no more BLS - // requests to execute. Otherwise, the Python backend will continuously - // execute BLS requests pushed to the message queue. - while (ipc_message->Command() == - PYTHONSTUB_CommandType::PYTHONSTUB_InferExecRequest || - ipc_message->Command() == - PYTHONSTUB_CommandType::PYTHONSTUB_InferStreamExecRequest) { - std::packaged_task task([this, ipc_message] { - ExecuteBLSRequest( - ipc_message, - (ipc_message->Command() == - PYTHONSTUB_CommandType::PYTHONSTUB_InferStreamExecRequest)); - }); - std::future future = - boost::asio::post(*thread_pool_, std::move(task)); - futures_.emplace_back(std::move(future)); - - auto error = Stub()->ReceiveMessageFromStub(response_message); - if (error != nullptr) { - restart = true; - RespondErrorToAllRequests( - TRITONSERVER_ErrorMessage(error), responses, requests, request_count); - return; - } - - RESPOND_ALL_AND_RETURN_IF_EXCEPTION( - responses, request_count, - ipc_message = IPCMessage::LoadFromSharedMemory( - Stub()->ShmPool(), response_message)); - } - - uint64_t compute_end_ns = 0; - SET_TIMESTAMP(compute_end_ns); - reporter.SetComputeEndNs(compute_end_ns); - - // Parsing the request response - AllocatedSharedMemory response_batch; - RESPOND_ALL_AND_RETURN_IF_EXCEPTION( - responses, request_count, - response_batch = Stub()->ShmPool()->Load(ipc_message->Args())); - - ResponseBatch* response_batch_shm_ptr = - reinterpret_cast(response_batch.data_.get()); - - // If inference fails, release all the requests and send an error response. - // If inference fails at this stage, it usually indicates a bug in the model - // code - if (response_batch_shm_ptr->has_error) { - if (response_batch_shm_ptr->is_error_set) { - std::unique_ptr error_message_shm; - RESPOND_ALL_AND_RETURN_IF_EXCEPTION( - responses, request_count, - error_message_shm = PbString::LoadFromSharedMemory( - Stub()->ShmPool(), response_batch_shm_ptr->error)); - RespondErrorToAllRequests( - error_message_shm->String().c_str(), responses, requests, - request_count); - } else { - const char* error_message = - "Failed to fetch the error in response batch."; - RespondErrorToAllRequests( - error_message, responses, requests, request_count); + if (response_batch_shm_ptr->batch_size > 0) { + bi::managed_external_buffer::handle_t* response_shm_handle = + reinterpret_cast( + ipc_message_shm + sizeof(ResponseBatch) + sizeof(IPCMessageShm)); + + std::shared_ptr> responses( + new std::vector()); + responses->reserve(request_count); + for (size_t i = 0; i < request_count; i++) { + // It is possible to have multiple responses batched together in a single + // response batch shm, where some of the responses are None due to the + // usage of response sender, so only create a TRITONBACKEND_Response + // object for the valid responses. + if (response_shm_handle[i] == 0) { + responses->emplace_back(nullptr); + } else { + TRITONBACKEND_Response* response; + auto err = TRITONBACKEND_ResponseNew(&response, requests[i]); + if (err == nullptr) { + responses->emplace_back(response); + } else { + responses->emplace_back(nullptr); + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response"); + TRITONSERVER_ErrorDelete(err); + } + } } - return; - } - bi::managed_external_buffer::handle_t* response_shm_handle = - reinterpret_cast( - response_batch.data_.get() + sizeof(ResponseBatch)); + std::vector requires_deferred_callback; - // If the output provided by the model is in GPU, we will pass the list of - // buffers provided by Triton to the stub process. - bool has_gpu_output = false; - std::vector requires_deferred_callback; + bool has_gpu_output = false; + std::vector> shm_responses; + std::vector, void*>>> + gpu_output_buffers(request_count); + GPUBuffersHelper gpu_buffer_helper; - std::vector> shm_responses; - std::vector, void*>>> - gpu_output_buffers(request_count); - GPUBuffersHelper gpu_buffer_helper; + for (uint32_t r = 0; r < request_count; ++r) { + NVTX_RANGE(nvtx_, "LoadingResponse " + Name()); + requires_deferred_callback.push_back(false); + if (response_shm_handle[r] == 0) { + continue; + } + TRITONBACKEND_Response* response = (*responses)[r]; + TRITONBACKEND_Request* request = requests[r]; + uint32_t requested_output_count = 0; - for (uint32_t r = 0; r < request_count; ++r) { - NVTX_RANGE(nvtx_, "LoadingResponse " + Name()); - TRITONBACKEND_Response* response = (*responses)[r]; - TRITONBACKEND_Request* request = requests[r]; - uint32_t requested_output_count = 0; - requires_deferred_callback.push_back(false); + shm_responses.emplace_back(nullptr); + std::unique_ptr& infer_response = shm_responses.back(); + try { + if (pb_infer_requests[r]->ReleaseFlags() == + TRITONSERVER_REQUEST_RELEASE_RESCHEDULE) { + // For rescheduled requests, we do not need to send a response. + LOG_IF_ERROR( + TRITONBACKEND_ResponseDelete((*responses)[r]), + "failed to delete response"); + (*responses)[r] = nullptr; + continue; + } + { + TRITONBACKEND_ResponseFactory* response_factory = + reinterpret_cast( + pb_infer_requests[r]->GetResponseFactoryAddress()); + std::unique_ptr< + TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> + lresponse_factory( + reinterpret_cast( + response_factory)); + } + infer_response = InferResponse::LoadFromSharedMemory( + Stub()->ShmPool(), response_shm_handle[r], + false /* open_cuda_handle */); + if (infer_response->HasError()) { + TRITONSERVER_Error* err = TRITONSERVER_ErrorNew( + infer_response->Error()->Code(), + infer_response->Error()->Message().c_str()); - shm_responses.emplace_back(nullptr); - std::unique_ptr& infer_response = shm_responses.back(); - try { - infer_response = InferResponse::LoadFromSharedMemory( - Stub()->ShmPool(), response_shm_handle[r], - false /* open_cuda_handle */); - if (infer_response->HasError()) { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + (*responses)[r], TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), + "failed sending response"); + TRITONSERVER_ErrorDelete(err); + (*responses)[r] = nullptr; + + // Reset the release flags for the request. + pb_infer_requests[r]->SetReleaseFlags( + TRITONSERVER_REQUEST_RELEASE_ALL); + + // If has_error is true, we do not look at the response tensors. + continue; + } + } + catch (const PythonBackendException& pb_exception) { TRITONSERVER_Error* err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - infer_response->Error()->Message().c_str()); - + TRITONSERVER_ERROR_INTERNAL, pb_exception.what()); LOG_IF_ERROR( TRITONBACKEND_ResponseSend( (*responses)[r], TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), @@ -1459,106 +1520,118 @@ ModelInstanceState::ProcessRequests( TRITONSERVER_ErrorDelete(err); (*responses)[r] = nullptr; - // If has_error is true, we do not look at the response tensors. + // Reset the release flags for the request. + pb_infer_requests[r]->SetReleaseFlags(TRITONSERVER_REQUEST_RELEASE_ALL); + continue; } - } - catch (const PythonBackendException& pb_exception) { - TRITONSERVER_Error* err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, pb_exception.what()); - LOG_IF_ERROR( - TRITONBACKEND_ResponseSend( - (*responses)[r], TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), - "failed sending response"); - TRITONSERVER_ErrorDelete(err); - (*responses)[r] = nullptr; - continue; - } - GUARDED_RESPOND_IF_ERROR( - responses, r, - TRITONBACKEND_RequestOutputCount(request, &requested_output_count)); - - std::set requested_output_names; - for (size_t j = 0; j < requested_output_count; ++j) { - const char* output_name; GUARDED_RESPOND_IF_ERROR( responses, r, - TRITONBACKEND_RequestOutputName(request, j, &output_name)); - requested_output_names.insert(output_name); - } + TRITONBACKEND_RequestOutputCount(request, &requested_output_count)); + std::set requested_output_names; + for (size_t j = 0; j < requested_output_count; ++j) { + const char* output_name; + GUARDED_RESPOND_IF_ERROR( + responses, r, + TRITONBACKEND_RequestOutputName(request, j, &output_name)); + requested_output_names.insert(output_name); + } - bool require_deferred_callback = false; + bool require_deferred_callback = false; - gpu_output_buffers[r] = - std::vector, void*>>{}; - infer_response->Send( - response, CudaStream(), require_deferred_callback, - TRITONSERVER_RESPONSE_COMPLETE_FINAL, Stub()->ShmPool(), - gpu_buffer_helper, gpu_output_buffers[r], requested_output_names); +#ifdef TRITON_ENABLE_GPU + for (auto& output_tensor : infer_response->OutputTensors()) { + if (output_tensor->MemoryType() == TRITONSERVER_MEMORY_GPU) { + // Attempt to use the cuda shared memory pool for GPU tensor. + ShareCUDAMemoryPool(output_tensor->MemoryTypeId()); + } + } +#endif // TRITON_ENABLE_GPU + + gpu_output_buffers[r] = + std::vector, void*>>{}; + infer_response->Send( + response, CudaStream(), require_deferred_callback, + TRITONSERVER_RESPONSE_COMPLETE_FINAL, Stub()->ShmPool(), + gpu_buffer_helper, gpu_output_buffers[r], requested_output_names); - requires_deferred_callback[r] = require_deferred_callback; + requires_deferred_callback[r] = require_deferred_callback; - if (requires_deferred_callback[r]) { - has_gpu_output = true; + if (requires_deferred_callback[r]) { + has_gpu_output = true; + } } - } - // Finalize the execute. - execute_finalize.Complete(); - - // If the output tensor is in GPU, there will be a second round trip - // required for filling the GPU buffers provided by the main process. - if (has_gpu_output) { - ipc_message->Command() = PYTHONSTUB_CommandType::PYTHONSTUB_LoadGPUBuffers; - gpu_buffer_helper.Complete(Stub()->ShmPool()); - ipc_message->Args() = gpu_buffer_helper.ShmHandle(); - SendMessageAndReceiveResponse( - ipc_message->ShmHandle(), response_message, restart, responses, - requests, 0); - - bool cuda_copy = false; - - uint32_t response_index = 0; - for (auto& gpu_output_buffer : gpu_output_buffers) { - for (auto& buffer_memory_pair : gpu_output_buffer) { - auto& pb_memory = buffer_memory_pair.first; - if (pb_memory->MemoryType() == TRITONSERVER_MEMORY_CPU) { - bool cuda_used = false; + execute_finalize.Complete(); + + // If the output tensor is in GPU, there will be a second round trip + // required for filling the GPU buffers provided by the main process. + if (has_gpu_output) { + ipc_message->Command() = + PYTHONSTUB_CommandType::PYTHONSTUB_LoadGPUBuffers; + gpu_buffer_helper.Complete(Stub()->ShmPool()); + ipc_message->Args() = gpu_buffer_helper.ShmHandle(); + bi::managed_external_buffer::handle_t response_message; + SendMessageAndReceiveResponse( + ipc_message->ShmHandle(), response_message, responses, requests, 0); + + bool cuda_copy = false; + + uint32_t response_index = 0; + for (auto& gpu_output_buffer : gpu_output_buffers) { + for (auto& buffer_memory_pair : gpu_output_buffer) { + auto& pb_memory = buffer_memory_pair.first; void* pointer = buffer_memory_pair.second; + bool cuda_used = false; - GUARDED_RESPOND_IF_ERROR( - responses, response_index, - CopyBuffer( - "Failed to copy the output tensor to buffer.", - TRITONSERVER_MEMORY_CPU, 0, TRITONSERVER_MEMORY_CPU, 0, - pb_memory->ByteSize(), pb_memory->DataPtr(), pointer, - CudaStream(), &cuda_used)); - cuda_copy |= cuda_used; + if (pb_memory->MemoryType() == TRITONSERVER_MEMORY_CPU) { + GUARDED_RESPOND_IF_ERROR( + responses, response_index, + CopyBuffer( + "Failed to copy the output tensor to buffer.", + TRITONSERVER_MEMORY_CPU, 0, TRITONSERVER_MEMORY_CPU, 0, + pb_memory->ByteSize(), pb_memory->DataPtr(), pointer, + CudaStream(), &cuda_used)); + cuda_copy |= cuda_used; + } else if ( + (pb_memory->MemoryType() == TRITONSERVER_MEMORY_GPU) && + pb_memory->UseCUDASharedPool() && + (pb_memory->DataPtr() != pointer)) { + // If the data pointer from pb_memory is not the same as the + // pointer, it means that the Triton-provided buffer is not used + // during tensor transfer. Instead, an intermediate buffer that uses + // CUDA shared memory pool is used. In this case, we need to copy + // the data from the intermediate buffer back to the Triton-provided + // buffer. + GUARDED_RESPOND_IF_ERROR( + responses, response_index, + CopyBuffer( + "Failed to copy the output tensor to buffer.", + TRITONSERVER_MEMORY_GPU, pb_memory->MemoryTypeId(), + TRITONSERVER_MEMORY_GPU, pb_memory->MemoryTypeId(), + pb_memory->ByteSize(), pb_memory->DataPtr(), pointer, + CudaStream(), &cuda_used)); + cuda_copy |= cuda_used; + } } - } - response_index++; + response_index++; #ifdef TRITON_ENABLE_GPU - if (cuda_copy) { - cudaStreamSynchronize(stream_); - } + if (cuda_copy) { + cudaStreamSynchronize(stream_); + } #endif // TRITON_ENABLE_GPU + } } - } - bls_defer.Complete(); - for (uint32_t r = 0; r < request_count; ++r) { - if (requires_deferred_callback[r]) { - shm_responses[r]->DeferredSendCallback(); + for (uint32_t r = 0; r < request_count; ++r) { + if (requires_deferred_callback[r]) { + shm_responses[r]->DeferredSendCallback(); + } } } - uint64_t exec_end_ns = 0; - SET_TIMESTAMP(exec_end_ns); - reporter.SetExecEndNs(exec_end_ns); - reporter.SetBatchStatistics(total_batch_size); - - return; + return nullptr; // success } void @@ -1589,16 +1662,36 @@ ModelInstanceState::PrepareResponseHandle( std::unique_ptr* infer_response, bi::managed_external_buffer::handle_t* response_handle) { +#ifdef TRITON_ENABLE_GPU + for (auto& output_tensor : (*infer_response)->OutputTensors()) { + if (!output_tensor->IsCPU()) { + // Attempt to use the cuda shared memory pool for GPU tensor. + ShareCUDAMemoryPool(output_tensor->MemoryTypeId()); + // It's possible that the CUDA memory pool offset isn't set correctly, + // even if the BLS output is using CUDA memory. This can occur when the + // CUDA memory pool hasn't been shared with the stub process at the time + // the BLS output is allocated during the ResponseAlloc callback. In such + // cases, we need to adjust the CUDA pool offset accordingly. + if (!output_tensor->Memory()->UseCUDASharedPool()) { + output_tensor->Memory()->UpdateCUDAOffset( + Stub()->ShmPool()->GetCUDAMemoryPoolManager()); + } + } + } +#endif // TRITON_ENABLE_GPU + (*infer_response)->SaveToSharedMemory(Stub()->ShmPool()); + for (auto& output_tensor : (*infer_response)->OutputTensors()) { - // For GPU tensors we need to store the memory release id in - // memory manager. if (!output_tensor->IsCPU()) { #ifdef TRITON_ENABLE_GPU - std::unique_ptr gpu_memory_record = - std::make_unique(output_tensor->Memory()->DataPtr()); + std::unique_ptr memory_record; + // Need to transfer the ownership of the BackendMemory to the + // MemoryManager so that the lifetime of the BackendMemory is managed. + memory_record = std::make_unique( + output_tensor->Memory()->GetBackendMemory()); uint64_t memory_release_id = - Stub()->GetMemoryManager()->AddRecord(std::move(gpu_memory_record)); + Stub()->GetMemoryManager()->AddRecord(std::move(memory_record)); output_tensor->Memory()->SetMemoryReleaseId(memory_release_id); #endif } @@ -1622,6 +1715,7 @@ ModelInstanceState::SendBLSDecoupledResponse( ipc_message = IPCMessage::Create(Stub()->ShmPool(), true /* inline_response */); ipc_message->Args() = response_batch_shm.handle_; + ipc_message->Command() = PYTHONSTUB_InferStreamExecResponse; PrepareResponseBatch( &response_batch, response_batch_shm, &ipc_message, &response_handle); is_response_batch_set = true; @@ -1654,17 +1748,28 @@ ModelInstanceState::SendBLSDecoupledResponse( } } +void +ModelInstanceState::ShareCUDAMemoryPool(const int32_t device_id) +{ +#ifdef TRITON_ENABLE_GPU + try { + Stub()->ShareCUDAMemoryPool(Model()->TritonMemoryManager(), device_id); + } + catch (const PythonBackendException& ex) { + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, + (std::string("Failed to share CUDA memory pool with stub process: ") + + ex.what() + ". Will use CUDA IPC.") + .c_str()); + } +#endif // TRITON_ENABLE_GPU +} + ModelInstanceState::~ModelInstanceState() { - ModelState* model_state = reinterpret_cast(Model()); Stub()->UpdateHealth(); if (Stub()->IsHealthy()) { - if (model_state->IsDecoupled()) { - futures_.clear(); - // Push a dummy message to signal the thread to terminate. - Stub()->ParentMessageQueue()->Push(DUMMY_MESSAGE); - decoupled_monitor_.join(); - } + // Wait for all the pending tasks to finish. thread_pool_->wait(); } // Terminate stub first to allow any last messages to be received by the back @@ -1672,7 +1777,6 @@ ModelInstanceState::~ModelInstanceState() Stub()->TerminateStub(); TerminateMonitor(); Stub()->ClearQueues(); - received_message_.reset(); Stub().reset(); } @@ -1723,11 +1827,12 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) python_execution_env_ = ""; force_cpu_only_input_tensors_ = true; decoupled_ = false; - platform_ = ""; void* bstate; THROW_IF_BACKEND_MODEL_ERROR(TRITONBACKEND_BackendState(backend, &bstate)); backend_state_ = reinterpret_cast(bstate); + + runtime_modeldir_ = backend_state_->runtime_modeldir; triton::common::TritonJson::Value params; common::TritonJson::Value model_config; if (model_config_.Find("parameters", ¶ms)) { @@ -1764,14 +1869,6 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) } } - triton::common::TritonJson::Value platform; - if (model_config_.Find("platform", &platform)) { - auto error = platform.AsString(&platform_); - if (error != nullptr) { - throw BackendModelException(error); - } - } - // Skip the FORCE_CPU_ONLY_INPUT_TENSORS variable if it doesn't exits. std::string force_cpu_only_input_tensor; error = nullptr; @@ -1847,10 +1944,33 @@ ModelState::ValidateModelConfig() return nullptr; } +TRITONSERVER_Error* +ModelState::SetModelConfig() +{ + BackendModel::SetModelConfig(); + // `Update model_transaction_policy` if setting was set + // with `set_model_transaction_policy` + triton::common::TritonJson::Value model_transaction_policy; + bool is_decoupled = false; + if (ModelConfig().Find( + "model_transaction_policy", &model_transaction_policy)) { + triton::common::TritonJson::Value decoupled; + if (model_transaction_policy.Find("decoupled", &decoupled)) { + auto error = decoupled.AsBool(&is_decoupled); + if (error != nullptr) { + throw BackendModelException(error); + } + SetDecoupled(is_decoupled); + } + } + + return nullptr; +} + extern "C" { -TRITONSERVER_Error* +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) { const char* cname; @@ -1894,14 +2014,16 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) std::unique_ptr backend_state(new BackendState()); triton::common::TritonJson::Value cmdline; - backend_state->shm_default_byte_size = 64 * 1024 * 1024; // 64 MBs - backend_state->shm_growth_byte_size = 64 * 1024 * 1024; // 64 MBs + backend_state->shm_default_byte_size = 1 * 1024 * 1024; // 1 MB + backend_state->shm_growth_byte_size = 1 * 1024 * 1024; // 1 MB backend_state->stub_timeout_seconds = 30; backend_state->shm_message_queue_size = 1000; - backend_state->number_of_instance_inits = 0; backend_state->thread_pool_size = 32; + // Initialize shared memory region prefix to include backend's name + // to avoid collision between python backend and python-based backends. backend_state->shared_memory_region_prefix = - "triton_python_backend_shm_region_"; + "triton_" + name + "_backend_shm_region_"; + std::string default_backend_dir_string; if (backend_config.Find("cmdline", &cmdline)) { triton::common::TritonJson::Value shm_growth_size; @@ -1929,8 +2051,8 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) RETURN_IF_ERROR(shm_default_size.AsString(&shm_default_byte_size)); try { backend_state->shm_default_byte_size = std::stol(shm_default_byte_size); - // Shared memory default byte size can't be less than 4 MBs. - if (backend_state->shm_default_byte_size < 4 * 1024 * 1024) { + // Shared memory default byte size can't be less than 1 MB. + if (backend_state->shm_default_byte_size < 1 * 1024 * 1024) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, (std::string("shm-default-byte-size") + @@ -2011,6 +2133,12 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG, ia.what()); } } + + triton::common::TritonJson::Value default_backend_dir; + if (cmdline.Find("backend-directory", &default_backend_dir)) { + RETURN_IF_ERROR( + default_backend_dir.AsString(&default_backend_dir_string)); + } } LOG_MESSAGE( @@ -2024,12 +2152,65 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) .c_str()); // Use BackendArtifacts to determine the location of Python files - const char* location; + const char* clocation; TRITONBACKEND_ArtifactType artifact_type; RETURN_IF_ERROR( - TRITONBACKEND_BackendArtifacts(backend, &artifact_type, &location)); - backend_state->python_lib = location; + TRITONBACKEND_BackendArtifacts(backend, &artifact_type, &clocation)); + + const char os_slash = std::filesystem::path::preferred_separator; + std::string location(clocation); +#ifdef _WIN32 + const std::string stub_executable_name = "triton_python_backend_stub.exe"; + SanitizePath(location); + SanitizePath(default_backend_dir_string); +#else + const std::string stub_executable_name = "triton_python_backend_stub"; +#endif + // Check if `triton_python_backend_stub` and `triton_python_backend_utils.py` + // are located under `location`. + std::string default_python_backend_dir = + default_backend_dir_string + os_slash + "python"; + std::string backend_stub_path = location + os_slash + stub_executable_name; + std::string backend_utils = + location + os_slash + "triton_python_backend_utils.py"; + // Both, stub and utils should be in the same location + if (FileExists(backend_stub_path) && FileExists(backend_utils)) { + backend_state->python_lib = location; + // If `location` is default location of a python backend, + // then we are using default python backend. + if (default_python_backend_dir == location) { + backend_state->runtime_modeldir = ""; + } else { + // If `location` is not default location of a python backend, + // then we are using a python backend based backend and model.py stored + // in the received location. + backend_state->runtime_modeldir = location; + } + } else { + // If stub and utils are not found in received `location`, + // then we are using a python backend based backend and stub and utils are + // stored in the default python backend location. + if (!default_backend_dir_string.empty()) { + std::string backend_stub_path = default_backend_dir_string + os_slash + + "python" + os_slash + + stub_executable_name; + if (!FileExists(backend_stub_path)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_NOT_FOUND, + (stub_executable_name + " is not found. Searched paths: " + + default_backend_dir_string + os_slash + "python and " + location) + .c_str()); + } + } + backend_state->runtime_modeldir = location; + backend_state->python_lib = + default_backend_dir_string + os_slash + "python"; + } +// FIXME [DLIS-5969]: Enable for Windows when custom execution environments +// are supported. +#ifndef _WIN32 backend_state->env_manager = std::make_unique(); +#endif RETURN_IF_ERROR(TRITONBACKEND_BackendSetState( backend, reinterpret_cast(backend_state.get()))); @@ -2038,7 +2219,7 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) return nullptr; } -TRITONSERVER_Error* +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_Finalize(TRITONBACKEND_Backend* backend) { LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, "TRITONBACKEND_Finalize: Start"); @@ -2050,7 +2231,7 @@ TRITONBACKEND_Finalize(TRITONBACKEND_Backend* backend) return nullptr; // success } -TRITONSERVER_Error* +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) { const char* cname; @@ -2077,7 +2258,7 @@ TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) return nullptr; } -TRITONSERVER_Error* +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) { void* vstate; @@ -2093,7 +2274,7 @@ TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) return nullptr; } -TRITONSERVER_Error* +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance) { const char* cname; @@ -2136,7 +2317,7 @@ TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance) return nullptr; } -TRITONSERVER_Error* +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceExecute( TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, const uint32_t request_count) @@ -2149,29 +2330,10 @@ TRITONBACKEND_ModelInstanceExecute( // If restart is equal to true, it indicates that the stub process is // unhealthy and needs a restart. - bool restart = false; - ModelState* model_state = - reinterpret_cast(instance_state->Model()); - if (!model_state->IsDecoupled()) { - instance_state->ProcessRequests(requests, request_count, restart); - - if (restart) { - LOG_MESSAGE( - TRITONSERVER_LOG_ERROR, - "Stub process is unhealthy and it will be restarted."); - instance_state->TerminateMonitor(); - instance_state->Stub()->KillStubProcess(); - TRITONSERVER_Error* err = instance_state->Stub()->Setup(); - if (err == nullptr) { - instance_state->StartMonitor(); - } - LOG_IF_ERROR(err, "Failed to restart the stub process."); - err = instance_state->Stub()->Launch(); - LOG_IF_ERROR(err, "Failed to restart the stub process."); - } - } else { - std::vector> infer_requests; + // TODO: Implement restart on decoupled + std::vector> infer_requests; + { uint64_t exec_start_ns = 0; SET_TIMESTAMP(exec_start_ns); @@ -2180,7 +2342,7 @@ TRITONBACKEND_ModelInstanceExecute( nullptr); reporter.SetExecStartNs(exec_start_ns); - error = instance_state->ProcessRequestsDecoupled( + error = instance_state->ProcessRequests( requests, request_count, infer_requests, reporter); uint64_t exec_end_ns = 0; @@ -2207,24 +2369,41 @@ TRITONBACKEND_ModelInstanceExecute( } } - // We should only delete the response factory for the requests that have - // not been closed. for (auto& infer_request : infer_requests) { - if (!instance_state->ExistsInClosedRequests( - infer_request->RequestAddress())) { - LOG_IF_ERROR( - infer_request->DeleteResponseFactory(), - "Failed to delete the response factory."); - } + // Reset the release flags for all the requests. + infer_request->SetReleaseFlags(TRITONSERVER_REQUEST_RELEASE_ALL); } } } + // The InferRequest object might not be created if an error occurs. Explicitly + // update the release flags here based on the number of InferRequest objects. + std::vector request_release_flags( + request_count, TRITONSERVER_REQUEST_RELEASE_ALL); + for (size_t i = 0; i < infer_requests.size(); ++i) { + request_release_flags[i] = infer_requests[i]->ReleaseFlags(); + } + for (uint32_t r = 0; r < request_count; ++r) { TRITONBACKEND_Request* request = requests[r]; - LOG_IF_ERROR( - TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), - "failed releasing request"); + try { + THROW_IF_TRITON_ERROR( + TRITONBACKEND_RequestRelease(request, request_release_flags[r])); + } + catch (const PythonBackendException& pb_exception) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + (std::string("Failed to release request: ") + pb_exception.what()) + .c_str()); + if (request_release_flags[r] == TRITONSERVER_REQUEST_RELEASE_RESCHEDULE) { + // If error occurs during request rescheduling, release the request with + // `TRITONSERVER_REQUEST_RELEASE_ALL` flag. + LOG_IF_ERROR( + TRITONBACKEND_RequestRelease( + request, TRITONSERVER_REQUEST_RELEASE_ALL), + "Failed to release request."); + } + } } LOG_MESSAGE( @@ -2237,7 +2416,7 @@ TRITONBACKEND_ModelInstanceExecute( return nullptr; } -TRITONSERVER_Error* +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance) { void* vstate; @@ -2254,7 +2433,7 @@ TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance) return nullptr; } -TRITONSERVER_Error* +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_GetBackendAttribute( TRITONBACKEND_Backend* backend, TRITONBACKEND_BackendAttribute* backend_attributes) @@ -2274,6 +2453,11 @@ TRITONBACKEND_GetBackendAttribute( backend_attributes, TRITONSERVER_INSTANCEGROUPKIND_CPU, 0, nullptr, 0)); #endif + // This backend can safely handle parallel calls to + // TRITONBACKEND_ModelInstanceInitialize (thread-safe). + RETURN_IF_ERROR(TRITONBACKEND_BackendAttributeSetParallelModelInstanceLoading( + backend_attributes, true)); + return nullptr; } diff --git a/src/python_be.h b/src/python_be.h index 825c45de..6082c50b 100644 --- a/src/python_be.h +++ b/src/python_be.h @@ -1,4 +1,4 @@ -// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -26,12 +26,8 @@ #pragma once -#include #include #include -#include -#include -#include #include #include @@ -84,6 +80,14 @@ #include "triton/core/tritonbackend.h" #include "triton/core/tritonserver.h" +#ifdef _WIN32 +#define NOMINMAX +#include +#else +#include +#include +#endif + #define LOG_IF_EXCEPTION(X) \ do { \ try { \ @@ -217,7 +221,13 @@ struct BackendState { std::atomic number_of_instance_inits; std::string shared_memory_region_prefix; int64_t thread_pool_size; + +// FIXME [DLIS-5969]: Enable for Windows when custom execution environments +// are supported. +#ifndef _WIN32 std::unique_ptr env_manager; +#endif + std::string runtime_modeldir; }; class ModelState : public BackendModel { @@ -237,8 +247,11 @@ class ModelState : public BackendModel { // Is decoupled API being used. bool IsDecoupled() { return decoupled_; } - // Returns the value in the platform field - std::string Platform() { return platform_; } + // Set decoupled mode + void SetDecoupled(bool decoupled) { decoupled_ = decoupled; } + + // Returns the value in the `runtime_modeldir_` field + std::string RuntimeModelDir() { return runtime_modeldir_; } // Launch auto-complete stub process. TRITONSERVER_Error* LaunchAutoCompleteStubProcess(); @@ -246,6 +259,10 @@ class ModelState : public BackendModel { // Validate Model Configuration TRITONSERVER_Error* ValidateModelConfig(); + // Overrides `BackendModel::SetModelConfig` to also + // set `ModelState::decoupled_` + TRITONSERVER_Error* SetModelConfig(); + // Auto-complete stub std::unique_ptr& Stub() { return auto_complete_stub_; } @@ -255,7 +272,7 @@ class ModelState : public BackendModel { std::string python_execution_env_; bool force_cpu_only_input_tensors_; bool decoupled_; - std::string platform_; + std::string runtime_modeldir_; std::unique_ptr auto_complete_stub_; }; @@ -270,15 +287,13 @@ class ModelInstanceState : public BackendModelInstance { std::thread stub_to_parent_queue_monitor_; bool stub_to_parent_thread_; - // Decoupled monitor thread - std::thread decoupled_monitor_; - bool decoupled_thread_; std::mutex mu_; std::condition_variable cv_; std::unique_ptr received_message_; std::vector> futures_; std::unique_ptr thread_pool_; - std::unordered_map> infer_payload_; + std::unordered_map> infer_payload_; + std::mutex infer_payload_mu_; std::unique_ptr request_executor_; public: @@ -291,28 +306,12 @@ class ModelInstanceState : public BackendModelInstance { // Launch stub process. TRITONSERVER_Error* LaunchStubProcess(); - TRITONSERVER_Error* SendMessageToStub(off_t message); void ResponseSendDecoupled(std::shared_ptr response_send_message); - // Checks whether the stub process is live - bool IsStubProcessAlive(); - - // Get a message from the stub process - void SendMessageAndReceiveResponse( - off_t message, off_t& response, bool& restart, - std::shared_ptr>& responses, - TRITONBACKEND_Request** requests, const uint32_t request_count); - - // Responds to all the requests with an error message. - void RespondErrorToAllRequests( - const char* message, - std::shared_ptr>& responses, - TRITONBACKEND_Request** requests, const uint32_t request_count); - - // In the decoupled mode, the parent message queue is monitored only by this - // function during the execute phase. No other thread should pop any message - // from the message queue in the decoupled mode. - void DecoupledMessageQueueMonitor(); + // The parent message queue is monitored only by this function during the + // execute phase. No other thread should pop any message from the message + // queue. + void MessageQueueMonitor(); // This function is executed on a separate thread and monitors the queue for // message sent from stub to parent process. @@ -327,13 +326,8 @@ class ModelInstanceState : public BackendModelInstance { TRITONBACKEND_Request* request, std::shared_ptr>& responses); - // Process all the requests obtained from Triton. - void ProcessRequests( - TRITONBACKEND_Request** requests, const uint32_t request_count, - bool& restart); - // Process all the requests in the decoupled mode. - TRITONSERVER_Error* ProcessRequestsDecoupled( + TRITONSERVER_Error* ProcessRequests( TRITONBACKEND_Request** requests, const uint32_t request_count, std::vector>& pb_infer_requests, PbMetricReporter& pb_metric_reporter); @@ -347,9 +341,6 @@ class ModelInstanceState : public BackendModelInstance { // Cleanup BLS responses void CleanupBLSResponses(); - // Wait for BLS requests to complete - void WaitForBLSRequestsToFinish(); - // Check the incoming requests for errors TRITONSERVER_Error* CheckIncomingRequests( TRITONBACKEND_Request** requests, const uint32_t request_count, @@ -367,6 +358,24 @@ class ModelInstanceState : public BackendModelInstance { AllocatedSharedMemory& request_batch, std::shared_ptr>& responses); + void SendMessageAndReceiveResponse( + bi::managed_external_buffer::handle_t message, + bi::managed_external_buffer::handle_t& response, + std::shared_ptr>& responses, + TRITONBACKEND_Request** requests, const uint32_t request_count); + + void RespondErrorToAllRequests( + const char* message, + std::shared_ptr>& responses, + TRITONBACKEND_Request** requests, const uint32_t request_count); + + // void SendMessageToStub(bi::managed_external_buffer::handle_t message); + TRITONSERVER_Error* SendMessageToStub( + bi::managed_external_buffer::handle_t message); + + // Checks whether the stub process is live + bool IsStubProcessAlive(); + // Model instance stub std::unique_ptr& Stub() { return model_instance_stub_; } @@ -391,8 +400,14 @@ class ModelInstanceState : public BackendModelInstance { std::unique_ptr* infer_response, bi::managed_external_buffer::handle_t* response_handle); - // Process the bls decoupled cleanup request - void ProcessBLSCleanupRequest(const std::unique_ptr& message); + // Process the decoupled cleanup request for InferPayload and ResponseFactory + void ProcessCleanupRequest(const std::unique_ptr& message); + + // Process cancelling a BLS request + void ProcessCancelBLSRequest(const std::unique_ptr& message); + + // Process request cancellation query + void ProcessIsRequestCancelled(const std::unique_ptr& message); // Process a message. The function 'request_handler' is invoked // to handle the request. T should be either 'MetricFamily', 'Metric' or @@ -411,5 +426,8 @@ class ModelInstanceState : public BackendModelInstance { // Process a model control request void ProcessModelControlRequest(const std::unique_ptr& message); + + // Attempt to share CUDA memory pool with the stub process + void ShareCUDAMemoryPool(const int32_t device_id); }; }}} // namespace triton::backend::python diff --git a/src/request_executor.cc b/src/request_executor.cc index 2590ee37..716d3c56 100644 --- a/src/request_executor.cc +++ b/src/request_executor.cc @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -28,6 +28,7 @@ #include +#include "correlation_id.h" #include "pb_utils.h" #include "scoped_defer.h" #include "triton/backend/backend_common.h" @@ -48,10 +49,10 @@ MemoryTypeToTritonMemoryType( const PreferredMemory::MemoryType& memory_type) { switch (memory_type) { - case PreferredMemory::MemoryType::CPU: + case PreferredMemory::MemoryType::kCPU: *triton_memory_type = TRITONSERVER_MEMORY_CPU; break; - case PreferredMemory::MemoryType::GPU: + case PreferredMemory::MemoryType::kGPU: *triton_memory_type = TRITONSERVER_MEMORY_GPU; break; @@ -68,9 +69,15 @@ InferRequestComplete( TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp) { if (request != nullptr) { + RequestCompletionUserp* completion_userp = + reinterpret_cast(userp); + completion_userp->infer_payload->SetRequestAddress(0L); + LOG_IF_ERROR( TRITONSERVER_InferenceRequestDelete(request), "Failed to delete inference request."); + + delete completion_userp; } } @@ -83,10 +90,17 @@ InferResponseComplete( std::unique_ptr infer_response; std::vector> output_tensors; std::shared_ptr pb_error; + std::string parameters_string; + TRITONSERVER_Error_Code error_code = TRITONSERVER_ERROR_INTERNAL; if (response != nullptr) { try { - THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceResponseError(response)); + TRITONSERVER_Error* server_error = + TRITONSERVER_InferenceResponseError(response); + if (server_error != nullptr) { + error_code = TRITONSERVER_ErrorCode(server_error); + } + THROW_IF_TRITON_ERROR(server_error); uint32_t output_count; THROW_IF_TRITON_ERROR( @@ -109,7 +123,6 @@ InferResponseComplete( std::string sname = cname; std::vector dims_vector{shape, shape + dim_count}; - // userp is only set for the CPU tensors if (memory_type != TRITONSERVER_MEMORY_GPU) { if (byte_size != 0) { std::shared_ptr pb_tensor = std::make_shared( @@ -129,12 +142,49 @@ InferResponseComplete( nullptr /* DLManagedTensor */)); } } else { - output_tensors.push_back(std::make_shared( + std::shared_ptr pb_tensor = std::make_shared( sname, dims_vector, datatype, memory_type, memory_type_id, const_cast(base), byte_size, - nullptr /* DLManagedTensor */)); + nullptr /* DLManagedTensor */); + + std::unique_ptr pb_memory( + reinterpret_cast(userp)); + pb_tensor->SetMemory(std::move(pb_memory)); + output_tensors.push_back(pb_tensor); + } + } + + triton::common::TritonJson::Value parameters_json( + triton::common::TritonJson::ValueType::OBJECT); + uint32_t parameter_count; + THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceResponseParameterCount( + response, ¶meter_count)); + + for (size_t i = 0; i < parameter_count; i++) { + const char* name; + TRITONSERVER_ParameterType type; + const void* vvalue; + THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceResponseParameter( + response, i, &name, &type, &vvalue)); + if (type == TRITONSERVER_PARAMETER_INT) { + THROW_IF_TRITON_ERROR(parameters_json.AddInt( + name, *(reinterpret_cast(vvalue)))); + } else if (type == TRITONSERVER_PARAMETER_BOOL) { + THROW_IF_TRITON_ERROR(parameters_json.AddBool( + name, *(reinterpret_cast(vvalue)))); + } else if (type == TRITONSERVER_PARAMETER_STRING) { + std::string string = reinterpret_cast(vvalue); + THROW_IF_TRITON_ERROR(parameters_json.AddString(name, string)); + } else { + throw PythonBackendException( + (std::string("Unsupported parameter type for parameter '") + + name + "'.")); } } + + triton::common::TritonJson::WriteBuffer buffer; + THROW_IF_TRITON_ERROR(parameters_json.Write(&buffer)); + parameters_string = buffer.Contents(); } catch (const PythonBackendException& pb_exception) { if (response != nullptr) { @@ -144,24 +194,25 @@ InferResponseComplete( response = nullptr; } - pb_error = std::make_shared(pb_exception.what()); + pb_error = std::make_shared(pb_exception.what(), error_code); output_tensors.clear(); } if (!infer_payload->IsDecoupled()) { infer_response = std::make_unique( - output_tensors, pb_error, true /* is_last_response */); + output_tensors, pb_error, parameters_string, + true /* is_last_response */); } else { if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) { // Not the last response. infer_response = std::make_unique( - output_tensors, pb_error, false /* is_last_response */, - userp /* id */); + output_tensors, pb_error, parameters_string, + false /* is_last_response */, userp /* id */); } else { // The last response. infer_response = std::make_unique( - output_tensors, pb_error, true /* is_last_response */, - userp /* id */); + output_tensors, pb_error, parameters_string, + true /* is_last_response */, userp /* id */); } } @@ -173,11 +224,13 @@ InferResponseComplete( (flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != 0) { // An empty response may be the last response for decoupled models. infer_response = std::make_unique( - output_tensors, pb_error, true /* is_last_response */, userp /* id */); + output_tensors, pb_error, "" /* parameters */, + true /* is_last_response */, userp /* id */); } else { pb_error = std::make_shared("Unexpected empty response."); infer_response = std::make_unique( - output_tensors, pb_error, true /* is_last_response */, userp /* id */); + output_tensors, pb_error, "" /* parameters */, + true /* is_last_response */, userp /* id */); } infer_payload->SetValue(std::move(infer_response)); @@ -198,7 +251,7 @@ ResponseAlloc( ScopedDefer _([&shm_pool] { shm_pool.release(); }); if (p->preferred_memory.PreferredMemoryType() == - PreferredMemory::MemoryType::DEFAULT) { + PreferredMemory::MemoryType::kDefault) { *actual_memory_type = preferred_memory_type; *actual_memory_type_id = preferred_memory_type_id; } else { @@ -241,24 +294,27 @@ ResponseAlloc( } break; #ifdef TRITON_ENABLE_GPU case TRITONSERVER_MEMORY_GPU: { - auto err = cudaSetDevice(*actual_memory_type_id); - if ((err != cudaSuccess) && (err != cudaErrorNoDevice) && - (err != cudaErrorInsufficientDriver)) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - std::string( - "unable to set current CUDA device: " + - std::string(cudaGetErrorString(err))) - .c_str()); - } + BackendMemory* backend_memory; + std::unique_ptr lbackend_memory; + try { + THROW_IF_TRITON_ERROR(BackendMemory::Create( + reinterpret_cast( + shm_pool->GetCUDAMemoryPoolManager()->TritonMemoryManager()), + {BackendMemory::AllocationType::GPU_POOL, + BackendMemory::AllocationType::GPU}, + *actual_memory_type_id, byte_size, &backend_memory)); + lbackend_memory.reset(backend_memory); - err = cudaMalloc(buffer, byte_size); - if (err != cudaSuccess) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - std::string( - "cudaMalloc failed: " + std::string(cudaGetErrorString(err))) - .c_str()); + std::unique_ptr pb_memory = PbMemory::Create( + shm_pool, std::move(lbackend_memory), true /* copy_gpu */); + *buffer = pb_memory->DataPtr(); + *buffer_userp = reinterpret_cast(pb_memory.get()); + pb_memory.release(); + } + catch (const PythonBackendException& pb_exception) { + TRITONSERVER_Error* err = + CreateTritonErrorFromException(pb_exception); + return err; } break; } @@ -269,6 +325,18 @@ ResponseAlloc( return nullptr; // Success } +void +InferRequestCancel(intptr_t request_address) +{ + if (request_address == 0L) { + return; + } + + TRITONSERVER_InferenceRequest* irequest = + reinterpret_cast(request_address); + THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestCancel(irequest)); +} + TRITONSERVER_Error* OutputBufferQuery( TRITONSERVER_ResponseAllocator* allocator, void* userp, @@ -311,6 +379,7 @@ RequestExecutor::Infer( bool is_ready = false; const char* model_name = infer_request->ModelName().c_str(); TRITONSERVER_InferenceRequest* irequest = nullptr; + RequestCompletionUserp* completion_userp = nullptr; try { int64_t model_version = infer_request->ModelVersion(); @@ -347,8 +416,14 @@ RequestExecutor::Infer( THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetId( irequest, infer_request->RequestId().c_str())); - THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetCorrelationId( - irequest, infer_request->CorrelationId())); + if (infer_request->GetCorrelationId().Type() == + CorrelationIdDataType::UINT64) { + THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetCorrelationId( + irequest, infer_request->GetCorrelationId().UnsignedIntValue())); + } else { + THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetCorrelationIdString( + irequest, infer_request->GetCorrelationId().StringValue().c_str())); + } THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetFlags( irequest, infer_request->Flags())); @@ -356,8 +431,48 @@ RequestExecutor::Infer( THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetTimeoutMicroseconds( irequest, infer_request->Timeout())); + completion_userp = new RequestCompletionUserp(infer_payload); THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetReleaseCallback( - irequest, InferRequestComplete, nullptr /* request_release_userp */)); + irequest, InferRequestComplete, + reinterpret_cast(completion_userp))); + + TRITONSERVER_InferenceTrace* trace = nullptr; + if (infer_request->GetTrace().TritonTrace() != nullptr) { + THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceTraceSpawnChildTrace( + reinterpret_cast( + infer_request->GetTrace().TritonTrace()), + &trace)); + } + + const std::string& param_str = infer_request->Parameters(); + triton::common::TritonJson::Value param; + THROW_IF_TRITON_ERROR(param.Parse(param_str.c_str(), param_str.length())); + std::vector param_keys; + THROW_IF_TRITON_ERROR(param.Members(¶m_keys)); + for (const auto& key : param_keys) { + triton::common::TritonJson::Value value; + if (!param.Find(key.c_str(), &value)) { + throw PythonBackendException("Unexpected missing key on parameters"); + } + if (value.IsString()) { + std::string string_value; + THROW_IF_TRITON_ERROR(value.AsString(&string_value)); + THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetStringParameter( + irequest, key.c_str(), string_value.c_str())); + } else if (value.IsInt()) { + int64_t int_value = 0; + THROW_IF_TRITON_ERROR(value.AsInt(&int_value)); + THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetIntParameter( + irequest, key.c_str(), int_value)); + } else if (value.IsBool()) { + bool bool_value = false; + THROW_IF_TRITON_ERROR(value.AsBool(&bool_value)); + THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetBoolParameter( + irequest, key.c_str(), bool_value)); + } else { + throw PythonBackendException("Unsupported value type on parameters"); + } + } for (auto& infer_input : infer_request->Inputs()) { THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestAddInput( @@ -388,11 +503,21 @@ RequestExecutor::Infer( reinterpret_cast(infer_payload->ResponseAllocUserp().get()), InferResponseComplete, reinterpret_cast(infer_payload.get()))); - THROW_IF_TRITON_ERROR(TRITONSERVER_ServerInferAsync( - server_, irequest, nullptr /* trace */)); + // Store the inference request address submitted to the Triton server for + // retrieval + infer_payload->SetRequestAddress(reinterpret_cast(irequest)); + infer_payload->SetRequestCancellationFunc(InferRequestCancel); + + THROW_IF_TRITON_ERROR( + TRITONSERVER_ServerInferAsync(server_, irequest, trace)); } } catch (const PythonBackendException& pb_exception) { + infer_payload->SetRequestAddress(0L); + if (completion_userp != nullptr) { + delete completion_userp; + } + LOG_IF_ERROR( TRITONSERVER_InferenceRequestDelete(irequest), "Failed to delete inference request."); diff --git a/src/request_executor.h b/src/request_executor.h index 1c5eb1fa..07562d6a 100644 --- a/src/request_executor.h +++ b/src/request_executor.h @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -37,6 +37,12 @@ namespace triton { namespace backend { namespace python { TRITONSERVER_Error* CreateTritonErrorFromException( const PythonBackendException& pb_exception); +struct RequestCompletionUserp { + std::shared_ptr infer_payload; + RequestCompletionUserp(std::shared_ptr& infer_payload) + : infer_payload(infer_payload){}; +}; + class RequestExecutor { TRITONSERVER_ResponseAllocator* response_allocator_ = nullptr; TRITONSERVER_Server* server_; diff --git a/src/resources/platform_handlers/tensorflow_savedmodel/README.md b/src/resources/platform_handlers/tensorflow_savedmodel/README.md deleted file mode 100644 index 23199e7b..00000000 --- a/src/resources/platform_handlers/tensorflow_savedmodel/README.md +++ /dev/null @@ -1,87 +0,0 @@ - - -# Serving Tensorflow SavedModels using Python Backend \[Experimental\] - -*NOTE*: This feature is subject to change and removal, and should not -be used in production. - -Starting from 23.07, we are adding experimental support for loading -and serving of models in [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) -format via Python backend. The `model.savedmodel` can be provided within -the triton server model repository without `model.py` and backend will -automatically use a pre-built python model (`model.py`)[model.py] to load -and serve provided TF SavedModel. The handler can [auto-complete](../../../../README.md#auto_complete_config) -the missing model configuration. - -The model repository structure can look like: - -``` -model_repository/ -`-- resnet_v1_50_savedmodel - |-- 1 - | `-- model.savedmodel - | |-- saved_model.pb - | `-- variables - |-- config.pbtxt - `-- resnet50_labels.txt -``` - -In order to use this feature, make sure that [TensorFlow pip package](https://pypi.org/project/tensorflow/2.13.0/) -is available in the same Python environment. - -``` -pip install tensorfow==2.13.0 -``` - -Alternatively, you can create a -[Python Execution Environment](#using-custom-python-execution-environments) -with the TensorFlow dependency. - -By default, Triton will use the [TensorFlow backend](https://github.com/triton-inference-server/tensorflow_backend) -to load and serve the saved model. In order to use the Python backend with -TensorFlow SavedModel, [model configuration](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md) -should explicitly provide the following settings: - -``` -backend: "python" -platform: "tensorflow_savedmodel" -``` - -It has been observed that certain DLFW like TensorFlow do not release the entire -memory allocated for loading a model back to the system when the model gets -unloaded. This can be problematic when working with a large number of models and -dynamically loading/unloading them. Using Python backend for TF SavedModel serving -will allow the models to be loaded in a separate process, which ensures that entire -memory allocated within the process would be released to the system upon a model -unload. - -Following are few known limitations of this feature: -- GPU execution is not supported. -- List of requests received in model [`execute`](../../../../README.md#execute) function are -not run in a single batch but one after the other. diff --git a/src/resources/platform_handlers/tensorflow_savedmodel/model.py b/src/resources/platform_handlers/tensorflow_savedmodel/model.py deleted file mode 100644 index 24b95472..00000000 --- a/src/resources/platform_handlers/tensorflow_savedmodel/model.py +++ /dev/null @@ -1,536 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import json -import os - -try: - import tensorflow as tf - from tensorflow.core.framework import types_pb2 - from tensorflow.python.client import session - from tensorflow.python.saved_model import loader, signature_constants - from tensorflow.python.tools import saved_model_utils -except ModuleNotFoundError as error: - raise RuntimeError( - "Missing/Incomplete tensorflow package installation..." - ) from error - -# triton_python_backend_utils is available in every Triton Python model. You -# need to use this module to create inference requests and responses. It also -# contains some utility functions for extracting information from model_config -# and converting Triton input/output types to numpy types. -import triton_python_backend_utils as pb_utils - -TF_STRING_TO_TRITON = { - "DT_BOOL": "TYPE_BOOL", - "DT_UINT8": "TYPE_UINT8", - "DT_UINT16": "TYPE_UINT16", - "DT_UINT32": "TYPE_UINT32", - "DT_UINT64": "TYPE_UINT64", - "DT_INT8": "TYPE_INT8", - "DT_INT16": "TYPE_INT16", - "DT_INT32": "TYPE_INT32", - "DT_INT64": "TYPE_INT64", - "DT_HALF": "TYPE_FP16", - "DT_FLOAT": "TYPE_FP32", - "DT_DOUBLE": "TYPE_FP64", - "DT_STRING": "TYPE_STRING", -} - -_DEFAULT_ARTIFACT_NAME = "model.savedmodel" - - -def _get_savedmodel_path(config): - artifact_name = config["default_model_filename"] - if not artifact_name: - artifact_name = _DEFAULT_ARTIFACT_NAME - - savedmodel_path = os.path.join(pb_utils.get_model_dir(), artifact_name) - if not os.path.exists(savedmodel_path): - raise pb_utils.TritonModelException( - f"No savedmodel dir found in " + savedmodel_path - ) - - return savedmodel_path - - -def _parse_signature_def(config): - if config["parameters"]: - if "TF_SIGNATURE_DEF" in config["parameters"].keys(): - return config["parameters"]["TF_SIGNATURE_DEF"]["string_value"] - return None - - -def _parse_graph_tag(config): - if config["parameters"]: - if "TF_GRAPH_TAG" in config["parameters"].keys(): - return config["parameters"]["TF_GRAPH_TAG"]["string_value"] - return None - - -def _parse_num_intra_threads(config): - if config["parameters"]: - if "TF_NUM_INTRA_THREADS" in config["parameters"].keys(): - return int(config["parameters"]["TF_NUM_INTRA_THREADS"]["string_value"]) - return None - - -def _parse_num_inter_threads(config): - if config["parameters"]: - if "TF_NUM_INTER_THREADS" in config["parameters"].keys(): - return int(config["parameters"]["TF_NUM_INTER_THREADS"]["string_value"]) - return None - - -def _get_truth_value(string_value): - val = string_value.casefold() - if val == "yes" or val == "1" or val == "on" or val == "true": - return True - else: - return False - - -def _parse_use_per_session_thread(config): - if config["parameters"]: - if "USE_PER_SESSION_THREAD" in config["parameters"].keys(): - val = config["parameters"]["USE_PER_SESSION_THREAD"]["string_value"] - return _get_truth_value(val) - return False - - -def _get_signature_def(savedmodel_path, config): - tag_sets = saved_model_utils.get_saved_model_tag_sets(savedmodel_path) - graph_tag = _parse_graph_tag(config) - if graph_tag is None: - if "serve" in tag_sets[0]: - graph_tag = "serve" - else: - graph_tag = tag_sets[0][0] - - meta_graph_def = saved_model_utils.get_meta_graph_def(savedmodel_path, graph_tag) - signature_def_map = meta_graph_def.signature_def - signature_def_k = _parse_signature_def(config) - if signature_def_k is None: - serving_default = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - if serving_default in signature_def_map.keys(): - signature_def_k = serving_default - else: - signature_def_k = signature_def_map.keys()[0] - - if signature_def_k not in signature_def_map.keys(): - raise pb_utils.TritonModelException( - f" The model does not include the signature_def '" + signature_def_k + "'" - ) - - return graph_tag, signature_def_map[signature_def_k] - - -def _has_batch_dim(tensor_info): - if tensor_info.tensor_shape.unknown_rank: - return True - elif tensor_info.tensor_shape.dim[0].size == -1: - return True - else: - return False - - -def _get_batching_hint_from_signature(signature_def): - for input_info in signature_def.inputs.values(): - if not _has_batch_dim(input_info): - return False - - for output_info in signature_def.outputs.values(): - if not _has_batch_dim(output_info): - return False - - return True - - -def _convert_proto_to_dict_tensor(name, tensor_proto, batching_enabled): - tensor_dict = {} - tensor_dict["name"] = name - dtype_dict = {value: key for (key, value) in types_pb2.DataType.items()} - tensor_dict["data_type"] = TF_STRING_TO_TRITON[dtype_dict[tensor_proto.dtype]] - if tensor_proto.tensor_shape.unknown_rank: - # FIXME: Fix the handling of unknown rank - dims = [-1] - else: - dims = [dim.size for dim in tensor_proto.tensor_shape.dim] - if batching_enabled: - tensor_dict["dims"] = dims[1:] - else: - tensor_dict["dims"] = dims - - return tensor_dict - - -def _validate_datatype(tf_dtype, triton_datatype, tensor_name): - dtype_dict = {value: key for (key, value) in types_pb2.DataType.items()} - if triton_datatype != TF_STRING_TO_TRITON[dtype_dict[tf_dtype]]: - raise pb_utils.TritonModelException( - f" Mismatch between datatype for tensor '" - + tensor_name - + "', expected '" - + TF_STRING_TO_TRITON[dtype_dict[tf_dtype]] - + "', got '" - + triton_datatype - ) - - -def _validate_dims(tf_shape, triton_dims, batching_enabled, tensor_name): - if tf_shape.unknown_rank: - return - - index = 0 - offset = 1 if batching_enabled else 0 - if len(tf_shape.dim) != (offset + len(triton_dims)): - raise pb_utils.TritonModelException( - f" Mismatch in the number of dimension with the model for tensor '" - + tensor_name - + "', expected " - + str(len(tf_shape.dim) - offset) - + ", got " - + str(len(triton_dims)) - ) - - for dim in tf_shape.dim: - if index == 0 and batching_enabled: - if dim.size != -1: - raise pb_utils.TritonModelException( - f" The first dimension of a batching model should be dynamic, " - "however, got shape of first dimension in model for tensor '" - + tensor_name - + "' as " - + str(dim.size) - ) - else: - if dim.size != triton_dims[index - offset]: - raise pb_utils.TritonModelException( - f" Mismatch in " - + str(index - offset) - + "th dimension for tensor '" - + tensor_name - + "', expected " - + str(dim.size) - + ", got " - + str(triton_dims[index - offset]) - ) - index = index + 1 - - -def _validate_model_config(model_config, signature_def): - signature_supports_batching = _get_batching_hint_from_signature(signature_def) - if (not signature_supports_batching) and (model_config["max_batch_size"] != 0): - raise pb_utils.TritonModelException( - f" The model signature does not support batching, yet model config" - " has max_batch_size set to '" + str(model_config["max_batch_size"]) + "'" - ) - - batching_enabled = model_config["max_batch_size"] != 0 - - if model_config["platform"] != "tensorflow_savedmodel": - raise pb_utils.TritonModelException( - f"[INTERNAL]: The platform field for using this model should be set to" - " 'tensorflow_savedmodel' in model config, got '" - + model_config["platform"] - + "'" - ) - if model_config["batch_input"]: - raise pb_utils.TritonModelException( - f"The platform model '" - + model_config["platform"] - + "' does not support model with batch_input" - ) - if model_config["batch_output"]: - raise pb_utils.TritonModelException( - f"The platform model '" - + model_config["platform"] - + "' does not support model with batch_output" - ) - - # Validate input tensors - input_tensor_info = signature_def.inputs - config_input_names = [input["name"] for input in model_config["input"]] - for input_name in input_tensor_info.keys(): - if input_name not in config_input_names: - raise pb_utils.TritonModelException( - f" Missing input tensor configuration for tensor '" + input_name + "'" - ) - for input in model_config["input"]: - config_input_name = input["name"] - if config_input_name not in input_tensor_info.keys(): - supported_names = "" - for valid_name in input_tensor_info.keys(): - supported_names = supported_names + ";" + valid_name - raise pb_utils.TritonModelException( - f" No input tensor with name '" - + config_input_name - + "', only supported input names are " - + supported_names - ) - _validate_datatype( - input_tensor_info[config_input_name].dtype, - input["data_type"], - config_input_name, - ) - _validate_dims( - input_tensor_info[config_input_name].tensor_shape, - input["dims"], - batching_enabled, - config_input_name, - ) - - # Validate output tensors - output_tensor_info = signature_def.outputs - for output in model_config["output"]: - config_output_name = output["name"] - if config_output_name not in output_tensor_info.keys(): - supported_names = "" - for valid_name in output_tensor_info.keys(): - supported_names = supported_names + ";" + valid_name - raise pb_utils.TritonModelException( - f" No output tensor with name '" - + config_output_name - + "', only supported output names are " - + supported_names - ) - - _validate_datatype( - output_tensor_info[config_output_name].dtype, - output["data_type"], - config_output_name, - ) - _validate_dims( - output_tensor_info[config_output_name].tensor_shape, - output["dims"], - batching_enabled, - config_output_name, - ) - - -class TritonPythonModel: - """Your Python model must use the same class name. Every Python model - that is created must have "TritonPythonModel" as the class name. - """ - - @staticmethod - def auto_complete_config(auto_complete_model_config): - config = auto_complete_model_config.as_dict() - - if config["platform"] != "tensorflow_savedmodel": - raise pb_utils.TritonModelException( - f"[INTERNAL]: The platform field for using this model should be set to" - " 'tensorflow_savedmodel' in model config, got '" - + config["platform"] - + "'" - ) - if config["batch_input"]: - raise pb_utils.TritonModelException( - f"The platform model '" - + config["platform"] - + "' does not support model with batch_input" - ) - if config["batch_output"]: - raise pb_utils.TritonModelException( - f"The platform model '" - + config["platform"] - + "' does not support model with batch_output" - ) - - savedmodel_path = _get_savedmodel_path(config) - - if savedmodel_path is None: - raise pb_utils.TritonModelException( - f"[INTERNAL]: The path to the framework model should be" " provided" - ) - - batching_enabled = False - if config["max_batch_size"] != 0: - batching_enabled = True - - _, signature_def = _get_signature_def(savedmodel_path, config) - - input_tensor_info = signature_def.inputs - output_tensor_info = signature_def.outputs - - batching_hint = False - if not batching_enabled: - batching_hint = _get_batching_hint_from_signature(signature_def) - - # FIXME: Currently the presence of dynamic batch dimension is - # being treated as sufficient proof for enabling batching. - # Need to visit the tensors that are already provided in config - # to confirm the hint - batching_enabled = batching_hint - - config_input_names = [input["name"] for input in config["input"]] - config_output_names = [output["name"] for output in config["output"]] - - # TODO: Add auto-completion of partial tensor specification. - for input_name in input_tensor_info.keys(): - if input_name not in config_input_names: - auto_complete_model_config.add_input( - _convert_proto_to_dict_tensor( - input_name, input_tensor_info[input_name], batching_enabled - ) - ) - - for output_name in output_tensor_info.keys(): - if output_name not in config_output_names: - auto_complete_model_config.add_output( - _convert_proto_to_dict_tensor( - output_name, output_tensor_info[output_name], batching_enabled - ) - ) - - if batching_enabled: - if config["max_batch_size"] == 0: - auto_complete_model_config.set_max_batch_size(4) - auto_complete_model_config.set_dynamic_batching() - - return auto_complete_model_config - - def initialize(self, args): - """`initialize` is called only once when the model is being loaded. - Implementing `initialize` function is optional. This function allows - the model to initialize any state associated with this model. - - Parameters - ---------- - args : dict - Both keys and values are strings. The dictionary keys and values are: - * model_config: A JSON string containing the model configuration - * model_instance_kind: A string containing model instance kind - * model_instance_device_id: A string containing model instance device ID - * model_repository: Model repository path - * model_version: Model version - * model_name: Model name - """ - # You must parse model_config. JSON string is not parsed here - self.model_config = model_config = json.loads(args["model_config"]) - - savedmodel_path = _get_savedmodel_path(model_config) - - self.model_name = args["model_name"] - self.logger = pb_utils.Logger - self.logger.log_info("Initializing model for " + self.model_name) - - if args["model_instance_kind"] != "CPU": - self.logger.log_warn( - "GPU instances are not supported by this backend. Falling back to KIND_CPU for " - + self.model_name - ) - - tag_set, signature_def = _get_signature_def(savedmodel_path, model_config) - _validate_model_config(model_config, signature_def) - - self.signature_def = signature_def - self.input_tensor_info = self.signature_def.inputs - output_tensor_info = self.signature_def.outputs - - # Get the input output names from model config - self.input_names = [input["name"] for input in model_config["input"]] - self.output_names = [output["name"] for output in model_config["output"]] - - # Get the output tensor names - self.output_tensor_names = [ - output_tensor_info[output_name].name for output_name in self.output_names - ] - - # load the session model - # FIXME Add more configuration options for the model. - sess_config = tf.compat.v1.ConfigProto( - inter_op_parallelism_threads=_parse_num_inter_threads(model_config), - intra_op_parallelism_threads=_parse_num_intra_threads(model_config), - use_per_session_threads=_parse_use_per_session_thread(model_config), - ) - self.tf_session = session.Session(graph=tf.Graph(), config=sess_config) - loader.load(self.tf_session, [tag_set], savedmodel_path) - - # Hoding the input dict for caching input tensor data for - # better inference performance - self.input_feed_dict = {} - - def execute(self, requests): - """`execute` MUST be implemented in every Python model. `execute` - function receives a list of pb_utils.InferenceRequest as the only - argument. This function is called when an inference request is made - for this model. Depending on the batching configuration (e.g. Dynamic - Batching) used, `requests` may contain multiple requests. Every - Python model, must create one pb_utils.InferenceResponse for every - pb_utils.InferenceRequest in `requests`. If there is an error, you can - set the error argument when creating a pb_utils.InferenceResponse - - Parameters - ---------- - requests : list - A list of pb_utils.InferenceRequest - - Returns - ------- - list - A list of pb_utils.InferenceResponse. The length of this list must - be the same as `requests` - """ - - responses = [] - - # FIXME: Instead of iterating through each request, run - # the inference as a single batch. - for request in requests: - # Prepare the input feed for the model. - for input_name in self.input_names: - self.input_feed_dict[ - self.input_tensor_info[input_name].name - ] = pb_utils.get_input_tensor_by_name(request, input_name).as_numpy() - - # FIXME: Add GPU Tensor handling. DLpack should be utilized - # for better performance - outputs = self.tf_session.run( - self.output_tensor_names, feed_dict=self.input_feed_dict - ) - - # Create output tensors. You need pb_utils.Tensor - # objects to create pb_utils.InferenceResponse. - output_tensors = [] - for i, output in enumerate(outputs): - output_tensors.append(pb_utils.Tensor(self.output_names[i], output)) - - inference_response = pb_utils.InferenceResponse( - output_tensors=output_tensors - ) - responses.append(inference_response) - - return responses - - def finalize(self): - """`finalize` is called only once when the model is being unloaded. - Implementing `finalize` function is OPTIONAL. This function allows - the model to perform any necessary clean ups before exit. - """ - if self.tf_session is not None: - self.tf_session.close - self.logger.log_info("Removed model instance for " + self.model_name) diff --git a/src/resources/triton_python_backend_utils.py b/src/resources/triton_python_backend_utils.py index 560a3198..de332cf7 100644 --- a/src/resources/triton_python_backend_utils.py +++ b/src/resources/triton_python_backend_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -381,12 +381,12 @@ def add_input(self, input): Raises ------ ValueError - If input contains property other than 'name', 'data_type' - and 'dims' or any of the properties are not set, or if an - input with the same name already exists in the configuration - but has different data_type or dims property + If input contains property other than 'name', 'data_type', + 'dims', 'optional' or any of the non-optional properties + are not set, or if an input with the same name already exists + in the configuration but has different data_type or dims property """ - valid_properties = ["name", "data_type", "dims"] + valid_properties = ["name", "data_type", "dims", "optional"] for current_property in input: if current_property not in valid_properties: raise ValueError( @@ -394,7 +394,7 @@ def add_input(self, input): + input["name"] + "' in auto-complete-config function for model '" + self._model_config["name"] - + "' contains property other than 'name', 'data_type' and 'dims'." + + "' contains property other than 'name', 'data_type', 'dims' and 'optional'." ) if "name" not in input: @@ -447,9 +447,26 @@ def add_input(self, input): + " but the model configuration specifies dims " + str(current_input["dims"]) ) + elif ( + "optional" in current_input + and "optional" in input + and current_input["optional"] != input["optional"] + ): + raise ValueError( + "model '" + + self._model_config["name"] + + "', tensor '" + + input["name"] + + "': the model expects optional " + + str(input["optional"]) + + " but the model configuration specifies optional " + + str(current_input["optional"]) + ) else: current_input["data_type"] = input["data_type"] current_input["dims"] = input["dims"] + if "optional" in input: + current_input["optional"] = input["optional"] return self._model_config["input"].append(input) @@ -538,7 +555,56 @@ def add_output(self, output): self._model_config["output"].append(output) + def set_model_transaction_policy(self, transaction_policy_dict): + """ + Set model transaction policy for the model. + Parameters + ---------- + transaction_policy_dict : dict + The dict, containing all properties to be set as a part + of `model_transaction_policy` field. + Raises + ------ + ValueError + If transaction_policy_dict contains property other + than 'decoupled', or if `model_transaction_policy` already exists + in the configuration, but has different `decoupled` property. + """ + valid_properties = ["decoupled"] + for current_property in transaction_policy_dict.keys(): + if current_property not in valid_properties: + raise ValueError( + "model transaction property in auto-complete-config " + + "function for model '" + + self._model_config["name"] + + "' contains property other than 'decoupled'." + ) + + if "model_transaction_policy" not in self._model_config: + self._model_config["model_transaction_policy"] = {} + + if "decoupled" in transaction_policy_dict.keys(): + if ( + "decoupled" in self._model_config["model_transaction_policy"] + and self._model_config["model_transaction_policy"]["decoupled"] + != transaction_policy_dict["decoupled"] + ): + raise ValueError( + "trying to change decoupled property in auto-complete-config " + + "for model '" + + self._model_config["name"] + + "', which is already set to '" + + str(self._model_config["model_transaction_policy"]["decoupled"]) + + "'." + ) + + self._model_config["model_transaction_policy"][ + "decoupled" + ] = transaction_policy_dict["decoupled"] + TRITONSERVER_REQUEST_FLAG_SEQUENCE_START = 1 TRITONSERVER_REQUEST_FLAG_SEQUENCE_END = 2 TRITONSERVER_RESPONSE_COMPLETE_FINAL = 1 +TRITONSERVER_REQUEST_RELEASE_ALL = 1 +TRITONSERVER_REQUEST_RELEASE_RESCHEDULE = 2 diff --git a/src/response_sender.cc b/src/response_sender.cc index a74459f6..ef3b09dd 100644 --- a/src/response_sender.cc +++ b/src/response_sender.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -35,39 +35,98 @@ namespace triton { namespace backend { namespace python { +void +CheckResponseSenderArguments( + const std::shared_ptr& response, const uint32_t flags) +{ + // Check the correctness of the provided flags. + if (flags != TRITONSERVER_RESPONSE_COMPLETE_FINAL && flags != 0) { + throw PythonBackendException( + "Unable to send response. Unsupported flag provided."); + } + + if (flags == 0 && response == nullptr) { + throw PythonBackendException( + "Inference Response object must be provided when the response flags is " + "set to zero."); + } +} + ResponseSender::ResponseSender( intptr_t request_address, intptr_t response_factory_address, - std::unique_ptr& shm_pool) + bool const* is_decoupled, + const std::set& requested_output_names, + std::unique_ptr& shm_pool, + const std::shared_ptr& pb_cancel) : request_address_(request_address), - response_factory_address_(response_factory_address), shm_pool_(shm_pool), - closed_(false) + response_factory_address_(response_factory_address), + is_decoupled_(is_decoupled), + requested_output_names_(requested_output_names), shm_pool_(shm_pool), + pb_cancel_(pb_cancel), closed_(false), number_of_response_sent_(0), + response_factory_deleted_(false) { } +ResponseSender::~ResponseSender() +{ + DeleteResponseFactory(); +} void -ResponseSender::Send( - std::shared_ptr infer_response, const uint32_t flags) +ResponseSender::UpdateStateAndCounters( + InferResponse* response, const uint32_t flags) { + if (is_decoupled_ == nullptr) { + // TODO: Can a model access the response sender on a BLS infer request? + throw PythonBackendException( + "Unable to send response. Response sender has no reference to the " + "decoupled state of the model."); + } + bool is_decoupled = *is_decoupled_; + + std::lock_guard lk(mu_); + + if (!is_decoupled) { + if (response != nullptr && number_of_response_sent_ > 0) { + throw PythonBackendException( + "Unable to send response. Non-decoupled model cannot send more than " + "one response."); + } + if (response == nullptr && flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL && + number_of_response_sent_ == 0) { + throw PythonBackendException( + "Unable to send response. Non-decoupled model cannot send complete " + "final before sending a response."); + } + } + if (closed_) { throw PythonBackendException( "Unable to send response. Response sender has been closed."); } if (flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { + response_factory_deleted_.exchange(true); closed_ = true; } + number_of_response_sent_++; +} - // Check the correctness of the provided flags. - if (flags != TRITONSERVER_RESPONSE_COMPLETE_FINAL && flags != 0) { - throw PythonBackendException( - "Unable to send response. Unsupported flag provided."); - } +void +ResponseSender::Send( + std::shared_ptr infer_response, const uint32_t flags) +{ + // Release the GIL. This avoids a potential deadlock situation in the parent + // process, where every thread in the thread pool is indirectly waiting for a + // function in the stub process that acquires the GIL. Meanwhile, the current + // thread, which holds the GIL, is also waiting for the parent side to have + // the next available thread to pick up the job during resource contention. + py::gil_scoped_release release; - if (flags == 0 && infer_response == nullptr) { - throw PythonBackendException( - "Inference Response object must be provided when the response flags is " - "set to zero."); + CheckResponseSenderArguments(infer_response, flags); + UpdateStateAndCounters(infer_response.get(), flags); + if (infer_response) { + infer_response->PruneOutputTensors(requested_output_names_); } std::unique_ptr& stub = Stub::GetOrCreateInstance(); @@ -114,7 +173,11 @@ ResponseSender::Send( { bi::scoped_lock guard{send_message_payload->mu}; - stub->SendIPCMessage(ipc_message); + // The server will destruct the response factory if the final flag is set. + if (flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { + response_factory_deleted_.exchange(true); + } + stub->SendIPCUtilsMessage(ipc_message); while (!send_message_payload->is_stub_turn) { send_message_payload->cv.wait(guard); } @@ -132,9 +195,26 @@ ResponseSender::Send( } if (has_gpu_output) { + ScopedDefer _([send_message_payload] { + bi::scoped_lock guard{send_message_payload->mu}; + send_message_payload->is_stub_turn = false; + send_message_payload->cv.notify_one(); + while (!send_message_payload->is_stub_turn) { + // Wait for the stub process to send the response and populate error + // message if any. + send_message_payload->cv.wait(guard); + } + }); + AllocatedSharedMemory gpu_buffers_handle = shm_pool_->Load( send_message_payload->gpu_buffers_handle); + if (!gpu_buffers_handle.data_->success) { + std::unique_ptr error = PbString::LoadFromSharedMemory( + shm_pool_, gpu_buffers_handle.data_->error); + throw PythonBackendException( + "Failed to load GPU buffers: " + error->String()); + } AllocatedSharedMemory gpu_buffers_handle_shm = @@ -142,12 +222,11 @@ ResponseSender::Send( gpu_buffers_handle.data_->buffers); uint64_t gpu_buffer_count = gpu_buffers_handle.data_->buffer_count; if (gpu_tensors.size() != gpu_buffer_count) { - LOG_ERROR - << (std::string( - "GPU buffers size does not match the provided buffers: ") + - std::to_string(gpu_tensors.size()) + - " != " + std::to_string(gpu_buffer_count)); - return; + throw PythonBackendException( + std::string( + "GPU buffers size does not match the provided buffers: ") + + std::to_string(gpu_tensors.size()) + + " != " + std::to_string(gpu_buffer_count)); } std::vector> dst_buffers; @@ -160,17 +239,6 @@ ResponseSender::Send( std::shared_ptr& src_buffer = gpu_tensors[i]; PbMemory::CopyBuffer(dst_buffers[i], src_buffer->Memory()); } - - { - bi::scoped_lock guard{send_message_payload->mu}; - send_message_payload->is_stub_turn = false; - send_message_payload->cv.notify_one(); - while (!send_message_payload->is_stub_turn) { - // Wait for the stub process to send the response and populate error - // message if any. - send_message_payload->cv.wait(guard); - } - } } if (send_message_payload->has_error) { @@ -184,4 +252,38 @@ ResponseSender::Send( } } } + +bool +ResponseSender::IsCancelled() +{ + return pb_cancel_->IsCancelled(); +} + +bool +ResponseSender::IsClosed() +{ + std::lock_guard lk(mu_); + return closed_; +} + +void +ResponseSender::Close() +{ + std::lock_guard lk(mu_); + closed_ = true; + response_factory_deleted_.exchange(true); +} + +void +ResponseSender::DeleteResponseFactory() +{ + bool already_deleted = response_factory_deleted_.exchange(true); + if (!already_deleted) { + std::unique_ptr& stub = Stub::GetOrCreateInstance(); + stub->EnqueueCleanupId( + reinterpret_cast(response_factory_address_), + PYTHONSTUB_DecoupledResponseFactoryCleanup); + } +} + }}} // namespace triton::backend::python diff --git a/src/response_sender.h b/src/response_sender.h index 114f22c0..a696f9eb 100644 --- a/src/response_sender.h +++ b/src/response_sender.h @@ -1,4 +1,4 @@ -// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -26,7 +26,11 @@ #pragma once +#include +#include + #include "infer_response.h" +#include "pb_cancel.h" #include "shm_manager.h" namespace triton { namespace backend { namespace python { @@ -35,13 +39,34 @@ class ResponseSender { public: ResponseSender( intptr_t request_address, intptr_t response_factory_address, - std::unique_ptr& shm_pool); + bool const* is_decoupled, + const std::set& requested_output_names, + std::unique_ptr& shm_pool, + const std::shared_ptr& pb_cancel); + intptr_t ResponseFactory() { return response_factory_address_; } + ~ResponseSender(); void Send(std::shared_ptr response, const uint32_t flags); + bool IsCancelled(); + void UpdateStateAndCounters(InferResponse* response, const uint32_t flags); + + // Can be useful at stopping the model from sending any more responses. + void Close(); + bool IsClosed(); private: + void DeleteResponseFactory(); + intptr_t request_address_; intptr_t response_factory_address_; + bool const* is_decoupled_; + std::set requested_output_names_; std::unique_ptr& shm_pool_; + std::shared_ptr pb_cancel_; + + std::mutex mu_; bool closed_; + size_t number_of_response_sent_; + + std::atomic response_factory_deleted_; }; }}} // namespace triton::backend::python diff --git a/src/shm_manager.cc b/src/shm_manager.cc index b5499f88..134cee6f 100644 --- a/src/shm_manager.cc +++ b/src/shm_manager.cc @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -33,6 +33,53 @@ namespace triton { namespace backend { namespace python { +void +CUDAMemoryPoolManager::SetCUDAPoolAddress( + const int32_t device_id, void* cuda_pool_address) +{ + std::lock_guard lock(mu_); + cuda_pool_address_map_[device_id] = cuda_pool_address; +} + +void* +CUDAMemoryPoolManager::CUDAPoolAddress(const int32_t device_id) +{ + if (cuda_pool_address_map_.find(device_id) != cuda_pool_address_map_.end()) { + return cuda_pool_address_map_[device_id]; + } else { + throw PythonBackendException( + "CUDA pool address for device " + std::to_string(device_id) + + " is not set."); + } +} + +void +CUDAMemoryPoolManager::SetTritonMemoryManager(void* triton_memory_manager) +{ + triton_memory_manager_ = triton_memory_manager; +} + +void* +CUDAMemoryPoolManager::TritonMemoryManager() +{ + return triton_memory_manager_; +} + +bool +CUDAMemoryPoolManager::UseCudaSharedPool(const int32_t device_id) +{ + return (cuda_pool_address_map_.find(device_id) != + cuda_pool_address_map_.end()) && + (cuda_pool_address_map_[device_id] != nullptr) && + (triton_memory_manager_ != nullptr); +} + +std::unordered_map& +CUDAMemoryPoolManager::CUDAPoolAddressMap() +{ + return cuda_pool_address_map_; +} + SharedMemoryManager::SharedMemoryManager( const std::string& shm_region_name, size_t shm_size, size_t shm_growth_bytes, bool create) @@ -40,6 +87,7 @@ SharedMemoryManager::SharedMemoryManager( shm_region_name_ = shm_region_name; create_ = create; shm_growth_bytes_ = shm_growth_bytes; + cuda_memory_pool_manager_ = std::make_unique(); try { if (create) { @@ -76,7 +124,7 @@ SharedMemoryManager::SharedMemoryManager( "' to requested size (" + std::to_string(shm_size) + " bytes). If you are running Triton inside docker, use '--shm-size' " "flag to control the shared memory region size. Each Python backend " - "model instance requires at least 64MBs of shared memory. Error: " + + "model instance requires at least 1 MB of shared memory. Error: " + ex.what()); // Remove the shared memory region if there was an error. bi::shared_memory_object::remove(shm_region_name.c_str()); @@ -99,6 +147,7 @@ SharedMemoryManager::SharedMemoryManager(const std::string& shm_region_name) shm_region_name_ = shm_region_name; create_ = false; shm_growth_bytes_ = 1024; + cuda_memory_pool_manager_ = std::make_unique(); shm_obj_ = std::make_unique( bi::open_only, shm_region_name.c_str(), bi::read_write); @@ -139,8 +188,8 @@ SharedMemoryManager::GrowIfNeeded(uint64_t byte_size) } catch (bi::interprocess_exception& ex) { std::string error_message = - ("Failed to increase the shared memory pool size for key '" + - shm_region_name_ + "' to " + std::to_string(*total_size_) + + ("Failed to increase the shared memory pool size to " + + std::to_string(*total_size_) + " bytes. If you are running Triton inside docker, use '--shm-size' " "flag to control the shared memory region size. Error: " + ex.what()); diff --git a/src/shm_manager.h b/src/shm_manager.h index bd462403..8517faf3 100644 --- a/src/shm_manager.h +++ b/src/shm_manager.h @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -26,16 +26,16 @@ #pragma once -#include - #include #include #include #include #include #include +#include #include #include +#include #include #include "pb_exception.h" @@ -43,6 +43,35 @@ namespace triton { namespace backend { namespace python { namespace bi = boost::interprocess; +static constexpr bi::managed_external_buffer::handle_t kShmControlRegionHandle{ + 1}; + +class CUDAMemoryPoolManager { + public: + CUDAMemoryPoolManager() : triton_memory_manager_(nullptr) {} + + void SetCUDAPoolAddress(const int32_t device_id, void* cuda_pool_address); + + void* CUDAPoolAddress(const int32_t device_id); + + void SetTritonMemoryManager(void* triton_memory_manager); + + void* TritonMemoryManager(); + + bool UseCudaSharedPool(const int32_t device_id); + + // Return cuda pool address map + std::unordered_map& CUDAPoolAddressMap(); + + private: + // The base address of the Triton CUDA memory pool + std::unordered_map cuda_pool_address_map_; + // The mutex to protect the cuda_pool_address_map_ + std::mutex mu_; + // TRITONBACKEND_MemoryManager + void* triton_memory_manager_; +}; + template struct AllocatedSharedMemory { AllocatedSharedMemory() = default; @@ -64,9 +93,9 @@ struct AllocatedSharedMemory { // info is placed in the beginning and the actual object is placed after that // (i.e. 4 plus the aligned address is not 16-bytes aligned). The aligned memory // is required by semaphore otherwise it may lead to SIGBUS error on ARM. -struct AllocatedShmOwnership { +struct alignas(16) AllocatedShmOwnership { uint32_t ref_count_; -} __attribute__((aligned(16))); +}; class SharedMemoryManager { public: @@ -140,6 +169,10 @@ class SharedMemoryManager { void Deallocate(bi::managed_external_buffer::handle_t handle) { + // Do not delete the control region, to avoid undefined behavior. + if (handle == kShmControlRegionHandle) { + return; + } bi::scoped_lock guard{*shm_mutex_}; GrowIfNeeded(0); void* ptr = managed_buffer_->get_address_from_handle(handle); @@ -148,6 +181,10 @@ class SharedMemoryManager { void DeallocateUnsafe(bi::managed_external_buffer::handle_t handle) { + // Do not delete the control region, to avoid undefined behavior. + if (handle == kShmControlRegionHandle) { + return; + } void* ptr = managed_buffer_->get_address_from_handle(handle); managed_buffer_->deallocate(ptr); } @@ -157,6 +194,14 @@ class SharedMemoryManager { void SetDeleteRegion(bool delete_region); + std::unique_ptr& GetCUDAMemoryPoolManager() + { + return cuda_memory_pool_manager_; + } + + uint64_t GetCurrentCapacity() { return current_capacity_; } + void* GetBaseAddress() { return managed_buffer_->get_address(); } + ~SharedMemoryManager() noexcept(false); private: @@ -171,6 +216,7 @@ class SharedMemoryManager { uint64_t* total_size_; bool create_; bool delete_region_; + std::unique_ptr cuda_memory_pool_manager_; template AllocatedSharedMemory WrapObjectInUniquePtr( diff --git a/src/shm_monitor/CMakeLists.txt b/src/shm_monitor/CMakeLists.txt index 0f7d4b86..2ae8bd45 100644 --- a/src/shm_monitor/CMakeLists.txt +++ b/src/shm_monitor/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -24,7 +24,7 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -cmake_minimum_required (VERSION 3.18) +cmake_minimum_required (VERSION 3.31.8) pybind11_add_module( triton-shm-monitor diff --git a/src/stub_launcher.cc b/src/stub_launcher.cc index de4dd46c..32f5d1bd 100644 --- a/src/stub_launcher.cc +++ b/src/stub_launcher.cc @@ -26,23 +26,30 @@ #include "stub_launcher.h" +#include + +#include "pb_utils.h" #include "python_be.h" +#ifdef _WIN32 +#include // getpid() +#endif + +extern char** environ; + namespace triton { namespace backend { namespace python { StubLauncher::StubLauncher(const std::string stub_process_kind) - : parent_pid_(0), stub_pid_(0), is_initialized_(false), + : parent_pid_(0), is_initialized_(false), stub_process_kind_(stub_process_kind), model_instance_name_(""), device_id_(0), kind_("") - { } StubLauncher::StubLauncher( const std::string stub_process_kind, const std::string model_instance_name, const int32_t device_id, const std::string kind) - : parent_pid_(0), stub_pid_(0), is_initialized_(false), - stub_process_kind_(stub_process_kind), + : is_initialized_(false), stub_process_kind_(stub_process_kind), model_instance_name_(model_instance_name), device_id_(device_id), kind_(kind) { @@ -62,22 +69,27 @@ StubLauncher::Initialize(ModelState* model_state) model_state->ModelConfig().Write(&model_config_buffer_); is_decoupled_ = model_state->IsDecoupled(); model_repository_path_ = model_state->RepositoryPath(); - platform_ = model_state->Platform(); - if (platform_.empty()) { - platform_ = "NONE"; + runtime_modeldir_ = model_state->RuntimeModelDir(); + if (runtime_modeldir_.empty()) { + runtime_modeldir_ = "DEFAULT"; } +#ifdef _WIN32 + ZeroMemory(&startup_info_, sizeof(startup_info_)); + startup_info_.cb = sizeof(startup_info_); + ZeroMemory(&stub_pid_, sizeof(stub_pid_)); +#else + stub_pid_ = 0; +#endif - // Atomically increase and read the stub process count to avoid shared memory - // region name collision - int num_init = ++model_state->StateForBackend()->number_of_instance_inits; shm_region_name_ = model_state->StateForBackend()->shared_memory_region_prefix + - std::to_string(num_init); + GenerateUUID(); model_version_ = model_state->Version(); std::stringstream ss; - ss << model_repository_path_ << "/" << model_version_ << "/"; + const char os_slash = std::filesystem::path::preferred_separator; + ss << model_repository_path_ << os_slash << model_version_ << os_slash; std::string artifact_name; RETURN_IF_ERROR(model_state->ModelConfig().MemberAsString( "default_model_filename", &artifact_name)); @@ -90,31 +102,20 @@ StubLauncher::Initialize(ModelState* model_state) model_path_ = ss.str(); - // Path to the extracted Python env - std::string python_execution_env = ""; + // FIXME [DLIS-5969]: Enable for Windows when custom execution environments + // are supported. if (python_execution_env_ != "") { - try { - python_execution_env = - model_state->StateForBackend()->env_manager->ExtractIfNotExtracted( - python_execution_env_); - } - catch (PythonBackendException& pb_exception) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, pb_exception.what()); - } - - path_to_activate_ = python_execution_env + "/bin/activate"; - path_to_libpython_ = python_execution_env + "/lib"; - if (python_execution_env.length() > 0 && !FileExists(path_to_activate_)) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - ("Path " + path_to_activate_ + - " does not exist. The Python environment should contain an " - "'activate' script.") - .c_str()); - } +#ifndef _WIN32 + RETURN_IF_ERROR(GetPythonEnvironment(model_state)); +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "Custom execution environments are not currently supported on " + "Windows."); +#endif } + parent_pid_ = getpid(); return nullptr; @@ -196,6 +197,11 @@ StubLauncher::Setup() return nullptr; } +// FIXME: This should be merged with the Unix launch function once Windows +// CI and functionality are demonstrably stable. The goal of keeping the +// functions separate is to help debug Windows-specific issues without worrying +// about the impact to our Unix builds. +#ifdef _WIN32 TRITONSERVER_Error* StubLauncher::Launch() { @@ -206,62 +212,173 @@ StubLauncher::Launch() stub_name = model_instance_name_; } - const char* stub_args[4]; - stub_args[0] = "bash"; - stub_args[1] = "-c"; - stub_args[3] = nullptr; // Last argument must be nullptr + const char os_slash = std::filesystem::path::preferred_separator; + + const std::string stub_executable_name = "triton_python_backend_stub.exe"; + SanitizePath(model_path_); + SanitizePath(model_repository_path_); // Default Python backend stub - std::string python_backend_stub = python_lib_ + "/triton_python_backend_stub"; + std::string python_backend_stub = + python_lib_ + os_slash + stub_executable_name; + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Stub path ") + python_backend_stub).c_str()); // Path to alternative Python backend stub std::string model_python_backend_stub = - std::string(model_repository_path_) + "/triton_python_backend_stub"; + std::string(model_repository_path_) + os_slash + stub_executable_name; + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Alt path ") + python_backend_stub).c_str()); + // Check if file exists + // TODO: Integrate win32 and pb_env if (FileExists(model_python_backend_stub)) { python_backend_stub = model_python_backend_stub; } - std::string bash_argument; + std::string launch_command; - // This shared memory variable indicates whether the stub process should - // revert the LD_LIBRARY_PATH changes to avoid shared library issues in - // executables and libraries. - ipc_control_->uses_env = false; - if (python_execution_env_ != "") { - std::stringstream ss; + std::stringstream ss; + ss << python_backend_stub << " " << model_path_ << " " << shm_region_name_ + << " " << shm_default_byte_size_ << " " << shm_growth_byte_size_ << " " + << parent_pid_ << " " << python_lib_ << " " << ipc_control_handle_ << " " + << stub_name << " " << runtime_modeldir_; + launch_command = ss.str(); - // Need to properly set the LD_LIBRARY_PATH so that Python environments - // using different python versions load properly. - ss << "source " << path_to_activate_ - << " && exec env LD_LIBRARY_PATH=" << path_to_libpython_ - << ":$LD_LIBRARY_PATH " << python_backend_stub << " " << model_path_ - << " " << shm_region_name_ << " " << shm_default_byte_size_ << " " - << shm_growth_byte_size_ << " " << parent_pid_ << " " << python_lib_ - << " " << ipc_control_handle_ << " " << stub_name << " " << platform_; - ipc_control_->uses_env = true; - bash_argument = ss.str(); - } else { + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Starting Python backend stub: ") + launch_command).c_str()); + + LPSTR launch_command_lpstr = const_cast(launch_command.c_str()); + // Start the child process. Unlike fork(), the remainder of this + // function exists in the context of the parent, only. + if (!CreateProcess( + NULL, // No module name (use command line) + launch_command_lpstr, // Command line + NULL, // Process handle not inheritable + NULL, // Thread handle not inheritable + FALSE, // Set handle inheritance to FALSE + 0, // No creation flags + NULL, // Use parent's environment block + NULL, // Use parent's starting directory + &startup_info_, // Pointer to STARTUPINFO structure + &stub_pid_) // Pointer to PROCESS_INFORMATION structure + ) { std::stringstream ss; - ss << " exec " << python_backend_stub << " " << model_path_ << " " - << shm_region_name_ << " " << shm_default_byte_size_ << " " - << shm_growth_byte_size_ << " " << parent_pid_ << " " << python_lib_ - << " " << ipc_control_handle_ << " " << stub_name << " " << platform_; - bash_argument = ss.str(); + ss << "Failed to run python backend stub. Errno = " << errno << '\n' + << "Python backend stub path: " << python_backend_stub << '\n' + << "Shared Memory Region Name: " << shm_region_name_ << '\n' + << "Shared Memory Default Byte Size: " << shm_default_byte_size_ << '\n' + << "Shared Memory Growth Byte Size: " << shm_growth_byte_size_ << '\n'; + // Print the error message directly because the underlying mutexes in + // LOG_MESSAGE() could be forked when it is locked by other thread(s). + std::cerr << '\n' << ss.str() << '\n'; + _Exit(1); } + ScopedDefer _([&] { + // Push a dummy message to the message queue so that the stub + // process is notified that it can release the object stored in + // shared memory. + if (stub_message_queue_) { + stub_message_queue_->Push(DUMMY_MESSAGE); + } + + // If the model is not initialized, wait for the stub process to exit. + if (!is_initialized_) { + stub_message_queue_.reset(); + parent_message_queue_.reset(); + memory_manager_.reset(); + WaitForStubProcess(); + } + }); + + // The stub process would send two messages to the parent process during the + // initialization. + // 1. When the stub process's health monitoring thread has started. + // 2. When the initialization is fully completed and the Python model is + // loaded. + // + // The reason it is broken into two steps is that creation of the health + // monitoring thread may take longer which can make the server process think + // that the stub process is unhealthy and return early. Waiting with a longer + // timeout prevents this issue. + const uint64_t initialization_timeout_ms = 10000; // 10 sec LOG_MESSAGE( TRITONSERVER_LOG_VERBOSE, - (std::string("Starting Python backend stub: ") + bash_argument).c_str()); + "Waiting for the stub health monitoring thread to start"); - stub_args[2] = bash_argument.c_str(); + bi::managed_external_buffer::handle_t message; + auto err = ReceiveMessageFromStub(message, initialization_timeout_ms); + if (err != nullptr) { + KillStubProcess(); + } - int stub_status_code = - system((python_backend_stub + "> /dev/null 2>&1").c_str()); + if (stub_process_kind_ == "AUTOCOMPLETE_STUB") { + if (err != nullptr) { + throw BackendModelException(err); + } + try { + AutocompleteStubProcess(); + } + catch (const PythonBackendException& ex) { + // Need to kill the stub process first + KillStubProcess(); + throw BackendModelException( + TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, ex.what())); + } + } else if (stub_process_kind_ == "MODEL_INSTANCE_STUB") { + RETURN_IF_ERROR(err); + RETURN_IF_ERROR(ModelInstanceStubProcess()); + } else { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("Unknown stub_process_kind: ") + stub_process_kind_) + .c_str()); + } + + is_initialized_ = true; + + return nullptr; +} +#else +TRITONSERVER_Error* +StubLauncher::Launch() +{ + std::string stub_name; + if (stub_process_kind_ == "AUTOCOMPLETE_STUB") { + stub_name = model_name_; + } else { + stub_name = model_instance_name_; + } + + if (!IsValidIdentifier(stub_name)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "Invalid stub name: contains invalid characters"); + } + + if (!IsValidIdentifier(shm_region_name_)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "Invalid shared memory region name: contains invalid characters"); + } + + // Default Python backend stub + std::string python_backend_stub = python_lib_ + "/triton_python_backend_stub"; + + // Path to alternative Python backend stub + std::string model_python_backend_stub = + std::string(model_repository_path_) + "/triton_python_backend_stub"; + + if (FileExists(model_python_backend_stub)) { + python_backend_stub = model_python_backend_stub; + } - // If running stub process without any arguments returns any status code, - // other than 1, it can indicate a permission issue as a result of - // downloading the stub process from a cloud object storage service. - if (WEXITSTATUS(stub_status_code) != 1) { + if (!IsExecutableFile(python_backend_stub)) { // Give the execute permission for the triton_python_backend_stub to the // owner. int error = chmod(python_backend_stub.c_str(), S_IXUSR); @@ -276,86 +393,227 @@ StubLauncher::Launch() } } - pid_t pid = fork(); - if (pid < 0) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "Failed to fork the stub process for auto-complete."); - } - if (pid == 0) { - // Replace this child process with the new stub process. - execvp("bash", (char**)stub_args); - // execvp() never return if succeeded. Otherwise, an error has occurred. - std::stringstream ss; - ss << "Failed to run python backend stub. Errno = " << errno << '\n' - << "Python backend stub path: " << python_backend_stub << '\n' - << "Shared Memory Region Name: " << shm_region_name_ << '\n' - << "Shared Memory Default Byte Size: " << shm_default_byte_size_ << '\n' - << "Shared Memory Growth Byte Size: " << shm_growth_byte_size_ << '\n'; - // Print the error message directly because the underlying mutexes in - // LOG_MESSAGE() could be forked when it is locked by other thread(s). - std::cerr << '\n' << ss.str() << '\n'; - // Terminate the child execution immediately to avoid any issues. - _Exit(1); - } else { - ScopedDefer _([&] { - // Push a dummy message to the message queue so that the stub - // process is notified that it can release the object stored in - // shared memory. - stub_message_queue_->Push(DUMMY_MESSAGE); + // Prepare arguments for execution + std::vector arg_strings; + std::vector exec_args; - // If the model is not initialized, wait for the stub process to exit. - if (!is_initialized_) { - int status; - stub_message_queue_.reset(); - parent_message_queue_.reset(); - memory_manager_.reset(); - waitpid(stub_pid_, &status, 0); - } - }); - - stub_pid_ = pid; - - // The stub process would send two messages to the parent process during the - // initialization. - // 1. When the stub process's health monitoring thread has started. - // 2. When the initialization is fully completed and the Python model is - // loaded. - // - // The reason it is broken into two steps is that creation of the health - // monitoring thread may take longer which can make the server process think - // that the stub process is unhealthy and return early. Waiting until the - // health thread is spawn would make sure would prevent this issue. - parent_message_queue_->Pop(); - - if (stub_process_kind_ == "AUTOCOMPLETE_STUB") { - try { - AutocompleteStubProcess(); - } - catch (const PythonBackendException& ex) { - // Need to kill the stub process first - kill(stub_pid_, SIGKILL); - int status; - waitpid(stub_pid_, &status, 0); - stub_pid_ = 0; - throw BackendModelException( - TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, ex.what())); - } - } else if (stub_process_kind_ == "MODEL_INSTANCE_STUB") { - RETURN_IF_ERROR(ModelInstanceStubProcess()); + // This shared memory variable indicates whether the stub process should + // revert the LD_LIBRARY_PATH changes to avoid shared library issues in + // executables and libraries. + ipc_control_->uses_env = false; + + if (python_execution_env_ != "") { + ipc_control_->uses_env = true; + + // Parse environment variables from activation script + std::map env_vars = + ParseActivationScript(path_to_activate_); + + // Prepare environment with additional library path + auto [env_strings, custom_env] = + PrepareEnvironment(env_vars, path_to_libpython_); + + // Set up arguments for direct execution + arg_strings.push_back(python_backend_stub); + arg_strings.push_back(model_path_); + arg_strings.push_back(shm_region_name_); + arg_strings.push_back(std::to_string(shm_default_byte_size_)); + arg_strings.push_back(std::to_string(shm_growth_byte_size_)); + arg_strings.push_back(std::to_string(parent_pid_)); + arg_strings.push_back(python_lib_); + arg_strings.push_back(std::to_string(ipc_control_handle_)); + arg_strings.push_back(stub_name); + arg_strings.push_back(runtime_modeldir_); + + // Convert strings to char* array for exec + for (const auto& arg : arg_strings) { + exec_args.push_back(arg.c_str()); + } + exec_args.push_back(nullptr); // exec requires null termination + + // Log the command being executed + std::ostringstream log_cmd; + for (size_t i = 0; i < arg_strings.size(); ++i) { + if (i > 0) + log_cmd << " "; + log_cmd << "'" << arg_strings[i] << "'"; + } + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("Starting Python backend stub with custom environment: ") + + log_cmd.str()) + .c_str()); + + pid_t pid = fork(); + if (pid < 0) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "Failed to fork the stub process for auto-complete."); + } + if (pid == 0) { + // Replace this child process with the new stub process using custom + // environment + execve( + python_backend_stub.c_str(), const_cast(exec_args.data()), + custom_env.data()); + // execve() never returns if succeeded. Otherwise, an error has occurred. + std::stringstream ss; + ss << "Failed to run python backend stub with custom environment. Errno " + "= " + << errno << '\n' + << "Python backend stub path: " << python_backend_stub << '\n' + << "Activation script: " << path_to_activate_ << '\n' + << "Library path: " << path_to_libpython_ << '\n'; + std::cerr << '\n' << ss.str() << '\n'; + _Exit(1); } else { + stub_pid_ = pid; + } + + } else { + arg_strings.push_back(python_backend_stub); + arg_strings.push_back(model_path_); + arg_strings.push_back(shm_region_name_); + arg_strings.push_back(std::to_string(shm_default_byte_size_)); + arg_strings.push_back(std::to_string(shm_growth_byte_size_)); + arg_strings.push_back(std::to_string(parent_pid_)); + arg_strings.push_back(python_lib_); + arg_strings.push_back(std::to_string(ipc_control_handle_)); + arg_strings.push_back(stub_name); + arg_strings.push_back(runtime_modeldir_); + + // Convert strings to char* array for exec + for (const auto& arg : arg_strings) { + exec_args.push_back(arg.c_str()); + } + exec_args.push_back(nullptr); // exec requires null termination + + // Log the command being executed + std::ostringstream log_cmd; + for (size_t i = 0; i < arg_strings.size(); ++i) { + if (i > 0) + log_cmd << " "; + log_cmd << "'" << arg_strings[i] << "'"; + } + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("Starting Python backend stub: ") + log_cmd.str()) + .c_str()); + + pid_t pid = fork(); + if (pid < 0) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INTERNAL, - (std::string("Unknown stub_process_kind: ") + stub_process_kind_) - .c_str()); + "Failed to fork the stub process for auto-complete."); } + if (pid == 0) { + // Replace this child process with the new stub process. + execv(python_backend_stub.c_str(), const_cast(exec_args.data())); + // execv() never returns if succeeded. Otherwise, an error has occurred. + std::stringstream ss; + ss << "Failed to run python backend stub. Errno = " << errno << '\n' + << "Python backend stub path: " << python_backend_stub << '\n'; + std::cerr << '\n' << ss.str() << '\n'; + _Exit(1); + } else { + stub_pid_ = pid; + } + } - is_initialized_ = true; + ScopedDefer _([&] { + // Push a dummy message to the message queue so that the stub + // process is notified that it can release the object stored in + // shared memory. + if (stub_message_queue_) { + stub_message_queue_->Push(DUMMY_MESSAGE); + } + + // If the model is not initialized, wait for the stub process to exit. + if (!is_initialized_) { + stub_message_queue_.reset(); + parent_message_queue_.reset(); + memory_manager_.reset(); + WaitForStubProcess(); + } + }); + + // The stub process would send two messages to the parent process during the + // initialization. + // 1. When the stub process's health monitoring thread has started. + // 2. When the initialization is fully completed and the Python model is + // loaded. + // + // The reason it is broken into two steps is that creation of the health + // monitoring thread may take longer which can make the server process think + // that the stub process is unhealthy and return early. Waiting with a + // longer timeout prevents this issue. + const uint64_t initialization_timeout_ms = 10000; // 10 sec + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + "Waiting for the stub health monitoring thread to start"); + + bi::managed_external_buffer::handle_t message; + auto err = ReceiveMessageFromStub(message, initialization_timeout_ms); + if (err != nullptr) { + KillStubProcess(); } + if (stub_process_kind_ == "AUTOCOMPLETE_STUB") { + if (err != nullptr) { + throw BackendModelException(err); + } + try { + AutocompleteStubProcess(); + } + catch (const PythonBackendException& ex) { + // Need to kill the stub process first + KillStubProcess(); + throw BackendModelException( + TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, ex.what())); + } + } else if (stub_process_kind_ == "MODEL_INSTANCE_STUB") { + RETURN_IF_ERROR(err); + RETURN_IF_ERROR(ModelInstanceStubProcess()); + } else { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("Unknown stub_process_kind: ") + stub_process_kind_) + .c_str()); + } + + is_initialized_ = true; + return nullptr; } +TRITONSERVER_Error* +StubLauncher::GetPythonEnvironment(ModelState* model_state) +{ + std::string python_execution_env = ""; + try { + python_execution_env = + model_state->StateForBackend()->env_manager->ExtractIfNotExtracted( + python_execution_env_); + } + catch (PythonBackendException& pb_exception) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, pb_exception.what()); + } + + path_to_activate_ = python_execution_env + "/bin/activate"; + path_to_libpython_ = python_execution_env + "/lib"; + if (python_execution_env.length() > 0 && !FileExists(path_to_activate_)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + ("Path " + path_to_activate_ + + " does not exist. The Python environment should contain an " + "'activate' script.") + .c_str()); + } + return nullptr; +} +#endif + void StubLauncher::AutocompleteStubProcess() { @@ -435,8 +693,13 @@ StubLauncher::ModelInstanceStubProcess() initialize_message->Args() = initialize_map_handle; stub_message_queue_->Push(initialize_message->ShmHandle()); + const uint64_t initialization_timeout_ms = 5000; // 5 sec + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + "Waiting for the stub process initialization response"); + bi::managed_external_buffer::handle_t message; - RETURN_IF_ERROR(ReceiveMessageFromStub(message)); + RETURN_IF_ERROR(ReceiveMessageFromStub(message, initialization_timeout_ms)); std::unique_ptr initialize_response_message = IPCMessage::LoadFromSharedMemory(shm_pool_, message); @@ -472,6 +735,18 @@ StubLauncher::ModelInstanceStubProcess() return nullptr; } +bool +StubLauncher::StubActive() +{ +#ifdef _WIN32 + DWORD ec; + GetExitCodeProcess(stub_pid_.hProcess, &ec); + return (ec == STILL_ACTIVE); +#else + return (stub_pid_ != 0); +#endif +} + void StubLauncher::UpdateHealth() { @@ -482,9 +757,13 @@ StubLauncher::UpdateHealth() ipc_control_->stub_health = false; } - // Sleep 1 second so that the child process has a chance to change the - // health variable +// Sleep 1 second so that the child process has a chance to change the +// health variable +#ifdef _WIN32 + Sleep(1); +#else sleep(1); +#endif { bi::scoped_lock lock(*health_mutex_); @@ -514,11 +793,11 @@ StubLauncher::TerminateStub() force_kill = true; } - int status; if (force_kill) { - kill(stub_pid_, SIGKILL); + KillStubProcess(); + } else { + WaitForStubProcess(); } - waitpid(stub_pid_, &status, 0); } // First destroy the IPCControl. This makes sure that IPCControl is @@ -539,19 +818,25 @@ StubLauncher::ClearQueues() void StubLauncher::KillStubProcess() { +#ifdef _WIN32 + unsigned int exit_code; + TerminateProcess(stub_pid_.hProcess, exit_code); + CloseHandle(stub_pid_.hProcess); + CloseHandle(stub_pid_.hThread); +#else kill(stub_pid_, SIGKILL); - int status; - waitpid(stub_pid_, &status, 0); + WaitForStubProcess(); stub_pid_ = 0; +#endif } TRITONSERVER_Error* StubLauncher::ReceiveMessageFromStub( - bi::managed_external_buffer::handle_t& message) + bi::managed_external_buffer::handle_t& message, + uint64_t timeout_miliseconds) { bool success = false; while (!success) { - uint64_t timeout_miliseconds = 1000; { boost::posix_time::ptime timeout = boost::get_system_time() + @@ -598,4 +883,124 @@ StubLauncher::ReceiveMessageFromStub( return nullptr; // success } -}}}; // namespace triton::backend::python + +void +StubLauncher::WaitForStubProcess() +{ +#ifdef _WIN32 + WaitForSingleObject(stub_pid_.hProcess, INFINITE); + CloseHandle(stub_pid_.hProcess); + CloseHandle(stub_pid_.hThread); +#else + int status; + if (stub_pid_ != 0) { + // Added this check to ensure server doesn't hang waiting after stub + // process has already be killed and cannot be waited on + waitpid(stub_pid_, &status, 0); + } +#endif +} + +#ifdef TRITON_ENABLE_GPU +void +StubLauncher::ShareCUDAMemoryPool( + TRITONBACKEND_MemoryManager* triton_mem_manager, const int32_t device_id) +{ + std::lock_guard lock(cuda_shm_pool_mutex_); + if ((tried_sharing_cuda_pool_map_.find(device_id) != + tried_sharing_cuda_pool_map_.end()) && + tried_sharing_cuda_pool_map_[device_id]) { + return; + } + + std::unique_ptr ipc_message = + IPCMessage::Create(shm_pool_, true /* inline_response */); + CUDAMemPoolMessage* cuda_pool_message_ptr = nullptr; + PythonBackendException pb_exception(std::string{}); + + try { + // Create a dummy BackendMemory object to get the start address of the CUDA + // memory pool. + BackendMemory* backend_memory; + std::unique_ptr lbackend_memory; + + THROW_IF_TRITON_ERROR(BackendMemory::Create( + triton_mem_manager, BackendMemory::AllocationType::GPU_POOL, device_id, + 1 /* byte size*/, &backend_memory)); + lbackend_memory.reset(backend_memory); + + CUDAHandler& cuda_api = CUDAHandler::getInstance(); + CUdeviceptr cuda_pool_address = 0; + cuda_api.PointerGetAttribute( + &cuda_pool_address, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + reinterpret_cast(lbackend_memory->MemoryPtr())); + + shm_pool_->GetCUDAMemoryPoolManager()->SetCUDAPoolAddress( + device_id, reinterpret_cast(cuda_pool_address)); + shm_pool_->GetCUDAMemoryPoolManager()->SetTritonMemoryManager( + reinterpret_cast(triton_mem_manager)); + + // Get the memory handle from the CUDA memory pool. + AllocatedSharedMemory cuda_pool_message = + shm_pool_->Construct(); + cuda_pool_message_ptr = cuda_pool_message.data_.get(); + { + ScopedSetDevice scoped_set_device(device_id); + THROW_IF_CUDA_ERROR(cudaIpcGetMemHandle( + reinterpret_cast( + &cuda_pool_message_ptr->cuda_handle), + reinterpret_cast(shm_pool_->GetCUDAMemoryPoolManager() + ->CUDAPoolAddress(device_id)))); + } + + ipc_message->Command() = PYTHONSTUB_CUDAPoolInitializeRequest; + ipc_message->Args() = cuda_pool_message.handle_; + + cuda_pool_message_ptr->device_id = device_id; + cuda_pool_message_ptr->has_error = false; + cuda_pool_message_ptr->is_error_set = false; + cuda_pool_message_ptr->waiting_on_stub = false; + + { + bi::scoped_lock lock{ + *(ipc_message->ResponseMutex())}; + parent_to_stub_mq_->Push(ipc_message->ShmHandle()); + while (!cuda_pool_message_ptr->waiting_on_stub) { + ipc_message->ResponseCondition()->wait(lock); + } + } + + if (cuda_pool_message_ptr->has_error) { + if (cuda_pool_message_ptr->is_error_set) { + std::unique_ptr error_message = + PbString::LoadFromSharedMemory( + shm_pool_, cuda_pool_message_ptr->error); + throw PythonBackendException(error_message->String()); + } else { + throw PythonBackendException( + "Failed to share CUDA memory pool with stub process: " + + model_name_); + } + } + } + catch (const PythonBackendException& exception) { + shm_pool_->GetCUDAMemoryPoolManager()->SetCUDAPoolAddress( + device_id, nullptr); + pb_exception = exception; + } + + { + bi::scoped_lock lock{ + *(ipc_message->ResponseMutex())}; + cuda_pool_message_ptr->waiting_on_stub = false; + ipc_message->ResponseCondition()->notify_all(); + } + + tried_sharing_cuda_pool_map_[device_id] = true; + + if (pb_exception.what() != std::string{""}) { + throw pb_exception; + } +} +#endif // TRITON_ENABLE_GPU +}}}; // namespace triton::backend::python diff --git a/src/stub_launcher.h b/src/stub_launcher.h index 89f35422..58cdcc61 100644 --- a/src/stub_launcher.h +++ b/src/stub_launcher.h @@ -26,8 +26,6 @@ #pragma once -#include - #include #include #include @@ -79,8 +77,8 @@ class StubLauncher { // Model instance stub process TRITONSERVER_Error* ModelInstanceStubProcess(); - // Stub PID - pid_t StubPid() { return stub_pid_; } + // Check if Stub PID is active + bool StubActive(); // Health mutex bi::interprocess_mutex* HealthMutex() { return health_mutex_; } @@ -149,19 +147,39 @@ class StubLauncher { // Get a message from the stub process TRITONSERVER_Error* ReceiveMessageFromStub( - bi::managed_external_buffer::handle_t& message); + bi::managed_external_buffer::handle_t& message, + uint64_t timeout_miliseconds = 1000); + + // Wait for stub process + void WaitForStubProcess(); + +#ifndef _WIN32 + // FIXME [DLIS-5969]: Enable for Windows when custom execution environments + // are supported. + TRITONSERVER_Error* GetPythonEnvironment(ModelState* model_state); +#endif +#ifdef TRITON_ENABLE_GPU + // Share CUDA memory pool with stub process + void ShareCUDAMemoryPool( + TRITONBACKEND_MemoryManager* triton_mem_manager, const int32_t device_id); +#endif // TRITON_ENABLE_GPU private: +#ifdef _WIN32 + STARTUPINFO startup_info_; + DWORD parent_pid_; + PROCESS_INFORMATION stub_pid_; +#else pid_t parent_pid_; pid_t stub_pid_; - +#endif bool is_initialized_; bool is_decoupled_; bool is_healthy_; std::string shm_region_name_; std::string model_repository_path_; std::string model_path_; - std::string platform_; + std::string runtime_modeldir_; const std::string stub_process_kind_; std::string model_name_; const std::string model_instance_name_; @@ -196,5 +214,9 @@ class StubLauncher { ipc_control_; bi::managed_external_buffer::handle_t ipc_control_handle_; std::unique_ptr shm_pool_; +#ifdef TRITON_ENABLE_GPU + std::mutex cuda_shm_pool_mutex_; + std::unordered_map tried_sharing_cuda_pool_map_; +#endif // TRITON_ENABLE_GPU }; }}} // namespace triton::backend::python