diff --git a/ci/L0_backend_vllm/accuracy_test/accuracy_test.py b/ci/L0_backend_vllm/accuracy_test/accuracy_test.py index 15c343a2..6000c1a2 100644 --- a/ci/L0_backend_vllm/accuracy_test/accuracy_test.py +++ b/ci/L0_backend_vllm/accuracy_test/accuracy_test.py @@ -190,7 +190,6 @@ def test_guided_decoding(self): sampling_params = SAMPLING_PARAMETERS guided_decoding_params = { "choice": ["Positive", "Negative"], - "backend": "outlines", } sampling_params["guided_decoding"] = json.dumps(guided_decoding_params) for i in range(len(GUIDED_PROMPTS)): @@ -245,7 +244,6 @@ def tearDown(self): if FLAGS.generate_guided_baseline: guided_decoding_params = { "choice": ["Positive", "Negative"], - "backend": "outlines", } guided_generation = GuidedDecodingParams(**guided_decoding_params) asyncio.run( diff --git a/ci/L0_backend_vllm/accuracy_test/test.sh b/ci/L0_backend_vllm/accuracy_test/test.sh index 8a94fff0..f575b7b1 100755 --- a/ci/L0_backend_vllm/accuracy_test/test.sh +++ b/ci/L0_backend_vllm/accuracy_test/test.sh @@ -48,17 +48,11 @@ RET=0 set +e # Need to generate baseline first, since running 2 vLLM engines causes # memory issues: https://github.com/vllm-project/vllm/issues/2248 -export VLLM_USE_V1=0 -export VLLM_WORKER_MULTIPROC_METHOD=spawn python3 $CLIENT_PY --generate-baseline >> $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$! wait $BASELINE_PID python3 $CLIENT_PY --generate-guided-baseline > $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$! wait $BASELINE_PID - -unset VLLM_USE_V1 -unset VLLM_WORKER_MULTIPROC_METHOD - set -e run_server @@ -88,12 +82,6 @@ set -e kill $SERVER_PID wait $SERVER_PID -# Check that warning about V1 Engine appears in log - this warning is expected -if ! grep -q "Engine in background thread is experimental on VLLM_USE_V1=1. Falling back to V0 Engine." $SERVER_LOG; then - echo -e "\n***\n*** ERROR: Expected warning about vLLM falling back to V0 Engine not found in logs.\n***" - RET=1 -fi - rm -rf models/ if [ $RET -eq 1 ]; then diff --git a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py index b8ddeb49..0111056c 100644 --- a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py +++ b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py @@ -173,13 +173,11 @@ def test_vllm_metrics(self): # TODO: Revisit this test due to the removal of best_of def test_custom_sampling_params(self): # Adding sampling parameters for testing metrics. - # Definitions can be found here https://docs.vllm.ai/en/latest/dev/sampling_params.html - n, best_of = 2, 4 + # Definitions can be found here https://docs.vllm.ai/en/latest/api/vllm/sampling_params.html + n, temperature = 2, 1 custom_sampling_parameters = self.sampling_parameters.copy() - # Changing "temperature" because "best_of" must be 1 when using greedy - # sampling, i.e. "temperature": "0". custom_sampling_parameters.update( - {"n": str(n), "best_of": str(best_of), "temperature": "1"} + {"n": str(n), "temperature": str(temperature)} ) # Test vLLM metrics diff --git a/ci/L0_backend_vllm/test.sh b/ci/L0_backend_vllm/test.sh index b4d27357..c3ff6c8e 100755 --- a/ci/L0_backend_vllm/test.sh +++ b/ci/L0_backend_vllm/test.sh @@ -28,6 +28,9 @@ RET=0 SUBTESTS="accuracy_test request_cancellation enabled_stream vllm_backend metrics_test" +export C_INCLUDE_PATH=/usr/local/cuda/include:$C_INCLUDE_PATH +export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas + python3 -m pip install tritonclient[grpc] for TEST in ${SUBTESTS}; do diff --git a/ci/L0_check_health_vllm/test.sh b/ci/L0_check_health_vllm/test.sh index 80668bcb..3918d3d8 100755 --- a/ci/L0_check_health_vllm/test.sh +++ b/ci/L0_check_health_vllm/test.sh @@ -31,11 +31,12 @@ source ../common/util.sh pip3 install pytest==8.1.1 pip3 install tritonclient[grpc] +rm -f *.log *.report.xml RET=0 function setup_model_repository { - local sample_model_repo_path=${1:-"../../samples/model_repository"} - rm -rf models vllm_baseline_output.pkl && mkdir -p models + local sample_model_repo_path="../../samples/model_repository" + rm -rf models && mkdir -p models cp -r $sample_model_repo_path/vllm_model models/vllm_opt } @@ -48,23 +49,24 @@ function enable_health_check { } VLLM_INSTALL_PATH="/usr/local/lib/python3.12/dist-packages/vllm" +VLLM_V1_ENGINE_PATH="$VLLM_INSTALL_PATH/v1/engine" function mock_vllm_async_llm_engine { # backup original file - mv $VLLM_INSTALL_PATH/engine/multiprocessing/client.py $VLLM_INSTALL_PATH/engine/multiprocessing/client.py.backup - cp $VLLM_INSTALL_PATH/engine/multiprocessing/client.py.backup $VLLM_INSTALL_PATH/engine/multiprocessing/client.py + mv $VLLM_V1_ENGINE_PATH/async_llm.py $VLLM_V1_ENGINE_PATH/async_llm.py.backup + cp $VLLM_V1_ENGINE_PATH/async_llm.py.backup $VLLM_V1_ENGINE_PATH/async_llm.py # overwrite the original check_health method - echo -e "" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py - echo -e " async def check_health(self, check_count=[0]):" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py - echo -e " check_count[0] += 1" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py - echo -e " if check_count[0] > 1:" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py - echo -e " raise RuntimeError(\"Simulated vLLM check_health() failure\")" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py + echo -e "" >> $VLLM_V1_ENGINE_PATH/async_llm.py + echo -e " async def check_health(self, check_count=[0]):" >> $VLLM_V1_ENGINE_PATH/async_llm.py + echo -e " check_count[0] += 1" >> $VLLM_V1_ENGINE_PATH/async_llm.py + echo -e " if check_count[0] > 1:" >> $VLLM_V1_ENGINE_PATH/async_llm.py + echo -e " raise RuntimeError(\"Simulated vLLM check_health() failure\")" >> $VLLM_V1_ENGINE_PATH/async_llm.py } function unmock_vllm_async_llm_engine { # restore from backup - rm -f $VLLM_INSTALL_PATH/engine/multiprocessing/client.py - mv $VLLM_INSTALL_PATH/engine/multiprocessing/client.py.backup $VLLM_INSTALL_PATH/engine/multiprocessing/client.py + rm -f $VLLM_V1_ENGINE_PATH/async_llm.py + mv $VLLM_V1_ENGINE_PATH/async_llm.py.backup $VLLM_V1_ENGINE_PATH/async_llm.py } function test_check_health { @@ -93,8 +95,12 @@ function test_check_health { } # Test health check unspecified +# Cold start on SBSA device can take longer than default 120 seconds +PREV_SERVER_TIMEOUT=$SERVER_TIMEOUT +SERVER_TIMEOUT=240 setup_model_repository test_check_health "health_check_unspecified" "test_vllm_is_healthy" +SERVER_TIMEOUT=$PREV_SERVER_TIMEOUT # Test health check disabled setup_model_repository diff --git a/src/model.py b/src/model.py index 7a135dcf..4145b71b 100644 --- a/src/model.py +++ b/src/model.py @@ -35,7 +35,6 @@ from typing import Dict, List import numpy as np -import torch import triton_python_backend_utils as pb_utils from PIL import Image from vllm.engine.arg_utils import AsyncEngineArgs @@ -45,7 +44,7 @@ from vllm.lora.request import LoRARequest from vllm.utils import random_uuid -from utils.metrics import VllmStatLogger +from utils.metrics import VllmStatLoggerFactory from utils.vllm_backend_utils import TritonSamplingParams _VLLM_ENGINE_ARGS_FILENAME = "model.json" @@ -184,12 +183,12 @@ def initialize(self, args): and not self._aync_engine_args.disable_log_stats ) - # Starting the vLLM engine and its event thread running the AsyncIO event loop. - self._init_engine() - # Setup vLLM metrics self._setup_metrics() + # Starting the vLLM engine and its event thread running the AsyncIO event loop. + self._init_engine() + # Starting the response thread. It allows vLLM to keep making progress while # response sender(s) are sending responses to server frontend. self._response_queue = queue.Queue() @@ -258,6 +257,7 @@ async def _run_llm_engine(self): async with build_async_engine_client_from_engine_args( engine_args=self._aync_engine_args, disable_frontend_multiprocessing=self._enable_metrics, + stat_loggers=self._vllm_metrics, ) as engine: # Capture the engine event loop and make it visible to other threads. self._event_loop = asyncio.get_running_loop() @@ -348,7 +348,7 @@ def _setup_lora(self): ) def _setup_metrics(self): - self._vllm_metrics = None + self._vllm_metrics = [] # TODO: Do not read metrics directly from the vLLM engine, read from prometheus # client to allow the use of ZMQ process when metrics are enabled. See # https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/entrypoints/openai/api_server.py#L222-L245 @@ -359,9 +359,8 @@ def _setup_metrics(self): "version": self.args["model_version"], } # Add vLLM custom metrics - vllm_config = self._llm_engine.engine.vllm_config - self._vllm_metrics = VllmStatLogger(labels, vllm_config, self.logger) - self._llm_engine.add_logger("triton", self._vllm_metrics) + factory = VllmStatLoggerFactory(labels, self.logger) + self._vllm_metrics.append(factory) except pb_utils.TritonModelException as e: if "metrics not supported" in str(e): # Metrics are disabled at the server @@ -785,8 +784,8 @@ def finalize(self): self._response_thread = None # Shutdown the metrics thread. - if self._vllm_metrics is not None: - self._vllm_metrics.finalize() + for stat_logger_factory in self._vllm_metrics: + stat_logger_factory.finalize() # When using parallel tensors, the stub process may not shutdown due to # unreleased references, so manually run the garbage collector once. diff --git a/src/utils/metrics.py b/src/utils/metrics.py index ecb044d4..644eb6d9 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -26,13 +26,12 @@ import queue import threading -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import triton_python_backend_utils as pb_utils from vllm.config import VllmConfig -from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase -from vllm.engine.metrics import Stats as VllmStats -from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets +from vllm.v1.metrics.loggers import StatLoggerBase, build_1_2_5_buckets +from vllm.v1.metrics.stats import IterationStats, SchedulerStats class TritonMetrics: @@ -161,13 +160,35 @@ def __init__(self, labels: List[str], max_model_len: int): ) -class VllmStatLogger(VllmStatLoggerBase): +# Create a partially initialized callable that adapts VllmStatLogger to StatLoggerFactory interface +class VllmStatLoggerFactory: + def __init__(self, labels, log_logger): + self._labels = labels + self._log_logger = log_logger + self._instances_list = [] + + def __call__(self, vllm_config, engine_index): + stat_logger = VllmStatLogger( + self._labels, self._log_logger, vllm_config, engine_index + ) + self._instances_list.append(stat_logger) + return stat_logger + + def finalize(self): + for stat_logger in self._instances_list: + if stat_logger is not None: + stat_logger.finalize() + + +class VllmStatLogger(StatLoggerBase): """StatLogger is used as an adapter between vLLM stats collector and Triton metrics provider.""" - def __init__(self, labels: Dict, vllm_config: VllmConfig, log_logger) -> None: + def __init__( + self, labels: Dict, log_logger, vllm_config: VllmConfig, engine_index: int + ) -> None: # Tracked stats over current local logging interval. # local_interval not used here. It's for vLLM logs to stdout. - super().__init__(local_interval=0, vllm_config=vllm_config) + super().__init__(vllm_config=vllm_config, engine_index=engine_index) self.metrics = TritonMetrics( labels=labels, max_model_len=vllm_config.model_config.max_model_len ) @@ -176,12 +197,9 @@ def __init__(self, labels: Dict, vllm_config: VllmConfig, log_logger) -> None: # Starting the metrics thread. It allows vLLM to keep making progress # while reporting metrics to triton metrics service. self._logger_queue = queue.Queue() - self._logger_thread = threading.Thread(target=self.logger_loop) + self._logger_thread = threading.Thread(target=self._logger_loop) self._logger_thread.start() - def info(self, type: str, obj: SupportsMetricsInfo) -> None: - pass - def _log_counter(self, counter, data: Union[int, float]) -> None: """Convenience function for logging to counter. @@ -208,7 +226,12 @@ def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None for datum in data: self._logger_queue.put_nowait((histogram, "observe", datum)) - def log(self, stats: VllmStats) -> None: + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0, + ) -> None: """Report stats to Triton metrics server. Args: @@ -217,38 +240,54 @@ def log(self, stats: VllmStats) -> None: Returns: None """ + + # Parse finished request stats into lists + e2e_latency: List[float] = [] + num_prompt_tokens: List[int] = [] + num_generation_tokens: List[int] = [] + for finished_req in iteration_stats.finished_requests: + e2e_latency.append(finished_req.e2e_latency) + num_prompt_tokens.append(finished_req.num_prompt_tokens) + num_generation_tokens.append(finished_req.num_generation_tokens) + # The list of vLLM metrics reporting to Triton is also documented here. # https://github.com/triton-inference-server/vllm_backend/blob/main/README.md#triton-metrics counter_metrics = [ - (self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter), - (self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter), + (self.metrics.counter_prompt_tokens, iteration_stats.num_prompt_tokens), + ( + self.metrics.counter_generation_tokens, + iteration_stats.num_generation_tokens, + ), ] histogram_metrics = [ ( self.metrics.histogram_time_to_first_token, - stats.time_to_first_tokens_iter, + iteration_stats.time_to_first_tokens_iter, ), ( self.metrics.histogram_time_per_output_token, - stats.time_per_output_tokens_iter, + iteration_stats.inter_token_latencies_iter, ), - (self.metrics.histogram_e2e_time_request, stats.time_e2e_requests), + (self.metrics.histogram_e2e_time_request, e2e_latency), ( self.metrics.histogram_num_prompt_tokens_request, - stats.num_prompt_tokens_requests, + num_prompt_tokens, ), ( self.metrics.histogram_num_generation_tokens_request, - stats.num_generation_tokens_requests, + num_generation_tokens, ), - (self.metrics.histogram_n_request, stats.n_requests), + (self.metrics.histogram_n_request, iteration_stats.n_params_iter), ] for metric, data in counter_metrics: self._log_counter(metric, data) for metric, data in histogram_metrics: self._log_histogram(metric, data) - def logger_loop(self): + def log_engine_initialized(self) -> None: + pass + + def _logger_loop(self): while True: item = self._logger_queue.get() # To signal shutdown a None item will be added to the queue.