Skip to content

Commit 25939d2

Browse files
committed
FFHT enhancements to fast hadamard transform kernels
Pull Request resolved: #5290 Use FFHT to speed up Fast Hadamard Transform on CPU. fast_hadamard_test was delayed to here becuase it was a source for a reference implementation. ghstack-source-id: 242165582 @exported-using-ghexport Differential Revision: [D61029709](https://our.internmc.facebook.com/intern/diff/D61029709/)
1 parent ad3948d commit 25939d2

File tree

11 files changed

+354
-58
lines changed

11 files changed

+354
-58
lines changed

extension/llm/custom_ops/spinquant/fast_hadamard_transform.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <cstdint>
1515
#include <memory>
1616

17+
#include <executorch/extension/llm/custom_ops/spinquant/third-party/FFHT/fht.h>
18+
1719
#include "fast_hadamard_transform_special.h"
1820

1921
namespace executorch {
@@ -41,10 +43,22 @@ void normalize_after_fht(T* out, int log2_vec_size) {
4143
}
4244
}
4345

46+
inline void fast_hadamard_transform_ffht_impl(float* vec, int log2_vec_size) {
47+
if (log2_vec_size <= 0) {
48+
return;
49+
}
50+
51+
fht_float(vec, log2_vec_size);
52+
normalize_after_fht(vec, log2_vec_size);
53+
}
54+
4455
template <typename T>
4556
void fast_hadamard_transform_unnormalized_simple_impl(
4657
T* vec,
4758
int log2_vec_size) {
59+
// NOTE: If you're here because you're profiling a model and this is
60+
// slow, consider updating FFHT to generate efficient assembly for
61+
// your data type!
4862
if (log2_vec_size == 0) {
4963
return;
5064
}
@@ -77,7 +91,11 @@ void fast_hadamard_transform_simple_impl(T* vec, int log2_vec_size) {
7791
// of vec, which must be of length (1 << log2_vec_size).
7892
template <typename T>
7993
void fast_hadamard_transform(T* vec, int log2_vec_size) {
80-
internal::fast_hadamard_transform_simple_impl(vec, log2_vec_size);
94+
if constexpr (std::is_same_v<T, float>) {
95+
internal::fast_hadamard_transform_ffht_impl(vec, log2_vec_size);
96+
} else {
97+
internal::fast_hadamard_transform_simple_impl(vec, log2_vec_size);
98+
}
8199
}
82100

83101
// Compute a quantized fast Walsh-Hadamard transform of vec, which

extension/llm/custom_ops/spinquant/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@ def define_common_targets():
1515
srcs = [
1616
"fast_hadamard_transform.cpp",
1717
],
18+
exported_deps = [
19+
"//executorch/extension/llm/custom_ops/spinquant/third-party/FFHT:fht",
20+
],
1821
visibility = ["@EXECUTORCH_CLIENTS"],
1922
)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
load(":targets.bzl", "define_common_targets")
2+
3+
oncall("executorch")
4+
5+
define_common_targets()
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cmath>
10+
#include <iostream>
11+
#include <random>
12+
13+
#include <gtest/gtest.h>
14+
15+
#include <executorch/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h>
16+
#include <executorch/extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h>
17+
#include <executorch/extension/llm/custom_ops/spinquant/third-party/FFHT/dumb_fht.h>
18+
19+
namespace {
20+
void reference_fht_impl(float* buf, int n) {
21+
dumb_fht(buf, std::log2<int>(n));
22+
const auto root_n = std::sqrt(n);
23+
for (int ii = 0; ii < n; ++ii) {
24+
buf[ii] /= root_n;
25+
}
26+
}
27+
28+
// Alternate implementation of fast_hadamard_transform_28N to mutation
29+
// test against. Benchmarking suggests this one is slower, which is
30+
// why it's in the test and the strided implementation is in the
31+
// header.
32+
template <typename T>
33+
void fast_hadamard_transform_28N_with_transpose(T* vec, int log2_vec_size) {
34+
const int vec_size = (1 << log2_vec_size);
35+
for (int ii = 0; ii < 28; ++ii) {
36+
executorch::fast_hadamard_transform(&vec[ii * vec_size], log2_vec_size);
37+
}
38+
std::unique_ptr<T[]> transposed = std::make_unique<T[]>(28 * vec_size);
39+
for (int ii = 0; ii < 28; ++ii) {
40+
for (int jj = 0; jj < vec_size; ++jj) {
41+
transposed[jj * 28 + ii] = vec[ii * vec_size + jj];
42+
}
43+
}
44+
for (int ii = 0; ii < vec_size; ++ii) {
45+
hadamard_mult_28(&transposed[ii * 28]);
46+
}
47+
for (int jj = 0; jj < vec_size; ++jj) {
48+
for (int ii = 0; ii < 28; ++ii) {
49+
vec[ii * vec_size + jj] = transposed[jj * 28 + ii];
50+
}
51+
}
52+
}
53+
54+
std::vector<float> randomFloats(int howMany) {
55+
std::random_device rd;
56+
std::mt19937 gen(rd());
57+
std::normal_distribution<float> dist;
58+
std::vector<float> data(howMany);
59+
for (int ii = 0; ii < data.size(); ++ii) {
60+
data[ii] = dist(gen);
61+
}
62+
return data;
63+
}
64+
} // namespace
65+
66+
TEST(FastHadamardTransformTest, SingleElement) {
67+
// FHT of a single element is a no-op.
68+
float data[1] = {42};
69+
executorch::fast_hadamard_transform(data, 0);
70+
EXPECT_EQ(data[0], 42);
71+
}
72+
73+
TEST(FastHadamardTransformTest, LargerInput) {
74+
std::vector<float> data = randomFloats(4096);
75+
76+
auto expected = data;
77+
reference_fht_impl(expected.data(), expected.size());
78+
79+
auto actual = data;
80+
executorch::fast_hadamard_transform(actual.data(), 12);
81+
82+
for (int ii = 0; ii < expected.size(); ++ii) {
83+
EXPECT_FLOAT_EQ(actual[ii], expected[ii]);
84+
}
85+
}
86+
87+
TEST(FastHadamardTransform28NTest, Basic) {
88+
std::vector<float> data = randomFloats(1024 * 28);
89+
90+
auto expected = data;
91+
fast_hadamard_transform_28N_with_transpose(expected.data(), 10);
92+
93+
auto actual = data;
94+
executorch::fast_hadamard_transform_28N(actual.data(), 10);
95+
96+
for (int ii = 0; ii < actual.size(); ++ii) {
97+
EXPECT_FLOAT_EQ(actual[ii], expected[ii]);
98+
}
99+
}
100+
101+
namespace {
102+
constexpr int32_t qmin = -(1 << 15) + 1;
103+
constexpr int32_t qmax = -qmin;
104+
105+
int16_t quantize(float x, float scale) {
106+
float scaled = x / scale;
107+
// XXX: Supposed to round ties to even, but this is just test code.
108+
int32_t scaled_int =
109+
std::clamp((int32_t)std::lround<int32_t>(scaled), qmin, qmax);
110+
return static_cast<int16_t>(scaled_int);
111+
}
112+
113+
template <typename T>
114+
std::vector<T> quantize(const std::vector<float>& data, float scale) {
115+
std::vector<T> result;
116+
result.reserve(data.size());
117+
for (const float unquant : data) {
118+
result.push_back(quantize(unquant, scale));
119+
}
120+
return result;
121+
}
122+
123+
template <typename T>
124+
std::pair<std::vector<T>, float> quantize(const std::vector<float>& data) {
125+
auto [minIt, maxIt] = std::minmax_element(data.begin(), data.end());
126+
float scale = (*maxIt - *minIt) / (qmax - qmin);
127+
return {quantize<T>(data, scale), scale};
128+
}
129+
130+
template <typename T>
131+
float dequantize(T x, float scale) {
132+
return x * scale;
133+
}
134+
135+
template <typename T>
136+
std::vector<float> dequantize(const std::vector<T>& data, float scale) {
137+
static_assert(!std::is_same_v<T, float>);
138+
std::vector<float> result;
139+
result.reserve(data.size());
140+
for (const T quant : data) {
141+
result.push_back(dequantize(quant, scale));
142+
}
143+
return result;
144+
}
145+
146+
#define EXPECT_CLOSE_IMPL(a, b, atol, rtol) \
147+
EXPECT_LE(std::abs(a - b), atol + rtol * std::abs(b)) \
148+
<< "a: " << a << ", b: " << b
149+
#define EXPECT_CLOSE(a, b) EXPECT_CLOSE_IMPL(a, b, 2e-4, 1e-4)
150+
151+
void testQuantizedFastHadamardTransform(int logN) {
152+
std::vector<float> data = randomFloats(1 << logN);
153+
154+
auto [qdata, scale] = quantize<int16_t>(data);
155+
156+
auto expected_unquant = dequantize(qdata, scale);
157+
reference_fht_impl(expected_unquant.data(), expected_unquant.size());
158+
auto expected = quantize<int16_t>(expected_unquant, scale);
159+
160+
auto actual = qdata;
161+
executorch::fast_hadamard_transform_symmetric_quantized_s16(
162+
actual.data(), logN);
163+
164+
for (int ii = 0; ii < expected.size(); ++ii) {
165+
EXPECT_CLOSE(
166+
dequantize(actual[ii], scale), dequantize(expected[ii], scale));
167+
}
168+
}
169+
170+
} // namespace
171+
172+
TEST(QuantizedFastHadamardTransformTest, Basic) {
173+
testQuantizedFastHadamardTransform(12); // 4096
174+
}
175+
176+
TEST(QuantizedFastHadamardTransformTest, OddLogN) {
177+
testQuantizedFastHadamardTransform(11); // 2048
178+
}
179+
180+
TEST(QuantizedFastHadamardTransform28NTest, Basic) {
181+
std::vector<float> data = randomFloats(1024 * 28);
182+
183+
auto [qdata, scale] = quantize<int16_t>(data);
184+
185+
auto expected_unquant = dequantize(qdata, scale);
186+
fast_hadamard_transform_28N_with_transpose(expected_unquant.data(), 10);
187+
auto expected = quantize<int16_t>(expected_unquant, scale);
188+
189+
auto actual = qdata;
190+
executorch::fast_hadamard_transform_symmetric_quantized_s16_28N(
191+
actual.data(), 10);
192+
193+
for (int ii = 0; ii < expected.size(); ++ii) {
194+
std::cerr << "element " << ii << ": actual: " << actual[ii]
195+
<< ", expected: " << expected[ii] << std::endl;
196+
EXPECT_CLOSE(
197+
dequantize(actual[ii], scale), dequantize(expected[ii], scale));
198+
}
199+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
runtime.cxx_test(
10+
name = "fast_hadamard_transform_test",
11+
srcs = ["fast_hadamard_transform_test.cpp"],
12+
headers = ["fast_hadamard_transform_special_unstrided_cpu.h"],
13+
deps = [
14+
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
15+
"//executorch/extension/llm/custom_ops/spinquant/third-party/FFHT:dumb_fht",
16+
],
17+
)

extension/llm/custom_ops/spinquant/third-party/FFHT/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ CFLAGS = -O3 -march=native -std=c99 -pedantic -Wall -Wextra -Wshadow -Wpointer-a
33

44
all: test_float test_double fast_copy.o fht.o
55

6-
OBJ := fast_copy.o fht.o
6+
OBJ := dumb_fht.o fast_copy.o fht.o
77

88
%.o: %.c
99
$(CC) $< -o $@ -c $(CFLAGS)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
load(":targets.bzl", "define_common_targets")
2+
3+
oncall("executorch")
4+
5+
define_common_targets()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include "dumb_fht.h"
2+
3+
void dumb_fht(float* buf, int log_n) {
4+
int n = 1 << log_n;
5+
for (int i = 0; i < log_n; ++i) {
6+
int s1 = 1 << i;
7+
int s2 = s1 << 1;
8+
for (int j = 0; j < n; j += s2) {
9+
for (int k = 0; k < s1; ++k) {
10+
float u = buf[j + k];
11+
float v = buf[j + k + s1];
12+
buf[j + k] = u + v;
13+
buf[j + k + s1] = u - v;
14+
}
15+
}
16+
}
17+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef DUMB_FHT_H
2+
#define DUMB_FHT_H
3+
4+
#ifdef __cplusplus
5+
extern "C" {
6+
#endif
7+
8+
void dumb_fht(float* buf, int log_n);
9+
10+
#ifdef __cplusplus
11+
} // extern "C"
12+
#endif
13+
14+
#endif /* DUMB_FHT_H */
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
runtime.cxx_library(
10+
name = "dumb_fht",
11+
srcs = ["dumb_fht.c"],
12+
exported_headers = ["dumb_fht.h"],
13+
visibility = ["@EXECUTORCH_CLIENTS"],
14+
)
15+
16+
runtime.cxx_library(
17+
name = "fht",
18+
srcs = select({
19+
"DEFAULT": ["fht_avx.c"],
20+
"ovr_config//cpu:arm64": ["fht_neon.c"],
21+
}),
22+
exported_headers = ["fht.h"],
23+
visibility = ["@EXECUTORCH_CLIENTS"],
24+
)
25+
26+
runtime.cxx_binary(
27+
name = "test_float",
28+
srcs = ["test_float.c"],
29+
deps = [
30+
":dumb_fht",
31+
":fht",
32+
],
33+
)

0 commit comments

Comments
 (0)