|
1 | 1 | import math
|
| 2 | +from functools import partial |
2 | 3 |
|
3 | 4 | from scipy import integrate
|
4 | 5 | import torch
|
@@ -142,6 +143,33 @@ def __call__(self, sigma, sigma_next):
|
142 | 143 | return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
143 | 144 |
|
144 | 145 |
|
| 146 | +def sigma_to_half_log_snr(sigma, model_sampling): |
| 147 | + """Convert sigma to half-logSNR log(alpha_t / sigma_t).""" |
| 148 | + if isinstance(model_sampling, comfy.model_sampling.CONST): |
| 149 | + # log((1 - t) / t) = log((1 - sigma) / sigma) |
| 150 | + return sigma.logit().neg() |
| 151 | + return sigma.log().neg() |
| 152 | + |
| 153 | + |
| 154 | +def half_log_snr_to_sigma(half_log_snr, model_sampling): |
| 155 | + """Convert half-logSNR log(alpha_t / sigma_t) to sigma.""" |
| 156 | + if isinstance(model_sampling, comfy.model_sampling.CONST): |
| 157 | + # 1 / (1 + exp(half_log_snr)) |
| 158 | + return half_log_snr.neg().sigmoid() |
| 159 | + return half_log_snr.neg().exp() |
| 160 | + |
| 161 | + |
| 162 | +def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4): |
| 163 | + """Adjust the first sigma to avoid invalid logSNR.""" |
| 164 | + if len(sigmas) <= 1: |
| 165 | + return sigmas |
| 166 | + if isinstance(model_sampling, comfy.model_sampling.CONST): |
| 167 | + if sigmas[0] >= 1: |
| 168 | + sigmas = sigmas.clone() |
| 169 | + sigmas[0] = model_sampling.percent_to_sigma(percent_offset) |
| 170 | + return sigmas |
| 171 | + |
| 172 | + |
145 | 173 | @torch.no_grad()
|
146 | 174 | def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
147 | 175 | """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
@@ -1449,100 +1477,121 @@ def default_noise_scaler(sigma):
|
1449 | 1477 | old_denoised = denoised
|
1450 | 1478 | return x
|
1451 | 1479 |
|
| 1480 | + |
1452 | 1481 | @torch.no_grad()
|
1453 | 1482 | def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
1454 |
| - ''' |
1455 |
| - SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2 |
1456 |
| - Arxiv: https://arxiv.org/abs/2305.14267 |
1457 |
| - ''' |
| 1483 | + """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. |
| 1484 | + arXiv: https://arxiv.org/abs/2305.14267 |
| 1485 | + """ |
1458 | 1486 | extra_args = {} if extra_args is None else extra_args
|
1459 | 1487 | seed = extra_args.get("seed", None)
|
1460 | 1488 | noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
1461 | 1489 | s_in = x.new_ones([x.shape[0]])
|
1462 | 1490 |
|
1463 | 1491 | inject_noise = eta > 0 and s_noise > 0
|
1464 | 1492 |
|
| 1493 | + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') |
| 1494 | + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) |
| 1495 | + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) |
| 1496 | + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) |
| 1497 | + |
1465 | 1498 | for i in trange(len(sigmas) - 1, disable=disable):
|
1466 | 1499 | denoised = model(x, sigmas[i] * s_in, **extra_args)
|
1467 | 1500 | if callback is not None:
|
1468 | 1501 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
1469 | 1502 | if sigmas[i + 1] == 0:
|
1470 | 1503 | x = denoised
|
1471 | 1504 | else:
|
1472 |
| - t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() |
1473 |
| - h = t_next - t |
| 1505 | + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) |
| 1506 | + h = lambda_t - lambda_s |
1474 | 1507 | h_eta = h * (eta + 1)
|
1475 |
| - s = t + r * h |
| 1508 | + lambda_s_1 = lambda_s + r * h |
1476 | 1509 | fac = 1 / (2 * r)
|
1477 |
| - sigma_s = s.neg().exp() |
| 1510 | + sigma_s_1 = sigma_fn(lambda_s_1) |
| 1511 | + |
| 1512 | + # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) |
| 1513 | + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() |
| 1514 | + alpha_t = sigmas[i + 1] * lambda_t.exp() |
1478 | 1515 |
|
1479 | 1516 | coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
1480 | 1517 | if inject_noise:
|
| 1518 | + # 0 < r < 1 |
1481 | 1519 | noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
1482 |
| - noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() |
1483 |
| - noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1]) |
| 1520 | + noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt() |
| 1521 | + noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1]) |
1484 | 1522 |
|
1485 | 1523 | # Step 1
|
1486 |
| - x_2 = (coeff_1 + 1) * x - coeff_1 * denoised |
| 1524 | + x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised |
1487 | 1525 | if inject_noise:
|
1488 |
| - x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise |
1489 |
| - denoised_2 = model(x_2, sigma_s * s_in, **extra_args) |
| 1526 | + x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise |
| 1527 | + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) |
1490 | 1528 |
|
1491 | 1529 | # Step 2
|
1492 | 1530 | denoised_d = (1 - fac) * denoised + fac * denoised_2
|
1493 |
| - x = (coeff_2 + 1) * x - coeff_2 * denoised_d |
| 1531 | + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d |
1494 | 1532 | if inject_noise:
|
1495 | 1533 | x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
1496 | 1534 | return x
|
1497 | 1535 |
|
| 1536 | + |
1498 | 1537 | @torch.no_grad()
|
1499 | 1538 | def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
1500 |
| - ''' |
1501 |
| - SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3 |
1502 |
| - Arxiv: https://arxiv.org/abs/2305.14267 |
1503 |
| - ''' |
| 1539 | + """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3. |
| 1540 | + arXiv: https://arxiv.org/abs/2305.14267 |
| 1541 | + """ |
1504 | 1542 | extra_args = {} if extra_args is None else extra_args
|
1505 | 1543 | seed = extra_args.get("seed", None)
|
1506 | 1544 | noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
1507 | 1545 | s_in = x.new_ones([x.shape[0]])
|
1508 | 1546 |
|
1509 | 1547 | inject_noise = eta > 0 and s_noise > 0
|
1510 | 1548 |
|
| 1549 | + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') |
| 1550 | + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) |
| 1551 | + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) |
| 1552 | + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) |
| 1553 | + |
1511 | 1554 | for i in trange(len(sigmas) - 1, disable=disable):
|
1512 | 1555 | denoised = model(x, sigmas[i] * s_in, **extra_args)
|
1513 | 1556 | if callback is not None:
|
1514 | 1557 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
1515 | 1558 | if sigmas[i + 1] == 0:
|
1516 | 1559 | x = denoised
|
1517 | 1560 | else:
|
1518 |
| - t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() |
1519 |
| - h = t_next - t |
| 1561 | + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) |
| 1562 | + h = lambda_t - lambda_s |
1520 | 1563 | h_eta = h * (eta + 1)
|
1521 |
| - s_1 = t + r_1 * h |
1522 |
| - s_2 = t + r_2 * h |
1523 |
| - sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp() |
| 1564 | + lambda_s_1 = lambda_s + r_1 * h |
| 1565 | + lambda_s_2 = lambda_s + r_2 * h |
| 1566 | + sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2) |
| 1567 | + |
| 1568 | + # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) |
| 1569 | + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() |
| 1570 | + alpha_s_2 = sigma_s_2 * lambda_s_2.exp() |
| 1571 | + alpha_t = sigmas[i + 1] * lambda_t.exp() |
1524 | 1572 |
|
1525 | 1573 | coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
1526 | 1574 | if inject_noise:
|
| 1575 | + # 0 < r_1 < r_2 < 1 |
1527 | 1576 | noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
1528 |
| - noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt() |
1529 |
| - noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() |
| 1577 | + noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt() |
| 1578 | + noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt() |
1530 | 1579 | noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
1531 | 1580 |
|
1532 | 1581 | # Step 1
|
1533 |
| - x_2 = (coeff_1 + 1) * x - coeff_1 * denoised |
| 1582 | + x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised |
1534 | 1583 | if inject_noise:
|
1535 | 1584 | x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
1536 | 1585 | denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
1537 | 1586 |
|
1538 | 1587 | # Step 2
|
1539 |
| - x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) |
| 1588 | + x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) |
1540 | 1589 | if inject_noise:
|
1541 | 1590 | x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
1542 | 1591 | denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
1543 | 1592 |
|
1544 | 1593 | # Step 3
|
1545 |
| - x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) |
| 1594 | + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) |
1546 | 1595 | if inject_noise:
|
1547 | 1596 | x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
1548 | 1597 | return x
|
0 commit comments