Skip to content

Conversation

Wanli-Jiang
Copy link
Collaborator

@Wanli-Jiang Wanli-Jiang commented Sep 23, 2025

(img=512x512, ISL=256, concurrency=16)

OSL 64 64 64 4096 4096 4096
backend TRTLLM-prev TRTLLM-opt TRTLLM-opt + torch.compile TRTLLM-prev TRTLLM-opt TRTLLM-opt + torch.compile
Total input tokens 32768 32768 32768 32768 32768 32768
Total generated tokens 8192 8192 8192 524288 524288 524288
Request throughput (req/s) 6.34 10.4 11 0.37 0.35 0.36
Output token throughput (tok/s) 405.95 665.81 703.71 1501.59 1451.83 1458.7
Total Token throughput (tok/s) 2029.75 3329.06 3518.53 1595.44 1542.56 1549.87
User throughput (tok/s) 27 42.59 44.41 94.14 90.78 91.21
Mean TTFT (ms) 1886.5 634.86 637.03 604.43 755.69 627.75
Median TTFT (ms) 1846.9 685.85 629.17 340.62 728.89 642.02
P99 TTFT (ms) 2840.42 1065.85 1055.93 2801.31 1171.86 970.82
Mean TPOT (ms) 8.63 14.2 12.92 10.48 10.84 10.82
Median TPOT (ms) 9.48 13.97 12.61 10.53 10.92 10.91
P99 TPOT (ms) 23.41 21.03 20.45 10.58 11.14 11.08
Mean ITL (ms) 8.5 13.98 12.71 10.48 10.84 10.81
Median ITL (ms) 0.1 8.45 8.44 10.23 10.71 10.69
P99 ITL (ms) 313.12 167.56 157.78 13.35 14.02 13.83

Summary by CodeRabbit

  • New Features

    • Optimized image processing with a streamlined SigLip-based path for faster, more efficient vision encoding.
    • Batch-oriented multimodal encoding for improved throughput.
    • Exposed device paths for multimodal buffers to simplify integration.
    • Optional runtime acceleration via compile-time toggle.
  • Performance

    • Reduced data transfers and memory overhead during image preprocessing.
    • Lower latency for multimodal (language/vision/speech) pipelines.
  • Chores

    • Skips loading obsolete vision head weights, reducing model load time and footprint.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@Wanli-Jiang
Copy link
Collaborator Author

/bot run --disable-fail-fast

Copy link
Contributor

coderabbitai bot commented Sep 23, 2025

📝 Walkthrough

Walkthrough

Introduces runtime torch.compile toggle, new optimized paths for SigLIP image embedding and vision encoder, dynamic and full image preprocessing utilities, public NoOp layer and InputMode enum, bindings/injections into encoder and processor, batch-oriented encoding in forward, and weight-loading skip for removed vision head, all within modeling_phi4mm.py.

Changes

Cohort / File(s) Summary
Runtime compile toggle & constants
tensorrt_llm/_torch/models/modeling_phi4mm.py
Added _is_torch_compile(), applied @torch.compile to encoder and model forward; introduced _BASE_RESOLUTION and _MASK_RESOLUTION.
Vision encoder optimization
tensorrt_llm/_torch/models/modeling_phi4mm.py
Added optimized_vision_encoder_forward; fused global/sub-image features; integrated into text embeddings; replaced vision head with NoOp.
SigLIP embedding optimization
tensorrt_llm/_torch/models/modeling_phi4mm.py
Added optimized_siglip_embedding_forward with precomputed coords/positional embeddings; bound to internal SigLIP module.
Image preprocessing utilities
tensorrt_llm/_torch/models/modeling_phi4mm.py
Added optimized_dynamic_preprocess, optimized_preprocess, and _reshape_attention_masks; produce resized tensors, attention masks, and embeddings.
Encoder class updates
tensorrt_llm/_torch/models/modeling_phi4mm.py
HFPhi4MultimodalEncoder: bound optimized forwards; applied nvtx_range on _encoding_batch_request; removed _replace_special_token_ids; compiled forward.
Input processor updates
tensorrt_llm/_torch/models/modeling_phi4mm.py
Phi4MMInputProcessor: bound dynamic_preprocess and preprocess to optimized variants; integrated device/dtype handling; updated multimodal token ID handling.
Model forward and weights
tensorrt_llm/_torch/models/modeling_phi4mm.py
Phi4MMForCausalLM: added multimodal_data_device_paths property; compiled forward; switched to batch-oriented encoding; load_weights skips image head keys.
Public symbols
tensorrt_llm/_torch/models/modeling_phi4mm.py
Added NoOp module and InputMode enum (LANGUAGE, VISION, SPEECH, VISION_SPEECH).

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant App as Caller
  participant LM as Phi4MMForCausalLM.forward
  participant Enc as HFPhi4MultimodalEncoder.forward
  participant Proc as ImageProcessor (optimized_*)
  participant Vis as Vision Encoder (optimized)
  participant Txt as Text Embeddings/LM Head

  App->>LM: input_ids, images, audios, masks
  note right of LM: @torch.compile (runtime toggle)
  LM->>Enc: batch inputs (ids, image tensors/masks, audio)
  Enc->>Proc: optimized_preprocess / dynamic_preprocess
  Proc-->>Enc: input_image_embeds, image_attention_mask, sizes
  Enc->>Vis: optimized_vision_encoder_forward(...)
  Vis-->>Enc: fused image features (SigLIP path)
  Enc->>Txt: integrate image/audio features with text
  Txt-->>LM: logits
  LM-->>App: logits/output
Loading
sequenceDiagram
  autonumber
  participant Loader as load_weights
  participant Store as State Dict
  participant Model as Phi4MMForCausalLM
  note over Loader,Model: Skip replaced vision head weights
  Loader->>Store: iterate keys
  alt key startswith "model.embed_tokens_extend.image_embed.img_processor.head."
    Loader-->>Model: skip load
  else
    Loader->>Model: load weight
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description Check ⚠️ Warning The pull request description does not follow the required template because the title field is missing, the Description section is empty, and the Test Coverage section is not provided. It still contains placeholders and commented guidance instead of actual content. As a result, the description does not clearly explain what the changes are or how they are tested. Please update the PR to include a properly formatted title following the template, fill in the Description section with a concise explanation of the issue and solution, and list the relevant tests in the Test Coverage section.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title follows the repository convention by including a valid JIRA ticket, type tag, and a concise summary of the primary change—optimizing the phi4-mm image modality inference—making it clear and focused on the main enhancement.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (6)
tensorrt_llm/_torch/models/modeling_phi4mm.py (6)

111-116: Avoid hyphens in module names when using importlib.

Hyphens in spec_from_file_location module name can cause oddities. Use an identifier.

-        spec = importlib.util.spec_from_file_location(
-            "Phi-4-multimodal-instruct.hf_modeling_phi4mm",
-            modeling_phi4mm_path)
+        spec = importlib.util.spec_from_file_location(
+            "phi4mm_hf_modeling",
+            modeling_phi4mm_path)

491-498: Nit: avoid list concatenation inside torch.cat.

-        torch.cat([_global_image] + [_im], dim=0)
+        torch.cat([_global_image, _im], dim=0)
-        torch.cat([_global_mask] + [_mask],
+        torch.cat([_global_mask, _mask],
                   dim=0) ...

528-529: Silence unused args in NoOp.forward.

-    def forward(self, *args, **kwargs):
+    def forward(self, *_args, **_kwargs):
         return None

615-623: Use boolean dtype and correct device for image attention masks.

Allocating masks with embed dtype is misleading and wastes memory.

-            batched_image_attn_mask = torch.zeros(
-                (total_b, max_p, h_i_attn, w_i_attn),
-                dtype=input_image_embeds_list[0].dtype,
-                device=input_image_embeds_list[0].device)
+            batched_image_attn_mask = torch.zeros(
+                (total_b, max_p, h_i_attn, w_i_attn),
+                dtype=torch.bool,
+                device=input_image_embeds_list[0].device)
...
-                else:
-                    batched_image_attn_mask[b_offset:b_offset + b, :p] = 1
+                else:
+                    batched_image_attn_mask[b_offset:b_offset + b, :p] = True

Also applies to: 627-633


671-675: Same as above for audio attention masks.

-            batched_audio_attn_mask = torch.zeros(
-                (total_b, max_p),
-                dtype=input_audio_embeds_list[0].dtype,
-                device=input_audio_embeds_list[0].device)
+            batched_audio_attn_mask = torch.zeros(
+                (total_b, max_p),
+                dtype=torch.bool,
+                device=input_audio_embeds_list[0].device)
...
-                else:
-                    batched_audio_attn_mask[b_offset:b_offset + b, :p] = 1
+                else:
+                    batched_audio_attn_mask[b_offset:b_offset + b, :p] = True

Also applies to: 679-685


187-237: Minor: avoid CPU sync via Python int conversions in hot path.

int(tensor) forces D2H sync. If feasible, keep in tensor form until the final reshape/index step.

Consider computing useful_height/width with tensor ops and slicing via tensor indices where possible.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5792464 and 2bb6a97.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/models/modeling_phi4mm.py (15 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/models/modeling_phi4mm.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/models/modeling_phi4mm.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/models/modeling_phi4mm.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/models/modeling_phi4mm.py (4)
tensorrt_llm/_utils.py (1)
  • nvtx_range (857-876)
tensorrt_llm/_torch/models/modeling_multimodal_utils.py (2)
  • find_closest_aspect_ratio (367-381)
  • dynamic_preprocess (384-435)
tensorrt_llm/inputs/multimodal.py (1)
  • MultimodalParams (196-520)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (2)
  • multimodal_data_device_paths (978-983)
  • multimodal_data_device_paths (1030-1041)
🪛 Ruff (0.13.1)
tensorrt_llm/_torch/models/modeling_phi4mm.py

375-376: Avoid specifying long messages outside the exception class

(TRY003)


440-442: Avoid specifying long messages outside the exception class

(TRY003)


491-491: Consider [_global_image, _im] instead of concatenation

Replace with [_global_image, _im]

(RUF005)


495-495: Consider [_global_mask, _mask] instead of concatenation

Replace with [_global_mask, _mask]

(RUF005)


528-528: Unused method argument: args

(ARG002)


528-528: Unused method argument: kwargs

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/models/modeling_phi4mm.py (2)

743-747: Confirm torch.compile decorator order with inference_mode.

@torch.compile wrapping @torch.inference_mode() can affect graph capture. Verify perf/correctness for variable crop counts.

Would you like me to benchmark both orders and report compile success rates?


955-960: Head replaced with NoOp — confirm no call sites expect a Tensor.

Repo search for "img_processor.head(" returned no matches; if any callers invoke img_processor.head(...) returning None will crash them — ensure no callers exist or change the NoOp to return a Tensor (e.g., an identity tensor). File: tensorrt_llm/_torch/models/modeling_phi4mm.py lines 955–960.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19651 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19651 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14786 completed with status: 'SUCCESS'

@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/phi4mm-opt branch from 2bb6a97 to a88d73d Compare September 23, 2025 07:42
Copy link
Collaborator

@yechank-nvidia yechank-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think naming with 'optimized' is little bit awkward. Can you remove them all?

@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/phi4mm-opt branch from a88d73d to 6e6f246 Compare September 24, 2025 08:16
@amukkara
Copy link
Collaborator

@Wanli-Jiang Can you add a summary of perf improvement in PR description? Just the TLLM speedup compared to previous implementation.

@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/phi4mm-opt branch from 6e6f246 to bf956f1 Compare September 25, 2025 01:09
@Wanli-Jiang
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19859 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19859 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14944 completed with status: 'FAILURE'

@Wanli-Jiang
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19898 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19898 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14976 completed with status: 'SUCCESS'

@Wanli-Jiang Wanli-Jiang merged commit 22b45ff into NVIDIA:main Sep 25, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants