Skip to content

Commit 08dd0e4

Browse files
committed
Fixups
1 parent 84d6843 commit 08dd0e4

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -305,11 +305,11 @@ func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: ind
305305

306306
// -----
307307

308-
// CHECK-LABEL: @lift_illegal_transpose_to_memory_no_mask(
309-
// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
310-
// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
311-
// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
312-
func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
308+
// CHECK-LABEL: @lift_illegal_transpose_to_memory(
309+
// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
310+
// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
311+
// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
312+
func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
313313
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
314314
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
315315
// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
@@ -328,23 +328,17 @@ func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memre
328328

329329
// -----
330330

331-
// CHECK-LABEL: @lift_illegal_transpose_to_memory(
332-
// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
333-
// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
334-
// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
335-
// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
336-
// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
337-
func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index, %dim1: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
338-
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
339-
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
340-
// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
341-
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
342-
// CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
343-
// CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
344-
// CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
345-
// CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
346-
// CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
347-
// CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]], %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
331+
// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_mask(
332+
// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
333+
// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
334+
// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>
335+
func.func @lift_illegal_transpose_to_memory_with_mask(%dim0: index, %dim1: index, %memref: memref<?x?xf32>, %a: index, %b: index) -> vector<4x[8]xf32> {
336+
// CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
337+
// CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
338+
// CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
339+
// CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
340+
// CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
341+
// CHECK-SAME: %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
348342
// CHECK-NEXT: return %[[LEGAL_READ]]
349343
%pad = arith.constant 0.0 : f32
350344
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
@@ -356,19 +350,12 @@ func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index,
356350
// -----
357351

358352
// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_arith_extop(
359-
// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
360-
// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
361-
// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>)
353+
// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>
362354
func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
363-
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
364-
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
365-
// CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
366-
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
367-
// CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
368-
// CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xi8> to memref<?x4xi8, strided<[?, 1], offset: ?>>
369-
// CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xi8, strided<[?, 1], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
370-
// CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xi8, strided<[?, ?], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
371-
// CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_I8]] : memref<?x?xi8, strided<[?, ?], offset: ?>>, vector<4x[8]xi8>
355+
// CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
356+
// CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
357+
// CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
358+
// CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
372359
// CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
373360
// CHECK-NEXT: return %[[EXT_TYPE]]
374361
%pad = arith.constant 0 : i8
@@ -377,3 +364,16 @@ func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: inde
377364
%legalType = vector.transpose %extRead, [1, 0] : vector<[8]x4xi32> to vector<4x[8]xi32>
378365
return %legalType : vector<4x[8]xi32>
379366
}
367+
368+
// -----
369+
370+
// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_in_bounds_attr
371+
func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
372+
// CHECK: vector.transfer_read
373+
// CHECK-SAME: in_bounds = [true, false]
374+
// CHECK-NOT: in_bounds = [false, true]
375+
%pad = arith.constant 0.0 : f32
376+
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[8]x4xf32>
377+
%legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
378+
return %legalType : vector<4x[8]xf32>
379+
}

0 commit comments

Comments
 (0)