Skip to content

Commit 8f6b4a3

Browse files
authored
fix: enhance the tokenizer's handing of Unicode (leejet#120)
1 parent 9842a3f commit 8f6b4a3

File tree

6 files changed

+43819
-80106
lines changed

6 files changed

+43819
-80106
lines changed

model.cpp

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,20 +1192,9 @@ ggml_type ModelLoader::get_sd_wtype() {
11921192
return GGML_TYPE_COUNT;
11931193
}
11941194

1195-
bool ModelLoader::load_vocab(on_new_token_cb_t on_new_token_cb) {
1196-
char* vocab_buffer = reinterpret_cast<char*>(vocab_json);
1197-
nlohmann::json vocab = nlohmann::json::parse(vocab_buffer);
1198-
std::map<char, int> decoder = unicode_to_byte();
1199-
for (auto& it : vocab.items()) {
1200-
int token_id = it.value();
1201-
std::string token_str = it.key();
1202-
std::string token = "";
1203-
for (char c : token_str) {
1204-
token += decoder[c];
1205-
}
1206-
on_new_token_cb(token, token_id);
1207-
}
1208-
return true;
1195+
std::string ModelLoader::load_merges() {
1196+
std::string merges_utf8_str(reinterpret_cast<const char*>(merges_utf8_c_str), sizeof(merges_utf8_c_str));
1197+
return merges_utf8_str;
12091198
}
12101199

12111200
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {

model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class ModelLoader {
115115
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
116116
SDVersion get_sd_version();
117117
ggml_type get_sd_wtype();
118-
bool load_vocab(on_new_token_cb_t on_new_token_cb);
118+
std::string load_merges();
119119
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb);
120120
int64_t cal_mem_size(ggml_backend_t backend);
121121
~ModelLoader() = default;

stable-diffusion.cpp

Lines changed: 103 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -493,12 +493,40 @@ const int BOS_TOKEN_ID = 49406;
493493
const int EOS_TOKEN_ID = 49407;
494494
const int PAD_TOKEN_ID = 49407;
495495

496+
std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
497+
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
498+
std::set<int> byte_set;
499+
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
500+
byte_set.insert(b);
501+
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
502+
}
503+
for (int b = 161; b <= 172; ++b) {
504+
byte_set.insert(b);
505+
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
506+
}
507+
for (int b = 174; b <= 255; ++b) {
508+
byte_set.insert(b);
509+
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
510+
}
511+
int n = 0;
512+
for (int b = 0; b < 256; ++b) {
513+
if (byte_set.find(b) == byte_set.end()) {
514+
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(n + 256)));
515+
++n;
516+
}
517+
}
518+
// LOG_DEBUG("byte_unicode_pairs %d", byte_unicode_pairs.size());
519+
return byte_unicode_pairs;
520+
}
521+
496522
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
497523
// TODO: implement bpe
498524
class CLIPTokenizer {
499525
private:
500526
SDVersion version = VERSION_1_x;
501-
std::map<std::string, int32_t> encoder;
527+
std::map<int, std::u32string> byte_encoder;
528+
std::map<std::u32string, int> encoder;
529+
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
502530
std::regex pat;
503531

504532
static std::string strip(const std::string& str) {
@@ -521,19 +549,61 @@ class CLIPTokenizer {
521549

522550
public:
523551
CLIPTokenizer(SDVersion version = VERSION_1_x)
524-
: version(version){};
525-
std::string bpe(std::string token) {
526-
std::string word = token + "</w>";
552+
: version(version) {}
553+
554+
void load_from_merges(const std::string& merges_utf8_str) {
555+
auto byte_unicode_pairs = bytes_to_unicode();
556+
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
557+
// for (auto & pair: byte_unicode_pairs) {
558+
// std::cout << pair.first << ": " << pair.second << std::endl;
559+
// }
560+
std::vector<std::u32string> merges;
561+
size_t start = 0;
562+
size_t pos;
563+
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
564+
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
565+
merges.push_back(merges_utf32_str.substr(start, pos - start));
566+
start = pos + 1;
567+
}
568+
merges = std::vector<std::u32string>(merges.begin() + 1, merges.begin() + 49152 - 256 - 2 + 1);
569+
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
570+
for (const auto& merge : merges) {
571+
size_t space_pos = merge.find(' ');
572+
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
573+
// LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
574+
}
575+
std::vector<std::u32string> vocab;
576+
for (const auto& pair : byte_unicode_pairs) {
577+
vocab.push_back(pair.second);
578+
}
579+
for (const auto& pair : byte_unicode_pairs) {
580+
vocab.push_back(pair.second + utf8_to_utf32("</w>"));
581+
}
582+
for (const auto& merge : merge_pairs) {
583+
vocab.push_back(merge.first + merge.second);
584+
}
585+
vocab.push_back(utf8_to_utf32("<|startoftext|>"));
586+
vocab.push_back(utf8_to_utf32("<|endoftext|>"));
587+
LOG_DEBUG("vocab size: %llu", vocab.size());
588+
int i = 0;
589+
for (const auto& token : vocab) {
590+
encoder[token] = i++;
591+
}
592+
593+
int rank = 0;
594+
for (const auto& merge : merge_pairs) {
595+
bpe_ranks[merge] = rank++;
596+
}
597+
};
598+
599+
std::u32string bpe(std::u32string token) {
600+
std::u32string word = token + utf8_to_utf32("</w>");
527601
if (encoder.find(word) != encoder.end()) {
528602
return word;
529603
} else if (encoder.find(token) != encoder.end()) {
530604
return token;
531605
}
532-
return UNK_TOKEN;
533-
}
534-
535-
void add_token(std::string token, int32_t token_id) {
536-
encoder[token] = token_id;
606+
return utf8_to_utf32(UNK_TOKEN);
537607
}
538608

539609
std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {
@@ -571,13 +641,25 @@ class CLIPTokenizer {
571641
std::vector<std::string> token_strs;
572642
while (std::regex_search(str, matches, pat)) {
573643
for (auto& token : matches) {
574-
std::istringstream iss(bpe(token));
575-
std::vector<std::string> tokens{std::istream_iterator<std::string>{iss},
576-
std::istream_iterator<std::string>{}};
577-
for (const auto& bpe_token : tokens) {
578-
bpe_tokens.push_back(encoder[bpe_token]);
579-
token_strs.push_back(bpe_token);
644+
std::string token_str = token.str();
645+
std::u32string utf32_token;
646+
for (int i = 0; i < token_str.length(); i++) {
647+
char b = token_str[i];
648+
utf32_token += byte_encoder[b];
580649
}
650+
auto bpe_strs = bpe(utf32_token);
651+
size_t start = 0;
652+
size_t pos;
653+
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
654+
auto bpe_str = bpe_strs.substr(start, pos - start);
655+
bpe_tokens.push_back(encoder[bpe_str]);
656+
token_strs.push_back(utf32_to_utf8(bpe_str));
657+
658+
start = pos + 1;
659+
}
660+
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
661+
bpe_tokens.push_back(encoder[bpe_str]);
662+
token_strs.push_back(utf32_to_utf8(bpe_str));
581663
}
582664
str = matches.suffix();
583665
}
@@ -4323,15 +4405,14 @@ class StableDiffusionGGML {
43234405
LOG_INFO("Stable Diffusion weight type: %s", ggml_type_name(model_data_type));
43244406

43254407
LOG_DEBUG("loading vocab");
4326-
auto add_token = [&](const std::string& token, int32_t token_id) {
4327-
cond_stage_model.tokenizer.add_token(token, token_id);
4328-
};
4329-
bool success = model_loader.load_vocab(add_token);
4330-
if (!success) {
4331-
LOG_ERROR("get vocab from file failed: '%s'", model_path.c_str());
4408+
std::string merges_utf8_str = model_loader.load_merges();
4409+
if (merges_utf8_str.size() == 0) {
4410+
LOG_ERROR("get merges failed: '%s'", model_path.c_str());
43324411
return false;
43334412
}
43344413

4414+
cond_stage_model.tokenizer.load_from_merges(merges_utf8_str);
4415+
43354416
// create the ggml context for network params
43364417
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
43374418

@@ -4431,7 +4512,7 @@ class StableDiffusionGGML {
44314512

44324513
// print_ggml_tensor(alphas_cumprod_tensor);
44334514

4434-
success = model_loader.load_tensors(on_new_tensor_cb);
4515+
bool success = model_loader.load_tensors(on_new_tensor_cb);
44354516
if (!success) {
44364517
LOG_ERROR("load tensors from file failed");
44374518
ggml_free(ctx);

util.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#include "util.h"
22

33
#include <stdarg.h>
4+
#include <codecvt>
45
#include <fstream>
6+
#include <locale>
57
#include <thread>
68
#include <unordered_set>
79
#include <vector>
@@ -119,6 +121,21 @@ int32_t get_num_physical_cores() {
119121
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
120122
}
121123

124+
std::u32string utf8_to_utf32(const std::string& utf8_str) {
125+
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
126+
return converter.from_bytes(utf8_str);
127+
}
128+
129+
std::string utf32_to_utf8(const std::u32string& utf32_str) {
130+
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
131+
return converter.to_bytes(utf32_str);
132+
}
133+
134+
std::u32string unicode_value_to_utf32(int unicode_value) {
135+
std::u32string utf32_string = {static_cast<char32_t>(unicode_value)};
136+
return utf32_string;
137+
}
138+
122139
std::string basename(const std::string& path) {
123140
size_t pos = path.find_last_of('/');
124141
if (pos != std::string::npos) {

util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ void replace_all_chars(std::string& str, char target, char replacement);
1414
bool file_exists(const std::string& filename);
1515
bool is_directory(const std::string& path);
1616

17+
std::u32string utf8_to_utf32(const std::string& utf8_str);
18+
std::string utf32_to_utf8(const std::u32string& utf32_str);
19+
std::u32string unicode_value_to_utf32(int unicode_value);
20+
1721
std::string basename(const std::string& path);
1822

1923
std::string path_join(const std::string& p1, const std::string& p2);

0 commit comments

Comments
 (0)