@@ -74,43 +74,60 @@ f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
74
74
return f32_fma (a, to_bfloat16 (b), to_bfloat16 (c));
75
75
}
76
76
77
- #ifdef __ARM_FEATURE_BF16
78
- static ET_INLINE float32x4_t
77
+ #define ET_TARGET_ARM_BF16_ATTRIBUTE \
78
+ __attribute__ ((target(" arch=armv8.2-a+bf16" )))
79
+ ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE float32x4_t
79
80
f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
80
81
return vbfdotq_f32 (a, b, c);
81
82
}
82
- #endif // __ARM_FEATURE_BF16
83
83
84
- template <bool useBfloat16Dot>
85
- static ET_INLINE void dot_with_fp32_arith_main_inner_loop (
84
+ ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
85
+ dot_with_fp32_arith_main_inner_loop_bfdot (
86
+ const BFloat16* vec1,
87
+ const BFloat16* vec2,
88
+ float32x4_t sum[kF32RegistersPerIteration ],
89
+ int registerPairIndex) {
90
+ const bfloat16x8_t temp_vec1 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
91
+ &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
92
+ const bfloat16x8_t temp_vec2 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
93
+ &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
94
+ sum[registerPairIndex] =
95
+ f32_dot_bf16 (sum[registerPairIndex], temp_vec1, temp_vec2);
96
+ }
97
+
98
+ static ET_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot (
99
+ const BFloat16* vec1,
100
+ const BFloat16* vec2,
101
+ float32x4_t sum[kF32RegistersPerIteration ],
102
+ int registerPairIndex) {
103
+ const uint16x8_t temp_vec1 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
104
+ &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
105
+ const uint16x8_t temp_vec2 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
106
+ &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
107
+
108
+ sum[2 * registerPairIndex] = f32_fma_bf16 (
109
+ sum[2 * registerPairIndex],
110
+ vget_low_u16 (temp_vec1),
111
+ vget_low_u16 (temp_vec2));
112
+ sum[2 * registerPairIndex + 1 ] = f32_fma_bf16 (
113
+ sum[2 * registerPairIndex + 1 ],
114
+ vget_high_u16 (temp_vec1),
115
+ vget_high_u16 (temp_vec2));
116
+ }
117
+
118
+ template <bool useBfdot>
119
+ ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
120
+ dot_with_fp32_arith_main_inner_loop (
86
121
const BFloat16* vec1,
87
122
const BFloat16* vec2,
88
123
float32x4_t sum[kF32RegistersPerIteration ],
89
124
int registerPairIndex) {
90
- #ifdef __ARM_FEATURE_BF16
91
- if (useBfloat16Dot) {
92
- const bfloat16x8_t temp_vec1 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
93
- &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
94
- const bfloat16x8_t temp_vec2 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
95
- &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
96
- sum[registerPairIndex] =
97
- f32_dot_bf16 (sum[registerPairIndex], temp_vec1, temp_vec2);
98
- } else
99
- #endif // __ARM_FEATURE_BF16
100
- {
101
- const uint16x8_t temp_vec1 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
102
- &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
103
- const uint16x8_t temp_vec2 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
104
- &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
105
-
106
- sum[2 * registerPairIndex] = f32_fma_bf16 (
107
- sum[2 * registerPairIndex],
108
- vget_low_u16 (temp_vec1),
109
- vget_low_u16 (temp_vec2));
110
- sum[2 * registerPairIndex + 1 ] = f32_fma_bf16 (
111
- sum[2 * registerPairIndex + 1 ],
112
- vget_high_u16 (temp_vec1),
113
- vget_high_u16 (temp_vec2));
125
+ if constexpr (useBfdot) {
126
+ dot_with_fp32_arith_main_inner_loop_bfdot (
127
+ vec1, vec2, sum, registerPairIndex);
128
+ } else {
129
+ dot_with_fp32_arith_main_inner_loop_no_bfdot (
130
+ vec1, vec2, sum, registerPairIndex);
114
131
}
115
132
}
116
133
@@ -126,18 +143,40 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
126
143
*tailSum = f32_fma_bf16 (*tailSum, temp_vec1, temp_vec2);
127
144
}
128
145
129
- template <typename T, bool useBfloat16Dot>
130
- float dot_with_fp32_arith (const T* vec1, const T* vec2, int64_t len) {
146
+ namespace {
147
+ template <int n>
148
+ struct ForcedUnrollTargetBFloat16 {
149
+ template <typename Func>
150
+ ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator ()(const Func& f) const {
151
+ ForcedUnrollTargetBFloat16<n - 1 >{}(f);
152
+ f (n - 1 );
153
+ }
154
+ };
155
+
156
+ template <>
157
+ struct ForcedUnrollTargetBFloat16 <1 > {
158
+ template <typename Func>
159
+ ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator ()(const Func& f) const {
160
+ f (0 );
161
+ }
162
+ };
163
+
164
+ } // namespace
165
+
166
+ template <typename T, bool useBFloat16Dot>
167
+ ET_TARGET_ARM_BF16_ATTRIBUTE float
168
+ dot_with_fp32_arith (const T* vec1, const T* vec2, int64_t len) {
131
169
float32x4_t sum[kF32RegistersPerIteration ] = {vdupq_n_f32 (0 )};
132
170
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 );
133
171
for (int j = 0 ; j < len_aligned; j += kF32ElementsPerIteration ) {
134
172
const auto * vec1_ = vec1 + j;
135
173
const auto * vec2_ = vec2 + j;
136
- utils::ForcedUnroll<kF32RegisterPairsPerIteration >{}(
137
- [vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE {
138
- dot_with_fp32_arith_main_inner_loop<useBfloat16Dot>(
139
- vec1_, vec2_, sum, k);
140
- });
174
+ ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration >{}(
175
+ [vec1_, vec2_, &sum](auto k)
176
+ ET_INLINE_ATTRIBUTE ET_TARGET_ARM_BF16_ATTRIBUTE {
177
+ dot_with_fp32_arith_main_inner_loop<useBFloat16Dot>(
178
+ vec1_, vec2_, sum, k);
179
+ });
141
180
}
142
181
auto reducedSum = reduce (sum);
143
182
@@ -163,12 +202,9 @@ float bf16_dot_with_fp32_arith(
163
202
const BFloat16* vec1,
164
203
const BFloat16* vec2,
165
204
int64_t len) {
166
- #ifdef __ARM_FEATURE_BF16
167
205
if (cpuinfo_has_arm_bf16 ()) {
168
206
return dot_with_fp32_arith<BFloat16, true >(vec1, vec2, len);
169
- } else
170
- #endif // __ARM_FEATURE_BF16
171
- {
207
+ } else {
172
208
return dot_with_fp32_arith<BFloat16, false >(vec1, vec2, len);
173
209
}
174
210
}
0 commit comments