Skip to content

Commit 10063c5

Browse files
authored
[mlir][ArmSME] Move vector.print -> ArmSME lowering to VectorToArmSME (#74063)
This moves the SME tile vector.print lowering from `-convert-arm-sme-to-scf` to `-convert-vector-to-arm-sme`. This seems like a more logical place, as this is lowering a vector op to ArmSME, and it also prevents vector.print from blocking tile allocation.
1 parent 1623292 commit 10063c5

File tree

4 files changed

+93
-95
lines changed

4 files changed

+93
-95
lines changed

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 4 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -447,75 +447,13 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
447447
}
448448
};
449449

450-
/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
451-
/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
452-
/// a 1D `vector.print`.
453-
///
454-
/// BEFORE:
455-
/// ```mlir
456-
/// vector.print %tile : vector<[4]x[4]xf32>
457-
/// ```
458-
/// AFTER:
459-
/// ```mlir
460-
/// %c0 = arith.constant 0 : index
461-
/// %c1 = arith.constant 1 : index
462-
/// %c4 = arith.constant 4 : index
463-
/// %vscale = vector.vscale
464-
/// %svl_s = arith.muli %c4, %vscale : index
465-
/// scf.for %i = %c0 to %svl_s step %c1 {
466-
/// %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
467-
/// : vector<[4]xf32> from vector<[4]x[4]xf32>
468-
/// vector.print %tile_slice : vector<[4]xf32>
469-
/// }
470-
/// ```
471-
struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
472-
using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
473-
474-
LogicalResult matchAndRewrite(vector::PrintOp printOp,
475-
PatternRewriter &rewriter) const override {
476-
if (!printOp.getSource())
477-
return failure();
478-
479-
VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
480-
if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
481-
return failure();
482-
483-
auto loc = printOp.getLoc();
484-
485-
// Create a loop over the rows of the tile.
486-
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
487-
auto minTileRows =
488-
rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
489-
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
490-
auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
491-
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
492-
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
493-
{
494-
// Loop body.
495-
rewriter.setInsertionPointToStart(forOp.getBody());
496-
// Extract the current row from the tile.
497-
Value rowIndex = forOp.getInductionVar();
498-
// FIXME: Forward tile IDs.
499-
// For now, if you vector.print a SME tile you need to do
500-
// -allocate-arm-sme-tiles after -convert-arm-sme-to-scf.
501-
auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
502-
loc, printOp.getSource(), rowIndex);
503-
// Print the row with a 1D vector.print.
504-
rewriter.create<vector::PrintOp>(loc, tileSlice,
505-
printOp.getPunctuation());
506-
}
507-
508-
rewriter.eraseOp(printOp);
509-
return success();
510-
}
511-
};
512-
513450
} // namespace
514451

515452
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
516-
patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
517-
TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
518-
TileVectorPrintOpConversion>(patterns.getContext());
453+
patterns
454+
.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
455+
TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion>(
456+
patterns.getContext());
519457
}
520458

521459
namespace {

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -666,15 +666,75 @@ struct VectorInsertToArmSMELowering
666666
}
667667
};
668668

669+
/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
670+
/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
671+
/// a 1D `vector.print`.
672+
///
673+
/// BEFORE:
674+
/// ```mlir
675+
/// vector.print %tile : vector<[4]x[4]xf32>
676+
/// ```
677+
/// AFTER:
678+
/// ```mlir
679+
/// %c0 = arith.constant 0 : index
680+
/// %c1 = arith.constant 1 : index
681+
/// %c4 = arith.constant 4 : index
682+
/// %vscale = vector.vscale
683+
/// %svl_s = arith.muli %c4, %vscale : index
684+
/// scf.for %i = %c0 to %svl_s step %c1 {
685+
/// %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
686+
/// : vector<[4]xf32> from vector<[4]x[4]xf32>
687+
/// vector.print %tile_slice : vector<[4]xf32>
688+
/// }
689+
/// ```
690+
struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
691+
using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
692+
693+
LogicalResult matchAndRewrite(vector::PrintOp printOp,
694+
PatternRewriter &rewriter) const override {
695+
if (!printOp.getSource())
696+
return failure();
697+
698+
VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
699+
if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
700+
return failure();
701+
702+
auto loc = printOp.getLoc();
703+
704+
// Create a loop over the rows of the tile.
705+
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
706+
auto minTileRows =
707+
rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
708+
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
709+
auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
710+
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
711+
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
712+
{
713+
// Loop body.
714+
rewriter.setInsertionPointToStart(forOp.getBody());
715+
// Extract the current row from the tile.
716+
Value rowIndex = forOp.getInductionVar();
717+
auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
718+
loc, printOp.getSource(), rowIndex);
719+
// Print the row with a 1D vector.print.
720+
rewriter.create<vector::PrintOp>(loc, tileSlice,
721+
printOp.getPunctuation());
722+
}
723+
724+
rewriter.eraseOp(printOp);
725+
return success();
726+
}
727+
};
728+
669729
} // namespace
670730

671731
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
672732
MLIRContext &ctx) {
673-
patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
674-
SplatOpToArmSMELowering, TransferReadToArmSMELowering,
675-
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
676-
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
677-
VectorOuterProductToArmSMELowering,
678-
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
679-
&ctx);
733+
patterns
734+
.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
735+
SplatOpToArmSMELowering, TransferReadToArmSMELowering,
736+
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
737+
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
738+
VectorOuterProductToArmSMELowering, VectorExtractToArmSMELowering,
739+
VectorInsertToArmSMELowering, VectorPrintToArmSMELowering>(&ctx);
680740
}

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -160,25 +160,3 @@ func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest :
160160
arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
161161
return
162162
}
163-
164-
//===----------------------------------------------------------------------===//
165-
// vector.print
166-
//===----------------------------------------------------------------------===//
167-
168-
// -----
169-
170-
func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
171-
{
172-
vector.print %tile : vector<[4]x[4]xf32>
173-
return
174-
}
175-
// CHECK-LABEL: func.func @arm_sme_tile_print(
176-
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>) {
177-
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
178-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
179-
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
180-
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
181-
// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
182-
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
183-
// CHECK-NEXT: %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
184-
// CHECK-NEXT: vector.print %[[TILE_SLICE]] : vector<[4]xf32>

mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,3 +736,25 @@ func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64
736736
%result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
737737
"prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
738738
}
739+
740+
//===----------------------------------------------------------------------===//
741+
// vector.print
742+
//===----------------------------------------------------------------------===//
743+
744+
// -----
745+
746+
func.func @vector_print_tile(%tile: vector<[4]x[4]xf32>)
747+
{
748+
vector.print %tile : vector<[4]x[4]xf32>
749+
return
750+
}
751+
// CHECK-LABEL: func.func @vector_print_tile(
752+
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>) {
753+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
754+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
755+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
756+
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
757+
// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
758+
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
759+
// CHECK-NEXT: %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
760+
// CHECK-NEXT: vector.print %[[TILE_SLICE]] : vector<[4]xf32>

0 commit comments

Comments
 (0)