Skip to content

[executorch] Custom op for fast hadamard transform kernel #5291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion build/cmake_deps.toml
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,10 @@ buck_targets = [
"//extension/llm/custom_ops:custom_ops",
]
filters = [
".cpp$",
# Second clause is to pick up fht_neon.c/fht_avx.c from FFHT. TODO:
# remove filters and patch extract_sources.py's Buck query to fetch
# srcs; presumably filters is here to remove .h files.
"(.cpp$)|(fht.*\\.c$)",
]
excludes = [
"^codegen",
Expand Down
3 changes: 2 additions & 1 deletion extension/llm/custom_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
add_library(
custom_ops_aot_lib SHARED
${_custom_ops__srcs} ${CMAKE_CURRENT_SOURCE_DIR}/op_sdpa_aot.cpp
${CMAKE_CURRENT_SOURCE_DIR}/op_tile_crop_aot.cpp
${CMAKE_CURRENT_SOURCE_DIR}/op_fast_hadamard_transform_aten.cpp
${CMAKE_CURRENT_SOURCE_DIR}/op_tile_crop.cpp
${CMAKE_CURRENT_SOURCE_DIR}/op_tile_crop_aot.cpp
)
target_include_directories(
custom_ops_aot_lib PUBLIC "${_common_include_directories}"
Expand Down
93 changes: 93 additions & 0 deletions extension/llm/custom_ops/op_fast_hadamard_transform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
#include <executorch/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h>
#include <executorch/kernels/optimized/utils/llvmMathExtras.h>
#include <executorch/kernels/portable/cpu/util/reduce_util.h> // For apply_over_dim.
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {

Tensor& fast_hadamard_transform_out(
RuntimeContext& ctx,
const Tensor& mat,
Tensor& out) {
ET_KERNEL_CHECK_MSG(
ctx,
resize_tensor(out, mat.sizes()) == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");

ET_KERNEL_CHECK(
ctx, mat.scalar_type() == out.scalar_type(), InvalidArgument, out);

if (mat.dim() == 0 || mat.numel() == 0) {
return out;
}

ET_KERNEL_CHECK_MSG(
ctx,
mat.strides().back() == 1,
InvalidArgument,
out,
"input matrix that isn't contiguous in the last dimension is not supported!");

const auto last_dim_size = mat.sizes().back();
const auto divisible_by_28 = last_dim_size % 28 == 0;
auto power_of_two_size = divisible_by_28 ? last_dim_size / 28 : last_dim_size;
ET_KERNEL_CHECK_MSG(
ctx,
(power_of_two_size & (power_of_two_size - 1)) == 0,
InvalidArgument,
out,
"This implementation requires power-of-2 (or power-of-2 * 28) input size in the last dimension!");

const auto log2_power_of_two_size = executorch::llvm::countTrailingZeros(
static_cast<unsigned int>(power_of_two_size),
executorch::llvm::ZeroBehavior::ZB_Undefined);

ET_SWITCH_FLOATH_TYPES(mat.scalar_type(), ctx, __func__, CTYPE, [&] {
const CTYPE* const mat_data = mat.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();

std::memcpy(out_data, mat_data, mat.numel() * sizeof(CTYPE));

if (divisible_by_28) {
apply_over_dim(
[log2_power_of_two_size, out_data](
const size_t size, const size_t stride, const size_t base) {
executorch::fast_hadamard_transform_28N(
out_data + base, log2_power_of_two_size);
},
out,
out.dim() - 1);
} else {
apply_over_dim(
[log2_power_of_two_size, out_data](
const size_t size, const size_t stride, const size_t base) {
executorch::fast_hadamard_transform(
out_data + base, log2_power_of_two_size);
},
out,
out.dim() - 1);
}
});
return out;
}
} // namespace native
} // namespace executor
} // namespace torch

EXECUTORCH_LIBRARY(
llama,
"fast_hadamard_transform.out",
torch::executor::native::fast_hadamard_transform_out);
25 changes: 25 additions & 0 deletions extension/llm/custom_ops/op_fast_hadamard_transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch::executor::native {

// Compute the fast Walsh-Hadamard transform
// (https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform)
// of mat along the last dimension (which must be contiguous).
//
// mat.sizes().back() is currently required to be either a power of
// two, or 28 * a power of two.
Tensor& fast_hadamard_transform_out(
RuntimeContext& ctx,
const Tensor& mat,
Tensor& out);
} // namespace torch::executor::native
43 changes: 43 additions & 0 deletions extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
#include <executorch/extension/llm/custom_ops/op_fast_hadamard_transform.h>

#include <torch/library.h>

namespace torch::executor::native {
namespace {
Tensor& fast_hadamard_transform_out_no_context(const Tensor& vec, Tensor& out) {
exec_aten::RuntimeContext context;
return fast_hadamard_transform_out(context, vec, out);
}
at::Tensor fast_hadamard_transform_aten(const at::Tensor& vec) {
auto out = at::empty_like(vec);
WRAP_TO_ATEN(fast_hadamard_transform_out_no_context, 1)
(vec, out);
return out;
}
} // namespace
} // namespace torch::executor::native

TORCH_LIBRARY_FRAGMENT(llama, m) {
m.def("fast_hadamard_transform(Tensor mat) -> Tensor");
m.def(
"fast_hadamard_transform.out(Tensor mat, *, Tensor(a!) out) -> Tensor(a!)");
}

TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl(
"fast_hadamard_transform",
torch::executor::native::fast_hadamard_transform_aten);
m.impl(
"fast_hadamard_transform.out",
WRAP_TO_ATEN(
torch::executor::native::fast_hadamard_transform_out_no_context, 1));
}
6 changes: 3 additions & 3 deletions extension/llm/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace torch {
namespace executor {

namespace native {

namespace {
Tensor& sdpa_with_kv_cache_out_no_context(
const Tensor& q_projected,
const Tensor& k_projected,
Expand Down Expand Up @@ -81,12 +81,12 @@ at::Tensor sdpa_with_kv_cache_aten(
output);
return output;
}

} // namespace
} // namespace native
} // namespace executor
} // namespace torch

TORCH_LIBRARY(llama, m) {
TORCH_LIBRARY_FRAGMENT(llama, m) {
m.def(
"sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
Expand Down
12 changes: 12 additions & 0 deletions extension/llm/custom_ops/sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@
try:
op = torch.ops.llama.sdpa_with_kv_cache.default
assert op is not None
op2 = torch.ops.llama.fast_hadamard_transform.default
assert op2 is not None
except:
libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*"))
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
logging.info(f"Loading custom ops library: {libs[0]}")
torch.ops.load_library(libs[0])
op = torch.ops.llama.sdpa_with_kv_cache.default
assert op is not None
op2 = torch.ops.llama.fast_hadamard_transform.default
assert op2 is not None

custom_ops_lib = torch.library.Library("llama", "IMPL")

Expand Down Expand Up @@ -126,3 +130,11 @@ def sdpa_with_kv_cache_meta(
)

return torch.empty_like(query)


@impl(custom_ops_lib, "fast_hadamard_transform", "Meta")
def fast_hadamard_transform_meta(mat):
# assert(mat.strides[-1] == 1, "input matrix must be contiguous in the last dimension!")
# assert(mat.shape[-1] == 128 or mat.shape[-1] == 14336, "unexpected input size for llama3 demo!")
# assert(mat.is_contiguous(), "input matrix must be contiguous currently!")
return torch.empty_like(mat)
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/llm/custom_ops/op_fast_hadamard_transform.h>
#include <executorch/extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_test_impl.h>
#include <executorch/kernels/test/TestUtil.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>

#include <gtest/gtest.h>

#include <cmath>

using exec_aten::Tensor;

using executorch::runtime::testing::fast_hadamard_transform_28N_with_transpose;
using executorch::runtime::testing::random_floats;
using executorch::runtime::testing::reference_fht_impl;

namespace {
Tensor& fast_hadamard_transform_nocontext(const Tensor& vec, Tensor& out) {
exec_aten::RuntimeContext context;
return torch::executor::native::fast_hadamard_transform_out(
context, vec, out);
}
} // namespace

TEST(OpFastHadamardTransformTest, EmptyInput) {
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
auto vec = tfFloat.zeros({0});
auto out = tfFloat.zeros({0});
auto result = fast_hadamard_transform_nocontext(vec, out);
EXPECT_EQ(result.numel(), 0);
}

TEST(OpFastHadamardTransformTest, SingleElementInput) {
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
auto vec = tfFloat.ones({1});
auto out = tfFloat.zeros({1});
auto result = fast_hadamard_transform_nocontext(vec, out);
EXPECT_EQ(result.numel(), 1);
// FHT of a single element is a no-op.
EXPECT_EQ(result.const_data_ptr<float>()[0], 1);
}

TEST(OpFastHadamardTransformTest, FourKInput) {
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
std::vector<float> data = random_floats(4096);
auto vec = tfFloat.make({4096}, data);
auto out = tfFloat.zeros({4096});
auto result = fast_hadamard_transform_nocontext(vec, out);

std::vector<float> reference_result = data;
reference_fht_impl(reference_result.data(), reference_result.size());

const float* const result_data = result.const_data_ptr<float>();
for (int ii = 0; ii < data.size(); ++ii) {
EXPECT_FLOAT_EQ(result_data[ii], reference_result[ii]);
}
}

TEST(OpFastHadamardTransformTest, MultipleRows) {
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
std::vector<float> data = random_floats(8 * 8 * 8);
auto mat = tfFloat.make({8, 8, 8}, data);
auto out = tfFloat.zeros({8, 8, 8});

auto result = fast_hadamard_transform_nocontext(mat, out);

std::vector<float> reference_result = data;
for (int ii = 0; ii < 8; ++ii) {
for (int jj = 0; jj < 8; ++jj) {
reference_fht_impl(&reference_result[ii * 64 + jj * 8], 8);
}
}

const float* const result_data = result.const_data_ptr<float>();
for (int ii = 0; ii < data.size(); ++ii) {
EXPECT_FLOAT_EQ(result_data[ii], reference_result[ii]);
}
}

TEST(OpFastHadamardTransformTest, Basic28N) {
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
constexpr int kTestLogSize = 7;
constexpr int kTestPowerOfTwoSize = 1 << kTestLogSize;
constexpr int kTestTotalSize = kTestPowerOfTwoSize * 28;
std::vector<float> data = random_floats(kTestTotalSize);
auto vec = tfFloat.make({kTestTotalSize}, data);
auto out = tfFloat.zeros({kTestTotalSize});

// The operator is supposed to autodetect 28 * 2**N size and handle
// accordingly.
auto result = fast_hadamard_transform_nocontext(vec, out);

std::vector<float> reference_result = data;
fast_hadamard_transform_28N_with_transpose(
reference_result.data(), kTestLogSize);

const float* const result_data = result.const_data_ptr<float>();
for (int ii = 0; ii < data.size(); ++ii) {
EXPECT_FLOAT_EQ(result_data[ii], reference_result[ii]);
}
}

TEST(OpFastHadamardTransformTest, InvalidSize) {
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
auto mat = tfFloat.zeros({3});
auto out = tfFloat.zeros({3});

exec_aten::RuntimeContext context;
torch::executor::native::fast_hadamard_transform_out(context, mat, out);
EXPECT_NE(context.failure_state(), executorch::runtime::Error::Ok);
}
11 changes: 11 additions & 0 deletions extension/llm/custom_ops/spinquant/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,14 @@ def define_common_targets():
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
],
)

runtime.cxx_test(
name = "op_fast_hadamard_transform_test",
srcs = ["op_fast_hadamard_transform_test.cpp"],
deps = [
":fast_hadamard_transform_test_impl",
"//executorch/extension/llm/custom_ops:custom_ops",
"//executorch/kernels/test:test_util",
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
],
)
Loading
Loading