Skip to content

Commit f6b9aa1

Browse files
committed
refector: optimize the usage of tensor_types
1 parent 7eb30d0 commit f6b9aa1

16 files changed

+119
-111
lines changed

clip.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -545,9 +545,9 @@ class CLIPEmbeddings : public GGMLBlock {
545545
int64_t vocab_size;
546546
int64_t num_positions;
547547

548-
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
549-
enum ggml_type token_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "token_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "token_embedding.weight"] : GGML_TYPE_F32;
550-
enum ggml_type position_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
548+
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
549+
enum ggml_type token_wtype = GGML_TYPE_F32;
550+
enum ggml_type position_wtype = GGML_TYPE_F32;
551551

552552
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size);
553553
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
@@ -594,10 +594,10 @@ class CLIPVisionEmbeddings : public GGMLBlock {
594594
int64_t image_size;
595595
int64_t num_patches;
596596
int64_t num_positions;
597-
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
598-
enum ggml_type patch_wtype = GGML_TYPE_F16; // tensor_types.find(prefix + "patch_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "patch_embedding.weight"] : GGML_TYPE_F16;
599-
enum ggml_type class_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "class_embedding") != tensor_types.end() ? tensor_types[prefix + "class_embedding"] : GGML_TYPE_F32;
600-
enum ggml_type position_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
597+
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
598+
enum ggml_type patch_wtype = GGML_TYPE_F16;
599+
enum ggml_type class_wtype = GGML_TYPE_F32;
600+
enum ggml_type position_wtype = GGML_TYPE_F32;
601601

602602
params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, patch_wtype, patch_size, patch_size, num_channels, embed_dim);
603603
params["class_embedding"] = ggml_new_tensor_1d(ctx, class_wtype, embed_dim);
@@ -657,9 +657,9 @@ enum CLIPVersion {
657657

658658
class CLIPTextModel : public GGMLBlock {
659659
protected:
660-
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
660+
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
661661
if (version == OPEN_CLIP_VIT_BIGG_14) {
662-
enum ggml_type wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "text_projection") != tensor_types.end() ? tensor_types[prefix + "text_projection"] : GGML_TYPE_F32;
662+
enum ggml_type wtype = GGML_TYPE_F32;
663663
params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size);
664664
}
665665
}
@@ -805,8 +805,8 @@ class CLIPProjection : public UnaryBlock {
805805
int64_t out_features;
806806
bool transpose_weight;
807807

808-
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
809-
enum ggml_type wtype = tensor_types.find(prefix + "weight") != tensor_types.end() ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
808+
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
809+
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
810810
if (transpose_weight) {
811811
params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features);
812812
} else {
@@ -868,7 +868,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
868868
CLIPTextModel model;
869869

870870
CLIPTextModelRunner(ggml_backend_t backend,
871-
std::map<std::string, enum ggml_type>& tensor_types,
871+
const String2GGMLType& tensor_types,
872872
const std::string prefix,
873873
CLIPVersion version = OPENAI_CLIP_VIT_L_14,
874874
bool with_final_ln = true,

common.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,9 @@ class GEGLU : public GGMLBlock {
182182
int64_t dim_in;
183183
int64_t dim_out;
184184

185-
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
186-
enum ggml_type wtype = (tensor_types.find(prefix + "proj.weight") != tensor_types.end()) ? tensor_types[prefix + "proj.weight"] : GGML_TYPE_F32;
187-
enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "proj.bias") != tensor_types.end()) ? tensor_types[prefix + "proj.bias"] : GGML_TYPE_F32;
185+
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
186+
enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_types, GGML_TYPE_F32);
187+
enum ggml_type bias_wtype = GGML_TYPE_F32;
188188
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
189189
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
190190
}
@@ -440,9 +440,9 @@ class SpatialTransformer : public GGMLBlock {
440440

441441
class AlphaBlender : public GGMLBlock {
442442
protected:
443-
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
443+
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
444444
// Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix
445-
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32;
445+
enum ggml_type wtype = GGML_TYPE_F32;
446446
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
447447
}
448448

conditioner.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
5757
std::vector<std::string> readed_embeddings;
5858

5959
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
60-
std::map<std::string, enum ggml_type>& tensor_types,
60+
const String2GGMLType& tensor_types,
6161
const std::string& embd_dir,
6262
SDVersion version = VERSION_SD1,
6363
PMVersion pv = PM_VERSION_1,
@@ -618,7 +618,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
618618
struct FrozenCLIPVisionEmbedder : public GGMLRunner {
619619
CLIPVisionModelProjection vision_model;
620620

621-
FrozenCLIPVisionEmbedder(ggml_backend_t backend, std::map<std::string, enum ggml_type>& tensor_types)
621+
FrozenCLIPVisionEmbedder(ggml_backend_t backend, const String2GGMLType& tensor_types = {})
622622
: vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend) {
623623
vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer");
624624
}
@@ -663,8 +663,8 @@ struct SD3CLIPEmbedder : public Conditioner {
663663
std::shared_ptr<T5Runner> t5;
664664

665665
SD3CLIPEmbedder(ggml_backend_t backend,
666-
std::map<std::string, enum ggml_type>& tensor_types,
667-
int clip_skip = -1)
666+
const String2GGMLType& tensor_types = {},
667+
int clip_skip = -1)
668668
: clip_g_tokenizer(0) {
669669
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
670670
clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
@@ -1010,8 +1010,8 @@ struct FluxCLIPEmbedder : public Conditioner {
10101010
size_t chunk_len = 256;
10111011

10121012
FluxCLIPEmbedder(ggml_backend_t backend,
1013-
std::map<std::string, enum ggml_type>& tensor_types,
1014-
int clip_skip = -1) {
1013+
const String2GGMLType& tensor_types = {},
1014+
int clip_skip = -1) {
10151015
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
10161016
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
10171017
set_clip_skip(clip_skip);
@@ -1231,10 +1231,10 @@ struct PixArtCLIPEmbedder : public Conditioner {
12311231
int mask_pad = 1;
12321232

12331233
PixArtCLIPEmbedder(ggml_backend_t backend,
1234-
std::map<std::string, enum ggml_type>& tensor_types,
1235-
int clip_skip = -1,
1236-
bool use_mask = false,
1237-
int mask_pad = 1)
1234+
const String2GGMLType& tensor_types = {},
1235+
int clip_skip = -1,
1236+
bool use_mask = false,
1237+
int mask_pad = 1)
12381238
: use_mask(use_mask), mask_pad(mask_pad) {
12391239
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
12401240
}

control.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,8 @@ struct ControlNet : public GGMLRunner {
317317
bool guided_hint_cached = false;
318318

319319
ControlNet(ggml_backend_t backend,
320-
std::map<std::string, enum ggml_type>& tensor_types,
321-
SDVersion version = VERSION_SD1)
320+
const String2GGMLType& tensor_types = {},
321+
SDVersion version = VERSION_SD1)
322322
: GGMLRunner(backend), control_net(version) {
323323
control_net.init(params_ctx, tensor_types, "");
324324
}

diffusion_model.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ struct UNetModel : public DiffusionModel {
3232
UNetModelRunner unet;
3333

3434
UNetModel(ggml_backend_t backend,
35-
std::map<std::string, enum ggml_type>& tensor_types,
36-
SDVersion version = VERSION_SD1,
37-
bool flash_attn = false)
35+
const String2GGMLType& tensor_types = {},
36+
SDVersion version = VERSION_SD1,
37+
bool flash_attn = false)
3838
: unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
3939
}
4040

@@ -85,7 +85,7 @@ struct MMDiTModel : public DiffusionModel {
8585
MMDiTRunner mmdit;
8686

8787
MMDiTModel(ggml_backend_t backend,
88-
std::map<std::string, enum ggml_type>& tensor_types)
88+
const String2GGMLType& tensor_types = {})
8989
: mmdit(backend, tensor_types, "model.diffusion_model") {
9090
}
9191

@@ -135,10 +135,10 @@ struct FluxModel : public DiffusionModel {
135135
Flux::FluxRunner flux;
136136

137137
FluxModel(ggml_backend_t backend,
138-
std::map<std::string, enum ggml_type>& tensor_types,
139-
SDVersion version = VERSION_FLUX,
140-
bool flash_attn = false,
141-
bool use_mask = false)
138+
const String2GGMLType& tensor_types = {},
139+
SDVersion version = VERSION_FLUX,
140+
bool flash_attn = false,
141+
bool use_mask = false)
142142
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
143143
}
144144

esrgan.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ struct ESRGAN : public GGMLRunner {
142142
int scale = 4;
143143
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
144144

145-
ESRGAN(ggml_backend_t backend, std::map<std::string, enum ggml_type>& tensor_types)
145+
ESRGAN(ggml_backend_t backend, const String2GGMLType& tensor_types = {})
146146
: GGMLRunner(backend) {
147147
rrdb_net.init(params_ctx, tensor_types, "");
148148
}

flux.hpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ namespace Flux {
3535
int64_t hidden_size;
3636
float eps;
3737

38-
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
39-
ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "scale") != tensor_types.end()) ? tensor_types[prefix + "scale"] : GGML_TYPE_F32;
38+
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
39+
ggml_type wtype = GGML_TYPE_F32;
4040
params["scale"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
4141
}
4242

@@ -1039,8 +1039,6 @@ namespace Flux {
10391039
};
10401040

10411041
struct FluxRunner : public GGMLRunner {
1042-
static std::map<std::string, enum ggml_type> empty_tensor_types;
1043-
10441042
public:
10451043
FluxParams flux_params;
10461044
Flux flux;
@@ -1050,11 +1048,11 @@ namespace Flux {
10501048
bool use_mask = false;
10511049

10521050
FluxRunner(ggml_backend_t backend,
1053-
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
1054-
const std::string prefix = "",
1055-
SDVersion version = VERSION_FLUX,
1056-
bool flash_attn = false,
1057-
bool use_mask = false)
1051+
const String2GGMLType& tensor_types = {},
1052+
const std::string prefix = "",
1053+
SDVersion version = VERSION_FLUX,
1054+
bool flash_attn = false,
1055+
bool use_mask = false)
10581056
: GGMLRunner(backend), use_mask(use_mask) {
10591057
flux_params.flash_attn = flash_attn;
10601058
flux_params.guidance_embed = false;

0 commit comments

Comments
 (0)