Skip to content

Commit 6448430

Browse files
daniandthewebursg
andauthored
feat: add break pseudo token support (leejet#422)
--------- Co-authored-by: Urs Ganse <[email protected]>
1 parent 347710f commit 6448430

File tree

2 files changed

+44
-6
lines changed

2 files changed

+44
-6
lines changed

conditioner.hpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,30 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
278278
const std::string& curr_text = item.first;
279279
float curr_weight = item.second;
280280
// printf(" %s: %f \n", curr_text.c_str(), curr_weight);
281+
int32_t clean_index = 0;
282+
if (curr_text == "BREAK" && curr_weight == -1.0f) {
283+
// Pad token array up to chunk size at this point.
284+
// TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
285+
// Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
286+
int padding_size = 75 - (tokens_acc % 75);
287+
for (int j = 0; j < padding_size; j++) {
288+
clean_input_ids.push_back(tokenizer.EOS_TOKEN_ID);
289+
clean_index++;
290+
}
291+
292+
// After padding, continue to the next iteration to process the following text as a new segment
293+
tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end());
294+
weights.insert(weights.end(), padding_size, curr_weight);
295+
continue;
296+
}
297+
298+
// Regular token, process normally
281299
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
282-
int32_t clean_index = 0;
283300
for (uint32_t i = 0; i < curr_tokens.size(); i++) {
284301
int token_id = curr_tokens[i];
285-
if (token_id == image_token)
302+
if (token_id == image_token) {
286303
class_token_index.push_back(clean_index - 1);
287-
else {
304+
} else {
288305
clean_input_ids.push_back(token_id);
289306
clean_index++;
290307
}
@@ -387,6 +404,22 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
387404
for (const auto& item : parsed_attention) {
388405
const std::string& curr_text = item.first;
389406
float curr_weight = item.second;
407+
408+
if (curr_text == "BREAK" && curr_weight == -1.0f) {
409+
// Pad token array up to chunk size at this point.
410+
// TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
411+
// Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
412+
size_t current_size = tokens.size();
413+
size_t padding_size = (75 - (current_size % 75)) % 75; // Ensure no negative padding
414+
415+
if (padding_size > 0) {
416+
LOG_DEBUG("BREAK token encountered, padding current chunk by %zu tokens.", padding_size);
417+
tokens.insert(tokens.end(), padding_size, tokenizer.EOS_TOKEN_ID);
418+
weights.insert(weights.end(), padding_size, 1.0f);
419+
}
420+
continue; // Skip to the next item after handling BREAK
421+
}
422+
390423
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
391424
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
392425
weights.insert(weights.end(), curr_tokens.size(), curr_weight);

util.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <cstdarg>
66
#include <fstream>
77
#include <locale>
8+
#include <regex>
89
#include <sstream>
910
#include <string>
1011
#include <thread>
@@ -547,6 +548,8 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int targe
547548
// (abc) - increases attention to abc by a multiplier of 1.1
548549
// (abc:3.12) - increases attention to abc by a multiplier of 3.12
549550
// [abc] - decreases attention to abc by a multiplier of 1.1
551+
// BREAK - separates the prompt into conceptually distinct parts for sequential processing
552+
// B - internal helper pattern; prevents 'B' in 'BREAK' from being consumed as normal text
550553
// \( - literal character '('
551554
// \[ - literal character '['
552555
// \) - literal character ')'
@@ -582,7 +585,7 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
582585
float round_bracket_multiplier = 1.1f;
583586
float square_bracket_multiplier = 1 / 1.1f;
584587

585-
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|[^\\()\[\]:]+|:)");
588+
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|\bBREAK\b|[^\\()\[\]:B]+|:|\bB)");
586589
std::regex re_break(R"(\s*\bBREAK\b\s*)");
587590

588591
auto multiply_range = [&](int start_position, float multiplier) {
@@ -591,7 +594,7 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
591594
}
592595
};
593596

594-
std::smatch m;
597+
std::smatch m, m2;
595598
std::string remaining_text = text;
596599

597600
while (std::regex_search(remaining_text, m, re_attention)) {
@@ -615,6 +618,8 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
615618
square_brackets.pop_back();
616619
} else if (text == "\\(") {
617620
res.push_back({text.substr(1), 1.0f});
621+
} else if (std::regex_search(text, m2, re_break)) {
622+
res.push_back({"BREAK", -1.0f});
618623
} else {
619624
res.push_back({text, 1.0f});
620625
}
@@ -645,4 +650,4 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
645650
}
646651

647652
return res;
648-
}
653+
}

0 commit comments

Comments
 (0)