Skip to content

Commit b1526d5

Browse files
committed
[ExecuTorch] Refactor fast_hadamard_transform_test shared implementation functions
Following diff will use this stuff to test the ExecuTorch operator as well. Differential Revision: [D62760050](https://our.internmc.facebook.com/intern/diff/D62760050/) ghstack-source-id: 242784012 Pull Request resolved: #5388
1 parent 54ba1b0 commit b1526d5

File tree

4 files changed

+114
-60
lines changed

4 files changed

+114
-60
lines changed

extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_test.cpp

Lines changed: 14 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,72 +7,30 @@
77
*/
88

99
#include <algorithm>
10+
#include <array>
1011
#include <cmath>
11-
#include <iostream>
12-
#include <random>
12+
#include <type_traits>
13+
#include <utility>
14+
#include <vector>
1315

1416
#include <gtest/gtest.h>
1517

1618
#include <executorch/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h>
17-
#include <executorch/extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h>
18-
#include <executorch/extension/llm/custom_ops/spinquant/third-party/FFHT/dumb_fht.h>
19+
#include <executorch/extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_test_impl.h>
1920

20-
namespace {
21-
void reference_fht_impl(float* buf, int n) {
22-
dumb_fht(buf, std::log2<int>(n));
23-
const auto root_n = std::sqrt(n);
24-
for (int ii = 0; ii < n; ++ii) {
25-
buf[ii] /= root_n;
26-
}
27-
}
28-
29-
// Alternate implementation of fast_hadamard_transform_28N to mutation
30-
// test against. Benchmarking suggests this one is slower, which is
31-
// why it's in the test and the strided implementation is in the
32-
// header.
33-
template <typename T>
34-
void fast_hadamard_transform_28N_with_transpose(T* vec, int log2_vec_size) {
35-
const int vec_size = (1 << log2_vec_size);
36-
for (int ii = 0; ii < 28; ++ii) {
37-
executorch::fast_hadamard_transform(&vec[ii * vec_size], log2_vec_size);
38-
}
39-
std::unique_ptr<T[]> transposed = std::make_unique<T[]>(28 * vec_size);
40-
for (int ii = 0; ii < 28; ++ii) {
41-
for (int jj = 0; jj < vec_size; ++jj) {
42-
transposed[jj * 28 + ii] = vec[ii * vec_size + jj];
43-
}
44-
}
45-
for (int ii = 0; ii < vec_size; ++ii) {
46-
hadamard_mult_28(&transposed[ii * 28]);
47-
}
48-
for (int jj = 0; jj < vec_size; ++jj) {
49-
for (int ii = 0; ii < 28; ++ii) {
50-
vec[ii * vec_size + jj] = transposed[jj * 28 + ii];
51-
}
52-
}
53-
}
54-
55-
std::vector<float> randomFloats(int howMany) {
56-
std::random_device rd;
57-
std::mt19937 gen(rd());
58-
std::normal_distribution<float> dist;
59-
std::vector<float> data(howMany);
60-
for (int ii = 0; ii < data.size(); ++ii) {
61-
data[ii] = dist(gen);
62-
}
63-
return data;
64-
}
65-
} // namespace
21+
using executorch::runtime::testing::fast_hadamard_transform_28N_with_transpose;
22+
using executorch::runtime::testing::random_floats;
23+
using executorch::runtime::testing::reference_fht_impl;
6624

6725
TEST(FastHadamardTransformTest, SingleElement) {
6826
// FHT of a single element is a no-op.
69-
float data[1] = {42};
70-
executorch::fast_hadamard_transform(data, 0);
27+
std::array<float, 1> data = {{42}};
28+
executorch::fast_hadamard_transform(data.data(), 0);
7129
EXPECT_EQ(data[0], 42);
7230
}
7331

7432
TEST(FastHadamardTransformTest, LargerInput) {
75-
std::vector<float> data = randomFloats(4096);
33+
std::vector<float> data = random_floats(4096);
7634

7735
auto expected = data;
7836
reference_fht_impl(expected.data(), expected.size());
@@ -86,7 +44,7 @@ TEST(FastHadamardTransformTest, LargerInput) {
8644
}
8745

8846
TEST(FastHadamardTransform28NTest, Basic) {
89-
std::vector<float> data = randomFloats(1024 * 28);
47+
std::vector<float> data = random_floats(1024 * 28);
9048

9149
auto expected = data;
9250
fast_hadamard_transform_28N_with_transpose(expected.data(), 10);
@@ -150,7 +108,7 @@ std::vector<float> dequantize(const std::vector<T>& data, float scale) {
150108
#define EXPECT_CLOSE(a, b) EXPECT_CLOSE_IMPL(a, b, 2e-4, 1e-4)
151109

152110
void testQuantizedFastHadamardTransform(int logN) {
153-
std::vector<float> data = randomFloats(1 << logN);
111+
std::vector<float> data = random_floats(1 << logN);
154112

155113
auto [qdata, scale] = quantize<int16_t>(data);
156114

@@ -179,7 +137,7 @@ TEST(QuantizedFastHadamardTransformTest, OddLogN) {
179137
}
180138

181139
TEST(QuantizedFastHadamardTransform28NTest, Basic) {
182-
std::vector<float> data = randomFloats(1024 * 28);
140+
std::vector<float> data = random_floats(1024 * 28);
183141

184142
auto [qdata, scale] = quantize<int16_t>(data);
185143

@@ -192,8 +150,6 @@ TEST(QuantizedFastHadamardTransform28NTest, Basic) {
192150
actual.data(), 10);
193151

194152
for (int ii = 0; ii < expected.size(); ++ii) {
195-
std::cerr << "element " << ii << ": actual: " << actual[ii]
196-
<< ", expected: " << expected[ii] << std::endl;
197153
EXPECT_CLOSE(
198154
dequantize(actual[ii], scale), dequantize(expected[ii], scale));
199155
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_test_impl.h>
10+
#include <executorch/extension/llm/custom_ops/spinquant/third-party/FFHT/dumb_fht.h>
11+
12+
#include <cmath>
13+
#include <random>
14+
#include <vector>
15+
16+
namespace executorch::runtime::testing {
17+
18+
void reference_fht_impl(float* buf, int n) {
19+
dumb_fht(buf, std::log2<int>(n));
20+
const auto root_n = std::sqrt(n);
21+
for (int ii = 0; ii < n; ++ii) {
22+
buf[ii] /= root_n;
23+
}
24+
}
25+
26+
std::vector<float> random_floats(int howMany) {
27+
std::random_device rd;
28+
std::mt19937 gen(rd());
29+
std::normal_distribution<float> dist;
30+
std::vector<float> data(howMany);
31+
for (int ii = 0; ii < data.size(); ++ii) {
32+
data[ii] = dist(gen);
33+
}
34+
return data;
35+
}
36+
37+
} // namespace executorch::runtime::testing
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <memory>
12+
#include <vector>
13+
14+
#include <executorch/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h>
15+
#include <executorch/extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h>
16+
17+
namespace executorch::runtime::testing {
18+
void reference_fht_impl(float* buf, int n);
19+
20+
// Alternate implementation of fast_hadamard_transform_28N to mutation
21+
// test against. Benchmarking suggests this one is slower, which is
22+
// why it's in the test.
23+
template <typename T>
24+
void fast_hadamard_transform_28N_with_transpose(T* vec, int log2_vec_size) {
25+
const int vec_size = (1 << log2_vec_size);
26+
for (int ii = 0; ii < 28; ++ii) {
27+
executorch::fast_hadamard_transform(&vec[ii * vec_size], log2_vec_size);
28+
}
29+
std::unique_ptr<T[]> transposed = std::make_unique<T[]>(28 * vec_size);
30+
for (int ii = 0; ii < 28; ++ii) {
31+
for (int jj = 0; jj < vec_size; ++jj) {
32+
transposed[jj * 28 + ii] = vec[ii * vec_size + jj];
33+
}
34+
}
35+
for (int ii = 0; ii < vec_size; ++ii) {
36+
hadamard_mult_28(&transposed[ii * 28]);
37+
}
38+
for (int jj = 0; jj < vec_size; ++jj) {
39+
for (int ii = 0; ii < 28; ++ii) {
40+
vec[ii * vec_size + jj] = transposed[jj * 28 + ii];
41+
}
42+
}
43+
}
44+
45+
std::vector<float> random_floats(int howMany);
46+
47+
} // namespace executorch::runtime::testing

extension/llm/custom_ops/spinquant/test/targets.bzl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,26 @@ def define_common_targets():
66
The directory containing this targets.bzl file should also contain both
77
TARGETS and BUCK files that call this function.
88
"""
9+
runtime.cxx_library(
10+
name = "fast_hadamard_transform_test_impl",
11+
srcs = ["fast_hadamard_transform_test_impl.cpp"],
12+
headers = [
13+
"fast_hadamard_transform_special_unstrided_cpu.h",
14+
"fast_hadamard_transform_test_impl.h",
15+
],
16+
exported_deps = [
17+
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
18+
],
19+
deps = [
20+
"//executorch/extension/llm/custom_ops/spinquant/third-party/FFHT:dumb_fht",
21+
],
22+
)
23+
924
runtime.cxx_test(
1025
name = "fast_hadamard_transform_test",
1126
srcs = ["fast_hadamard_transform_test.cpp"],
12-
headers = ["fast_hadamard_transform_special_unstrided_cpu.h"],
1327
deps = [
28+
":fast_hadamard_transform_test_impl",
1429
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
15-
"//executorch/extension/llm/custom_ops/spinquant/third-party/FFHT:dumb_fht",
1630
],
1731
)

0 commit comments

Comments
 (0)