Skip to content

[mlir][Transforms] GreedyPatternRewriteDriver: Do not CSE constants during iterations #75897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions flang/test/Lower/array-temp.f90
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ subroutine ss4(N)

! CHECK-LABEL: func @_QPss2(
! CHECK-SAME: %arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
! CHECK: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
! CHECK: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
! CHECK: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
! CHECK: %[[C_27_i32:[-0-9a-z_]+]] = arith.constant 27 : i32
! CHECK: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
! CHECK: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
! CHECK: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
! CHECK: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
! CHECK: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
! CHECK-DAG: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
! CHECK-DAG: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
! CHECK-DAG: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
! CHECK-DAG: %[[C_27_i32:[-0-9a-z_]+]] = arith.constant 27 : i32
! CHECK-DAG: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
! CHECK-DAG: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
! CHECK-DAG: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
! CHECK-DAG: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
! CHECK-DAG: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
! CHECK: %[[V_0:[0-9]+]] = fir.load %arg0 : !fir.ref<i32>
! CHECK: %[[V_1:[0-9]+]] = fir.convert %[[V_0:[0-9]+]] : (i32) -> index
! CHECK: %[[V_2:[0-9]+]] = arith.cmpi sgt, %[[V_1]], %[[C_0]] : index
Expand Down Expand Up @@ -137,15 +137,15 @@ subroutine ss4(N)

! CHECK-LABEL: func @_QPss3(
! CHECK-SAME: %arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
! CHECK: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
! CHECK: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
! CHECK: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
! CHECK: %[[C_34_i32:[-0-9a-z_]+]] = arith.constant 34 : i32
! CHECK: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
! CHECK: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
! CHECK: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
! CHECK: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
! CHECK: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
! CHECK-DAG: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
! CHECK-DAG: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
! CHECK-DAG: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
! CHECK-DAG: %[[C_34_i32:[-0-9a-z_]+]] = arith.constant 34 : i32
! CHECK-DAG: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
! CHECK-DAG: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
! CHECK-DAG: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
! CHECK-DAG: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
! CHECK-DAG: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
! CHECK: %[[V_0:[0-9]+]] = fir.load %arg0 : !fir.ref<i32>
! CHECK: %[[V_1:[0-9]+]] = fir.convert %[[V_0:[0-9]+]] : (i32) -> index
! CHECK: %[[V_2:[0-9]+]] = arith.cmpi sgt, %[[V_1]], %[[C_0]] : index
Expand Down Expand Up @@ -263,15 +263,15 @@ subroutine ss4(N)

! CHECK-LABEL: func @_QPss4(
! CHECK-SAME: %arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
! CHECK: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
! CHECK: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
! CHECK: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
! CHECK: %[[C_41_i32:[-0-9a-z_]+]] = arith.constant 41 : i32
! CHECK: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
! CHECK: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
! CHECK: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
! CHECK: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
! CHECK: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
! CHECK-DAG: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
! CHECK-DAG: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
! CHECK-DAG: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
! CHECK-DAG: %[[C_41_i32:[-0-9a-z_]+]] = arith.constant 41 : i32
! CHECK-DAG: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
! CHECK-DAG: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
! CHECK-DAG: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
! CHECK-DAG: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
! CHECK-DAG: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
! CHECK: %[[V_0:[0-9]+]] = fir.load %arg0 : !fir.ref<i32>
! CHECK: %[[V_1:[0-9]+]] = fir.convert %[[V_0:[0-9]+]] : (i32) -> index
! CHECK: %[[V_2:[0-9]+]] = arith.cmpi sgt, %[[V_1]], %[[C_0]] : index
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
if (origYield.getDefiningOp() == peeledScalarOperation) {
yieldedVals.push_back(origYield);
} else {
// Do not materialize any new ops inside of the decomposed LinalgOp,
// as that would trigger another application of the rewrite pattern
// (infinite loop).
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(peeledGenericOp);
yieldedVals.push_back(
getZero(rewriter, genericOp.getLoc(), origYield.getType()));
}
Expand Down
84 changes: 62 additions & 22 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
Worklist worklist;
#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED

/// Non-pattern based folder for operations.
OperationFolder folder;

/// Configuration information for how to simplify.
const GreedyRewriteConfig config;

Expand Down Expand Up @@ -358,7 +355,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
: PatternRewriter(ctx), folder(ctx, this), config(config), matcher(patterns)
: PatternRewriter(ctx), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
, debugFingerPrints(this)
Expand Down Expand Up @@ -429,15 +426,55 @@ bool GreedyPatternRewriteDriver::processWorklist() {
continue;
}

// Try to fold this op.
if (succeeded(folder.tryToFold(op))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
changed = true;
// Try to fold this op. Do not fold constant ops. That would lead to an
// infinite folding loop, as every constant op would be folded to an
// Attribute and then immediately be rematerialized as a constant op, which
// is then put on the worklist.
if (!op->hasTrait<OpTrait::ConstantLike>()) {
SmallVector<OpFoldResult> foldResults;
if (succeeded(op->fold(foldResults))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
changed = true;
if (foldResults.empty()) {
// Op was modified in-place.
notifyOperationModified(op);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
continue;
continue;
}

// Op results can be replaced with `foldResults`.
assert(foldResults.size() == op->getNumResults() &&
"folder produced incorrect number of results");
OpBuilder::InsertionGuard g(*this);
setInsertionPoint(op);
SmallVector<Value> replacements;
for (auto [ofr, resultType] :
llvm::zip_equal(foldResults, op->getResultTypes())) {
if (auto value = ofr.dyn_cast<Value>()) {
assert(value.getType() == resultType &&
"folder produced value of incorrect type");
replacements.push_back(value);
continue;
}
// Materialize Attributes as SSA values.
Operation *constOp = op->getDialect()->materializeConstant(
*this, ofr.get<Attribute>(), resultType, op->getLoc());
assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
"materializeConstant produced op that is not a ConstantLike");
assert(constOp->getResultTypes()[0] == resultType &&
"materializeConstant produced incorrect result type");
replacements.push_back(constOp->getResult(0));
}
replaceOp(op, replacements);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
continue;
}
}

// Try to match one of the patterns. The rewriter is automatically
Expand Down Expand Up @@ -592,7 +629,6 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {

addOperandsToWorklist(op->getOperands());
worklist.remove(op);
folder.notifyRemoval(op);

if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
Expand Down Expand Up @@ -672,16 +708,6 @@ class GreedyPatternRewriteIteration
} // namespace

LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
auto insertKnownConstant = [&](Operation *op) {
// Check for existing constants when populating the worklist. This avoids
// accidentally reversing the constant order during processing.
Attribute constValue;
if (matchPattern(op, m_Constant(&constValue)))
if (!folder.insertKnownConstant(op, constValue))
return true;
return false;
};

bool continueRewrites = false;
int64_t iteration = 0;
MLIRContext *ctx = getContext();
Expand All @@ -691,8 +717,22 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
config.maxIterations != GreedyRewriteConfig::kNoLimit)
break;

// New iteration: start with an empty worklist.
worklist.clear();

// `OperationFolder` CSE's constant ops (and may move them into parents
// regions to enable more aggressive CSE'ing).
OperationFolder folder(getContext(), this);
auto insertKnownConstant = [&](Operation *op) {
// Check for existing constants when populating the worklist. This avoids
// accidentally reversing the constant order during processing.
Attribute constValue;
if (matchPattern(op, m_Constant(&constValue)))
if (!folder.insertKnownConstant(op, constValue))
return true;
return false;
};

if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
region.walk([&](Operation *op) {
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,9 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb

// CHECK-LABEL: func.func @broadcast_vec2d_from_i32(
// CHECK-SAME: %[[SRC:.*]]: i32) {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
Expand Down Expand Up @@ -393,8 +393,8 @@ func.func @splat_vec2d_from_f16(%arg0: f16) {

// CHECK-LABEL: func.func @transpose_i8(
// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>)
// CHECK: %[[C16:.*]] = arith.constant 16 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
}
// CHECK-LABEL: @broadcast_vec3d_from_vec1d(
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
// CHECK: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK-DAG: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
// CHECK-DAG: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
// CHECK-DAG: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
// CHECK-DAG: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>

// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<2xf32>>
Expand Down
30 changes: 15 additions & 15 deletions mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
}

// CHECK-SAME: %[[ARG_0:.*]]: memref<?xf32, strided<[?], offset: ?>>,
// CHECK: %[[C_0:.*]] = arith.constant 0 : index
// CHECK: %[[C_16:.*]] = arith.constant 16 : index
// CHECK: %[[STEP:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C_16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
// CHECK: %[[MASK_VEC:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : vector<[16]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
Expand All @@ -556,8 +556,8 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
}
// CHECK-LABEL: func.func @vector_print_vector_0d(
// CHECK-SAME: %[[VEC:.*]]: vector<f32>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
Expand All @@ -581,9 +581,9 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
}
// CHECK-LABEL: func.func @vector_print_vector(
// CHECK-SAME: %[[VEC:.*]]: vector<2x2xf32>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<2x2xf32> to vector<4xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
Expand Down Expand Up @@ -650,10 +650,10 @@ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[
}
// CHECK-LABEL: func.func @transfer_read_array_of_scalable(
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> {
// CHECK: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32>
Expand Down Expand Up @@ -684,9 +684,9 @@ func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memr
// CHECK-LABEL: func.func @transfer_write_array_of_scalable(
// CHECK-SAME: %[[VEC:.*]]: vector<3x[4]xf32>,
// CHECK-SAME: %[[MEMREF:.*]]: memref<3x?xf32>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32>
Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ func.func @arith_constant_dense_2d_zero_f64() {
// -----

// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_i8() {
// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
Expand All @@ -111,10 +111,10 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
// -----

// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_f64() {
// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
Expand Down
Loading