Skip to content

Commit 2952fb3

Browse files
authored
[TOSA] Usage of 32bit integer for 'index to float' in rfft2d (#75098)
Lowering of rfft2d to linalg now uses index to i32 cast if an output float is of 32bit and cast to i64 otherwise.
1 parent 4319e19 commit 2952fb3

File tree

2 files changed

+29
-26
lines changed

2 files changed

+29
-26
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2235,8 +2235,11 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
22352235

22362236
static Value castIndexToFloat(OpBuilder &builder, Location loc,
22372237
FloatType type, Value value) {
2238-
auto integerVal =
2239-
builder.create<arith::IndexCastUIOp>(loc, builder.getI64Type(), value);
2238+
auto integerVal = builder.create<arith::IndexCastUIOp>(
2239+
loc,
2240+
type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2241+
: builder.getI32Type(),
2242+
value);
22402243

22412244
return builder.create<arith::UIToFPOp>(loc, type, integerVal);
22422245
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,28 +1629,28 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
16291629
// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<5x5x5xf32>
16301630
// CHECK: %[[VAR_3:.*]] = linalg.fill ins(%[[CST_ZERO:.*]]: f32) outs(%[[EMPTY_1:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
16311631
// CHECK: %[[CST_PI:.*]] = arith.constant 6.28318548 : f32
1632-
// CHECK: %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i64
1633-
// CHECK: %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i64 to f32
1634-
// CHECK: %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i64
1635-
// CHECK: %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i64 to f32
1632+
// CHECK: %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i32
1633+
// CHECK: %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i32 to f32
1634+
// CHECK: %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i32
1635+
// CHECK: %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i32 to f32
16361636
// CHECK: linalg.generic {
16371637
// CHECK: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
16381638
// CHECK: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
16391639
// CHECK: ins(%[[ARG_0]] : tensor<5x5x8xf32>)
16401640
// CHECK: outs(%[[VAR_1]], %[[VAR_3]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
16411641
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
16421642
// CHECK: %[[INDEX_1:.*]] = linalg.index 1 : index
1643-
// CHECK: %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
1644-
// CHECK: %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i64 to f32
1643+
// CHECK: %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
1644+
// CHECK: %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i32 to f32
16451645
// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index
1646-
// CHECK: %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
1647-
// CHECK: %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i64 to f32
1646+
// CHECK: %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
1647+
// CHECK: %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i32 to f32
16481648
// CHECK: %[[INDEX_3:.*]] = linalg.index 3 : index
1649-
// CHECK: %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
1650-
// CHECK: %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i64 to f32
1649+
// CHECK: %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
1650+
// CHECK: %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i32 to f32
16511651
// CHECK: %[[INDEX_4:.*]] = linalg.index 4 : index
1652-
// CHECK: %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
1653-
// CHECK: %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i64 to f32
1652+
// CHECK: %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
1653+
// CHECK: %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i32 to f32
16541654
// CHECK: %[[VAR_21:.*]] = arith.mulf %[[VAR_17]], %[[VAR_11]] : f32
16551655
// CHECK: %[[VAR_22:.*]] = arith.mulf %[[VAR_20]], %[[VAR_14]] : f32
16561656
// CHECK: %[[XCOMP:.*]] = arith.divf %[[VAR_21]], %[[VAR_6]] : f32
@@ -1699,28 +1699,28 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
16991699
// CHECK: %[[CST_2:.*]] = arith.constant 2 : index
17001700
// CHECK: %[[DIM_8:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor<?x?x?xf32>
17011701
// CHECK: %[[CST_9:.*]] = arith.constant 6.28318548 : f32
1702-
// CHECK: %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i64
1703-
// CHECK: %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i64 to f32
1704-
// CHECK: %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i64
1705-
// CHECK: %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i64 to f32
1702+
// CHECK: %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i32
1703+
// CHECK: %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i32 to f32
1704+
// CHECK: %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i32
1705+
// CHECK: %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i32 to f32
17061706
// CHECK: linalg.generic {
17071707
// CHECK: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
17081708
// CHECK: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
17091709
// CHECK: ins(%[[ARG_0]] : tensor<?x?x?xf32>)
17101710
// CHECK: outs(%[[VAR_3]], %[[VAR_5]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
17111711
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
17121712
// CHECK: %[[INDEX_1:.*]] = linalg.index 1 : index
1713-
// CHECK: %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
1714-
// CHECK: %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i64 to f32
1713+
// CHECK: %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
1714+
// CHECK: %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i32 to f32
17151715
// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index
1716-
// CHECK: %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
1717-
// CHECK: %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i64 to f32
1716+
// CHECK: %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
1717+
// CHECK: %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i32 to f32
17181718
// CHECK: %[[INDEX_3:.*]] = linalg.index 3 : index
1719-
// CHECK: %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
1720-
// CHECK: %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i64 to f32
1719+
// CHECK: %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
1720+
// CHECK: %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i32 to f32
17211721
// CHECK: %[[INDEX_4:.*]] = linalg.index 4 : index
1722-
// CHECK: %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
1723-
// CHECK: %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i64 to f32
1722+
// CHECK: %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
1723+
// CHECK: %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i32 to f32
17241724
// CHECK: %[[VAR_23:.*]] = arith.mulf %[[VAR_19]], %[[VAR_13]] : f32
17251725
// CHECK: %[[VAR_24:.*]] = arith.mulf %[[VAR_22]], %[[VAR_16]] : f32
17261726
// CHECK: %[[XCOMP:.*]] = arith.divf %[[VAR_23]], %[[VAR_7]] : f32

0 commit comments

Comments
 (0)