Skip to content

[mlir][Transforms][NFC] Dialect Conversion: Resolve insertion point TODO #95653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2024

Conversation

matthias-springer
Copy link
Member

Remove a TODO in the dialect conversion code base when materializing unresolved conversions:

// FIXME: Determine a suitable insertion location when there are multiple
// inputs.

The implementation used to select an insertion point as follows:

  • If the cast has exactly one operand: right after the definition of the SSA value.
  • Otherwise: right before the cast op.

However, it is not necessary to change the insertion point. Unresolved materializations (UnrealizedConversionCastOp) are built during buildUnresolvedArgumentMaterialization or buildUnresolvedTargetMaterialization. In the former case, the op is inserted at the beginning of the block. In the latter case, only one operand is supported in the dialect conversion, and the op is inserted right after the definition of the SSA value. I.e., the UnrealizedConversionCastOp is already inserted at the right place and it is not necessary to change the insertion point for the resolved materialization op.

Note: The IR change changes slightly because the unrealized_conversion_cast ops at the beginning of a block are no longer doubly-inverted (by setting the insertion to the beginning of the block when inserting the unrealized_conversion_cast and again when inserting the resolved conversion op). All affected test cases were fixed by using CHECK-DAG instead of CHECK.

Also improve the quality of multiple test cases that did not check for the correct operands.

Note: This commit is in preparation of decoupling the argument/source/target materialization logic of the type converter from the dialect conversion (to reduce its complexity and make that functionality usable from a new dialect conversion driver).

Remove a TODO in the dialect conversion code base when materializing unresolved conversions:
```
// FIXME: Determine a suitable insertion location when there are multiple
// inputs.
```

The implementation used to select an insertion point as follows:
- If the cast has exactly one operand: right after the definition of the SSA value.
- Otherwise: right before the cast op.

However, it is not necessary to change the insertion point. Unresolved materializations (`UnrealizedConversionCastOp`) are built during `buildUnresolvedArgumentMaterialization` or `buildUnresolvedTargetMaterialization`. In the former case, the op is inserted at the beginning of the block. In the latter case, only one operand is supported in the dialect conversion, and the op is inserted right after the definition of the SSA value. I.e., the `UnrealizedConversionCastOp` is already inserted at the right place and it is not necessary to change the insertion point for the resolved materialization op.

Note: The IR change changes slightly because the `unrealized_conversion_cast` ops at the beginning of a block are no longer doubly-inverted (by setting the insertion to the beginning of the block when inserting the `unrealized_conversion_cast` and again when inserting the resolved conversion op). All affected test cases were fixed by using `CHECK-DAG` instead of `CHECK`.

Also improve the quality of multiple test cases that did not check for the correct operands.

Note: This commit is in preparation of decoupling the argument/source/target materialization logic of the type converter from the dialect conversion (to reduce its complexity and make that functionality usable from a new dialect conversion driver).
@llvmbot
Copy link
Member

llvmbot commented Jun 15, 2024

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

Remove a TODO in the dialect conversion code base when materializing unresolved conversions:

// FIXME: Determine a suitable insertion location when there are multiple
// inputs.

The implementation used to select an insertion point as follows:

  • If the cast has exactly one operand: right after the definition of the SSA value.
  • Otherwise: right before the cast op.

However, it is not necessary to change the insertion point. Unresolved materializations (UnrealizedConversionCastOp) are built during buildUnresolvedArgumentMaterialization or buildUnresolvedTargetMaterialization. In the former case, the op is inserted at the beginning of the block. In the latter case, only one operand is supported in the dialect conversion, and the op is inserted right after the definition of the SSA value. I.e., the UnrealizedConversionCastOp is already inserted at the right place and it is not necessary to change the insertion point for the resolved materialization op.

Note: The IR change changes slightly because the unrealized_conversion_cast ops at the beginning of a block are no longer doubly-inverted (by setting the insertion to the beginning of the block when inserting the unrealized_conversion_cast and again when inserting the resolved conversion op). All affected test cases were fixed by using CHECK-DAG instead of CHECK.

Also improve the quality of multiple test cases that did not check for the correct operands.

Note: This commit is in preparation of decoupling the argument/source/target materialization logic of the type converter from the dialect conversion (to reduce its complexity and make that functionality usable from a new dialect conversion driver).


Patch is 58.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95653.diff

18 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+1-7)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+9-6)
  • (modified) mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir (+3-3)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+2-2)
  • (modified) mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir (+9-9)
  • (modified) mlir/test/Conversion/IndexToSPIRV/index-to-spirv.mlir (+6-6)
  • (modified) mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir (+2-2)
  • (modified) mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir (+4-3)
  • (modified) mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir (+11-13)
  • (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+5-5)
  • (modified) mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir (+2-2)
  • (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+25-19)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+17-13)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-scalable-memcpy.mlir (+2-2)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+4-4)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+8-8)
  • (modified) mlir/test/Dialect/SCF/bufferize.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+10-10)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 2f0efe1b1e454..1c0cb128aeabe 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2857,13 +2857,7 @@ static LogicalResult legalizeUnresolvedMaterialization(
 
   // Try to materialize the conversion.
   if (const TypeConverter *converter = mat.getConverter()) {
-    // FIXME: Determine a suitable insertion location when there are multiple
-    // inputs.
-    if (inputOperands.size() == 1)
-      rewriter.setInsertionPointAfterValue(inputOperands.front());
-    else
-      rewriter.setInsertionPoint(op);
-
+    rewriter.setInsertionPoint(op);
     Value newMaterialization;
     switch (mat.getMaterializationKind()) {
     case MaterializationKind::Argument:
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 56ae930e6d627..d3bdbe89a5487 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -478,9 +478,10 @@ func.func @mului_extended_vector1d(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -
 // -----
 
 // CHECK-LABEL: func @cmpf_2dvector(
+//  CHECK-SAME:     %[[OARG0:.*]]: vector<4x3xf32>, %[[OARG1:.*]]: vector<4x3xf32>)
 func.func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
-  // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
-  // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast
+  // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[OARG0]]
+  // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[OARG1]]
   // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xf32>>
   // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xf32>>
   // CHECK: %[[CMP:.*]] = llvm.fcmp "olt" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xf32>
@@ -492,9 +493,10 @@ func.func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
 // -----
 
 // CHECK-LABEL: func @cmpi_0dvector(
+//  CHECK-SAME:     %[[OARG0:.*]]: vector<i32>, %[[OARG1:.*]]: vector<i32>)
 func.func @cmpi_0dvector(%arg0 : vector<i32>, %arg1 : vector<i32>) {
-  // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
-  // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast
+  // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[OARG0]]
+  // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[OARG1]]
   // CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[ARG0]], %[[ARG1]] : vector<1xi32>
   %0 = arith.cmpi ult, %arg0, %arg1 : vector<i32>
   func.return
@@ -503,9 +505,10 @@ func.func @cmpi_0dvector(%arg0 : vector<i32>, %arg1 : vector<i32>) {
 // -----
 
 // CHECK-LABEL: func @cmpi_2dvector(
+//  CHECK-SAME:     %[[OARG0:.*]]: vector<4x3xi32>, %[[OARG1:.*]]: vector<4x3xi32>)
 func.func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
-  // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
-  // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast
+  // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[OARG0]]
+  // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[OARG1]]
   // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xi32>>
   // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xi32>>
   // CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xi32>
diff --git a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
index 63989347567b5..b234cbbb35f32 100644
--- a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
@@ -199,9 +199,9 @@ func.func @bitcast_2d(%arg0: vector<2x4xf32>) {
 
 // CHECK-LABEL: func @select_2d(
 func.func @select_2d(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : vector<4x3xi32>) {
-  // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %arg0
-  // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %arg1
-  // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %arg2
+  // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %arg0
+  // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %arg1
+  // CHECK-DAG: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %arg2
   // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xi1>>
   // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xi32>>
   // CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %[[ARG2]][0] : !llvm.array<4 x vector<3xi32>>
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index ae47ae36ca51c..beb2c8d2d242c 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -60,8 +60,8 @@ func.func @index_scalar(%lhs: index, %rhs: index) {
 // CHECK-LABEL: @index_scalar_srem
 // CHECK-SAME: (%[[A:.+]]: index, %[[B:.+]]: index)
 func.func @index_scalar_srem(%lhs: index, %rhs: index) {
-  // CHECK: %[[LHS:.+]] = builtin.unrealized_conversion_cast %[[A]] : index to i32
-  // CHECK: %[[RHS:.+]] = builtin.unrealized_conversion_cast %[[B]] : index to i32
+  // CHECK-DAG: %[[LHS:.+]] = builtin.unrealized_conversion_cast %[[A]] : index to i32
+  // CHECK-DAG: %[[RHS:.+]] = builtin.unrealized_conversion_cast %[[B]] : index to i32
   // CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
   // CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
   // CHECK:  %[[ABS:.+]] = spirv.UMod %[[LABS]], %[[RABS]] : i32
diff --git a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir
index 1b13ebb38dc9e..26abb3bdc23a1 100644
--- a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir
+++ b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir
@@ -50,8 +50,8 @@ func.func @trivial_ops(%a: index, %b: index) {
 // CHECK-LABEL: @ceildivs
 // CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
 func.func @ceildivs(%n: index, %m: index) -> index {
-  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
-  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK-DAG: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
   // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 :
   // CHECK: %[[POS_ONE:.*]] = llvm.mlir.constant(1 :
   // CHECK: %[[NEG_ONE:.*]] = llvm.mlir.constant(-1 :
@@ -82,8 +82,8 @@ func.func @ceildivs(%n: index, %m: index) -> index {
 // CHECK-LABEL: @ceildivu
 // CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
 func.func @ceildivu(%n: index, %m: index) -> index {
-  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
-  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK-DAG: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
   // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 :
   // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 :
 
@@ -103,11 +103,11 @@ func.func @ceildivu(%n: index, %m: index) -> index {
 // CHECK-LABEL: @floordivs
 // CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
 func.func @floordivs(%n: index, %m: index) -> index {
-  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
-  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
-  // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 :
-  // CHECK: %[[POS_ONE:.*]] = llvm.mlir.constant(1 :
-  // CHECK: %[[NEG_ONE:.*]] = llvm.mlir.constant(-1 :
+  // CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK-DAG: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 :
+  // CHECK-DAG: %[[POS_ONE:.*]] = llvm.mlir.constant(1 :
+  // CHECK-DAG: %[[NEG_ONE:.*]] = llvm.mlir.constant(-1 :
 
   // CHECK: %[[M_NEG:.*]] = llvm.icmp "slt" %[[M]], %[[ZERO]]
   // CHECK: %[[X:.*]] = llvm.select %[[M_NEG]], %[[POS_ONE]], %[[NEG_ONE]]
diff --git a/mlir/test/Conversion/IndexToSPIRV/index-to-spirv.mlir b/mlir/test/Conversion/IndexToSPIRV/index-to-spirv.mlir
index 53dc896e98c7d..7b26f4fc13696 100644
--- a/mlir/test/Conversion/IndexToSPIRV/index-to-spirv.mlir
+++ b/mlir/test/Conversion/IndexToSPIRV/index-to-spirv.mlir
@@ -67,8 +67,8 @@ func.func @constant_ops() {
 // CHECK-LABEL: @ceildivs
 // CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
 func.func @ceildivs(%n: index, %m: index) -> index {
-  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
-  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK-DAG: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
   // CHECK: %[[ZERO:.*]] = spirv.Constant 0
   // CHECK: %[[POS_ONE:.*]] = spirv.Constant 1
   // CHECK: %[[NEG_ONE:.*]] = spirv.Constant -1
@@ -99,8 +99,8 @@ func.func @ceildivs(%n: index, %m: index) -> index {
 // CHECK-LABEL: @ceildivu
 // CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
 func.func @ceildivu(%n: index, %m: index) -> index {
-  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
-  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK-DAG: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
   // CHECK: %[[ZERO:.*]] = spirv.Constant 0
   // CHECK: %[[ONE:.*]] = spirv.Constant 1
 
@@ -120,8 +120,8 @@ func.func @ceildivu(%n: index, %m: index) -> index {
 // CHECK-LABEL: @floordivs
 // CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
 func.func @floordivs(%n: index, %m: index) -> index {
-  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
-  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK-DAG: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
   // CHECK: %[[ZERO:.*]] = spirv.Constant 0
   // CHECK: %[[POS_ONE:.*]] = spirv.Constant 1
   // CHECK: %[[NEG_ONE:.*]] = spirv.Constant -1
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
index f0119afa42f69..586460fd58180 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
@@ -78,8 +78,8 @@ func.func @copy_sign_vector_0D(%value: vector<1xf16>, %sign: vector<1xf16>) -> v
 
 // CHECK-LABEL: func @copy_sign_vector_0D
 //  CHECK-SAME: (%[[VALUE:.+]]: vector<1xf16>, %[[SIGN:.+]]: vector<1xf16>)
-//       CHECK:   %[[CASTVAL:.+]] = builtin.unrealized_conversion_cast %[[VALUE]] : vector<1xf16> to f16
-//       CHECK:   %[[CASTSIGN:.+]] = builtin.unrealized_conversion_cast %[[SIGN]] : vector<1xf16> to f16
+//   CHECK-DAG:   %[[CASTVAL:.+]] = builtin.unrealized_conversion_cast %[[VALUE]] : vector<1xf16> to f16
+//   CHECK-DAG:   %[[CASTSIGN:.+]] = builtin.unrealized_conversion_cast %[[SIGN]] : vector<1xf16> to f16
 //       CHECK:   %[[SMASK:.+]] = spirv.Constant -32768 : i16
 //       CHECK:   %[[VMASK:.+]] = spirv.Constant 32767 : i16
 //       CHECK:   %[[VCAST:.+]] = spirv.Bitcast %[[CASTVAL]] : f16 to i16
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
index 9d8f4266adf27..ebfc3b95d6eff 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
@@ -506,14 +506,15 @@ func.func @memref_reinterpret_cast_unranked_to_dynamic_shape(%offset: index,
 
 // -----
 
-// CHECK-LABEL: @memref_reshape
+// CHECK-LABEL: @memref_reshape(
+//  CHECK-SAME:     %[[ARG0:.*]]: memref<2x3xf32>, %[[ARG1:.*]]: memref<?xindex>)
 func.func @memref_reshape(%input : memref<2x3xf32>, %shape : memref<?xindex>) {
   %output = memref.reshape %input(%shape)
                 : (memref<2x3xf32>, memref<?xindex>) -> memref<*xf32>
   return
 }
-// CHECK: [[INPUT:%.*]] = builtin.unrealized_conversion_cast %{{.*}} to [[INPUT_TY:!.*]]
-// CHECK: [[SHAPE:%.*]] = builtin.unrealized_conversion_cast %{{.*}} to [[SHAPE_TY:!.*]]
+// CHECK-DAG: [[INPUT:%.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : {{.*}} to [[INPUT_TY:!.*]]
+// CHECK-DAG: [[SHAPE:%.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : {{.*}} to [[SHAPE_TY:!.*]]
 // CHECK: [[RANK:%.*]] = llvm.extractvalue [[SHAPE]][3, 0] : [[SHAPE_TY]]
 // CHECK: [[UNRANKED_OUT_O:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)>
 // CHECK: [[UNRANKED_OUT_1:%.*]] = llvm.insertvalue [[RANK]], [[UNRANKED_OUT_O]][0] : !llvm.struct<(i64, ptr)>
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index f1600d43e7bfb..96cf2264e9c2f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -115,13 +115,11 @@ func.func @zero_d_load(%arg0: memref<f32>) -> f32 {
 
 // -----
 
-// CHECK-LABEL: func @static_load
-// CHECK:         %[[MEMREF:.*]]: memref<10x42xf32>,
-// CHECK:         %[[I:.*]]: index,
-// CHECK:         %[[J:.*]]: index)
+// CHECK-LABEL: func @static_load(
+// CHECK-SAME:    %[[MEMREF:.*]]: memref<10x42xf32>, %[[I:.*]]: index, %[[J:.*]]: index)
 func.func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
-// CHECK:  %[[II:.*]] = builtin.unrealized_conversion_cast %[[I]]
-// CHECK:  %[[JJ:.*]] = builtin.unrealized_conversion_cast %[[J]]
+// CHECK-DAG:  %[[II:.*]] = builtin.unrealized_conversion_cast %[[I]]
+// CHECK-DAG:  %[[JJ:.*]] = builtin.unrealized_conversion_cast %[[J]]
 // CHECK:  %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:  %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64
 // CHECK:  %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64
@@ -148,8 +146,8 @@ func.func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
 // CHECK:         %[[MEMREF:.*]]: memref<10x42xf32>,
 // CHECK-SAME:    %[[I:.*]]: index, %[[J:.*]]: index,
 func.func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) {
-// CHECK: %[[II:.*]] = builtin.unrealized_conversion_cast %[[I]]
-// CHECK: %[[JJ:.*]] = builtin.unrealized_conversion_cast %[[J]]
+// CHECK-DAG: %[[II:.*]] = builtin.unrealized_conversion_cast %[[I]]
+// CHECK-DAG: %[[JJ:.*]] = builtin.unrealized_conversion_cast %[[J]]
 // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64
 // CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64
@@ -205,7 +203,7 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32>> } {
   func.func @address() {
     %c1 = arith.constant 1 : index
     %0 = memref.alloc(%c1) : memref<? x vector<2xf32>>
-    // CHECK: %[[CST_S:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[CST_S:.*]] = arith.constant 1 : index
     // CHECK: %[[CST:.*]] = builtin.unrealized_conversion_cast
     // CHECK: llvm.mlir.zero
     // CHECK: llvm.getelementptr %{{.*}}[[CST]]
@@ -269,8 +267,8 @@ func.func @memref.reshape(%arg0: memref<4x5x6xf32>) -> memref<2x6x20xf32> {
 // CHECK-LABEL: func @memref.reshape.dynamic.dim
 // CHECK-SAME:    %[[arg:.*]]: memref<?x?x?xf32>, %[[shape:.*]]: memref<4xi64>) -> memref<?x?x12x32xf32>
 func.func @memref.reshape.dynamic.dim(%arg: memref<?x?x?xf32>, %shape: memref<4xi64>) -> memref<?x?x12x32xf32> {
-  // CHECK: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref<?x?x?xf32> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-  // CHECK: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<4xi64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref<?x?x?xf32> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK-DAG: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<4xi64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
   // CHECK: %[[alloc_ptr:.*]] = llvm.extractvalue %[[arg_cast]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
   // CHECK: %[[align_ptr:.*]] = llvm.extractvalue %[[arg_cast]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
@@ -318,8 +316,8 @@ func.func @memref.reshape.dynamic.dim(%arg: memref<?x?x?xf32>, %shape: memref<4x
 // CHECK-LABEL: func @memref.reshape_index
 // CHECK-SAME:    %[[arg:.*]]: memref<?x?xi32>, %[[shape:.*]]: memref<1xindex>
 func.func @memref.reshape_index(%arg0: memref<?x?xi32>, %shape: memref<1xindex>) ->  memref<?xi32> {
-  // CHECK: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref<?x?xi32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-  // CHECK: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<1xindex> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref<?x?xi32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK-DAG: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<1xindex> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[alloc_ptr:.*]] = llvm.extractvalue %[[arg_cast]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[align_ptr:.*]] = llvm.extractvalue %[[arg_cast]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 882804132e66d..9dc22abf143bf 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -10,9 +10,9 @@
 // CHECK-LABEL: func @view(
 // CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index
 func.func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
-  // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2F:.*]]
-  // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0F:.*]]
-  // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1F:.*]]
+  // CHECK-DAG: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2F]]
+  // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0F]]
+  // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1F]]
   // CHECK: llvm.mlir.constant(2048 : index) : i64
   // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   %0 = memref.alloc() : memref<2048xi8>
@@ -408,8 +408,8 @@ func.func @atomic_rmw_with_offset(%I : memref<10xi32, strided<[1], offset: 5>>,
 // CHECK-SAME:   %[[ARG0:.+]]: memref<10xi32, strided<[1], offset: 5>>
 // CHECK-SAME:   %[[ARG1:.+]]: i32
 // CHECK-SAME:   %[[ARG2:.+]]: index
-// CHECK:        %[[MEMREF_STRUCT:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<10xi32, strided<[1], offset: 5>> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-// CHECK:        %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i64
+// CHECK-DAG:    %[[MEMREF_STRUCT:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<10xi32, strided<[1], offset: 5>> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK-DAG:    %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i64
 // CHECK:        %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
 // CHECK:        %[[OFFSET:.+]] = llvm.mlir.consta...
[truncated]

@matthias-springer matthias-springer requested a review from ftynse June 15, 2024 11:50
@matthias-springer matthias-springer changed the title [mlir][Transforms] Dialect Conversion: Resolve insertion point TODO [mlir][Transforms][NFC] Dialect Conversion: Resolve insertion point TODO Jun 15, 2024
@matthias-springer matthias-springer merged commit 13d983e into main Jun 17, 2024
15 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/dialect_conv_mat_ip branch June 17, 2024 17:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants