Skip to content

Conversation

Fan-Yunfan
Copy link
Contributor

@Fan-Yunfan Fan-Yunfan commented Jul 25, 2025

Problem

Qwen3 (dense) models have no attn bias,and in the convert_hf_qwen method in TensorRT-LLM/tensorrt_llm/models/qwen/convert.py, the calculation of qkv_bias does not account for the case where attn_bias is None.

A None object has no shape attribute and cannot undergo split or concat operations, leading to an error in older TensorRT-LLM versions (e.g., v0.15.0).

The following PR fixes how to ensure compatibility with the Qwen3-to-engine conversion process in older versions of TensorRT-LLM (e.g., v0.15.0).

Current Implementation

assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0
assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0
assert (k_bias.shape[0] % (mapping.tp_size * head_size)) == 0
assert (v_bias.shape[0] % (mapping.tp_size * head_size)) == 0

wq = split(q_weight, mapping.tp_size, mapping.tp_rank)
wk = split(k_weight, mapping.tp_size, mapping.tp_rank)
wv = split(v_weight, mapping.tp_size, mapping.tp_rank)

bq = split(q_bias, mapping.tp_size, mapping.tp_rank)
bk = split(k_bias, mapping.tp_size, mapping.tp_rank)
bv = split(v_bias, mapping.tp_size, mapping.tp_rank)

qkv_w = torch.concat((wq, wk, wv))
qkv_b = torch.concat((bq, bk, bv))

Solution

Add a check for attn_bias being None before performing related calculations, and apply conditional logic to the final result qkv_b. If any of bq, bk, or bv is None, set qkv_b to None as well.

assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0
assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0
if k_bias is not None and v_bias is not None:
    assert (k_bias.shape[0] % (mapping.tp_size * head_size)) == 0
    assert (v_bias.shape[0] % (mapping.tp_size * head_size)) == 0

wq = split(q_weight, mapping.tp_size, mapping.tp_rank)
wk = split(k_weight, mapping.tp_size, mapping.tp_rank)
wv = split(v_weight, mapping.tp_size, mapping.tp_rank)

qkv_w = torch.concat((wq, wk, wv))

if q_bias is not None and k_bias is not None and v_bias is not None:
    bq = split(q_bias, mapping.tp_size, mapping.tp_rank)
    bk = split(k_bias, mapping.tp_size, mapping.tp_rank)
    bv = split(v_bias, mapping.tp_size, mapping.tp_rank)
    qkv_b = torch.concat((bq, bk, bv))
else:
    qkv_b = None

Related Issues:

#6295

Additional Notes:

On top of the main branch, after fixing the aforementioned qkv_bias issue, an additional fix for the head_dim problem in TensorRT-LLM/tensorrt_llm/models/qwen/config.py is required to enable a functional workflow for converting Qwen3 models into engines.

The fix for this issue can be found in another contributor's PR: #5913.

To ensure the successful conversion of Qwen3 models, this PR also incorporates the aforementioned head_dim fix. This is explicitly stated here for clarity.

Summary by CodeRabbit

Summary by CodeRabbit

  • Bug Fixes
    • Improved handling of model configuration to support explicit head dimension settings when available.
    • Enhanced robustness of tensor parallel processing by conditionally handling key and value biases, preventing errors when some biases are missing.

Copy link
Contributor

coderabbitai bot commented Jul 25, 2025

📝 Walkthrough

"""

Walkthrough

The changes update the handling of the head_dim attribute in the QWen configuration loader to prefer an explicit attribute if present, and refine bias splitting logic in the QWen model converter to conditionally process biases only when all are available, preventing errors with missing biases.

Changes

File(s) Change Summary
QWen configuration update tensorrt_llm/models/qwen/config.py
Updated QWenConfig.from_hugging_face to use getattr for head_dim, falling back to calculation.
QWen model conversion adjustment tensorrt_llm/models/qwen/convert.py
Made bias splitting and assertions conditional on presence of all relevant biases in convert_hf_qwen.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~7 minutes

Possibly related issues

Suggested reviewers

  • lucifer1004
  • kaiyux
    """

Note

⚡️ Unit Test Generation is now available in beta!

Learn more here, or try it out under "Finishing Touches" below.


📜 Recent review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d157f8f and a0d0511.

📒 Files selected for processing (2)
  • tensorrt_llm/models/qwen/config.py (1 hunks)
  • tensorrt_llm/models/qwen/convert.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tensorrt_llm/models/qwen/config.py
  • tensorrt_llm/models/qwen/convert.py
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@coderabbitai coderabbitai bot requested review from kaiyux and lucifer1004 July 25, 2025 00:16
@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Jul 25, 2025
@Fan-Yunfan
Copy link
Contributor Author

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tensorrt_llm/models/qwen/convert.py (1)

541-557: Good fix for handling None biases, but clean up whitespace.

The conditional logic correctly addresses the issue where Qwen3 models may have None attention biases. The implementation properly:

  • Only asserts divisibility when biases exist (lines 541-543)
  • Only performs bias concatenation when all biases are present (lines 551-555)
  • Sets qkv_b to None when any bias is missing (line 557)

Fix the trailing whitespace on line 548:

-                
+
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b3243b7 and d7d0ed8.

📒 Files selected for processing (1)
  • tensorrt_llm/models/qwen/convert.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/models/qwen/convert.py

539-539: Trailing whitespace

(W291)


539-540: SyntaxError: Expected an expression


540-540: Blank line contains whitespace

(W293)


548-548: Blank line contains whitespace

(W293)

@Fan-Yunfan Fan-Yunfan force-pushed the main branch 2 times, most recently from 66ada37 to d157f8f Compare July 25, 2025 01:09
@Fan-Yunfan
Copy link
Contributor Author

/bot run

@gkswns0531
Copy link
Contributor

Would you mind kindly checking if this is not overlapping with the functionality below?

Qwen3 Dense Support: #5650
fix_dim: #5913

@Fan-Yunfan
Copy link
Contributor Author

Fan-Yunfan commented Jul 25, 2025

Would you mind kindly checking if this is not overlapping with the functionality below?

Qwen3 Dense Support: #5650 fix_dim: #5913

Qwen3 Dense Support: #5650
fix_dim: #5913

Thank you for your comment.
Due to certain reasons, our TensorRT-LLM version does not yet support the 1.00rc release, and we are currently using version v0.15.0. In fact, I initially based my modifications on your two PRs: Qwen3 Dense Support: #5650 and fix_dim: #5913.

However, I found that even after incorporating your modifications, the engine conversion still fails in v0.15.0 due to the issue I mentioned earlier—where q_bias, k_bias, and v_bias are None. Since the code lacked the necessary checks, it resulted in syntax errors. After applying my additional fixes, the engine conversion process now completes successfully.

Therefore, my PRs mentioned earlier are not about adding support for converting Qwen3 engines but rather about fixing bugs that occur when converting Qwen3 models in older versions of TensorRT-LLM (e.g., v0.15.0).

@Fan-Yunfan
Copy link
Contributor Author

Would you mind kindly checking if this is not overlapping with the functionality below?

Qwen3 Dense Support: #5650 fix_dim: #5913

My modifications include your PR fix_dim: #5913, and I've provided an explanation in the PR comments regarding this inclusion.

The reason I incorporated your PR's changes is that my modifications must be built upon both of your PRs: Qwen3 Dense Support: #5650 and fix_dim: #5913. I noticed that your fix_dim: #5913 hasn't been merged yet, and I was concerned that if my PR were merged while yours remained unmerged, converting Qwen3 in v0.15.0 would still fail. Therefore, I proactively included the content from your PR.

Would you like me to separate your PR's changes from mine? If needed, I can resubmit my modifications to exclude your PR.

@gkswns0531
Copy link
Contributor

@fyf2016
Ah, I see — thank you for the explanation. There’s no need to exclude it. If your PR does get merged, I’d appreciate it if you could kindly ping me. I’ll go ahead and close my fix PR. Thank you very much!

@Fan-Yunfan
Copy link
Contributor Author

Fan-Yunfan commented Jul 25, 2025

@fyf2016 Ah, I see — thank you for the explanation. There’s no need to exclude it. If your PR does get merged, I’d appreciate it if you could kindly ping me. I’ll go ahead and close my fix PR. Thank you very much!

Got it ~ Thank you very much for your contribution to the Qwen3 engine conversion—it really helps me a lot.

thanks

@byshiue
Copy link
Collaborator

byshiue commented Jul 28, 2025

/bot run

Copy link
Collaborator

@byshiue byshiue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13120 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13120 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #9816 completed with status: 'FAILURE'

@Fan-Yunfan
Copy link
Contributor Author

/bot run

@byshiue
Copy link
Collaborator

byshiue commented Jul 28, 2025

@fyf2016 Please run pre-commit install && pre-commit run -a to fix the issue of code format.

@Fan-Yunfan
Copy link
Contributor Author

/bot run

1 similar comment
@byshiue
Copy link
Collaborator

byshiue commented Jul 28, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13153 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13153 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9847 completed with status: 'FAILURE'

@byshiue
Copy link
Collaborator

byshiue commented Jul 28, 2025

@fyf2016 Please don't update the code frequently if there is not code conflict. If you don't update, we can reuse the status of last CI and only run the fail jobs on next CI.

@byshiue
Copy link
Collaborator

byshiue commented Jul 28, 2025

/bot run

@Fan-Yunfan
Copy link
Contributor Author

Fan-Yunfan commented Jul 28, 2025

@fyf2016 Please don't update the code frequently if there is not code conflict. If you don't update, we can reuse the status of last CI and only run the fail jobs on next CI.

Thank you for your patient reminder, I understand~

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13183 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13183 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9875 completed with status: 'FAILURE'

@byshiue
Copy link
Collaborator

byshiue commented Jul 29, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13282 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13282 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9921 completed with status: 'SUCCESS'

@byshiue byshiue merged commit 1a8e28d into NVIDIA:main Jul 29, 2025
2 checks passed
@Fan-Yunfan
Copy link
Contributor Author

Fan-Yunfan commented Jul 30, 2025

@gkswns0531 Hi brother~, our PR(The fix_dim section incorporates your PR #5913) has been successfully merged into the master branch~ Thank you for your contribution to fixing the dim issue: fix_dim: #5913 . It has truly helped me and others who encountered the same problem.

image

Additionally, a special thanks to the community staff member @byshiue for your patient guidance and support, especially when I was unfamiliar with the process.

image

lancelly pushed a commit to lancelly/TensorRT-LLM that referenced this pull request Aug 6, 2025
…rt engine (NVIDIA#6344)

Signed-off-by: fanyunfan <[email protected]>
Co-authored-by: fanyunfan <[email protected]>
Signed-off-by: Lanyu Liao <[email protected]>
jain-ria pushed a commit to jain-ria/TensorRT-LLM that referenced this pull request Aug 7, 2025
solrex pushed a commit to solrex/TensorRT-LLM that referenced this pull request Sep 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants