@@ -13,6 +13,7 @@ struct SigmaSchedule {
1313 float alphas_cumprod[TIMESTEPS];
1414 float sigmas[TIMESTEPS];
1515 float log_sigmas[TIMESTEPS];
16+ int version = 0 ;
1617
1718 virtual std::vector<float > get_sigmas (uint32_t n) = 0;
1819
@@ -75,6 +76,144 @@ struct DiscreteSchedule : SigmaSchedule {
7576 }
7677};
7778
79+ /*
80+ https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
81+ */
82+ struct AYSSchedule : SigmaSchedule {
83+ /* interp and linear_interp adapted from dpilger26's NumCpp library:
84+ * https://github.com/dpilger26/NumCpp/tree/5e40aab74d14e257d65d3dc385c9ff9e2120c60e */
85+ constexpr double interp (double left, double right, double perc) noexcept {
86+ return (left * (1 . - perc)) + (right * perc);
87+ }
88+
89+ /* This will make the assumption that the reference x and y values are
90+ * already sorted in ascending order because they are being generated as
91+ * such in the calling function */
92+ std::vector<double > linear_interp (std::vector<float > new_x,
93+ const std::vector<float > ref_x,
94+ const std::vector<float > ref_y) {
95+ const size_t len_x = new_x.size ();
96+ size_t i = 0 ;
97+ size_t j = 0 ;
98+ std::vector<double > new_y (len_x);
99+
100+ if (ref_x.size () != ref_y.size ()) {
101+ LOG_ERROR (" Linear Interoplation Failed: length mismatch" );
102+ return new_y;
103+ }
104+
105+ /* serves as the bounds checking for the below while loop */
106+ if ((new_x[0 ] < ref_x[0 ]) || (new_x[new_x.size () - 1 ] > ref_x[ref_x.size () - 1 ])) {
107+ LOG_ERROR (" Linear Interpolation Failed: bad bounds" );
108+ return new_y;
109+ }
110+
111+ while (i < len_x) {
112+ if ((ref_x[j] > new_x[i]) || (new_x[i] > ref_x[j + 1 ])) {
113+ j++;
114+ continue ;
115+ }
116+
117+ const double perc = static_cast <double >(new_x[i] - ref_x[j]) / static_cast <double >(ref_x[j + 1 ] - ref_x[j]);
118+
119+ new_y[i] = interp (ref_y[j], ref_y[j + 1 ], perc);
120+ i++;
121+ }
122+
123+ return new_y;
124+ }
125+
126+ std::vector<float > linear_space (const float start, const float end, const size_t num_points) {
127+ std::vector<float > result (num_points);
128+ const float inc = (end - start) / (static_cast <float >(num_points - 1 ));
129+
130+ if (num_points > 0 ) {
131+ result[0 ] = start;
132+
133+ for (size_t i = 1 ; i < num_points; i++) {
134+ result[i] = result[i - 1 ] + inc;
135+ }
136+ }
137+
138+ return result;
139+ }
140+
141+ std::vector<float > log_linear_interpolation (std::vector<float > sigma_in,
142+ const size_t new_len) {
143+ const size_t s_len = sigma_in.size ();
144+ std::vector<float > x_vals = linear_space (0 .f , 1 .f , s_len);
145+ std::vector<float > y_vals (s_len);
146+
147+ /* Reverses the input array to be ascending instead of descending,
148+ * also hits it with a log, it is log-linear interpolation after all */
149+ for (size_t i = 0 ; i < s_len; i++) {
150+ y_vals[i] = std::log (sigma_in[s_len - i - 1 ]);
151+ }
152+
153+ std::vector<float > new_x_vals = linear_space (0 .f , 1 .f , new_len);
154+ std::vector<double > new_y_vals = linear_interp (new_x_vals, x_vals, y_vals);
155+ std::vector<float > results (new_len);
156+
157+ for (size_t i = 0 ; i < new_len; i++) {
158+ results[i] = static_cast <float >(std::exp (new_y_vals[new_len - i - 1 ]));
159+ }
160+
161+ return results;
162+ }
163+
164+ std::vector<float > get_sigmas (uint32_t len) {
165+ const std::vector<float > noise_levels[] = {
166+ /* SD1.5 */
167+ {14 .6146412293f , 6 .4745760956f , 3 .8636745985f , 2 .6946151520f ,
168+ 1 .8841921177f , 1 .3943805092f , 0 .9642583904f , 0 .6523686016f ,
169+ 0 .3977456272f , 0 .1515232662f , 0 .0291671582f },
170+ /* SDXL */
171+ {14 .6146412293f , 6 .3184485287f , 3 .7681790315f , 2 .1811480769f ,
172+ 1 .3405244945f , 0 .8620721141f , 0 .5550693289f , 0 .3798540708f ,
173+ 0 .2332364134f , 0 .1114188177f , 0 .0291671582f },
174+ /* SVD */
175+ {700 .00f , 54 .5f , 15 .886f , 7 .977f , 4 .248f , 1 .789f , 0 .981f , 0 .403f ,
176+ 0 .173f , 0 .034f , 0 .002f },
177+ };
178+
179+ std::vector<float > inputs;
180+ std::vector<float > results (len + 1 );
181+
182+ switch (version) {
183+ case VERSION_2_x: /* fallthrough */
184+ LOG_WARN (" AYS not designed for SD2.X models" );
185+ case VERSION_1_x:
186+ LOG_INFO (" AYS using SD1.5 noise levels" );
187+ inputs = noise_levels[0 ];
188+ break ;
189+ case VERSION_XL:
190+ LOG_INFO (" AYS using SDXL noise levels" );
191+ inputs = noise_levels[1 ];
192+ break ;
193+ case VERSION_SVD:
194+ LOG_INFO (" AYS using SVD noise levels" );
195+ inputs = noise_levels[2 ];
196+ break ;
197+ default :
198+ LOG_ERROR (" Version not compatable with AYS scheduler" );
199+ return results;
200+ }
201+
202+ /* Stretches those pre-calculated reference levels out to the desired
203+ * size using log-linear interpolation */
204+ if ((len + 1 ) != inputs.size ()) {
205+ results = log_linear_interpolation (inputs, len + 1 );
206+ } else {
207+ results = inputs;
208+ }
209+
210+ /* Not sure if this is strictly neccessary */
211+ results[len] = 0 .0f ;
212+
213+ return results;
214+ }
215+ };
216+
78217struct KarrasSchedule : SigmaSchedule {
79218 std::vector<float > get_sigmas (uint32_t n) {
80219 // These *COULD* be function arguments here,
@@ -122,4 +261,4 @@ struct CompVisVDenoiser : public Denoiser {
122261 }
123262};
124263
125- #endif // __DENOISER_HPP__
264+ #endif // __DENOISER_HPP__
0 commit comments