Skip to content

Commit ee0284e

Browse files
authored
[TOSA] FFT2D/RFFT2D accuracy increased (#88510)
This PR increases accurasy of FFT2D/RFFT2D calculation by removing periodic part of sin/cos
1 parent d556ed5 commit ee0284e

File tree

3 files changed

+212
-204
lines changed

3 files changed

+212
-204
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
1414
#include "mlir/Dialect/Arith/IR/Arith.h"
1515
#include "mlir/Dialect/Arith/Utils/Utils.h"
16+
#include "mlir/Dialect/Index/IR/IndexOps.h"
1617
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1718
#include "mlir/Dialect/Math/IR/Math.h"
1819
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -2367,16 +2368,25 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
23672368
Value sumImag = args[2];
23682369

23692370
// Indices for angle computation
2370-
auto oy = createLinalgIndex(builder, loc, elementType, 1);
2371-
auto ox = createLinalgIndex(builder, loc, elementType, 2);
2372-
auto iy = createLinalgIndex(builder, loc, elementType, 3);
2373-
auto ix = createLinalgIndex(builder, loc, elementType, 4);
2374-
2375-
// angle = 2 * pi() * ((iy * oy) / H + (ix * ox) / W)
2376-
auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);
2377-
auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox);
2378-
auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH);
2379-
auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW);
2371+
Value oy = builder.create<linalg::IndexOp>(loc, 1);
2372+
Value ox = builder.create<linalg::IndexOp>(loc, 2);
2373+
Value iy = builder.create<linalg::IndexOp>(loc, 3);
2374+
Value ix = builder.create<linalg::IndexOp>(loc, 4);
2375+
2376+
// Calculating angle without integer parts of components as sin/cos are
2377+
// periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2378+
// / W);
2379+
auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2380+
auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2381+
2382+
auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2383+
auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2384+
2385+
auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2386+
auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2387+
2388+
auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2389+
auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
23802390
auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
23812391
auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
23822392

@@ -2481,22 +2491,30 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
24812491
Value sumImag = args[3];
24822492

24832493
// Indices for angle computation
2484-
Value oy =
2485-
RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 1);
2486-
Value ox =
2487-
RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 2);
2488-
Value iy =
2489-
RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 3);
2490-
Value ix =
2491-
RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 4);
2492-
2493-
// float_t angle = sign_val * 2 * pi() * ((iy * oy) / H + (ix * ox) / W);
2494-
auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);
2495-
auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox);
2496-
auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH);
2497-
auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW);
2494+
Value oy = builder.create<linalg::IndexOp>(loc, 1);
2495+
Value ox = builder.create<linalg::IndexOp>(loc, 2);
2496+
Value iy = builder.create<linalg::IndexOp>(loc, 3);
2497+
Value ix = builder.create<linalg::IndexOp>(loc, 4);
2498+
2499+
// float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2500+
// ox) % W ) / W);
2501+
auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2502+
auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2503+
2504+
auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2505+
auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2506+
2507+
auto iyRemFloat =
2508+
RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2509+
auto ixRemFloat =
2510+
RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2511+
2512+
auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2513+
auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2514+
24982515
auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
24992516
auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2517+
25002518
if (inverse.getValue()) {
25012519
angle = builder.create<arith::MulFOp>(
25022520
loc, angle,

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/Dialect/Index/IR/IndexDialect.h"
1718
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1819
#include "mlir/Dialect/Math/IR/Math.h"
1920
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -40,7 +41,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
4041
void getDependentDialects(DialectRegistry &registry) const override {
4142
registry
4243
.insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
43-
tensor::TensorDialect, scf::SCFDialect>();
44+
index::IndexDialect, tensor::TensorDialect, scf::SCFDialect>();
4445
}
4546

4647
void runOnOperation() override {

0 commit comments

Comments
 (0)