Skip to content

[mlir][ArmSME] Add custom vector.print lowering for SME tiles #66691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ bool isValidSMETileElementType(Type type);
/// otherwise.
bool isValidSMETileVectorType(VectorType vType);

/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
/// integer, to an i32 that can be passed as the `tile` parameter to the SME
/// intrinsics. Or returns `tile` if already i32.
Value castTileIDToI32(Value tile, Location loc, RewriterBase &rewriter);

} // namespace arm_sme
} // namespace mlir

Expand Down
93 changes: 91 additions & 2 deletions mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,94 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
}
};

/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
/// extracting them via a MOVA, then printing with a 1D `vector.print`.
///
/// BEFORE:
/// ```mlir
/// vector.print %tile : vector<[4]x[4]xf32>
/// ```
/// AFTER:
/// ```mlir
/// %c0 = arith.constant 0 : index
/// %c1 = arith.constant 1 : index
/// %c4 = arith.constant 4 : index
/// %ptrue = arith.constant dense<true> : vector<[4]xi1>
/// %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xf32> to i32
/// %vscale = vector.vscale
/// %svl_s = arith.muli %c4, %vscale : index
/// %cst = arith.constant dense<0.000000e+00> : vector<[4]xf32>
/// scf.for %i = %c0 to %svl_s step %c1 {
/// %slice_idx = arith.index_cast %i : index to i32
/// %tile_slice = "arm_sme.intr.read.horiz"
/// (%cst, %ptrue, %tile_id, %slice_idx)
/// : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
/// vector.print %tile_slice : vector<[4]xf32>
/// }
/// ```
struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
using OpRewritePattern<vector::PrintOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::PrintOp printOp,
PatternRewriter &rewriter) const override {
if (!printOp.getSource())
return failure();

VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
return failure();

auto loc = printOp.getLoc();

// Create an 'all true' predicate for each tile row.
auto predicateType =
VectorType::get(vectorType.getDimSize(1), rewriter.getI1Type(), true);
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predicateType, true));

// Cast tile to i32 tile ID.
auto tileId =
rewriter.create<arm_sme::CastVectorToTile>(loc, printOp.getSource());
Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);

// Zero destination/fallback for tile slice extraction.
auto rowType = VectorType::get(vectorType.getDimSize(1),
vectorType.getElementType(), true);
auto zeroVector = rewriter.create<arith::ConstantOp>(
loc, rowType, rewriter.getZeroAttr(rowType));

// Create a loop over the rows of the tile.
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
auto minTileRows =
rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
{
// Loop body.
rewriter.setInsertionPointToStart(forOp.getBody());
// Extract the current row from the tile.
Value rowIndex = forOp.getInductionVar();
auto rowIndexI32 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), rowIndex);
auto tileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
loc, rowType, zeroVector, allTruePredicate, tileIdI32, rowIndexI32);
// Print the row with a 1D vector.print.
rewriter.create<vector::PrintOp>(loc, tileSlice,
printOp.getPunctuation());
}

rewriter.eraseOp(printOp);
return success();
}
};

} // namespace

void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
patterns.add<TileLoadOpConversion, TileStoreOpConversion>(
patterns.getContext());
patterns.add<TileLoadOpConversion, TileStoreOpConversion,
TileVectorPrintOpConversion>(patterns.getContext());
}

namespace {
Expand All @@ -208,6 +291,12 @@ struct ConvertArmSMEToSCFPass
target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
arith::ArithDialect, scf::SCFDialect>();
target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
target.addDynamicallyLegalOp<vector::PrintOp>([](vector::PrintOp op) {
if (!op.getSource())
return true;
VectorType vectorType = dyn_cast<VectorType>(op.getPrintType());
return !vectorType || !arm_sme::isValidSMETileVectorType(vectorType);
});
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
Expand Down
17 changes: 0 additions & 17 deletions mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,6 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
}
};

/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
/// integer, to an i32 that can be passed as the `tile` parameter to the SME
/// intrinsics. Or returns `tile` if already i32.
Value castTileIDToI32(Value tile, Location loc,
ConversionPatternRewriter &rewriter) {
assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
tile.getDefiningOp())) &&
"expected ArmSME GetTileID or CastVectorToTile op!");
unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
if (tileElementWidth < 32)
return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
if (tileElementWidth > 32)
return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
return tile;
}

/// Lower 'arm_sme.zero' to SME intrinsics.
///
/// BEFORE:
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/ArmSME/Utils/Utils.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"

using namespace mlir;
Expand Down Expand Up @@ -42,3 +43,16 @@ bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {

return true;
}

Value mlir::arm_sme::castTileIDToI32(Value tile, Location loc,
RewriterBase &rewriter) {
assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
tile.getDefiningOp())) &&
"expected ArmSME GetTileID or CastVectorToTile op!");
unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
if (tileElementWidth < 32)
return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
if (tileElementWidth > 32)
return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
return tile;
}
22 changes: 22 additions & 0 deletions mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,25 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}

// -----

func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
{
vector.print %tile : vector<[4]x[4]xf32>
return
}
// CHECK-LABEL: func.func @arm_sme_tile_print(
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK-DAG: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
// CHECK-DAG: %[[ZERO_VECTOR:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
// CHECK-NEXT: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_cast %[[TILE_SLICE_INDEX]] : index to i32
// CHECK-NEXT: %[[TILE_SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VECTOR]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
// CHECK-NEXT: vector.print %[[TILE_SLICE]] : vector<[4]xf32>
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

llvm.func @printCString(!llvm.ptr<i8>)

func.func @printTileBegin() {
func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
Expand All @@ -22,7 +22,7 @@ func.func @printTileBegin() {
return
}

func.func @printTileEnd() {
func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
Expand All @@ -44,7 +44,6 @@ func.func @entry() {

// Allocate memory.
%mem1 = memref.alloca(%za_s_size) : memref<?xi32>
%mem2 = memref.alloca(%za_s_size) : memref<?xi32>

// Fill each "row" of "mem1" with row number.
//
Expand All @@ -66,11 +65,6 @@ func.func @entry() {
// Load tile from "mem1" vertically.
%0 = arm_sme.tile_load %mem1[%c0, %c0], <vertical> : memref<?xi32>, vector<[4]x[4]xi32>

// Store tile back to "mem2" to print.
// TODO: Support vector.print for 2-D scalable vectors so don't have to spill
// to memory and reload to print.
vector.store %0, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>

// 1. ORIGINAL HORIZONTAL LAYOUT
// Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
// 4x4xi32.
Expand Down Expand Up @@ -99,10 +93,7 @@ func.func @entry() {
// CHECK-NEXT: ( 0, 1, 2, 3
// CHECK: TILE END
func.call @printTileBegin() : () -> ()
scf.for %i = %c0 to %za_s_size step %svl_s {
%tileslice = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
vector.print %tileslice : vector<[4]xi32>
}
vector.print %0 : vector<[4]x[4]xi32>
func.call @printTileEnd() : () -> ()

return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

llvm.func @printCString(!llvm.ptr<i8>)

func.func @printTileBegin() {
func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
Expand All @@ -25,7 +25,7 @@ func.func @printTileBegin() {
return
}

func.func @printTileEnd() {
func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
Expand All @@ -41,20 +41,8 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
%tile = vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32>

// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
%vscale = vector.vscale
%min_elts_s = arith.constant 4 : index
%svl_s = arith.muli %min_elts_s, %vscale : index
%za_s_size = arith.muli %svl_s, %svl_s : index

// Allocate memory.
%mem = memref.alloca(%za_s_size) : memref<?xf32>

// Store the tile to memory.
vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>

// Reload and print. The smallest SVL is 128-bits so the tile will be at
// least 4x4xf32.
// Print the tile. The smallest SVL is 128-bits so the tile will be at least
// 4x4xf32.
//
// WITHOUT-ACC: TILE BEGIN
// WITHOUT-ACC-NEXT: ( 0, 0, 0, 0
Expand All @@ -63,10 +51,7 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
// WITHOUT-ACC-NEXT: ( 0, 3, 6, 9
// WITHOUT-ACC: TILE END
func.call @printTileBegin() : () -> ()
scf.for %i = %c0 to %za_s_size step %svl_s {
%tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
vector.print %tileslice : vector<[4]xf32>
}
vector.print %tile : vector<[4]x[4]xf32>
func.call @printTileEnd() : () -> ()

return
Expand All @@ -81,20 +66,8 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
%tile = vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>

// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
%vscale = vector.vscale
%min_elts_s = arith.constant 4 : index
%svl_s = arith.muli %min_elts_s, %vscale : index
%za_s_size = arith.muli %svl_s, %svl_s : index

// Allocate memory.
%mem = memref.alloca(%za_s_size) : memref<?xf32>

// Store the tile to memory.
vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>

// Reload and print. The smallest SVL is 128-bits so the tile will be at
// least 4x4xf32.
// Print the tile. The smallest SVL is 128-bits so the tile will be at least
// 4x4xf32.
//
// WITH-ACC: TILE BEGIN
// WITH-ACC-NEXT: ( 10, 10, 10, 10
Expand All @@ -103,10 +76,7 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
// WITH-ACC-NEXT: ( 10, 13, 16, 19
// WITH-ACC: TILE END
func.call @printTileBegin() : () -> ()
scf.for %i = %c0 to %za_s_size step %svl_s {
%tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
vector.print %tileslice : vector<[4]xf32>
}
vector.print %tile : vector<[4]x[4]xf32>
func.call @printTileEnd() : () -> ()

return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

llvm.func @printCString(!llvm.ptr<i8>)

func.func @printTileBegin() {
func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
Expand All @@ -22,7 +22,7 @@ func.func @printTileBegin() {
return
}

func.func @printTileEnd() {
func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
Expand All @@ -32,7 +32,6 @@ func.func @printTileEnd() {
}

func.func @test_outerproduct_with_accumulator_2x2xf64() {
%c0 = arith.constant 0 : index
%f1 = arith.constant 1.0 : f64
%f2 = arith.constant 2.0 : f64
%f10 = arith.constant 10.0 : f64
Expand All @@ -44,30 +43,15 @@ func.func @test_outerproduct_with_accumulator_2x2xf64() {

%tile = vector.outerproduct %a, %b, %c : vector<[2]xf64>, vector<[2]xf64>

// Calculate the size of a 64-bit tile, e.g. ZA{n}.d.
%vscale = vector.vscale
%min_elts_d = arith.constant 2 : index
%svl_d = arith.muli %min_elts_d, %vscale : index
%za_d_size = arith.muli %svl_d, %svl_d : index

// Allocate memory.
%mem = memref.alloca(%za_d_size) : memref<?xf64>

// Store the tile to memory.
vector.store %tile, %mem[%c0] : memref<?xf64>, vector<[2]x[2]xf64>

// Reload and print. The smallest SVL is 128-bits so the tile will be at
// least 2x2xf64.
// Print the tile. The smallest SVL is 128-bits so the tile will be at least
// 2x2xf64.
//
// CHECK: TILE BEGIN
// CHECK-NEXT: ( 12, 12
// CHECK-NEXT: ( 12, 12
// CHECK: TILE END
func.call @printTileBegin() : () -> ()
scf.for %i = %c0 to %za_d_size step %svl_d {
%tileslice = vector.load %mem[%i] : memref<?xf64>, vector<[2]xf64>
vector.print %tileslice : vector<[2]xf64>
}
vector.print %tile : vector<[2]x[2]xf64>
func.call @printTileEnd() : () -> ()

return
Expand Down
Loading