Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 7 additions & 58 deletions tests/unittest/llmapi/run_llm_with_postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,64 +6,11 @@

import click

from tensorrt_llm.executor import GenerationResultBase
from tensorrt_llm.executor.postproc_worker import PostprocArgs, PostprocParams
from tensorrt_llm.executor.postproc_worker import PostprocParams
from tensorrt_llm.llmapi import LLM, KvCacheConfig, SamplingParams
from tensorrt_llm.llmapi.utils import print_colored
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
DeltaMessage)


def perform_faked_oai_postprocess(rsp: GenerationResultBase,
args: PostprocArgs):
first_iteration = len(rsp.outputs[0].token_ids) == 1
num_choices = 1
finish_reason_sent = [False] * num_choices
role = "assistant"
model = "LLaMA"

def yield_first_chat(idx: int, role: str = None, content: str = None):
choice_data = ChatCompletionResponseStreamChoice(index=idx,
delta=DeltaMessage(
role=role,
content=content),
finish_reason=None)
chunk = ChatCompletionStreamResponse(choices=[choice_data], model=model)

data = chunk.model_dump_json(exclude_unset=True)
return data

res = []
if first_iteration:
for i in range(num_choices):
res.append(f"data: {yield_first_chat(i, role=role)} \n\n")
first_iteration = False

for output in rsp.outputs:
i = output.index

if finish_reason_sent[i]:
continue

delta_text = output.text_diff
delta_message = DeltaMessage(content=delta_text)

choice = ChatCompletionResponseStreamChoice(index=i,
delta=delta_message,
finish_reason=None)
if output.finish_reason is not None:
choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
finish_reason_sent[i] = True
chunk = ChatCompletionStreamResponse(choices=[choice], model=model)
data = chunk.model_dump_json(exclude_unset=True)
res.append(f"data: {data}\n\n")

if rsp._done:
res.append(f"data: [DONE]\n\n")

return res
from tensorrt_llm.serve.postprocess_handlers import (ChatPostprocArgs,
chat_stream_post_processor)


@click.command()
Expand Down Expand Up @@ -98,9 +45,11 @@ def main(model_dir: str, tp_size: int, engine_dir: Optional[str], n: int,
n=n,
best_of=best_of,
top_k=top_k)
postproc_args = ChatPostprocArgs(role="assistant",
model="TinyLlama-1.1B-Chat-v1.0")
postproc_params = PostprocParams(
post_processor=perform_faked_oai_postprocess,
postproc_args=PostprocArgs(),
post_processor=chat_stream_post_processor,
postproc_args=postproc_args,
)

prompt = "A B C D E F"
Expand Down
37 changes: 19 additions & 18 deletions tests/unittest/llmapi/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2143,16 +2143,18 @@ def test_llm_with_postprocess_parallel():
def run_llm_with_postprocess_parallel_and_result_handler(
streaming, backend, tp_size: int = 1):
# avoid import error when running in CI
from tensorrt_llm.executor.postproc_worker import (PostprocArgs,
PostprocParams)
from tensorrt_llm.executor.postproc_worker import PostprocParams
from tensorrt_llm.serve.postprocess_handlers import (
ChatPostprocArgs, chat_stream_post_processor)

from .run_llm_with_postproc import perform_faked_oai_postprocess
from .run_llm_with_postproc import get_concatenated_content

sampling_params = SamplingParams(max_tokens=6)
post_proc_args = PostprocArgs(tokenizer=llama_model_path)
post_proc_params = PostprocParams(
post_processor=perform_faked_oai_postprocess,
postproc_args=post_proc_args)
post_proc_args = ChatPostprocArgs(tokenizer=llama_model_path,
role="assistant",
model=llama_model_path)
post_proc_params = PostprocParams(post_processor=chat_stream_post_processor,
postproc_args=post_proc_args)
kwargs = {}
if backend not in ["pytorch", "autodeploy"]:
kwargs["fast_build"] = True
Expand All @@ -2163,17 +2165,16 @@ def run_llm_with_postprocess_parallel_and_result_handler(
num_postprocess_workers=2,
postprocess_tokenizer_dir=llama_model_path,
**kwargs)
golden_result = "DEFGHI"
for i, output in enumerate(
llm.generate_async(prompts[0],
sampling_params=sampling_params,
_postproc_params=post_proc_params,
streaming=streaming)):
if i < len(golden_result) - 1:
assert golden_result[i] in output.outputs[0]._postprocess_result[-1]
else:
assert golden_result[i] in output.outputs[0]._postprocess_result[
-2] # EOS
golden_result = "D E F G H I"
outputs = []
for output in llm.generate_async(prompts[0],
sampling_params=sampling_params,
_postproc_params=post_proc_params,
streaming=streaming):
outputs.append(output.outputs[0]._postprocess_result)
actual_result = get_concatenated_content(outputs)
assert actual_result == golden_result, \
f"Expected: {golden_result}, Actual: {actual_result}"


@pytest.mark.parametrize("streaming", [True, False])
Expand Down
1 change: 0 additions & 1 deletion tests/unittest/llmapi/test_llm_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ def test_llm_multi_node_pytorch(nworkers: int):

@skip_single_gpu
def test_llm_multi_node_with_postproc():
pytest.skip("https://nvbugspro.nvidia.com/bug/5327706")
nworkers = 2
test_case_file = os.path.join(os.path.dirname(__file__),
"run_llm_with_postproc.py")
Expand Down