Skip to content

Commit 327a5b6

Browse files
swolchokfacebook-github-bot
authored andcommitted
Quantized fast hadamard transform (#5284)
Summary: Pull Request resolved: #5284 Demonstrate that we can calculate a quantized fast hadamard transform with integer math only, except for adjusting the scale of the result. (Not sure if there is a reason to actually commit this -- do we have a use case for quantized FHT on CPU?) Reviewed By: kimishpatel Differential Revision: D60866280 fbshipit-source-id: ba114f10c544b5d04c19f7d94a32043c237a2c65
1 parent b904833 commit 327a5b6

File tree

3 files changed

+94
-1
lines changed

3 files changed

+94
-1
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fast_hadamard_transform.h"
10+
11+
#include <algorithm>
12+
13+
namespace executorch {
14+
void fast_hadamard_transform_symmetric_quantized_s16(
15+
int16_t* vec,
16+
int log2_vec_size) {
17+
if (log2_vec_size == 0) {
18+
return;
19+
}
20+
21+
const int vec_size = 1 << log2_vec_size;
22+
// We perform log2_vec_size rounds where each round's maximum output
23+
// is at most double the maximum input, so we can at most multiply
24+
// the maximum input by vec_size. Performing intermediate arithmetic
25+
// in 32-bit precision should prevent overflow, since 16 +
26+
// log2_vec_size should be much less than 32.
27+
auto tmp = std::make_unique<int32_t[]>(vec_size);
28+
std::copy(vec, vec + vec_size, tmp.get());
29+
30+
// Per the function-level comment in the header, we can ignore the
31+
// quantization scale, so we just delegate to the usual unnormalized
32+
// implementation.
33+
// NOTE: if we need this to be fast on CPU, we can use FFHT to
34+
// generate fht_uint32 similar to fht_float.
35+
internal::fast_hadamard_transform_unnormalized_simple_impl(
36+
tmp.get(), log2_vec_size);
37+
38+
// Normalization step: divide by sqrt(1 << log2_vec_size). Similar
39+
// to fast_sqrt, if N is even, then the maximum-precision way
40+
// to do this is right-shift by log2_vec_size / 2. If N is odd, we
41+
// still do the right-shift, and then we have an extra division by
42+
// sqrt(2) that we perform by making use of a sufficiently accurate
43+
// rational approximation. (Our initial idea was to divide by sqrt(2)
44+
// by adjusting the quantization scale, but that would cause this
45+
// function to tend to increase the magnitude of the elements of
46+
// vec, which would resulting in clipping and therefore accuracy
47+
// loss, especially compounded over 30+ transformer layers.)
48+
const int log2_sqrt_vec_size = log2_vec_size / 2;
49+
constexpr int32_t qmin = -(1 << 15) + 1;
50+
constexpr int32_t qmax = -qmin;
51+
if (log2_vec_size % 2 != 0) {
52+
// 408 / 577 - 1.0 / sqrt(2) ~= 1.062e-0.6, which should be close enough.
53+
static const int32_t inv_sqrt_2_numerator = 408;
54+
static const int32_t inv_sqrt_2_denominator = 577;
55+
for (int ii = 0; ii < vec_size; ++ii) {
56+
const auto val_over_sqrt_vec_size =
57+
(tmp[ii] * inv_sqrt_2_numerator / inv_sqrt_2_denominator) >>
58+
log2_sqrt_vec_size;
59+
vec[ii] = std::clamp(val_over_sqrt_vec_size, qmin, qmax);
60+
}
61+
} else {
62+
for (int ii = 0; ii < vec_size; ++ii) {
63+
vec[ii] = std::clamp(tmp[ii] >> log2_sqrt_vec_size, qmin, qmax);
64+
}
65+
}
66+
return;
67+
}
68+
} // namespace executorch

extension/llm/custom_ops/spinquant/fast_hadamard_transform.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <cassert>
1313
#include <cmath>
1414
#include <cstdint>
15+
#include <memory>
1516

1617
#include "fast_hadamard_transform_special.h"
1718

@@ -41,7 +42,9 @@ void normalize_after_fht(T* out, int log2_vec_size) {
4142
}
4243

4344
template <typename T>
44-
void fast_hadamard_transform_simple_impl(T* vec, int log2_vec_size) {
45+
void fast_hadamard_transform_unnormalized_simple_impl(
46+
T* vec,
47+
int log2_vec_size) {
4548
if (log2_vec_size == 0) {
4649
return;
4750
}
@@ -59,7 +62,11 @@ void fast_hadamard_transform_simple_impl(T* vec, int log2_vec_size) {
5962
}
6063
step *= 2;
6164
}
65+
}
6266

67+
template <typename T>
68+
void fast_hadamard_transform_simple_impl(T* vec, int log2_vec_size) {
69+
fast_hadamard_transform_unnormalized_simple_impl(vec, log2_vec_size);
6370
normalize_after_fht(vec, log2_vec_size);
6471
}
6572

@@ -73,6 +80,21 @@ void fast_hadamard_transform(T* vec, int log2_vec_size) {
7380
internal::fast_hadamard_transform_simple_impl(vec, log2_vec_size);
7481
}
7582

83+
// Compute a quantized fast Walsh-Hadamard transform of vec, which
84+
// must be of length (1 << log2_vec_size) and symmetrically quantized.
85+
//
86+
// Note that we do not need to know the quantization scale, because
87+
// the Fast Hadamard transform is a series of additions and
88+
// subtractions with a final multiplication step, and we have the
89+
// following trivial identities:
90+
//
91+
// scale * a + scale * b = scale * (a + b) (addition doesn't need the scale)
92+
// alpha * (scale * a) = scale * (alpha * a) (multiplication doesn't need the
93+
// scale)
94+
void fast_hadamard_transform_symmetric_quantized_s16(
95+
int16_t* vec,
96+
int log2_vec_size);
97+
7698
// Like fast_hadamard_transform, but vec must be of length 28 * (1 <<
7799
// log2_vec_size) and the transform is computed by interpreting vec as
78100
// a (28, 1 << log2_vec_size) matrix and performing 28 FHTs, followed

extension/llm/custom_ops/spinquant/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,8 @@ def define_common_targets():
1212
"fast_hadamard_transform.h",
1313
"fast_hadamard_transform_special.h",
1414
],
15+
srcs = [
16+
"fast_hadamard_transform.cpp",
17+
],
1518
visibility = ["@EXECUTORCH_CLIENTS"],
1619
)

0 commit comments

Comments
 (0)