Skip to content

Commit c07bc6a

Browse files
committed
[mlir][ArmSME] Add support for lowering masked tile_load ops
This patch extends ArmSMEToSCF to support lowering of masked tile_load ops. Only masks created by 'vector.create_mask' are currently supported. There are two lowerings, one for pad of constant zero and another for non-zero pad. For the following example: %pad = arith.constant 0 : i32 %num_rows = arith.constant 2 : index %num_cols = arith.constant 4 : index %mask = vector.create_mask %num_rows, %num_cols : <[4]x[4]xi1> %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32> The former (constant non-zero pad) is lowered as follows: --------------------------------------------------------- %tile = arm_sme.zero : vector<[4]x[4]xi32> %num_cols = vector.create_mask %c4 : vector<[4]xi1> scf.for %slice_idx = %c0 to %num_rows step %c1 %tile_update = arm_sme.load_tile_slice %src[%slice_idx], %num_cols, %tile, %tile_slice_idx : memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32> The tile is zeroed the satisfy the padding and only active rows are loaded. The latter (non-zero pad) is lowered as follows: ------------------------------------------------ scf.for %slice_idx = %c0 to %num_tile_slices step %c1 { %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index %slice = scf.if %row_is_active -> vector<[4]xf32> { %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d : memref<?x?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> scf.yield %slice : vector<[4]xf32> } else { scf.yield %pad_1d : vector<[4]xf32> } arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx : vector<[4]xi32> into vector<[4]x[4]xi32> The scalar pad is broadcast to a 1-D vector and a regular 'vector.masked_load' (will be lowered to SVE, not SME) loads each slice for active rows, with padding specified as a passthru. For non-active rows the slice is the 1-D pad. The resulting slice is inserted into the tile with 'arm_sme.move_vector_to_tile_slice'.
1 parent a5c1eca commit c07bc6a

File tree

3 files changed

+519
-3
lines changed

3 files changed

+519
-3
lines changed

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 251 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
8080
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
8181
PatternRewriter &rewriter) const override {
8282
if (tileLoadOp.getMask())
83-
// TODO: add masked patterns.
8483
return rewriter.notifyMatchFailure(
85-
tileLoadOp, "op has mask, needs masked pattern(s)");
84+
tileLoadOp, "op has mask, apply masked patterns");
8685

8786
OpBuilder::InsertionGuard g(rewriter);
8887
auto loc = tileLoadOp.getLoc();
@@ -142,6 +141,254 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
142141
}
143142
};
144143

144+
/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
145+
///
146+
/// BEFORE:
147+
/// ```mlir
148+
/// %pad = arith.constant 0 : i32
149+
/// %num_rows = arith.constant 2 : index
150+
/// %num_cols = arith.constant 4 : index
151+
/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
152+
/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
153+
/// memref<?x?xi32>, vector<[4]x[4]xi32>
154+
/// ```
155+
///
156+
/// AFTER:
157+
/// ```mlir
158+
/// %c0 = arith.constant 0 : index
159+
/// %c1 = arith.constant 1 : index
160+
/// %tile = arm_sme.zero : vector<[4]x[4]xi32>
161+
/// %num_cols = vector.create_mask %c4 : vector<[4]xi1>
162+
/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
163+
/// %tile_update = arm_sme.load_tile_slice
164+
/// %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
165+
/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
166+
/// }
167+
/// ```
168+
///
169+
/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
170+
struct TileLoadOpWithMaskAndPadZeroConversion
171+
: public OpRewritePattern<arm_sme::TileLoadOp> {
172+
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
173+
174+
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
175+
PatternRewriter &rewriter) const override {
176+
OpBuilder::InsertionGuard g(rewriter);
177+
auto loc = tileLoadOp.getLoc();
178+
auto tileType = tileLoadOp.getVectorType();
179+
180+
auto maskOp = tileLoadOp.getMask();
181+
if (!maskOp)
182+
return rewriter.notifyMatchFailure(
183+
tileLoadOp, "op has no mask, needs unmasked pattern");
184+
185+
auto padOp = tileLoadOp.getPadding();
186+
assert(padOp && "expected padding when masking!");
187+
188+
auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
189+
if (!createMaskOp)
190+
return rewriter.notifyMatchFailure(
191+
tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
192+
"currently supported");
193+
194+
auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
195+
if (!constPadOp || constPadOp.getValue() !=
196+
rewriter.getZeroAttr(tileType.getElementType()))
197+
return rewriter.notifyMatchFailure(
198+
tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
199+
200+
auto numRows = createMaskOp.getOperands()[0];
201+
auto numCols = createMaskOp.getOperands()[1];
202+
203+
auto predicateType =
204+
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
205+
auto numColsOp =
206+
rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
207+
208+
// Initialize tile with zero to satisfy padding. Inactive cols will be
209+
// zeroed anyway since the loads use zeroing predication. For inactive rows
210+
// however, no load will occur so these need to be zeroed.
211+
auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
212+
213+
// Create a loop to load the active tile slices from memory.
214+
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
215+
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
216+
auto upperBound = numRows;
217+
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
218+
219+
rewriter.setInsertionPointToStart(forOp.getBody());
220+
221+
// Create 'arm_sme.load_tile_slice' to load tile slice from memory into
222+
// tile.
223+
SmallVector<Value> memrefIndices;
224+
auto tileSliceIndex = forOp.getInductionVar();
225+
getMemrefIndices(tileLoadOp.getIndices(),
226+
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
227+
upperBound, memrefIndices, loc, rewriter);
228+
rewriter.create<arm_sme::LoadTileSliceOp>(
229+
loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
230+
tileSliceIndex, tileLoadOp.getLayout());
231+
232+
rewriter.setInsertionPointAfter(forOp);
233+
234+
// Replace 'arm_sme.tile_load' with the tile.
235+
rewriter.replaceOp(tileLoadOp, tile);
236+
237+
return success();
238+
}
239+
};
240+
241+
/// Lower `arm_sme.tile_load` with mask and non-zero pad.
242+
///
243+
/// BEFORE:
244+
/// ```mlir
245+
/// %pad = arith.constant 1 : i32
246+
/// %num_rows = arith.constant 2 : index
247+
/// %num_cols = arith.constant 4 : index
248+
/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
249+
/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
250+
/// memref<?x?xi32>, vector<[4]x[4]xi32>
251+
/// ```
252+
///
253+
/// AFTER:
254+
/// ```mlir
255+
/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
256+
/// %num_rows = arith.constant 2 : index
257+
/// %num_cols = arith.constant 4 : index
258+
/// %tile_id = arm_sme.get_tile_id : i32
259+
/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
260+
/// %vscale = vector.vscale
261+
/// %c0 = arith.constant 0 : index
262+
/// %c1 = arith.constant 1 : index
263+
/// %min_svl_s = arith.constant 4 : index
264+
/// %svl_s = arith.muli %min_svl_s, %vscale : index
265+
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
266+
/// %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
267+
/// %slice = scf.if %row_is_active -> vector<[4]xi32> {
268+
/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad
269+
/// : memref<?x?xi32>, vector<[4]xi1>,
270+
/// vector<[4]xi32> into vector<[4]xi32>
271+
/// scf.yield %slice : vector<[4]xi32>
272+
/// } else {
273+
/// scf.yield %pad_1d : vector<[4]xi32>
274+
/// }
275+
/// // Insert slice into tile
276+
/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
277+
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
278+
/// }
279+
/// ```
280+
struct TileLoadOpWithMaskAndPadNonZeroConversion
281+
: public OpRewritePattern<arm_sme::TileLoadOp> {
282+
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
283+
284+
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
285+
PatternRewriter &rewriter) const override {
286+
OpBuilder::InsertionGuard g(rewriter);
287+
auto loc = tileLoadOp.getLoc();
288+
auto tileType = tileLoadOp.getVectorType();
289+
auto tileElementType = tileType.getElementType();
290+
unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
291+
292+
auto maskOp = tileLoadOp.getMask();
293+
if (!maskOp)
294+
return rewriter.notifyMatchFailure(
295+
tileLoadOp, "op has no mask, needs unmasked pattern");
296+
297+
auto padOp = tileLoadOp.getPadding();
298+
assert(padOp && "expected padding when masking!");
299+
300+
auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
301+
if (!createMaskOp)
302+
return rewriter.notifyMatchFailure(
303+
tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
304+
"currently supported");
305+
306+
auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
307+
if (constPadOp &&
308+
constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
309+
return rewriter.notifyMatchFailure(
310+
tileLoadOp, "op has constant zero pad, needs zero pad pattern");
311+
312+
auto numRows = createMaskOp.getOperands()[0];
313+
auto numCols = createMaskOp.getOperands()[1];
314+
315+
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
316+
auto predicateType =
317+
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
318+
auto numColsOp =
319+
rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
320+
321+
// Create 'arm_sme.get_tile' op.
322+
auto tileId = rewriter.create<arm_sme::GetTileID>(
323+
loc, rewriter.getIntegerType(tileElementWidth));
324+
325+
// Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
326+
// use as input tile to 'arm_sme.load_tile_slice' ops.
327+
auto tile =
328+
rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
329+
330+
// Create a loop that loads each ZA tile slice from memory.
331+
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
332+
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
333+
loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
334+
auto vscale =
335+
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
336+
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
337+
auto numTileSlices =
338+
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
339+
auto forOp =
340+
rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
341+
342+
rewriter.setInsertionPointToStart(forOp.getBody());
343+
344+
auto tileSliceIndex = forOp.getInductionVar();
345+
346+
auto rowIsActive = rewriter.create<arith::CmpIOp>(
347+
loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
348+
349+
SmallVector<Value> memrefIndices;
350+
getMemrefIndices(tileLoadOp.getIndices(),
351+
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
352+
numTileSlices, memrefIndices, loc, rewriter);
353+
354+
// Splat pad into 1-D vector matching type of tile slice.
355+
auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
356+
357+
Operation *slice = rewriter.create<scf::IfOp>(
358+
loc, rowIsActive,
359+
[&](OpBuilder &b, Location loc) {
360+
// If the row is active, emit a masked load where the predicate is
361+
// 'numCols'. Pad is used for inactive elements, taken from
362+
// passthru.
363+
auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
364+
loc, tileSliceType, tileLoadOp.getBase(), memrefIndices,
365+
numColsOp, /*passthru=*/pad1DOp);
366+
rewriter.create<scf::YieldOp>(loc, loadSlice->getResult(0));
367+
},
368+
[&](OpBuilder &b, Location loc) {
369+
// Inactive rows are filled with pad.
370+
rewriter.create<scf::YieldOp>(loc, pad1DOp.getResult());
371+
});
372+
373+
// TODO: If the load is vertical the transpose can't be done in-flight with
374+
// a regular (SVE) maskedload. Propagate layout to
375+
// 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This
376+
// is currently broken.
377+
378+
// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
379+
rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
380+
loc, tileType, slice->getResult(0), tile, tileSliceIndex,
381+
tileLoadOp.getLayout());
382+
383+
rewriter.setInsertionPointAfter(forOp);
384+
385+
// Replace 'arm_sme.tile_load' with the tile.
386+
rewriter.replaceOp(tileLoadOp, tile);
387+
388+
return success();
389+
}
390+
};
391+
145392
/// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
146393
/// slice using `arm_sme.store_tile_slice`.
147394
///
@@ -294,7 +541,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
294541
} // namespace
295542

296543
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
297-
patterns.add<TileLoadOpConversion, TileStoreOpConversion,
544+
patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
545+
TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
298546
TileVectorPrintOpConversion>(patterns.getContext());
299547
}
300548

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,62 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
3333
return
3434
}
3535

36+
// -----
37+
38+
// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(
39+
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
40+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
41+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
42+
// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
43+
// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
44+
// CHECK-DAG: %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
45+
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
46+
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
47+
// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
48+
func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) {
49+
%c0 = arith.constant 0 : index
50+
%c2 = arith.constant 2 : index
51+
%c3 = arith.constant 3 : index
52+
%pad = arith.constant 0 : i32
53+
%mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
54+
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
55+
return
56+
}
57+
58+
// -----
59+
60+
// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
61+
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>,
62+
// CHECK-SAME: %[[PAD:.*]]: i32) {
63+
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
64+
// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
65+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
66+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
67+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
68+
// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
69+
// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
70+
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
71+
// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
72+
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
73+
// CHECK-NEXT: %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
74+
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
75+
// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
76+
// CHECK: %[[SLICE:.*]] = scf.if %[[ROW_IS_ACTIVE]] -> (vector<[4]xi32>) {
77+
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
78+
// CHECK: scf.yield %[[LOAD_SLICE]] : vector<[4]xi32>
79+
// CHECK: } else {
80+
// CHECK: scf.yield %[[PAD_1D]] : vector<[4]xi32>
81+
// CHECK: }
82+
// CHECK: arm_sme.move_vector_to_tile_slice %[[SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
83+
func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
84+
%c0 = arith.constant 0 : index
85+
%c2 = arith.constant 2 : index
86+
%c3 = arith.constant 3 : index
87+
%mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
88+
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
89+
return
90+
}
91+
3692
//===----------------------------------------------------------------------===//
3793
// arm_sme.tile_store
3894
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)