@@ -493,12 +493,40 @@ const int BOS_TOKEN_ID = 49406;
493
493
const int EOS_TOKEN_ID = 49407 ;
494
494
const int PAD_TOKEN_ID = 49407 ;
495
495
496
+ std::vector<std::pair<int , std::u32string>> bytes_to_unicode () {
497
+ std::vector<std::pair<int , std::u32string>> byte_unicode_pairs;
498
+ std::set<int > byte_set;
499
+ for (int b = static_cast <int >(' !' ); b <= static_cast <int >(' ~' ); ++b) {
500
+ byte_set.insert (b);
501
+ byte_unicode_pairs.push_back (std::pair<int , std::u32string>(b, unicode_value_to_utf32 (b)));
502
+ }
503
+ for (int b = 161 ; b <= 172 ; ++b) {
504
+ byte_set.insert (b);
505
+ byte_unicode_pairs.push_back (std::pair<int , std::u32string>(b, unicode_value_to_utf32 (b)));
506
+ }
507
+ for (int b = 174 ; b <= 255 ; ++b) {
508
+ byte_set.insert (b);
509
+ byte_unicode_pairs.push_back (std::pair<int , std::u32string>(b, unicode_value_to_utf32 (b)));
510
+ }
511
+ int n = 0 ;
512
+ for (int b = 0 ; b < 256 ; ++b) {
513
+ if (byte_set.find (b) == byte_set.end ()) {
514
+ byte_unicode_pairs.push_back (std::pair<int , std::u32string>(b, unicode_value_to_utf32 (n + 256 )));
515
+ ++n;
516
+ }
517
+ }
518
+ // LOG_DEBUG("byte_unicode_pairs %d", byte_unicode_pairs.size());
519
+ return byte_unicode_pairs;
520
+ }
521
+
496
522
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
497
523
// TODO: implement bpe
498
524
class CLIPTokenizer {
499
525
private:
500
526
SDVersion version = VERSION_1_x;
501
- std::map<std::string, int32_t > encoder;
527
+ std::map<int , std::u32string> byte_encoder;
528
+ std::map<std::u32string, int > encoder;
529
+ std::map<std::pair<std::u32string, std::u32string>, int > bpe_ranks;
502
530
std::regex pat;
503
531
504
532
static std::string strip (const std::string& str) {
@@ -521,19 +549,61 @@ class CLIPTokenizer {
521
549
522
550
public:
523
551
CLIPTokenizer (SDVersion version = VERSION_1_x)
524
- : version(version){};
525
- std::string bpe (std::string token) {
526
- std::string word = token + " </w>" ;
552
+ : version(version) {}
553
+
554
+ void load_from_merges (const std::string& merges_utf8_str) {
555
+ auto byte_unicode_pairs = bytes_to_unicode ();
556
+ byte_encoder = std::map<int , std::u32string>(byte_unicode_pairs.begin (), byte_unicode_pairs.end ());
557
+ // for (auto & pair: byte_unicode_pairs) {
558
+ // std::cout << pair.first << ": " << pair.second << std::endl;
559
+ // }
560
+ std::vector<std::u32string> merges;
561
+ size_t start = 0 ;
562
+ size_t pos;
563
+ std::u32string merges_utf32_str = utf8_to_utf32 (merges_utf8_str);
564
+ while ((pos = merges_utf32_str.find (' \n ' , start)) != std::string::npos) {
565
+ merges.push_back (merges_utf32_str.substr (start, pos - start));
566
+ start = pos + 1 ;
567
+ }
568
+ merges = std::vector<std::u32string>(merges.begin () + 1 , merges.begin () + 49152 - 256 - 2 + 1 );
569
+ std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
570
+ for (const auto & merge : merges) {
571
+ size_t space_pos = merge.find (' ' );
572
+ merge_pairs.emplace_back (merge.substr (0 , space_pos), merge.substr (space_pos + 1 ));
573
+ // LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
574
+ }
575
+ std::vector<std::u32string> vocab;
576
+ for (const auto & pair : byte_unicode_pairs) {
577
+ vocab.push_back (pair.second );
578
+ }
579
+ for (const auto & pair : byte_unicode_pairs) {
580
+ vocab.push_back (pair.second + utf8_to_utf32 (" </w>" ));
581
+ }
582
+ for (const auto & merge : merge_pairs) {
583
+ vocab.push_back (merge.first + merge.second );
584
+ }
585
+ vocab.push_back (utf8_to_utf32 (" <|startoftext|>" ));
586
+ vocab.push_back (utf8_to_utf32 (" <|endoftext|>" ));
587
+ LOG_DEBUG (" vocab size: %llu" , vocab.size ());
588
+ int i = 0 ;
589
+ for (const auto & token : vocab) {
590
+ encoder[token] = i++;
591
+ }
592
+
593
+ int rank = 0 ;
594
+ for (const auto & merge : merge_pairs) {
595
+ bpe_ranks[merge] = rank++;
596
+ }
597
+ };
598
+
599
+ std::u32string bpe (std::u32string token) {
600
+ std::u32string word = token + utf8_to_utf32 (" </w>" );
527
601
if (encoder.find (word) != encoder.end ()) {
528
602
return word;
529
603
} else if (encoder.find (token) != encoder.end ()) {
530
604
return token;
531
605
}
532
- return UNK_TOKEN;
533
- }
534
-
535
- void add_token (std::string token, int32_t token_id) {
536
- encoder[token] = token_id;
606
+ return utf8_to_utf32 (UNK_TOKEN);
537
607
}
538
608
539
609
std::vector<int > tokenize (std::string text, size_t max_length = 0 , bool padding = false ) {
@@ -571,13 +641,25 @@ class CLIPTokenizer {
571
641
std::vector<std::string> token_strs;
572
642
while (std::regex_search (str, matches, pat)) {
573
643
for (auto & token : matches) {
574
- std::istringstream iss (bpe (token));
575
- std::vector<std::string> tokens{std::istream_iterator<std::string>{iss},
576
- std::istream_iterator<std::string>{}};
577
- for (const auto & bpe_token : tokens) {
578
- bpe_tokens.push_back (encoder[bpe_token]);
579
- token_strs.push_back (bpe_token);
644
+ std::string token_str = token.str ();
645
+ std::u32string utf32_token;
646
+ for (int i = 0 ; i < token_str.length (); i++) {
647
+ char b = token_str[i];
648
+ utf32_token += byte_encoder[b];
580
649
}
650
+ auto bpe_strs = bpe (utf32_token);
651
+ size_t start = 0 ;
652
+ size_t pos;
653
+ while ((pos = bpe_strs.find (' ' , start)) != std::u32string::npos) {
654
+ auto bpe_str = bpe_strs.substr (start, pos - start);
655
+ bpe_tokens.push_back (encoder[bpe_str]);
656
+ token_strs.push_back (utf32_to_utf8 (bpe_str));
657
+
658
+ start = pos + 1 ;
659
+ }
660
+ auto bpe_str = bpe_strs.substr (start, bpe_strs.size () - start);
661
+ bpe_tokens.push_back (encoder[bpe_str]);
662
+ token_strs.push_back (utf32_to_utf8 (bpe_str));
581
663
}
582
664
str = matches.suffix ();
583
665
}
@@ -4323,15 +4405,14 @@ class StableDiffusionGGML {
4323
4405
LOG_INFO (" Stable Diffusion weight type: %s" , ggml_type_name (model_data_type));
4324
4406
4325
4407
LOG_DEBUG (" loading vocab" );
4326
- auto add_token = [&](const std::string& token, int32_t token_id) {
4327
- cond_stage_model.tokenizer .add_token (token, token_id);
4328
- };
4329
- bool success = model_loader.load_vocab (add_token);
4330
- if (!success) {
4331
- LOG_ERROR (" get vocab from file failed: '%s'" , model_path.c_str ());
4408
+ std::string merges_utf8_str = model_loader.load_merges ();
4409
+ if (merges_utf8_str.size () == 0 ) {
4410
+ LOG_ERROR (" get merges failed: '%s'" , model_path.c_str ());
4332
4411
return false ;
4333
4412
}
4334
4413
4414
+ cond_stage_model.tokenizer .load_from_merges (merges_utf8_str);
4415
+
4335
4416
// create the ggml context for network params
4336
4417
LOG_DEBUG (" ggml tensor size = %d bytes" , (int )sizeof (ggml_tensor));
4337
4418
@@ -4431,7 +4512,7 @@ class StableDiffusionGGML {
4431
4512
4432
4513
// print_ggml_tensor(alphas_cumprod_tensor);
4433
4514
4434
- success = model_loader.load_tensors (on_new_tensor_cb);
4515
+ bool success = model_loader.load_tensors (on_new_tensor_cb);
4435
4516
if (!success) {
4436
4517
LOG_ERROR (" load tensors from file failed" );
4437
4518
ggml_free (ctx);
0 commit comments