Skip to content

[mlir][ArmSME] Add support for lowering masked tile_load ops #70915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 232 additions & 4 deletions mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
if (tileLoadOp.getMask())
// TODO: add masked patterns.
return rewriter.notifyMatchFailure(
tileLoadOp, "op has mask, needs masked pattern(s)");
return rewriter.notifyMatchFailure(tileLoadOp,
"op has mask, apply masked patterns");

OpBuilder::InsertionGuard g(rewriter);
auto loc = tileLoadOp.getLoc();
Expand Down Expand Up @@ -142,6 +141,234 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
}
};

/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
///
/// BEFORE:
/// ```mlir
/// %pad = arith.constant 0 : i32
/// %num_rows = arith.constant 2 : index
/// %num_cols = arith.constant 4 : index
/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
/// memref<?x?xi32>, vector<[4]x[4]xi32>
/// ```
///
/// AFTER:
/// ```mlir
/// %c0 = arith.constant 0 : index
/// %c1 = arith.constant 1 : index
/// %tile = arm_sme.zero : vector<[4]x[4]xi32>
/// %num_rows = arith.constant 2 : index
/// %num_cols = vector.create_mask %c4 : vector<[4]xi1>
/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
/// %tile_update = arm_sme.load_tile_slice
/// %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
/// }
/// ```
///
/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
struct TileLoadOpWithMaskAndPadZeroConversion
: public OpRewritePattern<arm_sme::TileLoadOp> {
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
OpBuilder::InsertionGuard g(rewriter);
auto loc = tileLoadOp.getLoc();
auto tileType = tileLoadOp.getVectorType();

auto maskOp = tileLoadOp.getMask();
if (!maskOp)
return rewriter.notifyMatchFailure(
tileLoadOp, "op has no mask, needs unmasked pattern");

auto padOp = tileLoadOp.getPadding();
assert(padOp && "expected padding when masking!");

auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
"currently supported");

auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
if (!constPadOp || constPadOp.getValue() !=
rewriter.getZeroAttr(tileType.getElementType()))
return rewriter.notifyMatchFailure(
tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");

auto numRows = createMaskOp.getOperands()[0];
auto numCols = createMaskOp.getOperands()[1];

auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
auto numColsOp =
rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);

// Initialize tile with zero to satisfy padding. Inactive cols will be
// zeroed anyway since the loads use zeroing predication. For inactive rows
// however, no load will occur so these need to be zeroed.
auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);

// Create a loop to load the active tile slices from memory.
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto upperBound = numRows;
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);

rewriter.setInsertionPointToStart(forOp.getBody());

// Create 'arm_sme.load_tile_slice' to load tile slice from memory into
// tile.
SmallVector<Value> memrefIndices;
auto tileSliceIndex = forOp.getInductionVar();
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
upperBound, memrefIndices, loc, rewriter);
rewriter.create<arm_sme::LoadTileSliceOp>(
loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
tileSliceIndex, tileLoadOp.getLayout());

rewriter.setInsertionPointAfter(forOp);

// Replace 'arm_sme.tile_load' with the tile.
rewriter.replaceOp(tileLoadOp, tile);

return success();
}
};

/// Lower `arm_sme.tile_load` with mask and non-zero pad.
///
/// BEFORE:
/// ```mlir
/// %pad = arith.constant 1 : i32
/// %num_rows = arith.constant 2 : index
/// %num_cols = arith.constant 4 : index
/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
/// memref<?x?xi32>, vector<[4]x[4]xi32>
/// ```
///
/// AFTER:
/// ```mlir
/// ...
/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
/// ...
/// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
/// : memref<?x?xi32>, vector<[4]xi1>,
/// vector<[4]xi32> into vector<[4]xi32>
/// // Insert slice into tile
/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
/// }
/// ```
struct TileLoadOpWithMaskAndPadNonZeroConversion
: public OpRewritePattern<arm_sme::TileLoadOp> {
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
OpBuilder::InsertionGuard g(rewriter);
auto loc = tileLoadOp.getLoc();
auto tileType = tileLoadOp.getVectorType();
auto tileElementType = tileType.getElementType();
unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();

auto maskOp = tileLoadOp.getMask();
if (!maskOp)
return rewriter.notifyMatchFailure(
tileLoadOp, "op has no mask, needs unmasked pattern");

auto padOp = tileLoadOp.getPadding();
assert(padOp && "expected padding when masking!");

auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
"currently supported");

auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
if (constPadOp &&
constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
return rewriter.notifyMatchFailure(
tileLoadOp, "op has constant zero pad, needs zero pad pattern");

auto numRows = createMaskOp.getOperands()[0];
auto numCols = createMaskOp.getOperands()[1];

auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
loc, rewriter.getI32Type(), numCols);

// Create 'arm_sme.get_tile' op.
auto tileId = rewriter.create<arm_sme::GetTileID>(
loc, rewriter.getIntegerType(tileElementWidth));

// Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
// use as input tile to 'arm_sme.load_tile_slice' ops.
auto tile =
rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);

// Create a loop that loads each ZA tile slice from memory.
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
auto forOp =
rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);

rewriter.setInsertionPointToStart(forOp.getBody());

auto tileSliceIndex = forOp.getInductionVar();

// Combine masks.
auto rowIsActive = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
loc, rewriter.getI32Type(), rowIsActive);
auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
auto maskIndex =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
loc, predicateType, maskIndex.getResult());

SmallVector<Value> memrefIndices;
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
numTileSlices, memrefIndices, loc, rewriter);

// Splat pad into 1-D vector matching type of tile slice.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);

auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
/*passthru=*/pad1DOp);

// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
tileLoadOp.getLayout());

rewriter.setInsertionPointAfter(forOp);

// Replace 'arm_sme.tile_load' with the tile.
rewriter.replaceOp(tileLoadOp, tile);

return success();
}
};

/// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
/// slice using `arm_sme.store_tile_slice`.
///
Expand Down Expand Up @@ -294,7 +521,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
} // namespace

void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
patterns.add<TileLoadOpConversion, TileStoreOpConversion,
patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
TileVectorPrintOpConversion>(patterns.getContext());
}

Expand Down
77 changes: 76 additions & 1 deletion mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file -verify-diagnostics | FileCheck %s

//===----------------------------------------------------------------------===//
// arm_sme.tile_load
Expand Down Expand Up @@ -33,6 +33,81 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
return
}

// -----

// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
// CHECK-DAG: %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%pad = arith.constant 0 : i32
%mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}

// -----

// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>,
// CHECK-SAME: %[[PAD:.*]]: i32) {
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[NUM_COLS:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[NUM_COLS_I32:.*]] = arith.index_castui %[[NUM_COLS]] : index to i32
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
// CHECK-NEXT: %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
// CHECK-NEXT: %[[ROW_IS_ACTIVE_SEXT_I32:.*]] = arith.extsi %[[ROW_IS_ACTIVE]] : i1 to i32
// CHECK-NEXT: %[[MASK:.*]] = arith.andi %[[ROW_IS_ACTIVE_SEXT_I32]], %[[NUM_COLS_I32]] : i32
// CHECK-NEXT: %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index
// CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
// CHECK: arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}

// -----

func.func @arm_sme_tile_load_zero_pad__unsupported_mask_op(%src : memref<?x?xi32>, %mask : vector<[4]x[4]xi1>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i32
// expected-error@+1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}

// -----

func.func @arm_sme_tile_load_nonzero_pad__unsupported_mask_op(%src : memref<?x?xi32>, %pad : i32, %mask : vector<[4]x[4]xi1>) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}

//===----------------------------------------------------------------------===//
// arm_sme.tile_store
//===----------------------------------------------------------------------===//
Expand Down
Loading