Skip to content

Commit 9a5cdf0

Browse files
committed
address comments
1 parent ec829d7 commit 9a5cdf0

File tree

3 files changed

+31
-23
lines changed

3 files changed

+31
-23
lines changed

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
158158
/// %c0 = arith.constant 0 : index
159159
/// %c1 = arith.constant 1 : index
160160
/// %tile = arm_sme.zero : vector<[4]x[4]xi32>
161+
/// %num_rows = arith.constant 2 : index
161162
/// %num_cols = vector.create_mask %c4 : vector<[4]xi1>
162163
/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
163164
/// %tile_update = arm_sme.load_tile_slice
@@ -252,24 +253,12 @@ struct TileLoadOpWithMaskAndPadZeroConversion
252253
///
253254
/// AFTER:
254255
/// ```mlir
256+
/// ...
255257
/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
256-
/// %num_rows = arith.constant 2 : index
257-
/// %num_cols = arith.constant 4 : index
258-
/// %num_cols_i32 = arith.index_castui %num_cols : index to i32
259-
/// %tile_id = arm_sme.get_tile_id : i32
260-
/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
261-
/// %vscale = vector.vscale
262-
/// %c0 = arith.constant 0 : index
263-
/// %c1 = arith.constant 1 : index
264-
/// %min_svl_s = arith.constant 4 : index
265-
/// %svl_s = arith.muli %min_svl_s, %vscale : index
266258
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
267-
/// %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
268-
/// %row_is_active_i32 = arith.extsi %row_is_active : i1 to i32
269-
/// %mask = arith.andi %row_is_active_i32, %num_cols_i32 : i32
270-
/// %mask_index = arith.index_cast %mask : i32 to index
271-
/// %mask_1d = vector.create_mask %mask_index : vector<[4]xi1>
272-
/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad
259+
/// ...
260+
/// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
261+
/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
273262
/// : memref<?x?xi32>, vector<[4]xi1>,
274263
/// vector<[4]xi32> into vector<[4]xi32>
275264
/// // Insert slice into tile

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file -verify-diagnostics | FileCheck %s
22

33
//===----------------------------------------------------------------------===//
44
// arm_sme.tile_load
@@ -89,6 +89,25 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32
8989
return
9090
}
9191

92+
// -----
93+
94+
func.func @arm_sme_tile_load_zero_pad__unsupported_mask_op(%src : memref<?x?xi32>, %mask : vector<[4]x[4]xi1>) {
95+
%c0 = arith.constant 0 : index
96+
%pad = arith.constant 0 : i32
97+
// expected-error@+1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
98+
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
99+
return
100+
}
101+
102+
// -----
103+
104+
func.func @arm_sme_tile_load_nonzero_pad__unsupported_mask_op(%src : memref<?x?xi32>, %pad : i32, %mask : vector<[4]x[4]xi1>) {
105+
%c0 = arith.constant 0 : index
106+
// expected-error@+1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
107+
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
108+
return
109+
}
110+
92111
//===----------------------------------------------------------------------===//
93112
// arm_sme.tile_store
94113
//===----------------------------------------------------------------------===//

mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
// RUN: %{compile} | %{run} | FileCheck %s
1313

14-
// Vector load.
14+
// 2-D vector load (SME tile).
1515
func.func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
1616
%c4 = arith.constant 4 : index
1717
%pad = arith.constant 0.0 : f32
@@ -24,7 +24,7 @@ func.func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index)
2424
return
2525
}
2626

27-
// Vector load + transpose.
27+
// 2-D vector load (SME tile) + transpose.
2828
func.func @transfer_read_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
2929
%pad = arith.constant 0.0 : f32
3030
%0 = vector.transfer_read %A[%base1, %base2], %pad
@@ -37,7 +37,7 @@ func.func @transfer_read_2d_transposed(%A : memref<?x?xf32>, %base1: index, %bas
3737
return
3838
}
3939

40-
// Vector load with mask and pad of zero.
40+
// 2-D vector load (SME tile) with mask and pad of zero.
4141
func.func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
4242
%c2 = arith.constant 2 : index
4343
%c3 = arith.constant 3 : index
@@ -52,7 +52,7 @@ func.func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: in
5252
return
5353
}
5454

55-
// Vector load with mask and pad of zero + transpose.
55+
// 2-D vector load (SME tile) with mask and pad of zero + transpose.
5656
func.func @transfer_read_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
5757
%c2 = arith.constant 2 : index
5858
%c3 = arith.constant 3 : index
@@ -68,7 +68,7 @@ func.func @transfer_read_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index,
6868
return
6969
}
7070

71-
// Vector load with mask and non-zero pad.
71+
// 2-D vector load (SME tile) with mask and non-zero pad.
7272
func.func @transfer_read_2d_mask_non_zero_pad(%A : memref<?x?xf32>, %base1: index, %base2: index) {
7373
%c2 = arith.constant 2 : index
7474
%c3 = arith.constant 3 : index
@@ -83,7 +83,7 @@ func.func @transfer_read_2d_mask_non_zero_pad(%A : memref<?x?xf32>, %base1: inde
8383
return
8484
}
8585

86-
// Vector load with mask and non-zero pad + transpose.
86+
// 2-D vector load (SME tile) with mask and non-zero pad + transpose.
8787
func.func @transfer_read_2d_mask_non_zero_pad_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
8888
%c2 = arith.constant 2 : index
8989
%c3 = arith.constant 3 : index

0 commit comments

Comments
 (0)