Skip to content

Commit d875a1a

Browse files
committed
Fixups
1 parent d32a38b commit d875a1a

File tree

3 files changed

+31
-26
lines changed

3 files changed

+31
-26
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ static constexpr StringLiteral
4545

4646
/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
4747
/// larger vector type. The (`row`, `col`) are the position of the tile in the
48-
/// original vector type. For example for an [8]x[8] tile would have four
49-
/// [4]x[4] sub-tiles.
48+
/// original vector type. For example for an [8]x[8] tile with four [4]x[4]
49+
/// sub-tiles, we would have:
5050
///
5151
/// 8 x vscale
5252
/// ┌─────────────┬─────────────┐
@@ -104,6 +104,7 @@ SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
104104

105105
/// Returns true if `mask` is generated by an operation that can be decomposed
106106
/// for SME. Currently, that is just no mask, or vector.create_mask.
107+
/// TODO: Add support for vector.constant_mask once required for SME.
107108
bool isSupportedMaskOp(Value mask) {
108109
return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
109110
}

mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
// RUN: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib,%mlir_arm_runner_utils | \
1313
// RUN: FileCheck %s
1414

15+
/// This is very similar to the SME matmul.mlir test, except that it uses a tile
16+
/// size of [8]x[8]xf32, which is larger than a 32-bit SME virtual tile, which
17+
/// would be [4]x[4]xf32. The [8]x[8] tile will be decomposed into four
18+
/// by the `-arm-sme-vector-legalization` pass, which should then use all 32-bit
19+
/// SME accumulators.
20+
1521
func.func @matmul(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf32>) {
1622
%res = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
1723
outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>

mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func.func @fill2DMemrefRows(%memref: memref<?x?xf32>) {
2929
return
3030
}
3131

32-
func.func @testTransposedReadWithMask() {
32+
func.func @testTransposedReadWithMask(%maskRows: index, %maskCols: index) {
3333
%in = memref.alloca() : memref<4x16xf32>
3434
%out = memref.alloca() : memref<16x4xf32>
3535

@@ -38,9 +38,7 @@ func.func @testTransposedReadWithMask() {
3838

3939
func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
4040

41-
/// A mask so we only read the first 2x15 portion of %in.
42-
%maskRows = arith.constant 2 : index
43-
%maskCols = arith.constant 15 : index
41+
/// A mask so we only read the first maskRows x maskCols portion of %in.
4442
%mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
4543
%pad = arith.constant 0.0 : f32
4644
%c0 = arith.constant 0 : index
@@ -59,35 +57,31 @@ func.func @testTransposedReadWithMask() {
5957
call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
6058

6159
/// Print the result memref.
62-
vector.print str "(Masked 15x2) transposed result:"
60+
vector.print str "Masked transposed result:"
6361
%outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
6462
call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
6563

6664
return
6765
}
6866

69-
func.func @testTransposedWriteWithMask() {
67+
func.func @testTransposedWriteWithMask(%maskRows: index, %maskCols: index) {
7068
%in = memref.alloca() : memref<16x4xf32>
7169
%out = memref.alloca() : memref<4x16xf32>
7270

73-
%fill = arith.constant -1.0 : f32
74-
linalg.fill ins(%fill : f32) outs(%out : memref<4x16xf32>)
71+
%c0_f32 = arith.constant 0.0 : f32
72+
linalg.fill ins(%c0_f32 : f32) outs(%out : memref<4x16xf32>)
7573

7674
%inDyn = memref.cast %in : memref<16x4xf32> to memref<?x?xf32>
7775
%outDyn = memref.cast %out : memref<4x16xf32> to memref<?x?xf32>
7876

7977
func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
8078

81-
%pad = arith.constant 0.0 : f32
82-
%c0 = arith.constant 0 : index
83-
8479
/// A regular read.
85-
%read = vector.transfer_read %inDyn[%c0, %c0], %pad {in_bounds = [true, true]}
80+
%c0 = arith.constant 0 : index
81+
%read = vector.transfer_read %inDyn[%c0, %c0], %c0_f32 {in_bounds = [true, true]}
8682
: memref<?x?xf32>, vector<[16]x[4]xf32>
8783

88-
/// A mask so we only write the first 3x8 portion of transpose(%in).
89-
%maskRows = arith.constant 3 : index
90-
%maskCols = arith.constant 8 : index
84+
/// A mask so we only write the first maskRows x maskCols portion of transpose(%in).
9185
%mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
9286

9387
/// Write out the data with a transpose. Here (like the read test) the mask
@@ -101,7 +95,7 @@ func.func @testTransposedWriteWithMask() {
10195
call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
10296

10397
/// Print the result memref.
104-
vector.print str "(Masked 3x8) transposed result:"
98+
vector.print str "Masked transposed result:"
10599
%outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
106100
call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
107101

@@ -120,7 +114,7 @@ func.func @main() {
120114
// CHECK-NEXT: [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
121115
// CHECK-NEXT: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
122116
//
123-
// CHECK: (Masked 15x2) transposed result:
117+
// CHECK: Masked transposed result:
124118
// CHECK: [1, 2, 0, 0]
125119
// CHECK-NEXT: [1, 2, 0, 0]
126120
// CHECK-NEXT: [1, 2, 0, 0]
@@ -137,7 +131,9 @@ func.func @main() {
137131
// CHECK-NEXT: [1, 2, 0, 0]
138132
// CHECK-NEXT: [1, 2, 0, 0]
139133
// CHECK-NEXT: [0, 0, 0, 0]
140-
func.call @testTransposedReadWithMask() : () -> ()
134+
%readMaskRows = arith.constant 2 : index
135+
%readMaskCols = arith.constant 15 : index
136+
func.call @testTransposedReadWithMask(%readMaskRows, %readMaskCols) : (index, index) -> ()
141137

142138
// CHECK: Input memref:
143139
// CHECK: [1, 1, 1, 1]
@@ -157,12 +153,14 @@ func.func @main() {
157153
// CHECK-NEXT: [15, 15, 15, 15]
158154
// CHECK-NEXT: [16, 16, 16, 16]
159155
//
160-
// CHECK: (Masked 3x8) transposed result:
161-
// CHECK: [1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1]
162-
// CHECK-NEXT: [1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1]
163-
// CHECK-NEXT: [1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1]
164-
// CHECK-NEXT: [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
165-
func.call @testTransposedWriteWithMask() : () -> ()
156+
// CHECK: Masked transposed result:
157+
// CHECK: [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0]
158+
// CHECK-NEXT: [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0]
159+
// CHECK-NEXT: [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0]
160+
// CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
161+
%writeMaskRows = arith.constant 3 : index
162+
%writeMaskCols = arith.constant 8 : index
163+
func.call @testTransposedWriteWithMask(%writeMaskRows, %writeMaskCols) : (index, index) -> ()
166164

167165
return
168166
}

0 commit comments

Comments
 (0)