Skip to content

Commit fed0df9

Browse files
committed
resolve rebase conflicts
1 parent 544e77e commit fed0df9

File tree

2 files changed

+29
-30
lines changed

2 files changed

+29
-30
lines changed

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ struct BroadcastOpToArmSMELowering
254254
/// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
255255
/// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
256256
/// : memref<?x?xi32>, vector<[4]x[4]xi32>
257-
/// %transposed_src = arm_sme.tile_load <ver>, %alloca[%c0, %c0]
257+
/// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0], <vertical>
258258
/// : memref<?x?xi32>, vector<[4]x[4]xi32>
259259
///
260260
/// NOTE: Tranposing via memory is obviously expensive, the current intention
@@ -277,7 +277,7 @@ struct TransposeOpToArmSMELowering
277277
transp.push_back(cast<IntegerAttr>(attr).getInt());
278278

279279
// Bail unless this is a true 2-D matrix transpose.
280-
if (transp[0] != 1 && transp[1] != 0)
280+
if (transp[0] != 1 || transp[1] != 0)
281281
return failure();
282282

283283
OpBuilder::InsertionGuard g(rewriter);
@@ -302,13 +302,12 @@ struct TransposeOpToArmSMELowering
302302

303303
// Store input tile.
304304
auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
305-
loc, input, arm_sme::TileSliceLayout::Horizontal, buffer,
306-
ValueRange{c0, c0});
305+
loc, input, buffer, ValueRange{c0, c0});
307306

308307
// Reload input tile vertically.
309308
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
310-
transposeOp, tileType, arm_sme::TileSliceLayout::Vertical,
311-
tileStoreOp.getBase(), tileStoreOp.getIndices());
309+
transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
310+
arm_sme::TileSliceLayout::Vertical);
312311

313312
return success();
314313
}

mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
22

3-
// =============================================================================
3+
//===----------------------------------------------------------------------===//
44
// vector.transfer_write
5-
// =============================================================================
5+
//===----------------------------------------------------------------------===//
66

77
// CHECK-LABEL: func.func @transfer_write_2d_i8(
88
// CHECK-SAME: %[[VECTOR:.*]]: vector<[16]x[16]xi8>,
@@ -169,9 +169,9 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
169169
return
170170
}
171171

172-
// =============================================================================
172+
//===----------------------------------------------------------------------===//
173173
// vector.broadcast
174-
// =============================================================================
174+
//===----------------------------------------------------------------------===//
175175

176176
// -----
177177

@@ -220,9 +220,9 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
220220
return
221221
}
222222

223-
// =============================================================================
223+
//===----------------------------------------------------------------------===//
224224
// vector.transpose
225-
// =============================================================================
225+
//===----------------------------------------------------------------------===//
226226

227227
// -----
228228

@@ -233,8 +233,8 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
233233
// CHECK: %[[VSCALE:.*]] = vector.vscale
234234
// CHECK: %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
235235
// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
236-
// CHECK: arm_sme.tile_store %[[TILE]], <hor>, %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
237-
// CHECK: arm_sme.tile_load <ver>, %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
236+
// CHECK: arm_sme.tile_store %[[TILE]], %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
237+
// CHECK: arm_sme.tile_load %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
238238
func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) {
239239
%0 = vector.transpose %arg0, [1, 0] : vector<[16]x[16]xi8> to vector<[16]x[16]xi8>
240240
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
@@ -245,8 +245,8 @@ func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) {
245245

246246
// CHECK-LABEL: @transpose_i16
247247
// CHECK: arith.constant 8
248-
// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
249-
// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
248+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
249+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
250250
func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) {
251251
%0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xi16> to vector<[8]x[8]xi16>
252252
"prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
@@ -257,8 +257,8 @@ func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) {
257257

258258
// CHECK-LABEL: @transpose_i32
259259
// CHECK: arith.constant 4
260-
// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
261-
// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
260+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
261+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
262262
func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) {
263263
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
264264
"prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
@@ -269,8 +269,8 @@ func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) {
269269

270270
// CHECK-LABEL: @transpose_i64
271271
// CHECK: arith.constant 2
272-
// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
273-
// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
272+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
273+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
274274
func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) {
275275
%0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xi64> to vector<[2]x[2]xi64>
276276
"prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
@@ -282,8 +282,8 @@ func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) {
282282
// CHECK-LABEL: @transpose_i128
283283
// CHECK: %[[VSCALE:.*]] = vector.vscale
284284
// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[VSCALE]], %[[VSCALE]]) : memref<?x?xi128>
285-
// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
286-
// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
285+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
286+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
287287
func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) {
288288
%0 = vector.transpose %arg0, [1, 0] : vector<[1]x[1]xi128> to vector<[1]x[1]xi128>
289289
"prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
@@ -294,8 +294,8 @@ func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) {
294294

295295
// CHECK-LABEL: @transpose_f16
296296
// CHECK: arith.constant 8
297-
// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
298-
// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
297+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
298+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
299299
func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) {
300300
%0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xf16> to vector<[8]x[8]xf16>
301301
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
@@ -306,8 +306,8 @@ func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) {
306306

307307
// CHECK-LABEL: @transpose_bf16
308308
// CHECK: arith.constant 8
309-
// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
310-
// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
309+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
310+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
311311
func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) {
312312
%0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xbf16> to vector<[8]x[8]xbf16>
313313
"prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
@@ -318,8 +318,8 @@ func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) {
318318

319319
// CHECK-LABEL: @transpose_f32
320320
// CHECK: arith.constant 4
321-
// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
322-
// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
321+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
322+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
323323
func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) {
324324
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
325325
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
@@ -330,8 +330,8 @@ func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) {
330330

331331
// CHECK-LABEL: @transpose_f64
332332
// CHECK: arith.constant 2
333-
// CHECK: arm_sme.tile_store {{.*}}, <hor>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
334-
// CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
333+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
334+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
335335
func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
336336
%0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xf64> to vector<[2]x[2]xf64>
337337
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()

0 commit comments

Comments
 (0)