@@ -278,13 +278,30 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
278278 const std::string& curr_text = item.first ;
279279 float curr_weight = item.second ;
280280 // printf(" %s: %f \n", curr_text.c_str(), curr_weight);
281+ int32_t clean_index = 0 ;
282+ if (curr_text == " BREAK" && curr_weight == -1 .0f ) {
283+ // Pad token array up to chunk size at this point.
284+ // TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
285+ // Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
286+ int padding_size = 75 - (tokens_acc % 75 );
287+ for (int j = 0 ; j < padding_size; j++) {
288+ clean_input_ids.push_back (tokenizer.EOS_TOKEN_ID );
289+ clean_index++;
290+ }
291+
292+ // After padding, continue to the next iteration to process the following text as a new segment
293+ tokens.insert (tokens.end (), clean_input_ids.begin (), clean_input_ids.end ());
294+ weights.insert (weights.end (), padding_size, curr_weight);
295+ continue ;
296+ }
297+
298+ // Regular token, process normally
281299 std::vector<int > curr_tokens = tokenizer.encode (curr_text, on_new_token_cb);
282- int32_t clean_index = 0 ;
283300 for (uint32_t i = 0 ; i < curr_tokens.size (); i++) {
284301 int token_id = curr_tokens[i];
285- if (token_id == image_token)
302+ if (token_id == image_token) {
286303 class_token_index.push_back (clean_index - 1 );
287- else {
304+ } else {
288305 clean_input_ids.push_back (token_id);
289306 clean_index++;
290307 }
@@ -387,6 +404,22 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
387404 for (const auto & item : parsed_attention) {
388405 const std::string& curr_text = item.first ;
389406 float curr_weight = item.second ;
407+
408+ if (curr_text == " BREAK" && curr_weight == -1 .0f ) {
409+ // Pad token array up to chunk size at this point.
410+ // TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
411+ // Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
412+ size_t current_size = tokens.size ();
413+ size_t padding_size = (75 - (current_size % 75 )) % 75 ; // Ensure no negative padding
414+
415+ if (padding_size > 0 ) {
416+ LOG_DEBUG (" BREAK token encountered, padding current chunk by %zu tokens." , padding_size);
417+ tokens.insert (tokens.end (), padding_size, tokenizer.EOS_TOKEN_ID );
418+ weights.insert (weights.end (), padding_size, 1 .0f );
419+ }
420+ continue ; // Skip to the next item after handling BREAK
421+ }
422+
390423 std::vector<int > curr_tokens = tokenizer.encode (curr_text, on_new_token_cb);
391424 tokens.insert (tokens.end (), curr_tokens.begin (), curr_tokens.end ());
392425 weights.insert (weights.end (), curr_tokens.size (), curr_weight);
0 commit comments