@@ -520,7 +520,6 @@ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
520
520
}
521
521
522
522
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
523
- // TODO: implement bpe
524
523
class CLIPTokenizer {
525
524
private:
526
525
SDVersion version = VERSION_1_x;
@@ -547,6 +546,21 @@ class CLIPTokenizer {
547
546
return text;
548
547
}
549
548
549
+ static std::set<std::pair<std::u32string, std::u32string>> get_pairs (const std::vector<std::u32string>& subwords) {
550
+ std::set<std::pair<std::u32string, std::u32string>> pairs;
551
+ if (subwords.size () == 0 ) {
552
+ return pairs;
553
+ }
554
+ std::u32string prev_subword = subwords[0 ];
555
+ for (int i = 1 ; i < subwords.size (); i++) {
556
+ std::u32string subword = subwords[i];
557
+ std::pair<std::u32string, std::u32string> pair (prev_subword, subword);
558
+ pairs.insert (pair);
559
+ prev_subword = subword;
560
+ }
561
+ return pairs;
562
+ }
563
+
550
564
public:
551
565
CLIPTokenizer (SDVersion version = VERSION_1_x)
552
566
: version(version) {}
@@ -565,7 +579,9 @@ class CLIPTokenizer {
565
579
merges.push_back (merges_utf32_str.substr (start, pos - start));
566
580
start = pos + 1 ;
567
581
}
568
- merges = std::vector<std::u32string>(merges.begin () + 1 , merges.begin () + 49152 - 256 - 2 + 1 );
582
+ // LOG_DEBUG("merges size %llu", merges.size());
583
+ GGML_ASSERT (merges.size () == 48895 );
584
+ merges = std::vector<std::u32string>(merges.begin () + 1 , merges.end ());
569
585
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
570
586
for (const auto & merge : merges) {
571
587
size_t space_pos = merge.find (' ' );
@@ -596,14 +612,79 @@ class CLIPTokenizer {
596
612
}
597
613
};
598
614
599
- std::u32string bpe (std::u32string token) {
600
- std::u32string word = token + utf8_to_utf32 (" </w>" );
601
- if (encoder.find (word) != encoder.end ()) {
602
- return word;
603
- } else if (encoder.find (token) != encoder.end ()) {
604
- return token;
615
+ std::u32string bpe (const std::u32string& token) {
616
+ std::vector<std::u32string> word;
617
+
618
+ for (int i = 0 ; i < token.size () - 1 ; i++) {
619
+ word.emplace_back (1 , token[i]);
620
+ }
621
+ word.push_back (token.substr (token.size () - 1 ) + utf8_to_utf32 (" </w>" ));
622
+
623
+ std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs (word);
624
+
625
+ if (pairs.empty ()) {
626
+ return token + utf8_to_utf32 (" </w>" );
605
627
}
606
- return utf8_to_utf32 (UNK_TOKEN);
628
+
629
+ while (true ) {
630
+ auto min_pair_iter = std::min_element (pairs.begin (),
631
+ pairs.end (),
632
+ [&](const std::pair<std::u32string, std::u32string>& a,
633
+ const std::pair<std::u32string, std::u32string>& b) {
634
+ if (bpe_ranks.find (a) == bpe_ranks.end ()) {
635
+ return false ;
636
+ } else if (bpe_ranks.find (b) == bpe_ranks.end ()) {
637
+ return true ;
638
+ }
639
+ return bpe_ranks.at (a) < bpe_ranks.at (b);
640
+ });
641
+
642
+ const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
643
+
644
+ if (bpe_ranks.find (bigram) == bpe_ranks.end ()) {
645
+ break ;
646
+ }
647
+
648
+ std::u32string first = bigram.first ;
649
+ std::u32string second = bigram.second ;
650
+ std::vector<std::u32string> new_word;
651
+ int32_t i = 0 ;
652
+
653
+ while (i < word.size ()) {
654
+ auto it = std::find (word.begin () + i, word.end (), first);
655
+ if (it == word.end ()) {
656
+ new_word.insert (new_word.end (), word.begin () + i, word.end ());
657
+ break ;
658
+ }
659
+ new_word.insert (new_word.end (), word.begin () + i, it);
660
+ i = static_cast <int32_t >(std::distance (word.begin (), it));
661
+
662
+ if (word[i] == first && i < static_cast <int32_t >(word.size ()) - 1 && word[i + 1 ] == second) {
663
+ new_word.push_back (first + second);
664
+ i += 2 ;
665
+ } else {
666
+ new_word.push_back (word[i]);
667
+ i += 1 ;
668
+ }
669
+ }
670
+
671
+ word = new_word;
672
+
673
+ if (word.size () == 1 ) {
674
+ break ;
675
+ }
676
+ pairs = get_pairs (word);
677
+ }
678
+
679
+ std::u32string result;
680
+ for (int i = 0 ; i < word.size (); i++) {
681
+ result += word[i];
682
+ if (i != word.size () - 1 ) {
683
+ result += utf8_to_utf32 (" " );
684
+ }
685
+ }
686
+
687
+ return result;
607
688
}
608
689
609
690
std::vector<int > tokenize (std::string text, size_t max_length = 0 , bool padding = false ) {
0 commit comments