Skip to content

Commit 1b3faeb

Browse files
committed
feat(tx): control freeing compute buffer
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 0dac94f commit 1b3faeb

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

examples/cli/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,7 @@ int main(int argc, const char* argv[]) {
840840
vae_decode_only,
841841
params.vae_tiling,
842842
true,
843+
true,
843844
params.n_threads,
844845
params.wtype,
845846
params.rng_type,

stable-diffusion.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,9 @@ class StableDiffusionGGML {
139139
ggml_backend_t vae_backend = NULL;
140140

141141
SDVersion version;
142-
bool vae_decode_only = false;
143-
bool free_params_immediately = false;
142+
bool vae_decode_only = false;
143+
bool free_params_immediately = false;
144+
bool free_compute_immediately = true;
144145

145146
rng_type_t rng_type = STD_DEFAULT_RNG;
146147
int n_threads = -1;
@@ -174,11 +175,13 @@ class StableDiffusionGGML {
174175
StableDiffusionGGML(int n_threads,
175176
bool vae_decode_only,
176177
bool free_params_immediately,
178+
bool free_compute_immediately,
177179
std::string lora_model_dir,
178180
rng_type_t rng_type)
179181
: n_threads(n_threads),
180182
vae_decode_only(vae_decode_only),
181183
free_params_immediately(free_params_immediately),
184+
free_compute_immediately(free_compute_immediately),
182185
lora_model_dir(lora_model_dir),
183186
rng_type(rng_type) {
184187
}
@@ -1094,7 +1097,9 @@ class StableDiffusionGGML {
10941097
control_net->free_control_ctx();
10951098
control_net->free_compute_buffer();
10961099
}
1097-
diffusion_model->free_compute_buffer();
1100+
if (free_compute_immediately) {
1101+
diffusion_model->free_compute_buffer();
1102+
}
10981103
return x;
10991104
}
11001105

@@ -1164,7 +1169,9 @@ class StableDiffusionGGML {
11641169
} else {
11651170
first_stage_model->compute(n_threads, x, decode, &result);
11661171
}
1167-
first_stage_model->free_compute_buffer();
1172+
if (free_compute_immediately) {
1173+
first_stage_model->free_compute_buffer();
1174+
}
11681175
if (decode) {
11691176
ggml_tensor_scale_output(result);
11701177
}
@@ -1178,7 +1185,9 @@ class StableDiffusionGGML {
11781185
} else {
11791186
tae_first_stage->compute(n_threads, x, decode, &result);
11801187
}
1181-
tae_first_stage->free_compute_buffer();
1188+
if (free_compute_immediately) {
1189+
tae_first_stage->free_compute_buffer();
1190+
}
11821191
}
11831192

11841193
int64_t t1 = ggml_time_ms();
@@ -1218,6 +1227,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
12181227
bool vae_decode_only,
12191228
bool vae_tiling,
12201229
bool free_params_immediately,
1230+
bool free_compute_immediately,
12211231
int n_threads,
12221232
enum sd_type_t wtype,
12231233
enum rng_type_t rng_type,
@@ -1246,6 +1256,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
12461256
sd_ctx->sd = new StableDiffusionGGML(n_threads,
12471257
vae_decode_only,
12481258
free_params_immediately,
1259+
free_compute_immediately,
12491260
lora_model_dir,
12501261
rng_type);
12511262
if (sd_ctx->sd == NULL) {
@@ -2541,6 +2552,9 @@ bool sd_sampling_stream_sample(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream) {
25412552
return true;
25422553
}
25432554
stream->x = sd_ctx->sd->denoiser->inverse_noise_scaling(stream->sigmas[stream->sigmas.size() - 1], stream->x);
2555+
if (sd_ctx->sd->free_compute_immediately) {
2556+
sd_ctx->sd->diffusion_model->free_compute_buffer();
2557+
}
25442558

25452559
size_t sampling_end = ggml_time_ms();
25462560
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - stream->sampling_start) * 1.0f / 1000);

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
149149
bool vae_decode_only,
150150
bool vae_tiling,
151151
bool free_params_immediately,
152+
bool free_compute_immediately,
152153
int n_threads,
153154
enum sd_type_t wtype,
154155
enum rng_type_t rng_type,

0 commit comments

Comments
 (0)