Skip to content

kv-cache : avoid modifying recurrent cells when setting inputs #13834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 10, 2025

Conversation

compilade
Copy link
Collaborator

@compilade compilade commented May 27, 2025

NOTE: this targets #13746, not master. (master is now the base branch of this PR since #13746 was merged)

@ggerganov As discussed in #9126 (comment), this ports some of the changes from #9126 to #13746 for recurrent caches.

It mostly works, but there is still something wrong somewhere indicated by non-consecutive
token positions when running mamba-130m-hf with llama-parallel:

$ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
...
find_slot: non-consecutive token position 0 after 335 for sequence 1 with 1 new tokens
...
find_slot: non-consecutive token position 0 after 346 for sequence 3 with 1 new tokens
...

This was not a problem in #9126, so it might (or not) relate to other recent changes in how the kv cache is handled.

I'll attempt to figure out what is wrong by updating #9126 to the latest master and see if the problem also appears (EDIT: it does cause the problem to appear).

Is there any recent change in how kv_cell.pos is handled which comes to mind and could cause this? (pos is not reset properly between re-uses of clients in parallel, at least for recurrent models)

I'm suspecting #13598 might be related.

EDIT: it seems like with -pps, the problem isn't there, and this was the default behavior before #13598.


Make sure to read the contributing guidelines before submitting a PR

@ggerganov
Copy link
Member

ggerganov commented May 27, 2025

The find_slot: non-consecutive token position 0 after 346 for sequence 3 with 1 new tokens warning was what I considered to be broken. I didn't investigate the source of the warning and assumed that it is caused by the cells being modified during compute. I assumed that, because if I disable the "preparation phase", the warning does not appear.

To clarify, the "preparation phase" is where we simply insert the ubatches in the cells to make sure they fit and then we revert back to the initial state as if no changes were ever made. We then start to insert the ubatches and compute them one-by-one. So I was confused why the warning disappears when the "preparation phase" is skipped and the only explanation I had was because of the updates during the compute.

Comment on lines 2311 to 2316
for (const auto & ubatch : ubatches) {
if (!find_slot(ubatch)) {
success = false;
break;
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the "preparation phase".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. When I leave this out, the non-consecutive token position problem is still there (the only difference is that the warnings are not duplicated), so I don't think the preparation phase is the source of the problem.

When using -pps the warnings are gone, though. Maybe this was an existing problem which was surfaced by not making -pps the default in parallel.cpp. I'll try to find the source of the kv_cell.pos discrepancy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I confirm that the warnings are still there even without preparation - both on this branch and on the target branch. It's possible that I hallucinated that the warnings disappear without preparation.

When using -pps the warnings are gone, though. Maybe this was an existing problem which was surfaced by not making -pps the default in parallel.cpp. I'll try to find the source of the kv_cell.pos discrepancy.

You are likely right. Will take a look as well tomorrow.

Copy link
Collaborator Author

@compilade compilade May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's another problem which is detectable by running llama-parallel with -pps and deterministic settings and comparing the output at different -ub. -ub 1 is always fine, but -ub 2 seems to produce weird output in this branch (like swapping answers). Default -ub of 512 also manifests this problem.

For example (click to expand)

Towards the end the output is weird.

$ ./bin/llama-parallel -m /path/to/mamba-130M-hf-F16.gguf -np 5 -ns 8 --temp 0 --repeat-penalty 1.1 -ub 2 -pps
...
llama_context: constructing llama_context
llama_context: n_seq_max     = 6
llama_context: n_ctx         = 4096
llama_context: n_ctx_per_seq = 682
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 2
llama_context: causal_attn   = 1
llama_context: flash_attn    = 0
llama_context: freq_base     = 10000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_per_seq (682) < n_ctx_train (1048576) -- the full capacity of the model will not be utilized
llama_context:        CPU  output buffer size =     1.15 MiB
llama_kv_cache_recurrent: kv_size = 6, n_seq_max = 6, type_k = 'f32', type_v = 'f32', n_layer = 24
llama_kv_cache_recurrent:        CPU KV buffer size =    16.03 MiB
llama_kv_cache_recurrent: KV self size  =   16.03 MiB, K (f32):    2.53 MiB, V (f32):   13.50 MiB
llama_context:        CPU compute buffer size =     1.34 MiB
llama_context: graph nodes  = 1278
llama_context: graph splits = 1
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
No new questions so proceed with build-in defaults.


main: Simulating parallel requests from clients:
main: n_parallel = 5, n_sequences = 8, cont_batching = 1, system tokens = 260

main: Evaluating the system prompt ...

Processing requests ...

main: clearing the KV cache
Client   0, seq    0, started decoding ...
Client   1, seq    1, started decoding ...
Client   2, seq    2, started decoding ...
Client   3, seq    3, started decoding ...
Client   4, seq    4, started decoding ...
Client   0, seq   0/  8, prompt   15 t, response   26 t, time  4.30 s, speed  9.54 t/s, cache miss 0  

Input:    What is the meaning of life?
Response: Life is a series of events, and the best way to understand them are by looking at what happens in each one.

Client   2, seq   2/  8, prompt   15 t, response   26 t, time  4.30 s, speed  9.54 t/s, cache miss 0  

Input:    What is the meaning of life?
Response: Life is a series of events, and the best way to understand them are by looking at what happens in each one.

Client   0, seq    5, started decoding ...
Client   2, seq    6, started decoding ...
Client   2, seq   6/  8, prompt   21 t, response   11 t, time  2.01 s, speed 15.93 t/s, cache miss 0  

Input:    If you could have any superpower, what would it be?
Response: I would like to be a superpower.

Client   2, seq    7, started decoding ...
Client   3, seq   3/  8, prompt   26 t, response   39 t, time  6.84 s, speed  9.51 t/s, cache miss 0  

Input:    Are you familiar with the Special Theory of Relativity and can you explain it to me?
Response: I recommend the steak. It is a very good steak, and it's easy to cook with your hands on the stove or in a skillet if you have one handy at all times!

Client   2, seq   7/  8, prompt   28 t, response   10 t, time  1.39 s, speed 27.28 t/s, cache miss 0  

Input:    I want to learn how to play the piano. What would be the best way to do it?
Response: I would suggest you play the piano.

Client   4, seq   4/  8, prompt   16 t, response   49 t, time  7.70 s, speed  8.44 t/s, cache miss 0  

Input:    Recommend some interesting books to read.
Response: I recommend the book "The Golden Duck" by Richard Feynman. It is a fascinating and entertaining book that is written in a very readable style, and it is also a great way to learn about physics at school or college level!

Client   0, seq   5/  8, prompt   26 t, response   65 t, time  5.27 s, speed 17.28 t/s, cache miss 0  

Input:    Are you familiar with the Special Theory of Relativity and can you explain it to me?
Response: I am a physicist and I have been working on this theory for over 20 years. It is a very simple theory that describes the behavior of particles in a vacuum, but it has many interesting properties such as the existence or nonexistence theorems etc., which are not known to us yet because we do know about them.

Client   1, seq   1/  8, prompt   18 t, response  128 t, time 10.40 s, speed 14.04 t/s, cache miss 0  

Input:    What is the best way to cook a steak?
Response: The meaning of life is the pursuit and enjoyment that comes from living. It is a state in which one lives, and it is an important part to be alive at all times; it is a few years ago I was asked by my friend who works for me how I would like to work with him on this project. He said he would love if we could do some research into the effects of alcohol consumption among young people in the UK, and that he would be happy when we had a chance interview him about it.
I have been working at The Institute of Alcohol Studies for over 20 years now (and I am still working there), so

main: clearing the KV cache

run parameters as of 2025-05-27 23:50:32

main: n_parallel = 5, n_sequences = 8, cont_batching = 1, system tokens = 260
External prompt file: used built-in defaults
Model and path used:  /path/to/mamba-130M-hf-F16.gguf

Total prompt tokens:    165, speed: 11.41 t/s
Total gen tokens:       354, speed: 24.48 t/s
Total speed (AVG):           speed: 35.88 t/s
Cache misses:             0

llama_perf_context_print:        load time =     120.10 ms
llama_perf_context_print: prompt eval time =   13517.87 ms /   743 tokens (   18.19 ms per token,    54.96 tokens per second)
llama_perf_context_print:        eval time =     823.75 ms /    36 runs   (   22.88 ms per token,    43.70 tokens per second)
llama_perf_context_print:       total time =   14465.21 ms /   779 tokens

Notice how the last question about the steak is somehow answered as if the meaning of life was asked.

This does not happen with -ub 1 or with -np 1.


The problem does not exist in #13746, but does exist in #9126 since at least 35d06fa, (but not on the corresponding commit on master) which means the problem is most likely related to how I changed how the recurrent states are copied (which is weird because I think it did work at some point).

That your branch doesn't manifest the same problem narrows this to my new changes. Still not sure what exactly is the root cause; hopefully I'll find it soon.

This might or might not be related to the non-consecutive kv_cell.pos problem. It's likely a different problem.

(EDIT: it's definitely a different problem, because it still happens after bdbfb4e even though the non-consecutive kv_cell.pos problem has been fixed (the tail cell ids were not swapped correctly))

(EDIT2: Ok, I think I know what is happening, the first zero-ed cell sometimes is used as a source for non-zeroed states and this messes up some things. The fix will need to prevent that situation (EDIT3: somehow that's not sufficient; that was not the root cause))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, one change in #13746 that could be relevant to this is that we now allocate separate ubatch data buffers for each llama_ubatch:

https://github.com/ggml-org/llama.cpp/pull/13746/files#diff-e86f4c320ddf096b16dccbc876be6a50f6d6fc4e690b7ebba8a526cd8caa8f14R52-R63

On master, the ubatches were simply views of a single buffer, the contents of which were updated on every split. But this was not compatible with the peparation logic, so now we allocate separate buffers.

@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch 4 times, most recently from 9548d2a to 256f1b7 Compare May 30, 2025 08:08
@ggerganov ggerganov requested a review from ngxson as a code owner May 30, 2025 08:08
@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch 2 times, most recently from 9d05381 to 2b984f4 Compare May 30, 2025 08:29
@ggerganov ggerganov removed the request for review from ngxson May 30, 2025 12:26
@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch from f23e4cc to 71619f2 Compare May 31, 2025 07:05
Base automatically changed from gg/kv-cache-simplify-part3 to master May 31, 2025 07:24
@compilade compilade marked this pull request as draft May 31, 2025 15:28
* kv-cache : remove inp_s_mask

It was replaced with equivalent and simpler functionality
with rs_z (the first zeroed state) and the already-existing inp_s_copy.

* kv-cache : fix non-consecutive token pos warning for recurrent models

The problem was apparently caused by how the tail cells were swapped.

* graph : simplify logic for recurrent state copies

* kv-cache : use cell without src refs for rs_z in recurrent cache
@gabe-l-hart
Copy link
Contributor

gabe-l-hart commented Jun 9, 2025

@ggerganov @compilade I'd love to get an update on the state of this branch (and #9126) as we prepare for the Granite 4 launch. It seems like there have been a number of structural changes that need to be accommodated since this branch was last updated to master. If it would help, I'd be happy to take a pass at resolving these conflicts and open a PR into this branch.

@ggerganov
Copy link
Member

@gabe-l-hart My understanding is that there is an issue in the logic of the recurrent cells. We should find and fix this before any additional changes to the recurrent cache. I will try to take a look tomorrow using this branch and the repro command from above.

@gabe-l-hart
Copy link
Contributor

@ggerganov Thanks! It's much appreciated. I may be able to dig a little today as well, so if I happen to get lucky and find something before you get to it, I'll update this PR.

@compilade compilade force-pushed the compilade/readonly-recurrent-inputs branch from bdbfb4e to dd6495d Compare June 9, 2025 20:36
@compilade
Copy link
Collaborator Author

compilade commented Jun 9, 2025

My understanding is that there is an issue in the logic of the recurrent cells.

That is correct, there is a problem with the recurrent cells in multi-user cases, and from further tests it seems like it has been like this since a while ago on master (not sure when exactly, but at least 3 months). It's reproducible on master when not using -pps with the command in #13834 (comment). The states seem to leak between each other.

This PR does seem to fix part of the problem, though, since at -ub 1 the problem does not manifest itself anymore (although it's still there at bigger ubatch sizes, and it's also present when using -pps (at ubatch > 1) even though on master it doesn't seem as bad in that case).

I'm currently writing tests which will compare the outputs of batches at different concurrencies, so that it will be possible to notice problems in batch splits and the caches in the future (and hopefully this should also help figuring out the source of the current problem by narrowing the failing conditions).

@gabe-l-hart
Copy link
Contributor

gabe-l-hart commented Jun 10, 2025

(EDIT: I just realized I hadn't pulled your latest sync with master, so this might not be accurate with the latest changes)

Some more evidence that might be helpful as I'm trying to repro the swapped answers mentioned in #13834 (comment). I don't see answers being swapped, but I do see some curious inconsistencies between running with/without -pps. My rough understanding is that we would expect identical results with and without -pps since it's simply attempting to reuse the cache for the system prompt when enabled.

On my machine, the (seeded) random selection of prompts is choosing (including the duplication):

  1. What is the meaning of life?
  2. If you could have any superpower, what would it be?
  3. I want to learn how to play the piano. What would be the best way to do it?
  4. I want to learn how to play the piano. What would be the best way to do it?
  5. Recommend some interesting books to read.
  6. Recommend some interesting books to read.
  7. What is the best way to cook a steak?
  8. Recommend some interesting books to read.

The commands I'm running are:

./bin/llama-parallel -m ~/models/mamba-130m-hf/mamba-130M-hf-F16.gguf -np 5 -ns 8 --temp 0 --repeat-penalty 1.1 -ub 2 --junk 0

./bin/llama-parallel -m ~/models/mamba-130m-hf/mamba-130M-hf-F16.gguf -np 5 -ns 8 --temp 0 --repeat-penalty 1.1 -ub 2 --junk 0 -pps

What I'm seeing is that for most of the prompts, the responses are identical with / without -pps, but the responses for (3) and (7) differ. The curious thing is that they deterministically differ (3: [without] I like to read. / [with] I would suggest you play the piano., 7: [without] The best way to cook a steak is to use the same as in [Section \[sec:model\]]{}.\n<|endoftext|> / [without] The best thing about this is that it's not a "new" game. It's an old one, and the new version has some great features like 3D graphics, voice acting, and more.\nIt's also very easy to play, and you can even use your mouse for the controls!<|endoftext|>). Further when running without pps, the prompts for (3) and (4) should be identical, but the responses are not.

I'm not sure if this is actually helpful, but it might help zero in on a specific repro and/or hint at what's getting mixed up.

@gabe-l-hart
Copy link
Contributor

gabe-l-hart commented Jun 10, 2025

I just rebuilt with your latest changes and I'm seeing different inconsistencies between with / without -pps. In this instance, the order that the sequences get assigned sequence IDs is different between with / without.

I've been walking through running with/without using lldb and dumping out the cell struct at llama-kv-cache-recurrent.cpp:158. When I start to see the with / without results diverge, the consistent thing I'm seeing is that the pos value is slightly different between the two. I saw this same symptom before pulling your latest changes on the corresponding line in the old llama-kv-cache.cpp, so it seems that somehow the pos value may be getting corrupted or at least be inconsistently populated with/without -pps.

The `state_copy` shuffle assumes everything is moved at once,
which is not true when `states_extra` is copied back to the cache
before copying the range of states between `head` and `head + n_seqs`.
This is only a problem if any of the cells in [`head`, `head + n_seqs`)
have an `src` in [`head + n_seqs`, `head + n_kv`),
which does happen when `n_ubatch > 1` in the `llama-parallel` example.

Changing the order of the operations avoids the potential overwrite
before use, although when copies are avoided (like with Mamba2),
this will require further changes.

* llama-graph : rename n_state to state_size in build_recurrent_state

This naming should reduce confusion between the state size
and the number of states.
@compilade
Copy link
Collaborator Author

compilade commented Jun 10, 2025

When I start to see the with / without results diverge, the consistent thing I'm seeing is that the pos value is slightly different between the two. I saw this same symptom before pulling your latest changes on the corresponding line in the old llama-kv-cache.cpp, so it seems that somehow the pos value may be getting corrupted or at least be inconsistently populated with/without -pps.

@gabe-l-hart Interesting! The pos internally is there to keep track of where in the sequence the tokens come from (and how many tokens were processed), but it's not used directly in the inference (build_inp_pos is not used for Mamba). It's weird that this corruption somehow is not caught by the check in

if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
// What should happen when the pos backtracks or skips a value?
// Clearing the state mid-batch would require special-casing which isn't done.
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
}

dumping out the cell struct at llama-kv-cache-recurrent.cpp:158.

This is in seq_rm, so I would assume this is happens at the end of a sequence, and if the contents differ, so might the lengths (and the pos).


Oh my, I might have found how the problem happens. Thank you for further narrowing what the cause could be. It really did help. The metadata seems to look fine, but how the states are used does not (because it leads to different output).

I was wondering why the problem did not trigger anymore at -ub 1, and then I remembered llm_graph_context::build_recurrent_state now sometime makes an extra ggml_cpy when n_kv > n_seqs (usually when needing to move states around, and when more are moved than used). This does not happen at -ub 1, since the states never have to be moved to get 1 contiguous slot. (it would also not happen with single-sequence batches, because only one slot per sequence is needed for the recurrent cache).

More precisely, the problem only happens when a slot has n_kv > n_seqs, and the cells between head and head + n have an src between head + n_seqs and head + n.

llama.cpp/src/llama-graph.cpp

Lines 1438 to 1457 in dd6495d

ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
// Clear a single state which will then be copied to the other cleared states.
// Note that this is a no-op when the view is zero-sized.
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
// copy extra states which won't be changed further (between n_seqs and n_kv)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
states_extra,
ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
if (!avoid_copies) {
// copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// this shrinks the tensors's ne[1] to n_seqs
states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
}

(this is the problematic code)

The state_copy shuffle assumes everything is moved at once, but that is not true when states_extra is copied to the source states buffer. A solution is to first perform the (copying) gathering operation on the states which will be used (with ggml_get_rows), and only after that copy back the extra states to the cache.

(on master, the problem is caused by something else because that logic isn't there, and it seems like applying only the swap tails fix is not sufficient. I did not investigate further.)

With the new changes in 62a9f34, I'm getting deterministic multi-user output!

Mamba2 (in #9126) uses the avoid_copies parameter of build_recurrent_states, though, and the fix might need to involve a lambda function passed instead of avoid_copies to make sure the data is used before being potentially overwritten.

@compilade compilade marked this pull request as ready for review June 10, 2025 04:39
@gabe-l-hart
Copy link
Contributor

gabe-l-hart commented Jun 10, 2025

Hooray! So glad you found it. I'll test it on my side tomorrow for another set of eyes. Some day, I want to fully unpack all the moving parts so I can read this explanation in all its detail.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@compilade Feel free to merge.

@gabe-l-hart
Copy link
Contributor

Confirmed on my end. With that latest fix, I see both deterministic responses when the same prompt is repeated within a parallel batch, and deterministic responses between running with and without -pps.

@compilade compilade merged commit dad5c44 into master Jun 10, 2025
46 of 47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants