Skip to content

[ExecuTorch] Add fast_hadamard_transform and fast_hadamard_transform_28N kernels #5283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ exclude_patterns = [
# NB: Objective-C is not supported
'examples/apple/**',
'examples/demo-apps/apple_ios/**',
# File contains @generated
'extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h',
]
command = [
'python',
Expand Down Expand Up @@ -177,6 +179,8 @@ exclude_patterns = [
'**/*.bat',
'**/*.jpg',
'**/*.jar',
# File contains @generated
'extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h',
]
command = [
'python',
Expand Down
16 changes: 16 additions & 0 deletions extension/llm/custom_ops/spinquant/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SpinQuant

This is an implementation of the [Fast Hadamard
Transform](https://en.wikipedia.org/wiki/Fast_Walsh–Hadamard_transform)
as used in [SpinQuant](https://arxiv.org/abs/2405.16406) (for the R3
and R4 matrices), [QuaRot](https://arxiv.org/abs/2404.00456), and
[Quip#](https://arxiv.org/pdf/2402.04396). We follow those papers'
method (as implemented in
https://github.com/Dao-AILab/fast-hadamard-transform/) for extending
the transform to non-power-of-two input sizes. CUDA is not considered
because https://github.com/Dao-AILab/fast-hadamard-transform/ is
already available.

The intended long-term destination for this code is pytorch/ao; it is
in ExecuTorch temporarily until we get C++ dependency from ExecuTorch
on torchao figured out.
5 changes: 5 additions & 0 deletions extension/llm/custom_ops/spinquant/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
93 changes: 93 additions & 0 deletions extension/llm/custom_ops/spinquant/fast_hadamard_transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// (c) Meta Platforms, Inc. and affiliates.
#pragma once

#include <cassert>
#include <cmath>
#include <cstdint>

#include "fast_hadamard_transform_special.h"

namespace executorch {
namespace internal {

// Square root of 1 << log2_n.
template <typename T>
T fast_sqrt_of_power_of_2(int log2_n) {
// The square root of 2**N is, by definition, 2**(N/2), which is
// trivial to compute for even N using a left shift.
//
// For odd N, 2**(N/2) = 2**(floor(N/2) + 1/2)
// = 2**(floor(N/2)) * (2 ** (1/2))
// = 2**(floor(N/2)) * sqrt(2)
// which is again fast to compute.
return T(1 << (log2_n / 2)) * ((log2_n % 2) ? T(std::sqrt(2)) : T(1));
}

template <typename T>
void normalize_after_fht(T* out, int log2_vec_size) {
const T inv_sqrt = T(1) / fast_sqrt_of_power_of_2<T>(log2_vec_size);
const int vec_size = 1 << log2_vec_size;
for (int ii = 0; ii < vec_size; ++ii) {
out[ii] *= inv_sqrt;
}
}

template <typename T>
void fast_hadamard_transform_simple_impl(T* vec, int log2_vec_size) {
if (log2_vec_size == 0) {
return;
}

int step = 1;
const auto vec_size = 1 << log2_vec_size;
while (step < vec_size) {
for (int ii = 0; ii < vec_size; ii += step * 2) {
for (int jj = ii; jj < ii + step; ++jj) {
auto x = vec[jj];
auto y = vec[jj + step];
vec[jj] = x + y;
vec[jj + step] = x - y;
}
}
step *= 2;
}

normalize_after_fht(vec, log2_vec_size);
}

} // namespace internal

// Compute the fast Walsh-Hadamard transform
// (https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform)
// of vec, which must be of length (1 << log2_vec_size).
template <typename T>
void fast_hadamard_transform(T* vec, int log2_vec_size) {
internal::fast_hadamard_transform_simple_impl(vec, log2_vec_size);
}

// Like fast_hadamard_transform, but vec must be of length 28 * (1 <<
// log2_vec_size) and the transform is computed by interpreting vec as
// a (28, 1 << log2_vec_size) matrix and performing 28 FHTs, followed
// by (1 << log2_vec_size) multiplications by a particular Hadamard
// matrix of size 28x28 (see special_hadamard_code_gen.py for the
// exact matrix).
template <typename T>
void fast_hadamard_transform_28N(T* vec, int log2_vec_size) {
const int vec_size = (1 << log2_vec_size);
for (int ii = 0; ii < 28; ++ii) {
fast_hadamard_transform(&vec[ii * vec_size], log2_vec_size);
}
for (int ii = 0; ii < vec_size; ++ii) {
hadamard_mult_28_strided(&vec[ii], vec_size);
}
}

} // namespace executorch
Loading
Loading