Skip to content

Commit 3a38ad1

Browse files
committed
[executorch] Custom op for fast hadamard transform kernel
Pull Request resolved: #5291 Custom op support for Fast Hadamard Transform. ghstack-source-id: 242537091 @exported-using-ghexport Differential Revision: [D60530438](https://our.internmc.facebook.com/intern/diff/D60530438/)
1 parent 7b779ec commit 3a38ad1

12 files changed

+316
-8
lines changed

build/cmake_deps.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,10 @@ buck_targets = [
324324
"//extension/llm/custom_ops:custom_ops",
325325
]
326326
filters = [
327-
".cpp$",
327+
# Second clause is to pick up fht_neon.c/fht_avx.c from FFHT. TODO:
328+
# remove filters and patch extract_sources.py's Buck query to fetch
329+
# srcs; presumably filters is here to remove .h files.
330+
"(.cpp$)|(fht.*\\.c$)",
328331
]
329332
excludes = [
330333
"^codegen",

build/extract_sources.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def main():
214214
buck_args = ["--target-platforms"]
215215
buck_args.append(args.target_platforms)
216216
for name, target in graph.by_name.items():
217-
target_to_srcs[name] = sorted(target.get_sources(graph, runner))
217+
target_to_srcs[name] = sorted(target.get_sources(graph, runner, buck_args))
218218

219219
# Generate the requested format.
220220
output: bytes

extension/llm/custom_ops/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
7575
add_library(
7676
custom_ops_aot_lib SHARED
7777
${_custom_ops__srcs} ${CMAKE_CURRENT_SOURCE_DIR}/op_sdpa_aot.cpp
78+
${CMAKE_CURRENT_SOURCE_DIR}/op_fast_hadamard_transform_aten.cpp
7879
${CMAKE_CURRENT_SOURCE_DIR}/op_tile_crop.cpp
7980
)
8081
target_include_directories(
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
10+
#include <executorch/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h>
11+
#include <executorch/kernels/optimized/utils/llvmMathExtras.h>
12+
#include <executorch/kernels/portable/cpu/util/reduce_util.h> // For apply_over_dim.
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
Tensor& fast_hadamard_transform_out(
20+
RuntimeContext& ctx,
21+
const Tensor& mat,
22+
Tensor& out) {
23+
ET_KERNEL_CHECK_MSG(
24+
ctx,
25+
resize_tensor(out, mat.sizes()) == Error::Ok,
26+
InvalidArgument,
27+
out,
28+
"Failed to resize output tensor.");
29+
30+
ET_KERNEL_CHECK(
31+
ctx, mat.scalar_type() == out.scalar_type(), InvalidArgument, out);
32+
33+
if (mat.dim() == 0 || mat.numel() == 0) {
34+
return out;
35+
}
36+
37+
ET_KERNEL_CHECK_MSG(
38+
ctx,
39+
mat.strides().back() == 1,
40+
InvalidArgument,
41+
out,
42+
"input matrix that isn't contiguous in the last dimension is not supported!");
43+
44+
const auto last_dim_size = mat.sizes().back();
45+
const auto divisible_by_28 = last_dim_size % 28 == 0;
46+
auto power_of_two_size = divisible_by_28 ? last_dim_size / 28 : last_dim_size;
47+
ET_KERNEL_CHECK_MSG(
48+
ctx,
49+
(power_of_two_size & (power_of_two_size - 1)) == 0,
50+
InvalidArgument,
51+
out,
52+
"This implementation requires power-of-2 (or power-of-2 * 28) input size in the last dimension!");
53+
54+
const auto log2_power_of_two_size = executorch::llvm::countTrailingZeros(
55+
static_cast<unsigned int>(power_of_two_size),
56+
executorch::llvm::ZeroBehavior::ZB_Undefined);
57+
58+
ET_SWITCH_FLOATH_TYPES(mat.scalar_type(), ctx, __func__, CTYPE, [&] {
59+
const CTYPE* const mat_data = mat.const_data_ptr<CTYPE>();
60+
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
61+
62+
std::memcpy(out_data, mat_data, mat.numel() * sizeof(CTYPE));
63+
64+
if (divisible_by_28) {
65+
apply_over_dim(
66+
[log2_power_of_two_size, out_data](
67+
const size_t size, const size_t stride, const size_t base) {
68+
executorch::fast_hadamard_transform_28N(
69+
out_data + base, log2_power_of_two_size);
70+
},
71+
out,
72+
out.dim() - 1);
73+
} else {
74+
apply_over_dim(
75+
[log2_power_of_two_size, out_data](
76+
const size_t size, const size_t stride, const size_t base) {
77+
executorch::fast_hadamard_transform(
78+
out_data + base, log2_power_of_two_size);
79+
},
80+
out,
81+
out.dim() - 1);
82+
}
83+
});
84+
return out;
85+
}
86+
} // namespace native
87+
} // namespace executor
88+
} // namespace torch
89+
90+
EXECUTORCH_LIBRARY(
91+
llama,
92+
"fast_hadamard_transform.out",
93+
torch::executor::native::fast_hadamard_transform_out);
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch::executor::native {
14+
15+
// Compute the fast Walsh-Hadamard transform
16+
// (https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform)
17+
// of mat along the last dimension (which must be contiguous).
18+
//
19+
// mat.sizes().back() is currently required to be either a power of
20+
// two, or 28 * a power of two.
21+
Tensor& fast_hadamard_transform_out(
22+
RuntimeContext& ctx,
23+
const Tensor& mat,
24+
Tensor& out);
25+
} // namespace torch::executor::native
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
10+
#include <executorch/extension/llm/custom_ops/op_fast_hadamard_transform.h>
11+
12+
#include <torch/library.h>
13+
14+
namespace torch::executor::native {
15+
namespace {
16+
Tensor& fast_hadamard_transform_out_no_context(const Tensor& vec, Tensor& out) {
17+
exec_aten::RuntimeContext context;
18+
return fast_hadamard_transform_out(context, vec, out);
19+
}
20+
at::Tensor fast_hadamard_transform_aten(const at::Tensor& vec) {
21+
auto out = at::empty_like(vec);
22+
WRAP_TO_ATEN(fast_hadamard_transform_out_no_context, 1)
23+
(vec, out);
24+
return out;
25+
}
26+
} // namespace
27+
} // namespace torch::executor::native
28+
29+
TORCH_LIBRARY_FRAGMENT(llama, m) {
30+
m.def("fast_hadamard_transform(Tensor mat) -> Tensor");
31+
m.def(
32+
"fast_hadamard_transform.out(Tensor mat, *, Tensor(a!) out) -> Tensor(a!)");
33+
}
34+
35+
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
36+
m.impl(
37+
"fast_hadamard_transform",
38+
torch::executor::native::fast_hadamard_transform_aten);
39+
m.impl(
40+
"fast_hadamard_transform.out",
41+
WRAP_TO_ATEN(
42+
torch::executor::native::fast_hadamard_transform_out_no_context, 1));
43+
}

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace torch {
1616
namespace executor {
1717

1818
namespace native {
19-
19+
namespace {
2020
Tensor& sdpa_with_kv_cache_out_no_context(
2121
const Tensor& q_projected,
2222
const Tensor& k_projected,
@@ -81,12 +81,12 @@ at::Tensor sdpa_with_kv_cache_aten(
8181
output);
8282
return output;
8383
}
84-
84+
} // namespace
8585
} // namespace native
8686
} // namespace executor
8787
} // namespace torch
8888

89-
TORCH_LIBRARY(llama, m) {
89+
TORCH_LIBRARY_FRAGMENT(llama, m) {
9090
m.def(
9191
"sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
9292
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "

extension/llm/custom_ops/sdpa_with_kv_cache.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@
2020
try:
2121
op = torch.ops.llama.sdpa_with_kv_cache.default
2222
assert op is not None
23+
op2 = torch.ops.llama.fast_hadamard_transform.default
24+
assert op2 is not None
2325
except:
2426
libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*"))
2527
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
2628
logging.info(f"Loading custom ops library: {libs[0]}")
2729
torch.ops.load_library(libs[0])
2830
op = torch.ops.llama.sdpa_with_kv_cache.default
2931
assert op is not None
32+
op2 = torch.ops.llama.fast_hadamard_transform.default
33+
assert op2 is not None
3034

3135
custom_ops_lib = torch.library.Library("llama", "IMPL")
3236

@@ -126,3 +130,11 @@ def sdpa_with_kv_cache_meta(
126130
)
127131

128132
return torch.empty_like(query)
133+
134+
135+
@impl(custom_ops_lib, "fast_hadamard_transform", "Meta")
136+
def fast_hadamard_transform_meta(mat):
137+
# assert(mat.strides[-1] == 1, "input matrix must be contiguous in the last dimension!")
138+
# assert(mat.shape[-1] == 128 or mat.shape[-1] == 14336, "unexpected input size for llama3 demo!")
139+
# assert(mat.is_contiguous(), "input matrix must be contiguous currently!")
140+
return torch.empty_like(mat)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/llm/custom_ops/op_fast_hadamard_transform.h>
10+
#include <executorch/extension/llm/custom_ops/spinquant/third-party/FFHT/dumb_fht.h>
11+
#include <executorch/kernels/test/TestUtil.h>
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13+
14+
#include <gtest/gtest.h>
15+
16+
#include <cmath>
17+
#include <random>
18+
19+
using exec_aten::Tensor;
20+
21+
namespace {
22+
Tensor& fast_hadamard_transform_nocontext(const Tensor& vec, Tensor& out) {
23+
exec_aten::RuntimeContext context;
24+
return torch::executor::native::fast_hadamard_transform_out(
25+
context, vec, out);
26+
}
27+
28+
void reference_fht_impl(float* buf, int n) {
29+
dumb_fht(buf, std::log2<int>(n));
30+
const auto root_n = std::sqrt(n);
31+
for (int ii = 0; ii < n; ++ii) {
32+
buf[ii] /= root_n;
33+
}
34+
}
35+
} // namespace
36+
37+
TEST(FastHadamardTransformTest, EmptyInput) {
38+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
39+
auto vec = tfFloat.zeros({0});
40+
auto out = tfFloat.zeros({0});
41+
auto result = fast_hadamard_transform_nocontext(vec, out);
42+
EXPECT_EQ(result.numel(), 0);
43+
}
44+
45+
TEST(FastHadamardTransformTest, SingleElementInput) {
46+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
47+
auto vec = tfFloat.ones({1});
48+
auto out = tfFloat.zeros({1});
49+
auto result = fast_hadamard_transform_nocontext(vec, out);
50+
EXPECT_EQ(result.numel(), 1);
51+
// FHT of a single element is a no-op.
52+
EXPECT_EQ(result.const_data_ptr<float>()[0], 1);
53+
}
54+
55+
TEST(FastHadamardTransformTest, FourKInput) {
56+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
57+
std::random_device rd;
58+
std::mt19937 gen(rd());
59+
std::normal_distribution<float> dist;
60+
std::vector<float> data(4096);
61+
for (int ii = 0; ii < data.size(); ++ii) {
62+
data[ii] = dist(gen);
63+
}
64+
auto vec = tfFloat.make({4096}, data);
65+
auto out = tfFloat.zeros({4096});
66+
auto result = fast_hadamard_transform_nocontext(vec, out);
67+
68+
std::vector<float> reference_result = data;
69+
reference_fht_impl(reference_result.data(), reference_result.size());
70+
71+
const float* const result_data = result.const_data_ptr<float>();
72+
for (int ii = 0; ii < 4096; ++ii) {
73+
EXPECT_FLOAT_EQ(result_data[ii], reference_result[ii]);
74+
}
75+
}
76+
77+
TEST(FastHadamardTransformTest, MultipleRows) {
78+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
79+
std::random_device rd;
80+
std::mt19937 gen(rd());
81+
std::normal_distribution<float> dist;
82+
std::vector<float> data(8 * 8 * 8);
83+
for (int ii = 0; ii < data.size(); ++ii) {
84+
data[ii] = dist(gen);
85+
}
86+
auto mat = tfFloat.make({8, 8, 8}, data);
87+
auto out = tfFloat.zeros({8, 8, 8});
88+
89+
auto result = fast_hadamard_transform_nocontext(mat, out);
90+
91+
std::vector<float> reference_result = data;
92+
for (int ii = 0; ii < 8; ++ii) {
93+
for (int jj = 0; jj < 8; ++jj) {
94+
reference_fht_impl(&reference_result[ii * 64 + jj * 8], 8);
95+
}
96+
}
97+
98+
const float* const result_data = result.const_data_ptr<float>();
99+
for (int ii = 0; ii < data.size(); ++ii) {
100+
EXPECT_FLOAT_EQ(result_data[ii], reference_result[ii]);
101+
}
102+
}

extension/llm/custom_ops/spinquant/test/targets.bzl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,14 @@ def define_common_targets():
1515
"//executorch/extension/llm/custom_ops/spinquant/third-party/FFHT:dumb_fht",
1616
],
1717
)
18+
19+
runtime.cxx_test(
20+
name = "op_fast_hadamard_transform_test",
21+
srcs = ["op_fast_hadamard_transform_test.cpp"],
22+
deps = [
23+
"//executorch/extension/llm/custom_ops:custom_ops",
24+
"//executorch/extension/llm/custom_ops/spinquant/third-party/FFHT:dumb_fht",
25+
"//executorch/kernels/test:test_util",
26+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
27+
],
28+
)

0 commit comments

Comments
 (0)