Skip to content

Commit 37492bd

Browse files
committed
[executorch] Custom op for fast hadamard transform kernel
Pull Request resolved: #5291 Custom op support for Fast Hadamard Transform. ghstack-source-id: 242138454 @exported-using-ghexport Differential Revision: [D60530438](https://our.internmc.facebook.com/intern/diff/D60530438/)
1 parent 4c1d3de commit 37492bd

File tree

10 files changed

+316
-6
lines changed

10 files changed

+316
-6
lines changed

extension/llm/custom_ops/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
7575
add_library(
7676
custom_ops_aot_lib SHARED
7777
${_custom_ops__srcs} ${CMAKE_CURRENT_SOURCE_DIR}/op_sdpa_aot.cpp
78+
${CMAKE_CURRENT_SOURCE_DIR}/op_fast_hadamard_transform_aten.cpp
7879
${CMAKE_CURRENT_SOURCE_DIR}/op_tile_crop.cpp
7980
)
8081
target_include_directories(
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
10+
#include <executorch/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h>
11+
#include <executorch/kernels/optimized/utils/llvmMathExtras.h>
12+
#include <executorch/kernels/portable/cpu/util/reduce_util.h> // For apply_over_dim.
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
Tensor& fast_hadamard_transform_out(
20+
RuntimeContext& ctx,
21+
const Tensor& mat,
22+
Tensor& out) {
23+
ET_KERNEL_CHECK_MSG(
24+
ctx,
25+
resize_tensor(out, mat.sizes()) == Error::Ok,
26+
InvalidArgument,
27+
out,
28+
"Failed to resize output tensor.");
29+
30+
ET_KERNEL_CHECK(
31+
ctx, mat.scalar_type() == out.scalar_type(), InvalidArgument, out);
32+
33+
if (mat.dim() == 0) {
34+
return out;
35+
}
36+
37+
ET_KERNEL_CHECK_MSG(
38+
ctx,
39+
mat.strides().back() == 1,
40+
InvalidArgument,
41+
out,
42+
"input matrix that isn't contiguous in the last dimension is not supported!");
43+
44+
const auto last_dim_size = mat.sizes().back();
45+
const auto divisible_by_28 = last_dim_size % 28 == 0;
46+
auto power_of_two_size = divisible_by_28 ? last_dim_size / 28 : last_dim_size;
47+
ET_KERNEL_CHECK_MSG(
48+
ctx,
49+
(power_of_two_size & (power_of_two_size - 1)) == 0,
50+
InvalidArgument,
51+
out,
52+
"This implementation requires power-of-2 (or power-of-2 * 28) input size in the last dimension!");
53+
54+
const auto log2_power_of_two_size = executorch::llvm::countTrailingZeros(
55+
static_cast<unsigned int>(power_of_two_size),
56+
executorch::llvm::ZeroBehavior::ZB_Undefined);
57+
58+
ET_SWITCH_FLOATH_TYPES(mat.scalar_type(), ctx, __func__, CTYPE, [&] {
59+
const CTYPE* const mat_data = mat.const_data_ptr<CTYPE>();
60+
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
61+
62+
std::memcpy(out_data, mat_data, mat.numel() * sizeof(CTYPE));
63+
64+
if (divisible_by_28) {
65+
apply_over_dim(
66+
[log2_power_of_two_size, out_data](
67+
const size_t size, const size_t stride, const size_t base) {
68+
executorch::fast_hadamard_transform_28N(
69+
out_data + base, log2_power_of_two_size);
70+
},
71+
out,
72+
out.dim() - 1);
73+
} else {
74+
apply_over_dim(
75+
[log2_power_of_two_size, out_data](
76+
const size_t size, const size_t stride, const size_t base) {
77+
executorch::fast_hadamard_transform(
78+
out_data + base, log2_power_of_two_size);
79+
},
80+
out,
81+
out.dim() - 1);
82+
}
83+
});
84+
return out;
85+
}
86+
} // namespace native
87+
} // namespace executor
88+
} // namespace torch
89+
90+
EXECUTORCH_LIBRARY(
91+
llama,
92+
"fast_hadamard_transform.out",
93+
torch::executor::native::fast_hadamard_transform_out);
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch::executor::native {
14+
15+
// Compute the fast Walsh-Hadamard transform
16+
// (https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform)
17+
// of mat along the last dimension (which must be contiguous).
18+
//
19+
// mat.sizes().back() is currently required to be either a power of
20+
// two, or 28 * a power of two.
21+
Tensor& fast_hadamard_transform_out(
22+
RuntimeContext& ctx,
23+
const Tensor& mat,
24+
Tensor& out);
25+
} // namespace torch::executor::native
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
10+
#include <executorch/extension/llm/custom_ops/op_fast_hadamard_transform.h>
11+
12+
#include <torch/library.h>
13+
14+
namespace torch::executor::native {
15+
namespace {
16+
Tensor& fast_hadamard_transform_out_no_context(const Tensor& vec, Tensor& out) {
17+
exec_aten::RuntimeContext context;
18+
return fast_hadamard_transform_out(context, vec, out);
19+
}
20+
at::Tensor fast_hadamard_transform_aten(const at::Tensor& vec) {
21+
auto out = at::empty_like(vec);
22+
WRAP_TO_ATEN(fast_hadamard_transform_out_no_context, 1)
23+
(vec, out);
24+
return out;
25+
}
26+
} // namespace
27+
} // namespace torch::executor::native
28+
29+
TORCH_LIBRARY_FRAGMENT(llama, m) {
30+
m.def("fast_hadamard_transform(Tensor mat) -> Tensor");
31+
m.def(
32+
"fast_hadamard_transform.out(Tensor mat, *, Tensor(a!) out) -> Tensor(a!)");
33+
}
34+
35+
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
36+
m.impl(
37+
"fast_hadamard_transform",
38+
torch::executor::native::fast_hadamard_transform_aten);
39+
m.impl(
40+
"fast_hadamard_transform.out",
41+
WRAP_TO_ATEN(
42+
torch::executor::native::fast_hadamard_transform_out_no_context, 1));
43+
}

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace torch {
1616
namespace executor {
1717

1818
namespace native {
19-
19+
namespace {
2020
Tensor& sdpa_with_kv_cache_out_no_context(
2121
const Tensor& q_projected,
2222
const Tensor& k_projected,
@@ -81,12 +81,12 @@ at::Tensor sdpa_with_kv_cache_aten(
8181
output);
8282
return output;
8383
}
84-
84+
} // namespace
8585
} // namespace native
8686
} // namespace executor
8787
} // namespace torch
8888

89-
TORCH_LIBRARY(llama, m) {
89+
TORCH_LIBRARY_FRAGMENT(llama, m) {
9090
m.def(
9191
"sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
9292
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "

extension/llm/custom_ops/sdpa_with_kv_cache.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@
2020
try:
2121
op = torch.ops.llama.sdpa_with_kv_cache.default
2222
assert op is not None
23+
op2 = torch.ops.llama.fast_hadamard_transform.default
24+
assert op2 is not None
2325
except:
2426
libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*"))
2527
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
2628
logging.info(f"Loading custom ops library: {libs[0]}")
2729
torch.ops.load_library(libs[0])
2830
op = torch.ops.llama.sdpa_with_kv_cache.default
2931
assert op is not None
32+
op2 = torch.ops.llama.fast_hadamard_transform.default
33+
assert op2 is not None
3034

3135
custom_ops_lib = torch.library.Library("llama", "IMPL")
3236

@@ -126,3 +130,11 @@ def sdpa_with_kv_cache_meta(
126130
)
127131

128132
return torch.empty_like(query)
133+
134+
135+
@impl(custom_ops_lib, "fast_hadamard_transform", "Meta")
136+
def fast_hadamard_transform_meta(mat):
137+
# assert(mat.strides[-1] == 1, "input matrix must be contiguous in the last dimension!")
138+
# assert(mat.shape[-1] == 128 or mat.shape[-1] == 14336, "unexpected input size for llama3 demo!")
139+
# assert(mat.is_contiguous(), "input matrix must be contiguous currently!")
140+
return torch.empty_like(mat)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/llm/custom_ops/op_fast_hadamard_transform.h>
10+
#include <executorch/extension/llm/custom_ops/spinquant/third-party/FFHT/dumb_fht.h>
11+
#include <executorch/kernels/test/TestUtil.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15+
16+
#include <gtest/gtest.h>
17+
18+
#include <cmath>
19+
#include <random>
20+
21+
using exec_aten::Tensor;
22+
23+
namespace {
24+
Tensor& fast_hadamard_transform_nocontext(const Tensor& vec, Tensor& out) {
25+
exec_aten::RuntimeContext context;
26+
return torch::executor::native::fast_hadamard_transform_out(
27+
context, vec, out);
28+
}
29+
30+
void reference_fht_impl(float* buf, int n) {
31+
dumb_fht(buf, std::log2<int>(n));
32+
const auto root_n = std::sqrt(n);
33+
for (int ii = 0; ii < n; ++ii) {
34+
buf[ii] /= root_n;
35+
}
36+
}
37+
} // namespace
38+
39+
TEST(FastHadamardTransformTest, EmptyInput) {
40+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
41+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Byte> tfByte;
42+
auto vec = tfFloat.zeros({0});
43+
auto out = tfFloat.zeros({0});
44+
auto result = fast_hadamard_transform_nocontext(vec, out);
45+
EXPECT_EQ(result.numel(), 0);
46+
}
47+
48+
TEST(FastHadamardTransformTest, SingleElementInput) {
49+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
50+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Byte> tfByte;
51+
auto vec = tfFloat.ones({1});
52+
auto out = tfFloat.zeros({1});
53+
auto result = fast_hadamard_transform_nocontext(vec, out);
54+
EXPECT_EQ(result.numel(), 1);
55+
// FHT of a single element is a no-op.
56+
EXPECT_EQ(result.const_data_ptr<float>()[0], 1);
57+
}
58+
59+
TEST(FastHadamardTransformTest, FourKInput) {
60+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
61+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Byte> tfByte;
62+
std::random_device rd;
63+
std::mt19937 gen(rd());
64+
std::normal_distribution<float> dist;
65+
std::vector<float> data(4096);
66+
for (int ii = 0; ii < data.size(); ++ii) {
67+
data[ii] = dist(gen);
68+
}
69+
auto vec = tfFloat.make({4096}, data);
70+
auto out = tfFloat.zeros({4096});
71+
auto result = fast_hadamard_transform_nocontext(vec, out);
72+
73+
std::vector<float> reference_result = data;
74+
reference_fht_impl(reference_result.data(), reference_result.size());
75+
76+
const float* const result_data = result.const_data_ptr<float>();
77+
for (int ii = 0; ii < 4096; ++ii) {
78+
EXPECT_FLOAT_EQ(result_data[ii], reference_result[ii]);
79+
}
80+
}
81+
82+
TEST(FastHadamardTransformTest, MultipleRows) {
83+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
84+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Byte> tfByte;
85+
std::random_device rd;
86+
std::mt19937 gen(rd());
87+
std::normal_distribution<float> dist;
88+
std::vector<float> data(8 * 8 * 8);
89+
for (int ii = 0; ii < data.size(); ++ii) {
90+
data[ii] = dist(gen);
91+
}
92+
auto mat = tfFloat.make({8, 8, 8}, data);
93+
auto out = tfFloat.zeros({8, 8, 8});
94+
95+
auto result = fast_hadamard_transform_nocontext(mat, out);
96+
97+
std::vector<float> reference_result = data;
98+
for (int ii = 0; ii < 8; ++ii) {
99+
for (int jj = 0; jj < 8; ++jj) {
100+
reference_fht_impl(&reference_result[ii * 64 + jj * 8], 8);
101+
}
102+
}
103+
104+
const float* const result_data = result.const_data_ptr<float>();
105+
for (int ii = 0; ii < data.size(); ++ii) {
106+
EXPECT_FLOAT_EQ(result_data[ii], reference_result[ii]);
107+
}
108+
}

extension/llm/custom_ops/spinquant/test/targets.bzl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,13 @@ def define_common_targets():
1515
"//executorch/extension/llm/custom_ops/spinquant/third-party/FFHT:dumb_fht",
1616
],
1717
)
18+
19+
runtime.cxx_test(
20+
name = "op_fast_hadamard_transform_test",
21+
srcs = ["op_fast_hadamard_transform_test.cpp"],
22+
deps = [
23+
"//executorch/extension/llm/custom_ops:custom_ops",
24+
"//executorch/extension/llm/custom_ops/spinquant/third-party/FFHT:dumb_fht",
25+
"//executorch/kernels/test:test_util",
26+
],
27+
)

extension/llm/custom_ops/targets.bzl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,16 @@ def define_common_targets():
99
for mkl_dep in ["", "_mkl_noomp"]:
1010
runtime.cxx_library(
1111
name = "custom_ops" + mkl_dep,
12-
srcs = ["op_sdpa.cpp", "op_fallback.cpp"],
13-
exported_headers = ["op_sdpa.h", "op_fallback.h"],
12+
srcs = [
13+
"op_fallback.cpp",
14+
"op_fast_hadamard_transform.cpp",
15+
"op_sdpa.cpp",
16+
],
17+
exported_headers = [
18+
"op_fallback.h",
19+
"op_fast_hadamard_transform.h",
20+
"op_sdpa.h",
21+
],
1422
exported_deps = [
1523
"//executorch/runtime/kernel:kernel_includes",
1624
"//executorch/kernels/portable/cpu:scalar_utils",
@@ -20,6 +28,10 @@ def define_common_targets():
2028
"//executorch/extension/parallel:thread_parallel",
2129
"//executorch/extension/threadpool:threadpool",
2230
],
31+
deps = [
32+
"//executorch/kernels/portable/cpu/util:reduce_util",
33+
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
34+
],
2335
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
2436
visibility = [
2537
"//executorch/...",
@@ -35,7 +47,9 @@ def define_common_targets():
3547
name = "custom_ops_aot_lib" + mkl_dep,
3648
srcs = [
3749
"op_sdpa_aot.cpp",
50+
"op_fast_hadamard_transform_aten.cpp",
3851
],
52+
compiler_flags = ["-Wno-global-constructors"],
3953
visibility = [
4054
"//executorch/...",
4155
"@EXECUTORCH_CLIENTS",

kernels/portable/cpu/util/targets.bzl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,5 +249,9 @@ def define_common_targets():
249249
"//executorch/runtime/core/exec_aten/util:tensor_util{}".format(suffix),
250250
],
251251
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
252-
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/quantized/..."],
252+
visibility = [
253+
"//executorch/extension/llm/custom_ops/...",
254+
"//executorch/kernels/portable/cpu/...",
255+
"//executorch/kernels/quantized/...",
256+
],
253257
)

0 commit comments

Comments
 (0)