@@ -32,6 +32,95 @@ using namespace mlir;
32
32
33
33
namespace {
34
34
35
+ // / Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
36
+ static Operation *createLoadTileSliceIntrinsic (
37
+ RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
38
+ arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
39
+ IntegerAttr tileId, Value tileSliceI32) {
40
+ if (layout == arm_sme::TileSliceLayout::Horizontal) {
41
+ switch (type) {
42
+ case arm_sme::ArmSMETileType::ZAB:
43
+ return rewriter.create <arm_sme::aarch64_sme_ld1b_horiz>(
44
+ loc, maskOp, ptr, tileId, tileSliceI32);
45
+ case arm_sme::ArmSMETileType::ZAH:
46
+ return rewriter.create <arm_sme::aarch64_sme_ld1h_horiz>(
47
+ loc, maskOp, ptr, tileId, tileSliceI32);
48
+ case arm_sme::ArmSMETileType::ZAS:
49
+ return rewriter.create <arm_sme::aarch64_sme_ld1w_horiz>(
50
+ loc, maskOp, ptr, tileId, tileSliceI32);
51
+ case arm_sme::ArmSMETileType::ZAD:
52
+ return rewriter.create <arm_sme::aarch64_sme_ld1d_horiz>(
53
+ loc, maskOp, ptr, tileId, tileSliceI32);
54
+ case arm_sme::ArmSMETileType::ZAQ:
55
+ return rewriter.create <arm_sme::aarch64_sme_ld1q_horiz>(
56
+ loc, maskOp, ptr, tileId, tileSliceI32);
57
+ }
58
+ } else {
59
+ switch (type) {
60
+ case arm_sme::ArmSMETileType::ZAB:
61
+ return rewriter.create <arm_sme::aarch64_sme_ld1b_vert>(
62
+ loc, maskOp, ptr, tileId, tileSliceI32);
63
+ case arm_sme::ArmSMETileType::ZAH:
64
+ return rewriter.create <arm_sme::aarch64_sme_ld1h_vert>(
65
+ loc, maskOp, ptr, tileId, tileSliceI32);
66
+ case arm_sme::ArmSMETileType::ZAS:
67
+ return rewriter.create <arm_sme::aarch64_sme_ld1w_vert>(
68
+ loc, maskOp, ptr, tileId, tileSliceI32);
69
+ case arm_sme::ArmSMETileType::ZAD:
70
+ return rewriter.create <arm_sme::aarch64_sme_ld1d_vert>(
71
+ loc, maskOp, ptr, tileId, tileSliceI32);
72
+ case arm_sme::ArmSMETileType::ZAQ:
73
+ return rewriter.create <arm_sme::aarch64_sme_ld1q_vert>(
74
+ loc, maskOp, ptr, tileId, tileSliceI32);
75
+ break ;
76
+ }
77
+ }
78
+ }
79
+
80
+ // / Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
81
+ static Operation *createStoreTileSliceIntrinsic (
82
+ RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
83
+ arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
84
+ IntegerAttr tileId, Value tileSliceI32) {
85
+ if (layout == arm_sme::TileSliceLayout::Horizontal) {
86
+ switch (type) {
87
+ case arm_sme::ArmSMETileType::ZAB:
88
+ return rewriter.create <arm_sme::aarch64_sme_st1b_horiz>(
89
+ loc, maskOp, ptr, tileId, tileSliceI32);
90
+ case arm_sme::ArmSMETileType::ZAH:
91
+ return rewriter.create <arm_sme::aarch64_sme_st1h_horiz>(
92
+ loc, maskOp, ptr, tileId, tileSliceI32);
93
+ case arm_sme::ArmSMETileType::ZAS:
94
+ return rewriter.create <arm_sme::aarch64_sme_st1w_horiz>(
95
+ loc, maskOp, ptr, tileId, tileSliceI32);
96
+ case arm_sme::ArmSMETileType::ZAD:
97
+ return rewriter.create <arm_sme::aarch64_sme_st1d_horiz>(
98
+ loc, maskOp, ptr, tileId, tileSliceI32);
99
+ case arm_sme::ArmSMETileType::ZAQ:
100
+ return rewriter.create <arm_sme::aarch64_sme_st1q_horiz>(
101
+ loc, maskOp, ptr, tileId, tileSliceI32);
102
+ }
103
+ } else {
104
+ switch (type) {
105
+ case arm_sme::ArmSMETileType::ZAB:
106
+ return rewriter.create <arm_sme::aarch64_sme_st1b_vert>(
107
+ loc, maskOp, ptr, tileId, tileSliceI32);
108
+ case arm_sme::ArmSMETileType::ZAH:
109
+ return rewriter.create <arm_sme::aarch64_sme_st1h_vert>(
110
+ loc, maskOp, ptr, tileId, tileSliceI32);
111
+ case arm_sme::ArmSMETileType::ZAS:
112
+ return rewriter.create <arm_sme::aarch64_sme_st1w_vert>(
113
+ loc, maskOp, ptr, tileId, tileSliceI32);
114
+ case arm_sme::ArmSMETileType::ZAD:
115
+ return rewriter.create <arm_sme::aarch64_sme_st1d_vert>(
116
+ loc, maskOp, ptr, tileId, tileSliceI32);
117
+ case arm_sme::ArmSMETileType::ZAQ:
118
+ return rewriter.create <arm_sme::aarch64_sme_st1q_vert>(
119
+ loc, maskOp, ptr, tileId, tileSliceI32);
120
+ }
121
+ }
122
+ }
123
+
35
124
IntegerAttr getTileIdOrError (arm_sme::ArmSMETileOpInterface op) {
36
125
auto tileId = op.getTileId ();
37
126
if (!tileId)
@@ -75,9 +164,6 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
75
164
ConversionPatternRewriter &rewriter) const override {
76
165
auto loc = zero.getLoc ();
77
166
78
- unsigned tileElementWidth =
79
- zero.getVectorType ().getElementType ().getIntOrFloatBitWidth ();
80
-
81
167
auto tileId = getTileIdOrError (zero);
82
168
if (!tileId)
83
169
return failure ();
@@ -86,23 +172,24 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
86
172
// The base mask is just the mask to zero the first tile (of a size).
87
173
// These masks are derived from:
88
174
// https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
175
+ arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType ();
89
176
auto baseMaskForSize = [&] {
90
- switch (tileElementWidth ) {
91
- case 8 :
177
+ switch (tileType ) {
178
+ case arm_sme::ArmSMETileType::ZAB :
92
179
// Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
93
180
// 64-bit element tiles named ZA0.D to ZA7.D.
94
181
return 0b1111'1111 ;
95
- case 16 :
96
- // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element
97
- // tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D.
98
- // Shift this left once for ZA1.H.
182
+ case arm_sme::ArmSMETileType::ZAH :
183
+ // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
184
+ // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
185
+ // once for ZA1.H.
99
186
return 0b0101'0101 ;
100
- case 32 :
187
+ case arm_sme::ArmSMETileType::ZAS :
101
188
// Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
102
189
// element tiles named ZA0.D and ZA4.D.
103
190
// Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
104
191
return 0b0001'0001 ;
105
- case 64 :
192
+ case arm_sme::ArmSMETileType::ZAD :
106
193
// Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
107
194
// setting the bit for that tile.
108
195
return 0b0000'0001 ;
@@ -172,63 +259,13 @@ struct LoadTileSliceConversion
172
259
// Create all active predicate mask.
173
260
auto maskOp = loadTileSliceOp.getMask ();
174
261
175
- auto tileType = loadTileSliceOp.getVectorType ();
176
- auto tileElementType = tileType.getElementType ();
177
- unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth ();
262
+ auto tileVectorType = loadTileSliceOp.getVectorType ();
263
+ arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType (tileVectorType);
178
264
arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout ();
179
265
180
266
// Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
181
- if (layout == arm_sme::TileSliceLayout::Horizontal) {
182
- switch (tileElementWidth) {
183
- default :
184
- llvm_unreachable (" unexpected element type!" );
185
- case 8 :
186
- rewriter.create <arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
187
- tileId, tileSliceI32);
188
- break ;
189
- case 16 :
190
- rewriter.create <arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
191
- tileId, tileSliceI32);
192
- break ;
193
- case 32 :
194
- rewriter.create <arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
195
- tileId, tileSliceI32);
196
- break ;
197
- case 64 :
198
- rewriter.create <arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
199
- tileId, tileSliceI32);
200
- break ;
201
- case 128 :
202
- rewriter.create <arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
203
- tileId, tileSliceI32);
204
- break ;
205
- }
206
- } else {
207
- switch (tileElementWidth) {
208
- default :
209
- llvm_unreachable (" unexpected element type!" );
210
- case 8 :
211
- rewriter.create <arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
212
- tileId, tileSliceI32);
213
- break ;
214
- case 16 :
215
- rewriter.create <arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
216
- tileId, tileSliceI32);
217
- break ;
218
- case 32 :
219
- rewriter.create <arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
220
- tileId, tileSliceI32);
221
- break ;
222
- case 64 :
223
- rewriter.create <arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
224
- tileId, tileSliceI32);
225
- break ;
226
- case 128 :
227
- rewriter.create <arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
228
- tileId, tileSliceI32);
229
- break ;
230
- }
231
- }
267
+ createLoadTileSliceIntrinsic (rewriter, loc, tileType, layout, maskOp, ptr,
268
+ tileId, tileSliceI32);
232
269
233
270
// The load intrinsics have no result, replace 'arm_sme.tile_load' with
234
271
// the input tile to preserve dataflow.
@@ -249,9 +286,7 @@ struct StoreTileSliceConversion
249
286
arm_sme::StoreTileSliceOp::Adaptor adaptor,
250
287
ConversionPatternRewriter &rewriter) const override {
251
288
auto loc = storeTileSliceOp.getLoc ();
252
- auto tileType = storeTileSliceOp.getVectorType ();
253
- auto tileElementType = tileType.getElementType ();
254
- unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth ();
289
+ auto tileVectorType = storeTileSliceOp.getVectorType ();
255
290
256
291
auto tileId = getTileIdOrError (storeTileSliceOp);
257
292
if (!tileId)
@@ -271,58 +306,12 @@ struct StoreTileSliceConversion
271
306
auto maskOp = storeTileSliceOp.getMask ();
272
307
273
308
arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout ();
309
+ arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType (tileVectorType);
274
310
275
- if (layout == arm_sme::TileSliceLayout::Horizontal) {
276
- switch (tileElementWidth) {
277
- default :
278
- llvm_unreachable (" unexpected element type!" );
279
- case 8 :
280
- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1b_horiz>(
281
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
282
- break ;
283
- case 16 :
284
- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1h_horiz>(
285
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
286
- break ;
287
- case 32 :
288
- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1w_horiz>(
289
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
290
- break ;
291
- case 64 :
292
- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1d_horiz>(
293
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
294
- break ;
295
- case 128 :
296
- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1q_horiz>(
297
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
298
- break ;
299
- }
300
- } else {
301
- switch (tileElementWidth) {
302
- default :
303
- llvm_unreachable (" unexpected element type!" );
304
- case 8 :
305
- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1b_vert>(
306
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
307
- break ;
308
- case 16 :
309
- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1h_vert>(
310
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
311
- break ;
312
- case 32 :
313
- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1w_vert>(
314
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
315
- break ;
316
- case 64 :
317
- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1d_vert>(
318
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
319
- break ;
320
- case 128 :
321
- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1q_vert>(
322
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
323
- break ;
324
- }
325
- }
311
+ rewriter.replaceOp (storeTileSliceOp,
312
+ createStoreTileSliceIntrinsic (rewriter, loc, tileType,
313
+ layout, maskOp, ptr,
314
+ tileId, tileSliceI32));
326
315
327
316
return success ();
328
317
}
0 commit comments