Skip to content

Commit cb1338a

Browse files
committed
support custom logger
1 parent 0e64238 commit cb1338a

File tree

6 files changed

+80
-38
lines changed

6 files changed

+80
-38
lines changed

.clang-format

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ UseTab: Never
33
IndentWidth: 4
44
TabWidth: 4
55
AllowShortIfStatementsOnASingleLine: false
6-
IndentCaseLabels: false
76
ColumnLimit: 0
87
AccessModifierOffset: -4
98
NamespaceIndentation: All

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ test/
33
.vscode/
44
.cache/
55
*.swp
6-
.vscode/
76
*.bat
87
*.bin
98
*.exe

examples/cli/main.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
#include <stdio.h>
1+
#include <cstdio>
22
#include <ctime>
33
#include <random>
4+
45
#include "ggml/ggml.h"
56
#include "stable-diffusion.h"
67
#include "util.h"

stable-diffusion.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,11 @@ void pretty_progress(int step, int steps, float time) {
9797
}
9898
}
9999
progress += "|";
100-
printf(time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s",
101-
progress.c_str(), step, steps,
102-
time > 1.0f || time == 0 ? time : (1.0f / time));
103-
fflush(stdout); // for linux
100+
LOG_DEFAULT(time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s",
101+
progress.c_str(), step, steps,
102+
time > 1.0f || time == 0 ? time : (1.0f / time));
104103
if (step == steps) {
105-
printf("\n");
104+
LOG_DEFAULT("\n");
106105
}
107106
}
108107

@@ -1749,7 +1748,7 @@ struct SpatialTransformer {
17491748
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS)
17501749
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head]
17511750
#else
1752-
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, max_position]
1751+
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, max_position]
17531752
// kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
17541753
kq = ggml_soft_max_inplace(ctx, kq);
17551754

util.cpp

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "util.h"
22

3-
#include <stdarg.h>
43
#include <codecvt>
4+
#include <cstdarg>
55
#include <fstream>
66
#include <locale>
77
#include <thread>
@@ -170,34 +170,72 @@ void set_sd_log_level(SDLogLevel level) {
170170
log_level = level;
171171
}
172172

173-
void log_printf(SDLogLevel level, const char* file, int line, const char* format, ...) {
173+
void default_sd_logger(SDLogLevel level, const char* text) {
174+
if (level == SDLogLevel::ERROR) {
175+
fputs(text, stderr);
176+
fflush(stderr);
177+
} else {
178+
fputs(text, stdout);
179+
fflush(stdout);
180+
}
181+
}
182+
183+
static sd_logger_function_t sd_logger = &default_sd_logger;
184+
185+
std::string log_prefix(SDLogLevel level, const char* file, int line) {
186+
const char* format = nullptr;
187+
switch (level) {
188+
case SDLogLevel::DEBUG: {
189+
format = "[DEBUG] %s:%-4d - ";
190+
} break;
191+
case SDLogLevel::INFO: {
192+
format = "[INFO] %s:%-4d - ";
193+
} break;
194+
case SDLogLevel::WARN: {
195+
format = "[WARN] %s:%-4d - ";
196+
} break;
197+
case SDLogLevel::ERROR: {
198+
format = "[ERROR] %s:%-4d - ";
199+
} break;
200+
}
201+
202+
char buffer[128];
203+
const int len = std::snprintf(buffer, sizeof(buffer), format, basename(file).c_str(), line);
204+
if (len >= sizeof(buffer)) {
205+
std::string buffer2(len + 1, '\0');
206+
std::snprintf(&buffer2[0], len + 1, format, basename(file).c_str(), line);
207+
return buffer2;
208+
}
209+
return buffer;
210+
}
211+
212+
void log_printf(SDLogLevel level, bool enable_log_tag, const char* file, int line, const char* format, ...) {
174213
if (level < log_level) {
175214
return;
176215
}
216+
177217
va_list args;
178218
va_start(args, format);
179-
180-
if (level == SDLogLevel::DEBUG) {
181-
printf("[DEBUG] %s:%-4d - ", basename(file).c_str(), line);
182-
vprintf(format, args);
183-
printf("\n");
184-
fflush(stdout);
185-
} else if (level == SDLogLevel::INFO) {
186-
printf("[INFO] %s:%-4d - ", basename(file).c_str(), line);
187-
vprintf(format, args);
188-
printf("\n");
189-
fflush(stdout);
190-
} else if (level == SDLogLevel::WARN) {
191-
fprintf(stdout, "[WARN] %s:%-4d - ", basename(file).c_str(), line);
192-
vfprintf(stdout, format, args);
193-
fprintf(stdout, "\n");
194-
fflush(stdout);
219+
const char* log_prefix_str = "";
220+
if (enable_log_tag) {
221+
log_prefix_str = log_prefix(level, file, line).c_str();
222+
}
223+
char buffer[128];
224+
const int len = std::vsnprintf(buffer, sizeof(buffer), format, args);
225+
if (len < sizeof(buffer)) {
226+
const std::string log_message = log_prefix_str + std::string(buffer);
227+
sd_logger(level, log_message.c_str());
195228
} else {
196-
fprintf(stderr, "[ERROR] %s:%-4d - ", basename(file).c_str(), line);
197-
vfprintf(stderr, format, args);
198-
fprintf(stderr, "\n");
199-
fflush(stderr);
229+
char* buffer2 = new char[len + 2];
230+
std::vsnprintf(buffer2, len + 1, format, args);
231+
buffer2[len + 1] = 0;
232+
const std::string log_message = log_prefix_str + std::string(buffer2);
233+
sd_logger(level, log_message.c_str());
234+
delete[] buffer2;
200235
}
201-
202236
va_end(args);
203237
}
238+
239+
void set_sd_logger(const sd_logger_function_t& sd_logger_function) {
240+
sd_logger = sd_logger_function;
241+
}

util.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#ifndef __UTIL_H__
22
#define __UTIL_H__
33

4-
#include <string>
54
#include <cstdint>
5+
#include <functional>
6+
#include <string>
67

78
bool ends_with(const std::string& str, const std::string& ending);
89
bool starts_with(const std::string& str, const std::string& start);
@@ -33,10 +34,15 @@ enum SDLogLevel {
3334

3435
void set_sd_log_level(SDLogLevel level);
3536

36-
void log_printf(SDLogLevel level, const char* file, int line, const char* format, ...);
37+
void log_printf(SDLogLevel level, bool enable_log_tag, const char* file, int line, const char* format, ...);
38+
39+
typedef std::function<void(SDLogLevel level, const char* text)> sd_logger_function_t;
40+
41+
void set_sd_logger(const sd_logger_function_t& sd_logger_function);
3742

38-
#define LOG_DEBUG(format, ...) log_printf(SDLogLevel::DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__)
39-
#define LOG_INFO(format, ...) log_printf(SDLogLevel::INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)
40-
#define LOG_WARN(format, ...) log_printf(SDLogLevel::WARN, __FILE__, __LINE__, format, ##__VA_ARGS__)
41-
#define LOG_ERROR(format, ...) log_printf(SDLogLevel::ERROR, __FILE__, __LINE__, format, ##__VA_ARGS__)
43+
#define LOG_DEFAULT(format, ...) log_printf(SDLogLevel::INFO, false, __FILE__, __LINE__, format, ##__VA_ARGS__)
44+
#define LOG_DEBUG(format, ...) log_printf(SDLogLevel::DEBUG, true, __FILE__, __LINE__, format, ##__VA_ARGS__)
45+
#define LOG_INFO(format, ...) log_printf(SDLogLevel::INFO, true, __FILE__, __LINE__, format, ##__VA_ARGS__)
46+
#define LOG_WARN(format, ...) log_printf(SDLogLevel::WARN, true, __FILE__, __LINE__, format, ##__VA_ARGS__)
47+
#define LOG_ERROR(format, ...) log_printf(SDLogLevel::ERROR, true, __FILE__, __LINE__, format, ##__VA_ARGS__)
4248
#endif // __UTIL_H__

0 commit comments

Comments
 (0)