Skip to content

Commit c83d0cd

Browse files
dsharletgxnnpack-bot
authored andcommitted
Change simd::vec to be recursively defined
This is a pretty big change, it does the following things: - Defines a generic `vec<T, N>` that is a concatenation of two `vec<T, N/2>` vectors - This doesn't seem like a big change, but it allows us to eliminate `multi_vec`, because `vec<T, N>` always exists and can represent multiples of the available specializations automatically. - This helps in a bunch of ways: - We don't need to tell `load` what type we want, it only requires a number of lanes, and it will figure out the right type. - Similarly, this solves the problem of how to figure out what a vector of N elements of type T is (`multi_vec<?, ?>` vs. `vec<T, N>`). Now it's always just `vec<T, N>`. - This seems to make some partial loads/stores a bit faster. - The default implementation of splitting an op into two sub-ops often does exactly what we want. This change enabled removing most of the existing `convert`, `horizontal_min`/`horizontal_max`, and a few other ops. This also tweaks a few of the reduction implementations in a related way, and gives up to a ~4x speedup in some shapes that stress tail cases. PiperOrigin-RevId: 846515240
1 parent f52fbff commit c83d0cd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1589
-1846
lines changed

ynnpack/base/simd/BUILD

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,17 @@ cc_library(
2323

2424
cc_library(
2525
name = "simd",
26-
hdrs = [
27-
"multi_vec.h",
28-
"vec.h",
29-
],
26+
hdrs = ["vec.h"],
3027
compatible_with = _COMPATIBLE_WITH,
3128
# These headers need to be textual because we can't compile them with the appropriate copts.
3229
textual_hdrs = [
3330
# These headers should not be directly included.
3431
"x86_avx2_base.h",
32+
"x86_avx512f_base.h",
3533
"x86_avx_base.h",
3634
"x86_sse2_base.h",
3735
"x86_sse41_base.h",
36+
"generic.inc",
3837
# For the most part, only one of these headers should be included. Multiple of these headers
3938
# may define the same operation and type, using a different implementation, depending on the
4039
# target. For example, f32x16 is provided by both avx512f as a single 512-bit vector, and

ynnpack/base/simd/arm_neon.h

Lines changed: 42 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "ynnpack/base/base.h"
1818
#include "ynnpack/base/bfloat16.h"
1919
#include "ynnpack/base/half.h"
20-
#include "ynnpack/base/simd/multi_vec.h"
2120
#include "ynnpack/base/simd/vec.h"
2221

2322
namespace ynn {
@@ -162,32 +161,32 @@ YNN_ALWAYS_INLINE void vst1_lane(float* ptr, float32x2_t v) {
162161

163162
} // namespace internal
164163

165-
YNN_ALWAYS_INLINE f32x4 load_aligned(const float* ptr, f32x4,
166-
decltype(f32x4::N) = {}) {
164+
YNN_ALWAYS_INLINE f32x4 load_aligned(const float* ptr, decltype(f32x4::N),
165+
f32x4 = {}) {
167166
return f32x4{vld1q_f32(ptr)};
168167
}
169-
YNN_ALWAYS_INLINE s32x4 load_aligned(const int32_t* ptr, s32x4,
170-
decltype(s32x4::N) = {}) {
168+
YNN_ALWAYS_INLINE s32x4 load_aligned(const int32_t* ptr, decltype(s32x4::N),
169+
s32x4 = {}) {
171170
return s32x4{vld1q_s32(ptr)};
172171
}
173-
YNN_ALWAYS_INLINE bf16x8 load_aligned(const bfloat16* ptr, bf16x8,
174-
decltype(bf16x8::N) = {}) {
172+
YNN_ALWAYS_INLINE bf16x8 load_aligned(const bfloat16* ptr, decltype(bf16x8::N),
173+
bf16x8 = {}) {
175174
return bf16x8{vld1q_u16(reinterpret_cast<const uint16_t*>(ptr))};
176175
}
177-
YNN_ALWAYS_INLINE f16x8 load_aligned(const half* ptr, f16x8,
178-
decltype(f16x8::N) = {}) {
176+
YNN_ALWAYS_INLINE f16x8 load_aligned(const half* ptr, decltype(f16x8::N),
177+
f16x8 = {}) {
179178
return f16x8{vld1q_u16(reinterpret_cast<const uint16_t*>(ptr))};
180179
}
181-
YNN_ALWAYS_INLINE s16x8 load_aligned(const int16_t* ptr, s16x8,
182-
decltype(s16x8::N) = {}) {
180+
YNN_ALWAYS_INLINE s16x8 load_aligned(const int16_t* ptr, decltype(s16x8::N),
181+
s16x8 = {}) {
183182
return s16x8{vld1q_s16(ptr)};
184183
}
185-
YNN_ALWAYS_INLINE u8x16 load_aligned(const uint8_t* ptr, u8x16,
186-
decltype(u8x16::N) = {}) {
184+
YNN_ALWAYS_INLINE u8x16 load_aligned(const uint8_t* ptr, decltype(u8x16::N),
185+
u8x16 = {}) {
187186
return u8x16{vld1q_u8(ptr)};
188187
}
189-
YNN_ALWAYS_INLINE s8x16 load_aligned(const int8_t* ptr, s8x16,
190-
decltype(s8x16::N) = {}) {
188+
YNN_ALWAYS_INLINE s8x16 load_aligned(const int8_t* ptr, decltype(s8x16::N),
189+
s8x16 = {}) {
191190
return s8x16{vld1q_s8(ptr)};
192191
}
193192

@@ -220,30 +219,30 @@ YNN_ALWAYS_INLINE void store_aligned(int8_t* ptr, s8x16 b,
220219
vst1q_s8(ptr, b.v);
221220
}
222221

223-
YNN_ALWAYS_INLINE f32x4 load(const float* ptr, f32x4, decltype(f32x4::N) = {}) {
222+
YNN_ALWAYS_INLINE f32x4 load(const float* ptr, decltype(f32x4::N), f32x4 = {}) {
224223
return f32x4{vld1q_f32(ptr)};
225224
}
226-
YNN_ALWAYS_INLINE s32x4 load(const int32_t* ptr, s32x4,
227-
decltype(s32x4::N) = {}) {
225+
YNN_ALWAYS_INLINE s32x4 load(const int32_t* ptr, decltype(s32x4::N),
226+
s32x4 = {}) {
228227
return s32x4{vld1q_s32(ptr)};
229228
}
230-
YNN_ALWAYS_INLINE bf16x8 load(const bfloat16* ptr, bf16x8,
231-
decltype(f16x8::N) = {}) {
229+
YNN_ALWAYS_INLINE bf16x8 load(const bfloat16* ptr, decltype(f16x8::N),
230+
bf16x8 = {}) {
232231
return bf16x8{vld1q_u16(reinterpret_cast<const uint16_t*>(ptr))};
233232
}
234-
YNN_ALWAYS_INLINE f16x8 load(const half* ptr, f16x8, decltype(f16x8::N) = {}) {
233+
YNN_ALWAYS_INLINE f16x8 load(const half* ptr, decltype(f16x8::N), f16x8 = {}) {
235234
return f16x8{vld1q_u16(reinterpret_cast<const uint16_t*>(ptr))};
236235
}
237-
YNN_ALWAYS_INLINE s16x8 load(const int16_t* ptr, s16x8,
238-
decltype(s16x8::N) = {}) {
236+
YNN_ALWAYS_INLINE s16x8 load(const int16_t* ptr, decltype(s16x8::N),
237+
s16x8 = {}) {
239238
return s16x8{vld1q_s16(ptr)};
240239
}
241-
YNN_ALWAYS_INLINE u8x16 load(const uint8_t* ptr, u8x16,
242-
decltype(u8x16::N) = {}) {
240+
YNN_ALWAYS_INLINE u8x16 load(const uint8_t* ptr, decltype(u8x16::N),
241+
u8x16 = {}) {
243242
return u8x16{vld1q_u8(ptr)};
244243
}
245-
YNN_ALWAYS_INLINE s8x16 load(const int8_t* ptr, s8x16,
246-
decltype(s8x16::N) = {}) {
244+
YNN_ALWAYS_INLINE s8x16 load(const int8_t* ptr, decltype(s8x16::N),
245+
s8x16 = {}) {
247246
return s8x16{vld1q_s8(ptr)};
248247
}
249248

@@ -291,7 +290,7 @@ inline vec<T, 4> partial_load_lanes_x4(const T* ptr, vec<T, 4> src, size_t n) {
291290
default:
292291
break;
293292
}
294-
return load_aligned(lanes, vec<T, 4>{});
293+
return load_aligned(lanes, std::integral_constant<size_t, 4>{});
295294
}
296295
template <typename T>
297296
inline void partial_store_x32x4(T* ptr, vec<T, 4> b, size_t n) {
@@ -313,10 +312,10 @@ inline void partial_store_x32x4(T* ptr, vec<T, 4> b, size_t n) {
313312

314313
} // namespace internal
315314

316-
YNN_ALWAYS_INLINE f32x4 load(const float* ptr, f32x4 src, size_t n) {
315+
YNN_ALWAYS_INLINE f32x4 load(const float* ptr, size_t n, f32x4 src) {
317316
return internal::partial_load_lanes_x4(ptr, src, n);
318317
}
319-
YNN_ALWAYS_INLINE s32x4 load(const int32_t* ptr, s32x4 src, size_t n) {
318+
YNN_ALWAYS_INLINE s32x4 load(const int32_t* ptr, size_t n, s32x4 src) {
320319
return internal::partial_load_lanes_x4(ptr, src, n);
321320
}
322321
YNN_ALWAYS_INLINE void store(float* ptr, f32x4 b, size_t n) {
@@ -326,13 +325,13 @@ YNN_ALWAYS_INLINE void store(int32_t* ptr, s32x4 b, size_t n) {
326325
internal::partial_store_x32x4(ptr, b, n);
327326
}
328327

329-
YNN_ALWAYS_INLINE bf16x8 load(const bfloat16* ptr, bf16x8 src, size_t n) {
328+
YNN_ALWAYS_INLINE bf16x8 load(const bfloat16* ptr, size_t n, bf16x8 src) {
330329
return internal::partial_load_memcpy(ptr, src, n);
331330
}
332-
YNN_ALWAYS_INLINE f16x8 load(const half* ptr, f16x8 src, size_t n) {
331+
YNN_ALWAYS_INLINE f16x8 load(const half* ptr, size_t n, f16x8 src) {
333332
return internal::partial_load_memcpy(ptr, src, n);
334333
}
335-
YNN_ALWAYS_INLINE s16x8 load(const int16_t* ptr, s16x8 src, size_t n) {
334+
YNN_ALWAYS_INLINE s16x8 load(const int16_t* ptr, size_t n, s16x8 src) {
336335
return internal::partial_load_memcpy(ptr, src, n);
337336
}
338337
YNN_ALWAYS_INLINE void store(bfloat16* ptr, bf16x8 value, size_t n) {
@@ -345,10 +344,10 @@ YNN_ALWAYS_INLINE void store(int16_t* ptr, s16x8 value, size_t n) {
345344
internal::partial_store_memcpy(ptr, value, n);
346345
}
347346

348-
YNN_ALWAYS_INLINE u8x16 load(const uint8_t* ptr, u8x16 src, size_t n) {
347+
YNN_ALWAYS_INLINE u8x16 load(const uint8_t* ptr, size_t n, u8x16 src) {
349348
return internal::partial_load_memcpy(ptr, src, n);
350349
}
351-
YNN_ALWAYS_INLINE s8x16 load(const int8_t* ptr, s8x16 src, size_t n) {
350+
YNN_ALWAYS_INLINE s8x16 load(const int8_t* ptr, size_t n, s8x16 src) {
352351
return internal::partial_load_memcpy(ptr, src, n);
353352
}
354353

@@ -558,10 +557,10 @@ YNN_ALWAYS_INLINE std::array<vec<T, 4>, 4> transpose(
558557
}};
559558
}
560559

561-
using f32x8 = multi_vec<f32x4, 2>;
562-
using s32x8 = multi_vec<s32x4, 2>;
563-
using s16x16 = multi_vec<s16x8, 2>;
564-
using s32x16 = multi_vec<s32x4, 4>;
560+
using f32x8 = vec<float, 8>;
561+
using s32x8 = vec<int32_t, 8>;
562+
using s16x16 = vec<int16_t, 16>;
563+
using s32x16 = vec<int32_t, 16>;
565564

566565
YNN_ALWAYS_INLINE f32x8 convert(bf16x8 a, float) {
567566
uint16x8x2_t a_u32 = vzipq_u16(vdupq_n_u16(0), a.v);
@@ -593,31 +592,17 @@ YNN_ALWAYS_INLINE s32x8 convert(s16x8 b, int32_t) {
593592
}
594593

595594
YNN_ALWAYS_INLINE s32x16 convert(s8x16 b, int32_t) {
596-
s16x16 b_s16 = convert(b, int16_t{});
597-
s32x8 lo = convert(extract<0>(b_s16, s16x8{}), int32_t{});
598-
s32x8 hi = convert(extract<1>(b_s16, s16x8{}), int32_t{});
599-
return {
600-
extract<0>(lo, s32x4{}),
601-
extract<1>(lo, s32x4{}),
602-
extract<0>(hi, s32x4{}),
603-
extract<1>(hi, s32x4{}),
604-
};
595+
return convert(convert(b, int16_t{}), int32_t{});
605596
}
606597

607598
YNN_ALWAYS_INLINE s32x16 convert(u8x16 b, int32_t) {
608-
s16x16 b_s16 = convert(b, int16_t{});
609-
s32x8 lo = convert(extract<0>(b_s16, s16x8{}), int32_t{});
610-
s32x8 hi = convert(extract<1>(b_s16, s16x8{}), int32_t{});
611-
return {
612-
extract<0>(lo, s32x4{}),
613-
extract<1>(lo, s32x4{}),
614-
extract<0>(hi, s32x4{}),
615-
extract<1>(hi, s32x4{}),
616-
};
599+
return convert(convert(b, int16_t{}), int32_t{});
617600
}
618601

619602
} // namespace simd
620603

621604
} // namespace ynn
622605

606+
#include "ynnpack/base/simd/generic.inc" // IWYU pragma: export
607+
623608
#endif // XNNPACK_YNNPACK_BASE_SIMD_ARM_H_

ynnpack/base/simd/arm_neonfma.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@ YNN_ALWAYS_INLINE f32x4 fma(f32x4 a, f32x4 b, f32x4 acc) {
2323

2424
} // namespace ynn
2525

26+
#include "ynnpack/base/simd/generic.inc" // IWYU pragma: export
27+
2628
#endif // XNNPACK_YNNPACK_BASE_SIMD_ARM_NEONFMA_H_

ynnpack/base/simd/arm_neonfp16.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
#include <arm_neon.h>
1010

1111
#include "ynnpack/base/simd/arm_neon.h" // IWYU pragma: export
12-
#include "ynnpack/base/simd/multi_vec.h"
12+
#include "ynnpack/base/simd/vec.h"
1313

1414
namespace ynn {
1515

1616
namespace simd {
1717

18-
using f32x8 = multi_vec<f32x4, 2>;
18+
using f32x8 = vec<float, 8>;
1919

2020
YNN_ALWAYS_INLINE f32x8 convert(f16x8 a, float) {
2121
return {
@@ -28,4 +28,6 @@ YNN_ALWAYS_INLINE f32x8 convert(f16x8 a, float) {
2828

2929
} // namespace ynn
3030

31+
#include "ynnpack/base/simd/generic.inc" // IWYU pragma: export
32+
3133
#endif // XNNPACK_YNNPACK_BASE_SIMD_ARM_NEONFP16_H_

0 commit comments

Comments
 (0)