@@ -4003,42 +4003,141 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4003
4003
float sumf = 0;
4004
4004
4005
4005
#if defined(__ARM_FEATURE_SVE)
4006
- if (ggml_sve_cnt_b == QK8_0) {
4007
- const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
4008
- const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
4006
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
4007
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
4009
4008
4010
- svfloat32_t sumv0 = svdup_n_f32(0.0f);
4011
- svfloat32_t sumv1 = svdup_n_f32(0.0f);
4009
+ const int vector_length = ggml_sve_cnt_b*8;
4012
4010
4013
- for (; ib + 1 < nb; ib += 2) {
4014
- const block_q4_0 * restrict x0 = &x[ib + 0];
4015
- const block_q4_0 * restrict x1 = &x[ib + 1];
4016
- const block_q8_0 * restrict y0 = &y[ib + 0];
4017
- const block_q8_0 * restrict y1 = &y[ib + 1];
4018
-
4019
- // load x
4020
- const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4021
- const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4022
-
4023
- // 4-bit -> 8-bit
4024
- const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
4025
- const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
4026
-
4027
- // sub 8
4028
- const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
4029
- const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
4011
+ // VLA Implementation using switch case
4012
+ switch (vector_length) {
4013
+ case 128:
4014
+ {
4015
+ // predicate for activating higher lanes for 4 float32 elements
4016
+ const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
4017
+
4018
+ for (; ib + 1 < nb; ib += 2) {
4019
+ const block_q4_0 * restrict x0 = &x[ib + 0];
4020
+ const block_q4_0 * restrict x1 = &x[ib + 1];
4021
+ const block_q8_0 * restrict y0 = &y[ib + 0];
4022
+ const block_q8_0 * restrict y1 = &y[ib + 1];
4023
+
4024
+ // load x
4025
+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4026
+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4027
+
4028
+ // 4-bit -> 8-bit
4029
+ const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));
4030
+ const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));
4031
+ const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));
4032
+ const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));
4033
+
4034
+ // sub 8
4035
+ const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
4036
+ const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
4037
+ const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
4038
+ const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);
4039
+
4040
+ // load y
4041
+ const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
4042
+ const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);
4043
+ const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
4044
+ const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);
4045
+
4046
+ // dot product
4047
+ sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,
4048
+ svdot_s32(svdup_n_s32(0), qx0ls, qy0l),
4049
+ svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
4050
+ sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,
4051
+ svdot_s32(svdup_n_s32(0), qx1ls, qy1l),
4052
+ svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4053
+ }
4030
4054
4031
- // load y
4032
- const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
4033
- const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
4055
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4056
+ } break;
4057
+ case 256:
4058
+ {
4059
+ // predicate for activating higher lanes for 16 int8 elements
4060
+ const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
4061
+ // predicate for activating lower lanes for 16 int8 elements
4062
+ const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
4063
+
4064
+ for (; ib + 1 < nb; ib += 2) {
4065
+ const block_q4_0 * restrict x0 = &x[ib + 0];
4066
+ const block_q4_0 * restrict x1 = &x[ib + 1];
4067
+ const block_q8_0 * restrict y0 = &y[ib + 0];
4068
+ const block_q8_0 * restrict y1 = &y[ib + 1];
4069
+
4070
+ // load x
4071
+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4072
+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4073
+
4074
+ // 4-bit -> 8-bit
4075
+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
4076
+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
4077
+
4078
+ // sub 8
4079
+ const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
4080
+ const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
4081
+
4082
+ // load y
4083
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
4084
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
4085
+
4086
+ // dot product
4087
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
4088
+ svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
4089
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
4090
+ svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4091
+ }
4034
4092
4035
- // dot product
4036
- sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
4037
- sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4038
- }
4093
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4094
+ } break;
4095
+ case 512:
4096
+ {
4097
+ // predicate for activating higher lanes for 32 int8 elements
4098
+ const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
4099
+
4100
+ // predicate for activating higher lanes for 16 int8 elements
4101
+ const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
4102
+ // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
4103
+ const svbool_t pl16 = svnot_b_z(ph32, ph16);
4104
+
4105
+ for (; ib + 1 < nb; ib += 2) {
4106
+ const block_q4_0 * restrict x0 = &x[ib + 0];
4107
+ const block_q4_0 * restrict x1 = &x[ib + 1];
4108
+ const block_q8_0 * restrict y0 = &y[ib + 0];
4109
+ const block_q8_0 * restrict y1 = &y[ib + 1];
4110
+
4111
+ // load x
4112
+ const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
4113
+ const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
4114
+
4115
+ // 4-bit -> 8-bit
4116
+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
4117
+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
4118
+
4119
+ // sub 8
4120
+ const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);
4121
+ const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);
4122
+
4123
+ // load y
4124
+ const svint8_t qy0 = svld1_s8(ph32, y0->qs);
4125
+ const svint8_t qy1 = svld1_s8(ph32, y1->qs);
4126
+
4127
+ // dot product
4128
+ sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,
4129
+ svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
4130
+ sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,
4131
+ svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4132
+ }
4039
4133
4040
- sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4134
+ sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
4135
+ } break;
4136
+ default:
4137
+ assert(false && "Unsupported vector length");
4138
+ break;
4041
4139
}
4140
+
4042
4141
#elif defined(__ARM_NEON)
4043
4142
float32x4_t sumv0 = vdupq_n_f32(0.0f);
4044
4143
float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -5488,29 +5587,124 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
5488
5587
float sumf = 0;
5489
5588
5490
5589
#if defined(__ARM_FEATURE_SVE)
5491
- if (ggml_sve_cnt_b == QK8_0) {
5492
- svfloat32_t sumv0 = svdup_n_f32(0.0f);
5493
- svfloat32_t sumv1 = svdup_n_f32(0.0f);
5590
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
5591
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
5494
5592
5495
- for (; ib + 1 < nb; ib += 2) {
5496
- const block_q8_0 * restrict x0 = &x[ib + 0];
5497
- const block_q8_0 * restrict x1 = &x[ib + 1];
5498
- const block_q8_0 * restrict y0 = &y[ib + 0];
5499
- const block_q8_0 * restrict y1 = &y[ib + 1];
5593
+ const int vector_length = ggml_sve_cnt_b*8;
5500
5594
5501
- // load x
5502
- const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5503
- const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5595
+ //VLA Implemenation for SVE
5596
+ switch (vector_length) {
5597
+ case 128:
5598
+ {
5599
+ // predicate for activating lanes for 16 Int8 elements
5600
+ const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);
5601
+ const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
5602
+
5603
+ for (; ib + 1 < nb; ib += 2) {
5604
+ const block_q8_0 * restrict x0 = &x[ib + 0];
5605
+ const block_q8_0 * restrict x1 = &x[ib + 1];
5606
+ const block_q8_0 * restrict y0 = &y[ib + 0];
5607
+ const block_q8_0 * restrict y1 = &y[ib + 1];
5608
+
5609
+ // load x
5610
+ const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
5611
+ const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);
5612
+ const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);
5613
+ const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);
5614
+
5615
+ // load y
5616
+ const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);
5617
+ const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);
5618
+ const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);
5619
+ const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);
5620
+
5621
+ sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,
5622
+ svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),
5623
+ svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5624
+ sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,
5625
+ svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),
5626
+ svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5627
+ }
5504
5628
5505
- // load y
5506
- const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5507
- const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5629
+ sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));
5630
+ } break;
5631
+ case 256:
5632
+ {
5633
+ //printf("sve256");
5634
+ for (; ib + 1 < nb; ib += 2) {
5635
+ const block_q8_0 * restrict x0 = &x[ib + 0];
5636
+ const block_q8_0 * restrict x1 = &x[ib + 1];
5637
+ const block_q8_0 * restrict y0 = &y[ib + 0];
5638
+ const block_q8_0 * restrict y1 = &y[ib + 1];
5639
+
5640
+ // load x
5641
+ const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5642
+ const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5643
+
5644
+ // load y
5645
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5646
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5647
+
5648
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
5649
+ svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5650
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
5651
+ svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5652
+ }
5508
5653
5509
- sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5510
- sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5511
- }
5654
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5655
+ } break;
5656
+ case 512:
5657
+ {
5658
+ // predicate for activating high 256 bit
5659
+ const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
5660
+ // predicate for activating low 256 bit
5661
+ const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);
5662
+
5663
+ // predicate for activating high lanes for 8 float32 elements
5664
+ const svbool_t ph8 = svptrue_pat_b32(SV_VL8);
5665
+ // predicate for activating low lanes for 8 float32 elements
5666
+ const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);
5667
+
5668
+ svfloat32_t sumv00 = svdup_n_f32(0.0f);
5669
+
5670
+ for (; ib + 1 < nb; ib += 2) {
5671
+ const block_q8_0 * restrict x0 = &x[ib + 0];
5672
+ const block_q8_0 * restrict x1 = &x[ib + 1];
5673
+ const block_q8_0 * restrict y0 = &y[ib + 0];
5674
+ const block_q8_0 * restrict y1 = &y[ib + 1];
5675
+
5676
+ //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
5677
+ // and add them to make one 64 element vector
5678
+ // load x
5679
+ const svint8_t qx_32 = svld1_s8(ph32, x0->qs);
5680
+ svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);
5681
+
5682
+ qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);
5512
5683
5513
- sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5684
+ // load y
5685
+ const svint8_t qy_32 = svld1_s8(ph32, y0->qs);
5686
+ svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);
5687
+
5688
+ qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);
5689
+
5690
+ // scale creation
5691
+ const float32_t deq1 = GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d);
5692
+ const float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d);
5693
+
5694
+ // duplicate deq1 in first half of vector and deq2 in second half of vector
5695
+ const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);
5696
+
5697
+ const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));
5698
+
5699
+ sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);
5700
+ }
5701
+
5702
+ sumf = svaddv_f32(svptrue_b32(), sumv00);
5703
+ break;
5704
+ }
5705
+ default:
5706
+ assert(false && "Unsupported vector length");
5707
+ break;
5514
5708
}
5515
5709
#elif defined(__ARM_NEON)
5516
5710
float32x4_t sumv0 = vdupq_n_f32(0.0f);
0 commit comments