@@ -493,12 +493,40 @@ const int BOS_TOKEN_ID = 49406;
493493const int EOS_TOKEN_ID = 49407 ;
494494const int PAD_TOKEN_ID = 49407 ;
495495
496+ std::vector<std::pair<int , std::u32string>> bytes_to_unicode () {
497+ std::vector<std::pair<int , std::u32string>> byte_unicode_pairs;
498+ std::set<int > byte_set;
499+ for (int b = static_cast <int >(' !' ); b <= static_cast <int >(' ~' ); ++b) {
500+ byte_set.insert (b);
501+ byte_unicode_pairs.push_back (std::pair<int , std::u32string>(b, unicode_value_to_utf32 (b)));
502+ }
503+ for (int b = 161 ; b <= 172 ; ++b) {
504+ byte_set.insert (b);
505+ byte_unicode_pairs.push_back (std::pair<int , std::u32string>(b, unicode_value_to_utf32 (b)));
506+ }
507+ for (int b = 174 ; b <= 255 ; ++b) {
508+ byte_set.insert (b);
509+ byte_unicode_pairs.push_back (std::pair<int , std::u32string>(b, unicode_value_to_utf32 (b)));
510+ }
511+ int n = 0 ;
512+ for (int b = 0 ; b < 256 ; ++b) {
513+ if (byte_set.find (b) == byte_set.end ()) {
514+ byte_unicode_pairs.push_back (std::pair<int , std::u32string>(b, unicode_value_to_utf32 (n + 256 )));
515+ ++n;
516+ }
517+ }
518+ // LOG_DEBUG("byte_unicode_pairs %d", byte_unicode_pairs.size());
519+ return byte_unicode_pairs;
520+ }
521+
496522// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
497523// TODO: implement bpe
498524class CLIPTokenizer {
499525private:
500526 SDVersion version = VERSION_1_x;
501- std::map<std::string, int32_t > encoder;
527+ std::map<int , std::u32string> byte_encoder;
528+ std::map<std::u32string, int > encoder;
529+ std::map<std::pair<std::u32string, std::u32string>, int > bpe_ranks;
502530 std::regex pat;
503531
504532 static std::string strip (const std::string& str) {
@@ -521,19 +549,61 @@ class CLIPTokenizer {
521549
522550public:
523551 CLIPTokenizer (SDVersion version = VERSION_1_x)
524- : version(version){};
525- std::string bpe (std::string token) {
526- std::string word = token + " </w>" ;
552+ : version(version) {}
553+
554+ void load_from_merges (const std::string& merges_utf8_str) {
555+ auto byte_unicode_pairs = bytes_to_unicode ();
556+ byte_encoder = std::map<int , std::u32string>(byte_unicode_pairs.begin (), byte_unicode_pairs.end ());
557+ // for (auto & pair: byte_unicode_pairs) {
558+ // std::cout << pair.first << ": " << pair.second << std::endl;
559+ // }
560+ std::vector<std::u32string> merges;
561+ size_t start = 0 ;
562+ size_t pos;
563+ std::u32string merges_utf32_str = utf8_to_utf32 (merges_utf8_str);
564+ while ((pos = merges_utf32_str.find (' \n ' , start)) != std::string::npos) {
565+ merges.push_back (merges_utf32_str.substr (start, pos - start));
566+ start = pos + 1 ;
567+ }
568+ merges = std::vector<std::u32string>(merges.begin () + 1 , merges.begin () + 49152 - 256 - 2 + 1 );
569+ std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
570+ for (const auto & merge : merges) {
571+ size_t space_pos = merge.find (' ' );
572+ merge_pairs.emplace_back (merge.substr (0 , space_pos), merge.substr (space_pos + 1 ));
573+ // LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
574+ }
575+ std::vector<std::u32string> vocab;
576+ for (const auto & pair : byte_unicode_pairs) {
577+ vocab.push_back (pair.second );
578+ }
579+ for (const auto & pair : byte_unicode_pairs) {
580+ vocab.push_back (pair.second + utf8_to_utf32 (" </w>" ));
581+ }
582+ for (const auto & merge : merge_pairs) {
583+ vocab.push_back (merge.first + merge.second );
584+ }
585+ vocab.push_back (utf8_to_utf32 (" <|startoftext|>" ));
586+ vocab.push_back (utf8_to_utf32 (" <|endoftext|>" ));
587+ LOG_DEBUG (" vocab size: %llu" , vocab.size ());
588+ int i = 0 ;
589+ for (const auto & token : vocab) {
590+ encoder[token] = i++;
591+ }
592+
593+ int rank = 0 ;
594+ for (const auto & merge : merge_pairs) {
595+ bpe_ranks[merge] = rank++;
596+ }
597+ };
598+
599+ std::u32string bpe (std::u32string token) {
600+ std::u32string word = token + utf8_to_utf32 (" </w>" );
527601 if (encoder.find (word) != encoder.end ()) {
528602 return word;
529603 } else if (encoder.find (token) != encoder.end ()) {
530604 return token;
531605 }
532- return UNK_TOKEN;
533- }
534-
535- void add_token (std::string token, int32_t token_id) {
536- encoder[token] = token_id;
606+ return utf8_to_utf32 (UNK_TOKEN);
537607 }
538608
539609 std::vector<int > tokenize (std::string text, size_t max_length = 0 , bool padding = false ) {
@@ -571,13 +641,25 @@ class CLIPTokenizer {
571641 std::vector<std::string> token_strs;
572642 while (std::regex_search (str, matches, pat)) {
573643 for (auto & token : matches) {
574- std::istringstream iss (bpe (token));
575- std::vector<std::string> tokens{std::istream_iterator<std::string>{iss},
576- std::istream_iterator<std::string>{}};
577- for (const auto & bpe_token : tokens) {
578- bpe_tokens.push_back (encoder[bpe_token]);
579- token_strs.push_back (bpe_token);
644+ std::string token_str = token.str ();
645+ std::u32string utf32_token;
646+ for (int i = 0 ; i < token_str.length (); i++) {
647+ char b = token_str[i];
648+ utf32_token += byte_encoder[b];
580649 }
650+ auto bpe_strs = bpe (utf32_token);
651+ size_t start = 0 ;
652+ size_t pos;
653+ while ((pos = bpe_strs.find (' ' , start)) != std::u32string::npos) {
654+ auto bpe_str = bpe_strs.substr (start, pos - start);
655+ bpe_tokens.push_back (encoder[bpe_str]);
656+ token_strs.push_back (utf32_to_utf8 (bpe_str));
657+
658+ start = pos + 1 ;
659+ }
660+ auto bpe_str = bpe_strs.substr (start, bpe_strs.size () - start);
661+ bpe_tokens.push_back (encoder[bpe_str]);
662+ token_strs.push_back (utf32_to_utf8 (bpe_str));
581663 }
582664 str = matches.suffix ();
583665 }
@@ -4323,15 +4405,14 @@ class StableDiffusionGGML {
43234405 LOG_INFO (" Stable Diffusion weight type: %s" , ggml_type_name (model_data_type));
43244406
43254407 LOG_DEBUG (" loading vocab" );
4326- auto add_token = [&](const std::string& token, int32_t token_id) {
4327- cond_stage_model.tokenizer .add_token (token, token_id);
4328- };
4329- bool success = model_loader.load_vocab (add_token);
4330- if (!success) {
4331- LOG_ERROR (" get vocab from file failed: '%s'" , model_path.c_str ());
4408+ std::string merges_utf8_str = model_loader.load_merges ();
4409+ if (merges_utf8_str.size () == 0 ) {
4410+ LOG_ERROR (" get merges failed: '%s'" , model_path.c_str ());
43324411 return false ;
43334412 }
43344413
4414+ cond_stage_model.tokenizer .load_from_merges (merges_utf8_str);
4415+
43354416 // create the ggml context for network params
43364417 LOG_DEBUG (" ggml tensor size = %d bytes" , (int )sizeof (ggml_tensor));
43374418
@@ -4431,7 +4512,7 @@ class StableDiffusionGGML {
44314512
44324513 // print_ggml_tensor(alphas_cumprod_tensor);
44334514
4434- success = model_loader.load_tensors (on_new_tensor_cb);
4515+ bool success = model_loader.load_tensors (on_new_tensor_cb);
44354516 if (!success) {
44364517 LOG_ERROR (" load tensors from file failed" );
44374518 ggml_free (ctx);
0 commit comments