@@ -237,21 +237,14 @@ struct GITSSchedule : SigmaSchedule {
237
237
238
238
struct SGMUniformSchedule : SigmaSchedule {
239
239
std::vector<float > get_sigmas (uint32_t n, float sigma_min_in, float sigma_max_in, t_to_sigma_t t_to_sigma_func) override {
240
- // This schedule's core logic is now handled directly in Denoiser::get_sigmas
241
- // to ensure correct access to both sigma_to_t and t_to_sigma.
242
- // This method is overridden to fulfill the virtual contract but ideally should not be
243
- // the primary execution path for SGMUniform when called from Denoiser::get_sigmas.
244
- // If it IS called, it means the Denoiser::get_sigmas logic wasn't triggered, which is unexpected.
245
- LOG_WARN (" SGMUniformSchedule::get_sigmas was called directly. This might indicate an issue with Denoiser dispatch." );
246
- // Provide a default (potentially incorrect for SGMUniform's intent) or empty schedule to avoid crashes.
247
- // For safety, returning a simple discrete-like schedule in t-space if this is ever hit.
240
+
248
241
std::vector<float > result;
249
242
if (n == 0 ) {
250
243
result.push_back (0 .0f );
251
244
return result;
252
245
}
253
246
result.reserve (n + 1 );
254
- int t_max = TIMESTEPS -1 ; // A common max t value
247
+ int t_max = TIMESTEPS -1 ;
255
248
float step = static_cast <float >(t_max) / static_cast <float >(n > 1 ? (n -1 ) : 1 ) ;
256
249
for (uint32_t i=0 ; i<n; ++i) {
257
250
result.push_back (t_to_sigma_func (t_max - step * i));
@@ -284,39 +277,27 @@ struct SimpleSchedule : SigmaSchedule {
284
277
std::vector<float > result_sigmas;
285
278
286
279
if (n == 0 ) {
287
- return result_sigmas; // Return empty for n=0, consistent with DiscreteSchedule
280
+ return result_sigmas;
288
281
}
289
282
290
283
result_sigmas.reserve (n + 1 );
291
284
292
- // TIMESTEPS is the length of the model's internal sigmas array, typically 1000.
293
- // t_to_sigma(t) maps a timestep t (0 to TIMESTEPS-1) to its sigma value.
294
285
int model_sigmas_len = TIMESTEPS;
295
286
296
- // ss = len(s.sigmas) / steps in Python
297
287
float step_factor = static_cast <float >(model_sigmas_len) / static_cast <float >(n);
298
288
299
289
for (uint32_t i = 0 ; i < n; ++i) {
300
- // Python: s.sigmas[-(1 + int(x * ss))]
301
- // x corresponds to i (0 to n-1)
302
- // int(x * ss) in Python is static_cast<int>(static_cast<float>(i) * step_factor)
303
- // The index -(1 + offset) means (model_sigmas_len - 1 - offset) from the start of a 0-indexed array.
290
+
304
291
int offset_from_start_of_py_array = static_cast <int >(static_cast <float >(i) * step_factor);
305
292
int timestep_index = model_sigmas_len - 1 - offset_from_start_of_py_array;
306
293
307
- // Ensure the index is within valid bounds [0, model_sigmas_len - 1]
308
294
if (timestep_index < 0 ) {
309
295
timestep_index = 0 ;
310
296
}
311
- // No need for upper bound check like `timestep_index >= model_sigmas_len` because
312
- // max offset is for i=n-1: int((n-1)/n * model_sigmas_len) which is < model_sigmas_len.
313
- // So, model_sigmas_len - 1 - max_offset is >= 0 if model_sigmas_len/n >= 1.
314
- // If n > model_sigmas_len, then model_sigmas_len/n < 1, resulting in timestep_index potentially being <0,
315
- // which is handled by the clamp above.
316
297
317
298
result_sigmas.push_back (t_to_sigma (static_cast <float >(timestep_index)));
318
299
}
319
- result_sigmas.push_back (0 .0f ); // Append the final zero sigma
300
+ result_sigmas.push_back (0 .0f );
320
301
return result_sigmas;
321
302
}
322
303
};
@@ -334,7 +315,6 @@ struct Denoiser {
334
315
virtual std::vector<float > get_sigmas (uint32_t n) {
335
316
// Check if the current schedule is SGMUniformSchedule
336
317
if (std::dynamic_pointer_cast<SGMUniformSchedule>(schedule)) {
337
- LOG_DEBUG (" Denoiser::get_sigmas - Using SGM_UNIFORM specific logic" );
338
318
std::vector<float > sigs;
339
319
sigs.reserve (n + 1 );
340
320
@@ -347,26 +327,22 @@ struct Denoiser {
347
327
float start_t_val = this ->sigma_to_t (this ->sigma_max ());
348
328
float end_t_val = this ->sigma_to_t (this ->sigma_min ());
349
329
350
- // Python: torch.linspace(start, end, n + 1)[:-1]
351
- // This creates n points. The k-th point (0-indexed) is start_t_val + k * (end_t_val - start_t_val) / n.
352
330
float dt_per_step;
353
- if (n > 0 ) { // Avoid division by zero if n=0, though covered by earlier check
331
+ if (n > 0 ) {
354
332
dt_per_step = (end_t_val - start_t_val) / static_cast <float >(n);
355
333
} else {
356
334
dt_per_step = 0 .0f ;
357
335
}
358
336
359
-
360
337
for (uint32_t i = 0 ; i < n; ++i) {
361
338
float current_t = start_t_val + static_cast <float >(i) * dt_per_step;
362
339
sigs.push_back (this ->t_to_sigma (current_t ));
363
340
}
364
341
365
- sigs.push_back (0 .0f ); // Append the final zero sigma
342
+ sigs.push_back (0 .0f );
366
343
return sigs;
367
344
368
345
} else { // For all other schedules, use the existing virtual dispatch
369
- LOG_DEBUG (" Denoiser::get_sigmas - Using general schedule dispatch for %s" , typeid (*schedule.get ()).name ());
370
346
auto bound_t_to_sigma = std::bind (&Denoiser::t_to_sigma, this , std::placeholders::_1);
371
347
return schedule->get_sigmas (n, sigma_min (), sigma_max (), bound_t_to_sigma);
372
348
}
0 commit comments