Skip to content

Commit 915fce0

Browse files
authored
[mlir][affine] Enable ConvertAffineToStandard pass to handle affine.delinearize_index Op. (#82189)
This PR, aims to enable the `ConvertAffineToStandard` to handle `affine.dilinearize_index` Operation. Fixes #78458
1 parent a2efb68 commit 915fce0

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
1515

1616
#include "mlir/Dialect/Affine/IR/AffineOps.h"
17+
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
1718
#include "mlir/Dialect/Affine/Utils.h"
1819
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1920
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -558,6 +559,7 @@ class LowerAffinePass
558559
RewritePatternSet patterns(&getContext());
559560
populateAffineToStdConversionPatterns(patterns);
560561
populateAffineToVectorConversionPatterns(patterns);
562+
populateAffineExpandIndexOpsPatterns(patterns);
561563
ConversionTarget target(getContext());
562564
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
563565
scf::SCFDialect, VectorDialect>();

mlir/test/Conversion/AffineToStandard/lower-affine.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,3 +927,57 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
927927
// CHECK: scf.reduce.return %[[RES]] : i64
928928
// CHECK: }
929929
// CHECK: }
930+
931+
///////////////////////////////////////////////////////////////////////
932+
933+
func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index) {
934+
%b0 = arith.constant 16 : index
935+
%b1 = arith.constant 224 : index
936+
%b2 = arith.constant 224 : index
937+
%1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
938+
return %1#0, %1#1, %1#2 : index, index, index
939+
}
940+
// CHECK-LABEL: func.func @test_dilinearize_index(
941+
// CHECK-SAME: %[[VAL_0:.*]]: index) -> (index, index, index) {
942+
// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
943+
// CHECK: %[[VAL_2:.*]] = arith.constant 224 : index
944+
// CHECK: %[[VAL_3:.*]] = arith.constant 224 : index
945+
// CHECK: %[[VAL_4:.*]] = arith.constant 50176 : index
946+
// CHECK: %[[VAL_5:.*]] = arith.constant 50176 : index
947+
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
948+
// CHECK: %[[VAL_7:.*]] = arith.constant -1 : index
949+
// CHECK: %[[VAL_8:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_6]] : index
950+
// CHECK: %[[VAL_9:.*]] = arith.subi %[[VAL_7]], %[[VAL_0]] : index
951+
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_8]], %[[VAL_9]], %[[VAL_0]] : index
952+
// CHECK: %[[VAL_11:.*]] = arith.divsi %[[VAL_10]], %[[VAL_5]] : index
953+
// CHECK: %[[VAL_12:.*]] = arith.subi %[[VAL_7]], %[[VAL_11]] : index
954+
// CHECK: %[[VAL_13:.*]] = arith.select %[[VAL_8]], %[[VAL_12]], %[[VAL_11]] : index
955+
// CHECK: %[[VAL_14:.*]] = arith.constant 50176 : index
956+
// CHECK: %[[VAL_15:.*]] = arith.remsi %[[VAL_0]], %[[VAL_14]] : index
957+
// CHECK: %[[VAL_16:.*]] = arith.constant 0 : index
958+
// CHECK: %[[VAL_17:.*]] = arith.cmpi slt, %[[VAL_15]], %[[VAL_16]] : index
959+
// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_14]] : index
960+
// CHECK: %[[VAL_19:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_15]] : index
961+
// CHECK: %[[VAL_20:.*]] = arith.constant 50176 : index
962+
// CHECK: %[[VAL_21:.*]] = arith.remsi %[[VAL_0]], %[[VAL_20]] : index
963+
// CHECK: %[[VAL_22:.*]] = arith.constant 0 : index
964+
// CHECK: %[[VAL_23:.*]] = arith.cmpi slt, %[[VAL_21]], %[[VAL_22]] : index
965+
// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_20]] : index
966+
// CHECK: %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_21]] : index
967+
// CHECK: %[[VAL_26:.*]] = arith.constant 224 : index
968+
// CHECK: %[[VAL_27:.*]] = arith.constant 0 : index
969+
// CHECK: %[[VAL_28:.*]] = arith.constant -1 : index
970+
// CHECK: %[[VAL_29:.*]] = arith.cmpi slt, %[[VAL_25]], %[[VAL_27]] : index
971+
// CHECK: %[[VAL_30:.*]] = arith.subi %[[VAL_28]], %[[VAL_25]] : index
972+
// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_29]], %[[VAL_30]], %[[VAL_25]] : index
973+
// CHECK: %[[VAL_32:.*]] = arith.divsi %[[VAL_31]], %[[VAL_26]] : index
974+
// CHECK: %[[VAL_33:.*]] = arith.subi %[[VAL_28]], %[[VAL_32]] : index
975+
// CHECK: %[[VAL_34:.*]] = arith.select %[[VAL_29]], %[[VAL_33]], %[[VAL_32]] : index
976+
// CHECK: %[[VAL_35:.*]] = arith.constant 224 : index
977+
// CHECK: %[[VAL_36:.*]] = arith.remsi %[[VAL_0]], %[[VAL_35]] : index
978+
// CHECK: %[[VAL_37:.*]] = arith.constant 0 : index
979+
// CHECK: %[[VAL_38:.*]] = arith.cmpi slt, %[[VAL_36]], %[[VAL_37]] : index
980+
// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_36]], %[[VAL_35]] : index
981+
// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_38]], %[[VAL_39]], %[[VAL_36]] : index
982+
// CHECK: return %[[VAL_13]], %[[VAL_34]], %[[VAL_40]] : index, index, index
983+
// CHECK: }

0 commit comments

Comments
 (0)