Skip to content

Commit 17095dd

Browse files
authored
feat: add token weighting support (leejet#13)
1 parent 7132027 commit 17095dd

File tree

2 files changed

+200
-7
lines changed

2 files changed

+200
-7
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
1616
- AVX, AVX2 and AVX512 support for x86 architectures
1717
- Original `txt2img` and `img2img` mode
1818
- Negative prompt
19+
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
1920
- Sampling method
2021
- `Euler A`
2122
- Supported platforms
@@ -30,7 +31,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
3031
- [ ] Make inference faster
3132
- The current implementation of ggml_conv_2d is slow and has high memory usage
3233
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
33-
- [ ] [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (eg: token weighting, ...)
3434
- [ ] LoRA support
3535
- [ ] k-quants support
3636
- [ ] Cross-platform reproducibility (perhaps ensuring consistency with the original SD)

stable-diffusion.cpp

Lines changed: 199 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,113 @@ class CLIPTokenizer {
355355
}
356356
};
357357

358+
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/prompt_parser.py#L345
359+
//
360+
// Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
361+
// Accepted tokens are:
362+
// (abc) - increases attention to abc by a multiplier of 1.1
363+
// (abc:3.12) - increases attention to abc by a multiplier of 3.12
364+
// [abc] - decreases attention to abc by a multiplier of 1.1
365+
// \( - literal character '('
366+
// \[ - literal character '['
367+
// \) - literal character ')'
368+
// \] - literal character ']'
369+
// \\ - literal character '\'
370+
// anything else - just text
371+
//
372+
// >>> parse_prompt_attention('normal text')
373+
// [['normal text', 1.0]]
374+
// >>> parse_prompt_attention('an (important) word')
375+
// [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
376+
// >>> parse_prompt_attention('(unbalanced')
377+
// [['unbalanced', 1.1]]
378+
// >>> parse_prompt_attention('\(literal\]')
379+
// [['(literal]', 1.0]]
380+
// >>> parse_prompt_attention('(unnecessary)(parens)')
381+
// [['unnecessaryparens', 1.1]]
382+
// >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
383+
// [['a ', 1.0],
384+
// ['house', 1.5730000000000004],
385+
// [' ', 1.1],
386+
// ['on', 1.0],
387+
// [' a ', 1.1],
388+
// ['hill', 0.55],
389+
// [', sun, ', 1.1],
390+
// ['sky', 1.4641000000000006],
391+
// ['.', 1.1]]
392+
std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::string& text) {
393+
std::vector<std::pair<std::string, float>> res;
394+
std::vector<int> round_brackets;
395+
std::vector<int> square_brackets;
396+
397+
float round_bracket_multiplier = 1.1f;
398+
float square_bracket_multiplier = 1 / 1.1f;
399+
400+
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|]|[^\\()\[\]:]+|:)");
401+
std::regex re_break(R"(\s*\bBREAK\b\s*)");
402+
403+
auto multiply_range = [&](int start_position, float multiplier) {
404+
for (int p = start_position; p < res.size(); ++p) {
405+
res[p].second *= multiplier;
406+
}
407+
};
408+
409+
std::smatch m;
410+
std::string remaining_text = text;
411+
412+
while (std::regex_search(remaining_text, m, re_attention)) {
413+
std::string text = m[0];
414+
std::string weight = m[1];
415+
416+
if (text == "(") {
417+
round_brackets.push_back(res.size());
418+
} else if (text == "[") {
419+
square_brackets.push_back(res.size());
420+
} else if (!weight.empty()) {
421+
if (!round_brackets.empty()) {
422+
multiply_range(round_brackets.back(), std::stod(weight));
423+
round_brackets.pop_back();
424+
}
425+
} else if (text == ")" && !round_brackets.empty()) {
426+
multiply_range(round_brackets.back(), round_bracket_multiplier);
427+
round_brackets.pop_back();
428+
} else if (text == "]" && !square_brackets.empty()) {
429+
multiply_range(square_brackets.back(), square_bracket_multiplier);
430+
square_brackets.pop_back();
431+
} else if (text == "\\(") {
432+
res.push_back({text.substr(1), 1.0f});
433+
} else {
434+
res.push_back({text, 1.0f});
435+
}
436+
437+
remaining_text = m.suffix();
438+
}
439+
440+
for (int pos : round_brackets) {
441+
multiply_range(pos, round_bracket_multiplier);
442+
}
443+
444+
for (int pos : square_brackets) {
445+
multiply_range(pos, square_bracket_multiplier);
446+
}
447+
448+
if (res.empty()) {
449+
res.push_back({"", 1.0f});
450+
}
451+
452+
int i = 0;
453+
while (i + 1 < res.size()) {
454+
if (res[i].second == res[i + 1].second) {
455+
res[i].first += res[i + 1].first;
456+
res.erase(res.begin() + i + 1);
457+
} else {
458+
++i;
459+
}
460+
}
461+
462+
return res;
463+
}
464+
358465
/*================================================ FrozenCLIPEmbedder ================================================*/
359466

360467
struct ResidualAttentionBlock {
@@ -639,6 +746,61 @@ struct FrozenCLIPEmbedder {
639746
}
640747
};
641748

749+
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
750+
struct FrozenCLIPEmbedderWithCustomWords {
751+
CLIPTokenizer tokenizer;
752+
CLIPTextModel text_model;
753+
754+
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
755+
size_t max_length = 0,
756+
bool padding = false) {
757+
auto parsed_attention = parse_prompt_attention(text);
758+
759+
{
760+
std::stringstream ss;
761+
ss << "[";
762+
for (const auto& item : parsed_attention) {
763+
ss << "['" << item.first << "', " << item.second << "], ";
764+
}
765+
ss << "]";
766+
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
767+
}
768+
769+
std::vector<int> tokens;
770+
std::vector<float> weights;
771+
for (const auto& item : parsed_attention) {
772+
const std::string& curr_text = item.first;
773+
float curr_weight = item.second;
774+
std::vector<int> curr_tokens = tokenizer.encode(curr_text);
775+
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
776+
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
777+
}
778+
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
779+
weights.insert(weights.begin(), 1.0);
780+
781+
if (max_length > 0) {
782+
if (tokens.size() > max_length - 1) {
783+
tokens.resize(max_length - 1);
784+
weights.resize(max_length - 1);
785+
} else {
786+
if (padding) {
787+
tokens.insert(tokens.end(), max_length - 1 - tokens.size(), PAD_TOKEN_ID);
788+
weights.insert(weights.end(), max_length - 1 - weights.size(), 1.0);
789+
}
790+
}
791+
}
792+
tokens.push_back(EOS_TOKEN_ID);
793+
weights.push_back(1.0);
794+
795+
// for (int i = 0; i < tokens.size(); i++) {
796+
// std::cout << tokens[i] << ":" << weights[i] << ", ";
797+
// }
798+
// std::cout << std::endl;
799+
800+
return {tokens, weights};
801+
}
802+
};
803+
642804
/*==================================================== UnetModel =====================================================*/
643805

644806
struct ResBlock {
@@ -2489,7 +2651,7 @@ class StableDiffusionGGML {
24892651
size_t max_params_mem_size = 0;
24902652
size_t max_rt_mem_size = 0;
24912653

2492-
FrozenCLIPEmbedder cond_stage_model;
2654+
FrozenCLIPEmbedderWithCustomWords cond_stage_model;
24932655
UNetModel diffusion_model;
24942656
AutoEncoderKL first_stage_model;
24952657

@@ -2784,9 +2946,11 @@ class StableDiffusionGGML {
27842946
}
27852947

27862948
ggml_tensor* get_learned_condition(ggml_context* res_ctx, const std::string& text) {
2787-
std::vector<int32_t> tokens = cond_stage_model.tokenizer.tokenize(text,
2788-
cond_stage_model.text_model.max_position_embeddings,
2789-
true);
2949+
auto tokens_and_weights = cond_stage_model.tokenize(text,
2950+
cond_stage_model.text_model.max_position_embeddings,
2951+
true);
2952+
std::vector<int>& tokens = tokens_and_weights.first;
2953+
std::vector<float>& weights = tokens_and_weights.second;
27902954
size_t ctx_size = 1 * 1024 * 1024; // 1MB
27912955
// calculate the amount of memory required
27922956
{
@@ -2848,10 +3012,39 @@ class StableDiffusionGGML {
28483012
int64_t t1 = ggml_time_ms();
28493013
LOG_DEBUG("computing condition graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
28503014

2851-
ggml_tensor* result = ggml_dup_tensor(res_ctx, hidden_states);
2852-
copy_ggml_tensor(result, hidden_states);
3015+
ggml_tensor* result = ggml_dup_tensor(res_ctx, hidden_states); // [N, n_token, hidden_size]
3016+
3017+
{
3018+
int64_t nelements = ggml_nelements(hidden_states);
3019+
float original_mean = 0.f;
3020+
float new_mean = 0.f;
3021+
float* vec = (float*)hidden_states->data;
3022+
for (int i = 0; i < nelements; i++) {
3023+
original_mean += vec[i] / nelements * 1.0f;
3024+
}
3025+
3026+
for (int i2 = 0; i2 < hidden_states->ne[2]; i2++) {
3027+
for (int i1 = 0; i1 < hidden_states->ne[1]; i1++) {
3028+
for (int i0 = 0; i0 < hidden_states->ne[0]; i0++) {
3029+
float value = ggml_tensor_get_f32(hidden_states, i0, i1, i2);
3030+
value *= weights[i1];
3031+
ggml_tensor_set_f32(result, value, i0, i1, i2);
3032+
}
3033+
}
3034+
}
3035+
3036+
vec = (float*)result->data;
3037+
for (int i = 0; i < nelements; i++) {
3038+
new_mean += vec[i] / nelements * 1.0f;
3039+
}
3040+
3041+
for (int i = 0; i < nelements; i++) {
3042+
vec[i] = vec[i] * (original_mean / new_mean);
3043+
}
3044+
}
28533045

28543046
// print_ggml_tensor(result);
3047+
28553048
size_t rt_mem_size = ctx_size + ggml_curr_max_dynamic_size();
28563049
if (rt_mem_size > max_rt_mem_size) {
28573050
max_rt_mem_size = rt_mem_size;

0 commit comments

Comments
 (0)