Skip to content

Commit ab75531

Browse files
swolchokfacebook-github-bot
authored andcommitted
add quantized fast_hadamard_transform_28N (#5285)
Summary: Pull Request resolved: #5285 ghstack-source-id: 242230779 exported-using-ghexport Reviewed By: kimishpatel Differential Revision: D60943029 fbshipit-source-id: a961d24508e7b9a87bdc65dc7ad3ae59dccc250e
1 parent 327a5b6 commit ab75531

File tree

2 files changed

+69
-29
lines changed

2 files changed

+69
-29
lines changed

extension/llm/custom_ops/spinquant/fast_hadamard_transform.cpp

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,43 @@
1111
#include <algorithm>
1212

1313
namespace executorch {
14+
namespace {
15+
// Normalization step: divide by sqrt(1 << log2_vec_size). Similar
16+
// to fast_sqrt above, if N is even, then the maximum-precision way
17+
// to do this is right-shift by log2_vec_size / 2. If N is odd, we
18+
// still do the right-shift, and then we have an extra division by
19+
// sqrt(2) that we perform by making use of a sufficiently accurate
20+
// rational approximation. Our initial idea was to divide by sqrt(2)
21+
// by adjusting the quantization scale, but that would cause this
22+
// function to tend to increase the magnitude of the elements of
23+
// vec, which would resulting in clipping and therefore accuracy
24+
// loss, especially compounded over 30+ transformer layers.
25+
void quantized_normalize_after_fht(
26+
const int32_t* tmp,
27+
int16_t* out,
28+
int log2_vec_size,
29+
int vec_size) {
30+
const int log2_sqrt_vec_size = log2_vec_size / 2;
31+
constexpr int32_t qmin = -(1 << 15) + 1;
32+
constexpr int32_t qmax = -qmin;
33+
if (log2_vec_size % 2 != 0) {
34+
// 408 / 577 - 1.0 / sqrt(2) ~= 1.062e-0.6, which should be close enough.
35+
static const int32_t inv_sqrt_2_numerator = 408;
36+
static const int32_t inv_sqrt_2_denominator = 577;
37+
for (int ii = 0; ii < vec_size; ++ii) {
38+
const auto val_over_sqrt_vec_size =
39+
(tmp[ii] * inv_sqrt_2_numerator / inv_sqrt_2_denominator) >>
40+
log2_sqrt_vec_size;
41+
out[ii] = std::clamp(val_over_sqrt_vec_size, qmin, qmax);
42+
}
43+
} else {
44+
for (int ii = 0; ii < vec_size; ++ii) {
45+
out[ii] = std::clamp(tmp[ii] >> log2_sqrt_vec_size, qmin, qmax);
46+
}
47+
}
48+
}
49+
} // namespace
50+
1451
void fast_hadamard_transform_symmetric_quantized_s16(
1552
int16_t* vec,
1653
int log2_vec_size) {
@@ -27,42 +64,38 @@ void fast_hadamard_transform_symmetric_quantized_s16(
2764
auto tmp = std::make_unique<int32_t[]>(vec_size);
2865
std::copy(vec, vec + vec_size, tmp.get());
2966

30-
// Per the function-level comment in the header, we can ignore the
67+
// Per the function-level comment above, we can ignore the
3168
// quantization scale, so we just delegate to the usual unnormalized
3269
// implementation.
3370
// NOTE: if we need this to be fast on CPU, we can use FFHT to
3471
// generate fht_uint32 similar to fht_float.
3572
internal::fast_hadamard_transform_unnormalized_simple_impl(
3673
tmp.get(), log2_vec_size);
3774

38-
// Normalization step: divide by sqrt(1 << log2_vec_size). Similar
39-
// to fast_sqrt, if N is even, then the maximum-precision way
40-
// to do this is right-shift by log2_vec_size / 2. If N is odd, we
41-
// still do the right-shift, and then we have an extra division by
42-
// sqrt(2) that we perform by making use of a sufficiently accurate
43-
// rational approximation. (Our initial idea was to divide by sqrt(2)
44-
// by adjusting the quantization scale, but that would cause this
45-
// function to tend to increase the magnitude of the elements of
46-
// vec, which would resulting in clipping and therefore accuracy
47-
// loss, especially compounded over 30+ transformer layers.)
48-
const int log2_sqrt_vec_size = log2_vec_size / 2;
49-
constexpr int32_t qmin = -(1 << 15) + 1;
50-
constexpr int32_t qmax = -qmin;
51-
if (log2_vec_size % 2 != 0) {
52-
// 408 / 577 - 1.0 / sqrt(2) ~= 1.062e-0.6, which should be close enough.
53-
static const int32_t inv_sqrt_2_numerator = 408;
54-
static const int32_t inv_sqrt_2_denominator = 577;
55-
for (int ii = 0; ii < vec_size; ++ii) {
56-
const auto val_over_sqrt_vec_size =
57-
(tmp[ii] * inv_sqrt_2_numerator / inv_sqrt_2_denominator) >>
58-
log2_sqrt_vec_size;
59-
vec[ii] = std::clamp(val_over_sqrt_vec_size, qmin, qmax);
60-
}
61-
} else {
62-
for (int ii = 0; ii < vec_size; ++ii) {
63-
vec[ii] = std::clamp(tmp[ii] >> log2_sqrt_vec_size, qmin, qmax);
64-
}
75+
quantized_normalize_after_fht(tmp.get(), vec, log2_vec_size, vec_size);
76+
}
77+
78+
void fast_hadamard_transform_symmetric_quantized_s16_28N(
79+
int16_t* vec,
80+
int log2_vec_size) {
81+
if (log2_vec_size == 0) {
82+
return;
6583
}
66-
return;
84+
const int vec_size = (1 << log2_vec_size);
85+
86+
auto tmp = std::make_unique<int32_t[]>(vec_size * 28);
87+
std::copy(vec, vec + vec_size * 28, tmp.get());
88+
89+
for (int ii = 0; ii < 28; ++ii) {
90+
internal::fast_hadamard_transform_unnormalized_simple_impl(
91+
&tmp[ii * vec_size], log2_vec_size);
92+
}
93+
94+
for (int ii = 0; ii < vec_size; ++ii) {
95+
hadamard_mult_28_strided(&tmp[ii], vec_size);
96+
}
97+
98+
quantized_normalize_after_fht(tmp.get(), vec, log2_vec_size, vec_size * 28);
6799
}
100+
68101
} // namespace executorch

extension/llm/custom_ops/spinquant/fast_hadamard_transform.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,11 @@ void fast_hadamard_transform_28N(T* vec, int log2_vec_size) {
112112
}
113113
}
114114

115+
// We don't need the quantization scale; see the function-level
116+
// comment on fast_hadamard_transform_symmetric_quantized_s16 for
117+
// details.
118+
void fast_hadamard_transform_symmetric_quantized_s16_28N(
119+
int16_t* vec,
120+
int log2_vec_size);
121+
115122
} // namespace executorch

0 commit comments

Comments
 (0)