Skip to content

Commit 1e0cf43

Browse files
committed
refactor: main gpu
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 7afa9b7 commit 1e0cf43

File tree

3 files changed

+14
-15
lines changed

3 files changed

+14
-15
lines changed

stable-diffusion.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ class StableDiffusionGGML {
196196
bool clip_on_cpu,
197197
bool control_net_cpu,
198198
bool vae_on_cpu,
199-
int main_gpu) {
199+
int main_gpu = 0) {
200200
use_tiny_autoencoder = taesd_path.size() > 0;
201201
#ifdef SD_USE_CUBLAS
202202
LOG_DEBUG("Using CUDA backend");
@@ -219,13 +219,7 @@ class StableDiffusionGGML {
219219
LOG_DEBUG("Using CPU backend");
220220
backend = ggml_backend_cpu_init();
221221
}
222-
#ifdef SD_USE_FLASH_ATTENTION
223-
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) || defined(SD_USE_VULKAN) || defined(SD_USE_CANN)
224-
LOG_WARN("Flash Attention not supported with GPU Backend");
225-
#else
226-
LOG_INFO("Flash Attention enabled");
227-
#endif
228-
#endif
222+
229223
ModelLoader model_loader;
230224

231225
vae_tiling = vae_tiling_;

stable-diffusion.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ typedef struct upscaler_ctx_t upscaler_ctx_t;
192192

193193
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
194194
int n_threads,
195-
enum ggml_type wtype);
195+
enum ggml_type wtype,
196+
int main_gpu = 0);
197+
196198
SD_API void upscaler_ctx_free(upscaler_ctx_t* upscaler_ctx);
197199

198200
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);

upscaler.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,24 @@ struct UpscalerGGML {
1414
: n_threads(n_threads) {
1515
}
1616

17-
bool load_from_file(const std::string& esrgan_path) {
17+
bool load_from_file(
18+
const std::string& esrgan_path,
19+
int main_gpu = 0) {
1820
#ifdef SD_USE_CUBLAS
1921
LOG_DEBUG("Using CUDA backend");
20-
backend = ggml_backend_cuda_init(0);
22+
backend = ggml_backend_cuda_init(main_gpu);
2123
#endif
2224
#ifdef SD_USE_METAL
2325
LOG_DEBUG("Using Metal backend");
2426
backend = ggml_backend_metal_init();
2527
#endif
2628
#ifdef SD_USE_CANN
2729
LOG_DEBUG("Using CANN backend");
28-
backend = ggml_backend_cann_init(0);
30+
backend = ggml_backend_cann_init(main_gpu);
2931
#endif
3032
#ifdef SD_USE_SYCL
3133
LOG_DEBUG("Using SYCL backend");
32-
backend = ggml_backend_sycl_init(0);
34+
backend = ggml_backend_sycl_init(main_gpu);
3335
#endif
3436

3537
if (!backend) {
@@ -96,7 +98,8 @@ struct upscaler_ctx_t {
9698

9799
upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
98100
int n_threads,
99-
enum ggml_type wtype) {
101+
enum ggml_type wtype,
102+
int main_gpu) {
100103
upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t));
101104
if (upscaler_ctx == NULL) {
102105
return NULL;
@@ -108,7 +111,7 @@ upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
108111
return NULL;
109112
}
110113

111-
if (!upscaler_ctx->upscaler->load_from_file(esrgan_path)) {
114+
if (!upscaler_ctx->upscaler->load_from_file(esrgan_path, main_gpu)) {
112115
delete upscaler_ctx->upscaler;
113116
upscaler_ctx->upscaler = NULL;
114117
free(upscaler_ctx);

0 commit comments

Comments
 (0)