Skip to content

Commit 5cab322

Browse files
authored
Implement _fft_r2c core ATen op (#8277)
* Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 832f855 commit 5cab322

File tree

12 files changed

+395
-0
lines changed

12 files changed

+395
-0
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,6 @@
6767
[submodule "backends/cadence/utils/FACTO"]
6868
path = backends/cadence/utils/FACTO
6969
url = https://github.com/pytorch-labs/FACTO.git
70+
[submodule "third-party/pocketfft"]
71+
path = third-party/pocketfft
72+
url = https://github.com/mreineck/pocketfft

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
- op: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out
88

9+
- op: _fft_r2c.out
10+
911
- op: _linalg_det.result
1012

1113
- op: _linalg_svd.U

kernels/optimized/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ message("Generated files ${gen_command_sources}")
6060

6161
list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
6262
add_library(optimized_kernels ${_optimized_kernels__srcs})
63+
target_include_directories(optimized_kernels PRIVATE "${EXECUTORCH_ROOT}/third-party/pocketfft")
6364
target_link_libraries(
6465
optimized_kernels PRIVATE executorch_core cpublas extension_threadpool
6566
)

kernels/optimized/cpu/op_fft_r2c.cpp

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/span.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
#include <pocketfft_hdronly.h>
13+
14+
#include <optional>
15+
16+
namespace torch::executor::native {
17+
18+
// TODO: contents of this anonymous namespace are copy/pasted from
19+
// PyTorch core (aten/src/ATen/native/mkl/SpectralOps.cpp). Small
20+
// portions (the parts that don't depend on Tensor) could be reused;
21+
// refactor to enable that once we can share headers from PyTorch
22+
// core.
23+
namespace {
24+
pocketfft::stride_t stride_from_tensor(const Tensor& t) {
25+
pocketfft::stride_t stride(t.strides().begin(), t.strides().end());
26+
for (auto& s : stride) {
27+
s *= t.element_size();
28+
}
29+
return stride;
30+
}
31+
32+
pocketfft::shape_t shape_from_tensor(const Tensor& t) {
33+
return pocketfft::shape_t(t.sizes().begin(), t.sizes().end());
34+
}
35+
36+
// NOTE: The reinterpret_cast in tensor_cdata is UB, but it's what
37+
// PyTorch core does and I'm not aware of a portable way to do this
38+
// that doesn't rely on UB.
39+
template <typename T>
40+
inline std::complex<T>* tensor_cdata(Tensor& t) {
41+
return reinterpret_cast<std::complex<T>*>(
42+
t.data_ptr<executorch::runtime::etensor::complex<T>>());
43+
}
44+
45+
template <typename T>
46+
inline const std::complex<T>* tensor_cdata(const Tensor& t) {
47+
return reinterpret_cast<const std::complex<T>*>(
48+
t.const_data_ptr<executorch::runtime::etensor::complex<T>>());
49+
}
50+
51+
// NOTE: in particular this is in ATen/native/SpectralOpsUtils.h and
52+
// could be shared immediately.
53+
enum class fft_norm_mode {
54+
none, // No normalization
55+
by_root_n, // Divide by sqrt(signal_size)
56+
by_n, // Divide by signal_size
57+
};
58+
59+
// NOTE: slight fork from upstream PyTorch to use ET_KERNEL_CHECK;
60+
// upstream with TORCH_CHECK will be fine to use once we have code
61+
// sharing.
62+
template <typename T>
63+
std::optional<T>
64+
compute_fct(KernelRuntimeContext& ctx, int64_t size, int64_t normalization) {
65+
constexpr auto one = static_cast<T>(1);
66+
switch (static_cast<fft_norm_mode>(normalization)) {
67+
case fft_norm_mode::none:
68+
return one;
69+
case fft_norm_mode::by_n:
70+
return one / static_cast<T>(size);
71+
case fft_norm_mode::by_root_n:
72+
return one / std::sqrt(static_cast<T>(size));
73+
}
74+
ET_KERNEL_CHECK_MSG(
75+
ctx,
76+
false,
77+
InvalidArgument,
78+
std::nullopt,
79+
"Unsupported normalization type: %" PRId64,
80+
normalization);
81+
}
82+
83+
template <typename T>
84+
std::optional<T> compute_fct(
85+
KernelRuntimeContext& ctx,
86+
const Tensor& t,
87+
IntArrayRef dim,
88+
int64_t normalization) {
89+
if (static_cast<fft_norm_mode>(normalization) == fft_norm_mode::none) {
90+
return static_cast<T>(1);
91+
}
92+
const auto& sizes = t.sizes();
93+
int64_t n = 1;
94+
for (auto idx : dim) {
95+
n *= sizes[idx];
96+
}
97+
return compute_fct<T>(ctx, n, normalization);
98+
}
99+
100+
} // namespace
101+
102+
Tensor& opt_fft_r2c_out(
103+
KernelRuntimeContext& ctx,
104+
const Tensor& in,
105+
IntArrayRef dim,
106+
int64_t normalization,
107+
bool onesided,
108+
Tensor& out) {
109+
auto in_sizes = in.sizes();
110+
ET_KERNEL_CHECK(ctx, in.dim() <= kTensorDimensionLimit, InvalidArgument, out);
111+
112+
std::array<Tensor::SizesType, kTensorDimensionLimit> out_sizes_storage;
113+
executorch::runtime::Span<Tensor::SizesType> out_sizes(
114+
out_sizes_storage.data(), in_sizes.size());
115+
std::copy(in_sizes.begin(), in_sizes.end(), out_sizes.begin());
116+
ET_KERNEL_CHECK(ctx, !dim.empty(), InvalidArgument, out);
117+
118+
ET_KERNEL_CHECK(
119+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
120+
121+
ET_KERNEL_CHECK_MSG(
122+
ctx,
123+
onesided,
124+
InvalidArgument,
125+
out,
126+
"onesided=False is not supported yet in _fft_r2c");
127+
128+
ET_KERNEL_CHECK_MSG(
129+
ctx,
130+
out.scalar_type() == executorch::runtime::toComplexType(in.scalar_type()),
131+
InvalidArgument,
132+
out,
133+
"the output type for _fft_r2c must be the Complex type corresponding to the input type");
134+
135+
for (auto d : dim) {
136+
ET_KERNEL_CHECK_MSG(
137+
ctx,
138+
d >= 0 && d < in.dim(),
139+
InvalidArgument,
140+
out,
141+
"dims must be in bounds (got %" PRId64 ")",
142+
d);
143+
}
144+
145+
if (onesided) {
146+
out_sizes[dim.back()] = out_sizes[dim.back()] / 2 + 1;
147+
}
148+
ET_KERNEL_CHECK_MSG(
149+
ctx,
150+
resize_tensor(
151+
out,
152+
executorch::runtime::ArrayRef<Tensor::SizesType>(
153+
out_sizes.data(), out_sizes.size())) == Error::Ok,
154+
InvalidArgument,
155+
out,
156+
"Failed to resize output tensor (last dim %d).",
157+
out_sizes[dim.back()]);
158+
159+
pocketfft::shape_t axes(dim.begin(), dim.end());
160+
auto in_shape = shape_from_tensor(in);
161+
// TODO: if arbitrary strides are a possibility, we need to validate
162+
// these, because pocketfft README says "Strides that lead to
163+
// multiple accesses of the same memory address are not allowed."
164+
auto in_stride = stride_from_tensor(in);
165+
auto out_stride = stride_from_tensor(out);
166+
// NOTE: as of this writing, upstream PyTorch only supports
167+
// float/double, so we follow suit.
168+
ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "_fft_r2c.out", CTYPE_IN, [&] {
169+
auto fct = compute_fct<CTYPE_IN>(ctx, in, dim, normalization);
170+
if (!fct) {
171+
// Check failed, just bail out of the lambda.
172+
return;
173+
}
174+
pocketfft::r2c<CTYPE_IN>(
175+
in_shape,
176+
in_stride,
177+
out_stride,
178+
axes,
179+
true,
180+
in.const_data_ptr<CTYPE_IN>(),
181+
tensor_cdata<CTYPE_IN>(out),
182+
*fct);
183+
184+
// TODO: fill with conjugate symmetry if not onesided; see
185+
// ATen/native/mkl/SpectralOps.cpp
186+
});
187+
return out;
188+
}
189+
} // namespace torch::executor::native

kernels/optimized/cpu/targets.bzl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ _OPTIMIZED_ATEN_OPS = (
2525
],
2626
),
2727
op_target(name = "op_exp"),
28+
op_target(
29+
name = "op_fft_r2c",
30+
deps = [] if runtime.is_oss else ["fbsource//third-party/pocket_fft:pocketfft"],
31+
),
2832
op_target(name = "op_sigmoid"),
2933
op_target(
3034
name = "op_gelu",

kernels/optimized/optimized-oss.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
# log_softmax, due to the OSS build not currently including sleef.
66
# TODO (T183193812)
77

8+
- op: _fft_r2c.out
9+
kernels:
10+
- arg_meta: null
11+
kernel_name: torch::executor::opt_fft_r2c_out
12+
813
- op: add.out
914
kernels:
1015
- arg_meta: null

kernels/optimized/optimized.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
#
33
# This yaml file contains operators that have optimized kernels available.
44

5+
- op: _fft_r2c.out
6+
kernels:
7+
- arg_meta: null
8+
kernel_name: torch::executor::opt_fft_r2c_out
9+
510
- op: _log_softmax.out
611
kernels:
712
- arg_meta: null

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ set(_optimized_kernels_test_sources
265265
"op_bmm_test.cpp"
266266
"op_div_test.cpp"
267267
"op_exp_test.cpp"
268+
"op_fft_r2c_test.cpp"
268269
"op_gelu_test.cpp"
269270
"op_le_test.cpp"
270271
"op_log_softmax_test.cpp"

0 commit comments

Comments
 (0)