Skip to content

Commit bb6d5c2

Browse files
[mlir][Transforms] GreedyPatternRewriteDriver: Do not CSE constants during iterations (#75897)
The `GreedyPatternRewriteDriver` tries to iteratively fold ops and apply rewrite patterns to ops. It has special handling for constants: they are CSE'd and sometimes moved to parent regions to allow for additional CSE'ing. This happens in `OperationFolder`. To allow for efficient CSE'ing, `OperationFolder` maintains an internal lookup data structure to find the existing constant ops with the same value for each `IsolatedFromAbove` region: ```c++ /// A mapping between an insertion region and the constants that have been /// created within it. DenseMap<Region *, ConstantMap> foldScopes; ``` Rewrite patterns are allowed to modify operations. In particular, they may move operations (including constants) from one region to another one. Such an IR rewrite can make the above lookup data structure inconsistent. We encountered such a bug in a downstream project. This bug materialized in the form of an op that uses the result of a constant op from a different `IsolatedFromAbove` region (that is not accessible). This commit changes the behavior of the `GreedyPatternRewriteDriver` such that `OperationFolder` is used to CSE constants at the beginning of each iteration (as the worklist is populated), but no longer during an iteration. `OperationFolder` is no longer used after populating the worklist, so we do not have to care about inconsistent state in the `OperationFolder` due to IR rewrites. The `GreedyPatternRewriteDriver` now performs the op folding by itself instead of calling `OperationFolder::tryToFold`. This change changes the order of constant ops in test cases, but not the region in which they appear. All broken test cases were fixed by turning `CHECK` into `CHECK-DAG`. Alternatives considered: The state of `OperationFolder` could be partially invalidated with every `notifyOperationModified` notification. That is more fragile than the solution in this commit because incorrect rewriter API usage can lead to missing notifications and hard-to-debug `IsolatedFromAbove` violations. (It did not fix the above mention bug in a downstream project, which could be due to incorrect rewriter API usage or due to another conceptual problem that I missed.) Moreover, ops are frequently getting modified during a greedy pattern rewrite, so we would likely keep invalidating large parts of the state of `OperationFolder` over and over. Migration guide: Turn `CHECK` into `CHECK-DAG` in test cases. Constant ops are no longer folded during a greedy pattern rewrite. If you rely on folding (and rematerialization) of constant ops during a greedy pattern rewrite, turn the folder into a pattern.
1 parent e96e7a9 commit bb6d5c2

30 files changed

+357
-315
lines changed

flang/test/Lower/array-temp.f90

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ subroutine ss4(N)
4343

4444
! CHECK-LABEL: func @_QPss2(
4545
! CHECK-SAME: %arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
46-
! CHECK: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
47-
! CHECK: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
48-
! CHECK: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
49-
! CHECK: %[[C_27_i32:[-0-9a-z_]+]] = arith.constant 27 : i32
50-
! CHECK: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
51-
! CHECK: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
52-
! CHECK: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
53-
! CHECK: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
54-
! CHECK: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
46+
! CHECK-DAG: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
47+
! CHECK-DAG: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
48+
! CHECK-DAG: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
49+
! CHECK-DAG: %[[C_27_i32:[-0-9a-z_]+]] = arith.constant 27 : i32
50+
! CHECK-DAG: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
51+
! CHECK-DAG: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
52+
! CHECK-DAG: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
53+
! CHECK-DAG: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
54+
! CHECK-DAG: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
5555
! CHECK: %[[V_0:[0-9]+]] = fir.load %arg0 : !fir.ref<i32>
5656
! CHECK: %[[V_1:[0-9]+]] = fir.convert %[[V_0:[0-9]+]] : (i32) -> index
5757
! CHECK: %[[V_2:[0-9]+]] = arith.cmpi sgt, %[[V_1]], %[[C_0]] : index
@@ -137,15 +137,15 @@ subroutine ss4(N)
137137

138138
! CHECK-LABEL: func @_QPss3(
139139
! CHECK-SAME: %arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
140-
! CHECK: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
141-
! CHECK: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
142-
! CHECK: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
143-
! CHECK: %[[C_34_i32:[-0-9a-z_]+]] = arith.constant 34 : i32
144-
! CHECK: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
145-
! CHECK: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
146-
! CHECK: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
147-
! CHECK: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
148-
! CHECK: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
140+
! CHECK-DAG: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
141+
! CHECK-DAG: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
142+
! CHECK-DAG: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
143+
! CHECK-DAG: %[[C_34_i32:[-0-9a-z_]+]] = arith.constant 34 : i32
144+
! CHECK-DAG: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
145+
! CHECK-DAG: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
146+
! CHECK-DAG: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
147+
! CHECK-DAG: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
148+
! CHECK-DAG: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
149149
! CHECK: %[[V_0:[0-9]+]] = fir.load %arg0 : !fir.ref<i32>
150150
! CHECK: %[[V_1:[0-9]+]] = fir.convert %[[V_0:[0-9]+]] : (i32) -> index
151151
! CHECK: %[[V_2:[0-9]+]] = arith.cmpi sgt, %[[V_1]], %[[C_0]] : index
@@ -263,15 +263,15 @@ subroutine ss4(N)
263263

264264
! CHECK-LABEL: func @_QPss4(
265265
! CHECK-SAME: %arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
266-
! CHECK: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
267-
! CHECK: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
268-
! CHECK: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
269-
! CHECK: %[[C_41_i32:[-0-9a-z_]+]] = arith.constant 41 : i32
270-
! CHECK: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
271-
! CHECK: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
272-
! CHECK: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
273-
! CHECK: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
274-
! CHECK: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
266+
! CHECK-DAG: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
267+
! CHECK-DAG: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
268+
! CHECK-DAG: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
269+
! CHECK-DAG: %[[C_41_i32:[-0-9a-z_]+]] = arith.constant 41 : i32
270+
! CHECK-DAG: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
271+
! CHECK-DAG: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
272+
! CHECK-DAG: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
273+
! CHECK-DAG: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
274+
! CHECK-DAG: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
275275
! CHECK: %[[V_0:[0-9]+]] = fir.load %arg0 : !fir.ref<i32>
276276
! CHECK: %[[V_1:[0-9]+]] = fir.convert %[[V_0:[0-9]+]] : (i32) -> index
277277
! CHECK: %[[V_2:[0-9]+]] = arith.cmpi sgt, %[[V_1]], %[[C_0]] : index

mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,11 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
312312
if (origYield.getDefiningOp() == peeledScalarOperation) {
313313
yieldedVals.push_back(origYield);
314314
} else {
315+
// Do not materialize any new ops inside of the decomposed LinalgOp,
316+
// as that would trigger another application of the rewrite pattern
317+
// (infinite loop).
318+
OpBuilder::InsertionGuard g(rewriter);
319+
rewriter.setInsertionPoint(peeledGenericOp);
315320
yieldedVals.push_back(
316321
getZero(rewriter, genericOp.getLoc(), origYield.getType()));
317322
}

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
314314
Worklist worklist;
315315
#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
316316

317-
/// Non-pattern based folder for operations.
318-
OperationFolder folder;
319-
320317
/// Configuration information for how to simplify.
321318
const GreedyRewriteConfig config;
322319

@@ -358,7 +355,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
358355
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
359356
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
360357
const GreedyRewriteConfig &config)
361-
: PatternRewriter(ctx), folder(ctx, this), config(config), matcher(patterns)
358+
: PatternRewriter(ctx), config(config), matcher(patterns)
362359
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
363360
// clang-format off
364361
, debugFingerPrints(this)
@@ -429,15 +426,55 @@ bool GreedyPatternRewriteDriver::processWorklist() {
429426
continue;
430427
}
431428

432-
// Try to fold this op.
433-
if (succeeded(folder.tryToFold(op))) {
434-
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
435-
changed = true;
429+
// Try to fold this op. Do not fold constant ops. That would lead to an
430+
// infinite folding loop, as every constant op would be folded to an
431+
// Attribute and then immediately be rematerialized as a constant op, which
432+
// is then put on the worklist.
433+
if (!op->hasTrait<OpTrait::ConstantLike>()) {
434+
SmallVector<OpFoldResult> foldResults;
435+
if (succeeded(op->fold(foldResults))) {
436+
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
437+
changed = true;
438+
if (foldResults.empty()) {
439+
// Op was modified in-place.
440+
notifyOperationModified(op);
436441
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
437-
if (config.scope && failed(verify(config.scope->getParentOp())))
438-
llvm::report_fatal_error("IR failed to verify after folding");
442+
if (config.scope && failed(verify(config.scope->getParentOp())))
443+
llvm::report_fatal_error("IR failed to verify after folding");
439444
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
440-
continue;
445+
continue;
446+
}
447+
448+
// Op results can be replaced with `foldResults`.
449+
assert(foldResults.size() == op->getNumResults() &&
450+
"folder produced incorrect number of results");
451+
OpBuilder::InsertionGuard g(*this);
452+
setInsertionPoint(op);
453+
SmallVector<Value> replacements;
454+
for (auto [ofr, resultType] :
455+
llvm::zip_equal(foldResults, op->getResultTypes())) {
456+
if (auto value = ofr.dyn_cast<Value>()) {
457+
assert(value.getType() == resultType &&
458+
"folder produced value of incorrect type");
459+
replacements.push_back(value);
460+
continue;
461+
}
462+
// Materialize Attributes as SSA values.
463+
Operation *constOp = op->getDialect()->materializeConstant(
464+
*this, ofr.get<Attribute>(), resultType, op->getLoc());
465+
assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
466+
"materializeConstant produced op that is not a ConstantLike");
467+
assert(constOp->getResultTypes()[0] == resultType &&
468+
"materializeConstant produced incorrect result type");
469+
replacements.push_back(constOp->getResult(0));
470+
}
471+
replaceOp(op, replacements);
472+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
473+
if (config.scope && failed(verify(config.scope->getParentOp())))
474+
llvm::report_fatal_error("IR failed to verify after folding");
475+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
476+
continue;
477+
}
441478
}
442479

443480
// Try to match one of the patterns. The rewriter is automatically
@@ -592,7 +629,6 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
592629

593630
addOperandsToWorklist(op->getOperands());
594631
worklist.remove(op);
595-
folder.notifyRemoval(op);
596632

597633
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
598634
strictModeFilteredOps.erase(op);
@@ -672,16 +708,6 @@ class GreedyPatternRewriteIteration
672708
} // namespace
673709

674710
LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
675-
auto insertKnownConstant = [&](Operation *op) {
676-
// Check for existing constants when populating the worklist. This avoids
677-
// accidentally reversing the constant order during processing.
678-
Attribute constValue;
679-
if (matchPattern(op, m_Constant(&constValue)))
680-
if (!folder.insertKnownConstant(op, constValue))
681-
return true;
682-
return false;
683-
};
684-
685711
bool continueRewrites = false;
686712
int64_t iteration = 0;
687713
MLIRContext *ctx = getContext();
@@ -691,8 +717,22 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
691717
config.maxIterations != GreedyRewriteConfig::kNoLimit)
692718
break;
693719

720+
// New iteration: start with an empty worklist.
694721
worklist.clear();
695722

723+
// `OperationFolder` CSE's constant ops (and may move them into parents
724+
// regions to enable more aggressive CSE'ing).
725+
OperationFolder folder(getContext(), this);
726+
auto insertKnownConstant = [&](Operation *op) {
727+
// Check for existing constants when populating the worklist. This avoids
728+
// accidentally reversing the constant order during processing.
729+
Attribute constValue;
730+
if (matchPattern(op, m_Constant(&constValue)))
731+
if (!folder.insertKnownConstant(op, constValue))
732+
return true;
733+
return false;
734+
};
735+
696736
if (!config.useTopDownTraversal) {
697737
// Add operations to the worklist in postorder.
698738
region.walk([&](Operation *op) {

mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,9 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb
309309

310310
// CHECK-LABEL: func.func @broadcast_vec2d_from_i32(
311311
// CHECK-SAME: %[[SRC:.*]]: i32) {
312-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
313-
// CHECK: %[[C4:.*]] = arith.constant 4 : index
314-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
312+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
313+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
314+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
315315
// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
316316
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
317317
// CHECK: %[[VSCALE:.*]] = vector.vscale
@@ -393,8 +393,8 @@ func.func @splat_vec2d_from_f16(%arg0: f16) {
393393

394394
// CHECK-LABEL: func.func @transpose_i8(
395395
// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>)
396-
// CHECK: %[[C16:.*]] = arith.constant 16 : index
397-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
396+
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
397+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
398398
// CHECK: %[[VSCALE:.*]] = vector.vscale
399399
// CHECK: %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
400400
// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,10 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
196196
}
197197
// CHECK-LABEL: @broadcast_vec3d_from_vec1d(
198198
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
199-
// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
200-
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
201-
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
202-
// CHECK: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
199+
// CHECK-DAG: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
200+
// CHECK-DAG: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
201+
// CHECK-DAG: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
202+
// CHECK-DAG: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
203203

204204
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
205205
// CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<2xf32>>

mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -533,9 +533,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
533533
}
534534

535535
// CHECK-SAME: %[[ARG_0:.*]]: memref<?xf32, strided<[?], offset: ?>>,
536-
// CHECK: %[[C_0:.*]] = arith.constant 0 : index
537-
// CHECK: %[[C_16:.*]] = arith.constant 16 : index
538-
// CHECK: %[[STEP:.*]] = arith.constant 1 : index
536+
// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
537+
// CHECK-DAG: %[[C_16:.*]] = arith.constant 16 : index
538+
// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
539539
// CHECK: %[[MASK_VEC:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : vector<[16]xi32>
540540
// CHECK: %[[VSCALE:.*]] = vector.vscale
541541
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
@@ -556,8 +556,8 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
556556
}
557557
// CHECK-LABEL: func.func @vector_print_vector_0d(
558558
// CHECK-SAME: %[[VEC:.*]]: vector<f32>) {
559-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
560-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
559+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
560+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
561561
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
562562
// CHECK: vector.print punctuation <open>
563563
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
@@ -581,9 +581,9 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
581581
}
582582
// CHECK-LABEL: func.func @vector_print_vector(
583583
// CHECK-SAME: %[[VEC:.*]]: vector<2x2xf32>) {
584-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
585-
// CHECK: %[[C2:.*]] = arith.constant 2 : index
586-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
584+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
585+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
586+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
587587
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<2x2xf32> to vector<4xf32>
588588
// CHECK: vector.print punctuation <open>
589589
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
@@ -650,10 +650,10 @@ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[
650650
}
651651
// CHECK-LABEL: func.func @transfer_read_array_of_scalable(
652652
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> {
653-
// CHECK: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
654-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
655-
// CHECK: %[[C3:.*]] = arith.constant 3 : index
656-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
653+
// CHECK-DAG: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
654+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
655+
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
656+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
657657
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
658658
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
659659
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32>
@@ -684,9 +684,9 @@ func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memr
684684
// CHECK-LABEL: func.func @transfer_write_array_of_scalable(
685685
// CHECK-SAME: %[[VEC:.*]]: vector<3x[4]xf32>,
686686
// CHECK-SAME: %[[MEMREF:.*]]: memref<3x?xf32>) {
687-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
688-
// CHECK: %[[C3:.*]] = arith.constant 3 : index
689-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
687+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
688+
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
689+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
690690
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
691691
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
692692
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32>

mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ func.func @arith_constant_dense_2d_zero_f64() {
9191
// -----
9292

9393
// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_i8() {
94-
// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
95-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
96-
// CHECK: %[[C16:.*]] = arith.constant 16 : index
97-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
94+
// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
95+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
96+
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
97+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
9898
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
9999
// CHECK: %[[VSCALE:.*]] = vector.vscale
100100
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
@@ -111,10 +111,10 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
111111
// -----
112112

113113
// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_f64() {
114-
// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
115-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
116-
// CHECK: %[[C2:.*]] = arith.constant 2 : index
117-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
114+
// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
115+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
116+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
117+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
118118
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
119119
// CHECK: %[[VSCALE:.*]] = vector.vscale
120120
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index

0 commit comments

Comments
 (0)