Skip to content

Commit c0341f3

Browse files
committed
fix: clip pooling
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 4dbc441 commit c0341f3

File tree

3 files changed

+38
-51
lines changed

3 files changed

+38
-51
lines changed

clip.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,8 +711,10 @@ class CLIPTextModel : public GGMLBlock {
711711
if (return_pooled) {
712712
auto text_projection = params["text_projection"];
713713
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
714-
pooled = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, text_projection)), pooled);
715-
return pooled;
714+
if (text_projection != NULL) {
715+
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
716+
}
717+
return pooled; // [hidden_size, 1, 1]
716718
}
717719

718720
return x; // [N, n_token, hidden_size]

conditioner.hpp

Lines changed: 33 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -837,21 +837,16 @@ struct SD3CLIPEmbedder : public Conditioner {
837837
}
838838

839839
if (chunk_idx == 0) {
840-
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
841-
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
842-
// clip_l->compute(n_threads,
843-
// input_ids,
844-
// 0,
845-
// NULL,
846-
// max_token_idx,
847-
// true,
848-
// &pooled_l,
849-
// work_ctx);
850-
851-
// clip_l.transformer.text_model.text_projection no in file, ignore
852-
// TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
853-
pooled_l = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
854-
ggml_set_f32(pooled_l, 0.f);
840+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
841+
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
842+
clip_l->compute(n_threads,
843+
input_ids,
844+
0,
845+
NULL,
846+
max_token_idx,
847+
true,
848+
&pooled_l,
849+
work_ctx);
855850
}
856851
}
857852

@@ -891,21 +886,16 @@ struct SD3CLIPEmbedder : public Conditioner {
891886
}
892887

893888
if (chunk_idx == 0) {
894-
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
895-
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
896-
// clip_g->compute(n_threads,
897-
// input_ids,
898-
// 0,
899-
// NULL,
900-
// max_token_idx,
901-
// true,
902-
// &pooled_g,
903-
// work_ctx);
904-
// clip_l.transformer.text_model.text_projection no in file, ignore pooled_g too
905-
906-
// TODO: fix pooled_g
907-
pooled_g = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1280);
908-
ggml_set_f32(pooled_g, 0.f);
889+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
890+
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
891+
clip_g->compute(n_threads,
892+
input_ids,
893+
0,
894+
NULL,
895+
max_token_idx,
896+
true,
897+
&pooled_g,
898+
work_ctx);
909899
}
910900
}
911901

@@ -1136,7 +1126,7 @@ struct FluxCLIPEmbedder : public Conditioner {
11361126
struct ggml_tensor* pooled = NULL; // [768,]
11371127
std::vector<float> hidden_states_vec;
11381128

1139-
size_t chunk_len = 256;
1129+
size_t chunk_len = 255;
11401130
size_t chunk_count = t5_tokens.size() / chunk_len;
11411131
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
11421132
// clip_l
@@ -1150,21 +1140,17 @@ struct FluxCLIPEmbedder : public Conditioner {
11501140
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
11511141
size_t max_token_idx = 0;
11521142

1153-
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1154-
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
1155-
// clip_l->compute(n_threads,
1156-
// input_ids,
1157-
// 0,
1158-
// NULL,
1159-
// max_token_idx,
1160-
// true,
1161-
// &pooled,
1162-
// work_ctx);
1163-
1164-
// clip_l.transformer.text_model.text_projection no in file, ignore
1165-
// TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
1166-
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
1167-
ggml_set_f32(pooled, 0.f);
1143+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1144+
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
1145+
1146+
clip_l->compute(n_threads,
1147+
input_ids,
1148+
0,
1149+
NULL,
1150+
max_token_idx,
1151+
true,
1152+
&pooled,
1153+
work_ctx);
11681154
}
11691155

11701156
// t5
@@ -1227,7 +1213,7 @@ struct FluxCLIPEmbedder : public Conditioner {
12271213
int height,
12281214
int adm_in_channels = -1,
12291215
bool force_zero_embeddings = false) {
1230-
auto tokens_and_weights = tokenize(text, 256, true);
1216+
auto tokens_and_weights = tokenize(text, 255, true);
12311217
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
12321218
}
12331219

t5.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,7 @@ class T5UniGramTokenizer {
418418
weights = new_weights;
419419

420420
if (padding) {
421-
int pad_token_id = pad_id_;
422-
tokens.insert(tokens.end(), length - tokens.size(), pad_token_id);
421+
tokens.insert(tokens.end(), length - tokens.size(), pad_id_);
423422
weights.insert(weights.end(), length - weights.size(), 1.0);
424423
}
425424
}

0 commit comments

Comments
 (0)