Skip to content

Commit c06e8bd

Browse files
committed
feat(tx): support multiple devices
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 5c9e85f commit c06e8bd

File tree

3 files changed

+253
-115
lines changed

3 files changed

+253
-115
lines changed

stable-diffusion.cpp

Lines changed: 166 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ class StableDiffusionGGML {
145145
public:
146146
ggml_backend_t backend = NULL; // general backend
147147
ggml_backend_t clip_backend = NULL;
148-
ggml_backend_t control_net_backend = NULL;
149148
ggml_backend_t vae_backend = NULL;
149+
ggml_backend_t control_net_backend = NULL;
150150
ggml_type model_wtype = GGML_TYPE_COUNT;
151151
ggml_type clip_l_wtype = GGML_TYPE_COUNT;
152152
ggml_type clip_g_wtype = GGML_TYPE_COUNT;
@@ -234,55 +234,155 @@ class StableDiffusionGGML {
234234
bool vae_on_cpu,
235235
bool diffusion_flash_attn,
236236
bool tae_preview_only,
237-
int main_gpu) {
237+
const std::vector<std::string>& rpc_servers,
238+
const float* tensor_split) {
238239
use_tiny_autoencoder = taesd_path.size() > 0;
239240

240241
ggml_log_set(ggml_log_callback_default, nullptr);
241-
#ifdef SD_USE_CUDA
242-
#ifdef SD_USE_HIP
243-
LOG_DEBUG("Using HIP backend");
244-
#elif defined(SD_USE_MUSA)
245-
LOG_DEBUG("Using MUSA backend");
246-
#else
247-
LOG_DEBUG("Using CUDA backend");
248-
#endif
249-
backend = ggml_backend_cuda_init(main_gpu);
250-
if (!backend) {
251-
LOG_ERROR("CUDA backend init failed");
252-
}
253-
#endif
254-
#ifdef SD_USE_METAL
255-
LOG_DEBUG("Using Metal backend");
256-
backend = ggml_backend_metal_init();
257-
if (!backend) {
258-
LOG_ERROR("Metal backend init failed");
259-
}
260-
#endif
261-
#ifdef SD_USE_VULKAN
262-
LOG_DEBUG("Using Vulkan backend");
263-
backend = ggml_backend_vk_init(main_gpu);
264-
if (!backend) {
265-
LOG_ERROR("Vulkan backend init failed");
266-
}
267-
#endif
268-
#ifdef SD_USE_SYCL
269-
LOG_DEBUG("Using SYCL backend");
270-
backend = ggml_backend_sycl_init(main_gpu);
271-
if (!backend) {
272-
LOG_ERROR("SYCL backend init failed");
273-
}
274-
#endif
275-
#ifdef SD_USE_CANN
276-
LOG_DEBUG("Using CANN backend");
277-
backend = ggml_backend_cann_init(main_gpu);
278-
if (!backend) {
279-
LOG_ERROR("CANN backend init failed");
280-
}
281-
#endif
282-
283-
if (!backend) {
284-
LOG_DEBUG("Using CPU backend");
242+
243+
std::vector<ggml_backend_dev_t> devices;
244+
245+
if (!rpc_servers.empty()) {
246+
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
247+
if (!rpc_reg) {
248+
LOG_ERROR("failed to find RPC backend");
249+
return false;
250+
}
251+
252+
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char* endpoint);
253+
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t)ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
254+
if (!ggml_backend_rpc_add_device_fn) {
255+
LOG_ERROR("failed to find RPC device add function");
256+
return false;
257+
}
258+
259+
for (const std::string& server : rpc_servers) {
260+
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
261+
if (dev) {
262+
devices.push_back(dev);
263+
} else {
264+
LOG_ERROR("failed to add RPC device for server '%s'", server.c_str());
265+
return false;
266+
}
267+
}
268+
}
269+
270+
// use all available devices
271+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
272+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
273+
switch (ggml_backend_dev_type(dev)) {
274+
case GGML_BACKEND_DEVICE_TYPE_CPU:
275+
case GGML_BACKEND_DEVICE_TYPE_ACCEL:
276+
// skip CPU backends since they are handled separately
277+
break;
278+
279+
case GGML_BACKEND_DEVICE_TYPE_GPU:
280+
devices.push_back(dev);
281+
break;
282+
}
283+
}
284+
285+
for (auto* dev : devices) {
286+
size_t free, total; // NOLINT
287+
ggml_backend_dev_memory(dev, &free, &total);
288+
LOG_INFO("using device %s (%s) - %zu MiB free", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free / 1024 / 1024);
289+
}
290+
291+
// build GPU devices buffer list
292+
std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>> gpu_devices;
293+
{
294+
const bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + devices.size(), [](float x) { return x == 0.0f; });
295+
// add GPU buffer types
296+
for (size_t i = 0; i < devices.size(); ++i) {
297+
if (!all_zero && tensor_split[i] <= 0.0f) {
298+
continue;
299+
}
300+
ggml_backend_device* dev = devices[i];
301+
gpu_devices.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
302+
}
303+
}
304+
305+
// initialize the backend
306+
if (gpu_devices.empty()) {
307+
// no GPU devices available
285308
backend = ggml_backend_cpu_init();
309+
} else if (gpu_devices.size() < 3) {
310+
// use the last GPU device: device 0, device 1
311+
backend = ggml_backend_dev_init(gpu_devices[gpu_devices.size() - 1].first, nullptr);
312+
} else {
313+
// use the 3rd GPU device: device 2
314+
backend = ggml_backend_dev_init(gpu_devices[2].first, nullptr);
315+
}
316+
switch (gpu_devices.size()) {
317+
case 0: {
318+
clip_backend = backend;
319+
vae_backend = backend;
320+
control_net_backend = backend;
321+
break;
322+
}
323+
case 1: {
324+
// device 0: clip, vae, control_net
325+
clip_backend = backend;
326+
if (clip_on_cpu) {
327+
LOG_INFO("CLIP: Using CPU backend");
328+
clip_backend = ggml_backend_cpu_init();
329+
}
330+
vae_backend = backend;
331+
if (vae_on_cpu) {
332+
LOG_INFO("VAE Autoencoder: Using CPU backend");
333+
vae_backend = ggml_backend_cpu_init();
334+
}
335+
control_net_backend = backend;
336+
if (control_net_cpu) {
337+
LOG_INFO("ControlNet: Using CPU backend");
338+
control_net_backend = ggml_backend_cpu_init();
339+
}
340+
break;
341+
}
342+
case 2: {
343+
// device 0: clip, vae, control_net
344+
if (clip_on_cpu) {
345+
LOG_INFO("CLIP: Using CPU backend");
346+
clip_backend = ggml_backend_cpu_init();
347+
} else {
348+
clip_backend = ggml_backend_dev_init(gpu_devices[0].first, nullptr);
349+
}
350+
if (vae_on_cpu) {
351+
LOG_INFO("VAE Autoencoder: Using CPU backend");
352+
vae_backend = ggml_backend_cpu_init();
353+
} else {
354+
vae_backend = ggml_backend_dev_init(gpu_devices[0].first, nullptr);
355+
}
356+
if (control_net_cpu) {
357+
LOG_INFO("ControlNet: Using CPU backend");
358+
control_net_backend = ggml_backend_cpu_init();
359+
} else {
360+
control_net_backend = ggml_backend_dev_init(gpu_devices[0].first, nullptr);
361+
}
362+
break;
363+
}
364+
default: {
365+
// device 0: clip, control_net
366+
// device 1: vae
367+
if (clip_on_cpu) {
368+
LOG_INFO("CLIP: Using CPU backend");
369+
clip_backend = ggml_backend_cpu_init();
370+
} else {
371+
clip_backend = ggml_backend_dev_init(gpu_devices[0].first, nullptr);
372+
}
373+
if (vae_on_cpu) {
374+
LOG_INFO("VAE Autoencoder: Using CPU backend");
375+
vae_backend = ggml_backend_cpu_init();
376+
} else {
377+
vae_backend = ggml_backend_dev_init(gpu_devices[1].first, nullptr);
378+
}
379+
if (control_net_cpu) {
380+
LOG_INFO("ControlNet: Using CPU backend");
381+
control_net_backend = ggml_backend_cpu_init();
382+
} else {
383+
control_net_backend = ggml_backend_dev_init(gpu_devices[0].first, nullptr);
384+
}
385+
}
286386
}
287387

288388
ModelLoader model_loader;
@@ -443,24 +543,19 @@ class StableDiffusionGGML {
443543
auto cc_vae = model_loader.has_prefix_tensors("first_stage_model.") && !model_loader.has_prefix_tensors("vae.");
444544

445545
if (version == VERSION_SVD) {
446-
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, model_loader.tensor_storages_types, cc_clip_l);
546+
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(clip_backend, model_loader.tensor_storages_types, cc_clip_l);
447547
clip_vision->alloc_params_buffer();
448548
clip_vision->get_param_tensors(tensors);
449549

450550
diffusion_model = std::make_shared<UNetModel>(backend, model_loader.tensor_storages_types, version);
451551
diffusion_model->alloc_params_buffer();
452552
diffusion_model->get_param_tensors(tensors);
453553

454-
first_stage_model = std::make_shared<AutoEncoderKL>(backend, model_loader.tensor_storages_types, vae_decode_only, true, version, cc_vae);
554+
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend, model_loader.tensor_storages_types, vae_decode_only, true, version, cc_vae);
455555
LOG_DEBUG("vae_decode_only %d", vae_decode_only);
456556
first_stage_model->alloc_params_buffer();
457557
first_stage_model->get_param_tensors(tensors);
458558
} else {
459-
clip_backend = backend;
460-
if (clip_on_cpu && !ggml_backend_is_cpu(backend)) {
461-
LOG_INFO("CLIP: Using CPU backend");
462-
clip_backend = ggml_backend_cpu_init();
463-
}
464559
if (diffusion_flash_attn) {
465560
LOG_INFO("Using flash attention in the diffusion model");
466561
}
@@ -489,30 +584,17 @@ class StableDiffusionGGML {
489584
diffusion_model->get_param_tensors(tensors);
490585

491586
if (!use_tiny_autoencoder || tae_preview_only) {
492-
if (vae_on_cpu && !ggml_backend_is_cpu(backend)) {
493-
LOG_INFO("VAE Autoencoder: Using CPU backend");
494-
vae_backend = ggml_backend_cpu_init();
495-
} else {
496-
vae_backend = backend;
497-
}
498587
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend, model_loader.tensor_storages_types, vae_decode_only, false, version, cc_vae);
499588
first_stage_model->alloc_params_buffer();
500589
first_stage_model->get_param_tensors(tensors);
501590
}
502591
if (use_tiny_autoencoder) {
503-
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, model_loader.tensor_storages_types, vae_decode_only, version, cc_vae);
592+
tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend, model_loader.tensor_storages_types, vae_decode_only, version, cc_vae);
504593
}
505594
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
506595

507596
if (control_net_path.size() > 0) {
508-
ggml_backend_t controlnet_backend = NULL;
509-
if (control_net_cpu && !ggml_backend_is_cpu(backend)) {
510-
LOG_DEBUG("ControlNet: Using CPU backend");
511-
controlnet_backend = ggml_backend_cpu_init();
512-
} else {
513-
controlnet_backend = backend;
514-
}
515-
control_net = std::make_shared<ControlNet>(controlnet_backend, model_loader.tensor_storages_types, version);
597+
control_net = std::make_shared<ControlNet>(control_net_backend, model_loader.tensor_storages_types, version);
516598
}
517599

518600
if (id_embeddings_path.find("v2") != std::string::npos) {
@@ -1421,7 +1503,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
14211503
bool keep_vae_on_cpu,
14221504
bool diffusion_flash_attn,
14231505
bool tae_preview_only,
1424-
int main_gpu) {
1506+
const char* rpc_servers,
1507+
const float* tensor_splits) {
14251508
sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t));
14261509
if (sd_ctx == NULL) {
14271510
return NULL;
@@ -1437,6 +1520,18 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
14371520
std::string embd_path(embed_dir_c_str);
14381521
std::string id_embd_path(id_embed_dir_c_str);
14391522
std::string lora_model_dir(lora_model_dir_c_str);
1523+
std::vector<std::string> rpc_servers_vec;
1524+
if (rpc_servers != nullptr && rpc_servers[0] != '\0') {
1525+
// split the servers set them into model->rpc_servers
1526+
std::string servers(rpc_servers);
1527+
size_t pos = 0;
1528+
while ((pos = servers.find(',')) != std::string::npos) {
1529+
std::string server = servers.substr(0, pos);
1530+
rpc_servers_vec.push_back(server);
1531+
servers.erase(0, pos + 1);
1532+
}
1533+
rpc_servers_vec.push_back(servers);
1534+
}
14401535

14411536
sd_ctx->sd = new StableDiffusionGGML(n_threads,
14421537
vae_decode_only,
@@ -1466,7 +1561,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
14661561
keep_vae_on_cpu,
14671562
diffusion_flash_attn,
14681563
tae_preview_only,
1469-
main_gpu)) {
1564+
rpc_servers_vec,
1565+
tensor_splits)) {
14701566
delete sd_ctx->sd;
14711567
sd_ctx->sd = NULL;
14721568
free(sd_ctx);

stable-diffusion.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
168168
bool keep_vae_on_cpu,
169169
bool diffusion_flash_attn,
170170
bool tae_preview_only,
171-
int main_gpu = 0);
171+
const char * rpc_servers = nullptr,
172+
const float * tensor_splits = nullptr);
172173

173174
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
174175

@@ -311,7 +312,8 @@ typedef struct upscaler_ctx_t upscaler_ctx_t;
311312

312313
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
313314
int n_threads,
314-
int main_gpu = 0);
315+
const char * rpc_servers = nullptr,
316+
const float * tensor_splits = nullptr);
315317
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
316318

317319
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);

0 commit comments

Comments
 (0)