-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Add multimodal_chat sample #43049
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add multimodal_chat sample #43049
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the word "sample" be in the file name given it's already in examples folder?
There's already examples/pytorch/text-generation/run_generation.py. Should they be merged?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that having sample in the filename is a bit redundant since this already lives under examples/. Renaming it to something like multimodal_chat.py would make it more consistent with other examples. As for merging with run_generation.py, I’d lean toward keeping this separate. This script demonstrates a multimodal chat workflow (vision + text), which feels conceptually different from generic text generation. Keeping it standalone helps with clarity and discoverability.
| for prompt in None, "What's the dog's color?": | ||
| if prompt: | ||
| messages.append({"role": "user", "content": [{"type": "text", "text": prompt}]}) | ||
| inputs = processor.apply_chat_template( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The image is recomputed every time. Can you suggest a plan and appropriate processor's methods to call to avoid that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the image does seem to be reprocessed on every iteration. A cleaner approach would be to preprocess the image once before entering the chat loop, cache the resulting vision features or image tokens, and then reuse them while only updating the text inputs each turn. This would avoid unnecessary computation and also make it clearer to readers that the image context is static while the conversation evolves.
i hope this may help.
| input_length = inputs["input_ids"].shape[1] | ||
| if prompt: | ||
| del inputs["pixel_values"] | ||
| del inputs["image_grid_thw"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I don't delete the entries, there's a mismatch of the number of image tokens during generate(). It seems past_key_values arg trims inputs['input_ids'] removing the image placeholders
Can you recommend a way to handle this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mismatch makes sense given how past_key_values trims input_ids and assumes a text-only growing sequence. One way to handle this more robustly would be to treat the image tokens as fixed context, either by caching the image features separately from the text cache or by explicitly reinserting the image placeholders before calling generate(). Adding a short comment explaining this constraint in the example would also go a long way in helping users understand why this extra handling is needed in a multimodal setup.
rest looks normal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that having sample in the filename is a bit redundant since this already lives under examples/. Renaming it to something like multimodal_chat.py would make it more consistent with other examples. As for merging with run_generation.py, I’d lean toward keeping this separate. This script demonstrates a multimodal chat workflow (vision + text), which feels conceptually different from generic text generation. Keeping it standalone helps with clarity and discoverability.
| for prompt in None, "What's the dog's color?": | ||
| if prompt: | ||
| messages.append({"role": "user", "content": [{"type": "text", "text": prompt}]}) | ||
| inputs = processor.apply_chat_template( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the image does seem to be reprocessed on every iteration. A cleaner approach would be to preprocess the image once before entering the chat loop, cache the resulting vision features or image tokens, and then reuse them while only updating the text inputs each turn. This would avoid unnecessary computation and also make it clearer to readers that the image context is static while the conversation evolves.
i hope this may help.
| input_length = inputs["input_ids"].shape[1] | ||
| if prompt: | ||
| del inputs["pixel_values"] | ||
| del inputs["image_grid_thw"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mismatch makes sense given how past_key_values trims input_ids and assumes a text-only growing sequence. One way to handle this more robustly would be to treat the image tokens as fixed context, either by caching the image features separately from the text cache or by explicitly reinserting the image placeholders before calling generate(). Adding a short comment explaining this constraint in the example would also go a long way in helping users understand why this extra handling is needed in a multimodal setup.
rest looks normal.
|
One possible way to address both the repeated image recomputation and the # --- Preprocess image ONCE ---
vision_inputs = processor(
images=image,
return_tensors="pt"
).to(device)
# Cache vision-related tensors (keep them static)
cached_vision = {
k: v for k, v in vision_inputs.items()
if k != "input_ids"
}
past_key_values = None
generated_text = ""
for user_message in chat_history:
# --- Text-only processing per turn ---
text_inputs = processor(
text=user_message,
return_tensors="pt"
).to(device)
# --- Merge cached vision + new text ---
model_inputs = {
**cached_vision,
"input_ids": text_inputs["input_ids"],
"attention_mask": text_inputs["attention_mask"],
"past_key_values": past_key_values,
}
outputs = model.generate(
**model_inputs,
max_new_tokens=128,
use_cache=True,
)
past_key_values = outputs.past_key_values
generated_text += processor.decode(outputs[0], skip_special_tokens=True)If the pattern is adopted, the scope of this PR becomes very clear and strong: it serves as a reference-quality multimodal chat example that demonstrates best practices for image caching and correct use of text-only KV caching in a multimodal setting, remember not to touch the core modules or api's in this pr. Also this is a simple workaround, any other approach is warmly welcome. |
What does this PR do?
Add a sample covering kv cache management
Fixes #43045
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.