Skip to content

Commit c8eeb93

Browse files
ulatekhggerganov
andauthored
whisper : suppress tokens with a regex (ggml-org#1997)
* Allow a regular expression to describe tokens to suppress. Example: --suppress-tokens-re "[,\.]|[ ]?[0-9]+" will suppress commas, periods, and numeric tokens. Technique inspired by openai/whisper#1041 Co-authored-by: Georgi Gerganov <[email protected]> * Blind change to fix Java test. --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 319fe51 commit c8eeb93

File tree

5 files changed

+35
-1
lines changed

5 files changed

+35
-1
lines changed

bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ public void tdrzEnable(boolean enable) {
148148
tdrz_enable = enable ? CBool.TRUE : CBool.FALSE;
149149
}
150150

151+
/** Regular expression matching tokens to suppress. */
152+
public String suppress_regex;
153+
151154
/** Tokens to provide to the whisper decoder as an initial prompt.
152155
* These are prepended to any existing text context from a previous call. */
153156
public String initial_prompt;
@@ -319,7 +322,7 @@ protected List<String> getFieldOrder() {
319322
"no_context", "single_segment", "no_timestamps",
320323
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
321324
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
322-
"tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
325+
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
323326
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
324327
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
325328
"new_segment_callback", "new_segment_callback_user_data",

examples/command/command.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ struct whisper_params {
5252
std::string prompt;
5353
std::string context;
5454
std::string grammar;
55+
56+
// A regular expression that matches tokens to suppress
57+
std::string suppress_regex;
5558
};
5659

5760
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -85,6 +88,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
8588
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
8689
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
8790
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
91+
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
8892
else {
8993
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
9094
whisper_print_usage(argc, argv, params);
@@ -122,6 +126,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
122126
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
123127
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
124128
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
129+
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
125130
fprintf(stderr, "\n");
126131
}
127132

@@ -167,6 +172,8 @@ std::string transcribe(
167172

168173
wparams.initial_prompt = params.context.data();
169174

175+
wparams.suppress_regex = params.suppress_regex.c_str();
176+
170177
const auto & grammar_parsed = params.grammar_parsed;
171178
auto grammar_rules = grammar_parsed.c_rules();
172179

examples/main/main.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <cmath>
77
#include <fstream>
88
#include <cstdio>
9+
#include <regex>
910
#include <string>
1011
#include <thread>
1112
#include <vector>
@@ -78,6 +79,9 @@ struct whisper_params {
7879
// [TDRZ] speaker turn string
7980
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
8081

82+
// A regular expression that matches tokens to suppress
83+
std::string suppress_regex;
84+
8185
std::string openvino_encode_device = "CPU";
8286

8387
std::string dtw = "";
@@ -160,6 +164,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
160164
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
161165
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
162166
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
167+
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
163168
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
164169
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
165170
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
@@ -223,6 +228,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
223228
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
224229
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
225230
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
231+
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
226232
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
227233
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
228234
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
@@ -1033,6 +1039,8 @@ int main(int argc, char ** argv) {
10331039

10341040
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
10351041

1042+
wparams.suppress_regex = params.suppress_regex.c_str();
1043+
10361044
wparams.initial_prompt = params.prompt.c_str();
10371045

10381046
wparams.greedy.best_of = params.best_of;

whisper.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4553,6 +4553,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
45534553

45544554
/*.tdrz_enable =*/ false,
45554555

4556+
/* suppress_regex =*/ nullptr,
4557+
45564558
/*.initial_prompt =*/ nullptr,
45574559
/*.prompt_tokens =*/ nullptr,
45584560
/*.prompt_n_tokens =*/ 0,
@@ -4796,6 +4798,17 @@ static void whisper_process_logits(
47964798
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
47974799
}
47984800

4801+
// suppress any tokens matching a regular expression
4802+
// ref: https://github.com/openai/whisper/discussions/1041
4803+
if (params.suppress_regex != nullptr) {
4804+
std::regex re(params.suppress_regex);
4805+
for (std::pair<whisper_vocab::token, whisper_vocab::id> token_id : vocab.token_to_id) {
4806+
if (std::regex_match(token_id.first, re)) {
4807+
logits[token_id.second] = -INFINITY;
4808+
}
4809+
}
4810+
}
4811+
47994812
// suppress non-speech tokens
48004813
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
48014814
if (params.suppress_non_speech_tokens) {

whisper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,9 @@ extern "C" {
505505
// [EXPERIMENTAL] [TDRZ] tinydiarize
506506
bool tdrz_enable; // enable tinydiarize speaker turn detection
507507

508+
// A regular expression that matches tokens to suppress
509+
const char * suppress_regex;
510+
508511
// tokens to provide to the whisper decoder as initial prompt
509512
// these are prepended to any existing text context from a previous call
510513
// use whisper_tokenize() to convert text to tokens

0 commit comments

Comments
 (0)