@@ -355,6 +355,113 @@ class CLIPTokenizer {
355
355
}
356
356
};
357
357
358
+ // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/prompt_parser.py#L345
359
+ //
360
+ // Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
361
+ // Accepted tokens are:
362
+ // (abc) - increases attention to abc by a multiplier of 1.1
363
+ // (abc:3.12) - increases attention to abc by a multiplier of 3.12
364
+ // [abc] - decreases attention to abc by a multiplier of 1.1
365
+ // \( - literal character '('
366
+ // \[ - literal character '['
367
+ // \) - literal character ')'
368
+ // \] - literal character ']'
369
+ // \\ - literal character '\'
370
+ // anything else - just text
371
+ //
372
+ // >>> parse_prompt_attention('normal text')
373
+ // [['normal text', 1.0]]
374
+ // >>> parse_prompt_attention('an (important) word')
375
+ // [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
376
+ // >>> parse_prompt_attention('(unbalanced')
377
+ // [['unbalanced', 1.1]]
378
+ // >>> parse_prompt_attention('\(literal\]')
379
+ // [['(literal]', 1.0]]
380
+ // >>> parse_prompt_attention('(unnecessary)(parens)')
381
+ // [['unnecessaryparens', 1.1]]
382
+ // >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
383
+ // [['a ', 1.0],
384
+ // ['house', 1.5730000000000004],
385
+ // [' ', 1.1],
386
+ // ['on', 1.0],
387
+ // [' a ', 1.1],
388
+ // ['hill', 0.55],
389
+ // [', sun, ', 1.1],
390
+ // ['sky', 1.4641000000000006],
391
+ // ['.', 1.1]]
392
+ std::vector<std::pair<std::string, float >> parse_prompt_attention (const std::string& text) {
393
+ std::vector<std::pair<std::string, float >> res;
394
+ std::vector<int > round_brackets;
395
+ std::vector<int > square_brackets;
396
+
397
+ float round_bracket_multiplier = 1 .1f ;
398
+ float square_bracket_multiplier = 1 / 1 .1f ;
399
+
400
+ std::regex re_attention (R"( \\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|]|[^\\()\[\]:]+|:)" );
401
+ std::regex re_break (R"( \s*\bBREAK\b\s*)" );
402
+
403
+ auto multiply_range = [&](int start_position, float multiplier) {
404
+ for (int p = start_position; p < res.size (); ++p) {
405
+ res[p].second *= multiplier;
406
+ }
407
+ };
408
+
409
+ std::smatch m;
410
+ std::string remaining_text = text;
411
+
412
+ while (std::regex_search (remaining_text, m, re_attention)) {
413
+ std::string text = m[0 ];
414
+ std::string weight = m[1 ];
415
+
416
+ if (text == " (" ) {
417
+ round_brackets.push_back (res.size ());
418
+ } else if (text == " [" ) {
419
+ square_brackets.push_back (res.size ());
420
+ } else if (!weight.empty ()) {
421
+ if (!round_brackets.empty ()) {
422
+ multiply_range (round_brackets.back (), std::stod (weight));
423
+ round_brackets.pop_back ();
424
+ }
425
+ } else if (text == " )" && !round_brackets.empty ()) {
426
+ multiply_range (round_brackets.back (), round_bracket_multiplier);
427
+ round_brackets.pop_back ();
428
+ } else if (text == " ]" && !square_brackets.empty ()) {
429
+ multiply_range (square_brackets.back (), square_bracket_multiplier);
430
+ square_brackets.pop_back ();
431
+ } else if (text == " \\ (" ) {
432
+ res.push_back ({text.substr (1 ), 1 .0f });
433
+ } else {
434
+ res.push_back ({text, 1 .0f });
435
+ }
436
+
437
+ remaining_text = m.suffix ();
438
+ }
439
+
440
+ for (int pos : round_brackets) {
441
+ multiply_range (pos, round_bracket_multiplier);
442
+ }
443
+
444
+ for (int pos : square_brackets) {
445
+ multiply_range (pos, square_bracket_multiplier);
446
+ }
447
+
448
+ if (res.empty ()) {
449
+ res.push_back ({" " , 1 .0f });
450
+ }
451
+
452
+ int i = 0 ;
453
+ while (i + 1 < res.size ()) {
454
+ if (res[i].second == res[i + 1 ].second ) {
455
+ res[i].first += res[i + 1 ].first ;
456
+ res.erase (res.begin () + i + 1 );
457
+ } else {
458
+ ++i;
459
+ }
460
+ }
461
+
462
+ return res;
463
+ }
464
+
358
465
/* ================================================ FrozenCLIPEmbedder ================================================*/
359
466
360
467
struct ResidualAttentionBlock {
@@ -639,6 +746,61 @@ struct FrozenCLIPEmbedder {
639
746
}
640
747
};
641
748
749
+ // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
750
+ struct FrozenCLIPEmbedderWithCustomWords {
751
+ CLIPTokenizer tokenizer;
752
+ CLIPTextModel text_model;
753
+
754
+ std::pair<std::vector<int >, std::vector<float >> tokenize (std::string text,
755
+ size_t max_length = 0 ,
756
+ bool padding = false ) {
757
+ auto parsed_attention = parse_prompt_attention (text);
758
+
759
+ {
760
+ std::stringstream ss;
761
+ ss << " [" ;
762
+ for (const auto & item : parsed_attention) {
763
+ ss << " ['" << item.first << " ', " << item.second << " ], " ;
764
+ }
765
+ ss << " ]" ;
766
+ LOG_DEBUG (" parse '%s' to %s" , text.c_str (), ss.str ().c_str ());
767
+ }
768
+
769
+ std::vector<int > tokens;
770
+ std::vector<float > weights;
771
+ for (const auto & item : parsed_attention) {
772
+ const std::string& curr_text = item.first ;
773
+ float curr_weight = item.second ;
774
+ std::vector<int > curr_tokens = tokenizer.encode (curr_text);
775
+ tokens.insert (tokens.end (), curr_tokens.begin (), curr_tokens.end ());
776
+ weights.insert (weights.end (), curr_tokens.size (), curr_weight);
777
+ }
778
+ tokens.insert (tokens.begin (), BOS_TOKEN_ID);
779
+ weights.insert (weights.begin (), 1.0 );
780
+
781
+ if (max_length > 0 ) {
782
+ if (tokens.size () > max_length - 1 ) {
783
+ tokens.resize (max_length - 1 );
784
+ weights.resize (max_length - 1 );
785
+ } else {
786
+ if (padding) {
787
+ tokens.insert (tokens.end (), max_length - 1 - tokens.size (), PAD_TOKEN_ID);
788
+ weights.insert (weights.end (), max_length - 1 - weights.size (), 1.0 );
789
+ }
790
+ }
791
+ }
792
+ tokens.push_back (EOS_TOKEN_ID);
793
+ weights.push_back (1.0 );
794
+
795
+ // for (int i = 0; i < tokens.size(); i++) {
796
+ // std::cout << tokens[i] << ":" << weights[i] << ", ";
797
+ // }
798
+ // std::cout << std::endl;
799
+
800
+ return {tokens, weights};
801
+ }
802
+ };
803
+
642
804
/* ==================================================== UnetModel =====================================================*/
643
805
644
806
struct ResBlock {
@@ -2489,7 +2651,7 @@ class StableDiffusionGGML {
2489
2651
size_t max_params_mem_size = 0 ;
2490
2652
size_t max_rt_mem_size = 0 ;
2491
2653
2492
- FrozenCLIPEmbedder cond_stage_model;
2654
+ FrozenCLIPEmbedderWithCustomWords cond_stage_model;
2493
2655
UNetModel diffusion_model;
2494
2656
AutoEncoderKL first_stage_model;
2495
2657
@@ -2784,9 +2946,11 @@ class StableDiffusionGGML {
2784
2946
}
2785
2947
2786
2948
ggml_tensor* get_learned_condition (ggml_context* res_ctx, const std::string& text) {
2787
- std::vector<int32_t > tokens = cond_stage_model.tokenizer .tokenize (text,
2788
- cond_stage_model.text_model .max_position_embeddings ,
2789
- true );
2949
+ auto tokens_and_weights = cond_stage_model.tokenize (text,
2950
+ cond_stage_model.text_model .max_position_embeddings ,
2951
+ true );
2952
+ std::vector<int >& tokens = tokens_and_weights.first ;
2953
+ std::vector<float >& weights = tokens_and_weights.second ;
2790
2954
size_t ctx_size = 1 * 1024 * 1024 ; // 1MB
2791
2955
// calculate the amount of memory required
2792
2956
{
@@ -2848,10 +3012,39 @@ class StableDiffusionGGML {
2848
3012
int64_t t1 = ggml_time_ms ();
2849
3013
LOG_DEBUG (" computing condition graph completed, taking %.2fs" , (t1 - t0) * 1 .0f / 1000 );
2850
3014
2851
- ggml_tensor* result = ggml_dup_tensor (res_ctx, hidden_states);
2852
- copy_ggml_tensor (result, hidden_states);
3015
+ ggml_tensor* result = ggml_dup_tensor (res_ctx, hidden_states); // [N, n_token, hidden_size]
3016
+
3017
+ {
3018
+ int64_t nelements = ggml_nelements (hidden_states);
3019
+ float original_mean = 0 .f ;
3020
+ float new_mean = 0 .f ;
3021
+ float * vec = (float *)hidden_states->data ;
3022
+ for (int i = 0 ; i < nelements; i++) {
3023
+ original_mean += vec[i] / nelements * 1 .0f ;
3024
+ }
3025
+
3026
+ for (int i2 = 0 ; i2 < hidden_states->ne [2 ]; i2++) {
3027
+ for (int i1 = 0 ; i1 < hidden_states->ne [1 ]; i1++) {
3028
+ for (int i0 = 0 ; i0 < hidden_states->ne [0 ]; i0++) {
3029
+ float value = ggml_tensor_get_f32 (hidden_states, i0, i1, i2);
3030
+ value *= weights[i1];
3031
+ ggml_tensor_set_f32 (result, value, i0, i1, i2);
3032
+ }
3033
+ }
3034
+ }
3035
+
3036
+ vec = (float *)result->data ;
3037
+ for (int i = 0 ; i < nelements; i++) {
3038
+ new_mean += vec[i] / nelements * 1 .0f ;
3039
+ }
3040
+
3041
+ for (int i = 0 ; i < nelements; i++) {
3042
+ vec[i] = vec[i] * (original_mean / new_mean);
3043
+ }
3044
+ }
2853
3045
2854
3046
// print_ggml_tensor(result);
3047
+
2855
3048
size_t rt_mem_size = ctx_size + ggml_curr_max_dynamic_size ();
2856
3049
if (rt_mem_size > max_rt_mem_size) {
2857
3050
max_rt_mem_size = rt_mem_size;
0 commit comments