Skip to content

Commit 0e64238

Browse files
Cyberhan123leejet
andauthored
feat: implement the complete bpe function (leejet#119)
* implement the complete bpe function --------- Co-authored-by: leejet <leejet714@gmail.com>
1 parent 8f6b4a3 commit 0e64238

File tree

4 files changed

+95
-11
lines changed

4 files changed

+95
-11
lines changed

.clang-format

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ IndentCaseLabels: false
77
ColumnLimit: 0
88
AccessModifierOffset: -4
99
NamespaceIndentation: All
10-
FixNamespaceComments: false
10+
FixNamespaceComments: false
1111
AlignAfterOpenBracket: true
1212
AlignConsecutiveAssignments: true
1313
IndentCaseLabels: true

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
5151
- The current implementation of ggml_conv_2d is slow and has high memory usage
5252
- Implement Winograd Convolution 2D for 3x3 kernel filtering
5353
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
54-
- [ ] Implement BPE Tokenizer
5554
- [ ] Implement [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN/tree/master) upscaler
5655
- [ ] k-quants support
5756

stable-diffusion.cpp

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,6 @@ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
520520
}
521521

522522
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
523-
// TODO: implement bpe
524523
class CLIPTokenizer {
525524
private:
526525
SDVersion version = VERSION_1_x;
@@ -547,6 +546,21 @@ class CLIPTokenizer {
547546
return text;
548547
}
549548

549+
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
550+
std::set<std::pair<std::u32string, std::u32string>> pairs;
551+
if (subwords.size() == 0) {
552+
return pairs;
553+
}
554+
std::u32string prev_subword = subwords[0];
555+
for (int i = 1; i < subwords.size(); i++) {
556+
std::u32string subword = subwords[i];
557+
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
558+
pairs.insert(pair);
559+
prev_subword = subword;
560+
}
561+
return pairs;
562+
}
563+
550564
public:
551565
CLIPTokenizer(SDVersion version = VERSION_1_x)
552566
: version(version) {}
@@ -565,7 +579,9 @@ class CLIPTokenizer {
565579
merges.push_back(merges_utf32_str.substr(start, pos - start));
566580
start = pos + 1;
567581
}
568-
merges = std::vector<std::u32string>(merges.begin() + 1, merges.begin() + 49152 - 256 - 2 + 1);
582+
// LOG_DEBUG("merges size %llu", merges.size());
583+
GGML_ASSERT(merges.size() == 48895);
584+
merges = std::vector<std::u32string>(merges.begin() + 1, merges.end());
569585
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
570586
for (const auto& merge : merges) {
571587
size_t space_pos = merge.find(' ');
@@ -596,14 +612,79 @@ class CLIPTokenizer {
596612
}
597613
};
598614

599-
std::u32string bpe(std::u32string token) {
600-
std::u32string word = token + utf8_to_utf32("</w>");
601-
if (encoder.find(word) != encoder.end()) {
602-
return word;
603-
} else if (encoder.find(token) != encoder.end()) {
604-
return token;
615+
std::u32string bpe(const std::u32string& token) {
616+
std::vector<std::u32string> word;
617+
618+
for (int i = 0; i < token.size() - 1; i++) {
619+
word.emplace_back(1, token[i]);
620+
}
621+
word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("</w>"));
622+
623+
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
624+
625+
if (pairs.empty()) {
626+
return token + utf8_to_utf32("</w>");
605627
}
606-
return utf8_to_utf32(UNK_TOKEN);
628+
629+
while (true) {
630+
auto min_pair_iter = std::min_element(pairs.begin(),
631+
pairs.end(),
632+
[&](const std::pair<std::u32string, std::u32string>& a,
633+
const std::pair<std::u32string, std::u32string>& b) {
634+
if (bpe_ranks.find(a) == bpe_ranks.end()) {
635+
return false;
636+
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
637+
return true;
638+
}
639+
return bpe_ranks.at(a) < bpe_ranks.at(b);
640+
});
641+
642+
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
643+
644+
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
645+
break;
646+
}
647+
648+
std::u32string first = bigram.first;
649+
std::u32string second = bigram.second;
650+
std::vector<std::u32string> new_word;
651+
int32_t i = 0;
652+
653+
while (i < word.size()) {
654+
auto it = std::find(word.begin() + i, word.end(), first);
655+
if (it == word.end()) {
656+
new_word.insert(new_word.end(), word.begin() + i, word.end());
657+
break;
658+
}
659+
new_word.insert(new_word.end(), word.begin() + i, it);
660+
i = static_cast<int32_t>(std::distance(word.begin(), it));
661+
662+
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
663+
new_word.push_back(first + second);
664+
i += 2;
665+
} else {
666+
new_word.push_back(word[i]);
667+
i += 1;
668+
}
669+
}
670+
671+
word = new_word;
672+
673+
if (word.size() == 1) {
674+
break;
675+
}
676+
pairs = get_pairs(word);
677+
}
678+
679+
std::u32string result;
680+
for (int i = 0; i < word.size(); i++) {
681+
result += word[i];
682+
if (i != word.size() - 1) {
683+
result += utf8_to_utf32(" ");
684+
}
685+
}
686+
687+
return result;
607688
}
608689

609690
std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {

stable-diffusion.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <string>
66
#include <vector>
77

8+
#include "ggml/ggml.h"
9+
810
enum RNGType {
911
STD_DEFAULT_RNG,
1012
CUDA_RNG
@@ -42,10 +44,12 @@ class StableDiffusion {
4244
bool free_params_immediately = false,
4345
std::string lora_model_dir = "",
4446
RNGType rng_type = STD_DEFAULT_RNG);
47+
4548
bool load_from_file(const std::string& model_path,
4649
const std::string& vae_path,
4750
ggml_type wtype,
4851
Schedule d = DEFAULT);
52+
4953
std::vector<uint8_t*> txt2img(
5054
std::string prompt,
5155
std::string negative_prompt,

0 commit comments

Comments
 (0)