Skip to content

Commit 5fac4d5

Browse files
Vithulepggerganov
andauthored
ggml : vector length agnostic SVE support (ggml-org#9290)
* Implemented vector length agnostic SVE using switch case for 512-bit, 256-bit, 128-bit vector lengths * Implemented vector length agnostic SVE using switch case for 512-bit, 256-bit, 128-bit vector lengths * Removed WhiteSpaces * ggml : style changes + fix 512-bit nb loop check - fix local scope in switch cases - consistent predicate names - empty lines when necessary - opening braces, spaces - const-correctness - add asserts * Update ggml/src/ggml-quants.c Co-authored-by: Georgi Gerganov <[email protected]> --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 5fb5e24 commit 5fac4d5

File tree

1 file changed

+242
-48
lines changed

1 file changed

+242
-48
lines changed

ggml/src/ggml-quants.c

Lines changed: 242 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4003,42 +4003,141 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
40034003
float sumf = 0;
40044004

40054005
#if defined(__ARM_FEATURE_SVE)
4006-
if (ggml_sve_cnt_b == QK8_0) {
4007-
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
4008-
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
4006+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
4007+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
40094008

4010-
svfloat32_t sumv0 = svdup_n_f32(0.0f);
4011-
svfloat32_t sumv1 = svdup_n_f32(0.0f);
4009+
const int vector_length = ggml_sve_cnt_b*8;
40124010

4013-
for (; ib + 1 < nb; ib += 2) {
4014-
const block_q4_0 * restrict x0 = &x[ib + 0];
4015-
const block_q4_0 * restrict x1 = &x[ib + 1];
4016-
const block_q8_0 * restrict y0 = &y[ib + 0];
4017-
const block_q8_0 * restrict y1 = &y[ib + 1];
4018-
4019-
// load x
4020-
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4021-
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4022-
4023-
// 4-bit -> 8-bit
4024-
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
4025-
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
4026-
4027-
// sub 8
4028-
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
4029-
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
4011+
// VLA Implementation using switch case
4012+
switch (vector_length) {
4013+
case 128:
4014+
{
4015+
// predicate for activating higher lanes for 4 float32 elements
4016+
const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
4017+
4018+
for (; ib + 1 < nb; ib += 2) {
4019+
const block_q4_0 * restrict x0 = &x[ib + 0];
4020+
const block_q4_0 * restrict x1 = &x[ib + 1];
4021+
const block_q8_0 * restrict y0 = &y[ib + 0];
4022+
const block_q8_0 * restrict y1 = &y[ib + 1];
4023+
4024+
// load x
4025+
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4026+
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4027+
4028+
// 4-bit -> 8-bit
4029+
const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));
4030+
const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));
4031+
const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));
4032+
const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));
4033+
4034+
// sub 8
4035+
const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
4036+
const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
4037+
const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
4038+
const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);
4039+
4040+
// load y
4041+
const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
4042+
const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);
4043+
const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
4044+
const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);
4045+
4046+
// dot product
4047+
sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,
4048+
svdot_s32(svdup_n_s32(0), qx0ls, qy0l),
4049+
svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
4050+
sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,
4051+
svdot_s32(svdup_n_s32(0), qx1ls, qy1l),
4052+
svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4053+
}
40304054

4031-
// load y
4032-
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
4033-
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
4055+
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4056+
} break;
4057+
case 256:
4058+
{
4059+
// predicate for activating higher lanes for 16 int8 elements
4060+
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
4061+
// predicate for activating lower lanes for 16 int8 elements
4062+
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
4063+
4064+
for (; ib + 1 < nb; ib += 2) {
4065+
const block_q4_0 * restrict x0 = &x[ib + 0];
4066+
const block_q4_0 * restrict x1 = &x[ib + 1];
4067+
const block_q8_0 * restrict y0 = &y[ib + 0];
4068+
const block_q8_0 * restrict y1 = &y[ib + 1];
4069+
4070+
// load x
4071+
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4072+
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4073+
4074+
// 4-bit -> 8-bit
4075+
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
4076+
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
4077+
4078+
// sub 8
4079+
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
4080+
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
4081+
4082+
// load y
4083+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
4084+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
4085+
4086+
// dot product
4087+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
4088+
svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
4089+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
4090+
svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4091+
}
40344092

4035-
// dot product
4036-
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
4037-
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4038-
}
4093+
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4094+
} break;
4095+
case 512:
4096+
{
4097+
// predicate for activating higher lanes for 32 int8 elements
4098+
const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
4099+
4100+
// predicate for activating higher lanes for 16 int8 elements
4101+
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
4102+
// predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
4103+
const svbool_t pl16 = svnot_b_z(ph32, ph16);
4104+
4105+
for (; ib + 1 < nb; ib += 2) {
4106+
const block_q4_0 * restrict x0 = &x[ib + 0];
4107+
const block_q4_0 * restrict x1 = &x[ib + 1];
4108+
const block_q8_0 * restrict y0 = &y[ib + 0];
4109+
const block_q8_0 * restrict y1 = &y[ib + 1];
4110+
4111+
// load x
4112+
const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
4113+
const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
4114+
4115+
// 4-bit -> 8-bit
4116+
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
4117+
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
4118+
4119+
// sub 8
4120+
const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);
4121+
const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);
4122+
4123+
// load y
4124+
const svint8_t qy0 = svld1_s8(ph32, y0->qs);
4125+
const svint8_t qy1 = svld1_s8(ph32, y1->qs);
4126+
4127+
// dot product
4128+
sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,
4129+
svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
4130+
sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,
4131+
svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4132+
}
40394133

4040-
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4134+
sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
4135+
} break;
4136+
default:
4137+
assert(false && "Unsupported vector length");
4138+
break;
40414139
}
4140+
40424141
#elif defined(__ARM_NEON)
40434142
float32x4_t sumv0 = vdupq_n_f32(0.0f);
40444143
float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -5488,29 +5587,124 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
54885587
float sumf = 0;
54895588

54905589
#if defined(__ARM_FEATURE_SVE)
5491-
if (ggml_sve_cnt_b == QK8_0) {
5492-
svfloat32_t sumv0 = svdup_n_f32(0.0f);
5493-
svfloat32_t sumv1 = svdup_n_f32(0.0f);
5590+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
5591+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
54945592

5495-
for (; ib + 1 < nb; ib += 2) {
5496-
const block_q8_0 * restrict x0 = &x[ib + 0];
5497-
const block_q8_0 * restrict x1 = &x[ib + 1];
5498-
const block_q8_0 * restrict y0 = &y[ib + 0];
5499-
const block_q8_0 * restrict y1 = &y[ib + 1];
5593+
const int vector_length = ggml_sve_cnt_b*8;
55005594

5501-
// load x
5502-
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5503-
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5595+
//VLA Implemenation for SVE
5596+
switch (vector_length) {
5597+
case 128:
5598+
{
5599+
// predicate for activating lanes for 16 Int8 elements
5600+
const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);
5601+
const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
5602+
5603+
for (; ib + 1 < nb; ib += 2) {
5604+
const block_q8_0 * restrict x0 = &x[ib + 0];
5605+
const block_q8_0 * restrict x1 = &x[ib + 1];
5606+
const block_q8_0 * restrict y0 = &y[ib + 0];
5607+
const block_q8_0 * restrict y1 = &y[ib + 1];
5608+
5609+
// load x
5610+
const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
5611+
const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);
5612+
const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);
5613+
const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);
5614+
5615+
// load y
5616+
const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);
5617+
const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);
5618+
const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);
5619+
const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);
5620+
5621+
sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,
5622+
svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),
5623+
svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5624+
sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,
5625+
svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),
5626+
svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5627+
}
55045628

5505-
// load y
5506-
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5507-
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5629+
sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));
5630+
} break;
5631+
case 256:
5632+
{
5633+
//printf("sve256");
5634+
for (; ib + 1 < nb; ib += 2) {
5635+
const block_q8_0 * restrict x0 = &x[ib + 0];
5636+
const block_q8_0 * restrict x1 = &x[ib + 1];
5637+
const block_q8_0 * restrict y0 = &y[ib + 0];
5638+
const block_q8_0 * restrict y1 = &y[ib + 1];
5639+
5640+
// load x
5641+
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5642+
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5643+
5644+
// load y
5645+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5646+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5647+
5648+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
5649+
svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5650+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
5651+
svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5652+
}
55085653

5509-
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5510-
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5511-
}
5654+
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5655+
} break;
5656+
case 512:
5657+
{
5658+
// predicate for activating high 256 bit
5659+
const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
5660+
// predicate for activating low 256 bit
5661+
const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);
5662+
5663+
// predicate for activating high lanes for 8 float32 elements
5664+
const svbool_t ph8 = svptrue_pat_b32(SV_VL8);
5665+
// predicate for activating low lanes for 8 float32 elements
5666+
const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);
5667+
5668+
svfloat32_t sumv00 = svdup_n_f32(0.0f);
5669+
5670+
for (; ib + 1 < nb; ib += 2) {
5671+
const block_q8_0 * restrict x0 = &x[ib + 0];
5672+
const block_q8_0 * restrict x1 = &x[ib + 1];
5673+
const block_q8_0 * restrict y0 = &y[ib + 0];
5674+
const block_q8_0 * restrict y1 = &y[ib + 1];
5675+
5676+
//load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
5677+
// and add them to make one 64 element vector
5678+
// load x
5679+
const svint8_t qx_32 = svld1_s8(ph32, x0->qs);
5680+
svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);
5681+
5682+
qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);
55125683

5513-
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5684+
// load y
5685+
const svint8_t qy_32 = svld1_s8(ph32, y0->qs);
5686+
svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);
5687+
5688+
qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);
5689+
5690+
// scale creation
5691+
const float32_t deq1 = GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d);
5692+
const float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d);
5693+
5694+
// duplicate deq1 in first half of vector and deq2 in second half of vector
5695+
const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);
5696+
5697+
const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));
5698+
5699+
sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);
5700+
}
5701+
5702+
sumf = svaddv_f32(svptrue_b32(), sumv00);
5703+
break;
5704+
}
5705+
default:
5706+
assert(false && "Unsupported vector length");
5707+
break;
55145708
}
55155709
#elif defined(__ARM_NEON)
55165710
float32x4_t sumv0 = vdupq_n_f32(0.0f);

0 commit comments

Comments
 (0)