Skip to content

Commit acfe0ba

Browse files
swolchokfacebook-github-bot
authored andcommitted
Custom op for fast hadamard transform kernel (#5291)
Summary: Pull Request resolved: #5291 Custom op support for Fast Hadamard Transform. ghstack-source-id: 243051223 exported-using-ghexport Reviewed By: kimishpatel, helunwencser Differential Revision: D60530438 fbshipit-source-id: a483eea58e9897c4e2042157a80b0ddaed79a17b
1 parent 19f5ed8 commit acfe0ba

File tree

11 files changed

+333
-9
lines changed

11 files changed

+333
-9
lines changed

build/cmake_deps.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,10 @@ buck_targets = [
337337
"//extension/llm/custom_ops:custom_ops",
338338
]
339339
filters = [
340-
".cpp$",
340+
# Second clause is to pick up fht_neon.c/fht_avx.c from FFHT. TODO:
341+
# remove filters and patch extract_sources.py's Buck query to fetch
342+
# srcs; presumably filters is here to remove .h files.
343+
"(.cpp$)|(fht.*\\.c$)",
341344
]
342345
excludes = [
343346
"^codegen",

extension/llm/custom_ops/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
7575
add_library(
7676
custom_ops_aot_lib SHARED
7777
${_custom_ops__srcs} ${CMAKE_CURRENT_SOURCE_DIR}/op_sdpa_aot.cpp
78-
${CMAKE_CURRENT_SOURCE_DIR}/op_tile_crop_aot.cpp
78+
${CMAKE_CURRENT_SOURCE_DIR}/op_fast_hadamard_transform_aten.cpp
7979
${CMAKE_CURRENT_SOURCE_DIR}/op_tile_crop.cpp
80+
${CMAKE_CURRENT_SOURCE_DIR}/op_tile_crop_aot.cpp
8081
)
8182
target_include_directories(
8283
custom_ops_aot_lib PUBLIC "${_common_include_directories}"
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
10+
#include <executorch/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h>
11+
#include <executorch/kernels/optimized/utils/llvmMathExtras.h>
12+
#include <executorch/kernels/portable/cpu/util/reduce_util.h> // For apply_over_dim.
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
Tensor& fast_hadamard_transform_out(
20+
RuntimeContext& ctx,
21+
const Tensor& mat,
22+
Tensor& out) {
23+
ET_KERNEL_CHECK_MSG(
24+
ctx,
25+
resize_tensor(out, mat.sizes()) == Error::Ok,
26+
InvalidArgument,
27+
out,
28+
"Failed to resize output tensor.");
29+
30+
ET_KERNEL_CHECK(
31+
ctx, mat.scalar_type() == out.scalar_type(), InvalidArgument, out);
32+
33+
if (mat.dim() == 0 || mat.numel() == 0) {
34+
return out;
35+
}
36+
37+
ET_KERNEL_CHECK_MSG(
38+
ctx,
39+
mat.strides().back() == 1,
40+
InvalidArgument,
41+
out,
42+
"input matrix that isn't contiguous in the last dimension is not supported!");
43+
44+
const auto last_dim_size = mat.sizes().back();
45+
const auto divisible_by_28 = last_dim_size % 28 == 0;
46+
auto power_of_two_size = divisible_by_28 ? last_dim_size / 28 : last_dim_size;
47+
ET_KERNEL_CHECK_MSG(
48+
ctx,
49+
(power_of_two_size & (power_of_two_size - 1)) == 0,
50+
InvalidArgument,
51+
out,
52+
"This implementation requires power-of-2 (or power-of-2 * 28) input size in the last dimension!");
53+
54+
const auto log2_power_of_two_size = executorch::llvm::countTrailingZeros(
55+
static_cast<unsigned int>(power_of_two_size),
56+
executorch::llvm::ZeroBehavior::ZB_Undefined);
57+
58+
ET_SWITCH_FLOATH_TYPES(mat.scalar_type(), ctx, __func__, CTYPE, [&] {
59+
const CTYPE* const mat_data = mat.const_data_ptr<CTYPE>();
60+
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
61+
62+
std::memcpy(out_data, mat_data, mat.numel() * sizeof(CTYPE));
63+
64+
if (divisible_by_28) {
65+
apply_over_dim(
66+
[log2_power_of_two_size, out_data](
67+
const size_t size, const size_t stride, const size_t base) {
68+
executorch::fast_hadamard_transform_28N(
69+
out_data + base, log2_power_of_two_size);
70+
},
71+
out,
72+
out.dim() - 1);
73+
} else {
74+
apply_over_dim(
75+
[log2_power_of_two_size, out_data](
76+
const size_t size, const size_t stride, const size_t base) {
77+
executorch::fast_hadamard_transform(
78+
out_data + base, log2_power_of_two_size);
79+
},
80+
out,
81+
out.dim() - 1);
82+
}
83+
});
84+
return out;
85+
}
86+
} // namespace native
87+
} // namespace executor
88+
} // namespace torch
89+
90+
EXECUTORCH_LIBRARY(
91+
llama,
92+
"fast_hadamard_transform.out",
93+
torch::executor::native::fast_hadamard_transform_out);
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch::executor::native {
14+
15+
// Compute the fast Walsh-Hadamard transform
16+
// (https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform)
17+
// of mat along the last dimension (which must be contiguous).
18+
//
19+
// mat.sizes().back() is currently required to be either a power of
20+
// two, or 28 * a power of two.
21+
Tensor& fast_hadamard_transform_out(
22+
RuntimeContext& ctx,
23+
const Tensor& mat,
24+
Tensor& out);
25+
} // namespace torch::executor::native
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
10+
#include <executorch/extension/llm/custom_ops/op_fast_hadamard_transform.h>
11+
12+
#include <torch/library.h>
13+
14+
namespace torch::executor::native {
15+
namespace {
16+
Tensor& fast_hadamard_transform_out_no_context(const Tensor& vec, Tensor& out) {
17+
exec_aten::RuntimeContext context;
18+
return fast_hadamard_transform_out(context, vec, out);
19+
}
20+
at::Tensor fast_hadamard_transform_aten(const at::Tensor& vec) {
21+
auto out = at::empty_like(vec);
22+
WRAP_TO_ATEN(fast_hadamard_transform_out_no_context, 1)
23+
(vec, out);
24+
return out;
25+
}
26+
} // namespace
27+
} // namespace torch::executor::native
28+
29+
TORCH_LIBRARY_FRAGMENT(llama, m) {
30+
m.def("fast_hadamard_transform(Tensor mat) -> Tensor");
31+
m.def(
32+
"fast_hadamard_transform.out(Tensor mat, *, Tensor(a!) out) -> Tensor(a!)");
33+
}
34+
35+
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
36+
m.impl(
37+
"fast_hadamard_transform",
38+
torch::executor::native::fast_hadamard_transform_aten);
39+
m.impl(
40+
"fast_hadamard_transform.out",
41+
WRAP_TO_ATEN(
42+
torch::executor::native::fast_hadamard_transform_out_no_context, 1));
43+
}

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace torch {
1616
namespace executor {
1717

1818
namespace native {
19-
19+
namespace {
2020
Tensor& sdpa_with_kv_cache_out_no_context(
2121
const Tensor& q_projected,
2222
const Tensor& k_projected,
@@ -81,12 +81,12 @@ at::Tensor sdpa_with_kv_cache_aten(
8181
output);
8282
return output;
8383
}
84-
84+
} // namespace
8585
} // namespace native
8686
} // namespace executor
8787
} // namespace torch
8888

89-
TORCH_LIBRARY(llama, m) {
89+
TORCH_LIBRARY_FRAGMENT(llama, m) {
9090
m.def(
9191
"sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
9292
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "

extension/llm/custom_ops/sdpa_with_kv_cache.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@
2020
try:
2121
op = torch.ops.llama.sdpa_with_kv_cache.default
2222
assert op is not None
23+
op2 = torch.ops.llama.fast_hadamard_transform.default
24+
assert op2 is not None
2325
except:
2426
libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*"))
2527
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
2628
logging.info(f"Loading custom ops library: {libs[0]}")
2729
torch.ops.load_library(libs[0])
2830
op = torch.ops.llama.sdpa_with_kv_cache.default
2931
assert op is not None
32+
op2 = torch.ops.llama.fast_hadamard_transform.default
33+
assert op2 is not None
3034

3135
custom_ops_lib = torch.library.Library("llama", "IMPL")
3236

@@ -126,3 +130,11 @@ def sdpa_with_kv_cache_meta(
126130
)
127131

128132
return torch.empty_like(query)
133+
134+
135+
@impl(custom_ops_lib, "fast_hadamard_transform", "Meta")
136+
def fast_hadamard_transform_meta(mat):
137+
# assert(mat.strides[-1] == 1, "input matrix must be contiguous in the last dimension!")
138+
# assert(mat.shape[-1] == 128 or mat.shape[-1] == 14336, "unexpected input size for llama3 demo!")
139+
# assert(mat.is_contiguous(), "input matrix must be contiguous currently!")
140+
return torch.empty_like(mat)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/llm/custom_ops/op_fast_hadamard_transform.h>
10+
#include <executorch/extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_test_impl.h>
11+
#include <executorch/kernels/test/TestUtil.h>
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13+
14+
#include <gtest/gtest.h>
15+
16+
#include <cmath>
17+
18+
using exec_aten::Tensor;
19+
20+
using executorch::runtime::testing::fast_hadamard_transform_28N_with_transpose;
21+
using executorch::runtime::testing::random_floats;
22+
using executorch::runtime::testing::reference_fht_impl;
23+
24+
namespace {
25+
Tensor& fast_hadamard_transform_nocontext(const Tensor& vec, Tensor& out) {
26+
exec_aten::RuntimeContext context;
27+
return torch::executor::native::fast_hadamard_transform_out(
28+
context, vec, out);
29+
}
30+
} // namespace
31+
32+
TEST(OpFastHadamardTransformTest, EmptyInput) {
33+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
34+
auto vec = tfFloat.zeros({0});
35+
auto out = tfFloat.zeros({0});
36+
auto result = fast_hadamard_transform_nocontext(vec, out);
37+
EXPECT_EQ(result.numel(), 0);
38+
}
39+
40+
TEST(OpFastHadamardTransformTest, SingleElementInput) {
41+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
42+
auto vec = tfFloat.ones({1});
43+
auto out = tfFloat.zeros({1});
44+
auto result = fast_hadamard_transform_nocontext(vec, out);
45+
EXPECT_EQ(result.numel(), 1);
46+
// FHT of a single element is a no-op.
47+
EXPECT_EQ(result.const_data_ptr<float>()[0], 1);
48+
}
49+
50+
TEST(OpFastHadamardTransformTest, FourKInput) {
51+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
52+
std::vector<float> data = random_floats(4096);
53+
auto vec = tfFloat.make({4096}, data);
54+
auto out = tfFloat.zeros({4096});
55+
auto result = fast_hadamard_transform_nocontext(vec, out);
56+
57+
std::vector<float> reference_result = data;
58+
reference_fht_impl(reference_result.data(), reference_result.size());
59+
60+
const float* const result_data = result.const_data_ptr<float>();
61+
for (int ii = 0; ii < data.size(); ++ii) {
62+
EXPECT_FLOAT_EQ(result_data[ii], reference_result[ii]);
63+
}
64+
}
65+
66+
TEST(OpFastHadamardTransformTest, MultipleRows) {
67+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
68+
std::vector<float> data = random_floats(8 * 8 * 8);
69+
auto mat = tfFloat.make({8, 8, 8}, data);
70+
auto out = tfFloat.zeros({8, 8, 8});
71+
72+
auto result = fast_hadamard_transform_nocontext(mat, out);
73+
74+
std::vector<float> reference_result = data;
75+
for (int ii = 0; ii < 8; ++ii) {
76+
for (int jj = 0; jj < 8; ++jj) {
77+
reference_fht_impl(&reference_result[ii * 64 + jj * 8], 8);
78+
}
79+
}
80+
81+
const float* const result_data = result.const_data_ptr<float>();
82+
for (int ii = 0; ii < data.size(); ++ii) {
83+
EXPECT_FLOAT_EQ(result_data[ii], reference_result[ii]);
84+
}
85+
}
86+
87+
TEST(OpFastHadamardTransformTest, Basic28N) {
88+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
89+
constexpr int kTestLogSize = 7;
90+
constexpr int kTestPowerOfTwoSize = 1 << kTestLogSize;
91+
constexpr int kTestTotalSize = kTestPowerOfTwoSize * 28;
92+
std::vector<float> data = random_floats(kTestTotalSize);
93+
auto vec = tfFloat.make({kTestTotalSize}, data);
94+
auto out = tfFloat.zeros({kTestTotalSize});
95+
96+
// The operator is supposed to autodetect 28 * 2**N size and handle
97+
// accordingly.
98+
auto result = fast_hadamard_transform_nocontext(vec, out);
99+
100+
std::vector<float> reference_result = data;
101+
fast_hadamard_transform_28N_with_transpose(
102+
reference_result.data(), kTestLogSize);
103+
104+
const float* const result_data = result.const_data_ptr<float>();
105+
for (int ii = 0; ii < data.size(); ++ii) {
106+
EXPECT_FLOAT_EQ(result_data[ii], reference_result[ii]);
107+
}
108+
}
109+
110+
TEST(OpFastHadamardTransformTest, InvalidSize) {
111+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
112+
auto mat = tfFloat.zeros({3});
113+
auto out = tfFloat.zeros({3});
114+
115+
exec_aten::RuntimeContext context;
116+
torch::executor::native::fast_hadamard_transform_out(context, mat, out);
117+
EXPECT_NE(context.failure_state(), executorch::runtime::Error::Ok);
118+
}

extension/llm/custom_ops/spinquant/test/targets.bzl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,14 @@ def define_common_targets():
2929
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
3030
],
3131
)
32+
33+
runtime.cxx_test(
34+
name = "op_fast_hadamard_transform_test",
35+
srcs = ["op_fast_hadamard_transform_test.cpp"],
36+
deps = [
37+
":fast_hadamard_transform_test_impl",
38+
"//executorch/extension/llm/custom_ops:custom_ops",
39+
"//executorch/kernels/test:test_util",
40+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
41+
],
42+
)

0 commit comments

Comments
 (0)