Skip to content

Commit b904833

Browse files
swolchokfacebook-github-bot
authored andcommitted
Add fast_hadamard_transform and fast_hadamard_transform_28N kernels (#5283)
Summary: Pull Request resolved: #5283 This adds a pair of kernels needed for SpinQuant (and QuaRot). ghstack-source-id: 242165579 exported-using-ghexport Reviewed By: kimishpatel, helunwencser Differential Revision: D60194968 fbshipit-source-id: a6b666dd3cc09dc4f5a5f14edd0a8807f7149417
1 parent 1d46d72 commit b904833

File tree

7 files changed

+654
-0
lines changed

7 files changed

+654
-0
lines changed

.lintrunner.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ exclude_patterns = [
7474
# NB: Objective-C is not supported
7575
'examples/apple/**',
7676
'examples/demo-apps/apple_ios/**',
77+
# File contains @generated
78+
'extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h',
7779
]
7880
command = [
7981
'python',
@@ -177,6 +179,8 @@ exclude_patterns = [
177179
'**/*.bat',
178180
'**/*.jpg',
179181
'**/*.jar',
182+
# File contains @generated
183+
'extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h',
180184
]
181185
command = [
182186
'python',
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SpinQuant
2+
3+
This is an implementation of the [Fast Hadamard
4+
Transform](https://en.wikipedia.org/wiki/Fast_Walsh–Hadamard_transform)
5+
as used in [SpinQuant](https://arxiv.org/abs/2405.16406) (for the R3
6+
and R4 matrices), [QuaRot](https://arxiv.org/abs/2404.00456), and
7+
[Quip#](https://arxiv.org/pdf/2402.04396). We follow those papers'
8+
method (as implemented in
9+
https://github.com/Dao-AILab/fast-hadamard-transform/) for extending
10+
the transform to non-power-of-two input sizes. CUDA is not considered
11+
because https://github.com/Dao-AILab/fast-hadamard-transform/ is
12+
already available.
13+
14+
The intended long-term destination for this code is pytorch/ao; it is
15+
in ExecuTorch temporarily until we get C++ dependency from ExecuTorch
16+
on torchao figured out.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
load(":targets.bzl", "define_common_targets")
2+
3+
oncall("executorch")
4+
5+
define_common_targets()
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// (c) Meta Platforms, Inc. and affiliates.
10+
#pragma once
11+
12+
#include <cassert>
13+
#include <cmath>
14+
#include <cstdint>
15+
16+
#include "fast_hadamard_transform_special.h"
17+
18+
namespace executorch {
19+
namespace internal {
20+
21+
// Square root of 1 << log2_n.
22+
template <typename T>
23+
T fast_sqrt_of_power_of_2(int log2_n) {
24+
// The square root of 2**N is, by definition, 2**(N/2), which is
25+
// trivial to compute for even N using a left shift.
26+
//
27+
// For odd N, 2**(N/2) = 2**(floor(N/2) + 1/2)
28+
// = 2**(floor(N/2)) * (2 ** (1/2))
29+
// = 2**(floor(N/2)) * sqrt(2)
30+
// which is again fast to compute.
31+
return T(1 << (log2_n / 2)) * ((log2_n % 2) ? T(std::sqrt(2)) : T(1));
32+
}
33+
34+
template <typename T>
35+
void normalize_after_fht(T* out, int log2_vec_size) {
36+
const T inv_sqrt = T(1) / fast_sqrt_of_power_of_2<T>(log2_vec_size);
37+
const int vec_size = 1 << log2_vec_size;
38+
for (int ii = 0; ii < vec_size; ++ii) {
39+
out[ii] *= inv_sqrt;
40+
}
41+
}
42+
43+
template <typename T>
44+
void fast_hadamard_transform_simple_impl(T* vec, int log2_vec_size) {
45+
if (log2_vec_size == 0) {
46+
return;
47+
}
48+
49+
int step = 1;
50+
const auto vec_size = 1 << log2_vec_size;
51+
while (step < vec_size) {
52+
for (int ii = 0; ii < vec_size; ii += step * 2) {
53+
for (int jj = ii; jj < ii + step; ++jj) {
54+
auto x = vec[jj];
55+
auto y = vec[jj + step];
56+
vec[jj] = x + y;
57+
vec[jj + step] = x - y;
58+
}
59+
}
60+
step *= 2;
61+
}
62+
63+
normalize_after_fht(vec, log2_vec_size);
64+
}
65+
66+
} // namespace internal
67+
68+
// Compute the fast Walsh-Hadamard transform
69+
// (https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform)
70+
// of vec, which must be of length (1 << log2_vec_size).
71+
template <typename T>
72+
void fast_hadamard_transform(T* vec, int log2_vec_size) {
73+
internal::fast_hadamard_transform_simple_impl(vec, log2_vec_size);
74+
}
75+
76+
// Like fast_hadamard_transform, but vec must be of length 28 * (1 <<
77+
// log2_vec_size) and the transform is computed by interpreting vec as
78+
// a (28, 1 << log2_vec_size) matrix and performing 28 FHTs, followed
79+
// by (1 << log2_vec_size) multiplications by a particular Hadamard
80+
// matrix of size 28x28 (see special_hadamard_code_gen.py for the
81+
// exact matrix).
82+
template <typename T>
83+
void fast_hadamard_transform_28N(T* vec, int log2_vec_size) {
84+
const int vec_size = (1 << log2_vec_size);
85+
for (int ii = 0; ii < 28; ++ii) {
86+
fast_hadamard_transform(&vec[ii * vec_size], log2_vec_size);
87+
}
88+
for (int ii = 0; ii < vec_size; ++ii) {
89+
hadamard_mult_28_strided(&vec[ii], vec_size);
90+
}
91+
}
92+
93+
} // namespace executorch

0 commit comments

Comments
 (0)