Skip to content

Commit 3f20ffe

Browse files
Merge pull request tensorflow#7820 from benoitsteiner/master
Improved support for AVX512.
2 parents 850938b + 196c9b7 commit 3f20ffe

File tree

2 files changed

+100
-2
lines changed

2 files changed

+100
-2
lines changed

third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ typedef struct Packet32q8i {
1111
Packet32q8i(__m256i val) : val(val) {}
1212
} Packet32q8i;
1313

14+
typedef struct Packet16q16i {
15+
__m256i val;
16+
operator __m256i() const { return val; }
17+
Packet16q16i();
18+
Packet16q16i(__m256i val) : val(val) {}
19+
} Packet16q16i;
20+
1421
typedef struct Packet32q8u {
1522
__m256i val;
1623
operator __m256i() const { return val; }
@@ -32,6 +39,13 @@ typedef struct Packet16q8u {
3239
Packet16q8u(__m128i val) : val(val) {}
3340
} Packet16q8u;
3441

42+
typedef struct Packet8q16i {
43+
__m128i val;
44+
operator __m128i() const { return val; }
45+
Packet8q16i();
46+
Packet8q16i(__m128i val) : val(val) {}
47+
} Packet8q16i;
48+
3549
typedef struct Packet8q32i {
3650
__m256i val;
3751
operator __m256i() const { return val; }
@@ -92,6 +106,28 @@ struct packet_traits<QUInt8> : default_packet_traits {
92106
};
93107
};
94108
template <>
109+
struct packet_traits<QInt16> : default_packet_traits {
110+
typedef Packet16q16i type;
111+
typedef Packet8q16i half;
112+
enum {
113+
Vectorizable = 1,
114+
AlignedOnScalar = 1,
115+
size = 16,
116+
};
117+
enum {
118+
HasAdd = 0,
119+
HasSub = 0,
120+
HasMul = 0,
121+
HasNegate = 0,
122+
HasAbs = 0,
123+
HasAbs2 = 0,
124+
HasMin = 1,
125+
HasMax = 1,
126+
HasConj = 0,
127+
HasSetLinear = 0
128+
};
129+
};
130+
template <>
95131
struct packet_traits<QInt32> : default_packet_traits {
96132
typedef Packet8q32i type;
97133
typedef Packet4q32i half;
@@ -122,6 +158,12 @@ struct unpacket_traits<Packet32q8i> {
122158
enum { size = 32, alignment=Aligned32 };
123159
};
124160
template <>
161+
struct unpacket_traits<Packet16q16i> {
162+
typedef QInt16 type;
163+
typedef Packet8q16i half;
164+
enum { size = 16, alignment=Aligned32 };
165+
};
166+
template <>
125167
struct unpacket_traits<Packet32q8u> {
126168
typedef QUInt8 type;
127169
typedef Packet16q8u half;
@@ -146,6 +188,11 @@ EIGEN_STRONG_INLINE Packet32q8u ploadu<Packet32q8u>(const QUInt8* from) {
146188
reinterpret_cast<const __m256i*>(from));
147189
}
148190
template <>
191+
EIGEN_STRONG_INLINE Packet16q16i ploadu<Packet16q16i>(const QInt16* from) {
192+
EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(
193+
reinterpret_cast<const __m256i*>(from));
194+
}
195+
template <>
149196
EIGEN_STRONG_INLINE Packet8q32i ploadu<Packet8q32i>(const QInt32* from) {
150197
EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(
151198
reinterpret_cast<const __m256i*>(from));
@@ -163,6 +210,11 @@ EIGEN_STRONG_INLINE Packet32q8u pload<Packet32q8u>(const QUInt8* from) {
163210
reinterpret_cast<const __m256i*>(from));
164211
}
165212
template <>
213+
EIGEN_STRONG_INLINE Packet16q16i pload<Packet16q16i>(const QInt16* from) {
214+
EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(
215+
reinterpret_cast<const __m256i*>(from));
216+
}
217+
template <>
166218
EIGEN_STRONG_INLINE Packet8q32i pload<Packet8q32i>(const QInt32* from) {
167219
EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(
168220
reinterpret_cast<const __m256i*>(from));
@@ -180,6 +232,11 @@ EIGEN_STRONG_INLINE void pstoreu<QUInt8>(QUInt8* to, const Packet32q8u& from) {
180232
reinterpret_cast<__m256i*>(to), from.val);
181233
}
182234
template <>
235+
EIGEN_STRONG_INLINE void pstoreu<QInt16>(QInt16* to, const Packet16q16i& from) {
236+
EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(
237+
reinterpret_cast<__m256i*>(to), from.val);
238+
}
239+
template <>
183240
EIGEN_STRONG_INLINE void pstoreu<QInt32>(QInt32* to, const Packet8q32i& from) {
184241
EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(
185242
reinterpret_cast<__m256i*>(to), from.val);
@@ -192,6 +249,11 @@ EIGEN_STRONG_INLINE void pstore<QInt32>(QInt32* to, const Packet8q32i& from) {
192249
from.val);
193250
}
194251
template <>
252+
EIGEN_STRONG_INLINE void pstore<QInt16>(QInt16* to, const Packet16q16i& from) {
253+
EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to),
254+
from.val);
255+
}
256+
template <>
195257
EIGEN_STRONG_INLINE void pstore<QUInt8>(QUInt8* to, const Packet32q8u& from) {
196258
EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to),
197259
from.val);
@@ -208,6 +270,10 @@ EIGEN_STRONG_INLINE QInt32 pfirst<Packet8q32i>(const Packet8q32i& a) {
208270
return _mm_cvtsi128_si32(_mm256_castsi256_si128(a));
209271
}
210272
template <>
273+
EIGEN_STRONG_INLINE QInt16 pfirst<Packet16q16i>(const Packet16q16i& a) {
274+
return _mm256_extract_epi16(a.val, 0);
275+
}
276+
template <>
211277
EIGEN_STRONG_INLINE QUInt8 pfirst<Packet32q8u>(const Packet32q8u& a) {
212278
return static_cast<uint8_t>(_mm256_extract_epi8(a.val, 0));
213279
}
@@ -237,6 +303,10 @@ EIGEN_STRONG_INLINE Packet8q32i padd<Packet8q32i>(const Packet8q32i& a,
237303
return _mm256_add_epi32(a.val, b.val);
238304
}
239305
template <>
306+
EIGEN_STRONG_INLINE Packet16q16i pset1<Packet16q16i>(const QInt16& from) {
307+
return _mm256_set1_epi16(from.value);
308+
}
309+
template <>
240310
EIGEN_STRONG_INLINE Packet8q32i psub<Packet8q32i>(const Packet8q32i& a,
241311
const Packet8q32i& b) {
242312
return _mm256_sub_epi32(a.val, b.val);
@@ -264,6 +334,17 @@ EIGEN_STRONG_INLINE Packet8q32i pmax<Packet8q32i>(const Packet8q32i& a,
264334
return _mm256_max_epi32(a.val, b.val);
265335
}
266336

337+
template <>
338+
EIGEN_STRONG_INLINE Packet16q16i pmin<Packet16q16i>(const Packet16q16i& a,
339+
const Packet16q16i& b) {
340+
return _mm256_min_epi16(a.val, b.val);
341+
}
342+
template <>
343+
EIGEN_STRONG_INLINE Packet16q16i pmax<Packet16q16i>(const Packet16q16i& a,
344+
const Packet16q16i& b) {
345+
return _mm256_max_epi16(a.val, b.val);
346+
}
347+
267348
template <>
268349
EIGEN_STRONG_INLINE Packet32q8u pmin<Packet32q8u>(const Packet32q8u& a,
269350
const Packet32q8u& b) {
@@ -304,6 +385,23 @@ EIGEN_STRONG_INLINE QInt32 predux_max<Packet8q32i>(const Packet8q32i& a) {
304385
_mm256_max_epi32(tmp, _mm256_shuffle_epi32(tmp, 1)));
305386
}
306387

388+
template <>
389+
EIGEN_STRONG_INLINE QInt16 predux_min<Packet16q16i>(const Packet16q16i& a) {
390+
__m256i tmp = _mm256_min_epi16(a, _mm256_permute2f128_si256(a, a, 1));
391+
tmp =
392+
_mm256_min_epi16(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
393+
tmp = _mm256_min_epi16(tmp, _mm256_shuffle_epi32(tmp, 1));
394+
return std::min(_mm256_extract_epi16(tmp, 0), _mm256_extract_epi16(tmp, 1));
395+
}
396+
template <>
397+
EIGEN_STRONG_INLINE QInt16 predux_max<Packet16q16i>(const Packet16q16i& a) {
398+
__m256i tmp = _mm256_max_epi16(a, _mm256_permute2f128_si256(a, a, 1));
399+
tmp =
400+
_mm256_max_epi16(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
401+
tmp = _mm256_max_epi16(tmp, _mm256_shuffle_epi32(tmp, 1));
402+
return std::max(_mm256_extract_epi16(tmp, 0), _mm256_extract_epi16(tmp, 1));
403+
}
404+
307405
template <>
308406
EIGEN_STRONG_INLINE QUInt8 predux_min<Packet32q8u>(const Packet32q8u& a) {
309407
__m256i tmp = _mm256_min_epu8(a, _mm256_permute2f128_si256(a, a, 1));

third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ EIGEN_STRONG_INLINE QInt16 predux_max<Packet32q16i>(const Packet32q16i& a) {
457457
std::uint32_t w =
458458
pfirst(
459459
_mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
460-
return std::min({
460+
return std::max({
461461
static_cast<std::int16_t>(w >> 16),
462462
static_cast<std::int16_t>(w)
463463
});
@@ -493,7 +493,7 @@ EIGEN_STRONG_INLINE QUInt8 predux_max<Packet64q8u>(const Packet64q8u& a) {
493493
std::uint32_t w =
494494
pfirst(
495495
_mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
496-
return std::min({
496+
return std::max({
497497
static_cast<std::uint8_t>(w >> 24),
498498
static_cast<std::uint8_t>(w >> 16),
499499
static_cast<std::uint8_t>(w >> 8),

0 commit comments

Comments
 (0)