Skip to content

Commit cb48c3c

Browse files
committed
feat(tx): control freeing compute buffer
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 5ceb59c commit cb48c3c

File tree

4 files changed

+28
-7
lines changed

4 files changed

+28
-7
lines changed

examples/cli/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,7 @@ int main(int argc, const char* argv[]) {
907907
vae_decode_only,
908908
params.vae_tiling,
909909
true,
910+
true,
910911
params.n_threads,
911912
params.wtype,
912913
params.rng_type,

examples/stream-cli/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,7 @@ int main(int argc, const char* argv[]) {
769769
vae_decode_only,
770770
params.vae_tiling,
771771
true,
772+
true,
772773
params.n_threads,
773774
params.wtype,
774775
params.rng_type,

stable-diffusion.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,9 @@ class StableDiffusionGGML {
151151
ggml_type vae_wtype = GGML_TYPE_COUNT;
152152

153153
SDVersion version;
154-
bool vae_decode_only = false;
155-
bool free_params_immediately = false;
154+
bool vae_decode_only = false;
155+
bool free_params_immediately = false;
156+
bool free_compute_immediately = true;
156157

157158
rng_type_t rng_type = STD_DEFAULT_RNG;
158159
int n_threads = -1;
@@ -186,11 +187,13 @@ class StableDiffusionGGML {
186187
StableDiffusionGGML(int n_threads,
187188
bool vae_decode_only,
188189
bool free_params_immediately,
190+
bool free_compute_immediately,
189191
std::string lora_model_dir,
190192
rng_type_t rng_type)
191193
: n_threads(n_threads),
192194
vae_decode_only(vae_decode_only),
193195
free_params_immediately(free_params_immediately),
196+
free_compute_immediately(free_compute_immediately),
194197
lora_model_dir(lora_model_dir),
195198
rng_type(rng_type) {
196199
}
@@ -979,7 +982,9 @@ class StableDiffusionGGML {
979982
} else {
980983
first_stage_model->compute(n_threads, latents, true, &result);
981984
}
982-
first_stage_model->free_compute_buffer();
985+
if (free_compute_immediately) {
986+
first_stage_model->free_compute_buffer();
987+
}
983988
ggml_tensor_scale(latents, scale_factor);
984989

985990
ggml_tensor_scale_output(result);
@@ -997,7 +1002,9 @@ class StableDiffusionGGML {
9971002
} else {
9981003
tae_first_stage->compute(n_threads, latents, true, &result);
9991004
}
1000-
tae_first_stage->free_compute_buffer();
1005+
if (free_compute_immediately) {
1006+
tae_first_stage->free_compute_buffer();
1007+
}
10011008
} else {
10021009
return;
10031010
}
@@ -1259,7 +1266,9 @@ class StableDiffusionGGML {
12591266
control_net->free_control_ctx();
12601267
control_net->free_compute_buffer();
12611268
}
1262-
diffusion_model->free_compute_buffer();
1269+
if (free_compute_immediately) {
1270+
diffusion_model->free_compute_buffer();
1271+
}
12631272
return x;
12641273
}
12651274

@@ -1329,7 +1338,9 @@ class StableDiffusionGGML {
13291338
} else {
13301339
first_stage_model->compute(n_threads, x, decode, &result);
13311340
}
1332-
first_stage_model->free_compute_buffer();
1341+
if (free_compute_immediately) {
1342+
first_stage_model->free_compute_buffer();
1343+
}
13331344
if (decode) {
13341345
ggml_tensor_scale_output(result);
13351346
}
@@ -1343,7 +1354,9 @@ class StableDiffusionGGML {
13431354
} else {
13441355
tae_first_stage->compute(n_threads, x, decode, &result);
13451356
}
1346-
tae_first_stage->free_compute_buffer();
1357+
if (free_compute_immediately) {
1358+
tae_first_stage->free_compute_buffer();
1359+
}
13471360
}
13481361

13491362
int64_t t1 = ggml_time_ms();
@@ -1383,6 +1396,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
13831396
bool vae_decode_only,
13841397
bool vae_tiling,
13851398
bool free_params_immediately,
1399+
bool free_compute_immediately,
13861400
int n_threads,
13871401
enum sd_type_t wtype,
13881402
enum rng_type_t rng_type,
@@ -1412,6 +1426,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
14121426
sd_ctx->sd = new StableDiffusionGGML(n_threads,
14131427
vae_decode_only,
14141428
free_params_immediately,
1429+
free_compute_immediately,
14151430
lora_model_dir,
14161431
rng_type);
14171432
if (sd_ctx->sd == NULL) {
@@ -2970,6 +2985,9 @@ bool sd_sampling_stream_sample(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream) {
29702985
return true;
29712986
}
29722987
stream->x = sd_ctx->sd->denoiser->inverse_noise_scaling(stream->sigmas[stream->sigmas.size() - 1], stream->x);
2988+
if (sd_ctx->sd->free_compute_immediately) {
2989+
sd_ctx->sd->diffusion_model->free_compute_buffer();
2990+
}
29732991

29742992
size_t sampling_end = ggml_time_ms();
29752993
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - stream->sampling_start) * 1.0f / 1000);

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
156156
bool vae_decode_only,
157157
bool vae_tiling,
158158
bool free_params_immediately,
159+
bool free_compute_immediately,
159160
int n_threads,
160161
enum sd_type_t wtype,
161162
enum rng_type_t rng_type,

0 commit comments

Comments
 (0)