Skip to content

[TOSA] FFT2D/RFFT2D accuracy increased #88510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 18, 2024
Merged

Conversation

d-smirnov
Copy link
Contributor

This PR increases accurasy of FFT2D/RFFT2D calculation by removing periodic part of sin/cos

Increased accurasy of FFT2D/RFFT2D calculation by removing periodic part
of sin/cos
@llvmbot
Copy link
Member

llvmbot commented Apr 12, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Dmitriy Smirnov (d-smirnov)

Changes

This PR increases accurasy of FFT2D/RFFT2D calculation by removing periodic part of sin/cos


Patch is 35.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88510.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+41-24)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp (+2-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+168-179)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7c477f2e1412be..d63218f4a0420c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -2364,16 +2365,24 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
       Value sumImag = args[2];
 
       // Indices for angle computation
-      auto oy = createLinalgIndex(builder, loc, elementType, 1);
-      auto ox = createLinalgIndex(builder, loc, elementType, 2);
-      auto iy = createLinalgIndex(builder, loc, elementType, 3);
-      auto ix = createLinalgIndex(builder, loc, elementType, 4);
-
-      // angle = 2 * pi() * ((iy * oy) / H + (ix * ox) / W)
-      auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);
-      auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox);
-      auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH);
-      auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW);
+      Value oy = builder.create<linalg::IndexOp>(loc, 1);
+      Value ox = builder.create<linalg::IndexOp>(loc, 2);
+      Value iy = builder.create<linalg::IndexOp>(loc, 3);
+      Value ix = builder.create<linalg::IndexOp>(loc, 4);
+
+      // float_t angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W ) /
+      // W);
+      auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
+      auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
+
+      auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
+      auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
+
+      auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
+      auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
+
+      auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
+      auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
       auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
       auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
 
@@ -2478,22 +2487,30 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
       Value sumImag = args[3];
 
       // Indices for angle computation
-      Value oy =
-          RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 1);
-      Value ox =
-          RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 2);
-      Value iy =
-          RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 3);
-      Value ix =
-          RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 4);
-
-      // float_t angle = sign_val * 2 * pi() * ((iy * oy) / H + (ix * ox) / W);
-      auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);
-      auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox);
-      auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH);
-      auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW);
+      Value oy = builder.create<linalg::IndexOp>(loc, 1);
+      Value ox = builder.create<linalg::IndexOp>(loc, 2);
+      Value iy = builder.create<linalg::IndexOp>(loc, 3);
+      Value ix = builder.create<linalg::IndexOp>(loc, 4);
+
+      // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
+      // ox) % W ) / W);
+      auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
+      auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
+
+      auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
+      auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
+
+      auto iyRemFloat =
+          RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
+      auto ixRemFloat =
+          RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
+
+      auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
+      auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
+
       auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
       auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
+
       if (inverse.getValue()) {
         angle = builder.create<arith::MulFOp>(
             loc, angle,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 687477810030d4..ad7f6cf84e5edc 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -40,7 +41,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
   void getDependentDialects(DialectRegistry &registry) const override {
     registry
         .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
-                tensor::TensorDialect, scf::SCFDialect>();
+                index::IndexDialect, tensor::TensorDialect, scf::SCFDialect>();
   }
 
   void runOnOperation() override {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1fa783f05f04ee..9e6112de20932f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1622,138 +1622,132 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
 }
 
 // -----
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-
-// CHECK-LABEL: @test_static_rfft2d
-// CHECK-SAME: (%[[ARG_0:[0-9a-zA-Z_]*]]:
+// CHECK-LABEL:   func.func @test_static_rfft2d(
+// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 5 : index
+// CHECK:           %[[VAL_6:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK:           %[[VAL_9:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK:           %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK:           %[[VAL_12:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK:           %[[VAL_14:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_15:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_16:.*]] = arith.constant 6.28318548 : f32
+// CHECK:           %[[VAL_17:.*]] = arith.index_castui %[[VAL_13]] : index to i32
+// CHECK:           %[[VAL_18:.*]] = arith.uitofp %[[VAL_17]] : i32 to f32
+// CHECK:           %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
+// CHECK:           %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
+// CHECK:           %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x5x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK:           ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
+// CHECK:             %[[VAL_25:.*]] = linalg.index 1 : index
+// CHECK:             %[[VAL_26:.*]] = linalg.index 2 : index
+// CHECK:             %[[VAL_27:.*]] = linalg.index 3 : index
+// CHECK:             %[[VAL_28:.*]] = linalg.index 4 : index
+// CHECK:             %[[VAL_29:.*]] = index.mul %[[VAL_27]], %[[VAL_25]]
+// CHECK:             %[[VAL_30:.*]] = index.mul %[[VAL_28]], %[[VAL_26]]
+// CHECK:             %[[VAL_31:.*]] = index.remu %[[VAL_29]], %[[VAL_13]]
+// CHECK:             %[[VAL_32:.*]] = index.remu %[[VAL_30]], %[[VAL_15]]
+// CHECK:             %[[VAL_33:.*]] = arith.index_castui %[[VAL_31]] : index to i32
+// CHECK:             %[[VAL_34:.*]] = arith.uitofp %[[VAL_33]] : i32 to f32
+// CHECK:             %[[VAL_35:.*]] = arith.index_castui %[[VAL_32]] : index to i32
+// CHECK:             %[[VAL_36:.*]] = arith.uitofp %[[VAL_35]] : i32 to f32
+// CHECK:             %[[VAL_37:.*]] = arith.divf %[[VAL_34]], %[[VAL_18]] : f32
+// CHECK:             %[[VAL_38:.*]] = arith.divf %[[VAL_36]], %[[VAL_20]] : f32
+// CHECK:             %[[VAL_39:.*]] = arith.addf %[[VAL_37]], %[[VAL_38]] : f32
+// CHECK:             %[[VAL_40:.*]] = arith.mulf %[[VAL_16]], %[[VAL_39]] : f32
+// CHECK:             %[[VAL_41:.*]] = math.cos %[[VAL_40]] : f32
+// CHECK:             %[[VAL_42:.*]] = math.sin %[[VAL_40]] : f32
+// CHECK:             %[[VAL_43:.*]] = arith.mulf %[[VAL_22]], %[[VAL_41]] : f32
+// CHECK:             %[[VAL_44:.*]] = arith.mulf %[[VAL_22]], %[[VAL_42]] : f32
+// CHECK:             %[[VAL_45:.*]] = arith.addf %[[VAL_23]], %[[VAL_43]] : f32
+// CHECK:             %[[VAL_46:.*]] = arith.subf %[[VAL_24]], %[[VAL_44]] : f32
+// CHECK:             linalg.yield %[[VAL_45]], %[[VAL_46]] : f32, f32
+// CHECK:           } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
+// CHECK:           return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+// CHECK:         }
 func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
-// CHECK:   %[[CST_1:.*]] = arith.constant 1 : index
-// CHECK:   %[[CST_2:.*]] = arith.constant 2 : index
-// CHECK:   %[[CST_8:.*]] = arith.constant 8 : index
-// CHECK:   %[[CST_4:.*]] = arith.constant 4 : index
-// CHECK:   %[[CST_5:.*]] = arith.constant 5 : index
-// CHECK:   %[[EMPTY_0:.*]] = tensor.empty() : tensor<5x5x5xf32>
-// CHECK:   %[[CST_ZERO:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:   %[[VAR_1:.*]] = linalg.fill ins(%[[CST_ZERO:.*]] : f32) outs(%[[EMPTY_0:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
-// CHECK:   %[[EMPTY_1:.*]] = tensor.empty() : tensor<5x5x5xf32>
-// CHECK:   %[[VAR_3:.*]] = linalg.fill ins(%[[CST_ZERO:.*]]: f32) outs(%[[EMPTY_1:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
-// CHECK:   %[[CST_PI:.*]] = arith.constant 6.28318548 : f32
-// CHECK:   %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i32
-// CHECK:   %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i32 to f32
-// CHECK:   %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i32
-// CHECK:   %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i32 to f32
-// CHECK:   linalg.generic {
-// CHECK:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
-// CHECK:     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
-// CHECK:     ins(%[[ARG_0]] : tensor<5x5x8xf32>)
-// CHECK:     outs(%[[VAR_1]], %[[VAR_3]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
-// CHECK:   ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
-// CHECK:     %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK:     %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
-// CHECK:     %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i32 to f32
-// CHECK:     %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK:     %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
-// CHECK:     %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i32 to f32
-// CHECK:     %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK:     %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
-// CHECK:     %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i32 to f32
-// CHECK:     %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK:     %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
-// CHECK:     %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i32 to f32
-// CHECK:     %[[VAR_21:.*]] = arith.mulf %[[VAR_17]], %[[VAR_11]] : f32
-// CHECK:     %[[VAR_22:.*]] = arith.mulf %[[VAR_20]], %[[VAR_14]] : f32
-// CHECK:     %[[XCOMP:.*]] = arith.divf %[[VAR_21]], %[[VAR_6]] : f32
-// CHECK:     %[[YCOMP:.*]] = arith.divf %[[VAR_22]], %[[VAR_8]] : f32
-// CHECK:     %[[VAR_25:.*]] = arith.addf %[[XCOMP]], %[[YCOMP]] : f32
-// CHECK:     %[[ALPHA:.*]] = arith.mulf %[[CST_PI]], %[[VAR_25]] : f32
-// CHECK:     %[[COS_ALPHA:.*]] = math.cos %[[ALPHA]] : f32
-// CHECK:     %[[SIN_ALPHA:.*]] = math.sin %[[ALPHA]] : f32
-// CHECK:     %[[REAL_CONTRIB:.*]] = arith.mulf %[[IN]], %[[COS_ALPHA]] : f32
-// CHECK:     %[[IMAG_CONTRIB:.*]] = arith.mulf %[[IN]], %[[SIN_ALPHA]] : f32
-// CHECK:     %[[OUT_REAL:.*]] = arith.addf %[[OUT_0]], %[[REAL_CONTRIB]] : f32
-// CHECK:     %[[OUT_IMAG:.*]] = arith.subf %[[OUT_1]], %[[IMAG_CONTRIB]] : f32
-// CHECK:     linalg.yield %[[OUT_REAL]], %[[OUT_IMAG]] : f32, f32
-// CHECK:   } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
-
   %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
   return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32>
 }
 
 // -----
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-
-// CHECK-LABEL: @test_dynamic_rfft2d
-// CHECK-SAME: (%[[ARG_0:[0-9a-zA-Z_]*]]:
+// CHECK-LABEL:   func.func @test_dynamic_rfft2d(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_5:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_5]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_9:.*]] = arith.divui %[[VAL_6]], %[[VAL_8]] : index
+// CHECK:           %[[VAL_10:.*]] = arith.addi %[[VAL_9]], %[[VAL_7]] : index
+// CHECK:           %[[VAL_11:.*]] = tensor.empty(%[[VAL_2]], %[[VAL_4]], %[[VAL_10]]) : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_13:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_11]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK:           %[[VAL_14:.*]] = tensor.empty(%[[VAL_2]], %[[VAL_4]], %[[VAL_10]]) : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_15:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_16:.*]] = linalg.fill ins(%[[VAL_15]] : f32) outs(%[[VAL_14]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK:           %[[VAL_17:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_18:.*]] = tensor.dim %[[VAL_0]], %[[VAL_17]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_19:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_21:.*]] = arith.constant 6.28318548 : f32
+// CHECK:           %[[VAL_22:.*]] = arith.index_castui %[[VAL_18]] : index to i32
+// CHECK:           %[[VAL_23:.*]] = arith.uitofp %[[VAL_22]] : i32 to f32
+// CHECK:           %[[VAL_24:.*]] = arith.index_castui %[[VAL_20]] : index to i32
+// CHECK:           %[[VAL_25:.*]] = arith.uitofp %[[VAL_24]] : i32 to f32
+// CHECK:           %[[VAL_26:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<?x?x?xf32>) outs(%[[VAL_13]], %[[VAL_16]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+// CHECK:           ^bb0(%[[VAL_27:.*]]: f32, %[[VAL_28:.*]]: f32, %[[VAL_29:.*]]: f32):
+// CHECK:             %[[VAL_30:.*]] = linalg.index 1 : index
+// CHECK:             %[[VAL_31:.*]] = linalg.index 2 : index
+// CHECK:             %[[VAL_32:.*]] = linalg.index 3 : index
+// CHECK:             %[[VAL_33:.*]] = linalg.index 4 : index
+// CHECK:             %[[VAL_34:.*]] = index.mul %[[VAL_32]], %[[VAL_30]]
+// CHECK:             %[[VAL_35:.*]] = index.mul %[[VAL_33]], %[[VAL_31]]
+// CHECK:             %[[VAL_36:.*]] = index.remu %[[VAL_34]], %[[VAL_18]]
+// CHECK:             %[[VAL_37:.*]] = index.remu %[[VAL_35]], %[[VAL_20]]
+// CHECK:             %[[VAL_38:.*]] = arith.index_castui %[[VAL_36]] : index to i32
+// CHECK:             %[[VAL_39:.*]] = arith.uitofp %[[VAL_38]] : i32 to f32
+// CHECK:             %[[VAL_40:.*]] = arith.index_castui %[[VAL_37]] : index to i32
+// CHECK:             %[[VAL_41:.*]] = arith.uitofp %[[VAL_40]] : i32 to f32
+// CHECK:             %[[VAL_42:.*]] = arith.divf %[[VAL_39]], %[[VAL_23]] : f32
+// CHECK:             %[[VAL_43:.*]] = arith.divf %[[VAL_41]], %[[VAL_25]] : f32
+// CHECK:             %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f32
+// CHECK:             %[[VAL_45:.*]] = arith.mulf %[[VAL_21]], %[[VAL_44]] : f32
+// CHECK:             %[[VAL_46:.*]] = math.cos %[[VAL_45]] : f32
+// CHECK:             %[[VAL_47:.*]] = math.sin %[[VAL_45]] : f32
+// CHECK:             %[[VAL_48:.*]] = arith.mulf %[[VAL_27]], %[[VAL_46]] : f32
+// CHECK:             %[[VAL_49:.*]] = arith.mulf %[[VAL_27]], %[[VAL_47]] : f32
+// CHECK:             %[[VAL_50:.*]] = arith.addf %[[VAL_28]], %[[VAL_48]] : f32
+// CHECK:             %[[VAL_51:.*]] = arith.subf %[[VAL_29]], %[[VAL_49]] : f32
+// CHECK:             linalg.yield %[[VAL_50]], %[[VAL_51]] : f32, f32
+// CHECK:           } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+// CHECK:           return %[[VAL_52:.*]]#0, %[[VAL_52]]#1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
+// CHECK:         }
 func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
-// CHECK:   %[[CST_0:.*]] = arith.constant 0 : index
-// CHECK:   %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[CST_0]] : tensor<?x?x?xf32>
-// CHECK:   %[[CST_1:.*]] = arith.constant 1 : index
-// CHECK:   %[[DIM_0:.*]] = tensor.dim %[[ARG_0]], %[[CST_1]] : tensor<?x?x?xf32>
-// CHECK:   %[[CST_2:.*]] = arith.constant 2 : index
-// CHECK:   %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor<?x?x?xf32>
-// CHECK:   %[[CST_1_2:.*]] = arith.constant 1 : index
-// CHECK:   %[[CST_2_3:.*]] = arith.constant 2 : index
-// CHECK:   %[[VAR_0:.*]] = arith.divui %[[DIM_1]], %[[CST_2_3]] : index
-// CHECK:   %[[VAR_1:.*]] = arith.addi %[[VAR_0]], %[[CST_1_2]] : index
-// CHECK:   %[[EMPTY_0:.*]] = tensor.empty(%[[DIM]], %[[DIM_0]], %[[VAR_1]]) : tensor<?x?x?xf32>
-// CHECK:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:   %[[VAR_3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY_0]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK:   %[[EMPTY_1:.*]] = tensor.empty(%[[DIM]], %[[DIM_0]], %[[VAR_1]]) : tensor<?x?x?xf32>
-// CHECK:   %[[CST_4:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:   %[[VAR_5:.*]] = linalg.fill ins(%[[CST_4]] : f32) outs(%[[EMPTY_1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK:   %[...
[truncated]

@d-smirnov
Copy link
Contributor Author

@d-smirnov d-smirnov requested a review from GeorgeARM April 18, 2024 08:14
@d-smirnov
Copy link
Contributor Author

@eric-k256

Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@GeorgeARM GeorgeARM merged commit ee0284e into llvm:main Apr 18, 2024
joker-eph added a commit that referenced this pull request Apr 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants