Skip to content

Commit 53e8d81

Browse files
authored
Generalize SEEDS samplers (comfyanonymous#8529)
Restore VP algorithm for RF and refactor noise_coeffs and half-logSNR calculations
1 parent 29596bd commit 53e8d81

File tree

1 file changed

+77
-28
lines changed

1 file changed

+77
-28
lines changed

comfy/k_diffusion/sampling.py

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from functools import partial
23

34
from scipy import integrate
45
import torch
@@ -142,6 +143,33 @@ def __call__(self, sigma, sigma_next):
142143
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
143144

144145

146+
def sigma_to_half_log_snr(sigma, model_sampling):
147+
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
148+
if isinstance(model_sampling, comfy.model_sampling.CONST):
149+
# log((1 - t) / t) = log((1 - sigma) / sigma)
150+
return sigma.logit().neg()
151+
return sigma.log().neg()
152+
153+
154+
def half_log_snr_to_sigma(half_log_snr, model_sampling):
155+
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
156+
if isinstance(model_sampling, comfy.model_sampling.CONST):
157+
# 1 / (1 + exp(half_log_snr))
158+
return half_log_snr.neg().sigmoid()
159+
return half_log_snr.neg().exp()
160+
161+
162+
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
163+
"""Adjust the first sigma to avoid invalid logSNR."""
164+
if len(sigmas) <= 1:
165+
return sigmas
166+
if isinstance(model_sampling, comfy.model_sampling.CONST):
167+
if sigmas[0] >= 1:
168+
sigmas = sigmas.clone()
169+
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
170+
return sigmas
171+
172+
145173
@torch.no_grad()
146174
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
147175
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -1449,100 +1477,121 @@ def default_noise_scaler(sigma):
14491477
old_denoised = denoised
14501478
return x
14511479

1480+
14521481
@torch.no_grad()
14531482
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
1454-
'''
1455-
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
1456-
Arxiv: https://arxiv.org/abs/2305.14267
1457-
'''
1483+
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
1484+
arXiv: https://arxiv.org/abs/2305.14267
1485+
"""
14581486
extra_args = {} if extra_args is None else extra_args
14591487
seed = extra_args.get("seed", None)
14601488
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
14611489
s_in = x.new_ones([x.shape[0]])
14621490

14631491
inject_noise = eta > 0 and s_noise > 0
14641492

1493+
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
1494+
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
1495+
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
1496+
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
1497+
14651498
for i in trange(len(sigmas) - 1, disable=disable):
14661499
denoised = model(x, sigmas[i] * s_in, **extra_args)
14671500
if callback is not None:
14681501
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
14691502
if sigmas[i + 1] == 0:
14701503
x = denoised
14711504
else:
1472-
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
1473-
h = t_next - t
1505+
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
1506+
h = lambda_t - lambda_s
14741507
h_eta = h * (eta + 1)
1475-
s = t + r * h
1508+
lambda_s_1 = lambda_s + r * h
14761509
fac = 1 / (2 * r)
1477-
sigma_s = s.neg().exp()
1510+
sigma_s_1 = sigma_fn(lambda_s_1)
1511+
1512+
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
1513+
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
1514+
alpha_t = sigmas[i + 1] * lambda_t.exp()
14781515

14791516
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
14801517
if inject_noise:
1518+
# 0 < r < 1
14811519
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
1482-
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
1483-
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
1520+
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
1521+
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
14841522

14851523
# Step 1
1486-
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
1524+
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
14871525
if inject_noise:
1488-
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
1489-
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
1526+
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
1527+
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
14901528

14911529
# Step 2
14921530
denoised_d = (1 - fac) * denoised + fac * denoised_2
1493-
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
1531+
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
14941532
if inject_noise:
14951533
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
14961534
return x
14971535

1536+
14981537
@torch.no_grad()
14991538
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
1500-
'''
1501-
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
1502-
Arxiv: https://arxiv.org/abs/2305.14267
1503-
'''
1539+
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
1540+
arXiv: https://arxiv.org/abs/2305.14267
1541+
"""
15041542
extra_args = {} if extra_args is None else extra_args
15051543
seed = extra_args.get("seed", None)
15061544
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
15071545
s_in = x.new_ones([x.shape[0]])
15081546

15091547
inject_noise = eta > 0 and s_noise > 0
15101548

1549+
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
1550+
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
1551+
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
1552+
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
1553+
15111554
for i in trange(len(sigmas) - 1, disable=disable):
15121555
denoised = model(x, sigmas[i] * s_in, **extra_args)
15131556
if callback is not None:
15141557
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
15151558
if sigmas[i + 1] == 0:
15161559
x = denoised
15171560
else:
1518-
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
1519-
h = t_next - t
1561+
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
1562+
h = lambda_t - lambda_s
15201563
h_eta = h * (eta + 1)
1521-
s_1 = t + r_1 * h
1522-
s_2 = t + r_2 * h
1523-
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
1564+
lambda_s_1 = lambda_s + r_1 * h
1565+
lambda_s_2 = lambda_s + r_2 * h
1566+
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
1567+
1568+
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
1569+
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
1570+
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
1571+
alpha_t = sigmas[i + 1] * lambda_t.exp()
15241572

15251573
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
15261574
if inject_noise:
1575+
# 0 < r_1 < r_2 < 1
15271576
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
1528-
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
1529-
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
1577+
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
1578+
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
15301579
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
15311580

15321581
# Step 1
1533-
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
1582+
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
15341583
if inject_noise:
15351584
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
15361585
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
15371586

15381587
# Step 2
1539-
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
1588+
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
15401589
if inject_noise:
15411590
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
15421591
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
15431592

15441593
# Step 3
1545-
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
1594+
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
15461595
if inject_noise:
15471596
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
15481597
return x

0 commit comments

Comments
 (0)