20
20
using namespace mlir ;
21
21
using namespace mlir ::arm_sme;
22
22
23
- static constexpr unsigned kZeroZAMask = 255 ;
24
-
25
23
namespace {
26
24
// / Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
27
25
// / ops to enable the ZA storage array.
@@ -51,21 +49,41 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
51
49
}
52
50
};
53
51
54
- // / Lower 'arm_sme.zero'. Use 'arm_sme.cast_tile_to_vector' to model the return
55
- // / value. The latter is a nop, which should be folded away (e.g. during
56
- // / canonicalisation).
52
+ // / Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
53
+ // / `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
54
+ // / integer, to an i32 that can be passed as the `tile` parameter to the SME
55
+ // / intrinsics. Or returns `tile` if already i32.
56
+ Value castTileIDToI32 (Value tile, Location loc,
57
+ ConversionPatternRewriter &rewriter) {
58
+ assert ((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
59
+ tile.getDefiningOp ())) &&
60
+ " expected ArmSME GetTileID or CastVectorToTile op!" );
61
+ unsigned tileElementWidth = tile.getType ().getIntOrFloatBitWidth ();
62
+ if (tileElementWidth < 32 )
63
+ return rewriter.create <arith::ExtUIOp>(loc, rewriter.getI32Type (), tile);
64
+ if (tileElementWidth > 32 )
65
+ return rewriter.create <arith::TruncIOp>(loc, rewriter.getI32Type (), tile);
66
+ return tile;
67
+ }
68
+
69
+ // / Lower 'arm_sme.zero' to SME intrinsics.
57
70
// /
58
71
// / BEFORE:
59
72
// / ```mlir
60
- // / %0 = arm_sme.zero : vector<[16 ]x[16]xi8 >
73
+ // / %v = arm_sme.zero : vector<[4 ]x[4]xi32 >
61
74
// / ```
62
75
// /
63
76
// / AFTER:
64
77
// / ```mlir
65
- // / %1 = arm_sme.get_tile_id : i8
66
- // / %2 = arm_sme.cast_tile_to_vector %1 : i8 to vector<[16]x[16]xi8>
67
- // / "arm_sme.intr.zero"(%c255_i32) : (i32) -> ()
78
+ // / %tile_id = arm_sme.get_tile_id : i32
79
+ // / %zero_mask = arith.shli %c17_i32, %tile_id : i32
80
+ // / "arm_sme.intr.zero"(%zero_mask) : (i32) -> ()
81
+ // / %v = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
68
82
// / ```
83
+ // /
84
+ // / The 'arm_sme.cast_tile_to_vector' (which models the return) and the
85
+ // / 'arith.shli' (which generates the mask) will be folded away after tile
86
+ // / allocation and canonization.
69
87
struct ZeroOpConversion : public ConvertOpToLLVMPattern <ZeroOp> {
70
88
using ConvertOpToLLVMPattern<ZeroOp>::ConvertOpToLLVMPattern;
71
89
@@ -75,42 +93,76 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
75
93
auto loc = zero.getLoc ();
76
94
77
95
// Get Tile ID for the `zero` intrinsic.
78
- // TODO: Map this to a valid `mask` for the `zero` intrinsic.
79
96
auto tileId = rewriter.create <arm_sme::GetTileID>(
80
97
loc, zero.getVectorType ().getElementType ());
81
98
82
- // Create 'arm_sme.intr.zero' intrinsic to zero ZA.
83
- // FIXME: Replace the hard-coded mask with a valid value based
84
- // on `tileId`.
85
- auto mask = rewriter.create <arith::ConstantOp>(
86
- loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (kZeroZAMask ));
87
- rewriter.create <arm_sme::aarch64_sme_zero>(loc, mask);
88
-
89
- // Create `CastTileToVectorOp` to use it as the output
99
+ auto tileElementWidth = tileId.getType ().getIntOrFloatBitWidth ();
100
+
101
+ // Get the base mask for tile based on the element size.
102
+ // The base mask is just the mask to zero the first tile (of a size).
103
+ // These masks are derived from:
104
+ // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
105
+ auto baseMaskForSize = [&] {
106
+ switch (tileElementWidth) {
107
+ case 8 :
108
+ // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
109
+ // 64-bit element tiles named ZA0.D to ZA7.D.
110
+ return 0b1111'1111 ;
111
+ case 16 :
112
+ // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element
113
+ // tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D.
114
+ // Shift this left once for ZA1.H.
115
+ return 0b0101'0101 ;
116
+ case 32 :
117
+ // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
118
+ // element tiles named ZA0.D and ZA4.D.
119
+ // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
120
+ return 0b0001'0001 ;
121
+ case 64 :
122
+ // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
123
+ // setting the bit for that tile.
124
+ return 0b0000'0001 ;
125
+ default :
126
+ llvm_unreachable (" bad element size" );
127
+ }
128
+ }();
129
+ auto maskType = rewriter.getI32Type ();
130
+ auto baseMask = rewriter.create <arith::ConstantOp>(
131
+ loc, maskType, rewriter.getIntegerAttr (maskType, baseMaskForSize));
132
+
133
+ // The actual mask is just the base mask shifted by the tile ID.
134
+ // This will be folded to a constant after tile allocation.
135
+ //
136
+ // The shift is just derived from the layout of the tiles, and that the tile
137
+ // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
138
+ // tiles:
139
+ //
140
+ // ZA0.S = ZA0.D and ZA4.D
141
+ // * Tile ID -> 0
142
+ // * Mask -> 00010001 = (00010001 << 0)
143
+ // ZA1.S = ZA1.D and ZA5.D
144
+ // * Tile ID -> 1
145
+ // * Mask -> 00100010 = (00010001 << 1)
146
+ // ZA2.S = ZA2.D and ZA6.D
147
+ // * Tile ID -> 2
148
+ // * Mask -> 01000100 = (00010001 << 2)
149
+ // ZA3.S = ZA3.D and ZA7.D
150
+ // * Tile ID -> 3
151
+ // * Mask -> 10001000 = (00010001 << 3)
152
+ //
153
+ // This holds for all tile sizes.
154
+ auto tileMask = rewriter.create <arith::ShLIOp>(
155
+ loc, baseMask, castTileIDToI32 (tileId, loc, rewriter));
156
+ rewriter.create <arm_sme::aarch64_sme_zero>(loc, tileMask);
157
+
158
+ // Create `CastTileToVectorOp` to use as the output.
90
159
rewriter.replaceOpWithNewOp <arm_sme::CastTileToVector>(zero, zero.getType (),
91
160
tileId);
92
161
93
162
return success ();
94
163
}
95
164
};
96
165
97
- // / Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
98
- // / `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
99
- // / integer, to an i32 that can be passed as the `tile` parameter to the SME
100
- // / intrinsics. Or returns `tile` if already i32.
101
- Value castTileIDToI32 (Value tile, Location loc,
102
- ConversionPatternRewriter &rewriter) {
103
- assert ((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
104
- tile.getDefiningOp ())) &&
105
- " expected ArmSME GetTileID or CastVectorToTile op!" );
106
- unsigned tileElementWidth = tile.getType ().getIntOrFloatBitWidth ();
107
- if (tileElementWidth < 32 )
108
- return rewriter.create <arith::ExtUIOp>(loc, rewriter.getI32Type (), tile);
109
- if (tileElementWidth > 32 )
110
- return rewriter.create <arith::TruncIOp>(loc, rewriter.getI32Type (), tile);
111
- return tile;
112
- }
113
-
114
166
// / Lower `arm_sme.load_tile_slice` to SME intrinsics.
115
167
struct LoadTileSliceToArmSMELowering
116
168
: public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
0 commit comments