1717#include " ynnpack/base/base.h"
1818#include " ynnpack/base/bfloat16.h"
1919#include " ynnpack/base/half.h"
20- #include " ynnpack/base/simd/multi_vec.h"
2120#include " ynnpack/base/simd/vec.h"
2221
2322namespace ynn {
@@ -162,32 +161,32 @@ YNN_ALWAYS_INLINE void vst1_lane(float* ptr, float32x2_t v) {
162161
163162} // namespace internal
164163
165- YNN_ALWAYS_INLINE f32x4 load_aligned (const float * ptr, f32x4,
166- decltype ( f32x4::N) = {}) {
164+ YNN_ALWAYS_INLINE f32x4 load_aligned (const float * ptr, decltype ( f32x4::N) ,
165+ f32x4 = {}) {
167166 return f32x4{vld1q_f32 (ptr)};
168167}
169- YNN_ALWAYS_INLINE s32x4 load_aligned (const int32_t * ptr, s32x4,
170- decltype ( s32x4::N) = {}) {
168+ YNN_ALWAYS_INLINE s32x4 load_aligned (const int32_t * ptr, decltype ( s32x4::N) ,
169+ s32x4 = {}) {
171170 return s32x4{vld1q_s32 (ptr)};
172171}
173- YNN_ALWAYS_INLINE bf16x8 load_aligned (const bfloat16* ptr, bf16x8,
174- decltype ( bf16x8::N) = {}) {
172+ YNN_ALWAYS_INLINE bf16x8 load_aligned (const bfloat16* ptr, decltype ( bf16x8::N) ,
173+ bf16x8 = {}) {
175174 return bf16x8{vld1q_u16 (reinterpret_cast <const uint16_t *>(ptr))};
176175}
177- YNN_ALWAYS_INLINE f16x8 load_aligned (const half* ptr, f16x8,
178- decltype ( f16x8::N) = {}) {
176+ YNN_ALWAYS_INLINE f16x8 load_aligned (const half* ptr, decltype ( f16x8::N) ,
177+ f16x8 = {}) {
179178 return f16x8{vld1q_u16 (reinterpret_cast <const uint16_t *>(ptr))};
180179}
181- YNN_ALWAYS_INLINE s16x8 load_aligned (const int16_t * ptr, s16x8,
182- decltype ( s16x8::N) = {}) {
180+ YNN_ALWAYS_INLINE s16x8 load_aligned (const int16_t * ptr, decltype ( s16x8::N) ,
181+ s16x8 = {}) {
183182 return s16x8{vld1q_s16 (ptr)};
184183}
185- YNN_ALWAYS_INLINE u8x16 load_aligned (const uint8_t * ptr, u8x16,
186- decltype ( u8x16::N) = {}) {
184+ YNN_ALWAYS_INLINE u8x16 load_aligned (const uint8_t * ptr, decltype ( u8x16::N) ,
185+ u8x16 = {}) {
187186 return u8x16{vld1q_u8 (ptr)};
188187}
189- YNN_ALWAYS_INLINE s8x16 load_aligned (const int8_t * ptr, s8x16,
190- decltype ( s8x16::N) = {}) {
188+ YNN_ALWAYS_INLINE s8x16 load_aligned (const int8_t * ptr, decltype ( s8x16::N) ,
189+ s8x16 = {}) {
191190 return s8x16{vld1q_s8 (ptr)};
192191}
193192
@@ -220,30 +219,30 @@ YNN_ALWAYS_INLINE void store_aligned(int8_t* ptr, s8x16 b,
220219 vst1q_s8 (ptr, b.v );
221220}
222221
223- YNN_ALWAYS_INLINE f32x4 load (const float * ptr, f32x4, decltype (f32x4::N) = {}) {
222+ YNN_ALWAYS_INLINE f32x4 load (const float * ptr, decltype (f32x4::N), f32x4 = {}) {
224223 return f32x4{vld1q_f32 (ptr)};
225224}
226- YNN_ALWAYS_INLINE s32x4 load (const int32_t * ptr, s32x4,
227- decltype ( s32x4::N) = {}) {
225+ YNN_ALWAYS_INLINE s32x4 load (const int32_t * ptr, decltype ( s32x4::N) ,
226+ s32x4 = {}) {
228227 return s32x4{vld1q_s32 (ptr)};
229228}
230- YNN_ALWAYS_INLINE bf16x8 load (const bfloat16* ptr, bf16x8 ,
231- decltype (f16x8::N) = {}) {
229+ YNN_ALWAYS_INLINE bf16x8 load (const bfloat16* ptr, decltype (f16x8::N) ,
230+ bf16x8 = {}) {
232231 return bf16x8{vld1q_u16 (reinterpret_cast <const uint16_t *>(ptr))};
233232}
234- YNN_ALWAYS_INLINE f16x8 load (const half* ptr, f16x8, decltype (f16x8::N) = {}) {
233+ YNN_ALWAYS_INLINE f16x8 load (const half* ptr, decltype (f16x8::N), f16x8 = {}) {
235234 return f16x8{vld1q_u16 (reinterpret_cast <const uint16_t *>(ptr))};
236235}
237- YNN_ALWAYS_INLINE s16x8 load (const int16_t * ptr, s16x8,
238- decltype ( s16x8::N) = {}) {
236+ YNN_ALWAYS_INLINE s16x8 load (const int16_t * ptr, decltype ( s16x8::N) ,
237+ s16x8 = {}) {
239238 return s16x8{vld1q_s16 (ptr)};
240239}
241- YNN_ALWAYS_INLINE u8x16 load (const uint8_t * ptr, u8x16,
242- decltype ( u8x16::N) = {}) {
240+ YNN_ALWAYS_INLINE u8x16 load (const uint8_t * ptr, decltype ( u8x16::N) ,
241+ u8x16 = {}) {
243242 return u8x16{vld1q_u8 (ptr)};
244243}
245- YNN_ALWAYS_INLINE s8x16 load (const int8_t * ptr, s8x16,
246- decltype ( s8x16::N) = {}) {
244+ YNN_ALWAYS_INLINE s8x16 load (const int8_t * ptr, decltype ( s8x16::N) ,
245+ s8x16 = {}) {
247246 return s8x16{vld1q_s8 (ptr)};
248247}
249248
@@ -291,7 +290,7 @@ inline vec<T, 4> partial_load_lanes_x4(const T* ptr, vec<T, 4> src, size_t n) {
291290 default :
292291 break ;
293292 }
294- return load_aligned (lanes, vec<T , 4 >{});
293+ return load_aligned (lanes, std::integral_constant< size_t , 4 >{});
295294}
296295template <typename T>
297296inline void partial_store_x32x4 (T* ptr, vec<T, 4 > b, size_t n) {
@@ -313,10 +312,10 @@ inline void partial_store_x32x4(T* ptr, vec<T, 4> b, size_t n) {
313312
314313} // namespace internal
315314
316- YNN_ALWAYS_INLINE f32x4 load (const float * ptr, f32x4 src, size_t n ) {
315+ YNN_ALWAYS_INLINE f32x4 load (const float * ptr, size_t n, f32x4 src ) {
317316 return internal::partial_load_lanes_x4 (ptr, src, n);
318317}
319- YNN_ALWAYS_INLINE s32x4 load (const int32_t * ptr, s32x4 src, size_t n ) {
318+ YNN_ALWAYS_INLINE s32x4 load (const int32_t * ptr, size_t n, s32x4 src ) {
320319 return internal::partial_load_lanes_x4 (ptr, src, n);
321320}
322321YNN_ALWAYS_INLINE void store (float * ptr, f32x4 b, size_t n) {
@@ -326,13 +325,13 @@ YNN_ALWAYS_INLINE void store(int32_t* ptr, s32x4 b, size_t n) {
326325 internal::partial_store_x32x4 (ptr, b, n);
327326}
328327
329- YNN_ALWAYS_INLINE bf16x8 load (const bfloat16* ptr, bf16x8 src, size_t n ) {
328+ YNN_ALWAYS_INLINE bf16x8 load (const bfloat16* ptr, size_t n, bf16x8 src ) {
330329 return internal::partial_load_memcpy (ptr, src, n);
331330}
332- YNN_ALWAYS_INLINE f16x8 load (const half* ptr, f16x8 src, size_t n ) {
331+ YNN_ALWAYS_INLINE f16x8 load (const half* ptr, size_t n, f16x8 src ) {
333332 return internal::partial_load_memcpy (ptr, src, n);
334333}
335- YNN_ALWAYS_INLINE s16x8 load (const int16_t * ptr, s16x8 src, size_t n ) {
334+ YNN_ALWAYS_INLINE s16x8 load (const int16_t * ptr, size_t n, s16x8 src ) {
336335 return internal::partial_load_memcpy (ptr, src, n);
337336}
338337YNN_ALWAYS_INLINE void store (bfloat16* ptr, bf16x8 value, size_t n) {
@@ -345,10 +344,10 @@ YNN_ALWAYS_INLINE void store(int16_t* ptr, s16x8 value, size_t n) {
345344 internal::partial_store_memcpy (ptr, value, n);
346345}
347346
348- YNN_ALWAYS_INLINE u8x16 load (const uint8_t * ptr, u8x16 src, size_t n ) {
347+ YNN_ALWAYS_INLINE u8x16 load (const uint8_t * ptr, size_t n, u8x16 src ) {
349348 return internal::partial_load_memcpy (ptr, src, n);
350349}
351- YNN_ALWAYS_INLINE s8x16 load (const int8_t * ptr, s8x16 src, size_t n ) {
350+ YNN_ALWAYS_INLINE s8x16 load (const int8_t * ptr, size_t n, s8x16 src ) {
352351 return internal::partial_load_memcpy (ptr, src, n);
353352}
354353
@@ -558,10 +557,10 @@ YNN_ALWAYS_INLINE std::array<vec<T, 4>, 4> transpose(
558557 }};
559558}
560559
561- using f32x8 = multi_vec<f32x4, 2 >;
562- using s32x8 = multi_vec<s32x4, 2 >;
563- using s16x16 = multi_vec<s16x8, 2 >;
564- using s32x16 = multi_vec<s32x4, 4 >;
560+ using f32x8 = vec< float , 8 >;
561+ using s32x8 = vec< int32_t , 8 >;
562+ using s16x16 = vec< int16_t , 16 >;
563+ using s32x16 = vec< int32_t , 16 >;
565564
566565YNN_ALWAYS_INLINE f32x8 convert (bf16x8 a, float ) {
567566 uint16x8x2_t a_u32 = vzipq_u16 (vdupq_n_u16 (0 ), a.v );
@@ -593,31 +592,17 @@ YNN_ALWAYS_INLINE s32x8 convert(s16x8 b, int32_t) {
593592}
594593
595594YNN_ALWAYS_INLINE s32x16 convert (s8x16 b, int32_t ) {
596- s16x16 b_s16 = convert (b, int16_t {});
597- s32x8 lo = convert (extract<0 >(b_s16, s16x8{}), int32_t {});
598- s32x8 hi = convert (extract<1 >(b_s16, s16x8{}), int32_t {});
599- return {
600- extract<0 >(lo, s32x4{}),
601- extract<1 >(lo, s32x4{}),
602- extract<0 >(hi, s32x4{}),
603- extract<1 >(hi, s32x4{}),
604- };
595+ return convert (convert (b, int16_t {}), int32_t {});
605596}
606597
607598YNN_ALWAYS_INLINE s32x16 convert (u8x16 b, int32_t ) {
608- s16x16 b_s16 = convert (b, int16_t {});
609- s32x8 lo = convert (extract<0 >(b_s16, s16x8{}), int32_t {});
610- s32x8 hi = convert (extract<1 >(b_s16, s16x8{}), int32_t {});
611- return {
612- extract<0 >(lo, s32x4{}),
613- extract<1 >(lo, s32x4{}),
614- extract<0 >(hi, s32x4{}),
615- extract<1 >(hi, s32x4{}),
616- };
599+ return convert (convert (b, int16_t {}), int32_t {});
617600}
618601
619602} // namespace simd
620603
621604} // namespace ynn
622605
606+ #include " ynnpack/base/simd/generic.inc" // IWYU pragma: export
607+
623608#endif // XNNPACK_YNNPACK_BASE_SIMD_ARM_H_
0 commit comments