Skip to content

Commit 4218d45

Browse files
swolchokfacebook-github-bot
authored andcommitted
FFHT enhancements to fast hadamard transform kernels (#5290)
Summary: Pull Request resolved: #5290 Use FFHT to speed up Fast Hadamard Transform on CPU. fast_hadamard_test was delayed to here becuase it was a source for a reference implementation. ghstack-source-id: 242230781 exported-using-ghexport Reviewed By: mergennachin Differential Revision: D61029709 fbshipit-source-id: e62b2c98a921b9512bfda3e490dd807e8ab0d291
1 parent eedc38a commit 4218d45

File tree

12 files changed

+497
-58
lines changed

12 files changed

+497
-58
lines changed

extension/llm/custom_ops/spinquant/fast_hadamard_transform.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <cstdint>
1515
#include <memory>
1616

17+
#include <executorch/extension/llm/custom_ops/spinquant/third-party/FFHT/fht.h>
18+
1719
#include "fast_hadamard_transform_special.h"
1820

1921
namespace executorch {
@@ -45,6 +47,9 @@ template <typename T>
4547
void fast_hadamard_transform_unnormalized_simple_impl(
4648
T* vec,
4749
int log2_vec_size) {
50+
// NOTE: If you're here because you're profiling a model and this is
51+
// slow, consider updating FFHT to generate efficient assembly for
52+
// your data type!
4853
if (log2_vec_size == 0) {
4954
return;
5055
}
@@ -70,14 +75,31 @@ void fast_hadamard_transform_simple_impl(T* vec, int log2_vec_size) {
7075
normalize_after_fht(vec, log2_vec_size);
7176
}
7277

78+
inline void fast_hadamard_transform_ffht_impl(float* vec, int log2_vec_size) {
79+
#if defined(__aarch64__) || defined(__x86_64__)
80+
if (log2_vec_size <= 0) {
81+
return;
82+
}
83+
84+
fht_float(vec, log2_vec_size);
85+
normalize_after_fht(vec, log2_vec_size);
86+
#else
87+
fast_hadamard_transform_simple_impl(vec, log2_vec_size);
88+
#endif
89+
}
90+
7391
} // namespace internal
7492

7593
// Compute the fast Walsh-Hadamard transform
7694
// (https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform)
7795
// of vec, which must be of length (1 << log2_vec_size).
7896
template <typename T>
7997
void fast_hadamard_transform(T* vec, int log2_vec_size) {
80-
internal::fast_hadamard_transform_simple_impl(vec, log2_vec_size);
98+
if constexpr (std::is_same_v<T, float>) {
99+
internal::fast_hadamard_transform_ffht_impl(vec, log2_vec_size);
100+
} else {
101+
internal::fast_hadamard_transform_simple_impl(vec, log2_vec_size);
102+
}
81103
}
82104

83105
// Compute a quantized fast Walsh-Hadamard transform of vec, which

extension/llm/custom_ops/spinquant/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@ def define_common_targets():
1515
srcs = [
1616
"fast_hadamard_transform.cpp",
1717
],
18+
exported_deps = [
19+
"//executorch/extension/llm/custom_ops/spinquant/third-party/FFHT:fht",
20+
],
1821
visibility = ["@EXECUTORCH_CLIENTS"],
1922
)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
load(":targets.bzl", "define_common_targets")
2+
3+
oncall("executorch")
4+
5+
define_common_targets()

extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h

Lines changed: 137 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <algorithm>
10+
#include <cmath>
11+
#include <iostream>
12+
#include <random>
13+
14+
#include <gtest/gtest.h>
15+
16+
#include <executorch/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h>
17+
#include <executorch/extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h>
18+
#include <executorch/extension/llm/custom_ops/spinquant/third-party/FFHT/dumb_fht.h>
19+
20+
namespace {
21+
void reference_fht_impl(float* buf, int n) {
22+
dumb_fht(buf, std::log2<int>(n));
23+
const auto root_n = std::sqrt(n);
24+
for (int ii = 0; ii < n; ++ii) {
25+
buf[ii] /= root_n;
26+
}
27+
}
28+
29+
// Alternate implementation of fast_hadamard_transform_28N to mutation
30+
// test against. Benchmarking suggests this one is slower, which is
31+
// why it's in the test and the strided implementation is in the
32+
// header.
33+
template <typename T>
34+
void fast_hadamard_transform_28N_with_transpose(T* vec, int log2_vec_size) {
35+
const int vec_size = (1 << log2_vec_size);
36+
for (int ii = 0; ii < 28; ++ii) {
37+
executorch::fast_hadamard_transform(&vec[ii * vec_size], log2_vec_size);
38+
}
39+
std::unique_ptr<T[]> transposed = std::make_unique<T[]>(28 * vec_size);
40+
for (int ii = 0; ii < 28; ++ii) {
41+
for (int jj = 0; jj < vec_size; ++jj) {
42+
transposed[jj * 28 + ii] = vec[ii * vec_size + jj];
43+
}
44+
}
45+
for (int ii = 0; ii < vec_size; ++ii) {
46+
hadamard_mult_28(&transposed[ii * 28]);
47+
}
48+
for (int jj = 0; jj < vec_size; ++jj) {
49+
for (int ii = 0; ii < 28; ++ii) {
50+
vec[ii * vec_size + jj] = transposed[jj * 28 + ii];
51+
}
52+
}
53+
}
54+
55+
std::vector<float> randomFloats(int howMany) {
56+
std::random_device rd;
57+
std::mt19937 gen(rd());
58+
std::normal_distribution<float> dist;
59+
std::vector<float> data(howMany);
60+
for (int ii = 0; ii < data.size(); ++ii) {
61+
data[ii] = dist(gen);
62+
}
63+
return data;
64+
}
65+
} // namespace
66+
67+
TEST(FastHadamardTransformTest, SingleElement) {
68+
// FHT of a single element is a no-op.
69+
float data[1] = {42};
70+
executorch::fast_hadamard_transform(data, 0);
71+
EXPECT_EQ(data[0], 42);
72+
}
73+
74+
TEST(FastHadamardTransformTest, LargerInput) {
75+
std::vector<float> data = randomFloats(4096);
76+
77+
auto expected = data;
78+
reference_fht_impl(expected.data(), expected.size());
79+
80+
auto actual = data;
81+
executorch::fast_hadamard_transform(actual.data(), 12);
82+
83+
for (int ii = 0; ii < expected.size(); ++ii) {
84+
EXPECT_FLOAT_EQ(actual[ii], expected[ii]);
85+
}
86+
}
87+
88+
TEST(FastHadamardTransform28NTest, Basic) {
89+
std::vector<float> data = randomFloats(1024 * 28);
90+
91+
auto expected = data;
92+
fast_hadamard_transform_28N_with_transpose(expected.data(), 10);
93+
94+
auto actual = data;
95+
executorch::fast_hadamard_transform_28N(actual.data(), 10);
96+
97+
for (int ii = 0; ii < actual.size(); ++ii) {
98+
EXPECT_FLOAT_EQ(actual[ii], expected[ii]);
99+
}
100+
}
101+
102+
namespace {
103+
constexpr int32_t qmin = -(1 << 15) + 1;
104+
constexpr int32_t qmax = -qmin;
105+
106+
int16_t quantize(float x, float scale) {
107+
float scaled = x / scale;
108+
// XXX: Supposed to round ties to even, but this is just test code.
109+
int32_t scaled_int =
110+
std::clamp((int32_t)std::lround<int32_t>(scaled), qmin, qmax);
111+
return static_cast<int16_t>(scaled_int);
112+
}
113+
114+
template <typename T>
115+
std::vector<T> quantize(const std::vector<float>& data, float scale) {
116+
std::vector<T> result;
117+
result.reserve(data.size());
118+
for (const float unquant : data) {
119+
result.push_back(quantize(unquant, scale));
120+
}
121+
return result;
122+
}
123+
124+
template <typename T>
125+
std::pair<std::vector<T>, float> quantize(const std::vector<float>& data) {
126+
auto [minIt, maxIt] = std::minmax_element(data.begin(), data.end());
127+
float scale = (*maxIt - *minIt) / (qmax - qmin);
128+
return {quantize<T>(data, scale), scale};
129+
}
130+
131+
template <typename T>
132+
float dequantize(T x, float scale) {
133+
return x * scale;
134+
}
135+
136+
template <typename T>
137+
std::vector<float> dequantize(const std::vector<T>& data, float scale) {
138+
static_assert(!std::is_same_v<T, float>);
139+
std::vector<float> result;
140+
result.reserve(data.size());
141+
for (const T quant : data) {
142+
result.push_back(dequantize(quant, scale));
143+
}
144+
return result;
145+
}
146+
147+
#define EXPECT_CLOSE_IMPL(a, b, atol, rtol) \
148+
EXPECT_LE(std::abs(a - b), atol + rtol * std::abs(b)) \
149+
<< "a: " << a << ", b: " << b
150+
#define EXPECT_CLOSE(a, b) EXPECT_CLOSE_IMPL(a, b, 2e-4, 1e-4)
151+
152+
void testQuantizedFastHadamardTransform(int logN) {
153+
std::vector<float> data = randomFloats(1 << logN);
154+
155+
auto [qdata, scale] = quantize<int16_t>(data);
156+
157+
auto expected_unquant = dequantize(qdata, scale);
158+
reference_fht_impl(expected_unquant.data(), expected_unquant.size());
159+
auto expected = quantize<int16_t>(expected_unquant, scale);
160+
161+
auto actual = qdata;
162+
executorch::fast_hadamard_transform_symmetric_quantized_s16(
163+
actual.data(), logN);
164+
165+
for (int ii = 0; ii < expected.size(); ++ii) {
166+
EXPECT_CLOSE(
167+
dequantize(actual[ii], scale), dequantize(expected[ii], scale));
168+
}
169+
}
170+
171+
} // namespace
172+
173+
TEST(QuantizedFastHadamardTransformTest, Basic) {
174+
testQuantizedFastHadamardTransform(12); // 4096
175+
}
176+
177+
TEST(QuantizedFastHadamardTransformTest, OddLogN) {
178+
testQuantizedFastHadamardTransform(11); // 2048
179+
}
180+
181+
TEST(QuantizedFastHadamardTransform28NTest, Basic) {
182+
std::vector<float> data = randomFloats(1024 * 28);
183+
184+
auto [qdata, scale] = quantize<int16_t>(data);
185+
186+
auto expected_unquant = dequantize(qdata, scale);
187+
fast_hadamard_transform_28N_with_transpose(expected_unquant.data(), 10);
188+
auto expected = quantize<int16_t>(expected_unquant, scale);
189+
190+
auto actual = qdata;
191+
executorch::fast_hadamard_transform_symmetric_quantized_s16_28N(
192+
actual.data(), 10);
193+
194+
for (int ii = 0; ii < expected.size(); ++ii) {
195+
std::cerr << "element " << ii << ": actual: " << actual[ii]
196+
<< ", expected: " << expected[ii] << std::endl;
197+
EXPECT_CLOSE(
198+
dequantize(actual[ii], scale), dequantize(expected[ii], scale));
199+
}
200+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
runtime.cxx_test(
10+
name = "fast_hadamard_transform_test",
11+
srcs = ["fast_hadamard_transform_test.cpp"],
12+
headers = ["fast_hadamard_transform_special_unstrided_cpu.h"],
13+
deps = [
14+
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
15+
"//executorch/extension/llm/custom_ops/spinquant/third-party/FFHT:dumb_fht",
16+
],
17+
)

extension/llm/custom_ops/spinquant/third-party/FFHT/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ CFLAGS = -O3 -march=native -std=c99 -pedantic -Wall -Wextra -Wshadow -Wpointer-a
33

44
all: test_float test_double fast_copy.o fht.o
55

6-
OBJ := fast_copy.o fht.o
6+
OBJ := dumb_fht.o fast_copy.o fht.o
77

88
%.o: %.c
99
$(CC) $< -o $@ -c $(CFLAGS)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
load(":targets.bzl", "define_common_targets")
2+
3+
oncall("executorch")
4+
5+
define_common_targets()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include "dumb_fht.h"
2+
3+
void dumb_fht(float* buf, int log_n) {
4+
int n = 1 << log_n;
5+
for (int i = 0; i < log_n; ++i) {
6+
int s1 = 1 << i;
7+
int s2 = s1 << 1;
8+
for (int j = 0; j < n; j += s2) {
9+
for (int k = 0; k < s1; ++k) {
10+
float u = buf[j + k];
11+
float v = buf[j + k + s1];
12+
buf[j + k] = u + v;
13+
buf[j + k + s1] = u - v;
14+
}
15+
}
16+
}
17+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef DUMB_FHT_H
2+
#define DUMB_FHT_H
3+
4+
#ifdef __cplusplus
5+
extern "C" {
6+
#endif
7+
8+
void dumb_fht(float* buf, int log_n);
9+
10+
#ifdef __cplusplus
11+
} // extern "C"
12+
#endif
13+
14+
#endif /* DUMB_FHT_H */
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
runtime.cxx_library(
10+
name = "dumb_fht",
11+
srcs = ["dumb_fht.c"],
12+
exported_headers = ["dumb_fht.h"],
13+
visibility = ["@EXECUTORCH_CLIENTS"],
14+
)
15+
16+
runtime.cxx_library(
17+
name = "fht",
18+
srcs = select({
19+
"DEFAULT": [],
20+
"ovr_config//cpu:arm64": ["fht_neon.c"],
21+
"ovr_config//cpu:x86_64": ["fht_avx.c"],
22+
}),
23+
exported_headers = ["fht.h"],
24+
visibility = ["@EXECUTORCH_CLIENTS"],
25+
)
26+
27+
runtime.cxx_binary(
28+
name = "test_float",
29+
srcs = ["test_float.c"],
30+
deps = [
31+
":dumb_fht",
32+
":fht",
33+
],
34+
)

0 commit comments

Comments
 (0)