Skip to content

Commit 5fb5e24

Browse files
authored
llama : minor sampling refactor (2) (ggml-org#9386)
1 parent 38ca6f6 commit 5fb5e24

File tree

12 files changed

+104
-102
lines changed

12 files changed

+104
-102
lines changed

examples/batched.swift/Sources/main.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,6 @@ while n_cur <= n_len {
140140

141141
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
142142

143-
llama_sampler_accept(smpl, new_token_id)
144-
145143
// is it an end of stream? -> mark the stream as finished
146144
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
147145
i_batch[i] = -1

examples/batched/batched.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,6 @@ int main(int argc, char ** argv) {
172172

173173
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
174174

175-
llama_sampler_accept(smpl, new_token_id);
176-
177175
// is it an end of generation? -> mark the stream as finished
178176
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
179177
i_batch[i] = -1;

examples/gritlm/gritlm.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
121121
llama_decode(ctx, bat);
122122

123123
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
124-
llama_sampler_accept(smpl, token);
125124

126125
if (token == eos_token) {
127126
break;

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,6 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
414414
// sample the most likely token
415415
const auto new_token_id = llama_sampler_sample(sampler, context, -1);
416416

417-
llama_sampler_accept(sampler, new_token_id);
418-
419417
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
420418
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
421419
return nullptr;

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,6 @@ actor LlamaContext {
152152

153153
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
154154

155-
llama_sampler_accept(sampling, new_token_id)
156-
157155
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
158156
print("\n")
159157
is_done = true

examples/passkey/passkey.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,6 @@ int main(int argc, char ** argv) {
220220
{
221221
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
222222

223-
llama_sampler_accept(smpl, new_token_id);
224-
225223
// is it an end of generation?
226224
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
227225
LOG_TEE("\n");

examples/save-load-state/save-load-state.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ int main(int argc, char ** argv) {
7474
auto next_token = llama_sampler_sample(smpl, ctx, -1);
7575
auto next_token_str = llama_token_to_piece(ctx, next_token);
7676

77-
llama_sampler_accept(smpl, next_token);
78-
7977
printf("%s", next_token_str.c_str());
8078
result0 += next_token_str;
8179

@@ -132,8 +130,6 @@ int main(int argc, char ** argv) {
132130
auto next_token = llama_sampler_sample(smpl2, ctx2, -1);
133131
auto next_token_str = llama_token_to_piece(ctx2, next_token);
134132

135-
llama_sampler_accept(smpl2, next_token);
136-
137133
printf("%s", next_token_str.c_str());
138134
result1 += next_token_str;
139135

@@ -222,8 +218,6 @@ int main(int argc, char ** argv) {
222218
auto next_token = llama_sampler_sample(smpl3, ctx3, -1);
223219
auto next_token_str = llama_token_to_piece(ctx3, next_token);
224220

225-
llama_sampler_accept(smpl3, next_token);
226-
227221
printf("%s", next_token_str.c_str());
228222
result2 += next_token_str;
229223

examples/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ struct server_context {
613613

614614
gpt_params params;
615615

616-
llama_batch batch;
616+
llama_batch batch = {};
617617

618618
bool clean_kv_cache = true;
619619
bool add_bos_token = true;

examples/simple/simple.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,6 @@ int main(int argc, char ** argv) {
118118
{
119119
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
120120

121-
llama_sampler_accept(smpl, new_token_id);
122-
123121
// is it an end of generation?
124122
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
125123
LOG_TEE("\n");

include/llama.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,15 +1127,16 @@ extern "C" {
11271127
int32_t n_logit_bias,
11281128
const llama_logit_bias * logit_bias);
11291129

1130-
// Shorthand for:
1130+
/// @details Sample and accept a token from the idx-th output of the last evaluation
11311131
//
1132+
// Shorthand for:
11321133
// const auto * logits = llama_get_logits_ith(ctx, idx);
11331134
// llama_token_data_array cur_p = { ... init from logits ... };
11341135
// llama_sampler_apply(smpl, &cur_p);
1135-
// return cur_p.data[cur_p.selected].id;
1136-
//
1137-
// At this point, this is mostly a convenience function.
1138-
//
1136+
// auto token = cur_p.data[cur_p.selected].id;
1137+
// llama_sampler_accept(smpl, token);
1138+
// return token;
1139+
// Returns the sampled token
11391140
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
11401141

11411142
// TODO: extend in the future

0 commit comments

Comments
 (0)