Skip to content

Commit 1b61764

Browse files
Refactor T5 model handling and attention mask integration
- Updated SD3CLIPEmbedder and FluxCLIPEmbedder to pass NULL for attention masks during T5 computation. - Enhanced ChromaT5Embedder to generate and utilize a padding mask for T5 embeddings. - Modified T5Stack and T5 classes to handle optional attention masks, allowing for NULL values. - Adjusted T5Runner to make attention masks optional in the build_graph and compute methods. - Added a new linear layer test example with comprehensive tensor operations and validation. - Removed obsolete Makefile for chroma_test example. - Updated main.cpp in chroma_test to reflect new model paths and configurations.
1 parent ee17ccb commit 1b61764

File tree

7 files changed

+798
-1045
lines changed

7 files changed

+798
-1045
lines changed

chroma.hpp

Lines changed: 503 additions & 700 deletions
Large diffs are not rendered by default.

conditioner.hpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,7 @@ struct SD3CLIPEmbedder : public Conditioner {
902902

903903
t5->compute(n_threads,
904904
input_ids,
905+
NULL, // Pass NULL for attention_mask_01
905906
&chunk_hidden_states_t5,
906907
work_ctx);
907908
{
@@ -1147,6 +1148,7 @@ struct FluxCLIPEmbedder : public Conditioner {
11471148

11481149
t5->compute(n_threads,
11491150
input_ids,
1151+
NULL, // Pass NULL for attention_mask_01
11501152
&chunk_hidden_states,
11511153
work_ctx);
11521154
{
@@ -1259,8 +1261,18 @@ struct ChromaT5Embedder : public Conditioner {
12591261
struct ggml_tensor* input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
12601262
struct ggml_tensor* hidden_states = NULL;
12611263

1262-
// Compute T5 embeddings
1263-
t5->compute(n_threads, input_ids, &hidden_states, work_ctx);
1264+
// Generate T5 padding mask (c_concat)
1265+
struct ggml_tensor* c_concat_tensor = NULL;
1266+
std::vector<float> padding_mask_vec(tokens.size());
1267+
for (size_t i = 0; i < tokens.size(); ++i) {
1268+
padding_mask_vec[i] = (tokens[i] == t5_tokenizer.pad_id_) ? 0.0f : 1.0f;
1269+
}
1270+
c_concat_tensor = vector_to_ggml_tensor(work_ctx, padding_mask_vec);
1271+
c_concat_tensor = ggml_reshape_2d(work_ctx, c_concat_tensor, 1, tokens.size()); // Reshape to [1, N_tokens]
1272+
1273+
1274+
// Compute T5 embeddings, passing the attention mask
1275+
t5->compute(n_threads, input_ids, c_concat_tensor, &hidden_states, work_ctx);
12641276

12651277
// Apply weights to hidden_states, similar to FluxCLIPEmbedder
12661278
if (!force_zero_embeddings) {
@@ -1283,15 +1295,6 @@ struct ChromaT5Embedder : public Conditioner {
12831295
}
12841296
}
12851297

1286-
// Generate T5 padding mask (c_concat)
1287-
struct ggml_tensor* c_concat_tensor = NULL;
1288-
std::vector<float> padding_mask_vec(tokens.size());
1289-
for (size_t i = 0; i < tokens.size(); ++i) {
1290-
padding_mask_vec[i] = (tokens[i] == t5_tokenizer.pad_id_) ? 0.0f : 1.0f;
1291-
}
1292-
c_concat_tensor = vector_to_ggml_tensor(work_ctx, padding_mask_vec);
1293-
c_concat_tensor = ggml_reshape_2d(work_ctx, c_concat_tensor, 1, tokens.size()); // Reshape to [1, N_tokens]
1294-
12951298
return SDCondition(hidden_states, NULL, c_concat_tensor);
12961299
}
12971300

examples/chroma_test/Makefile

Lines changed: 0 additions & 231 deletions
This file was deleted.

examples/chroma_test/main.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ int main() {
3535
std::cout << "\n--- VAE and T5 Model Testing via Stable Diffusion API ---" << std::endl;
3636

3737
// Define model paths
38-
const char* model_path = "weights/chroma-unlocked-v29.5-Q3_K_L.gguf"; // Main model path (can be empty if only VAE/T5 are loaded)
38+
const char* model_path = ""; // Main model path (can be empty if only VAE/T5 are loaded)
3939
const char* clip_l_path = "";
4040
const char* clip_g_path = "";
41-
const char* t5xxl_path = "weights/t5xxl_q3_k.gguf"; // New T5 model path
42-
const char* diffusion_model_path = "";
43-
const char* vae_path = "weights/ae.safetensors"; // Example VAE path (can be GGUF or safetensors)
41+
const char* t5xxl_path = "./weights/t5xxl_q3_k.gguf"; // New T5 model path
42+
const char* diffusion_model_path = "./weights/chroma-unlocked-v29.5-Q3_K_L.gguf";
43+
const char* vae_path = "./weights/ae.safetensors"; // Example VAE path (can be GGUF or safetensors)
4444
const char* control_net_path = "";
4545
const char* lora_model_dir = "";
4646
const char* embed_dir = "";
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Define the executable for the linear layer test
2+
# Add the source file main.c to the executable
3+
add_executable(linear_layer_test main.cpp)
4+
5+
# Link the executable against the ggml library
6+
# This ensures that the executable can use GGML functions
7+
target_link_libraries(linear_layer_test stable-diffusion)

0 commit comments

Comments
 (0)