Skip to content

Conversation

@Wovchena
Copy link

What does this PR do?

Add a sample covering kv cache management

Fixes #43045

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the word "sample" be in the file name given it's already in examples folder?
There's already examples/pytorch/text-generation/run_generation.py. Should they be merged?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that having sample in the filename is a bit redundant since this already lives under examples/. Renaming it to something like multimodal_chat.py would make it more consistent with other examples. As for merging with run_generation.py, I’d lean toward keeping this separate. This script demonstrates a multimodal chat workflow (vision + text), which feels conceptually different from generic text generation. Keeping it standalone helps with clarity and discoverability.

for prompt in None, "What's the dog's color?":
if prompt:
messages.append({"role": "user", "content": [{"type": "text", "text": prompt}]})
inputs = processor.apply_chat_template(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The image is recomputed every time. Can you suggest a plan and appropriate processor's methods to call to avoid that?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the image does seem to be reprocessed on every iteration. A cleaner approach would be to preprocess the image once before entering the chat loop, cache the resulting vision features or image tokens, and then reuse them while only updating the text inputs each turn. This would avoid unnecessary computation and also make it clearer to readers that the image context is static while the conversation evolves.

i hope this may help.

input_length = inputs["input_ids"].shape[1]
if prompt:
del inputs["pixel_values"]
del inputs["image_grid_thw"]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I don't delete the entries, there's a mismatch of the number of image tokens during generate(). It seems past_key_values arg trims inputs['input_ids'] removing the image placeholders

Can you recommend a way to handle this?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mismatch makes sense given how past_key_values trims input_ids and assumes a text-only growing sequence. One way to handle this more robustly would be to treat the image tokens as fixed context, either by caching the image features separately from the text cache or by explicitly reinserting the image placeholders before calling generate(). Adding a short comment explaining this constraint in the example would also go a long way in helping users understand why this extra handling is needed in a multimodal setup.

rest looks normal.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that having sample in the filename is a bit redundant since this already lives under examples/. Renaming it to something like multimodal_chat.py would make it more consistent with other examples. As for merging with run_generation.py, I’d lean toward keeping this separate. This script demonstrates a multimodal chat workflow (vision + text), which feels conceptually different from generic text generation. Keeping it standalone helps with clarity and discoverability.

for prompt in None, "What's the dog's color?":
if prompt:
messages.append({"role": "user", "content": [{"type": "text", "text": prompt}]})
inputs = processor.apply_chat_template(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the image does seem to be reprocessed on every iteration. A cleaner approach would be to preprocess the image once before entering the chat loop, cache the resulting vision features or image tokens, and then reuse them while only updating the text inputs each turn. This would avoid unnecessary computation and also make it clearer to readers that the image context is static while the conversation evolves.

i hope this may help.

input_length = inputs["input_ids"].shape[1]
if prompt:
del inputs["pixel_values"]
del inputs["image_grid_thw"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mismatch makes sense given how past_key_values trims input_ids and assumes a text-only growing sequence. One way to handle this more robustly would be to treat the image tokens as fixed context, either by caching the image features separately from the text cache or by explicitly reinserting the image placeholders before calling generate(). Adding a short comment explaining this constraint in the example would also go a long way in helping users understand why this extra handling is needed in a multimodal setup.

rest looks normal.

@CodersAcademy006
Copy link

One possible way to address both the repeated image recomputation and the past_key_values / image-token mismatch is to treat the image as fixed context and cache its features outside the chat loop. The image is preprocessed once, its vision-related tensors are reused for every turn, and only the text inputs grow while updating past_key_values. This avoids recomputing vision features on each iteration and prevents image placeholder tokens from being trimmed during generation.

# --- Preprocess image ONCE ---
vision_inputs = processor(
    images=image,
    return_tensors="pt"
).to(device)

# Cache vision-related tensors (keep them static)
cached_vision = {
    k: v for k, v in vision_inputs.items()
    if k != "input_ids"
}

past_key_values = None
generated_text = ""

for user_message in chat_history:
    # --- Text-only processing per turn ---
    text_inputs = processor(
        text=user_message,
        return_tensors="pt"
    ).to(device)

    # --- Merge cached vision + new text ---
    model_inputs = {
        **cached_vision,
        "input_ids": text_inputs["input_ids"],
        "attention_mask": text_inputs["attention_mask"],
        "past_key_values": past_key_values,
    }

    outputs = model.generate(
        **model_inputs,
        max_new_tokens=128,
        use_cache=True,
    )

    past_key_values = outputs.past_key_values
    generated_text += processor.decode(outputs[0], skip_special_tokens=True)

If the pattern is adopted, the scope of this PR becomes very clear and strong: it serves as a reference-quality multimodal chat example that demonstrates best practices for image caching and correct use of text-only KV caching in a multimodal setting, remember not to touch the core modules or api's in this pr.

Also this is a simple workaround, any other approach is warmly welcome.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Multimodal chat sample

2 participants