Skip to content

Commit a4d87e3

Browse files
committed
[mlir][ArmSME] Calculate correct tile mask when lowering arm_sme.zero
This patch updates the lowering of the arm_sme.zero to intrinsics so that it calculates the correct mask for the tile to zero. The zero instruction takes an 8-bit mask which specifies which 64-bit tiles to zero, ZA0.D to ZA7.D correspond to bits 0 to 7. To zero tiles with element sizes of 8-bit to 32-bit just requires zeroing the right 64-bit tiles. This is quite easy to calculate, each size has a "base mask" which can be shifted left by the tile ID to get the mask for that tile. base_mask << tile_id After tile allocation, this will be folded to a constant mask. Reviewed By: awarzynski Differential Revision: https://reviews.llvm.org/D157902
1 parent ad9eed1 commit a4d87e3

File tree

3 files changed

+219
-36
lines changed

3 files changed

+219
-36
lines changed

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 87 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
using namespace mlir;
2121
using namespace mlir::arm_sme;
2222

23-
static constexpr unsigned kZeroZAMask = 255;
24-
2523
namespace {
2624
/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
2725
/// ops to enable the ZA storage array.
@@ -51,21 +49,41 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
5149
}
5250
};
5351

54-
/// Lower 'arm_sme.zero'. Use 'arm_sme.cast_tile_to_vector' to model the return
55-
/// value. The latter is a nop, which should be folded away (e.g. during
56-
/// canonicalisation).
52+
/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
53+
/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
54+
/// integer, to an i32 that can be passed as the `tile` parameter to the SME
55+
/// intrinsics. Or returns `tile` if already i32.
56+
Value castTileIDToI32(Value tile, Location loc,
57+
ConversionPatternRewriter &rewriter) {
58+
assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
59+
tile.getDefiningOp())) &&
60+
"expected ArmSME GetTileID or CastVectorToTile op!");
61+
unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
62+
if (tileElementWidth < 32)
63+
return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
64+
if (tileElementWidth > 32)
65+
return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
66+
return tile;
67+
}
68+
69+
/// Lower 'arm_sme.zero' to SME intrinsics.
5770
///
5871
/// BEFORE:
5972
/// ```mlir
60-
/// %0 = arm_sme.zero : vector<[16]x[16]xi8>
73+
/// %v = arm_sme.zero : vector<[4]x[4]xi32>
6174
/// ```
6275
///
6376
/// AFTER:
6477
/// ```mlir
65-
/// %1 = arm_sme.get_tile_id : i8
66-
/// %2 = arm_sme.cast_tile_to_vector %1 : i8 to vector<[16]x[16]xi8>
67-
/// "arm_sme.intr.zero"(%c255_i32) : (i32) -> ()
78+
/// %tile_id = arm_sme.get_tile_id : i32
79+
/// %zero_mask = arith.shli %c17_i32, %tile_id : i32
80+
/// "arm_sme.intr.zero"(%zero_mask) : (i32) -> ()
81+
/// %v = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
6882
/// ```
83+
///
84+
/// The 'arm_sme.cast_tile_to_vector' (which models the return) and the
85+
/// 'arith.shli' (which generates the mask) will be folded away after tile
86+
/// allocation and canonization.
6987
struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
7088
using ConvertOpToLLVMPattern<ZeroOp>::ConvertOpToLLVMPattern;
7189

@@ -75,42 +93,76 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
7593
auto loc = zero.getLoc();
7694

7795
// Get Tile ID for the `zero` intrinsic.
78-
// TODO: Map this to a valid `mask` for the `zero` intrinsic.
7996
auto tileId = rewriter.create<arm_sme::GetTileID>(
8097
loc, zero.getVectorType().getElementType());
8198

82-
// Create 'arm_sme.intr.zero' intrinsic to zero ZA.
83-
// FIXME: Replace the hard-coded mask with a valid value based
84-
// on `tileId`.
85-
auto mask = rewriter.create<arith::ConstantOp>(
86-
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask));
87-
rewriter.create<arm_sme::aarch64_sme_zero>(loc, mask);
88-
89-
// Create `CastTileToVectorOp` to use it as the output
99+
auto tileElementWidth = tileId.getType().getIntOrFloatBitWidth();
100+
101+
// Get the base mask for tile based on the element size.
102+
// The base mask is just the mask to zero the first tile (of a size).
103+
// These masks are derived from:
104+
// https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
105+
auto baseMaskForSize = [&] {
106+
switch (tileElementWidth) {
107+
case 8:
108+
// Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
109+
// 64-bit element tiles named ZA0.D to ZA7.D.
110+
return 0b1111'1111;
111+
case 16:
112+
// Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element
113+
// tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D.
114+
// Shift this left once for ZA1.H.
115+
return 0b0101'0101;
116+
case 32:
117+
// Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
118+
// element tiles named ZA0.D and ZA4.D.
119+
// Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
120+
return 0b0001'0001;
121+
case 64:
122+
// Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
123+
// setting the bit for that tile.
124+
return 0b0000'0001;
125+
default:
126+
llvm_unreachable("bad element size");
127+
}
128+
}();
129+
auto maskType = rewriter.getI32Type();
130+
auto baseMask = rewriter.create<arith::ConstantOp>(
131+
loc, maskType, rewriter.getIntegerAttr(maskType, baseMaskForSize));
132+
133+
// The actual mask is just the base mask shifted by the tile ID.
134+
// This will be folded to a constant after tile allocation.
135+
//
136+
// The shift is just derived from the layout of the tiles, and that the tile
137+
// ID is the index of the tile. For example, looking at the 32-bit ZAx.S
138+
// tiles:
139+
//
140+
// ZA0.S = ZA0.D and ZA4.D
141+
// * Tile ID -> 0
142+
// * Mask -> 00010001 = (00010001 << 0)
143+
// ZA1.S = ZA1.D and ZA5.D
144+
// * Tile ID -> 1
145+
// * Mask -> 00100010 = (00010001 << 1)
146+
// ZA2.S = ZA2.D and ZA6.D
147+
// * Tile ID -> 2
148+
// * Mask -> 01000100 = (00010001 << 2)
149+
// ZA3.S = ZA3.D and ZA7.D
150+
// * Tile ID -> 3
151+
// * Mask -> 10001000 = (00010001 << 3)
152+
//
153+
// This holds for all tile sizes.
154+
auto tileMask = rewriter.create<arith::ShLIOp>(
155+
loc, baseMask, castTileIDToI32(tileId, loc, rewriter));
156+
rewriter.create<arm_sme::aarch64_sme_zero>(loc, tileMask);
157+
158+
// Create `CastTileToVectorOp` to use as the output.
90159
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(zero, zero.getType(),
91160
tileId);
92161

93162
return success();
94163
}
95164
};
96165

97-
/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
98-
/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
99-
/// integer, to an i32 that can be passed as the `tile` parameter to the SME
100-
/// intrinsics. Or returns `tile` if already i32.
101-
Value castTileIDToI32(Value tile, Location loc,
102-
ConversionPatternRewriter &rewriter) {
103-
assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
104-
tile.getDefiningOp())) &&
105-
"expected ArmSME GetTileID or CastVectorToTile op!");
106-
unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
107-
if (tileElementWidth < 32)
108-
return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
109-
if (tileElementWidth > 32)
110-
return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
111-
return tile;
112-
}
113-
114166
/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
115167
struct LoadTileSliceToArmSMELowering
116168
: public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" \
2+
// RUN: -allocate-arm-sme-tiles -canonicalize \
3+
// RUN: -allow-unregistered-dialect \
4+
// RUN: | FileCheck %s
5+
6+
// -----
7+
8+
// CHECK-LABEL: zero_za_b
9+
func.func @zero_za_b() {
10+
// CHECK-DAG: %[[TILE_ID:.*]] = arith.constant 0 : i8
11+
// CHECK-DAG: %[[ZERO_MASK:.*]] = arith.constant 255 : i32
12+
13+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK]]) : (i32) -> ()
14+
// CHECK-NEXT: %[[ZERO_ZA0B:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
15+
%zero_za0b = arm_sme.zero : vector<[16]x[16]xi8>
16+
"prevent.dce"(%zero_za0b) : (vector<[16]x[16]xi8>) -> ()
17+
return
18+
}
19+
20+
// -----
21+
22+
// CHECK-LABEL: zero_za_h
23+
func.func @zero_za_h() {
24+
// CHECK-DAG: %[[TILE_ID_ZA0H:.*]] = arith.constant 0 : i16
25+
// CHECK-DAG: %[[TILE_ID_ZA1H:.*]] = arith.constant 1 : i16
26+
27+
// CHECK-DAG: %[[ZERO_MASK_ZA0H:.*]] = arith.constant 85 : i32
28+
// CHECK-DAG: %[[ZERO_MASK_ZA1H:.*]] = arith.constant 170 : i32
29+
30+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0H]]) : (i32) -> ()
31+
// CHECK-NEXT: %[[ZERO_ZA0H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0H]] : i16 to vector<[8]x[8]xi16>
32+
%zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
33+
"prevent.dce"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
34+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1H]]) : (i32) -> ()
35+
// CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xi16>
36+
%zero_za1h = arm_sme.zero : vector<[8]x[8]xi16>
37+
"prevent.dce"(%zero_za1h) : (vector<[8]x[8]xi16>) -> ()
38+
return
39+
}
40+
41+
// -----
42+
43+
// CHECK-LABEL: zero_za_s
44+
func.func @zero_za_s() {
45+
// CHECK-DAG: %[[TILE_ID_ZA0S:.*]] = arith.constant 0 : i32
46+
// CHECK-DAG: %[[TILE_ID_ZA1S:.*]] = arith.constant 1 : i32
47+
// CHECK-DAG: %[[TILE_ID_ZA2S:.*]] = arith.constant 2 : i32
48+
// CHECK-DAG: %[[TILE_ID_ZA3S:.*]] = arith.constant 3 : i32
49+
50+
// CHECK-DAG: %[[ZERO_MASK_ZA0S:.*]] = arith.constant 17 : i32
51+
// CHECK-DAG: %[[ZERO_MASK_ZA1S:.*]] = arith.constant 34 : i32
52+
// CHECK-DAG: %[[ZERO_MASK_ZA2S:.*]] = arith.constant 68 : i32
53+
// CHECK-DAG: %[[ZERO_MASK_ZA3S:.*]] = arith.constant 136 : i32
54+
55+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0S]]) : (i32) -> ()
56+
// CHECK-NEXT: %[[ZERO_ZA0S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0S]] : i32 to vector<[4]x[4]xi32>
57+
%zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
58+
"prevent.dce"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
59+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1S]]) : (i32) -> ()
60+
// CHECK-NEXT: %[[ZERO_ZA1S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1S]] : i32 to vector<[4]x[4]xi32>
61+
%zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
62+
"prevent.dce"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
63+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2S]]) : (i32) -> ()
64+
// CHECK-NEXT: %[[ZERO_ZA2S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2S]] : i32 to vector<[4]x[4]xi32>
65+
%zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
66+
"prevent.dce"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
67+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3S]]) : (i32) -> ()
68+
// CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xi32>
69+
%zero_za3s = arm_sme.zero : vector<[4]x[4]xi32>
70+
"prevent.dce"(%zero_za3s) : (vector<[4]x[4]xi32>) -> ()
71+
return
72+
}
73+
74+
// -----
75+
76+
// CHECK-LABEL: zero_za_d
77+
func.func @zero_za_d() {
78+
// CHECK-DAG: %[[TILE_ID_ZA0D:.*]] = arith.constant 0 : i64
79+
// CHECK-DAG: %[[TILE_ID_ZA1D:.*]] = arith.constant 1 : i64
80+
// CHECK-DAG: %[[TILE_ID_ZA2D:.*]] = arith.constant 2 : i64
81+
// CHECK-DAG: %[[TILE_ID_ZA3D:.*]] = arith.constant 3 : i64
82+
// CHECK-DAG: %[[TILE_ID_ZA4D:.*]] = arith.constant 4 : i64
83+
// CHECK-DAG: %[[TILE_ID_ZA5D:.*]] = arith.constant 5 : i64
84+
// CHECK-DAG: %[[TILE_ID_ZA6D:.*]] = arith.constant 6 : i64
85+
// CHECK-DAG: %[[TILE_ID_ZA7D:.*]] = arith.constant 7 : i64
86+
87+
// CHECK-DAG: %[[ZERO_MASK_ZA0D:.*]] = arith.constant 1 : i32
88+
// CHECK-DAG: %[[ZERO_MASK_ZA1D:.*]] = arith.constant 2 : i32
89+
// CHECK-DAG: %[[ZERO_MASK_ZA2D:.*]] = arith.constant 4 : i32
90+
// CHECK-DAG: %[[ZERO_MASK_ZA3D:.*]] = arith.constant 8 : i32
91+
// CHECK-DAG: %[[ZERO_MASK_ZA4D:.*]] = arith.constant 16 : i32
92+
// CHECK-DAG: %[[ZERO_MASK_ZA5D:.*]] = arith.constant 32 : i32
93+
// CHECK-DAG: %[[ZERO_MASK_ZA6D:.*]] = arith.constant 64 : i32
94+
// CHECK-DAG: %[[ZERO_MASK_ZA7D:.*]] = arith.constant 128 : i32
95+
96+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0D]]) : (i32) -> ()
97+
// CHECK-NEXT: %[[ZERO_ZA0D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0D]] : i64 to vector<[2]x[2]xi64>
98+
%zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
99+
"prevent.dce"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
100+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1D]]) : (i32) -> ()
101+
// CHECK-NEXT: %[[ZERO_ZA1D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1D]] : i64 to vector<[2]x[2]xi64>
102+
%zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
103+
"prevent.dce"(%zero_za1d) : (vector<[2]x[2]xi64>) -> ()
104+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2D]]) : (i32) -> ()
105+
// CHECK-NEXT: %[[ZERO_ZA2D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2D]] : i64 to vector<[2]x[2]xi64>
106+
%zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
107+
"prevent.dce"(%zero_za2d) : (vector<[2]x[2]xi64>) -> ()
108+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3D]]) : (i32) -> ()
109+
// CHECK-NEXT: %[[ZERO_ZA3D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3D]] : i64 to vector<[2]x[2]xi64>
110+
%zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
111+
"prevent.dce"(%zero_za3d) : (vector<[2]x[2]xi64>) -> ()
112+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA4D]]) : (i32) -> ()
113+
// CHECK-NEXT: %[[ZERO_ZA4D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA4D]] : i64 to vector<[2]x[2]xi64>
114+
%zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
115+
"prevent.dce"(%zero_za4d) : (vector<[2]x[2]xi64>) -> ()
116+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA5D]]) : (i32) -> ()
117+
// CHECK-NEXT: %[[ZERO_ZA5D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA5D]] : i64 to vector<[2]x[2]xi64>
118+
%zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
119+
"prevent.dce"(%zero_za5d) : (vector<[2]x[2]xi64>) -> ()
120+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA6D]]) : (i32) -> ()
121+
// CHECK-NEXT: %[[ZERO_ZA6D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA6D]] : i64 to vector<[2]x[2]xi64>
122+
%zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
123+
"prevent.dce"(%zero_za6d) : (vector<[2]x[2]xi64>) -> ()
124+
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA7D]]) : (i32) -> ()
125+
// CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xi64>
126+
%zero_za7d = arm_sme.zero : vector<[2]x[2]xi64>
127+
"prevent.dce"(%zero_za7d) : (vector<[2]x[2]xi64>) -> ()
128+
return
129+
}

mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
// CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32
1010
// CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
1111
// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
12-
// CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
1312
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
13+
// CHECK-DAG: %[[EXT_TILE_ID:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
14+
// CHECK-DAG: %[[TILE_MASK:.*]] = arith.shli %[[C255]], %[[EXT_TILE_ID]] : i32
15+
// CHECK-DAG: "arm_sme.intr.zero"(%[[TILE_MASK]]) : (i32) -> ()
1416
// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
1517
// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
1618
// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index

0 commit comments

Comments
 (0)