-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[TOSA] FFT2D/RFFT2D accuracy increased #88510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Increased accurasy of FFT2D/RFFT2D calculation by removing periodic part of sin/cos
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Dmitriy Smirnov (d-smirnov) ChangesThis PR increases accurasy of FFT2D/RFFT2D calculation by removing periodic part of sin/cos Patch is 35.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88510.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7c477f2e1412be..d63218f4a0420c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -2364,16 +2365,24 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
Value sumImag = args[2];
// Indices for angle computation
- auto oy = createLinalgIndex(builder, loc, elementType, 1);
- auto ox = createLinalgIndex(builder, loc, elementType, 2);
- auto iy = createLinalgIndex(builder, loc, elementType, 3);
- auto ix = createLinalgIndex(builder, loc, elementType, 4);
-
- // angle = 2 * pi() * ((iy * oy) / H + (ix * ox) / W)
- auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);
- auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox);
- auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH);
- auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW);
+ Value oy = builder.create<linalg::IndexOp>(loc, 1);
+ Value ox = builder.create<linalg::IndexOp>(loc, 2);
+ Value iy = builder.create<linalg::IndexOp>(loc, 3);
+ Value ix = builder.create<linalg::IndexOp>(loc, 4);
+
+ // float_t angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W ) /
+ // W);
+ auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
+ auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
+
+ auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
+ auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
+
+ auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
+ auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
+
+ auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
+ auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
@@ -2478,22 +2487,30 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
Value sumImag = args[3];
// Indices for angle computation
- Value oy =
- RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 1);
- Value ox =
- RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 2);
- Value iy =
- RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 3);
- Value ix =
- RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 4);
-
- // float_t angle = sign_val * 2 * pi() * ((iy * oy) / H + (ix * ox) / W);
- auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);
- auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox);
- auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH);
- auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW);
+ Value oy = builder.create<linalg::IndexOp>(loc, 1);
+ Value ox = builder.create<linalg::IndexOp>(loc, 2);
+ Value iy = builder.create<linalg::IndexOp>(loc, 3);
+ Value ix = builder.create<linalg::IndexOp>(loc, 4);
+
+ // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
+ // ox) % W ) / W);
+ auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
+ auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
+
+ auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
+ auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
+
+ auto iyRemFloat =
+ RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
+ auto ixRemFloat =
+ RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
+
+ auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
+ auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
+
auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
+
if (inverse.getValue()) {
angle = builder.create<arith::MulFOp>(
loc, angle,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 687477810030d4..ad7f6cf84e5edc 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -40,7 +41,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
- tensor::TensorDialect, scf::SCFDialect>();
+ index::IndexDialect, tensor::TensorDialect, scf::SCFDialect>();
}
void runOnOperation() override {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1fa783f05f04ee..9e6112de20932f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1622,138 +1622,132 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
}
// -----
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-
-// CHECK-LABEL: @test_static_rfft2d
-// CHECK-SAME: (%[[ARG_0:[0-9a-zA-Z_]*]]:
+// CHECK-LABEL: func.func @test_static_rfft2d(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 5 : index
+// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK: %[[VAL_14:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_15:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_16:.*]] = arith.constant 6.28318548 : f32
+// CHECK: %[[VAL_17:.*]] = arith.index_castui %[[VAL_13]] : index to i32
+// CHECK: %[[VAL_18:.*]] = arith.uitofp %[[VAL_17]] : i32 to f32
+// CHECK: %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
+// CHECK: %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
+// CHECK: %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x5x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK: ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
+// CHECK: %[[VAL_25:.*]] = linalg.index 1 : index
+// CHECK: %[[VAL_26:.*]] = linalg.index 2 : index
+// CHECK: %[[VAL_27:.*]] = linalg.index 3 : index
+// CHECK: %[[VAL_28:.*]] = linalg.index 4 : index
+// CHECK: %[[VAL_29:.*]] = index.mul %[[VAL_27]], %[[VAL_25]]
+// CHECK: %[[VAL_30:.*]] = index.mul %[[VAL_28]], %[[VAL_26]]
+// CHECK: %[[VAL_31:.*]] = index.remu %[[VAL_29]], %[[VAL_13]]
+// CHECK: %[[VAL_32:.*]] = index.remu %[[VAL_30]], %[[VAL_15]]
+// CHECK: %[[VAL_33:.*]] = arith.index_castui %[[VAL_31]] : index to i32
+// CHECK: %[[VAL_34:.*]] = arith.uitofp %[[VAL_33]] : i32 to f32
+// CHECK: %[[VAL_35:.*]] = arith.index_castui %[[VAL_32]] : index to i32
+// CHECK: %[[VAL_36:.*]] = arith.uitofp %[[VAL_35]] : i32 to f32
+// CHECK: %[[VAL_37:.*]] = arith.divf %[[VAL_34]], %[[VAL_18]] : f32
+// CHECK: %[[VAL_38:.*]] = arith.divf %[[VAL_36]], %[[VAL_20]] : f32
+// CHECK: %[[VAL_39:.*]] = arith.addf %[[VAL_37]], %[[VAL_38]] : f32
+// CHECK: %[[VAL_40:.*]] = arith.mulf %[[VAL_16]], %[[VAL_39]] : f32
+// CHECK: %[[VAL_41:.*]] = math.cos %[[VAL_40]] : f32
+// CHECK: %[[VAL_42:.*]] = math.sin %[[VAL_40]] : f32
+// CHECK: %[[VAL_43:.*]] = arith.mulf %[[VAL_22]], %[[VAL_41]] : f32
+// CHECK: %[[VAL_44:.*]] = arith.mulf %[[VAL_22]], %[[VAL_42]] : f32
+// CHECK: %[[VAL_45:.*]] = arith.addf %[[VAL_23]], %[[VAL_43]] : f32
+// CHECK: %[[VAL_46:.*]] = arith.subf %[[VAL_24]], %[[VAL_44]] : f32
+// CHECK: linalg.yield %[[VAL_45]], %[[VAL_46]] : f32, f32
+// CHECK: } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
+// CHECK: return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+// CHECK: }
func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
-// CHECK: %[[CST_1:.*]] = arith.constant 1 : index
-// CHECK: %[[CST_2:.*]] = arith.constant 2 : index
-// CHECK: %[[CST_8:.*]] = arith.constant 8 : index
-// CHECK: %[[CST_4:.*]] = arith.constant 4 : index
-// CHECK: %[[CST_5:.*]] = arith.constant 5 : index
-// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<5x5x5xf32>
-// CHECK: %[[CST_ZERO:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAR_1:.*]] = linalg.fill ins(%[[CST_ZERO:.*]] : f32) outs(%[[EMPTY_0:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
-// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<5x5x5xf32>
-// CHECK: %[[VAR_3:.*]] = linalg.fill ins(%[[CST_ZERO:.*]]: f32) outs(%[[EMPTY_1:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
-// CHECK: %[[CST_PI:.*]] = arith.constant 6.28318548 : f32
-// CHECK: %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i32
-// CHECK: %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i32 to f32
-// CHECK: %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i32
-// CHECK: %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i32 to f32
-// CHECK: linalg.generic {
-// CHECK: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
-// CHECK: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
-// CHECK: ins(%[[ARG_0]] : tensor<5x5x8xf32>)
-// CHECK: outs(%[[VAR_1]], %[[VAR_3]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
-// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
-// CHECK: %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK: %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
-// CHECK: %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i32 to f32
-// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK: %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
-// CHECK: %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i32 to f32
-// CHECK: %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK: %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
-// CHECK: %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i32 to f32
-// CHECK: %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK: %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
-// CHECK: %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i32 to f32
-// CHECK: %[[VAR_21:.*]] = arith.mulf %[[VAR_17]], %[[VAR_11]] : f32
-// CHECK: %[[VAR_22:.*]] = arith.mulf %[[VAR_20]], %[[VAR_14]] : f32
-// CHECK: %[[XCOMP:.*]] = arith.divf %[[VAR_21]], %[[VAR_6]] : f32
-// CHECK: %[[YCOMP:.*]] = arith.divf %[[VAR_22]], %[[VAR_8]] : f32
-// CHECK: %[[VAR_25:.*]] = arith.addf %[[XCOMP]], %[[YCOMP]] : f32
-// CHECK: %[[ALPHA:.*]] = arith.mulf %[[CST_PI]], %[[VAR_25]] : f32
-// CHECK: %[[COS_ALPHA:.*]] = math.cos %[[ALPHA]] : f32
-// CHECK: %[[SIN_ALPHA:.*]] = math.sin %[[ALPHA]] : f32
-// CHECK: %[[REAL_CONTRIB:.*]] = arith.mulf %[[IN]], %[[COS_ALPHA]] : f32
-// CHECK: %[[IMAG_CONTRIB:.*]] = arith.mulf %[[IN]], %[[SIN_ALPHA]] : f32
-// CHECK: %[[OUT_REAL:.*]] = arith.addf %[[OUT_0]], %[[REAL_CONTRIB]] : f32
-// CHECK: %[[OUT_IMAG:.*]] = arith.subf %[[OUT_1]], %[[IMAG_CONTRIB]] : f32
-// CHECK: linalg.yield %[[OUT_REAL]], %[[OUT_IMAG]] : f32, f32
-// CHECK: } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
-
%output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32>
}
// -----
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-
-// CHECK-LABEL: @test_dynamic_rfft2d
-// CHECK-SAME: (%[[ARG_0:[0-9a-zA-Z_]*]]:
+// CHECK-LABEL: func.func @test_dynamic_rfft2d(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_5]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_9:.*]] = arith.divui %[[VAL_6]], %[[VAL_8]] : index
+// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_9]], %[[VAL_7]] : index
+// CHECK: %[[VAL_11:.*]] = tensor.empty(%[[VAL_2]], %[[VAL_4]], %[[VAL_10]]) : tensor<?x?x?xf32>
+// CHECK: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_13:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_11]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[VAL_14:.*]] = tensor.empty(%[[VAL_2]], %[[VAL_4]], %[[VAL_10]]) : tensor<?x?x?xf32>
+// CHECK: %[[VAL_15:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_16:.*]] = linalg.fill ins(%[[VAL_15]] : f32) outs(%[[VAL_14]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[VAL_17:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_18:.*]] = tensor.dim %[[VAL_0]], %[[VAL_17]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_19:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_21:.*]] = arith.constant 6.28318548 : f32
+// CHECK: %[[VAL_22:.*]] = arith.index_castui %[[VAL_18]] : index to i32
+// CHECK: %[[VAL_23:.*]] = arith.uitofp %[[VAL_22]] : i32 to f32
+// CHECK: %[[VAL_24:.*]] = arith.index_castui %[[VAL_20]] : index to i32
+// CHECK: %[[VAL_25:.*]] = arith.uitofp %[[VAL_24]] : i32 to f32
+// CHECK: %[[VAL_26:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<?x?x?xf32>) outs(%[[VAL_13]], %[[VAL_16]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+// CHECK: ^bb0(%[[VAL_27:.*]]: f32, %[[VAL_28:.*]]: f32, %[[VAL_29:.*]]: f32):
+// CHECK: %[[VAL_30:.*]] = linalg.index 1 : index
+// CHECK: %[[VAL_31:.*]] = linalg.index 2 : index
+// CHECK: %[[VAL_32:.*]] = linalg.index 3 : index
+// CHECK: %[[VAL_33:.*]] = linalg.index 4 : index
+// CHECK: %[[VAL_34:.*]] = index.mul %[[VAL_32]], %[[VAL_30]]
+// CHECK: %[[VAL_35:.*]] = index.mul %[[VAL_33]], %[[VAL_31]]
+// CHECK: %[[VAL_36:.*]] = index.remu %[[VAL_34]], %[[VAL_18]]
+// CHECK: %[[VAL_37:.*]] = index.remu %[[VAL_35]], %[[VAL_20]]
+// CHECK: %[[VAL_38:.*]] = arith.index_castui %[[VAL_36]] : index to i32
+// CHECK: %[[VAL_39:.*]] = arith.uitofp %[[VAL_38]] : i32 to f32
+// CHECK: %[[VAL_40:.*]] = arith.index_castui %[[VAL_37]] : index to i32
+// CHECK: %[[VAL_41:.*]] = arith.uitofp %[[VAL_40]] : i32 to f32
+// CHECK: %[[VAL_42:.*]] = arith.divf %[[VAL_39]], %[[VAL_23]] : f32
+// CHECK: %[[VAL_43:.*]] = arith.divf %[[VAL_41]], %[[VAL_25]] : f32
+// CHECK: %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f32
+// CHECK: %[[VAL_45:.*]] = arith.mulf %[[VAL_21]], %[[VAL_44]] : f32
+// CHECK: %[[VAL_46:.*]] = math.cos %[[VAL_45]] : f32
+// CHECK: %[[VAL_47:.*]] = math.sin %[[VAL_45]] : f32
+// CHECK: %[[VAL_48:.*]] = arith.mulf %[[VAL_27]], %[[VAL_46]] : f32
+// CHECK: %[[VAL_49:.*]] = arith.mulf %[[VAL_27]], %[[VAL_47]] : f32
+// CHECK: %[[VAL_50:.*]] = arith.addf %[[VAL_28]], %[[VAL_48]] : f32
+// CHECK: %[[VAL_51:.*]] = arith.subf %[[VAL_29]], %[[VAL_49]] : f32
+// CHECK: linalg.yield %[[VAL_50]], %[[VAL_51]] : f32, f32
+// CHECK: } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+// CHECK: return %[[VAL_52:.*]]#0, %[[VAL_52]]#1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
+// CHECK: }
func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
-// CHECK: %[[CST_0:.*]] = arith.constant 0 : index
-// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[CST_0]] : tensor<?x?x?xf32>
-// CHECK: %[[CST_1:.*]] = arith.constant 1 : index
-// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG_0]], %[[CST_1]] : tensor<?x?x?xf32>
-// CHECK: %[[CST_2:.*]] = arith.constant 2 : index
-// CHECK: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor<?x?x?xf32>
-// CHECK: %[[CST_1_2:.*]] = arith.constant 1 : index
-// CHECK: %[[CST_2_3:.*]] = arith.constant 2 : index
-// CHECK: %[[VAR_0:.*]] = arith.divui %[[DIM_1]], %[[CST_2_3]] : index
-// CHECK: %[[VAR_1:.*]] = arith.addi %[[VAR_0]], %[[CST_1_2]] : index
-// CHECK: %[[EMPTY_0:.*]] = tensor.empty(%[[DIM]], %[[DIM_0]], %[[VAR_1]]) : tensor<?x?x?xf32>
-// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAR_3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY_0]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK: %[[EMPTY_1:.*]] = tensor.empty(%[[DIM]], %[[DIM_0]], %[[VAR_1]]) : tensor<?x?x?xf32>
-// CHECK: %[[CST_4:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAR_5:.*]] = linalg.fill ins(%[[CST_4]] : f32) outs(%[[EMPTY_1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK: %[...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
This reverts commit ee0284e.
This PR increases accurasy of FFT2D/RFFT2D calculation by removing periodic part of sin/cos