Skip to content

Commit 5ceb59c

Browse files
committed
refactor(tx): add streaming util
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent c458bfd commit 5ceb59c

File tree

8 files changed

+2520
-29
lines changed

8 files changed

+2520
-29
lines changed

denoiser.hpp

Lines changed: 678 additions & 0 deletions
Large diffs are not rendered by default.

examples/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
22

3-
add_subdirectory(cli)
3+
add_subdirectory(cli)
4+
add_subdirectory(stream-cli)

examples/stream-cli/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(TARGET stable-diffusion-stream-cli)
2+
3+
add_executable(${TARGET} main.cpp)
4+
install(TARGETS ${TARGET} RUNTIME)
5+
target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT})
6+
target_compile_features(${TARGET} PUBLIC cxx_std_11)

examples/stream-cli/main.cpp

Lines changed: 924 additions & 0 deletions
Large diffs are not rendered by default.

rng.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,28 @@
66

77
class RNG {
88
public:
9+
virtual uint64_t get_seed() = 0;
910
virtual void manual_seed(uint64_t seed) = 0;
1011
virtual std::vector<float> randn(uint32_t n) = 0;
1112
};
1213

1314
class STDDefaultRNG : public RNG {
1415
private:
16+
uint64_t seed;
1517
std::default_random_engine generator;
1618

1719
public:
20+
STDDefaultRNG(uint64_t seed = 0) {
21+
this->seed = seed;
22+
generator.seed((unsigned int)seed);
23+
}
24+
25+
uint64_t get_seed() {
26+
return seed;
27+
}
28+
1829
void manual_seed(uint64_t seed) {
30+
this->seed = seed;
1931
generator.seed((unsigned int)seed);
2032
}
2133

rng_philox.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ class PhiloxRNG : public RNG {
9393
this->offset = 0;
9494
}
9595

96+
uint64_t get_seed() {
97+
return seed;
98+
}
99+
96100
void manual_seed(uint64_t seed) {
97101
this->seed = seed;
98102
this->offset = 0;

stable-diffusion.cpp

Lines changed: 843 additions & 28 deletions
Large diffs are not rendered by default.

stable-diffusion.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,57 @@ SD_API void sd_lora_adapters_apply(sd_ctx_t* sd_ctx, std::vector<sd_lora_adapter
251251

252252
SD_API int sd_get_version(sd_ctx_t* sd_ctx);
253253

254+
typedef struct sd_sampling_stream_t sd_sampling_stream_t;
255+
256+
SD_API sd_sampling_stream_t* txt2img_stream(sd_ctx_t* sd_ctx,
257+
const char* prompt_c_str,
258+
const char* negative_prompt_c_str,
259+
int clip_skip,
260+
float cfg_scale,
261+
float guidance,
262+
int width,
263+
int height,
264+
enum sample_method_t sample_method,
265+
enum schedule_t schedule,
266+
int sample_steps,
267+
int64_t seed,
268+
const sd_image_t* control_cond,
269+
float control_strength,
270+
int* skip_layers,
271+
size_t skip_layers_count,
272+
float slg_scale,
273+
float skip_layer_start,
274+
float skip_layer_end);
275+
SD_API sd_sampling_stream_t* img2img_stream(sd_ctx_t* sd_ctx,
276+
sd_image_t init_image,
277+
sd_image_t mask_image,
278+
const char* prompt_c_str,
279+
const char* negative_prompt_c_str,
280+
int clip_skip,
281+
float cfg_scale,
282+
float guidance,
283+
int width,
284+
int height,
285+
enum sample_method_t sample_method,
286+
enum schedule_t schedule,
287+
int sample_steps,
288+
float strength,
289+
int64_t seed,
290+
const sd_image_t* control_cond,
291+
float control_strength,
292+
int* skip_layers,
293+
size_t skip_layers_count,
294+
float slg_scale,
295+
float skip_layer_start,
296+
float skip_layer_end);
297+
SD_API int sd_sampling_stream_sampled_steps(sd_sampling_stream_t* stream);
298+
SD_API int sd_sampling_stream_steps(sd_sampling_stream_t* stream);
299+
SD_API void sd_sampling_stream_free(sd_sampling_stream_t* stream);
300+
SD_API bool sd_sampling_stream_sample(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream);
301+
SD_API sd_image_t sd_sampling_stream_get_preview_image(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream, bool faster);
302+
SD_API sd_image_t sd_sampling_stream_get_image(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream);
303+
SD_API const char* sd_sampling_stream_get_parameters_str(sd_sampling_stream_t* stream);
304+
254305
typedef struct upscaler_ctx_t upscaler_ctx_t;
255306

256307
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,

0 commit comments

Comments
 (0)