@@ -490,6 +490,7 @@ namespace Flux {
490
490
491
491
struct FluxParams {
492
492
int64_t in_channels = 64 ;
493
+ int64_t out_channels = 64 ;
493
494
int64_t vec_in_dim = 768 ;
494
495
int64_t context_in_dim = 4096 ;
495
496
int64_t hidden_size = 3072 ;
@@ -642,7 +643,6 @@ namespace Flux {
642
643
Flux () {}
643
644
Flux (FluxParams params)
644
645
: params(params) {
645
- int64_t out_channels = params.in_channels ;
646
646
int64_t pe_dim = params.hidden_size / params.num_heads ;
647
647
648
648
blocks[" img_in" ] = std::shared_ptr<GGMLBlock>(new Linear (params.in_channels , params.hidden_size , true ));
@@ -669,7 +669,7 @@ namespace Flux {
669
669
params.flash_attn ));
670
670
}
671
671
672
- blocks[" final_layer" ] = std::shared_ptr<GGMLBlock>(new LastLayer (params.hidden_size , 1 , out_channels));
672
+ blocks[" final_layer" ] = std::shared_ptr<GGMLBlock>(new LastLayer (params.hidden_size , 1 , params. out_channels ));
673
673
}
674
674
675
675
struct ggml_tensor * patchify (struct ggml_context * ctx,
@@ -834,12 +834,16 @@ namespace Flux {
834
834
FluxRunner (ggml_backend_t backend,
835
835
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
836
836
const std::string prefix = " " ,
837
+ SDVersion version = VERSION_FLUX,
837
838
bool flash_attn = false )
838
839
: GGMLRunner(backend) {
839
840
flux_params.flash_attn = flash_attn;
840
841
flux_params.guidance_embed = false ;
841
842
flux_params.depth = 0 ;
842
843
flux_params.depth_single_blocks = 0 ;
844
+ if (version == VERSION_FLUX_INPAINT) {
845
+ flux_params.in_channels = 384 ;
846
+ }
843
847
for (auto pair : tensor_types) {
844
848
std::string tensor_name = pair.first ;
845
849
if (tensor_name.find (" model.diffusion_model." ) == std::string::npos)
0 commit comments