Skip to content

Commit a4e1541

Browse files
authored
[mlir][ArmSME] Move creation of load/store intrinsics to helpers (NFC) (#76168)
Also, for consistency make the ZeroOp lowering switch on the ArmSMETileType, rather than the element bit width.
1 parent 88151dd commit a4e1541

File tree

1 file changed

+108
-119
lines changed

1 file changed

+108
-119
lines changed

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 108 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,95 @@ using namespace mlir;
3232

3333
namespace {
3434

35+
/// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
36+
static Operation *createLoadTileSliceIntrinsic(
37+
RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
38+
arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
39+
IntegerAttr tileId, Value tileSliceI32) {
40+
if (layout == arm_sme::TileSliceLayout::Horizontal) {
41+
switch (type) {
42+
case arm_sme::ArmSMETileType::ZAB:
43+
return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
44+
loc, maskOp, ptr, tileId, tileSliceI32);
45+
case arm_sme::ArmSMETileType::ZAH:
46+
return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
47+
loc, maskOp, ptr, tileId, tileSliceI32);
48+
case arm_sme::ArmSMETileType::ZAS:
49+
return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
50+
loc, maskOp, ptr, tileId, tileSliceI32);
51+
case arm_sme::ArmSMETileType::ZAD:
52+
return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
53+
loc, maskOp, ptr, tileId, tileSliceI32);
54+
case arm_sme::ArmSMETileType::ZAQ:
55+
return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
56+
loc, maskOp, ptr, tileId, tileSliceI32);
57+
}
58+
} else {
59+
switch (type) {
60+
case arm_sme::ArmSMETileType::ZAB:
61+
return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
62+
loc, maskOp, ptr, tileId, tileSliceI32);
63+
case arm_sme::ArmSMETileType::ZAH:
64+
return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
65+
loc, maskOp, ptr, tileId, tileSliceI32);
66+
case arm_sme::ArmSMETileType::ZAS:
67+
return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
68+
loc, maskOp, ptr, tileId, tileSliceI32);
69+
case arm_sme::ArmSMETileType::ZAD:
70+
return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
71+
loc, maskOp, ptr, tileId, tileSliceI32);
72+
case arm_sme::ArmSMETileType::ZAQ:
73+
return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
74+
loc, maskOp, ptr, tileId, tileSliceI32);
75+
break;
76+
}
77+
}
78+
}
79+
80+
/// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
81+
static Operation *createStoreTileSliceIntrinsic(
82+
RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
83+
arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
84+
IntegerAttr tileId, Value tileSliceI32) {
85+
if (layout == arm_sme::TileSliceLayout::Horizontal) {
86+
switch (type) {
87+
case arm_sme::ArmSMETileType::ZAB:
88+
return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
89+
loc, maskOp, ptr, tileId, tileSliceI32);
90+
case arm_sme::ArmSMETileType::ZAH:
91+
return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
92+
loc, maskOp, ptr, tileId, tileSliceI32);
93+
case arm_sme::ArmSMETileType::ZAS:
94+
return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
95+
loc, maskOp, ptr, tileId, tileSliceI32);
96+
case arm_sme::ArmSMETileType::ZAD:
97+
return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
98+
loc, maskOp, ptr, tileId, tileSliceI32);
99+
case arm_sme::ArmSMETileType::ZAQ:
100+
return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
101+
loc, maskOp, ptr, tileId, tileSliceI32);
102+
}
103+
} else {
104+
switch (type) {
105+
case arm_sme::ArmSMETileType::ZAB:
106+
return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
107+
loc, maskOp, ptr, tileId, tileSliceI32);
108+
case arm_sme::ArmSMETileType::ZAH:
109+
return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
110+
loc, maskOp, ptr, tileId, tileSliceI32);
111+
case arm_sme::ArmSMETileType::ZAS:
112+
return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
113+
loc, maskOp, ptr, tileId, tileSliceI32);
114+
case arm_sme::ArmSMETileType::ZAD:
115+
return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
116+
loc, maskOp, ptr, tileId, tileSliceI32);
117+
case arm_sme::ArmSMETileType::ZAQ:
118+
return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
119+
loc, maskOp, ptr, tileId, tileSliceI32);
120+
}
121+
}
122+
}
123+
35124
IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
36125
auto tileId = op.getTileId();
37126
if (!tileId)
@@ -75,9 +164,6 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
75164
ConversionPatternRewriter &rewriter) const override {
76165
auto loc = zero.getLoc();
77166

78-
unsigned tileElementWidth =
79-
zero.getVectorType().getElementType().getIntOrFloatBitWidth();
80-
81167
auto tileId = getTileIdOrError(zero);
82168
if (!tileId)
83169
return failure();
@@ -86,23 +172,24 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
86172
// The base mask is just the mask to zero the first tile (of a size).
87173
// These masks are derived from:
88174
// https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
175+
arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType();
89176
auto baseMaskForSize = [&] {
90-
switch (tileElementWidth) {
91-
case 8:
177+
switch (tileType) {
178+
case arm_sme::ArmSMETileType::ZAB:
92179
// Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
93180
// 64-bit element tiles named ZA0.D to ZA7.D.
94181
return 0b1111'1111;
95-
case 16:
96-
// Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element
97-
// tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D.
98-
// Shift this left once for ZA1.H.
182+
case arm_sme::ArmSMETileType::ZAH:
183+
// Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
184+
// element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
185+
// once for ZA1.H.
99186
return 0b0101'0101;
100-
case 32:
187+
case arm_sme::ArmSMETileType::ZAS:
101188
// Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
102189
// element tiles named ZA0.D and ZA4.D.
103190
// Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
104191
return 0b0001'0001;
105-
case 64:
192+
case arm_sme::ArmSMETileType::ZAD:
106193
// Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
107194
// setting the bit for that tile.
108195
return 0b0000'0001;
@@ -172,63 +259,13 @@ struct LoadTileSliceConversion
172259
// Create all active predicate mask.
173260
auto maskOp = loadTileSliceOp.getMask();
174261

175-
auto tileType = loadTileSliceOp.getVectorType();
176-
auto tileElementType = tileType.getElementType();
177-
unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
262+
auto tileVectorType = loadTileSliceOp.getVectorType();
263+
arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
178264
arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
179265

180266
// Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
181-
if (layout == arm_sme::TileSliceLayout::Horizontal) {
182-
switch (tileElementWidth) {
183-
default:
184-
llvm_unreachable("unexpected element type!");
185-
case 8:
186-
rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
187-
tileId, tileSliceI32);
188-
break;
189-
case 16:
190-
rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
191-
tileId, tileSliceI32);
192-
break;
193-
case 32:
194-
rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
195-
tileId, tileSliceI32);
196-
break;
197-
case 64:
198-
rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
199-
tileId, tileSliceI32);
200-
break;
201-
case 128:
202-
rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
203-
tileId, tileSliceI32);
204-
break;
205-
}
206-
} else {
207-
switch (tileElementWidth) {
208-
default:
209-
llvm_unreachable("unexpected element type!");
210-
case 8:
211-
rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
212-
tileId, tileSliceI32);
213-
break;
214-
case 16:
215-
rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
216-
tileId, tileSliceI32);
217-
break;
218-
case 32:
219-
rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
220-
tileId, tileSliceI32);
221-
break;
222-
case 64:
223-
rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
224-
tileId, tileSliceI32);
225-
break;
226-
case 128:
227-
rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
228-
tileId, tileSliceI32);
229-
break;
230-
}
231-
}
267+
createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
268+
tileId, tileSliceI32);
232269

233270
// The load intrinsics have no result, replace 'arm_sme.tile_load' with
234271
// the input tile to preserve dataflow.
@@ -249,9 +286,7 @@ struct StoreTileSliceConversion
249286
arm_sme::StoreTileSliceOp::Adaptor adaptor,
250287
ConversionPatternRewriter &rewriter) const override {
251288
auto loc = storeTileSliceOp.getLoc();
252-
auto tileType = storeTileSliceOp.getVectorType();
253-
auto tileElementType = tileType.getElementType();
254-
unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
289+
auto tileVectorType = storeTileSliceOp.getVectorType();
255290

256291
auto tileId = getTileIdOrError(storeTileSliceOp);
257292
if (!tileId)
@@ -271,58 +306,12 @@ struct StoreTileSliceConversion
271306
auto maskOp = storeTileSliceOp.getMask();
272307

273308
arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
309+
arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
274310

275-
if (layout == arm_sme::TileSliceLayout::Horizontal) {
276-
switch (tileElementWidth) {
277-
default:
278-
llvm_unreachable("unexpected element type!");
279-
case 8:
280-
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
281-
storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
282-
break;
283-
case 16:
284-
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
285-
storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
286-
break;
287-
case 32:
288-
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
289-
storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
290-
break;
291-
case 64:
292-
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
293-
storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
294-
break;
295-
case 128:
296-
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
297-
storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
298-
break;
299-
}
300-
} else {
301-
switch (tileElementWidth) {
302-
default:
303-
llvm_unreachable("unexpected element type!");
304-
case 8:
305-
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
306-
storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
307-
break;
308-
case 16:
309-
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
310-
storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
311-
break;
312-
case 32:
313-
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
314-
storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
315-
break;
316-
case 64:
317-
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
318-
storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
319-
break;
320-
case 128:
321-
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
322-
storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
323-
break;
324-
}
325-
}
311+
rewriter.replaceOp(storeTileSliceOp,
312+
createStoreTileSliceIntrinsic(rewriter, loc, tileType,
313+
layout, maskOp, ptr,
314+
tileId, tileSliceI32));
326315

327316
return success();
328317
}

0 commit comments

Comments
 (0)