@@ -182,8 +182,6 @@ struct SpatialTransformer {
182182
183183 std::vector<Transformer> transformers;
184184
185- struct ggml_tensor * attn_scale;
186-
187185 // proj_out
188186 struct ggml_tensor * proj_out_w; // [in_channels, in_channels, 1, 1]
189187 struct ggml_tensor * proj_out_b; // [in_channels,]
@@ -202,7 +200,6 @@ struct SpatialTransformer {
202200 mem_size += 2 * in_channels * ggml_type_sizef (GGML_TYPE_F32); // norm_w/norm_b
203201 mem_size += 2 * in_channels * in_channels * 1 * 1 * ggml_type_sizef (GGML_TYPE_F16); // proj_in_w/proj_out_w
204202 mem_size += 2 * in_channels * ggml_type_sizef (GGML_TYPE_F32); // proj_in_b/proj_out_b
205- mem_size += 1 * ggml_type_sizef (GGML_TYPE_F32); // attn_scale
206203
207204 // transformer
208205 for (auto & transformer : transformers) {
@@ -226,11 +223,6 @@ struct SpatialTransformer {
226223 proj_out_w = ggml_new_tensor_4d (ctx, GGML_TYPE_F16, 1 , 1 , in_channels, in_channels);
227224 proj_out_b = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, in_channels);
228225
229- attn_scale = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, 1 );
230- ggml_allocr_alloc (alloc, attn_scale);
231- float scale = 1 .0f / sqrt ((float )d_head);
232- ggml_backend_tensor_set (attn_scale, &scale, 0 , sizeof (scale));
233-
234226 // transformer
235227 for (auto & transformer : transformers) {
236228 transformer.norm1_w = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, in_channels);
@@ -332,7 +324,7 @@ struct SpatialTransformer {
332324 x = ggml_reshape_2d (ctx, x, c, h * w * n); // [N * h * w, in_channels]
333325 struct ggml_tensor * q = ggml_mul_mat (ctx, transformer.attn1_q_w , x); // [N * h * w, in_channels]
334326#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL)
335- q = ggml_scale_inplace (ctx, q, attn_scale );
327+ q = ggml_scale_inplace (ctx, q, 1 . 0f / sqrt (( float )d_head) );
336328#endif
337329 q = ggml_reshape_4d (ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head]
338330 q = ggml_cont (ctx, ggml_permute (ctx, q, 0 , 2 , 1 , 3 )); // [N, n_head, h * w, d_head]
@@ -380,7 +372,7 @@ struct SpatialTransformer {
380372 context = ggml_reshape_2d (ctx, context, context->ne [0 ], context->ne [1 ] * context->ne [2 ]); // [N * max_position, hidden_size]
381373 struct ggml_tensor * q = ggml_mul_mat (ctx, transformer.attn2_q_w , x); // [N * h * w, in_channels]
382374#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL)
383- q = ggml_scale_inplace (ctx, q, attn_scale );
375+ q = ggml_scale_inplace (ctx, q, 1 . 0f / sqrt (( float )d_head) );
384376#endif
385377 q = ggml_reshape_4d (ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head]
386378 q = ggml_cont (ctx, ggml_permute (ctx, q, 0 , 2 , 1 , 3 )); // [N, n_head, h * w, d_head]
0 commit comments