Skip to content

Commit 9783cf4

Browse files
authored
[mlir][ArmSME] Add support for lowering masked tile_load ops (#70915)
This patch extends ArmSMEToSCF to support lowering of masked tile_load ops. Only masks created by 'vector.create_mask' are currently supported. There are two lowerings depending on the pad. For pad of constant zero, the tile is first zeroed, then only active rows are loaded. For non-zero pad, the scalar pad is broadcast to a 1-D vector and a regular 'vector.masked_load' (will be lowered to SVE, not SME) loads each slice, with padding specified as a passthru and the 2-D mask combined into a 1-D mask. The resulting slice is then inserted into the tile with 'arm_sme.move_vector_to_tile_slice'.
1 parent 3b905a0 commit 9783cf4

File tree

3 files changed

+520
-5
lines changed

3 files changed

+520
-5
lines changed

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 232 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
8080
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
8181
PatternRewriter &rewriter) const override {
8282
if (tileLoadOp.getMask())
83-
// TODO: add masked patterns.
84-
return rewriter.notifyMatchFailure(
85-
tileLoadOp, "op has mask, needs masked pattern(s)");
83+
return rewriter.notifyMatchFailure(tileLoadOp,
84+
"op has mask, apply masked patterns");
8685

8786
OpBuilder::InsertionGuard g(rewriter);
8887
auto loc = tileLoadOp.getLoc();
@@ -142,6 +141,234 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
142141
}
143142
};
144143

144+
/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
145+
///
146+
/// BEFORE:
147+
/// ```mlir
148+
/// %pad = arith.constant 0 : i32
149+
/// %num_rows = arith.constant 2 : index
150+
/// %num_cols = arith.constant 4 : index
151+
/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
152+
/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
153+
/// memref<?x?xi32>, vector<[4]x[4]xi32>
154+
/// ```
155+
///
156+
/// AFTER:
157+
/// ```mlir
158+
/// %c0 = arith.constant 0 : index
159+
/// %c1 = arith.constant 1 : index
160+
/// %tile = arm_sme.zero : vector<[4]x[4]xi32>
161+
/// %num_rows = arith.constant 2 : index
162+
/// %num_cols = vector.create_mask %c4 : vector<[4]xi1>
163+
/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
164+
/// %tile_update = arm_sme.load_tile_slice
165+
/// %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
166+
/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
167+
/// }
168+
/// ```
169+
///
170+
/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
171+
struct TileLoadOpWithMaskAndPadZeroConversion
172+
: public OpRewritePattern<arm_sme::TileLoadOp> {
173+
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
174+
175+
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
176+
PatternRewriter &rewriter) const override {
177+
OpBuilder::InsertionGuard g(rewriter);
178+
auto loc = tileLoadOp.getLoc();
179+
auto tileType = tileLoadOp.getVectorType();
180+
181+
auto maskOp = tileLoadOp.getMask();
182+
if (!maskOp)
183+
return rewriter.notifyMatchFailure(
184+
tileLoadOp, "op has no mask, needs unmasked pattern");
185+
186+
auto padOp = tileLoadOp.getPadding();
187+
assert(padOp && "expected padding when masking!");
188+
189+
auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
190+
if (!createMaskOp)
191+
return rewriter.notifyMatchFailure(
192+
tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
193+
"currently supported");
194+
195+
auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
196+
if (!constPadOp || constPadOp.getValue() !=
197+
rewriter.getZeroAttr(tileType.getElementType()))
198+
return rewriter.notifyMatchFailure(
199+
tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
200+
201+
auto numRows = createMaskOp.getOperands()[0];
202+
auto numCols = createMaskOp.getOperands()[1];
203+
204+
auto predicateType =
205+
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
206+
auto numColsOp =
207+
rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
208+
209+
// Initialize tile with zero to satisfy padding. Inactive cols will be
210+
// zeroed anyway since the loads use zeroing predication. For inactive rows
211+
// however, no load will occur so these need to be zeroed.
212+
auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
213+
214+
// Create a loop to load the active tile slices from memory.
215+
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
216+
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
217+
auto upperBound = numRows;
218+
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
219+
220+
rewriter.setInsertionPointToStart(forOp.getBody());
221+
222+
// Create 'arm_sme.load_tile_slice' to load tile slice from memory into
223+
// tile.
224+
SmallVector<Value> memrefIndices;
225+
auto tileSliceIndex = forOp.getInductionVar();
226+
getMemrefIndices(tileLoadOp.getIndices(),
227+
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
228+
upperBound, memrefIndices, loc, rewriter);
229+
rewriter.create<arm_sme::LoadTileSliceOp>(
230+
loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
231+
tileSliceIndex, tileLoadOp.getLayout());
232+
233+
rewriter.setInsertionPointAfter(forOp);
234+
235+
// Replace 'arm_sme.tile_load' with the tile.
236+
rewriter.replaceOp(tileLoadOp, tile);
237+
238+
return success();
239+
}
240+
};
241+
242+
/// Lower `arm_sme.tile_load` with mask and non-zero pad.
243+
///
244+
/// BEFORE:
245+
/// ```mlir
246+
/// %pad = arith.constant 1 : i32
247+
/// %num_rows = arith.constant 2 : index
248+
/// %num_cols = arith.constant 4 : index
249+
/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
250+
/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
251+
/// memref<?x?xi32>, vector<[4]x[4]xi32>
252+
/// ```
253+
///
254+
/// AFTER:
255+
/// ```mlir
256+
/// ...
257+
/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
258+
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
259+
/// ...
260+
/// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
261+
/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
262+
/// : memref<?x?xi32>, vector<[4]xi1>,
263+
/// vector<[4]xi32> into vector<[4]xi32>
264+
/// // Insert slice into tile
265+
/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
266+
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
267+
/// }
268+
/// ```
269+
struct TileLoadOpWithMaskAndPadNonZeroConversion
270+
: public OpRewritePattern<arm_sme::TileLoadOp> {
271+
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
272+
273+
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
274+
PatternRewriter &rewriter) const override {
275+
OpBuilder::InsertionGuard g(rewriter);
276+
auto loc = tileLoadOp.getLoc();
277+
auto tileType = tileLoadOp.getVectorType();
278+
auto tileElementType = tileType.getElementType();
279+
unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
280+
281+
auto maskOp = tileLoadOp.getMask();
282+
if (!maskOp)
283+
return rewriter.notifyMatchFailure(
284+
tileLoadOp, "op has no mask, needs unmasked pattern");
285+
286+
auto padOp = tileLoadOp.getPadding();
287+
assert(padOp && "expected padding when masking!");
288+
289+
auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
290+
if (!createMaskOp)
291+
return rewriter.notifyMatchFailure(
292+
tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
293+
"currently supported");
294+
295+
auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
296+
if (constPadOp &&
297+
constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
298+
return rewriter.notifyMatchFailure(
299+
tileLoadOp, "op has constant zero pad, needs zero pad pattern");
300+
301+
auto numRows = createMaskOp.getOperands()[0];
302+
auto numCols = createMaskOp.getOperands()[1];
303+
304+
auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
305+
loc, rewriter.getI32Type(), numCols);
306+
307+
// Create 'arm_sme.get_tile' op.
308+
auto tileId = rewriter.create<arm_sme::GetTileID>(
309+
loc, rewriter.getIntegerType(tileElementWidth));
310+
311+
// Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
312+
// use as input tile to 'arm_sme.load_tile_slice' ops.
313+
auto tile =
314+
rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
315+
316+
// Create a loop that loads each ZA tile slice from memory.
317+
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
318+
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
319+
loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
320+
auto vscale =
321+
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
322+
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
323+
auto numTileSlices =
324+
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
325+
auto forOp =
326+
rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
327+
328+
rewriter.setInsertionPointToStart(forOp.getBody());
329+
330+
auto tileSliceIndex = forOp.getInductionVar();
331+
332+
// Combine masks.
333+
auto rowIsActive = rewriter.create<arith::CmpIOp>(
334+
loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
335+
auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
336+
loc, rewriter.getI32Type(), rowIsActive);
337+
auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
338+
auto maskIndex =
339+
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
340+
auto predicateType =
341+
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
342+
auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
343+
loc, predicateType, maskIndex.getResult());
344+
345+
SmallVector<Value> memrefIndices;
346+
getMemrefIndices(tileLoadOp.getIndices(),
347+
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
348+
numTileSlices, memrefIndices, loc, rewriter);
349+
350+
// Splat pad into 1-D vector matching type of tile slice.
351+
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
352+
auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
353+
354+
auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
355+
loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
356+
/*passthru=*/pad1DOp);
357+
358+
// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
359+
rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
360+
loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
361+
tileLoadOp.getLayout());
362+
363+
rewriter.setInsertionPointAfter(forOp);
364+
365+
// Replace 'arm_sme.tile_load' with the tile.
366+
rewriter.replaceOp(tileLoadOp, tile);
367+
368+
return success();
369+
}
370+
};
371+
145372
/// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
146373
/// slice using `arm_sme.store_tile_slice`.
147374
///
@@ -294,7 +521,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
294521
} // namespace
295522

296523
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
297-
patterns.add<TileLoadOpConversion, TileStoreOpConversion,
524+
patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
525+
TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
298526
TileVectorPrintOpConversion>(patterns.getContext());
299527
}
300528

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file -verify-diagnostics | FileCheck %s
22

33
//===----------------------------------------------------------------------===//
44
// arm_sme.tile_load
@@ -33,6 +33,81 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
3333
return
3434
}
3535

36+
// -----
37+
38+
// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(
39+
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
40+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
41+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
42+
// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
43+
// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
44+
// CHECK-DAG: %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
45+
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
46+
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
47+
// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
48+
func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) {
49+
%c0 = arith.constant 0 : index
50+
%c2 = arith.constant 2 : index
51+
%c3 = arith.constant 3 : index
52+
%pad = arith.constant 0 : i32
53+
%mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
54+
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
55+
return
56+
}
57+
58+
// -----
59+
60+
// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
61+
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>,
62+
// CHECK-SAME: %[[PAD:.*]]: i32) {
63+
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
64+
// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
65+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
66+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
67+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
68+
// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
69+
// CHECK-DAG: %[[NUM_COLS:.*]] = arith.constant 2 : index
70+
// CHECK-DAG: %[[NUM_COLS_I32:.*]] = arith.index_castui %[[NUM_COLS]] : index to i32
71+
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
72+
// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
73+
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
74+
// CHECK-NEXT: %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
75+
// CHECK-NEXT: %[[ROW_IS_ACTIVE_SEXT_I32:.*]] = arith.extsi %[[ROW_IS_ACTIVE]] : i1 to i32
76+
// CHECK-NEXT: %[[MASK:.*]] = arith.andi %[[ROW_IS_ACTIVE_SEXT_I32]], %[[NUM_COLS_I32]] : i32
77+
// CHECK-NEXT: %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index
78+
// CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
79+
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
80+
// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
81+
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
82+
// CHECK: arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
83+
func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
84+
%c0 = arith.constant 0 : index
85+
%c2 = arith.constant 2 : index
86+
%c3 = arith.constant 3 : index
87+
%mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
88+
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
89+
return
90+
}
91+
92+
// -----
93+
94+
func.func @arm_sme_tile_load_zero_pad__unsupported_mask_op(%src : memref<?x?xi32>, %mask : vector<[4]x[4]xi1>) {
95+
%c0 = arith.constant 0 : index
96+
%pad = arith.constant 0 : i32
97+
// expected-error@+1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
98+
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
99+
return
100+
}
101+
102+
// -----
103+
104+
func.func @arm_sme_tile_load_nonzero_pad__unsupported_mask_op(%src : memref<?x?xi32>, %pad : i32, %mask : vector<[4]x[4]xi1>) {
105+
%c0 = arith.constant 0 : index
106+
// expected-error@+1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
107+
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
108+
return
109+
}
110+
36111
//===----------------------------------------------------------------------===//
37112
// arm_sme.tile_store
38113
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)