Skip to content

[TOSA] Usage of 32bit integer for 'index to float' in rfft2d #75098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 5, 2024

Conversation

d-smirnov
Copy link
Contributor

Lowering of rfft2d to linalg now uses index to i32 cast if an output float is of 32bit and cast to i64 otherwise.
@eric-k256 @GeorgeARM

@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2023

@llvm/pr-subscribers-mlir-linalg

Author: Dmitriy Smirnov (d-smirnov)

Changes

Lowering of rfft2d to linalg now uses index to i32 cast if an output float is of 32bit and cast to i64 otherwise.
@eric-k256 @GeorgeARM


Full diff: https://github.com/llvm/llvm-project/pull/75098.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+24-24)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 3bf7bf12b5e96f..f0e4a00bdc5179 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2284,8 +2284,11 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
 
   static Value castIndexToFloat(OpBuilder &builder, Location loc,
                                 FloatType type, Value value) {
-    auto integerVal =
-        builder.create<arith::IndexCastUIOp>(loc, builder.getI64Type(), value);
+    auto integerVal = builder.create<arith::IndexCastUIOp>(
+        loc,
+        32 <= type.getIntOrFloatBitWidth() ? builder.getI32Type()
+                                           : builder.getI64Type(),
+        value);
 
     return builder.create<arith::UIToFPOp>(loc, type, integerVal);
   }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aa53b366f6da68..6f9597dd399af6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1691,10 +1691,10 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
 // CHECK:   %[[EMPTY_1:.*]] = tensor.empty() : tensor<5x5x5xf32>
 // CHECK:   %[[VAR_3:.*]] = linalg.fill ins(%[[CST_ZERO:.*]]: f32) outs(%[[EMPTY_1:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
 // CHECK:   %[[CST_PI:.*]] = arith.constant 6.28318548 : f32
-// CHECK:   %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i64
-// CHECK:   %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i64 to f32
-// CHECK:   %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i64
-// CHECK:   %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i64 to f32
+// CHECK:   %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i32
+// CHECK:   %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i32 to f32
+// CHECK:   %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i32
+// CHECK:   %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i32 to f32
 // CHECK:   linalg.generic {
 // CHECK:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
 // CHECK:     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1702,17 +1702,17 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
 // CHECK:     outs(%[[VAR_1]], %[[VAR_3]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
 // CHECK:   ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
 // CHECK:     %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK:     %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK:     %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i64 to f32
+// CHECK:     %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK:     %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i32 to f32
 // CHECK:     %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK:     %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK:     %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i64 to f32
+// CHECK:     %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK:     %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i32 to f32
 // CHECK:     %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK:     %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK:     %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i64 to f32
+// CHECK:     %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK:     %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i32 to f32
 // CHECK:     %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK:     %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK:     %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i64 to f32
+// CHECK:     %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK:     %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i32 to f32
 // CHECK:     %[[VAR_21:.*]] = arith.mulf %[[VAR_17]], %[[VAR_11]] : f32
 // CHECK:     %[[VAR_22:.*]] = arith.mulf %[[VAR_20]], %[[VAR_14]] : f32
 // CHECK:     %[[XCOMP:.*]] = arith.divf %[[VAR_21]], %[[VAR_6]] : f32
@@ -1761,10 +1761,10 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
 // CHECK:   %[[CST_2:.*]] = arith.constant 2 : index
 // CHECK:   %[[DIM_8:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor<?x?x?xf32>
 // CHECK:   %[[CST_9:.*]] = arith.constant 6.28318548 : f32
-// CHECK:   %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i64
-// CHECK:   %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i64 to f32
-// CHECK:   %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i64
-// CHECK:   %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i64 to f32
+// CHECK:   %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i32
+// CHECK:   %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i32 to f32
+// CHECK:   %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i32
+// CHECK:   %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i32 to f32
 // CHECK:   linalg.generic {
 // CHECK:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
 // CHECK:     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1772,17 +1772,17 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
 // CHECK:     outs(%[[VAR_3]], %[[VAR_5]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
 // CHECK:   ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
 // CHECK:     %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK:     %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK:     %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i64 to f32
+// CHECK:     %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK:     %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i32 to f32
 // CHECK:     %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK:     %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK:     %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i64 to f32
+// CHECK:     %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK:     %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i32 to f32
 // CHECK:     %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK:     %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK:     %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i64 to f32
+// CHECK:     %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK:     %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i32 to f32
 // CHECK:     %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK:     %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK:     %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i64 to f32
+// CHECK:     %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK:     %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i32 to f32
 // CHECK:     %[[VAR_23:.*]] = arith.mulf %[[VAR_19]], %[[VAR_13]] : f32
 // CHECK:     %[[VAR_24:.*]] = arith.mulf %[[VAR_22]], %[[VAR_16]] : f32
 // CHECK:     %[[XCOMP:.*]] = arith.divf %[[VAR_23]], %[[VAR_7]] : f32

@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2023

@llvm/pr-subscribers-mlir

Author: Dmitriy Smirnov (d-smirnov)

Changes

Lowering of rfft2d to linalg now uses index to i32 cast if an output float is of 32bit and cast to i64 otherwise.
@eric-k256 @GeorgeARM


Full diff: https://github.com/llvm/llvm-project/pull/75098.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+24-24)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 3bf7bf12b5e96f..f0e4a00bdc5179 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2284,8 +2284,11 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
 
   static Value castIndexToFloat(OpBuilder &builder, Location loc,
                                 FloatType type, Value value) {
-    auto integerVal =
-        builder.create<arith::IndexCastUIOp>(loc, builder.getI64Type(), value);
+    auto integerVal = builder.create<arith::IndexCastUIOp>(
+        loc,
+        32 <= type.getIntOrFloatBitWidth() ? builder.getI32Type()
+                                           : builder.getI64Type(),
+        value);
 
     return builder.create<arith::UIToFPOp>(loc, type, integerVal);
   }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aa53b366f6da68..6f9597dd399af6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1691,10 +1691,10 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
 // CHECK:   %[[EMPTY_1:.*]] = tensor.empty() : tensor<5x5x5xf32>
 // CHECK:   %[[VAR_3:.*]] = linalg.fill ins(%[[CST_ZERO:.*]]: f32) outs(%[[EMPTY_1:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
 // CHECK:   %[[CST_PI:.*]] = arith.constant 6.28318548 : f32
-// CHECK:   %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i64
-// CHECK:   %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i64 to f32
-// CHECK:   %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i64
-// CHECK:   %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i64 to f32
+// CHECK:   %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i32
+// CHECK:   %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i32 to f32
+// CHECK:   %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i32
+// CHECK:   %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i32 to f32
 // CHECK:   linalg.generic {
 // CHECK:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
 // CHECK:     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1702,17 +1702,17 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
 // CHECK:     outs(%[[VAR_1]], %[[VAR_3]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
 // CHECK:   ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
 // CHECK:     %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK:     %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK:     %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i64 to f32
+// CHECK:     %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK:     %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i32 to f32
 // CHECK:     %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK:     %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK:     %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i64 to f32
+// CHECK:     %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK:     %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i32 to f32
 // CHECK:     %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK:     %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK:     %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i64 to f32
+// CHECK:     %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK:     %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i32 to f32
 // CHECK:     %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK:     %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK:     %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i64 to f32
+// CHECK:     %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK:     %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i32 to f32
 // CHECK:     %[[VAR_21:.*]] = arith.mulf %[[VAR_17]], %[[VAR_11]] : f32
 // CHECK:     %[[VAR_22:.*]] = arith.mulf %[[VAR_20]], %[[VAR_14]] : f32
 // CHECK:     %[[XCOMP:.*]] = arith.divf %[[VAR_21]], %[[VAR_6]] : f32
@@ -1761,10 +1761,10 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
 // CHECK:   %[[CST_2:.*]] = arith.constant 2 : index
 // CHECK:   %[[DIM_8:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor<?x?x?xf32>
 // CHECK:   %[[CST_9:.*]] = arith.constant 6.28318548 : f32
-// CHECK:   %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i64
-// CHECK:   %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i64 to f32
-// CHECK:   %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i64
-// CHECK:   %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i64 to f32
+// CHECK:   %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i32
+// CHECK:   %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i32 to f32
+// CHECK:   %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i32
+// CHECK:   %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i32 to f32
 // CHECK:   linalg.generic {
 // CHECK:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
 // CHECK:     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1772,17 +1772,17 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
 // CHECK:     outs(%[[VAR_3]], %[[VAR_5]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
 // CHECK:   ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
 // CHECK:     %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK:     %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK:     %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i64 to f32
+// CHECK:     %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK:     %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i32 to f32
 // CHECK:     %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK:     %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK:     %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i64 to f32
+// CHECK:     %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK:     %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i32 to f32
 // CHECK:     %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK:     %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK:     %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i64 to f32
+// CHECK:     %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK:     %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i32 to f32
 // CHECK:     %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK:     %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK:     %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i64 to f32
+// CHECK:     %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK:     %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i32 to f32
 // CHECK:     %[[VAR_23:.*]] = arith.mulf %[[VAR_19]], %[[VAR_13]] : f32
 // CHECK:     %[[VAR_24:.*]] = arith.mulf %[[VAR_22]], %[[VAR_16]] : f32
 // CHECK:     %[[XCOMP:.*]] = arith.divf %[[VAR_23]], %[[VAR_7]] : f32

Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks okay to me.

@GeorgeARM
Copy link
Contributor

GeorgeARM commented Dec 12, 2023

Thanks for this PR. Overall it looks fine to me.
Could you please elaborate in your commit-msg what is the reasoning behind this change?

If this is due to conversion targets not being able to handle 64bit integer arithmetic "easily" or at all (e.g. SPIRV) then I think that this might require more discussion. Not sure how we can handle or if we need a more immediate casting operator from index to floating point that can be lowered later on as per target needs.

@MaheshRavishankar
Copy link
Contributor

Thanks for the tag @GeorgeARM . I am on vacation right now but @antiagainst or @kuhar can help here

Copy link
Contributor

@RoboTux RoboTux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@kuhar
Copy link
Member

kuhar commented Dec 12, 2023

If this is due to conversion targets not being able to handle 64bit integer arithmetic "easily" or at all (e.g. SPIRV) then I think that this might require more discussion. Not sure how we can handle or if we need a more immediate casting operator from index to floating point that can be lowered later on as per target needs.

FYI, I have patterns designed to help with this set of constraints:

//===----------------------------------------------------------------------===//
// *IToFPOp Patterns
//===----------------------------------------------------------------------===//
template <typename IToFPOp, ExtensionKind Extension>
struct IToFPPattern final : NarrowingPattern<IToFPOp> {
using NarrowingPattern<IToFPOp>::NarrowingPattern;
LogicalResult matchAndRewrite(IToFPOp op,
PatternRewriter &rewriter) const override {
FailureOr<unsigned> narrowestWidth =
calculateBitsRequired(op.getIn(), Extension);
if (failed(narrowestWidth))
return failure();
FailureOr<Type> narrowTy =
this->getNarrowType(*narrowestWidth, op.getIn().getType());
if (failed(narrowTy))
return failure();
Value newIn = rewriter.createOrFold<arith::TruncIOp>(op.getLoc(), *narrowTy,
op.getIn());
rewriter.replaceOpWithNewOp<IToFPOp>(op, op.getType(), newIn);
return success();
}
};
using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;

@d-smirnov
Copy link
Contributor Author

d-smirnov commented Dec 12, 2023 via email

Copy link

github-actions bot commented Dec 14, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Lowering of rfft2d to linalg now uses index to i32 cast
if an output float is of 32bit and 64bit otherwise.
@d-smirnov d-smirnov reopened this Jan 3, 2024
@d-smirnov
Copy link
Contributor Author

PR was accidentally closed. Reopened now.

@d-smirnov d-smirnov mentioned this pull request Jan 4, 2024
@GeorgeARM
Copy link
Contributor

LGTM

@GeorgeARM GeorgeARM merged commit 2952fb3 into llvm:main Jan 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants