Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,8 @@ namespace Flux {
set_backend_tensor_data(mod_index_arange, mod_index_arange_vec.data());
}
y = to_backend(y);

float current_timestep = ggml_get_f32_1d(timesteps, 0);
LOG_DEBUG("current_timestep %f", current_timestep);
timesteps = to_backend(timesteps);
if (flux_params.guidance_embed || flux_params.is_chroma) {
guidance = to_backend(guidance);
Expand All @@ -1275,6 +1276,30 @@ namespace Flux {
ref_latents[i] = to_backend(ref_latents[i]);
}

// get use_yarn, use_ntk and use_dype from env for now (TODO: add args)
// Env value could be one of yarn, dy_yarn, ntk or dy_ntk, (anything else means disabled)
const char* env_value = getenv("FLUX_ROPE");
bool use_yarn = false;
bool use_dype = false;
bool use_ntk = false;
if (env_value != nullptr) {
if (strcmp(env_value, "YARN") == 0) {
LOG_DEBUG("Using YARN RoPE");
use_yarn = true;
} else if (strcmp(env_value, "DY_YARN") == 0) {
LOG_DEBUG("Using DY YARN RoPE");
use_yarn = true;
use_dype = true;
} else if (strcmp(env_value, "NTK") == 0) {
LOG_DEBUG("Using NTK RoPE");
use_ntk = true;
} else if (strcmp(env_value, "DY_NTK") == 0) {
LOG_DEBUG("Using DY NTK RoPE");
use_ntk = true;
use_dype = true;
}
}

pe_vec = Rope::gen_flux_pe(x->ne[1],
x->ne[0],
flux_params.patch_size,
Expand All @@ -1283,7 +1308,11 @@ namespace Flux {
ref_latents,
increase_ref_index,
flux_params.theta,
flux_params.axes_dim);
flux_params.axes_dim,
use_yarn,
use_dype,
use_ntk,
current_timestep);
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
Expand Down
195 changes: 193 additions & 2 deletions rope.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,115 @@ namespace Rope {
return result;
}

float find_correction_factor(float num_rotations, int dim, float base, float max_position_embeddings) {
return (dim * std::log(max_position_embeddings / (num_rotations * 2 * 3.14159265358979323846))) / (2 * std::log(base));
}

std::pair<int, int> find_correction_range(float low_ratio, float high_ratio, int dim, float base, float ori_max_pe_len) {
float low = std::floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len));
float high = std::ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len));
return {std::max(0, static_cast<int>(low)), std::min(dim / 2, static_cast<int>(high))};
}

std::vector<float> linear_ramp_mask(int min, int max, int dim) {
if (min == max) {
max += 0.001f; // Prevent singularity
}
std::vector<float> ramp(dim);
for (int i = 0; i < dim; ++i) {
ramp[i] = std::max(0.0f, std::min(1.0f, static_cast<float>(i - min) / (max - min)));
}
return ramp;
}

__STATIC_INLINE__ std::vector<std::vector<float>> rope_ext(
const std::vector<float>& pos,
int dim,
float theta = 10000.0f,
bool use_real = false,
float linear_factor = 1.0f,
float ntk_factor = 1.0f,
bool repeat_interleave_real = true,
bool yarn = false,
int max_pe_len = -1,
int ori_max_pe_len = 64,
bool dype = false,
float current_timestep = 1.0f) {
assert(dim % 2 == 0);
int half_dim = dim / 2;

// Compute frequencies
std::vector<float> freqs_base(half_dim);
std::vector<float> freqs_linear(half_dim);
std::vector<float> freqs_ntk(half_dim);
std::vector<float> freqs(half_dim);

if (yarn && max_pe_len > ori_max_pe_len) {
float beta_0 = 1.25f;
float beta_1 = 0.75f;
float gamma_0 = 16.0f;
float gamma_1 = 2.0f;

float scale = std::max(1.0f, static_cast<float>(max_pe_len) / ori_max_pe_len);
// d,t,s
float new_base = theta * std::pow(scale, half_dim / (half_dim - 1));
for (int i = 0; i < half_dim; ++i) {
float exponent = static_cast<float>(i) / half_dim;
freqs_base[i] = 1.0f / std::pow(theta, exponent);
freqs_linear[i] = 1.0f / (scale * std::pow(theta, exponent));
freqs_ntk[i] = 1.0f / std::pow(new_base, exponent);
}

if (dype) {
beta_0 = std::pow(beta_0, 2.0f * current_timestep * current_timestep);
beta_1 = std::pow(beta_1, 2.0f * current_timestep * current_timestep);
gamma_0 = std::pow(gamma_0, 2.0f * current_timestep * current_timestep);
gamma_1 = std::pow(gamma_1, 2.0f * current_timestep * current_timestep);
}

// Apply correction range and linear ramp mask
auto [low, high] = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len);
auto mask = linear_ramp_mask(low, high, half_dim);
for (int i = 0; i < half_dim; ++i) {
freqs[i] = freqs_linear[i] * mask[i] + freqs_ntk[i] * (1.0f - mask[i]);
}

// Apply gamma correction
auto [low_gamma, high_gamma] = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len);
auto mask_gamma = linear_ramp_mask(low_gamma, high_gamma, half_dim);
for (int i = 0; i < half_dim; ++i) {
freqs[i] = freqs[i] * mask_gamma[i] + freqs_base[i] * (1.0f - mask_gamma[i]);
}
} else {
float theta_ntk = theta * ntk_factor;
for (int i = 0; i < half_dim; ++i) {
float exponent = static_cast<float>(i) / half_dim;
freqs[i] = 1.0f / std::pow(theta_ntk, exponent) / linear_factor;
}
}

// Outer product of pos and freqs
std::vector<std::vector<float>> freqs_outer(pos.size(), std::vector<float>(half_dim));
for (size_t i = 0; i < pos.size(); ++i) {
for (int j = 0; j < half_dim; ++j) {
freqs_outer[i][j] = pos[i] * freqs[j];
}
}

std::vector<std::vector<float>> result;
result.resize(pos.size(), std::vector<float>(half_dim * 4));
for (size_t i = 0; i < pos.size(); ++i) {
for (int j = 0; j < half_dim; ++j) {
result[i][4 * j] = std::cos(freqs_outer[i][j]); // cos
result[i][4 * j + 1] = -std::sin(freqs_outer[i][j]); // -sin
result[i][4 * j + 2] = std::sin(freqs_outer[i][j]); // sin
result[i][4 * j + 3] = std::cos(freqs_outer[i][j]); // cos
}
}

return result;
}

// Generate IDs for image patches and text
__STATIC_INLINE__ std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) {
return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0));
Expand Down Expand Up @@ -151,6 +260,53 @@ namespace Rope {
return flatten(emb);
}

std::vector<float> embed_nd_ext(
const std::vector<std::vector<float>>& ids,
int bs,
float theta,
const std::vector<int>& axes_dim,
bool yarn = false,
std::vector<int> max_pe_len = {},
int ori_max_pe_len = 64,
bool dype = false,
float current_timestep = 1.0f,
std::vector<float> ntk_factors = {}) {
std::vector<std::vector<float>> trans_ids = transpose(ids);
size_t pos_len = ids.size() / bs;
int num_axes = axes_dim.size();

if (ntk_factors.size() == 0) {
ntk_factors = std::vector<float>(num_axes, 1.0f);
}
if (max_pe_len.size() == 0) {
max_pe_len = std::vector<int>(num_axes, -1);
}

int emb_dim = 0;
for (int d : axes_dim) {
emb_dim += d;
}

std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2, 0.0f));
int offset = 0;

for (int i = 0; i < num_axes; ++i) {
std::vector<std::vector<float>> rope_emb = rope_ext(
trans_ids[i], axes_dim[i], theta, false, 1.0f, ntk_factors[i], true, yarn, max_pe_len[i], ori_max_pe_len, dype, current_timestep);

for (int b = 0; b < bs; ++b) {
for (size_t j = 0; j < pos_len; ++j) {
for (size_t k = 0; k < rope_emb[j].size(); ++k) {
emb[b * pos_len + j][offset + k] = rope_emb[j][k];
}
}
}
offset += static_cast<int>(axes_dim[i] * 2);
}

return flatten(emb);
}

__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
int bs,
const std::vector<ggml_tensor*>& ref_latents,
Expand Down Expand Up @@ -210,9 +366,44 @@ namespace Rope {
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
int theta,
const std::vector<int>& axes_dim) {
const std::vector<int>& axes_dim,
bool use_yarn = false,
bool use_dype = false,
bool use_ntk = false,
float current_timestep = 1.0f) {
int base_resolution = 1024;
// set it via environment variable for now (TODO: arg)
const char* env_base_resolution = getenv("FLUX_DYPE_BASE_RESOLUTION");
if (env_base_resolution != nullptr) {
base_resolution = atoi(env_base_resolution);
}
int base_patches = base_resolution / 16;
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim);
std::vector<int> max_pos_vec = {};
std::vector<float> ntk_factor_vec = {};
for (int i = 0; i < axes_dim.size(); i++) {
float max_pos_f = 0.0f;
for (const auto& row : ids) {
float val = row[i];
if (val > max_pos_f) {
max_pos_f = val;
}
}
int max_pos = static_cast<int>(max_pos_f) + 1;
max_pos_vec.push_back(max_pos);
float ntk_factor = 1.0f;
if (use_ntk) {
float base_ntk = pow((float)max_pos / base_patches, (float)axes_dim[i] / (axes_dim[i] - 2));
ntk_factor = use_dype ? pow(base_ntk, 2.0f * current_timestep * current_timestep) : base_ntk;
ntk_factor = std::max(1.0f, ntk_factor);
}
ntk_factor_vec.push_back(ntk_factor);
}
if (use_yarn || use_ntk) {
return embed_nd_ext(ids, bs, theta, axes_dim, use_yarn, max_pos_vec, base_patches, use_dype, current_timestep, ntk_factor_vec);
} else {
return embed_nd(ids, bs, theta, axes_dim);
}
}

__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h,
Expand Down
Loading