@@ -822,117 +822,4 @@ namespace Chroma {
822
822
};
823
823
824
824
} // namespace Chroma
825
- #endif // __CHROMA_HPP__
826
- ->
827
- struct ChromaT5Embedder : public Conditioner {
828
- T5UniGramTokenizer t5_tokenizer;
829
- std::shared_ptr<T5Runner> t5;
830
-
831
- ChromaT5Embedder (ggml_backend_t backend,
832
- std::map<std::string, enum ggml_type>& tensor_types)
833
- { // Initialize prefix_
834
- t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
835
- }
836
-
837
- SDCondition get_learned_condition (ggml_context* work_ctx,
838
- int n_threads,
839
- const std::string& text,
840
- int clip_skip, // Not used by T5
841
- int width, // Not used by T5
842
- int height, // Not used by T5
843
- int adm_in_channels = -1 , // Not used by T5
844
- bool force_zero_embeddings = false ) {
845
- // Tokenize the text using T5UniGramTokenizer
846
- auto parsed_attention = parse_prompt_attention (text);
847
- std::vector<int > tokens;
848
- std::vector<float > weights;
849
-
850
- for (const auto & item : parsed_attention) {
851
- const std::string& curr_text = item.first ;
852
- float curr_weight = item.second ;
853
- std::vector<int > curr_tokens = t5_tokenizer.Encode (curr_text, false );
854
- tokens.insert (tokens.end (), curr_tokens.begin (), curr_tokens.end ());
855
- weights.insert (weights.end (), curr_tokens.size (), curr_weight);
856
- }
857
-
858
- // Add EOS token and pad
859
- int EOS_TOKEN_ID = 1 ; // Assuming EOS_TOKEN_ID for T5 is 1
860
- tokens.push_back (EOS_TOKEN_ID);
861
- weights.push_back (1.0 );
862
- t5_tokenizer.pad_tokens (tokens, weights, 256 , true ); // Max length 256 for T5, enable padding
863
-
864
- // Create input_ids tensor from tokens
865
- struct ggml_tensor * input_ids = vector_to_ggml_tensor_i32 (work_ctx, tokens);
866
- struct ggml_tensor * hidden_states = NULL ;
867
-
868
- // Compute T5 embeddings
869
- t5->compute (n_threads, input_ids, &hidden_states, work_ctx);
870
-
871
- // Apply weights to hidden_states, similar to FluxCLIPEmbedder
872
- if (!force_zero_embeddings) {
873
- auto tensor = hidden_states;
874
- float original_mean = ggml_tensor_mean (tensor);
875
- // T5 output is [N, n_token, model_dim], so ne[0] is model_dim, ne[1] is n_token
876
- for (int i1 = 0 ; i1 < tensor->ne [1 ]; i1++) { // Iterate over tokens
877
- for (int i0 = 0 ; i0 < tensor->ne [0 ]; i0++) { // Iterate over hidden_size
878
- float value = ggml_tensor_get_f32 (tensor, i0, i1, 0 ); // Assuming 2D tensor
879
- value *= weights[i1]; // Apply weight for this token
880
- ggml_tensor_set_f32 (tensor, value, i0, i1, 0 );
881
- }
882
- }
883
- float new_mean = ggml_tensor_mean (tensor);
884
- ggml_tensor_scale (tensor, (original_mean / new_mean));
885
- } else {
886
- float * vec = (float *)hidden_states->data ;
887
- for (int i = 0 ; i < ggml_nelements (hidden_states); i++) {
888
- vec[i] = 0 ;
889
- }
890
- }
891
-
892
- // Generate T5 padding mask (c_concat)
893
- struct ggml_tensor * c_concat_tensor = NULL ;
894
- std::vector<float > padding_mask_vec (tokens.size ());
895
- for (size_t i = 0 ; i < tokens.size (); ++i) {
896
- padding_mask_vec[i] = (tokens[i] == 0 ) ? 0 .0f : 1 .0f ;
897
- }
898
- c_concat_tensor = vector_to_ggml_tensor (work_ctx, padding_mask_vec);
899
- c_concat_tensor = ggml_reshape_2d (work_ctx, c_concat_tensor, 1 , tokens.size ()); // Reshape to [1, N_tokens]
900
-
901
- return SDCondition (hidden_states, NULL , c_concat_tensor);
902
- }
903
-
904
- void alloc_params_buffer () {
905
- t5->alloc_params_buffer ();
906
- }
907
-
908
- void free_params_buffer () {
909
- t5->free_params_buffer ();
910
- }
911
-
912
- void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
913
- t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
914
- }
915
-
916
- size_t get_params_buffer_size () {
917
- return t5->get_params_buffer_size ();
918
- }
919
-
920
- std::tuple<SDCondition, std::vector<bool >> get_learned_condition_with_trigger (ggml_context* work_ctx,
921
- int n_threads,
922
- const std::string& text,
923
- int clip_skip,
924
- int width,
925
- int height,
926
- int num_input_imgs,
927
- int adm_in_channels = -1 ,
928
- bool force_zero_embeddings = false ) override {
929
- GGML_ASSERT (0 && " Not implemented yet!" );
930
- return std::make_tuple (SDCondition (), std::vector<bool >());
931
- }
932
-
933
- std::string remove_trigger_from_prompt (ggml_context* work_ctx,
934
- const std::string& prompt) override {
935
- GGML_ASSERT (0 && " Not implemented yet!" );
936
- return " " ;
937
- }
938
- };
825
+ #endif // __CHROMA_HPP__
0 commit comments