Skip to content

Commit 09fe2e7

Browse files
nvrxqngxson
andauthored
server: allow filtering llama server response fields (ggml-org#10940)
* llama_server_response_fields * llama_server_response_fields_fix_issues * params fixes * fix * clarify docs * change to "response_fields" --------- Co-authored-by: Xuan Son Nguyen <[email protected]>
1 parent 30caac3 commit 09fe2e7

File tree

4 files changed

+63
-1
lines changed

4 files changed

+63
-1
lines changed

examples/server/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,8 @@ These words will not be included in the completion, so make sure to add them to
450450

451451
`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
452452

453+
`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error.
454+
453455
**Response format**
454456

455457
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.

examples/server/server.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ struct slot_params {
9292
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
9393

9494
std::vector<std::string> antiprompt;
95+
std::vector<std::string> response_fields;
9596
bool timings_per_token = false;
9697
bool post_sampling_probs = false;
9798
bool ignore_eos = false;
@@ -209,6 +210,7 @@ struct server_task {
209210
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
210211
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
211212
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
213+
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
212214

213215
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
214216
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
@@ -522,6 +524,7 @@ struct server_task_result_cmpl_final : server_task_result {
522524

523525
bool post_sampling_probs;
524526
std::vector<completion_token_output> probs_output;
527+
std::vector<std::string> response_fields;
525528

526529
slot_params generation_params;
527530

@@ -568,7 +571,7 @@ struct server_task_result_cmpl_final : server_task_result {
568571
if (!stream && !probs_output.empty()) {
569572
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
570573
}
571-
return res;
574+
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
572575
}
573576

574577
json to_json_oaicompat_chat() {
@@ -2066,6 +2069,7 @@ struct server_context {
20662069
res->tokens = slot.generated_tokens;
20672070
res->timings = slot.get_timings();
20682071
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
2072+
res->response_fields = slot.params.response_fields;
20692073

20702074
res->truncated = slot.truncated;
20712075
res->n_decoded = slot.n_decoded;

examples/server/tests/unit/test_completion.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,40 @@ def check_slots_status():
257257
# assert match_regex(re_content, res.body["content"])
258258

259259

260+
@pytest.mark.parametrize(
261+
"prompt,n_predict,response_fields",
262+
[
263+
("I believe the meaning of life is", 8, []),
264+
("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
265+
],
266+
)
267+
def test_completion_response_fields(
268+
prompt: str, n_predict: int, response_fields: list[str]
269+
):
270+
global server
271+
server.start()
272+
res = server.make_request(
273+
"POST",
274+
"/completion",
275+
data={
276+
"n_predict": n_predict,
277+
"prompt": prompt,
278+
"response_fields": response_fields,
279+
},
280+
)
281+
assert res.status_code == 200
282+
assert "content" in res.body
283+
assert len(res.body["content"])
284+
if len(response_fields):
285+
assert res.body["generation_settings/n_predict"] == n_predict
286+
assert res.body["prompt"] == "<s> " + prompt
287+
assert isinstance(res.body["content"], str)
288+
assert len(res.body) == len(response_fields)
289+
else:
290+
assert len(res.body)
291+
assert "generation_settings" in res.body
292+
293+
260294
def test_n_probs():
261295
global server
262296
server.start()

examples/server/utils.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,28 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) {
9090
return false;
9191
}
9292

93+
// get value by path(key1 / key2)
94+
static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
95+
json result = json::object();
96+
97+
for (const std::string & path : paths) {
98+
json current = js;
99+
const auto keys = string_split<std::string>(path, /*separator*/ '/');
100+
bool valid_path = true;
101+
for (const std::string & k : keys) {
102+
if (valid_path && current.is_object() && current.contains(k)) {
103+
current = current[k];
104+
} else {
105+
valid_path = false;
106+
}
107+
}
108+
if (valid_path) {
109+
result[path] = current;
110+
}
111+
}
112+
return result;
113+
}
114+
93115
/**
94116
* this handles 2 cases:
95117
* - only string, example: "string"

0 commit comments

Comments
 (0)