Skip to content

Commit c798862

Browse files
committed
refactor: support main gpu
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent f782ffe commit c798862

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

examples/cli/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ int main(int argc, const char* argv[]) {
829829

830830
if (results == NULL) {
831831
printf("generate failed\n");
832-
free_sd_ctx(sd_ctx);
832+
sd_ctx_free(sd_ctx);
833833
return 1;
834834
}
835835

@@ -875,7 +875,7 @@ int main(int argc, const char* argv[]) {
875875
results[i].data = NULL;
876876
}
877877
free(results);
878-
free_sd_ctx(sd_ctx);
878+
sd_ctx_free(sd_ctx);
879879
free(control_image_buffer);
880880
free(input_image_buffer);
881881

stable-diffusion.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,23 +193,24 @@ class StableDiffusionGGML {
193193
schedule_t schedule,
194194
bool clip_on_cpu,
195195
bool control_net_cpu,
196-
bool vae_on_cpu) {
196+
bool vae_on_cpu,
197+
int main_gpu) {
197198
use_tiny_autoencoder = taesd_path.size() > 0;
198199
#ifdef SD_USE_CUBLAS
199200
LOG_DEBUG("Using CUDA backend");
200-
backend = ggml_backend_cuda_init(0);
201+
backend = ggml_backend_cuda_init(main_gpu);
201202
#endif
202203
#ifdef SD_USE_METAL
203204
LOG_DEBUG("Using Metal backend");
204205
backend = ggml_backend_metal_init();
205206
#endif
206207
#ifdef SD_USE_CANN
207208
LOG_DEBUG("Using CANN backend");
208-
backend = ggml_backend_cann_init(0);
209+
backend = ggml_backend_cann_init(main_gpu);
209210
#endif
210211
#ifdef SD_USE_SYCL
211212
LOG_DEBUG("Using SYCL backend");
212-
backend = ggml_backend_sycl_init(0);
213+
backend = ggml_backend_sycl_init(main_gpu);
213214
#endif
214215

215216
if (!backend) {
@@ -1058,7 +1059,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
10581059
enum schedule_t s,
10591060
bool keep_clip_on_cpu,
10601061
bool keep_control_net_cpu,
1061-
bool keep_vae_on_cpu) {
1062+
bool keep_vae_on_cpu,
1063+
int main_gpu) {
10621064
sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t));
10631065
if (sd_ctx == NULL) {
10641066
return NULL;
@@ -1080,9 +1082,6 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
10801082
free_params_immediately,
10811083
lora_model_dir,
10821084
rng_type);
1083-
if (sd_ctx->sd == NULL) {
1084-
return NULL;
1085-
}
10861085

10871086
if (!sd_ctx->sd->load_from_file(model_path,
10881087
clip_l_path,
@@ -1099,7 +1098,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
10991098
s,
11001099
keep_clip_on_cpu,
11011100
keep_control_net_cpu,
1102-
keep_vae_on_cpu)) {
1101+
keep_vae_on_cpu,
1102+
main_gpu)) {
11031103
delete sd_ctx->sd;
11041104
sd_ctx->sd = NULL;
11051105
free(sd_ctx);
@@ -1108,7 +1108,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
11081108
return sd_ctx;
11091109
}
11101110

1111-
void free_sd_ctx(sd_ctx_t* sd_ctx) {
1111+
void sd_ctx_free(sd_ctx_t* sd_ctx) {
11121112
if (sd_ctx->sd != NULL) {
11131113
delete sd_ctx->sd;
11141114
sd_ctx->sd = NULL;

stable-diffusion.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,10 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
139139
enum schedule_t s,
140140
bool keep_clip_on_cpu,
141141
bool keep_control_net_cpu,
142-
bool keep_vae_on_cpu);
142+
bool keep_vae_on_cpu,
143+
int main_gpu = 0);
143144

144-
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
145+
SD_API void sd_ctx_free(sd_ctx_t* sd_ctx);
145146

146147
SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
147148
const char* prompt,

0 commit comments

Comments
 (0)